diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/StopWatch.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/StopWatch.h index 42e4eb8..225dee5 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/StopWatch.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/StopWatch.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/compat.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/compat.h index 6d54446..225c9e9 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/compat.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/compat.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/context.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/context.h index e942af2..95ea6ec 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/context.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/context.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -6,14 +21,14 @@ #pragma once #include -#include -/* #include */ +#include +/* #include */ #include #include #include -#include +#include -#include +#include #include "gemm_test.h" @@ -22,8 +37,8 @@ #ifndef SYCL_CUDA_STREAM #define SYCL_CUDA_STREAM namespace at { - namespace cuda { - inline dpct::queue_ptr getCurrentCUDAStream() { + namespace sycl { + inline dpct::queue_ptr getCurrentSYCLStream() { auto device_type = c10::DeviceType::XPU; c10::impl::VirtualGuardImpl impl(device_type); c10::Stream c10_stream = impl.getStream(c10::Device(device_type)); @@ -46,30 +61,21 @@ namespace at { #define WARP_SIZE 32 -#define CUDA_CHECK(callstr) \ - { \ - cudaError_t error_code = callstr; \ - if (error_code != cudaSuccess) { \ - std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ - assert(0); \ - } \ - } - -#define CUDA_1D_KERNEL_LOOP(i, n) \ +#define SYCL_1D_KERNEL_LOOP(i, n) \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) -#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ +#define SYCL_2D_KERNEL_LOOP(i, n, j, m) \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) -#define DS_CUDA_NUM_THREADS 512 +#define DS_SYCL_NUM_THREADS 512 #define DS_MAXIMUM_NUM_BLOCKS 262144 inline int DS_GET_BLOCKS(const int N) { return (std::max)( - (std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), - // Use at least 1 block, since CUDA does not allow empty block + (std::min)((N + DS_SYCL_NUM_THREADS - 1) / DS_SYCL_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since SYCL does not allow empty block 1); } @@ -78,11 +84,11 @@ class TrainingContext { TrainingContext() try : _workspace(nullptr), _seed(42), _curr_offset(0) { _gen = dpct::rng::create_host_rng(dpct::rng::random_engine_type::mcg59); _gen->set_seed(123); - int stat = DPCT_CHECK_ERROR(_cublasHandle = &dpct::get_in_order_queue()); + int stat = DPCT_CHECK_ERROR(_mklHandle = &dpct::get_in_order_queue()); if (stat != 0) { - // It would be nice to use cublasGetStatusName and - // cublasGetStatusString, but they were only added in CUDA 11.4.2. - auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") + + // It would be nice to use mklGetStatusName and + // mklGetStatusString, but they were only added in SYCL 11.4.2. + auto message = std::string("Failed to create mkl handle: mklStatus_t was ") + std::to_string(stat); std::cerr << message << std::endl; throw std::runtime_error(message); @@ -96,7 +102,7 @@ class TrainingContext { virtual ~TrainingContext() { - _cublasHandle = nullptr; + _mklHandle = nullptr; sycl::free(_workspace, dpct::get_in_order_queue()); } @@ -119,13 +125,13 @@ class TrainingContext { dpct::queue_ptr GetCurrentStream() { // get current pytorch stream. - dpct::queue_ptr stream = at::cuda::getCurrentCUDAStream(); + dpct::queue_ptr stream = at::sycl::getCurrentSYCLStream(); return stream; } - dpct::queue_ptr GetNewStream() { return at::cuda::getStreamFromPool(); } + dpct::queue_ptr GetNewStream() { return at::sycl::getStreamFromPool(); } - dpct::queue_ptr GetCublasHandle() { return _cublasHandle; } + dpct::queue_ptr GetCublasHandle() { return _mklHandle; } std::pair IncrementOffset(uint64_t offset_inc) { @@ -205,7 +211,7 @@ class TrainingContext { private: dpct::rng::host_rng_ptr _gen; - dpct::queue_ptr _cublasHandle; + dpct::queue_ptr _mklHandle; void* _workspace; uint64_t _seed; uint64_t _curr_offset; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/conversion_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/conversion_utils.h index 7b7adda..43d727e 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/conversion_utils.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/conversion_utils.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -6,7 +21,7 @@ #pragma once #include -#include +#include #include "ds_kernel_utils.h" #include @@ -270,12 +285,7 @@ DS_D_INLINE sycl::float2 to(sycl::marray val) template <> DS_D_INLINE sycl::half to(double val) { -#ifdef __HIP_PLATFORM_AMD__ - float val_f = __double2float_rn(val); - return __float2half(val_f); -#else return sycl::half(val); -#endif } template <> DS_D_INLINE sycl::half to(float val) diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/custom_sycl_layers.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/custom_sycl_layers.h new file mode 100644 index 0000000..3c7fa31 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/custom_sycl_layers.h @@ -0,0 +1,341 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" +#include + +#include +#include + +#include "context.h" +/* #include "mkl_wrappers.h" */ + + +#define MAX_THREADS 1024 +#define THREADS 256 + +#define MAX_THREAD_STRIDE 32 +#define TILE_DIM 32 + +// Maximum sequence-length support based on the number of threads (2048) allowed in each block and +// this MAX is 8K For higher sequence length we need to use higher Max, like for 64K : 32 +#define MAX_THREAD_ITERATIONS 8 // Maximum 8K +#define MAX_WARP_NUM 32 + +#define MAX_REGISTERS 256 + +#define MAX_REG 256 + +#define WARP_SIZE_BITS 5 + +// Fused bias add with gelu activation +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream); + +template +void launch_gelu(const T* input, + T* output, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream); + +// Custom fused bias add with layer normalization +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream, + bool preLayerNorm, + bool training, + T* vars, + T* means); + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream, + bool preLayerNorm, + bool training, + T* vars); + +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream[2]); +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream[2], + bool invertible = false, + const T* betta = nullptr); + +template +void launch_layerNorm_backward(const T* out_grad, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream[2]); + +template +void launch_layerNorm_backward(const T* out_grad, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream[2], + bool invertible = false, + const T* betta = nullptr); + +template +void launch_layerNorm_backward_nreversible(const T* out_grad, + const T* vals, + const T* out_grad_trans, + const T* vals_trans, + const T* means, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + dpct::queue_ptr stream[2]); + +template +void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, dpct::queue_ptr stream); + +template +void launch_attn_softmax_backward(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + dpct::queue_ptr stream); + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + dpct::queue_ptr stream); + +// Custom softmax with scaling and attention mask addition +template +void launch_attn_softmax(T* vals, + const T* attn_mask, + int batch_size, + int heads, + int sequence_length, + dpct::queue_ptr stream); + +template +void launch_transform_0213(T* output, + const T* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + dpct::queue_ptr stream); + +// Custom bias add +template +void launch_bias_add_transform_0213(T* outputs, + const T* vals, + const T* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + dpct::queue_ptr stream, + int trans_count); + +// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3] +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + dpct::queue_ptr stream, + int trans_count); + +template +void launch_dropout(T* vals, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + dpct::queue_ptr stream); + +template +void launch_dropout(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + dpct::queue_ptr stream, + bool bwd = false); + +template +void launch_dropout(T* out, + const T* vals, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + dpct::queue_ptr stream); + +template +void launch_dropout_grad(T* vals, + uint8_t* mask, + int total_count, + float ratio, + dpct::queue_ptr stream); + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + dpct::queue_ptr stream); + +template +void launch_fuse_transpose_bias_kernel(const T* inp, + T* out, + int rows, + int cols, + dpct::queue_ptr stream); + +void launch_param_update(const float* input, sycl::half* output, int size, dpct::queue_ptr stream); +void launch_param_update_half(const float* input, + sycl::half* output, + int size, + dpct::queue_ptr stream); + +void launch_token_sort(int32_t* indices, + int layers, + int batch_size, + int reserved_size, + int original_tokens, + dpct::queue_ptr stream); + +template +void launch_gather_tokens(T* retained_tokens, + T* activations, + int32_t* gather_indices, + int32_t batch_size, + int32_t sampled_tokens, + int32_t channels, + int32_t read_batch_stride, + int32_t read_seq_stride, + int32_t write_batch_stride, + int32_t write_seq_stride, + dpct::queue_ptr stream); + +template +void launch_scatter_tokens(T* all_activations, + T* layer_activations, + int32_t* gather_indices, + int32_t batch_size, + int32_t sampled_tokens, + int32_t channels, + int32_t read_batch_stride, + int32_t read_seq_stride, + int32_t write_batch_stride, + int32_t write_seq_stride, + dpct::queue_ptr stream); + +template +void launch_slice_gpt_mask(T* output_mask, + const T* input_mask, + int batch_size, + int truncated_seq_len, + int orig_seq_len, + dpct::queue_ptr stream); + +template +void launch_slice_bert_mask(T* output_mask, + const T* input_mask, + const int32_t* retained_indices, + int32_t layers, + int32_t batch_size, + int32_t truncated_seq_len, + int32_t orig_seq_len, + dpct::queue_ptr stream); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dequantization_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dequantization_utils.h index 3c54962..d8f2787 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/dequantization_utils.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dequantization_utils.h @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include "conversion_utils.h" #include "ds_kernel_utils.h" #include "quantization.h" diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/atomic.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/atomic.h new file mode 100644 index 0000000..4b516f5 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/atomic.h @@ -0,0 +1,842 @@ +//==---- atomic.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_ATOMIC_HPP__ +#define __DPCT_ATOMIC_HPP__ + +#include + +namespace dpct { + +/// Atomically add the value operand to the value at the addr and assign the +/// result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to add to the value at \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_add(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); +} + +template +inline T1 atomic_fetch_add(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); +} + +/// Atomically add the value operand to the value at the addr and assign the +/// result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to add to the value at \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_add(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_add(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_add(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_add(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_add(addr, operand, memoryOrder); +} + +/// Atomically subtract the value operand from the value at the addr and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to subtract from the value at \p addr +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_sub(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_sub(operand); +} + +template +inline T1 atomic_fetch_sub(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_sub(operand); +} + +/// Atomically subtract the value operand from the value at the addr and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to subtract from the value at \p addr +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_sub(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_sub(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_sub(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_sub(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_sub(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_sub(addr, operand, memoryOrder); +} + +/// Atomically perform a bitwise AND between the value operand and the value at the addr +/// and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise AND operation with the value at the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_and(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_and(operand); +} + +template +inline T1 atomic_fetch_and(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_and(operand); +} + +/// Atomically perform a bitwise AND between the value operand and the value at the addr +/// and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise AND operation with the value at the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_and(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_and(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_and(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_and(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_and(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_and(addr, operand, memoryOrder); +} + +/// Atomically or the value at the addr with the value operand, and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise OR operation with the value at the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_or(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_or(operand); +} + +template +inline T1 atomic_fetch_or(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_or(operand); +} + +/// Atomically or the value at the addr with the value operand, and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise OR operation with the value at the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_or(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_or(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_or(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_or(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_or(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_or(addr, operand, memoryOrder); +} + +/// Atomically xor the value at the addr with the value operand, and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise XOR operation with the value at the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_xor(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_xor(operand); +} + +template +inline T1 atomic_fetch_xor(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_xor(operand); +} + +/// Atomically xor the value at the addr with the value operand, and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise XOR operation with the value at the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_xor(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_xor(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_xor(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_xor(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_xor(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_xor(addr, operand, memoryOrder); +} + +/// Atomically calculate the minimum of the value at addr and the value operand +/// and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_min(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_min(operand); +} + +template +inline T1 atomic_fetch_min(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_min(operand); +} + +/// Atomically calculate the minimum of the value at addr and the value operand +/// and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_min(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_min(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_min(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_min(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_min(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_min(addr, operand, memoryOrder); +} + +/// Atomically calculate the maximum of the value at addr and the value operand +/// and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_max(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_max(operand); +} + +template +inline T1 atomic_fetch_max(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_max(operand); +} + +/// Atomically calculate the maximum of the value at addr and the value operand +/// and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_max(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_max(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_max(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_max(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_fetch_max(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_fetch_max(addr, operand, memoryOrder); +} + +/// Atomically set \p operand to the value stored in \p addr, if old value stored in +/// \p addr is equal to zero or greater than \p operand, else decrease the value stored +/// in \p addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The threshold value. +/// \param memoryOrder The memory ordering used. +/// \returns The old value stored in \p addr. +template +inline unsigned int atomic_fetch_compare_dec(unsigned int *addr, + unsigned int operand) { + auto atm = sycl::atomic_ref(addr[0]); + unsigned int old; + + while (true) { + old = atm.load(); + if (old == 0 || old > operand) { + if (atm.compare_exchange_strong(old, operand)) + break; + } else if (atm.compare_exchange_strong(old, old - 1)) + break; + } + + return old; +} + +/// Atomically increment the value stored in \p addr if old value stored in \p +/// addr is less than \p operand, else set 0 to the value stored in \p addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The threshold value. +/// \param memoryOrder The memory ordering used. +/// \returns The old value stored in \p addr. +template +inline unsigned int atomic_fetch_compare_inc(unsigned int *addr, + unsigned int operand) { + auto atm = sycl::atomic_ref(addr[0]); + unsigned int old; + while (true) { + old = atm.load(); + if (old >= operand) { + if (atm.compare_exchange_strong(old, 0)) + break; + } else if (atm.compare_exchange_strong(old, old + 1)) + break; + } + return old; +} + +/// Atomically increment the value stored in \p addr if old value stored in \p +/// addr is less than \p operand, else set 0 to the value stored in \p addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The threshold value. +/// \param memoryOrder The memory ordering used. +/// \returns The old value stored in \p addr. +template +inline unsigned int +atomic_fetch_compare_inc(unsigned int *addr, unsigned int operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_fetch_compare_inc(addr, + operand); + case sycl::memory_order::acq_rel: + return atomic_fetch_compare_inc(addr, + operand); + case sycl::memory_order::seq_cst: + return atomic_fetch_compare_inc(addr, + operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +/// Atomically exchange the value at the address addr with the value operand. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to be exchanged with the value pointed by \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_exchange(T *addr, T operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.exchange(operand); +} + +template +inline T1 atomic_exchange(T1 *addr, T2 operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.exchange(operand); +} + +/// Atomically exchange the value at the address addr with the value operand. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to be exchanged with the value pointed by \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_exchange(T *addr, T operand, + sycl::memory_order memoryOrder) { + switch (memoryOrder) { + case sycl::memory_order::relaxed: + return atomic_exchange(addr, operand); + case sycl::memory_order::acq_rel: + return atomic_exchange(addr, operand); + case sycl::memory_order::seq_cst: + return atomic_exchange(addr, operand); + default: + assert(false && "Invalid memory_order for atomics. Valid memory_order for " + "atomics are: sycl::memory_order::relaxed, " + "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); + } +} + +template +inline T1 atomic_exchange(T1 *addr, T2 operand, + sycl::memory_order memoryOrder) { + atomic_exchange(addr, operand, memoryOrder); +} + +/// Atomically compare the value at \p addr to the value expected and exchange +/// with the value desired if the value at \p addr is equal to the value expected. +/// Returns the value at the \p addr before the call. +/// \param [in, out] addr Multi_ptr. +/// \param expected The value to compare against the value at \p addr. +/// \param desired The value to assign to \p addr if the value at \p addr is expected. +/// \param success The memory ordering used when comparison succeeds. +/// \param fail The memory ordering used when comparison fails. +/// \returns The value at the \p addr before the call. +template +T atomic_compare_exchange_strong( + sycl::multi_ptr addr, T expected, T desired, + sycl::memory_order success = sycl::memory_order::relaxed, + sycl::memory_order fail = sycl::memory_order::relaxed) { + auto atm = sycl::atomic_ref(*addr); + + atm.compare_exchange_strong(expected, desired, success, fail); + return expected; +} + +template +T1 atomic_compare_exchange_strong( + sycl::multi_ptr addr, T2 expected, T3 desired, + sycl::memory_order success = sycl::memory_order::relaxed, + sycl::memory_order fail = sycl::memory_order::relaxed) { + auto atm = + sycl::atomic_ref(*addr); + T1 expected_value = expected; + atm.compare_exchange_strong(expected_value, desired, success, fail); + return expected_value; +} + +/// Atomically compare the value at \p addr to the value expected and exchange +/// with the value desired if the value at \p addr is equal to the value expected. +/// Returns the value at the \p addr before the call. +/// \param [in] addr The pointer to the data. +/// \param expected The value to compare against the value at \p addr. +/// \param desired The value to assign to \p addr if the value at \p addr is expected. +/// \param success The memory ordering used when comparison succeeds. +/// \param fail The memory ordering used when comparison fails. +/// \returns The value at the \p addr before the call. +template +T atomic_compare_exchange_strong( + T *addr, T expected, T desired, + sycl::memory_order success = sycl::memory_order::relaxed, + sycl::memory_order fail = sycl::memory_order::relaxed) { + auto atm = + sycl::atomic_ref(addr[0]); + atm.compare_exchange_strong(expected, desired, success, fail); + return expected; +} + +template +T1 atomic_compare_exchange_strong( + T1 *addr, T2 expected, T3 desired, + sycl::memory_order success = sycl::memory_order::relaxed, + sycl::memory_order fail = sycl::memory_order::relaxed) { + T1 expected_value = expected; + auto atm = + sycl::atomic_ref(addr[0]); + atm.compare_exchange_strong(expected_value, desired, success, fail); + return expected_value; +} + +/// Atomic extension to implement standard APIs in std::atomic +namespace detail{ +template struct IsValidAtomicType { + static constexpr bool value = + (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_pointer::value); +}; +} // namespace detail + +template +class atomic{ + static_assert( + detail::IsValidAtomicType::value, + "Invalid atomic type. Valid types are int, unsigned int, long, " + "unsigned long, long long, unsigned long long, float, double " + "and pointer types"); + T __d; + +public: + /// default memory synchronization order + static constexpr sycl::memory_order default_read_order = + sycl::atomic_ref::default_read_order; + static constexpr sycl::memory_order default_write_order = + sycl::atomic_ref::default_write_order; + static constexpr sycl::memory_scope default_scope = DefaultScope; + static constexpr sycl::memory_order default_read_modify_write_order = + DefaultOrder; + + + /// Default constructor. + constexpr atomic() noexcept = default; + /// Constructor with initialize value. + constexpr atomic(T d) noexcept : __d(d){}; + + /// atomically replaces the value of the referenced object with a non-atomic argument + /// \param operand The value to replace the pointed value. + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + void store(T operand, sycl::memory_order memoryOrder = default_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + atm.store(operand, memoryOrder, memoryScope); + } + + /// atomically obtains the value of the referenced object + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object + T load(sycl::memory_order memoryOrder = default_read_order, + sycl::memory_scope memoryScope = default_scope) const noexcept { + sycl::atomic_ref atm( + const_cast(__d)); + return atm.load(memoryOrder, memoryScope); + } + + /// atomically replaces the value of the referenced object and obtains the value held previously + /// \param operand The value to replace the pointed value. + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object before the call. + T exchange(T operand, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + + sycl::atomic_ref atm(__d); + return atm.exchange(operand, memoryOrder, memoryScope); + } + + /// atomically compares the value of the referenced object with non-atomic argument + /// and performs atomic exchange if equal or atomic load if not + /// \param expected The value expected to be found in the object referenced by the atomic_ref object + /// \param desired The value to store in the referenced object if it is as expected + /// \param success The memory models for the read-modify-write + /// \param failure The memory models for load operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false otherwise. + bool compare_exchange_weak( + T &expected, T desired, + sycl::memory_order success, sycl::memory_order failure, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + return atm.compare_exchange_weak(expected, desired, success, failure, memoryScope); + } + /// \param expected The value expected to be found in the object referenced by the atomic_ref object + /// \param desired The value to store in the referenced object if it is as expected + /// \param memoryOrder The memory synchronization ordering for operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false otherwise. + bool compare_exchange_weak(T &expected, T desired, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + return atm.compare_exchange_weak(expected, desired, memoryOrder, memoryScope); + } + + /// atomically compares the value of the referenced object with non-atomic argument + /// and performs atomic exchange if equal or atomic load if not + /// \param expected The value expected to be found in the object referenced by the atomic_ref object + /// \param desired The value to store in the referenced object if it is as expected + /// \param success The memory models for the read-modify-write + /// \param failure The memory models for load operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false otherwise. + bool compare_exchange_strong( + T &expected, T desired, + sycl::memory_order success, sycl::memory_order failure, + sycl::memory_scope memoryScope = default_scope) noexcept { + + sycl::atomic_ref atm(__d); + return atm.compare_exchange_strong(expected, desired, success, failure, memoryScope); + } + /// \param expected The value expected to be found in the object referenced by the atomic_ref object + /// \param desired The value to store in the referenced object if it is as expected + /// \param memoryOrder The memory synchronization ordering for operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false otherwise. + bool compare_exchange_strong(T &expected, T desired, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + return atm.compare_exchange_strong(expected, desired, memoryOrder, memoryScope); + } + + /// atomically adds the argument to the value stored in the atomic object and obtains the value held previously + /// \param operand The other argument of arithmetic addition + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object before the call. + T fetch_add(T operand, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + + sycl::atomic_ref atm(__d); + return atm.fetch_add(operand, memoryOrder, memoryScope); + } + + /// atomically subtracts the argument from the value stored in the atomic object and obtains the value held previously + /// \param operand The other argument of arithmetic subtraction + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object before the call. + T fetch_sub(T operand, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + + sycl::atomic_ref atm(__d); + return atm.fetch_sub(operand, memoryOrder, memoryScope); + } +}; + +} // namespace dpct +#endif // __DPCT_ATOMIC_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/blas_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/blas_utils.h new file mode 100644 index 0000000..7b25b80 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/blas_utils.h @@ -0,0 +1,1792 @@ +//==---- blas_utils.hpp----------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_BLAS_UTILS_HPP__ +#define __DPCT_BLAS_UTILS_HPP__ + +#include "memory.h" +#include "util.h" +#include "lib_common_utils.h" +#include +#include +#include +#include +#include + +namespace dpct { + +/// Get the value of \p s. +/// Copy the data to host synchronously, then return the data. +/// \param [in] p The pointer points the data. +/// \param [in] q The queue where the memory copy should be executed. +template +inline auto get_value(const T *s, sycl::queue &q) { + return detail::get_value(s, q); +} + +namespace detail { +inline void mem_free(sycl::queue *exec_queue, + std::vector pointers_array, sycl::event e) { + e.wait(); + for (auto p : pointers_array) + sycl::free(p, *exec_queue); +} + +inline int stride_for(int num_elems, int mem_align_in_elems) { + return ((num_elems - 1) / mem_align_in_elems + 1) * mem_align_in_elems; +} + +#ifndef DPCT_USM_LEVEL_NONE +template +class working_memory { + T *_input_ptr; + T *_temp_ptr; + bool _is_sycl_malloced = false; + bool _is_scalar_value = false; + sycl::queue _q; + sycl::event _e; + +public: + working_memory(size_t size, sycl::queue q) : _q(q) { + _is_scalar_value = false; + _temp_ptr = (T *)sycl::malloc_device(size, q); + } + working_memory(T *result_ptr, sycl::queue q) : _input_ptr(result_ptr), _q(q) { + _is_scalar_value = true; + _is_sycl_malloced = sycl::get_pointer_type(_input_ptr, _q.get_context()) != + sycl::usm::alloc::unknown; + if (!_is_sycl_malloced) + _temp_ptr = sycl::malloc_shared(1, _q); + } + auto get_ptr() { + if (_is_scalar_value && _is_sycl_malloced) + return _input_ptr; + return _temp_ptr; + } + void set_event(sycl::event e) { _e = e; } + ~working_memory() { + if (_is_scalar_value) { + if (!_is_sycl_malloced) { + _q.memcpy(_input_ptr, _temp_ptr, sizeof(T)).wait(); + sycl::free(_temp_ptr, _q); + } + } else { + std::vector ptrs{_temp_ptr}; + dpct::async_dpct_free(ptrs, {_e}); + } + } +}; +#endif + +template +inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx, + void *result) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else +#ifdef DPCT_USM_LEVEL_NONE + auto x_buffer = dpct::get_buffer(x); + auto r_buffer = + sycl::buffer(reinterpret_cast(result), sycl::range<1>(1)); + if (dpct::is_device_ptr(result)) + r_buffer = dpct::get_buffer(result); + oneapi::mkl::blas::column_major::nrm2(q, n, x_buffer, incx, r_buffer); +#else + working_memory res_mem(reinterpret_cast(result), q); + oneapi::mkl::blas::column_major::nrm2(q, n, reinterpret_cast(x), + incx, res_mem.get_ptr()); +#endif +#endif +} + +template +inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx, + const Txy *y, int incy, Tr *result) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else +#ifdef DPCT_USM_LEVEL_NONE + auto x_buffer = dpct::get_buffer(x); + auto y_buffer = dpct::get_buffer(y); + auto r_buffer = sycl::buffer((Tr *)result, sycl::range<1>(1)); + if (dpct::is_device_ptr(result)) + r_buffer = dpct::get_buffer(result); + if constexpr (std::is_same_v> || + std::is_same_v>) { + if constexpr (is_conjugate) + oneapi::mkl::blas::column_major::dotc(q, n, x_buffer, incx, y_buffer, + incy, r_buffer); + else + oneapi::mkl::blas::column_major::dotu(q, n, x_buffer, incx, y_buffer, + incy, r_buffer); + } else + oneapi::mkl::blas::column_major::dot(q, n, x_buffer, incx, y_buffer, incy, + r_buffer); +#else + working_memory res_mem(result, q); + if constexpr (std::is_same_v> || + std::is_same_v>) { + if constexpr (is_conjugate) + oneapi::mkl::blas::column_major::dotc(q, n, x, incx, y, incy, res_mem.get_ptr()); + else + oneapi::mkl::blas::column_major::dotu(q, n, x, incx, y, incy, res_mem.get_ptr()); + } else + oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, res_mem.get_ptr()); +#endif +#endif +} + +template +inline void dotuc(sycl::queue &q, int n, const void *x, + library_data_t x_type, int incx, const void *y, + library_data_t y_type, int incy, void *result, + library_data_t result_type) { + std::uint64_t key = detail::get_type_combination_id(x_type, y_type, result_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float): { + detail::dotuc_impl( + q, n, reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double): { + detail::dotuc_impl( + q, n, reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float): { + detail::dotuc_impl( + q, n, reinterpret_cast *>(x), incx, + reinterpret_cast *>(y), incy, + reinterpret_cast *>(result)); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double): { + detail::dotuc_impl( + q, n, reinterpret_cast *>(x), incx, + reinterpret_cast *>(y), incy, + reinterpret_cast *>(result)); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half): { + detail::dotuc_impl( + q, n, reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +template +inline void scal_impl(sycl::queue &q, int n, const void *alpha, void *x, + int incx) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); + auto data_x = get_memory(x); + oneapi::mkl::blas::column_major::scal(q, n, alpha_val, + data_x, incx); +#endif +} + +template +inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x, + int incx, void *y, int incy) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, + data_x, incx, + data_y, incy); +#endif +} + +template +inline void rot_impl(sycl::queue &q, int n, void *x, int incx, void *y, + int incy, const void *c, const void *s) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + Tc c_value = dpct::get_value(reinterpret_cast(c), q); + Ts s_value = dpct::get_value(reinterpret_cast(s), q); + auto data_x = get_memory(x); + auto data_y = get_memory(y); + oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, + data_y, incy, c_value, + s_value); +#endif +} + +template +inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, int lda, const void *b, + int ldb, const void *beta, void *c, int ldc) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm( + q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + data_b, ldb, beta_value, data_c, ldc); +#endif +} + +template +inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void **a, int lda, + const void **b, int ldb, const void *beta, void **c, + int ldc, int batch_size) { + struct matrix_info_t { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; + }; + + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + + matrix_info_t *matrix_info = + (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); + matrix_info->transpose_info[0] = a_trans; + matrix_info->transpose_info[1] = b_trans; + matrix_info->value_info[0] = alpha_value; + matrix_info->value_info[1] = beta_value; + matrix_info->size_info[0] = m; + matrix_info->size_info[1] = n; + matrix_info->size_info[2] = k; + matrix_info->ld_info[0] = lda; + matrix_info->ld_info[1] = ldb; + matrix_info->ld_info[2] = ldc; + matrix_info->groupsize_info = batch_size; + + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, + matrix_info->size_info, matrix_info->size_info + 1, + matrix_info->size_info + 2, matrix_info->value_info, + reinterpret_cast(a), matrix_info->ld_info, + reinterpret_cast(b), matrix_info->ld_info + 1, + matrix_info->value_info + 1, reinterpret_cast(c), + matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + + q.submit([&](sycl::handler &cgh) { + cgh.depends_on(e); + cgh.host_task([=] { std::free(matrix_info); }); + }); +} + +template +inline void +gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, + int k, const void *alpha, const void *a, int lda, + long long int stride_a, const void *b, int ldb, + long long int stride_b, const void *beta, void *c, + int ldc, long long int stride_c, int batch_size) { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm_batch( + q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + stride_a, data_b, ldb, stride_b, beta_value, + data_c, ldc, stride_c, batch_size); +} + +template +inline void rk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, int n, int k, + const T *alpha, const T *a, int lda, const T *b, + int ldb, const Tbeta *beta, T *c, int ldc) { + // For symmetric matrix, this function performs: C = alpha*OP(A)*(OP(B))^T + beta*C + // For Hermitian matrix, this function performs: C = alpha*OP(A)*(OP(B))^H + beta*C + // The gemmt() function performs: C = alpha*OPA(A)*OPB(B) + beta*C + // So the OPB need be updated before we call gemmt(). + using Ty = typename dpct::DataType::T2; + using Ts = typename dpct::DataType::T2; + Ty alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + oneapi::mkl::transpose trans_A = trans, trans_B = trans; + int origin_b_rows = trans == oneapi::mkl::transpose::nontrans ? n : k; + int origin_b_cols = trans == oneapi::mkl::transpose::nontrans ? k : n; + + if ((is_hermitian && trans == oneapi::mkl::transpose::trans) || + (!is_hermitian && !std::is_floating_point_v && trans == oneapi::mkl::transpose::conjtrans)) { + // In this case, OPB need be a conjugate operation, + // but only notrans, conjtrans and trans are available. + // So we need do a conjtrans operation first, then do a trans operation. + trans_B = oneapi::mkl::transpose::trans; + auto data_a = get_memory(a); + auto data_c = get_memory(c); +#ifdef DPCT_USM_LEVEL_NONE + auto new_B_buffer = sycl::buffer(sycl::range<1>(origin_b_rows * origin_b_cols)); + auto from_buffer = dpct::get_buffer(b); + oneapi::mkl::blas::column_major::omatcopy_batch( + q, oneapi::mkl::transpose::conjtrans, origin_b_rows, origin_b_cols, + Ts(1.0), from_buffer, ldb, origin_b_rows * ldb, new_B_buffer, + origin_b_cols, origin_b_rows * origin_b_cols, 1); + oneapi::mkl::blas::column_major::gemmt( + q, uplo, trans_A, trans_B, n, k, alpha_value, + data_a, lda, new_B_buffer, origin_b_cols, beta_value, data_c, ldc); +#else + working_memory new_B(origin_b_rows * origin_b_cols * sizeof(T), q); + oneapi::mkl::blas::column_major::omatcopy_batch( + q, oneapi::mkl::transpose::conjtrans, origin_b_rows, origin_b_cols, + Ts(1.0), reinterpret_cast(b), ldb, origin_b_rows * ldb, + reinterpret_cast(new_B.get_ptr()), origin_b_cols, + origin_b_rows * origin_b_cols, 1); + sycl::event e = oneapi::mkl::blas::column_major::gemmt( + q, uplo, trans_A, trans_B, n, k, alpha_value, + data_a, lda, reinterpret_cast(new_B.get_ptr()), origin_b_cols, + beta_value, data_c, ldc); + new_B.set_event(e); +#endif + } else { + if constexpr (is_hermitian) { + trans_B = trans == oneapi::mkl::transpose::nontrans + ? oneapi::mkl::transpose::conjtrans + : oneapi::mkl::transpose::nontrans; + } else { + trans_B = trans == oneapi::mkl::transpose::nontrans + ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + } + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemmt( + q, uplo, trans_A, trans_B, n, k, alpha_value, + data_a, lda, data_b, ldb, beta_value, data_c, ldc); + } +} + +template +inline void +trsm_batch_impl(sycl::queue &q, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, int m, int n, const void *alpha, + const void **a, int lda, void **b, int ldb, int batch_size) { + struct matrix_info_t { + matrix_info_t(oneapi::mkl::side side_info, oneapi::mkl::uplo uplo_info, + oneapi::mkl::transpose transpose_info, + oneapi::mkl::diag diag_info, Ts value_info, std::int64_t m, + std::int64_t n, std::int64_t lda, std::int64_t ldb, + std::int64_t groupsize_info) + : side_info(side_info), uplo_info(uplo_info), + transpose_info(transpose_info), diag_info(diag_info), + value_info(value_info), groupsize_info(groupsize_info) { + size_info[0] = m; + size_info[1] = n; + ld_info[0] = lda; + ld_info[1] = ldb; + } + oneapi::mkl::side side_info; + oneapi::mkl::uplo uplo_info; + oneapi::mkl::transpose transpose_info; + oneapi::mkl::diag diag_info; + Ts value_info; + std::int64_t size_info[2]; + std::int64_t ld_info[2]; + std::int64_t groupsize_info; + }; + + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + + matrix_info_t *matrix_info = + new matrix_info_t(left_right, upper_lower, trans, unit_diag, alpha_value, + m, n, lda, ldb, batch_size); + + sycl::event e = oneapi::mkl::blas::column_major::trsm_batch( + q, &(matrix_info->side_info), &(matrix_info->uplo_info), + &(matrix_info->transpose_info), &(matrix_info->diag_info), + matrix_info->size_info, matrix_info->size_info + 1, + &(matrix_info->value_info), reinterpret_cast(a), + matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, 1, &(matrix_info->groupsize_info)); + + q.submit([&](sycl::handler &cgh) { + cgh.depends_on(e); + cgh.host_task([=] { delete matrix_info; }); + }); +} + +template +inline void getrfnp_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], + int lda, int *info, int batch_size) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + using Ty = typename DataType::T2; + // Set the info array value to 0 + detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size); + std::int64_t stride_a = n * lda; + std::int64_t scratchpad_size = + oneapi::mkl::lapack::getrfnp_batch_scratchpad_size( + exec_queue, n, n, lda, stride_a, batch_size); + + Ty *a_strided_mem = + (Ty *)dpct::dpct_malloc(stride_a * batch_size * sizeof(Ty), exec_queue); + T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); + dpct::dpct_memcpy(host_a, a, batch_size * sizeof(T *)); + for (std::int64_t i = 0; i < batch_size; ++i) + dpct::dpct_memcpy(a_strided_mem + i * stride_a, host_a[i], + n * lda * sizeof(T)); + +#ifdef DPCT_USM_LEVEL_NONE + { + sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; + auto a_buffer = get_buffer(a_strided_mem); + oneapi::mkl::lapack::getrfnp_batch(exec_queue, n, n, a_buffer, lda, + stride_a, batch_size, scratchpad, + scratchpad_size); + } + std::vector events; + for (std::int64_t i = 0; i < batch_size; ++i) + events.push_back(detail::dpct_memcpy(exec_queue, host_a[i], + a_strided_mem + i * stride_a, + n * lda * sizeof(T), automatic)); +#else + Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); + sycl::event e = oneapi::mkl::lapack::getrfnp_batch( + exec_queue, n, n, a_strided_mem, lda, stride_a, batch_size, scratchpad, + scratchpad_size); + std::vector events; + for (std::int64_t i = 0; i < batch_size; ++i) + events.push_back(detail::dpct_memcpy(exec_queue, host_a[i], + a_strided_mem + i * stride_a, + n * lda * sizeof(T), automatic, {e})); + + std::vector ptrs{scratchpad, a_strided_mem}; + dpct::async_dpct_free(ptrs, events, exec_queue); +#endif + + exec_queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(events); + cgh.host_task([=] { std::free(host_a); }); + }); +#endif +} + +} // namespace detail + +inline oneapi::mkl::transpose get_transpose(int t) { + if (t == 0) { + return oneapi::mkl::transpose::nontrans; + } else if (t == 1) { + return oneapi::mkl::transpose::trans; + } else { + return oneapi::mkl::transpose::conjtrans; + } +} + +/// Computes the LU factorizations of a batch of general matrices. +/// \param [in] exec_queue The queue where the routine should be executed. +/// \param [in] n The order of the matrices. +/// \param [in, out] a Array of pointers to matrices. These matrices will be +/// overwritten by lower triangulars with unit diagonal elements and upper +/// triangulars. +/// \param [in] lda The leading dimension of the matrices. +/// \param [out] ipiv An array stores the pivot indices. If \p ipiv is nullptr, +/// non-pivoting LU factorization is computed. +/// \param [out] info An array stores the error information. +/// \param [in] batch_size The size of the batch. +template +inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], + int lda, int *ipiv, int *info, int batch_size) { + if (ipiv == nullptr) { + detail::getrfnp_batch_wrapper(exec_queue, n, a, lda, info, batch_size); + return; + } + using Ty = typename DataType::T2; + // Set the info array value to 0 + detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size); +#ifdef DPCT_USM_LEVEL_NONE + std::int64_t stride_a = n * lda; + std::int64_t stride_ipiv = n; + std::int64_t scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size( + exec_queue, n, n, lda, stride_a, stride_ipiv, batch_size); + + T *a_buffer_ptr; + a_buffer_ptr = (T *)dpct_malloc(stride_a * batch_size * sizeof(T)); + + T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); + dpct_memcpy(host_a, a, batch_size * sizeof(T *)); + for (std::int64_t i = 0; i < batch_size; ++i) + dpct_memcpy(a_buffer_ptr + i * stride_a, host_a[i], n * lda * sizeof(T)); + + { + sycl::buffer ipiv_buf( + sycl::range<1>(batch_size * stride_ipiv)); + sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; + auto a_buffer = get_buffer(a_buffer_ptr); + oneapi::mkl::lapack::getrf_batch(exec_queue, n, n, a_buffer, lda, stride_a, + ipiv_buf, stride_ipiv, batch_size, scratchpad, + scratchpad_size); + + auto to_buffer = get_buffer(ipiv); + exec_queue.submit([&](sycl::handler &cgh) { + auto from_acc = ipiv_buf.get_access(cgh); + auto to_acc = to_buffer.get_access(cgh); + cgh.parallel_for>( + sycl::range<2>(batch_size, n), [=](sycl::id<2> id) { + to_acc[id.get(0) * n + id.get(1)] = + static_cast(from_acc[id.get(0) * stride_ipiv + id.get(1)]); + }); + }); + } + + // Copy back to the original buffers + std::vector events; + for (std::int64_t i = 0; i < batch_size; ++i) + events.push_back(detail::dpct_memcpy(exec_queue, host_a[i], + a_buffer_ptr + i * stride_a, + n * lda * sizeof(T), automatic)); + + std::vector ptrs{host_a}; + std::thread mem_free_thread( + [=](std::vector pointers_array, + std::vector events_array) { + sycl::event::wait(events_array); + for (auto p : pointers_array) + std::free(p); + }, + ptrs, events); + mem_free_thread.detach(); +#else + std::int64_t m_int64 = n; + std::int64_t n_int64 = n; + std::int64_t lda_int64 = lda; + std::int64_t group_sizes = batch_size; + std::int64_t scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size( + exec_queue, &m_int64, &n_int64, &lda_int64, 1, &group_sizes); + + Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); + std::int64_t *ipiv_int64 = + sycl::malloc_device(batch_size * n, exec_queue); + std::int64_t **ipiv_int64_ptr = + sycl::malloc_shared(batch_size, exec_queue); + T **a_shared = sycl::malloc_shared(batch_size, exec_queue); + exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)).wait(); + for (std::int64_t i = 0; i < batch_size; ++i) + ipiv_int64_ptr[i] = ipiv_int64 + n * i; + + oneapi::mkl::lapack::getrf_batch(exec_queue, &m_int64, &n_int64, (Ty **)a_shared, &lda_int64, + ipiv_int64_ptr, 1, &group_sizes, scratchpad, + scratchpad_size); + + sycl::event e = exec_queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for>( + sycl::range<1>(batch_size * n), [=](sycl::id<1> idx) { + ipiv[idx] = static_cast(ipiv_int64[idx]); + }); + }); + + std::vector ptrs{scratchpad, ipiv_int64, ipiv_int64_ptr, a_shared}; + async_dpct_free(ptrs, {e}, exec_queue); +#endif +} + +/// Solves a system of linear equations with a batch of LU-factored square +/// coefficient matrices, with multiple right-hand sides. +/// \param [in] exec_queue The queue where the routine should be executed. +/// \param [in] trans Indicates the form of the linear equations. +/// \param [in] n The order of the matrices. +/// \param [in] nrhs The number of right hand sides. +/// \param [in] a Array of pointers to matrices. +/// \param [in] lda The leading dimension of the matrices in \p a. +/// \param [in] ipiv An array stores the pivots. +/// \param [in, out] b Array of pointers to matrices, whose columns are +/// the right-hand sides for the systems of equations. +/// \param [in] ldb The leading dimension of the matrices in \p b. +/// \param [out] info A value stores the error information. +/// \param [in] batch_size The size of the batch. +template +inline void getrs_batch_wrapper(sycl::queue &exec_queue, + oneapi::mkl::transpose trans, int n, int nrhs, + const T *a[], int lda, const int *ipiv, T *b[], + int ldb, int *info, int batch_size) { + using Ty = typename DataType::T2; + // Set the info value to 0 + *info = 0; +#ifdef DPCT_USM_LEVEL_NONE + std::int64_t stride_a = n * lda; + std::int64_t stride_b = nrhs * ldb; + std::int64_t stride_ipiv = n; + std::int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size( + exec_queue, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, + batch_size); + + T *a_buffer_ptr, *b_buffer_ptr; + a_buffer_ptr = (T *)dpct_malloc(stride_a * batch_size * sizeof(T)); + b_buffer_ptr = (T *)dpct_malloc(stride_b * batch_size * sizeof(T)); + + T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); + T **host_b = (T **)std::malloc(batch_size * sizeof(T *)); + dpct_memcpy(host_a, a, batch_size * sizeof(T *)); + dpct_memcpy(host_b, b, batch_size * sizeof(T *)); + for (std::int64_t i = 0; i < batch_size; ++i) { + dpct_memcpy(a_buffer_ptr + i * stride_a, host_a[i], n * lda * sizeof(T)); + dpct_memcpy(b_buffer_ptr + i * stride_b, host_b[i], nrhs * ldb * sizeof(T)); + } + + { + auto a_buffer = get_buffer(a_buffer_ptr); + auto b_buffer = get_buffer(b_buffer_ptr); + sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; + sycl::buffer ipiv_buf( + sycl::range<1>(batch_size * stride_ipiv)); + auto from_buf = get_buffer(ipiv); + exec_queue.submit([&](sycl::handler &cgh) { + auto from_acc = from_buf.get_access(cgh); + auto to_acc = ipiv_buf.get_access(cgh); + cgh.parallel_for>( + sycl::range<2>(batch_size, n), [=](sycl::id<2> id) { + to_acc[id.get(0) * stride_ipiv + id.get(1)] = + static_cast(from_acc[id.get(0) * n + id.get(1)]); + }); + }); + + oneapi::mkl::lapack::getrs_batch(exec_queue, trans, n, nrhs, a_buffer, lda, + stride_a, ipiv_buf, stride_ipiv, b_buffer, ldb, + stride_b, batch_size, scratchpad, scratchpad_size); + } + + // Copy back to the original buffers + std::vector events; + for (std::int64_t i = 0; i < batch_size; ++i) + events.push_back(detail::dpct_memcpy(exec_queue, host_b[i], + b_buffer_ptr + i * stride_b, + nrhs * ldb * sizeof(T), automatic)); + std::vector ptrs{host_a, host_b}; + std::thread mem_free_thread( + [=](std::vector pointers_array, + std::vector events_array) { + sycl::event::wait(events_array); + for (auto p : pointers_array) + std::free(p); + }, + ptrs, events); + mem_free_thread.detach(); +#else + std::int64_t n_int64 = n; + std::int64_t nrhs_int64 = nrhs; + std::int64_t lda_int64 = lda; + std::int64_t ldb_int64 = ldb; + std::int64_t group_sizes = batch_size; + std::int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size( + exec_queue, &trans, &n_int64, &nrhs_int64, &lda_int64, &ldb_int64, 1, + &group_sizes); + + Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); + std::int64_t *ipiv_int64 = + sycl::malloc_device(batch_size * n, exec_queue); + std::int64_t **ipiv_int64_ptr = + sycl::malloc_shared(batch_size, exec_queue); + T **a_shared = sycl::malloc_shared(batch_size, exec_queue); + T **b_shared = sycl::malloc_shared(batch_size, exec_queue); + exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)); + exec_queue.memcpy(b_shared, b, batch_size * sizeof(T *)); + + exec_queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for>( + sycl::range<1>(batch_size * n), [=](sycl::id<1> idx) { + ipiv_int64[idx] = static_cast(ipiv[idx]); + }); + }).wait(); + + for (std::int64_t i = 0; i < batch_size; ++i) + ipiv_int64_ptr[i] = ipiv_int64 + n * i; + + sycl::event e = oneapi::mkl::lapack::getrs_batch( + exec_queue, &trans, &n_int64, &nrhs_int64, (Ty **)a_shared, &lda_int64, + ipiv_int64_ptr, (Ty **)b_shared, &ldb_int64, 1, &group_sizes, scratchpad, + scratchpad_size); + + std::vector ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared, b_shared}; + async_dpct_free(ptrs, {e}, exec_queue); +#endif +} + +/// Computes the inverses of a batch of LU-factored matrices. +/// \param [in] exec_queue The queue where the routine should be executed. +/// \param [in] n The order of the matrices. +/// \param [in] a Array of pointers to matrices. +/// \param [in] lda The leading dimension of the matrices in \p a. +/// \param [in] ipiv An array stores the pivots. +/// \param [out] b Array of pointers to inverse matrices. +/// \param [in] ldb The leading dimension of the matrices in \p b. +/// \param [out] info An array stores the error information. +/// \param [in] batch_size The size of the batch. +template +inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, + const T *a[], int lda, int *ipiv, T *b[], + int ldb, int *info, int batch_size) { + using Ty = typename DataType::T2; + // Set the info array value to 0 + detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size); +#ifdef DPCT_USM_LEVEL_NONE + std::int64_t stride_b = n * ldb; + std::int64_t stride_ipiv = n; + std::int64_t scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size( + exec_queue, n, ldb, stride_b, stride_ipiv, batch_size); + + T *b_buffer_ptr; + b_buffer_ptr = (T *)dpct_malloc(stride_b * batch_size * sizeof(T)); + + T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); + T **host_b = (T **)std::malloc(batch_size * sizeof(T *)); + dpct_memcpy(host_a, a, batch_size * sizeof(T *)); + dpct_memcpy(host_b, b, batch_size * sizeof(T *)); + + for (std::int64_t i = 0; i < batch_size; ++i) { + // Need to create a copy of input matrices "a" to keep them unchanged. + // Matrices "b" (copy of matrices "a") will be used as input and output + // parameter in oneapi::mkl::lapack::getri_batch call. + matrix_mem_copy(b_buffer_ptr + i * stride_b, host_a[i], ldb, lda, n, n, + dpct::device_to_device, exec_queue); + } + + { + auto b_buffer = get_buffer(b_buffer_ptr); + sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; + sycl::buffer ipiv_buf( + sycl::range<1>(batch_size * stride_ipiv)); + auto from_buf = get_buffer(ipiv); + exec_queue.submit([&](sycl::handler &cgh) { + auto from_acc = from_buf.get_access(cgh); + auto to_acc = ipiv_buf.get_access(cgh); + cgh.parallel_for>( + sycl::range<2>(batch_size, n), [=](sycl::id<2> id) { + to_acc[id.get(0) * stride_ipiv + id.get(1)] = + static_cast(from_acc[id.get(0) * n + id.get(1)]); + }); + }); + + oneapi::mkl::lapack::getri_batch(exec_queue, n, b_buffer, ldb, stride_b, ipiv_buf, + stride_ipiv, batch_size, scratchpad, + scratchpad_size); + } + + // Copy back to the original buffers + std::vector events; + for (std::int64_t i = 0; i < batch_size; ++i) + events.push_back(detail::dpct_memcpy(exec_queue, host_b[i], + b_buffer_ptr + i * stride_b, + n * ldb * sizeof(T), automatic)); + std::vector ptrs{host_a, host_b}; + std::thread mem_free_thread( + [=](std::vector pointers_array, + std::vector events_array) { + sycl::event::wait(events_array); + for (auto p : pointers_array) + std::free(p); + }, + ptrs, events); + mem_free_thread.detach(); +#else + std::int64_t n_int64 = n; + std::int64_t ldb_int64 = ldb; + std::int64_t group_sizes = batch_size; + std::int64_t scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size( + exec_queue, &n_int64, &ldb_int64, 1, &group_sizes); + + Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); + std::int64_t *ipiv_int64 = + sycl::malloc_device(batch_size * n, exec_queue); + std::int64_t **ipiv_int64_ptr = + sycl::malloc_shared(batch_size, exec_queue); + + exec_queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for>( + sycl::range<1>(batch_size * n), [=](sycl::id<1> idx) { + ipiv_int64[idx] = static_cast(ipiv[idx]); + }); + }); + + T **a_shared = sycl::malloc_shared(batch_size, exec_queue); + T **b_shared = sycl::malloc_shared(batch_size, exec_queue); + exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)); + exec_queue.memcpy(b_shared, b, batch_size * sizeof(T *)).wait(); + for (std::int64_t i = 0; i < batch_size; ++i) { + ipiv_int64_ptr[i] = ipiv_int64 + n * i; + // Need to create a copy of input matrices "a" to keep them unchanged. + // Matrices "b" (copy of matrices "a") will be used as input and output + // parameter in oneapi::mkl::lapack::getri_batch call. + matrix_mem_copy(b_shared[i], a_shared[i], ldb, lda, n, n, dpct::device_to_device, + exec_queue); + } + + sycl::event e = oneapi::mkl::lapack::getri_batch( + exec_queue, &n_int64, (Ty **)b_shared, &ldb_int64, ipiv_int64_ptr, 1, + &group_sizes, scratchpad, scratchpad_size); + + std::vector ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared, b_shared}; + async_dpct_free(ptrs, {e}, exec_queue); +#endif +} + +/// Computes the QR factorizations of a batch of general matrices. +/// \param [in] exec_queue The queue where the routine should be executed. +/// \param [in] m The number of rows in the matrices. +/// \param [in] n The number of columns in the matrices. +/// \param [in, out] a Array of pointers to matrices. These +/// matrices will be overwritten by the factorization data. +/// \param [in] lda The leading dimension of the matrices in \p a. +/// \param [out] tau An array stores the scalars. +/// \param [out] info A value stores the error information. +/// \param [in] batch_size The size of the batch. +template +inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, + T *a[], int lda, T *tau[], int *info, + int batch_size) { + using Ty = typename DataType::T2; + // Set the info value to 0 + *info = 0; +#ifdef DPCT_USM_LEVEL_NONE + std::int64_t stride_a = n * lda; + std::int64_t stride_tau = std::max(1, std::min(m, n)); + std::int64_t scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size( + exec_queue, m, n, lda, stride_a, stride_tau, batch_size); + + T *a_buffer_ptr, *tau_buffer_ptr; + a_buffer_ptr = (T *)dpct_malloc(stride_a * batch_size * sizeof(T)); + tau_buffer_ptr = (T *)dpct_malloc(stride_tau * batch_size * sizeof(T)); + + T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); + T **host_tau = (T **)std::malloc(batch_size * sizeof(T *)); + dpct_memcpy(host_a, a, batch_size * sizeof(T *)); + dpct_memcpy(host_tau, tau, batch_size * sizeof(T *)); + + for (std::int64_t i = 0; i < batch_size; ++i) + dpct_memcpy(a_buffer_ptr + i * stride_a, host_a[i], n * lda * sizeof(T)); + { + auto a_buffer = get_buffer(a_buffer_ptr); + auto tau_buffer = get_buffer(tau_buffer_ptr); + sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; + oneapi::mkl::lapack::geqrf_batch(exec_queue, m, n, a_buffer, lda, stride_a, + tau_buffer, stride_tau, batch_size, scratchpad, + scratchpad_size); + } + + // Copy back to the original buffers + std::vector events_a; + std::vector events_tau; + for (std::int64_t i = 0; i < batch_size; ++i) { + events_a.push_back(detail::dpct_memcpy(exec_queue, host_a[i], + a_buffer_ptr + i * stride_a, + n * lda * sizeof(T), automatic)); + events_tau.push_back(detail::dpct_memcpy( + exec_queue, host_tau[i], tau_buffer_ptr + i * stride_tau, + std::max(1, std::min(m, n)) * sizeof(T), automatic)); + } + std::vector ptr_a{host_a}; + std::vector ptr_tau{host_tau}; + std::thread mem_free_thread_a( + [=](std::vector pointers_array, + std::vector events_array) { + sycl::event::wait(events_array); + for (auto p : pointers_array) + std::free(p); + }, + ptr_a, events_a); + std::thread mem_free_thread_tau( + [=](std::vector pointers_array, + std::vector events_array) { + sycl::event::wait(events_array); + for (auto p : pointers_array) + std::free(p); + }, + ptr_tau, events_tau); + mem_free_thread_a.detach(); + mem_free_thread_tau.detach(); +#else + std::int64_t m_int64 = n; + std::int64_t n_int64 = n; + std::int64_t lda_int64 = lda; + std::int64_t group_sizes = batch_size; + std::int64_t scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size( + exec_queue, &m_int64, &n_int64, &lda_int64, 1, &group_sizes); + + Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); + T **a_shared = sycl::malloc_shared(batch_size, exec_queue); + T **tau_shared = sycl::malloc_shared(batch_size, exec_queue); + exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)); + exec_queue.memcpy(tau_shared, tau, batch_size * sizeof(T *)).wait(); + + sycl::event e = oneapi::mkl::lapack::geqrf_batch( + exec_queue, &m_int64, &n_int64, (Ty **)a_shared, &lda_int64, (Ty **)tau_shared, 1, + &group_sizes, scratchpad, scratchpad_size); + + std::vector ptrs{scratchpad, a_shared, tau_shared}; + async_dpct_free(ptrs, {e}, exec_queue); +#endif +} + +/// Computes the Euclidean norm of a vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, + int incx, void *result, library_data_t result_type) { + std::uint64_t key = detail::get_type_combination_id(x_type, result_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::nrm2_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::nrm2_impl(q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + detail::nrm2_impl, float>( + q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + detail::nrm2_impl, double>( + q, n, x, incx, result); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + detail::nrm2_impl( + q, n, x, incx, result); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes the dot product of two vectors. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void dot(sycl::queue &q, int n, const void *x, library_data_t x_type, + int incx, const void *y, library_data_t y_type, int incy, + void *result, library_data_t result_type) { + detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); +} + +/// Computes the dot product of two vectors, conjugating the first vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in] y Input vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [out] result The result scalar. +/// \param [in] result_type Data type of the result. +inline void dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, + int incx, const void *y, library_data_t y_type, int incy, + void *result, library_data_t result_type) { + detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, + result_type); +} + +/// Computes the product of a vector by a scalar. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +inline void scal(sycl::queue &q, int n, const void *alpha, + library_data_t alpha_type, void *x, library_data_t x_type, + int incx) { + std::uint64_t key = detail::get_type_combination_id(x_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float): { + detail::scal_impl(q, n, alpha, x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::real_double): { + detail::scal_impl(q, n, alpha, x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float): { + detail::scal_impl, std::complex>(q, n, alpha, + x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double): { + detail::scal_impl, std::complex>( + q, n, alpha, x, incx); + break; + } + case detail::get_type_combination_id(library_data_t::real_half): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + sycl::half alaph_half(alpha_value); + detail::scal_impl(q, n, &alaph_half, x, incx); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes a vector-scalar product and adds the result to a vector. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in] alpha The scale factor alpha. +/// \param [in] alpha_type The data type of alpha. +/// \param [in] x Input vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +inline void axpy(sycl::queue &q, int n, const void *alpha, + library_data_t alpha_type, const void *x, library_data_t x_type, + int incx, void *y, library_data_t y_type, int incy) { + std::uint64_t key = detail::get_type_combination_id(x_type, alpha_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::axpy_impl(q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::axpy_impl(q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::axpy_impl, std::complex>( + q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::axpy_impl, std::complex>( + q, n, alpha, x, incx, y, incy); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + sycl::half alaph_half(alpha_value); + detail::axpy_impl(q, n, &alaph_half, x, incx, y, incy); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Performs rotation of points in the plane. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] n Number of elements in vector x. +/// \param [in, out] x Input/Output vector x. +/// \param [in] x_type Data type of the vector x. +/// \param [in] incx Stride of vector x. +/// \param [in, out] y Input/Output vector y. +/// \param [in] y_type Data type of the vector y. +/// \param [in] incy Stride of vector y. +/// \param [in] c Scaling factor. +/// \param [in] s Scaling factor. +/// \param [in] cs_type Data type of the scaling factors. +inline void rot(sycl::queue &q, int n, void *x, library_data_t x_type, + int incx, void *y, library_data_t y_type, int incy, + const void *c, const void *s, library_data_t cs_type) { + std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::real_float): { + detail::rot_impl, float, float>(q, n, x, incx, y, incy, c, + s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::real_double): { + detail::rot_impl, double, double>(q, n, x, incx, y, incy, c, + s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float): { + detail::rot_impl, float, std::complex>(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double): { + detail::rot_impl, double, std::complex>(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::real_half, + library_data_t::real_half): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + case detail::get_type_combination_id(library_data_t::real_bfloat16, + library_data_t::real_bfloat16): { + detail::rot_impl(q, n, x, incx, y, incy, c, s); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes matrix-matrix product with general matrices. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] a_trans Specifies the operation applied to A. +/// \param [in] b_trans Specifies the operation applied to B. +/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. +/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. +/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). +/// \param [in] alpha Scaling factor for the matrix-matrix product. +/// \param [in] a Input matrix A. +/// \param [in] a_type Data type of the matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] b Input matrix B. +/// \param [in] b_type Data type of the matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] beta Scaling factor for matrix C. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] c_type Data type of the matrix C. +/// \param [in] ldc Leading dimension of C. +/// \param [in] scaling_type Data type of the scaling factors. +inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, library_data_t a_type, + int lda, const void *b, library_data_t b_type, int ldb, + const void *beta, void *c, library_data_t c_type, int ldc, + library_data_t scaling_type) { + bool matched = false; + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) { + scaling_type = library_data_t::complex_float; + } else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, + a, lda, b, ldb, &beta_half, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// Computes a batch of matrix-matrix product with general matrices. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] a_trans Specifies the operation applied to A. +/// \param [in] b_trans Specifies the operation applied to B. +/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. +/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. +/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). +/// \param [in] alpha Scaling factor for the matrix-matrix product. +/// \param [in] a Input matrix A. +/// \param [in] a_type Data type of the matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] b Input matrix B. +/// \param [in] b_type Data type of the matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] beta Scaling factor for matrix C. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] c_type Data type of the matrix C. +/// \param [in] ldc Leading dimension of C. +/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. +/// \param [in] scaling_type Data type of the scaling factors. +inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a[], + library_data_t a_type, int lda, const void *b[], + library_data_t b_type, int ldb, const void *beta, + void *c[], library_data_t c_type, int ldc, + int batch_size, library_data_t scaling_type) { +#ifdef DPCT_USM_LEVEL_NONE + throw std::runtime_error("this API is unsupported when USM level is none"); +#else + bool matched = false; + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) { + scaling_type = library_data_t::complex_float; + } else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, + a, lda, b, ldb, &beta_float, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, + batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +#endif +} + +/// Computes a batch of matrix-matrix product with general matrices. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] a_trans Specifies the operation applied to A. +/// \param [in] b_trans Specifies the operation applied to B. +/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. +/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. +/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). +/// \param [in] alpha Scaling factor for the matrix-matrix product. +/// \param [in] a Input matrix A. +/// \param [in] a_type Data type of the matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] stride_a Stride between the different A matrices. +/// \param [in] b Input matrix B. +/// \param [in] b_type Data type of the matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] stride_b Stride between the different B matrices. +/// \param [in] beta Scaling factor for matrix C. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] c_type Data type of the matrix C. +/// \param [in] ldc Leading dimension of C. +/// \param [in] stride_c Stride between the different C matrices. +/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. +/// \param [in] scaling_type Data type of the scaling factors. +inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, library_data_t a_type, + int lda, long long int stride_a, const void *b, + library_data_t b_type, int ldb, long long int stride_b, + const void *beta, void *c, library_data_t c_type, + int ldc, long long int stride_c, int batch_size, + library_data_t scaling_type) { + bool matched = false; + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) { + scaling_type = library_data_t::complex_float; + } else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, + &beta_half, c, ldc, stride_c, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} + +/// This routines perform a special rank-k update of a symmetric matrix C by +/// general matrices A and B. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] uplo Specifies whether C's data is stored in its upper or lower triangle. +/// \param [in] trans Specifies the operation to apply. +/// \param [in] n The number of rows and columns in C. +/// \param [in] k The inner dimension of matrix multiplications. +/// \param [in] alpha Scaling factor for the rank-k update. +/// \param [in] a Input matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] b Input matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] beta Scaling factor for the rank-k update. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] ldc Leading dimension of C. +template +inline void syrk(sycl::queue &q, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, int n, int k, const T *alpha, + const T *a, int lda, const T *b, int ldb, const T *beta, T *c, + int ldc) { + detail::rk_impl(q, uplo, trans, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); +} + +/// This routines perform a special rank-k update of a Hermitian matrix C by +/// general matrices A and B. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] uplo Specifies whether C's data is stored in its upper or lower triangle. +/// \param [in] trans Specifies the operation to apply. +/// \param [in] n The number of rows and columns in C. +/// \param [in] k The inner dimension of matrix multiplications. +/// \param [in] alpha Scaling factor for the rank-k update. +/// \param [in] a Input matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] b Input matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] beta Scaling factor for the rank-k update. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] ldc Leading dimension of C. +template +inline void herk(sycl::queue &q, oneapi::mkl::uplo uplo, + oneapi::mkl::transpose trans, int n, int k, const T *alpha, + const T *a, int lda, const T *b, int ldb, const Tbeta *beta, + T *c, int ldc) { + detail::rk_impl(q, uplo, trans, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); +} + +/// This routine performs a group of trsm operations. Each trsm solves an +/// equation of the form op(A) * X = alpha * B or X * op(A) = alpha * B. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] left_right Specifies A multiplies X on the left or on the right. +/// \param [in] upper_lower Specifies A is upper or lower triangular. +/// \param [in] trans Specifies the operation applied to A. +/// \param [in] unit_diag Specifies whether A is unit triangular. +/// \param [in] m Number of rows of the B matrices. +/// \param [in] n Number of columns of the B matrices. +/// \param [in] alpha Scaling factor for the solutions. +/// \param [in] a Input matrices A. +/// \param [in] a_type Data type of the matrices A. +/// \param [in] lda Leading dimension of the matrices A. +/// \param [in, out] b Input and output matrices B. +/// \param [in] b_type Data type of the matrices B. +/// \param [in] ldb Leading dimension of the matrices B. +/// \param [in] batch_size Specifies the number of trsm operations to perform. +/// \param [in] scaling_type Data type of the scaling factors. +inline void trsm_batch(sycl::queue &q, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, int m, int n, + const void *alpha, const void **a, library_data_t a_type, + int lda, void **b, library_data_t b_type, int ldb, + int batch_size, library_data_t scaling_type) { +#ifdef DPCT_USM_LEVEL_NONE + throw std::runtime_error("this API is unsupported when USM level is none"); +#else + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, scaling_type); + switch (key) { + case detail::get_type_combination_id(library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::trsm_batch_impl(q, left_right, upper_lower, + trans, unit_diag, m, n, alpha, + a, lda, b, ldb, batch_size); + break; + } + case detail::get_type_combination_id(library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::trsm_batch_impl( + q, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb, batch_size); + break; + } + case detail::get_type_combination_id(library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float): { + detail::trsm_batch_impl, std::complex, + std::complex>(q, left_right, upper_lower, + trans, unit_diag, m, n, alpha, + a, lda, b, ldb, batch_size); + break; + } + case detail::get_type_combination_id(library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double): { + detail::trsm_batch_impl, std::complex, + std::complex>(q, left_right, upper_lower, + trans, unit_diag, m, n, alpha, + a, lda, b, ldb, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +#endif +} + +/// Computes a triangular matrix-general matrix product. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] left_right Specifies A is on the left or right side of the +/// multiplication. +/// \param [in] upper_lower Specifies A is upper or lower triangular. +/// \param [in] trans Specifies the operation applied to A. +/// \param [in] unit_diag Specifies whether A is unit triangular. +/// \param [in] m Number of rows of B. +/// \param [in] n Number of columns of B. +/// \param [in] alpha Scaling factor for the matrix-matrix product. +/// \param [in] a Input matrices A. +/// \param [in] lda Leading dimension of the matrices A. +/// \param [in] b Input matrices B. +/// \param [in] ldb Leading dimension of the matrices B. +/// \param [out] c Output matrices C. +/// \param [in] ldc Leading dimension of the matrices C. +template +inline void trmm(sycl::queue &q, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, int m, int n, const T *alpha, + const T *a, int lda, const T *b, int ldb, T *c, int ldc) { + using Ty = typename DataType::T2; + auto alpha_val = dpct::get_value(alpha, q); + if (b != c) { + dpct::matrix_mem_copy(c, b, ldc, ldb, m, n, dpct::device_to_device, q); + } + auto data_a = detail::get_memory(a); + auto data_c = detail::get_memory(c); + oneapi::mkl::blas::column_major::trmm(q, left_right, upper_lower, trans, + unit_diag, m, n, alpha_val, data_a, lda, + data_c, ldc); +} + +} // namespace dpct +#endif // __DPCT_BLAS_UTILS_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/device.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/device.h new file mode 100644 index 0000000..729ebf6 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/device.h @@ -0,0 +1,781 @@ +//==---- device.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_DEVICE_HPP__ +#define __DPCT_DEVICE_HPP__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#include +#endif +#if defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +namespace dpct { +namespace detail { +static void get_version(const sycl::device &dev, int &major, int &minor) { + // Version string has the following format: + // a. OpenCL + // b. + std::string ver; + ver = dev.get_info(); + std::string::size_type i = 0; + while (i < ver.size()) { + if (isdigit(ver[i])) + break; + i++; + } + major = std::stoi(&(ver[i])); + while (i < ver.size()) { + if (ver[i] == '.') + break; + i++; + } + i++; + minor = std::stoi(&(ver[i])); +} +} // namespace detail + +/// SYCL default exception handler +inline auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (sycl::exception const &e) { + std::cerr << "Caught asynchronous SYCL exception:" << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } + } +}; + +typedef sycl::event *event_ptr; + +typedef sycl::queue *queue_ptr; + +typedef char *device_ptr; + +/// Destroy \p event pointed memory. +/// +/// \param event Pointer to the sycl::event address. +static void destroy_event(event_ptr event) { + delete event; +} + +class device_info { +public: + // get interface + const char *get_name() const { return _name; } + char *get_name() { return _name; } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() const { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else { + return _max_work_item_sizes_i; + } + } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else { + return _max_work_item_sizes_i; + } + } + bool get_host_unified_memory() const { return _host_unified_memory; } + int get_major_version() const { return _major; } + int get_minor_version() const { return _minor; } + int get_integrated() const { return _integrated; } + int get_max_clock_frequency() const { return _frequency; } + int get_max_compute_units() const { return _max_compute_units; } + int get_max_work_group_size() const { return _max_work_group_size; } + int get_max_sub_group_size() const { return _max_sub_group_size; } + int get_max_work_items_per_compute_unit() const { + return _max_work_items_per_compute_unit; + } + int get_max_register_size_per_work_group() const { + return _max_register_size_per_work_group; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() const { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + size_t get_global_mem_size() const { return _global_mem_size; } + size_t get_local_mem_size() const { return _local_mem_size; } + /// Returns the maximum clock rate of device's global memory in kHz. If + /// compiler does not support this API then returns default value 3200000 kHz. + unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } + /// Returns the maximum bus width between device and memory in bits. If + /// compiler does not support this API then returns default value 64 bits. + unsigned int get_memory_bus_width() const { return _memory_bus_width; } + uint32_t get_device_id() const { return _device_id; } + std::array get_uuid() const { return _uuid; } + /// Returns global memory cache size in bytes. + unsigned int get_global_mem_cache_size() const { + return _global_mem_cache_size; + } + + // set interface + void set_name(const char* name) { + size_t length = strlen(name); + if (length < 256) { + std::memcpy(_name, name, length + 1); + } else { + std::memcpy(_name, name, 255); + _name[255] = '\0'; + } + } + void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) { + for (int i = 0; i < 3; ++i) + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + [[deprecated]] void + set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) { + for (int i = 0; i < 3; ++i) { + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + } + void set_host_unified_memory(bool host_unified_memory) { + _host_unified_memory = host_unified_memory; + } + void set_major_version(int major) { _major = major; } + void set_minor_version(int minor) { _minor = minor; } + void set_integrated(int integrated) { _integrated = integrated; } + void set_max_clock_frequency(int frequency) { _frequency = frequency; } + void set_max_compute_units(int max_compute_units) { + _max_compute_units = max_compute_units; + } + void set_global_mem_size(size_t global_mem_size) { + _global_mem_size = global_mem_size; + } + void set_local_mem_size(size_t local_mem_size) { + _local_mem_size = local_mem_size; + } + void set_max_work_group_size(int max_work_group_size) { + _max_work_group_size = max_work_group_size; + } + void set_max_sub_group_size(int max_sub_group_size) { + _max_sub_group_size = max_sub_group_size; + } + void + set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) { + _max_work_items_per_compute_unit = max_work_items_per_compute_unit; + } + void set_max_nd_range_size(int max_nd_range_size[]) { + for (int i = 0; i < 3; i++) { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_memory_clock_rate(unsigned int memory_clock_rate) { + _memory_clock_rate = memory_clock_rate; + } + void set_memory_bus_width(unsigned int memory_bus_width) { + _memory_bus_width = memory_bus_width; + } + void + set_max_register_size_per_work_group(int max_register_size_per_work_group) { + _max_register_size_per_work_group = max_register_size_per_work_group; + } + void set_device_id(uint32_t device_id) { + _device_id = device_id; + } + void set_uuid(std::array uuid) { + _uuid = std::move(uuid); + } + void set_global_mem_cache_size(unsigned int global_mem_cache_size) { + _global_mem_cache_size = global_mem_cache_size; + } + +private: + char _name[256]; + int _max_work_item_sizes_i[3]; + bool _host_unified_memory = false; + int _major; + int _minor; + int _integrated = 0; + int _frequency; + // Set estimated value 3200000 kHz as default value. + unsigned int _memory_clock_rate = 3200000; + // Set estimated value 64 bits as default value. + unsigned int _memory_bus_width = 64; + unsigned int _global_mem_cache_size; + int _max_compute_units; + int _max_work_group_size; + int _max_sub_group_size; + int _max_work_items_per_compute_unit; + int _max_register_size_per_work_group; + size_t _global_mem_size; + size_t _local_mem_size; + size_t _max_nd_range_size[3]; + int _max_nd_range_size_i[3]; + uint32_t _device_id; + std::array _uuid; +}; + +static int get_major_version(const sycl::device &dev) { + int major, minor; + detail::get_version(dev, major, minor); + return major; +} + +static int get_minor_version(const sycl::device &dev) { + int major, minor; + detail::get_version(dev, major, minor); + return minor; +} + +static void get_device_info(device_info &out, const sycl::device &dev) { + device_info prop; + prop.set_name(dev.get_info().c_str()); + + int major, minor; + detail::get_version(dev, major, minor); + prop.set_major_version(major); + prop.set_minor_version(minor); + + prop.set_max_work_item_sizes( +#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) + // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes + // is an enum class element + dev.get_info()); +#else + // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by + // an int + dev.get_info>()); +#endif + prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); + + prop.set_max_clock_frequency( + dev.get_info() * 1000); + + prop.set_max_compute_units( + dev.get_info()); + prop.set_max_work_group_size( + dev.get_info()); + prop.set_global_mem_size(dev.get_info()); + prop.set_local_mem_size(dev.get_info()); + +#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) + if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) { + unsigned int tmp = + dev.get_info(); + if (tmp != 0) + prop.set_memory_clock_rate(1000 * tmp); + } + if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) { + prop.set_memory_bus_width( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_id)) { + prop.set_device_id( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) { + prop.set_uuid(dev.get_info()); + } +#elif defined(_MSC_VER) && !defined(__clang__) +#pragma message("get_device_info: querying memory_clock_rate and \ +memory_bus_width are not supported by the compiler used. \ +Use 3200000 kHz as memory_clock_rate default value. \ +Use 64 bits as memory_bus_width default value.") +#else +#warning "get_device_info: querying memory_clock_rate and \ +memory_bus_width are not supported by the compiler used. \ +Use 3200000 kHz as memory_clock_rate default value. \ +Use 64 bits as memory_bus_width default value." +#endif + + size_t max_sub_group_size = 1; + std::vector sub_group_sizes = + dev.get_info(); + + for (const auto &sub_group_size : sub_group_sizes) { + if (max_sub_group_size < sub_group_size) + max_sub_group_size = sub_group_size; + } + + prop.set_max_sub_group_size(max_sub_group_size); + + prop.set_max_work_items_per_compute_unit( + dev.get_info()); + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + prop.set_max_nd_range_size(max_nd_range_size); + + // Estimates max register size per work group, feel free to update the value + // according to device properties. + prop.set_max_register_size_per_work_group(65536); + + prop.set_global_mem_cache_size( + dev.get_info()); + out = prop; +} + +/// dpct device extension +class device_ext : public sycl::device { + typedef std::mutex mutex_type; + +public: + device_ext() : sycl::device(), _ctx(*this) {} + ~device_ext() { + std::lock_guard lock(m_mutex); + clear_queues(); + } + device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) { + std::lock_guard lock(m_mutex); + init_queues(); + } + + int is_native_atomic_supported() { return 0; } + int get_major_version() const { + return dpct::get_major_version(*this); + } + + int get_minor_version() const { + return dpct::get_minor_version(*this); + } + + int get_max_compute_units() const { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { return get_device_info().get_integrated(); } + + int get_max_sub_group_size() const { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const { + return get_info(); + } + + size_t get_global_mem_size() const { + return get_device_info().get_global_mem_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the SYCL device. + /// \param [out] total_memory The number of bytes of total memory on the SYCL device. + void get_memory_info(size_t &free_memory, size_t &total_memory) { +#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) + if (!has(sycl::aspect::ext_intel_free_memory)) { + std::cerr << "get_memory_info: ext_intel_free_memory is not supported." << std::endl; + free_memory = 0; + } else { + free_memory = get_info(); + } +#else + std::cerr << "get_memory_info: ext_intel_free_memory is not supported." << std::endl; + free_memory = 0; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma message("Querying the number of bytes of free memory is not supported") +#else +#warning "Querying the number of bytes of free memory is not supported" +#endif +#endif + total_memory = get_device_info().get_global_mem_size(); + } + + void get_device_info(device_info &out) const { + dpct::get_device_info(out, *this); + } + + device_info get_device_info() const { + device_info prop; + dpct::get_device_info(prop, *this); + return prop; + } + + void reset() { + std::lock_guard lock(m_mutex); + clear_queues(); + init_queues(); + } + + sycl::queue &in_order_queue() { return *_q_in_order; } + + sycl::queue &out_of_order_queue() { return *_q_out_of_order; } + + sycl::queue &default_queue() { +#ifdef DPCT_USM_LEVEL_NONE + return out_of_order_queue(); +#else + return in_order_queue(); +#endif // DPCT_USM_LEVEL_NONE + } + + void queues_wait_and_throw() { + std::unique_lock lock(m_mutex); + std::vector> current_queues( + _queues); + lock.unlock(); + for (const auto &q : current_queues) { + q->wait_and_throw(); + } + // Guard the destruct of current_queues to make sure the ref count is safe. + lock.lock(); + } + + sycl::queue *create_queue(bool enable_exception_handler = false) { +#ifdef DPCT_USM_LEVEL_NONE + return create_out_of_order_queue(enable_exception_handler); +#else + return create_in_order_queue(enable_exception_handler); +#endif // DPCT_USM_LEVEL_NONE + } + + sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler); + } + + void destroy_queue(sycl::queue *&queue) { + std::lock_guard lock(m_mutex); + _queues.erase(std::remove_if(_queues.begin(), _queues.end(), + [=](const std::shared_ptr &q) -> bool { + return q.get() == queue; + }), + _queues.end()); + queue = nullptr; + } + void set_saved_queue(sycl::queue* q) { + std::lock_guard lock(m_mutex); + _saved_queue = q; + } + sycl::queue *get_saved_queue() const { + std::lock_guard lock(m_mutex); + return _saved_queue; + } + sycl::context get_context() const { return _ctx; } + +private: + void clear_queues() { + _queues.clear(); + _q_in_order = _q_out_of_order = _saved_queue = nullptr; + } + + void init_queues() { + _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); + _q_out_of_order = create_queue_impl(true); + _saved_queue = &default_queue(); + } + + /// Caller should acquire resource \p m_mutex before calling this function. + template + sycl::queue *create_queue_impl(bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; + } + _queues.push_back(std::make_shared( + _ctx, *this, eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back().get(); + } + + void get_version(int &major, int &minor) const { + detail::get_version(*this, major, minor); + } + sycl::queue *_q_in_order, *_q_out_of_order; + sycl::queue *_saved_queue; + sycl::context _ctx; + std::vector> _queues; + mutable mutex_type m_mutex; +}; + +static inline unsigned int get_tid() { +#if defined(__linux__) + return syscall(SYS_gettid); +#elif defined(_WIN64) + return GetCurrentThreadId(); +#else +#error "Only support Windows and Linux." +#endif +} + +/// device manager +class dev_mgr { +public: + device_ext ¤t_device() { + unsigned int dev_id=current_device_id(); + check_id(dev_id); + return *_devs[dev_id]; + } + device_ext &cpu_device() const { + std::lock_guard lock(m_mutex); + if (_cpu_device == -1) { + throw std::runtime_error("no valid cpu device"); + } else { + return *_devs[_cpu_device]; + } + } + device_ext &get_device(unsigned int id) const { + std::lock_guard lock(m_mutex); + check_id(id); + return *_devs[id]; + } + unsigned int current_device_id() const { + std::lock_guard lock(m_mutex); + auto it=_thread2dev_map.find(get_tid()); + if(it != _thread2dev_map.end()) + return it->second; + return DEFAULT_DEVICE_ID; + } + +/// Select device with a device ID. +/// \param [in] id The id of the device which can +/// be obtained through get_device_id(const sycl::device). + void select_device(unsigned int id) { + std::lock_guard lock(m_mutex); + check_id(id); + _thread2dev_map[get_tid()]=id; + } + unsigned int device_count() { return _devs.size(); } + + unsigned int get_device_id(const sycl::device &dev) { + unsigned int id = 0; + for(auto dev_item : _devs) { + if (*dev_item == dev) { + break; + } + id++; + } + return id; + } + + template + std::enable_if_t< + std::is_invocable_r_v> + select_device(const DeviceSelector &selector = sycl::gpu_selector_v) { + sycl::device selected_device = sycl::device(selector); + unsigned int selected_device_id = get_device_id(selected_device); + select_device(selected_device_id); + } + + /// Returns the instance of device manager singleton. + static dev_mgr &instance() { + static dev_mgr d_m; + return d_m; + } + dev_mgr(const dev_mgr &) = delete; + dev_mgr &operator=(const dev_mgr &) = delete; + dev_mgr(dev_mgr &&) = delete; + dev_mgr &operator=(dev_mgr &&) = delete; + +private: + mutable std::recursive_mutex m_mutex; + dev_mgr() { + sycl::device default_device = + sycl::device(sycl::default_selector_v); + _devs.push_back(std::make_shared(default_device)); + + std::vector sycl_all_devs = + sycl::device::get_devices(sycl::info::device_type::all); + // Collect other devices except for the default device. + if (default_device.is_cpu()) + _cpu_device = 0; + for (auto &dev : sycl_all_devs) { + if (dev == default_device) { + continue; + } + _devs.push_back(std::make_shared(dev)); + if (_cpu_device == -1 && dev.is_cpu()) { + _cpu_device = _devs.size() - 1; + } + } + } + void check_id(unsigned int id) const { + if (id >= _devs.size()) { + throw std::runtime_error("invalid device id"); + } + } + std::vector> _devs; + /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current + /// thread id in _thread2dev_map, which means default device should be used + /// for the current thread. + const unsigned int DEFAULT_DEVICE_ID = 0; + /// thread-id to device-id map. + std::map _thread2dev_map; + int _cpu_device = -1; +}; + +/// Util function to get the default queue of current selected device depends on +/// the USM config. Return the default out-of-ordered queue when USM-none is +/// enabled, otherwise return the default in-ordered queue. +static inline sycl::queue &get_default_queue() { + return dev_mgr::instance().current_device().default_queue(); +} + +/// Util function to get the default in-ordered queue of current device in +/// dpct device manager. +static inline sycl::queue &get_in_order_queue() { + return dev_mgr::instance().current_device().in_order_queue(); +} + +/// Util function to get the default out-of-ordered queue of current device in +/// dpct device manager. +static inline sycl::queue &get_out_of_order_queue() { + return dev_mgr::instance().current_device().out_of_order_queue(); +} + +/// Util function to get the id of current device in +/// dpct device manager. +static inline unsigned int get_current_device_id() { + return dev_mgr::instance().current_device_id(); +} + +/// Util function to get the current device. +static inline device_ext &get_current_device() { + return dev_mgr::instance().current_device(); +} + +/// Util function to get a device by id. +static inline device_ext &get_device(unsigned int id) { + return dev_mgr::instance().get_device(id); +} + +/// Util function to get the context of the default queue of current +/// device in dpct device manager. +static inline sycl::context get_default_context() { + return dpct::get_current_device().get_context(); +} + +/// Util function to get a CPU device. +static inline device_ext &cpu_device() { + return dev_mgr::instance().cpu_device(); +} + +static inline unsigned int select_device(unsigned int id) { + dev_mgr::instance().select_device(id); + return id; +} + +template +static inline std::enable_if_t< + std::is_invocable_r_v> +select_device(const DeviceSelector &selector = sycl::gpu_selector_v) { + dev_mgr::instance().select_device(selector); +} + +static inline unsigned int get_device_id(const sycl::device &dev){ + return dev_mgr::instance().get_device_id(dev); +} + +/// Util function to check whether a device supports some kinds of sycl::aspect. +inline void +has_capability_or_fail(const sycl::device &dev, + const std::initializer_list &props) { + for (const auto &it : props) { + if (dev.has(it)) + continue; + switch (it) { + case sycl::aspect::fp64: + throw std::runtime_error("'double' is not supported in '" + + dev.get_info() + + "' device"); + break; + case sycl::aspect::fp16: + throw std::runtime_error("'half' is not supported in '" + + dev.get_info() + + "' device"); + break; + default: +#define __SYCL_ASPECT(ASPECT, ID) \ + case sycl::aspect::ASPECT: \ + return #ASPECT; +#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) +#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string { + switch (AspectNum) { +#include +#include + default: + return "unknown aspect"; + } + }; +#undef __SYCL_ASPECT_DEPRECATED_ALIAS +#undef __SYCL_ASPECT_DEPRECATED +#undef __SYCL_ASPECT + throw std::runtime_error( + "'" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); + } + break; + } +} +} // namespace dpct + +#endif // __DPCT_DEVICE_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/dpct.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/dpct.h new file mode 100644 index 0000000..99f48e2 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/dpct.h @@ -0,0 +1,62 @@ +//==---- dpct.h ---------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_HPP__ +#define __DPCT_HPP__ + +#include +#include +#include +#include + +template class dpct_kernel_name; +template class dpct_kernel_scalar; + +#include "atomic.h" +#include "device.h" +#include "image.h" +#include "kernel.h" +#include "math.h" +#include "memory.h" +#include "util.h" + +#if defined(_MSC_VER) +#define __dpct_align__(n) __declspec(align(n)) +#define __dpct_inline__ __forceinline +#else +#define __dpct_align__(n) __attribute__((aligned(n))) +#define __dpct_inline__ __inline__ __attribute__((always_inline)) +#endif + +#if defined(_MSC_VER) +#define __dpct_noinline__ __declspec(noinline) +#else +#define __dpct_noinline__ __attribute__((noinline)) +#endif + +#define DPCT_COMPATIBILITY_TEMP (600) + +namespace dpct{ +enum error_code { success = 0, default_error = 999 }; +} + +#define DPCT_CHECK_ERROR(expr) \ + [&]() { \ + try { \ + expr; \ + return dpct::success; \ + } catch (std::exception const &e) { \ + std::cerr << e.what() << std::endl; \ + return dpct::default_error; \ + } \ + }() + +#define DPCT_PI_F (3.14159274101257f) +#define DPCT_PI (3.141592653589793115998) + +#endif // __DPCT_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/image.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/image.h new file mode 100644 index 0000000..8c162f7 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/image.h @@ -0,0 +1,891 @@ +//==---- image.hpp --------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_IMAGE_HPP__ +#define __DPCT_IMAGE_HPP__ + +#include + +#include "memory.h" +#include "util.h" + +namespace dpct { + +enum class image_channel_data_type { + signed_int, + unsigned_int, + fp, +}; + +class image_channel; +class image_wrapper_base; +namespace detail { +/// Image object type traits, with accessor type and sampled data type defined. +/// The data type of an image accessor must be one of sycl::int4, sycl::uint4, +/// sycl::float4 and sycl::half4. The data type of accessors with 8bits/16bits +/// channel width will be 32 bits. sycl::half is an exception. +template struct image_trait { + using acc_data_t = sycl::vec; + template + using accessor_t = + sycl::accessor; + template + using array_accessor_t = + sycl::accessor; + using data_t = T; + using elem_t = T; + static constexpr image_channel_data_type data_type = + std::is_integral::value + ? (std::is_signed::value ? image_channel_data_type::signed_int + : image_channel_data_type::unsigned_int) + : image_channel_data_type::fp; + static constexpr int channel_num = 1; +}; +template <> +struct image_trait : public image_trait { + using data_t = std::uint8_t; + using elem_t = data_t; +}; +template <> +struct image_trait + : public image_trait { + using data_t = std::uint16_t; + using elem_t = data_t; +}; +template <> +struct image_trait : public image_trait { + using data_t = std::int8_t; + using elem_t = data_t; +}; +template <> +struct image_trait : public image_trait { + using data_t = std::int16_t; + using elem_t = data_t; +}; +template <> +struct image_trait + : public image_trait::value, signed char, unsigned char>::type> {}; + +template +struct image_trait> : public image_trait {}; + +template +struct image_trait> : public image_trait { + using data_t = sycl::vec; + static constexpr int channel_num = 2; +}; + +template +struct image_trait> + : public image_trait> { + static constexpr int channel_num = 3; +}; + +template +struct image_trait> : public image_trait { + using data_t = sycl::vec; + static constexpr int channel_num = 4; +}; + +/// Functor to fetch data from read result of an image accessor. +template struct fetch_data { + using return_t = typename image_trait::data_t; + using acc_data_t = typename image_trait::acc_data_t; + + return_t operator()(acc_data_t &&original_data) { + return (return_t)original_data.r(); + } +}; +template +struct fetch_data> : public fetch_data {}; +template struct fetch_data> { + using return_t = typename image_trait>::data_t; + using acc_data_t = typename image_trait>::acc_data_t; + + return_t operator()(acc_data_t &&origin_data) { + return return_t(origin_data.r(), origin_data.g()); + } +}; +template +struct fetch_data> + : public fetch_data> {}; +template struct fetch_data> { + using return_t = typename image_trait>::data_t; + using acc_data_t = typename image_trait>::acc_data_t; + + return_t operator()(acc_data_t &&origin_data) { + return return_t(origin_data.r(), origin_data.g(), origin_data.b(), + origin_data.a()); + } +}; + +/// Create image according with given type \p T and \p dims. +template static image_wrapper_base *create_image_wrapper(int dims); + +/// Create image with given data type \p T, channel order and dims +template +static image_wrapper_base *create_image_wrapper(unsigned channel_num, int dims); + +/// Create image with channel info and specified dimensions. +static image_wrapper_base *create_image_wrapper(image_channel channel, int dims); + +} // namespace detail + +/// Image channel info, include channel number, order, data width and type +class image_channel { + image_channel_data_type _type = image_channel_data_type::signed_int; + /// Number of channels. + unsigned _channel_num = 0; + /// Total size of all channels in bytes. + unsigned _total_size = 0; + /// Size of each channel in bytes. + unsigned _channel_size = 0; + +public: + /// Create image channel info according to template argument \p T. + template static image_channel create() { + image_channel channel; + channel.set_channel_size(detail::image_trait::channel_num, + sizeof(typename detail::image_trait::elem_t) * + 8); + channel.set_channel_data_type(detail::image_trait::data_type); + return channel; + } + + image_channel() = default; + + image_channel_data_type get_channel_data_type() { return _type; } + void set_channel_data_type(image_channel_data_type type) { _type = type; } + + unsigned get_total_size() { return _total_size; } + + unsigned get_channel_num() { return _channel_num; } + void set_channel_num(unsigned channel_num) { + _channel_num = channel_num; + _total_size = _channel_size * _channel_num; + } + + /// image_channel constructor. + /// \param r Channel r width in bits. + /// \param g Channel g width in bits. Should be same with \p r, or zero. + /// \param b Channel b width in bits. Should be same with \p g, or zero. + /// \param a Channel a width in bits. Should be same with \p b, or zero. + /// \param data_type Image channel data type: signed_nt, unsigned_int or fp. + image_channel(int r, int g, int b, int a, image_channel_data_type data_type) { + _type = data_type; + if (a) { + assert(r == a && "SYCL doesn't support different channel size"); + assert(r == b && "SYCL doesn't support different channel size"); + assert(r == g && "SYCL doesn't support different channel size"); + set_channel_size(4, a); + } else if (b) { + assert(r == b && "SYCL doesn't support different channel size"); + assert(r == g && "SYCL doesn't support different channel size"); + set_channel_size(3, b); + } else if (g) { + assert(r == g && "SYCL doesn't support different channel size"); + set_channel_size(2, g); + } else { + set_channel_size(1, r); + } + } + + sycl::image_channel_type get_channel_type() const { + if (_channel_size == 4) { + if (_type == image_channel_data_type::signed_int) + return sycl::image_channel_type::signed_int32; + else if (_type == image_channel_data_type::unsigned_int) + return sycl::image_channel_type::unsigned_int32; + else if (_type == image_channel_data_type::fp) + return sycl::image_channel_type::fp32; + } else if (_channel_size == 2) { + if (_type == image_channel_data_type::signed_int) + return sycl::image_channel_type::signed_int16; + else if (_type == image_channel_data_type::unsigned_int) + return sycl::image_channel_type::unsigned_int16; + else if (_type == image_channel_data_type::fp) + return sycl::image_channel_type::fp16; + } else { + if (_type == image_channel_data_type::signed_int) + return sycl::image_channel_type::signed_int8; + else if (_type == image_channel_data_type::unsigned_int) + return sycl::image_channel_type::unsigned_int8; + } + assert(false && "unexpected channel data kind and channel size"); + return sycl::image_channel_type::signed_int32; + } + void set_channel_type(sycl::image_channel_type type) { + switch (type) { + case sycl::image_channel_type::unsigned_int8: + _type = image_channel_data_type::unsigned_int; + _channel_size = 1; + break; + case sycl::image_channel_type::unsigned_int16: + _type = image_channel_data_type::unsigned_int; + _channel_size = 2; + break; + case sycl::image_channel_type::unsigned_int32: + _type = image_channel_data_type::unsigned_int; + _channel_size = 4; + break; + case sycl::image_channel_type::signed_int8: + _type = image_channel_data_type::signed_int; + _channel_size = 1; + break; + case sycl::image_channel_type::signed_int16: + _type = image_channel_data_type::signed_int; + _channel_size = 2; + break; + case sycl::image_channel_type::signed_int32: + _type = image_channel_data_type::signed_int; + _channel_size = 4; + break; + case sycl::image_channel_type::fp16: + _type = image_channel_data_type::fp; + _channel_size = 2; + break; + case sycl::image_channel_type::fp32: + _type = image_channel_data_type::fp; + _channel_size = 4; + break; + default: + break; + } + _total_size = _channel_size * _channel_num; + } + + sycl::image_channel_order get_channel_order() const { + switch (_channel_num) { + case 1: + return sycl::image_channel_order::r; + case 2: + return sycl::image_channel_order::rg; + case 3: + return sycl::image_channel_order::rgb; + case 4: + return sycl::image_channel_order::rgba; + default: + return sycl::image_channel_order::r; + } + } + /// Get the size for each channel in bits. + unsigned get_channel_size() const { return _channel_size * 8; } + + /// Set channel size. + /// \param in_channel_num Channels number to set. + /// \param channel_size Size for each channel in bits. + void set_channel_size(unsigned in_channel_num, + unsigned channel_size) { + if (in_channel_num < _channel_num) + return; + _channel_num = in_channel_num; + _channel_size = channel_size / 8; + _total_size = _channel_size * _channel_num; + } +}; + +/// 2D or 3D matrix data for image. +class image_matrix { + image_channel _channel; + int _range[3] = {1, 1, 1}; + int _dims = 0; + void *_host_data = nullptr; + + /// Set range of each dimension. + template void set_range(sycl::range range) { + for (int i = 0; i < dimensions; ++i) + _range[i] = range[i]; + _dims = dimensions; + } + + template + sycl::range get_range(integer_sequence) { + return sycl::range(_range[DimIdx]...); + } + +public: + /// Constructor with channel info and dimension size info. + template + image_matrix(image_channel channel, sycl::range range) + : _channel(channel) { + set_range(range); + _host_data = std::malloc(range.size() * _channel.get_total_size()); + } + image_matrix(sycl::image_channel_type channel_type, unsigned channel_num, + size_t x, size_t y) { + _channel.set_channel_type(channel_type); + _channel.set_channel_num(channel_num); + _dims = 1; + _range[0] = x; + if (y) { + _dims = 2; + _range[1] = y; + } + _host_data = std::malloc(_range[0] * _range[1] * _channel.get_total_size()); + } + + /// Construct a new image class with the matrix data. + template sycl::image *create_image() { + return create_image(_channel); + } + /// Construct a new image class with the matrix data. + template + sycl::image *create_image(image_channel channel) { + return new sycl::image( + _host_data, channel.get_channel_order(), channel.get_channel_type(), + get_range(make_index_sequence()), + sycl::property::image::use_host_ptr()); + } + + /// Get channel info. + inline image_channel get_channel() { return _channel; } + /// Get range of the image. + sycl::range<3> get_range() { + return sycl::range<3>(_range[0], _range[1], _range[2]); + } + /// Get matrix dims. + inline int get_dims() { return _dims; } + /// Convert to pitched data. + pitched_data to_pitched_data() { + return pitched_data(_host_data, _range[0], _range[0], _range[1]); + } + + ~image_matrix() { + if (_host_data) + std::free(_host_data); + _host_data = nullptr; + } +}; +using image_matrix_p = image_matrix *; + +enum class image_data_type { matrix, linear, pitch, unsupport }; + +/// Image data info. +class image_data { +public: + image_data() { _type = image_data_type::unsupport; } + image_data(image_matrix_p matrix_data) { set_data(matrix_data); } + image_data(void *data_ptr, size_t x_size, image_channel channel) { + set_data(data_ptr, x_size, channel); + } + image_data(void *data_ptr, size_t x_size, size_t y_size, size_t pitch_size, + image_channel channel) { + set_data(data_ptr, x_size, y_size, pitch_size, channel); + } + void set_data(image_matrix_p matrix_data) { + _type = image_data_type::matrix; + _data = matrix_data; + _channel = matrix_data->get_channel(); + } + void set_data(void *data_ptr, size_t x_size, image_channel channel) { + _type = image_data_type::linear; + _data = data_ptr; + _x = x_size; + _channel = channel; + } + void set_data(void *data_ptr, size_t x_size, size_t y_size, size_t pitch_size, + image_channel channel) { + _type = image_data_type::pitch; + _data = data_ptr; + _x = x_size; + _y = y_size; + _pitch = pitch_size; + _channel = channel; + } + + image_data_type get_data_type() const { return _type; } + void set_data_type(image_data_type type) { _type = type; } + + void *get_data_ptr() const { return _data; } + void set_data_ptr(void *data) { _data = data; } + + size_t get_x() const { return _x; } + void set_x(size_t x) { _x = x; } + + size_t get_y() const { return _y; } + void set_y(size_t y) { _y = y; } + + size_t get_pitch() const { return _pitch; } + void set_pitch(size_t pitch) { _pitch = pitch; } + + image_channel get_channel() const { return _channel; } + void set_channel(image_channel channel) { _channel = channel; } + + image_channel_data_type get_channel_data_type() { + return _channel.get_channel_data_type(); + } + void set_channel_data_type(image_channel_data_type type) { + _channel.set_channel_data_type(type); + } + + unsigned get_channel_size() { return _channel.get_channel_size(); } + void set_channel_size(unsigned channel_num, unsigned channel_size) { + return _channel.set_channel_size(channel_num, channel_size); + } + + unsigned get_channel_num() { return _channel.get_channel_num(); } + void set_channel_num(unsigned num) { + return _channel.set_channel_num(num); + } + + sycl::image_channel_type get_channel_type() { + return _channel.get_channel_type(); + } + void set_channel_type(sycl::image_channel_type type) { + return _channel.set_channel_type(type); + } + +private: + image_data_type _type; + void *_data = nullptr; + size_t _x, _y, _pitch; + image_channel _channel; +}; + +/// Image sampling info, include addressing mode, filtering mode and +/// normalization info. +class sampling_info { + sycl::addressing_mode _addressing_mode = + sycl::addressing_mode::clamp_to_edge; + sycl::filtering_mode _filtering_mode = sycl::filtering_mode::nearest; + sycl::coordinate_normalization_mode _coordinate_normalization_mode = + sycl::coordinate_normalization_mode::unnormalized; + +public: + sycl::addressing_mode get_addressing_mode() { return _addressing_mode; } + void set(sycl::addressing_mode addressing_mode) { _addressing_mode = addressing_mode; } + + sycl::filtering_mode get_filtering_mode() { return _filtering_mode; } + void set(sycl::filtering_mode filtering_mode) { _filtering_mode = filtering_mode; } + + sycl::coordinate_normalization_mode get_coordinate_normalization_mode() { + return _coordinate_normalization_mode; + } + void set(sycl::coordinate_normalization_mode coordinate_normalization_mode) { + _coordinate_normalization_mode = coordinate_normalization_mode; + } + + bool is_coordinate_normalized() { + return _coordinate_normalization_mode == + sycl::coordinate_normalization_mode::normalized; + } + void set_coordinate_normalization_mode(int is_normalized) { + _coordinate_normalization_mode = + is_normalized ? sycl::coordinate_normalization_mode::normalized + : sycl::coordinate_normalization_mode::unnormalized; + } + void + set(sycl::addressing_mode addressing_mode, + sycl::filtering_mode filtering_mode, + sycl::coordinate_normalization_mode coordinate_normalization_mode) { + set(addressing_mode); + set(filtering_mode); + set(coordinate_normalization_mode); + } + void set(sycl::addressing_mode addressing_mode, + sycl::filtering_mode filtering_mode, int is_normalized) { + set(addressing_mode); + set(filtering_mode); + set_coordinate_normalization_mode(is_normalized); + } + + sycl::sampler get_sampler() { + return sycl::sampler(_coordinate_normalization_mode, _addressing_mode, + _filtering_mode); + } +}; + +/// Image base class. +class image_wrapper_base { + sampling_info _sampling_info; + image_data _data; + +public: + virtual ~image_wrapper_base() = 0; + + void attach(image_data data) { set_data(data); } + /// Attach matrix data to this class. + void attach(image_matrix *matrix) { + detach(); + image_wrapper_base::set_data(image_data(matrix)); + } + /// Attach matrix data to this class. + void attach(image_matrix *matrix, image_channel channel) { + attach(matrix); + image_wrapper_base::set_channel(channel); + } + /// Attach linear data to this class. + void attach(const void *ptr, size_t count) { + attach(ptr, count, get_channel()); + } + /// Attach linear data to this class. + void attach(const void *ptr, size_t count, image_channel channel) { + detach(); + image_wrapper_base::set_data(image_data(const_cast(ptr), count, channel)); + } + /// Attach 2D data to this class. + void attach(const void *data, size_t x, size_t y, size_t pitch) { + attach(data, x, y, pitch, get_channel()); + } + /// Attach 2D data to this class. + void attach(const void *data, size_t x, size_t y, size_t pitch, + image_channel channel) { + detach(); + image_wrapper_base::set_data( + image_data(const_cast(data), x, y, pitch, channel)); + } + /// Detach data. + virtual void detach() {} + + sampling_info get_sampling_info() { return _sampling_info; } + void set_sampling_info(sampling_info info) { + _sampling_info = info; + } + const image_data &get_data() { return _data; } + void set_data(image_data data) { _data = data; } + + image_channel get_channel() { return _data.get_channel(); } + void set_channel(image_channel channel) { _data.set_channel(channel); } + + image_channel_data_type get_channel_data_type() { + return _data.get_channel_data_type(); + } + void set_channel_data_type(image_channel_data_type type) { + _data.set_channel_data_type(type); + } + + unsigned get_channel_size() { return _data.get_channel_size(); } + void set_channel_size(unsigned channel_num, unsigned channel_size) { + return _data.set_channel_size(channel_num, channel_size); + } + + sycl::addressing_mode get_addressing_mode() { + return _sampling_info.get_addressing_mode(); + } + void set(sycl::addressing_mode addressing_mode) { + _sampling_info.set(addressing_mode); + } + + sycl::filtering_mode get_filtering_mode() { + return _sampling_info.get_filtering_mode(); + } + void set(sycl::filtering_mode filtering_mode) { + _sampling_info.set(filtering_mode); + } + + sycl::coordinate_normalization_mode get_coordinate_normalization_mode() { + return _sampling_info.get_coordinate_normalization_mode(); + } + void + set(sycl::coordinate_normalization_mode coordinate_normalization_mode) { + _sampling_info.set(coordinate_normalization_mode); + } + + bool is_coordinate_normalized() { + return _sampling_info.is_coordinate_normalized(); + } + void set_coordinate_normalization_mode(int is_normalized) { + _sampling_info.set_coordinate_normalization_mode(is_normalized); + } + void + set(sycl::addressing_mode addressing_mode, + sycl::filtering_mode filtering_mode, + sycl::coordinate_normalization_mode coordinate_normalization_mode) { + set(addressing_mode); + set(filtering_mode); + set(coordinate_normalization_mode); + } + void set(sycl::addressing_mode addressing_mode, + sycl::filtering_mode filtering_mode, int is_normalized) { + set(addressing_mode); + set(filtering_mode); + set_coordinate_normalization_mode(is_normalized); + } + + unsigned get_channel_num() { return _data.get_channel_num(); } + void set_channel_num(unsigned num) { + return _data.set_channel_num(num); + } + + sycl::image_channel_type get_channel_type() { + return _data.get_channel_type(); + } + void set_channel_type(sycl::image_channel_type type) { + return _data.set_channel_type(type); + } + + sycl::sampler get_sampler() { return _sampling_info.get_sampler(); } +}; +inline image_wrapper_base::~image_wrapper_base() {} +using image_wrapper_base_p = image_wrapper_base *; + +template class image_accessor_ext; + +/// Image class, wrapper of sycl::image. +template class image_wrapper : public image_wrapper_base { + sycl::image *_image = nullptr; + +#ifndef DPCT_USM_LEVEL_NONE + std::vector _host_buffer; +#endif + + void create_image(sycl::queue q) { + auto &data = get_data(); + if (data.get_data_type() == image_data_type::matrix) { + _image = static_cast(data.get_data_ptr()) + ->create_image(data.get_channel()); + return; + } + auto ptr = data.get_data_ptr(); + auto channel = data.get_channel(); + + if (detail::get_pointer_attribute(q, ptr) == detail::pointer_access_attribute::device_only) { +#ifdef DPCT_USM_LEVEL_NONE + ptr = get_buffer(ptr) + .template get_access() + .get_pointer(); +#else + auto sz = data.get_x(); + if (data.get_data_type() == image_data_type::pitch) + sz *= channel.get_total_size() * data.get_y(); + _host_buffer.resize(sz); + q.memcpy(_host_buffer.data(), ptr, sz).wait(); + ptr = _host_buffer.data(); +#endif + } + + if constexpr (dimensions == 1) { + assert(data.get_data_type() == image_data_type::linear); + _image = new sycl::image<1>( + ptr, channel.get_channel_order(), channel.get_channel_type(), + sycl::range<1>(data.get_x() / channel.get_total_size())); + } else if constexpr (dimensions == 2) { + assert(data.get_data_type() == image_data_type::pitch); + _image = new sycl::image<2>(ptr, channel.get_channel_order(), + channel.get_channel_type(), + sycl::range<2>(data.get_x(), data.get_y()), + sycl::range<1>(data.get_pitch())); + } else { + throw std::runtime_error("3D image only support matrix data"); + } + return; + } + +public: + using acc_data_t = typename detail::image_trait::acc_data_t; + using accessor_t = + typename image_accessor_ext::accessor_t; + + image_wrapper() { set_channel(image_channel::create()); } + ~image_wrapper() { detach(); } + + /// Get image accessor. + accessor_t get_access(sycl::handler &cgh, sycl::queue &q = get_default_queue()) { + if (!_image) + create_image(q); + return accessor_t(*_image, cgh); + } + + /// Detach data. + void detach() override { + if (_image) + delete _image; + _image = nullptr; + } +}; + +/// Wrap sampler and image accessor together. +template +class image_accessor_ext { +public: + using accessor_t = + typename detail::image_trait::template accessor_t; + using data_t = typename detail::image_trait::data_t; + sycl::sampler _sampler; + accessor_t _img_acc; + +public: + image_accessor_ext(sycl::sampler sampler, accessor_t acc) + : _sampler(sampler), _img_acc(acc) {} + + /// Read data from accessor. + template + typename std::enable_if::type read(float x, float y, + float z) { + return detail::fetch_data()( + _img_acc.read(sycl::float4(x, y, z, 0), _sampler)); + } + /// Read data from accessor. + template ::value + &&std::is_integral::value + &&std::is_integral::value> + typename std::enable_if::type read(Coord0 x, Coord1 y, + Coord2 z) { + return detail::fetch_data()( + _img_acc.read(sycl::int4(x, y, z, 0), _sampler)); + } + /// Read data from accessor. + template + typename std::enable_if::type read(float x, float y) { + return detail::fetch_data()( + _img_acc.read(sycl::float2(x, y), _sampler)); + } + /// Read data from accessor. + template ::value + &&std::is_integral::value> + typename std::enable_if::type read(Coord0 x, Coord1 y) { + return detail::fetch_data()( + _img_acc.read(sycl::int2(x, y), _sampler)); + } + /// Read data from accessor. + template + typename std::enable_if::type read(float x) { + return detail::fetch_data()(_img_acc.read(x, _sampler)); + } + /// Read data from accessor. + template ::value> + typename std::enable_if::type read(CoordT x) { + return detail::fetch_data()(_img_acc.read(x, _sampler)); + } +}; + +template class image_accessor_ext { +public: + using accessor_t = + typename detail::image_trait::template array_accessor_t; + using data_t = typename detail::image_trait::data_t; + sycl::sampler _sampler; + accessor_t _img_acc; + +public: + image_accessor_ext(sycl::sampler sampler, accessor_t acc) + : _sampler(sampler), _img_acc(acc) {} + + /// Read data from accessor. + template + typename std::enable_if::type read(int index, float x, + float y) { + return detail::fetch_data()( + _img_acc[index].read(sycl::float2(x, y), _sampler)); + } + /// Read data from accessor. + template + typename std::enable_if::type read(int index, int x, int y) { + return detail::fetch_data()( + _img_acc[index].read(sycl::int2(x, y), _sampler)); + } + /// Read data from accessor. + template + typename std::enable_if::type read(int index, float x) { + return detail::fetch_data()( + _img_acc[index].read(x, _sampler)); + } + /// Read data from accessor. + template + typename std::enable_if::type read(int index, int x) { + return detail::fetch_data()( + _img_acc[index].read(x, _sampler)); + } +}; + +/// Create image wrapper according to image data and sampling info. +/// \return Pointer to image wrapper base class. +/// \param data Image data used to create image wrapper. +/// \param info Image sampling info used to create image wrapper. +/// \returns Pointer to base class of created image wrapper object. +static inline image_wrapper_base *create_image_wrapper(image_data data, + sampling_info info) { + image_channel channel; + int dims = 1; + if (data.get_data_type() == image_data_type::matrix) { + auto matrix = (image_matrix_p)data.get_data_ptr(); + channel = matrix->get_channel(); + dims = matrix->get_dims(); + } else { + if (data.get_data_type() == image_data_type::pitch) { + dims = 2; + } + channel = data.get_channel(); + } + + if (auto ret = detail::create_image_wrapper(channel, dims)) { + ret->set_sampling_info(info); + ret->set_data(data); + return ret; + } + return nullptr; +} + +namespace detail { +/// Create image according with given type \p T and \p dims. +template static image_wrapper_base *create_image_wrapper(int dims) { + switch (dims) { + case 1: + return new image_wrapper(); + case 2: + return new image_wrapper(); + case 3: + return new image_wrapper(); + default: + return nullptr; + } +} +/// Create image with given data type \p T, channel order and dims +template +static image_wrapper_base *create_image_wrapper(unsigned channel_num, int dims) { + switch (channel_num) { + case 1: + return create_image_wrapper(dims); + case 2: + return create_image_wrapper>(dims); + case 3: + return create_image_wrapper>(dims); + case 4: + return create_image_wrapper>(dims); + default: + return nullptr; + } +} + +/// Create image with channel info and specified dimensions. +static image_wrapper_base *create_image_wrapper(image_channel channel, int dims) { + switch (channel.get_channel_type()) { + case sycl::image_channel_type::fp16: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::fp32: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::signed_int8: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::signed_int16: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::signed_int32: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::unsigned_int8: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::unsigned_int16: + return create_image_wrapper(channel.get_channel_num(), dims); + case sycl::image_channel_type::unsigned_int32: + return create_image_wrapper(channel.get_channel_num(), dims); + default: + return nullptr; + } +} +} // namespace detail + +} // namespace dpct + +#endif // !__DPCT_IMAGE_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/kernel.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/kernel.h new file mode 100644 index 0000000..36364e3 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/kernel.h @@ -0,0 +1,459 @@ +//==---- kernel.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_KERNEL_HPP__ +#define __DPCT_KERNEL_HPP__ + +#include +#ifdef _WIN32 +#include +#include +#else +#include +#endif + +#if defined(__has_include) && __has_include() +#include +#elif defined(__has_include) && __has_include() +#include +#else +#error "SYCLomatic runtime requires C++ filesystem support" +#endif + +#include +#include +#include + +namespace dpct { + +typedef void (*kernel_functor)(sycl::queue &, const sycl::nd_range<3> &, + unsigned int, void **, void **); + +struct kernel_function_info { + int max_work_group_size = 0; +}; + +static inline void get_kernel_function_info(kernel_function_info *kernel_info, + const void *function) { + kernel_info->max_work_group_size = + dpct::dev_mgr::instance() + .current_device() + .get_info(); +} +static inline kernel_function_info +get_kernel_function_info(const void *function) { + kernel_function_info kernel_info; + kernel_info.max_work_group_size = + dpct::dev_mgr::instance() + .current_device() + .get_info(); + return kernel_info; +} + + +namespace detail { + +#if defined(__has_include) && __has_include() +namespace fs = std::filesystem; +#else +namespace fs = std::experimental::filesystem; +#endif + +/// Write data to temporary file and return absolute path to temporary file. +/// Temporary file is created in a temporary directory both of which have random +/// names with only the user having access permissions. Only one temporary file +/// will be created in the temporary directory. +static inline fs::path write_data_to_file(char const *const data, size_t size) { + std::error_code ec; + + if (sizeof(size_t) >= sizeof(std::streamsize) && + size > (std::numeric_limits::max)()) + throw std::runtime_error("data file too large"); + + // random number generator + std::random_device dev; + std::mt19937 prng(dev()); + std::uniform_int_distribution rand(0); + + // find temporary directory + auto tmp_dir = fs::temp_directory_path(ec); + if (ec) + throw std::runtime_error("could not find temporary directory"); + + // create private directory + std::stringstream directory; + fs::path directory_path; + constexpr int max_attempts = 5; + int i; + + for (i = 0; i < max_attempts; i++) { + directory << std::hex << rand(prng); + directory_path = tmp_dir / directory.str(); + if (fs::create_directory(directory_path)) { + break; + } + } + if (i == max_attempts) + throw std::runtime_error("could not create directory"); + + // only allow owner permissions to private directory + fs::permissions(directory_path, fs::perms::owner_all, ec); + if (ec) + throw std::runtime_error("could not set directory permissions"); + + // random filename in private directory + std::stringstream filename; + filename << std::hex << rand(prng); +#ifdef _WIN32 + auto filepath = directory_path / (filename.str() + ".dll"); +#else + auto filepath = directory_path / filename.str(); +#endif + + // write data to temporary file + auto outfile = std::ofstream(filepath, std::ios::out | std::ios::binary); + if (outfile) { + // only allow program to write file + fs::permissions(filepath, fs::perms::owner_write, ec); + if (ec) + throw std::runtime_error("could not set permissions"); + + outfile.write(data, size); + if (!outfile.good()) + throw std::runtime_error("could not write data"); + outfile.close(); + + // only allow program to read/execute file + fs::permissions(filepath, fs::perms::owner_read | fs::perms::owner_exec, + ec); + if (ec) + throw std::runtime_error("could not set permissions"); + } else + throw std::runtime_error("could not write data"); + + // check temporary file contents + auto infile = std::ifstream(filepath, std::ios::in | std::ios::binary); + if (infile) { + bool mismatch = false; + size_t cnt = 0; + + while (1) { + char c; + infile.get(c); + if (infile.eof()) + break; + if (c != data[cnt++]) + mismatch = true; + } + if (cnt != size || mismatch) + throw std::runtime_error("file contents not written correctly"); + } else + throw std::runtime_error("could not validate file"); + + if (!filepath.is_absolute()) + throw std::runtime_error("temporary filepath is not absolute"); + + return filepath; +} + +static inline uint16_t extract16(unsigned char const *const ptr) { + uint16_t ret = 0; + + ret |= static_cast(ptr[0]) << 0; + ret |= static_cast(ptr[1]) << 8; + + return (ret); +} + +static inline uint32_t extract32(unsigned char const *const ptr) { + uint32_t ret = 0; + + ret |= static_cast(ptr[0]) << 0; + ret |= static_cast(ptr[1]) << 8; + ret |= static_cast(ptr[2]) << 16; + ret |= static_cast(ptr[3]) << 24; + + return (ret); +} + +static inline uint64_t extract64(unsigned char const *const ptr) { + uint64_t ret = 0; + + ret |= static_cast(ptr[0]) << 0; + ret |= static_cast(ptr[1]) << 8; + ret |= static_cast(ptr[2]) << 16; + ret |= static_cast(ptr[3]) << 24; + ret |= static_cast(ptr[4]) << 32; + ret |= static_cast(ptr[5]) << 40; + ret |= static_cast(ptr[6]) << 48; + ret |= static_cast(ptr[7]) << 56; + + return (ret); +} + +static inline uint64_t get_lib_size(char const *const blob) { +#ifdef _WIN32 + /////////////////////////////////////////////////////////////////////// + // Analyze DOS stub + unsigned char const *const ublob = + reinterpret_cast(blob); + if (ublob[0] != 0x4d || ublob[1] != 0x5a) { + throw std::runtime_error("Blob is not a Windows DLL."); + } + uint32_t pe_header_offset = extract32(ublob + 0x3c); + + /////////////////////////////////////////////////////////////////////// + // Ananlyze PE-header + unsigned char const *const pe_header = ublob + pe_header_offset; + + // signature + uint32_t pe_signature = extract32(pe_header + 0); + if (pe_signature != 0x00004550) { + throw std::runtime_error("PE-header signature is not 0x00004550"); + } + + // machine + uint16_t machine = extract16(pe_header + 4); + if (machine != 0x8664) { + throw std::runtime_error("Only DLLs for x64 supported"); + } + + // number of sections + uint16_t number_of_sections = extract16(pe_header + 6); + + // sizeof optional header + uint16_t sizeof_optional_header = extract16(pe_header + 20); + + // magic + uint16_t magic = extract16(pe_header + 24); + if (magic != 0x10b && magic != 0x20b) { + throw std::runtime_error("MAGIC is not 0x010b or 0x020b"); + } + + /////////////////////////////////////////////////////////////////////// + // Analyze tail of optional header + constexpr int coff_header_size = 24; + + unsigned char const *const tail_of_optional_header = + pe_header + coff_header_size + sizeof_optional_header; + if (extract64(tail_of_optional_header - 8) != 0) { + throw std::runtime_error("Optional header not zero-padded"); + } + + /////////////////////////////////////////////////////////////////////// + // Analyze last section header + constexpr int section_header_size = 40; + unsigned char const *const last_section_header = + tail_of_optional_header + section_header_size * (number_of_sections - 1); + + uint32_t sizeof_raw_data = extract32(last_section_header + 16); + uint32_t pointer_to_raw_data = extract32(last_section_header + 20); + + return sizeof_raw_data + pointer_to_raw_data; +#else + if (blob[0] != 0x7F || blob[1] != 'E' || blob[2] != 'L' || blob[3] != 'F') + throw std::runtime_error("Blob is not in ELF format"); + + if (blob[4] != 0x02) + throw std::runtime_error("Only 64-bit headers are supported"); + + if (blob[5] != 0x01) + throw std::runtime_error("Only little-endian headers are supported"); + + unsigned char const *const ublob = + reinterpret_cast(blob); + uint64_t e_shoff = extract64(ublob + 0x28); + uint16_t e_shentsize = extract16(ublob + 0x3A); + uint16_t e_shnum = extract16(ublob + 0x3C); + + return e_shoff + (e_shentsize * e_shnum); +#endif +} + +#ifdef _WIN32 +class path_lib_record { +public: + void operator=(const path_lib_record &) = delete; + ~path_lib_record() { + for (auto entry : lib_to_path) { + FreeLibrary(static_cast(entry.first)); + fs::permissions(entry.second, fs::perms::owner_all); + fs::remove_all(entry.second.remove_filename()); + } + } + static void record_lib_path(fs::path path, void *library) { + lib_to_path[library] = path; + } + static void remove_lib(void *library) { + auto path = lib_to_path[library]; + std::error_code ec; + + FreeLibrary(static_cast(library)); + fs::permissions(path, fs::perms::owner_all); + if (fs::remove_all(path.remove_filename(), ec) != 2 || ec) + // one directory and one temporary file should have been deleted + throw std::runtime_error("Directory delete failed"); + + lib_to_path.erase(library); + } + +private: + static inline std::unordered_map lib_to_path; +}; +#endif + +} // namespace detail + +class kernel_library { +public: + kernel_library() : ptr{nullptr} {} + kernel_library(void *ptr) : ptr{ptr} {} + + operator void *() const { return ptr; } + +private: + void *ptr; +#ifdef _WIN32 + static inline detail::path_lib_record single_instance_to_trigger_destructor; +#endif +}; + +namespace detail { + +static inline kernel_library load_dl_from_data(char const *const data, + size_t size) { + fs::path filename = write_data_to_file(data, size); +#ifdef _WIN32 + void *so = LoadLibraryW(filename.wstring().c_str()); +#else + void *so = dlopen(filename.c_str(), RTLD_LAZY); +#endif + if (so == nullptr) + throw std::runtime_error("Failed to load kernel library"); + +#ifdef _WIN32 + detail::path_lib_record::record_lib_path(filename, so); +#else + std::error_code ec; + + // Windows DLL cannot be deleted while in use + if (fs::remove_all(filename.remove_filename(), ec) != 2 || ec) + // one directory and one temporary file should have been deleted + throw std::runtime_error("Directory delete failed"); +#endif + + return so; +} + +} // namespace detail + +/// Load kernel library and return a handle to use the library. +/// \param [in] name The name of the library. +static inline kernel_library load_kernel_library(const std::string &name) { + std::ifstream ifs; + ifs.open(name, std::ios::in | std::ios::binary); + + std::stringstream buffer; + buffer << ifs.rdbuf(); + + const std::string buffer_string = buffer.str(); + return detail::load_dl_from_data(buffer_string.c_str(), buffer_string.size()); +} + +/// Load kernel library whose image is alreay in memory and return a handle to +/// use the library. +/// \param [in] image A pointer to the image in memory. +static inline kernel_library load_kernel_library_mem(char const *const image) { + const size_t size = detail::get_lib_size(image); + + return detail::load_dl_from_data(image, size); +} + +/// Unload kernel library. +/// \param [in,out] library Handle to the library to be closed. +static inline void unload_kernel_library(const kernel_library &library) { +#ifdef _WIN32 + detail::path_lib_record::remove_lib(library); +#else + dlclose(library); +#endif +} + +class kernel_function { +public: + kernel_function() : ptr{nullptr} {} + kernel_function(dpct::kernel_functor ptr) : ptr{ptr} {} + + operator void *() const { return ((void *)ptr); } + + void operator()(sycl::queue &q, const sycl::nd_range<3> &range, + unsigned int a, void **args, void **extra) { + ptr(q, range, a, args, extra); + } + +private: + dpct::kernel_functor ptr; +}; + +/// Find kernel function in a kernel library and return its address. +/// \param [in] library Handle to the kernel library. +/// \param [in] name Name of the kernel function. +static inline dpct::kernel_function +get_kernel_function(kernel_library &library, const std::string &name) { +#ifdef _WIN32 + dpct::kernel_functor fn = reinterpret_cast( + GetProcAddress(static_cast(static_cast(library)), + (name + std::string("_wrapper")).c_str())); +#else + dpct::kernel_functor fn = reinterpret_cast( + dlsym(library, (name + std::string("_wrapper")).c_str())); +#endif + if (fn == nullptr) + throw std::runtime_error("Failed to get function"); + return fn; +} + +/// Invoke a kernel function. +/// \param [in] function kernel function. +/// \param [in] queue SYCL queue used to execute kernel +/// \param [in] groupRange SYCL group range +/// \param [in] localRange SYCL local range +/// \param [in] localMemSize The size of local memory required by the kernel +/// function. +/// \param [in] kernelParams Array of pointers to kernel arguments. +/// \param [in] extra Extra arguments. +static inline void invoke_kernel_function(dpct::kernel_function &function, + sycl::queue &queue, + sycl::range<3> groupRange, + sycl::range<3> localRange, + unsigned int localMemSize, + void **kernelParams, void **extra) { + function(queue, sycl::nd_range<3>(groupRange * localRange, localRange), + localMemSize, kernelParams, extra); +} + +/// Find image wrapper in a kernel library and return its address. +/// \param [in] library Handle to the kernel library. +/// \param [in] name Name of the target image wrapper. +static inline dpct::image_wrapper_base_p +get_image_wrapper(dpct::kernel_library &library, const std::string &name) { +#ifdef _WIN32 + dpct::image_wrapper_base_p fn = + reinterpret_cast(GetProcAddress( + static_cast(static_cast(library)), name.c_str())); +#else + dpct::image_wrapper_base_p fn = reinterpret_cast( + dlsym(library, name.c_str())); +#endif + if (fn == nullptr) + throw std::runtime_error("Failed to get image"); + return fn; +} + +} // namespace dpct +#endif // __DPCT_KERNEL_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/lib_common_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/lib_common_utils.h new file mode 100644 index 0000000..24c000c --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/lib_common_utils.h @@ -0,0 +1,159 @@ +//==---- lib_common_utils.hpp ---------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_LIB_COMMON_UTILS_HPP__ +#define __DPCT_LIB_COMMON_UTILS_HPP__ + +#include +#include +#include "memory.h" +#include "util.h" + +namespace dpct { +namespace detail { +template inline auto get_memory(const void *x) { + T *new_x = reinterpret_cast(const_cast(x)); +#ifdef DPCT_USM_LEVEL_NONE + return dpct::get_buffer>(new_x); +#else + return new_x; +#endif +} + +template +inline typename DataType::T2 get_value(const T *s, sycl::queue &q) { + using Ty = typename DataType::T2; + Ty s_h; + if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) + detail::dpct_memcpy(q, (void *)&s_h, (void *)s, sizeof(T), device_to_host) + .wait(); + else + s_h = *reinterpret_cast(s); + return s_h; +} +} // namespace detail + +enum class version_field : int { major, minor, update, patch }; + +/// Returns the requested field of Intel(R) oneAPI Math Kernel Library version. +/// \param field The version information field (major, minor, update or patch). +/// \param result The result value. +inline void mkl_get_version(version_field field, int *result) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + MKLVersion version; + mkl_get_version(&version); + if (version_field::major == field) { + *result = version.MajorVersion; + } else if (version_field::minor == field) { + *result = version.MinorVersion; + } else if (version_field::update == field) { + *result = version.UpdateVersion; + } else if (version_field::patch == field) { + *result = 0; + } else { + throw std::runtime_error("unknown field"); + } +#endif +} + +enum class library_data_t : unsigned char { + real_float = 0, + complex_float, + real_double, + complex_double, + real_half, + complex_half, + real_bfloat16, + complex_bfloat16, + real_int4, + complex_int4, + real_uint4, + complex_uint4, + real_int8, + complex_int8, + real_uint8, + complex_uint8, + real_int16, + complex_int16, + real_uint16, + complex_uint16, + real_int32, + complex_int32, + real_uint32, + complex_uint32, + real_int64, + complex_int64, + real_uint64, + complex_uint64, + real_int8_4, + real_int8_32, + real_uint8_4, + library_data_t_size +}; + +namespace detail { +template +inline constexpr std::uint64_t get_type_combination_id(ArgT Val) { + static_assert((unsigned char)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(std::is_same_v, "Unsupported ArgT"); + return (std::uint64_t)Val; +} + +template +inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, + RestT... RestVal) { + static_assert((std::uint8_t)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); + static_assert(std::is_same_v, "Unsupported FirstT"); + return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); +} + +inline constexpr std::size_t library_data_size[] = { + 8 * sizeof(float), // real_float + 8 * sizeof(std::complex), // complex_float + 8 * sizeof(double), // real_double + 8 * sizeof(std::complex), // complex_double + 8 * sizeof(sycl::half), // real_half + 8 * sizeof(std::complex), // complex_half + 16, // real_bfloat16 + 16 * 2, // complex_bfloat16 + 4, // real_int4 + 4 * 2, // complex_int4 + 4, // real_uint4 + 4 * 2, // complex_uint4 + 8, // real_int8 + 8 * 2, // complex_int8 + 8, // real_uint8 + 8 * 2, // complex_uint8 + 16, // real_int16 + 16 * 2, // complex_int16 + 16, // real_uint16 + 16 * 2, // complex_uint16 + 32, // real_int32 + 32 * 2, // complex_int32 + 32, // real_uint32 + 32 * 2, // complex_uint32 + 64, // real_int64 + 64 * 2, // complex_int64 + 64, // real_uint64 + 64 * 2, // complex_uint64 + 8, // real_int8_4 + 8, // real_int8_32 + 8 // real_uint8_4 +}; +} // namespace detail +} // namespace dpct + +#endif // __DPCT_LIB_COMMON_UTILS_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/math.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/math.h new file mode 100644 index 0000000..c569a28 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/math.h @@ -0,0 +1,1011 @@ +//==---- math.hpp ---------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_MATH_HPP__ +#define __DPCT_MATH_HPP__ + +#include +#include +#include + +namespace dpct { +namespace detail { +template +class vectorized_binary { +public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) { + v4[i] = binary_op(a[i], b[i]); + } + return v4; + } +}; +template +class vectorized_binary< + VecT, BinaryOperation, + std::void_t>> { +public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + return binary_op(a, b).template as(); + } +}; + +template inline bool isnan(const T a) { return sycl::isnan(a); } +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +inline bool isnan(const sycl::ext::oneapi::bfloat16 a) { + return sycl::ext::oneapi::experimental::isnan(a); +} +#endif +} // namespace detail + +/// Compute fast_length for variable-length array +/// \param [in] a The array +/// \param [in] len Length of the array +/// \returns The computed fast_length +inline float fast_length(const float *a, int len) { + switch (len) { + case 1: + return a[0]; + case 2: + return sycl::fast_length(sycl::float2(a[0], a[1])); + case 3: + return sycl::fast_length(sycl::float3(a[0], a[1], a[2])); + case 4: + return sycl::fast_length(sycl::float4(a[0], a[1], a[2], a[3])); + case 0: + return 0; + default: + float f = 0; + for (int i = 0; i < len; ++i) + f += a[i] * a[i]; + return sycl::sqrt(f); + } +} + +/// Calculate the square root of the input array. +/// \param [in] a The array pointer +/// \param [in] len Length of the array +/// \returns The square root +template inline T length(const T *a, const int len) { + switch (len) { + case 1: + return a[0]; + case 2: + return sycl::length(sycl::vec(a[0], a[1])); + case 3: + return sycl::length(sycl::vec(a[0], a[1], a[2])); + case 4: + return sycl::length(sycl::vec(a[0], a[1], a[2], a[3])); + default: + T ret = 0; + for (int i = 0; i < len; ++i) + ret += a[i] * a[i]; + return sycl::sqrt(ret); + } +} + +/// Returns min(max(val, min_val), max_val) +/// \param [in] val The input value +/// \param [in] min_val The minimum value +/// \param [in] max_val The maximum value +/// \returns the value between min_val and max_val +template inline T clamp(T val, T min_val, T max_val) { + return sycl::clamp(val, min_val, max_val); +} +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> +inline sycl::ext::oneapi::bfloat16 clamp(sycl::ext::oneapi::bfloat16 val, + sycl::ext::oneapi::bfloat16 min_val, + sycl::ext::oneapi::bfloat16 max_val) { + if (val < min_val) + return min_val; + if (val > max_val) + return max_val; + return val; +} +#endif +template +inline sycl::marray clamp(sycl::marray val, + sycl::marray min_val, + sycl::marray max_val) { + return {clamp(val[0], min_val[0], max_val[0]), + clamp(val[1], min_val[1], max_val[1])}; +} + +/// Performs comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t< + std::is_same_v, bool>, bool> +compare(const T a, const T b, const BinaryOperation binary_op) { + return binary_op(a, b); +} +template +inline std::enable_if_t< + std::is_same_v, T, T>, bool>, bool> +compare(const T a, const T b, const std::not_equal_to<> binary_op) { + return !detail::isnan(a) && !detail::isnan(b) && binary_op(a, b); +} + +/// Performs unordered comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t< + std::is_same_v, bool>, bool> +unordered_compare(const T a, const T b, const BinaryOperation binary_op) { + return detail::isnan(a) || detail::isnan(b) || binary_op(a, b); +} + +/// Performs 2 element comparison and return true if both results are true. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare_both(const T a, const T b, const BinaryOperation binary_op) { + return compare(a[0], b[0], binary_op) && compare(a[1], b[1], binary_op); +} + +/// Performs 2 element unordered comparison and return true if both results are +/// true. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare_both(const T a, const T b, const BinaryOperation binary_op) { + return unordered_compare(a[0], b[0], binary_op) && + unordered_compare(a[1], b[1], binary_op); +} + +/// Performs 2 element comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare(const T a, const T b, const BinaryOperation binary_op) { + return {compare(a[0], b[0], binary_op), compare(a[1], b[1], binary_op)}; +} + +/// Performs 2 elements comparison, compare result of each element is 0 (false) +/// or 0xffff (true), returns an unsigned int by composing compare result of two +/// elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned compare_mask(const sycl::vec a, const sycl::vec b, + const BinaryOperation binary_op) { + return sycl::vec(-compare(a[0], b[0], binary_op), + -compare(a[1], b[1], binary_op)) + .as>(); +} +template +inline unsigned compare_mask(const sycl::marray a, + const sycl::marray b, + const BinaryOperation binary_op) { + return sycl::vec(-compare(a[0], b[0], binary_op), + -compare(a[1], b[1], binary_op)) + .as>(); +} + +/// Performs 2 element unordered comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare(const T a, const T b, const BinaryOperation binary_op) { + return {unordered_compare(a[0], b[0], binary_op), + unordered_compare(a[1], b[1], binary_op)}; +} + +/// Performs 2 elements unordered comparison, compare result of each element is +/// 0 (false) or 0xffff (true), returns an unsigned int by composing compare +/// result of two elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned unordered_compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op) { + return sycl::vec(-unordered_compare(a[0], b[0], binary_op), + -unordered_compare(a[1], b[1], binary_op)) + .as>(); +} +template +inline unsigned unordered_compare_mask(const sycl::marray a, + const sycl::marray b, + const BinaryOperation binary_op) { + return sycl::vec(-unordered_compare(a[0], b[0], binary_op), + -unordered_compare(a[1], b[1], binary_op)) + .as>(); +} + +/// Determine whether 2 element value is NaN. +/// \param [in] a The input value +/// \returns the comparison result +template +inline std::enable_if_t isnan(const T a) { + return {detail::isnan(a[0]), detail::isnan(a[1])}; +} + +/// Emulated function for __funnelshift_l +inline unsigned int funnelshift_l(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) << (shift & 31U)) >> 32; +} + +/// Emulated function for __funnelshift_lc +inline unsigned int funnelshift_lc(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) << sycl::min(shift, 32U)) >> 32; +} + +/// Emulated function for __funnelshift_r +inline unsigned int funnelshift_r(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) >> (shift & 31U)) & 0xFFFFFFFF; +} + +/// Emulated function for __funnelshift_rc +inline unsigned int funnelshift_rc(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) >> sycl::min(shift, 32U)) & 0xFFFFFFFF; +} + +/// cbrt function wrapper. +template inline T cbrt(T val) { return sycl::cbrt((T)val); } + +// min function overloads. +// For floating-point types, `float` or `double` arguments are acceptable. +// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or +// `std::int64_t` type arguments are acceptable. +inline double min(const double a, const float b) { + return sycl::fmin(a, static_cast(b)); +} +inline double min(const float a, const double b) { + return sycl::fmin(static_cast(a), b); +} +inline float min(const float a, const float b) { return sycl::fmin(a, b); } +inline double min(const double a, const double b) { return sycl::fmin(a, b); } +inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) { + return sycl::min(static_cast(a), b); +} +inline std::int32_t min(const std::int32_t a, const std::int32_t b) { + return sycl::min(a, b); +} +inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) { + return sycl::min(a, b); +} +inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) { + return sycl::min(static_cast(a), b); +} +inline std::int64_t min(const std::int64_t a, const std::int64_t b) { + return sycl::min(a, b); +} +inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) { + return sycl::min(a, b); +} +inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) { + return sycl::min(static_cast(a), b); +} +inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) { + return sycl::min(static_cast(a), b); +} +// max function overloads. +// For floating-point types, `float` or `double` arguments are acceptable. +// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or +// `std::int64_t` type arguments are acceptable. +inline double max(const double a, const float b) { + return sycl::fmax(a, static_cast(b)); +} +inline double max(const float a, const double b) { + return sycl::fmax(static_cast(a), b); +} +inline float max(const float a, const float b) { return sycl::fmax(a, b); } +inline double max(const double a, const double b) { return sycl::fmax(a, b); } +inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) { + return sycl::max(static_cast(a), b); +} +inline std::int32_t max(const std::int32_t a, const std::int32_t b) { + return sycl::max(a, b); +} +inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) { + return sycl::max(a, b); +} +inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) { + return sycl::max(static_cast(a), b); +} +inline std::int64_t max(const std::int64_t a, const std::int64_t b) { + return sycl::max(a, b); +} +inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) { + return sycl::max(a, b); +} +inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) { + return sycl::max(static_cast(a), b); +} +inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) { + return sycl::max(static_cast(a), b); +} + +// pow functions overload. +inline float pow(const float a, const int b) { return sycl::pown(a, b); } +inline double pow(const double a, const int b) { return sycl::pown(a, b); } +inline float pow(const float a, const float b) { return sycl::pow(a, b); } +inline double pow(const double a, const double b) { return sycl::pow(a, b); } +template +inline typename std::enable_if_t, T> +pow(const T a, const U b) { + return sycl::pow(a, static_cast(b)); +} +template +inline typename std::enable_if_t, double> +pow(const T a, const U b) { + return sycl::pow(static_cast(a), static_cast(b)); +} + +namespace detail { +template +constexpr bool is_floating_point = + std::disjunction_v, std::is_same +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS + , + std::is_same +#endif + >; +} // namespace detail + +/// Performs relu saturation. +/// \param [in] a The input value +/// \returns the relu saturation result +template inline T relu(T a) { + T zero{}; + if constexpr (detail::is_floating_point) + return !detail::isnan(a) && a < zero ? zero : a; + else + return a < zero ? zero : a; +} +template inline sycl::vec relu(const sycl::vec a) { + return {relu(a[0]), relu(a[1])}; +} +template inline sycl::marray relu(const sycl::marray a) { + return {relu(a[0]), relu(a[1])}; +} + +/// Performs complex number multiply addition. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns the operation result +template +inline sycl::vec complex_mul_add(const sycl::vec a, + const sycl::vec b, + const sycl::vec c) { + return sycl::vec{a[0] * b[0] - a[1] * b[1] + c[0], + a[0] * b[1] + a[1] * b[0] + c[1]}; +} +template +inline sycl::marray complex_mul_add(const sycl::marray a, + const sycl::marray b, + const sycl::marray c) { + return sycl::marray{a[0] * b[0] - a[1] * b[1] + c[0], + a[0] * b[1] + a[1] * b[0] + c[1]}; +} + +/// Performs 2 elements comparison and returns the bigger one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the bigger value +template inline T fmax_nan(const T a, const T b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmax(a, b); +} +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> +inline sycl::ext::oneapi::bfloat16 +fmax_nan(const sycl::ext::oneapi::bfloat16 a, + const sycl::ext::oneapi::bfloat16 b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmax(float(a), float(b)); +} +#endif +template +inline sycl::vec fmax_nan(const sycl::vec a, + const sycl::vec b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} +template +inline sycl::marray fmax_nan(const sycl::marray a, + const sycl::marray b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} + +/// Performs 2 elements comparison and returns the smaller one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the smaller value +template inline T fmin_nan(const T a, const T b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmin(a, b); +} +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> +inline sycl::ext::oneapi::bfloat16 +fmin_nan(const sycl::ext::oneapi::bfloat16 a, + const sycl::ext::oneapi::bfloat16 b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmin(float(a), float(b)); +} +#endif +template +inline sycl::vec fmin_nan(const sycl::vec a, + const sycl::vec b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} +template +inline sycl::marray fmin_nan(const sycl::marray a, + const sycl::marray b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} + +/// A sycl::abs wrapper functors. +struct abs { + template auto operator()(const T x) const { + return sycl::abs(x); + } +}; + +/// A sycl::abs_diff wrapper functors. +struct abs_diff { + template auto operator()(const T x, const T y) const { + return sycl::abs_diff(x, y); + } +}; + +/// A sycl::add_sat wrapper functors. +struct add_sat { + template auto operator()(const T x, const T y) const { + return sycl::add_sat(x, y); + } +}; + +/// A sycl::rhadd wrapper functors. +struct rhadd { + template auto operator()(const T x, const T y) const { + return sycl::rhadd(x, y); + } +}; + +/// A sycl::hadd wrapper functors. +struct hadd { + template auto operator()(const T x, const T y) const { + return sycl::hadd(x, y); + } +}; + +/// A sycl::max wrapper functors. +struct maximum { + template auto operator()(const T x, const T y) const { + return sycl::max(x, y); + } +}; + +/// A sycl::min wrapper functors. +struct minimum { + template auto operator()(const T x, const T y) const { + return sycl::min(x, y); + } +}; + +/// A sycl::sub_sat wrapper functors. +struct sub_sat { + template auto operator()(const T x, const T y) const { + return sycl::sub_sat(x, y); + } +}; + +/// Compute vectorized binary operation value for two values, with each value +/// treated as a vector type \p VecT. +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation The binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized binary operation value of the two values +template +inline unsigned vectorized_binary(unsigned a, unsigned b, + const BinaryOperation binary_op) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = + detail::vectorized_binary()(v2, v3, binary_op); + v0 = v4.template as>(); + return v0; +} + +/// Compute vectorized isgreater for two values, with each value treated as a +/// vector type \p S. +/// \tparam [in] S The type of the vector +/// \tparam [in] T The type of the original values +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized greater than of the two values +template inline T vectorized_isgreater(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = v2 > v3; + v0 = v4.template as>(); + return v0; +} + +/// Compute vectorized max for two values, with each value treated as a vector +/// type \p S. +/// \tparam [in] S The type of the vector +/// \tparam [in] T The type of the original values +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized max of the two values +template inline T vectorized_max(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::max(v2, v3); + v0 = v4.template as>(); + return v0; +} + +/// Compute vectorized min for two values, with each value treated as a vector +/// type \p S. +/// \tparam [in] S The type of the vector +/// \tparam [in] T The type of the original values +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized min of the two values +template inline T vectorized_min(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::min(v2, v3); + v0 = v4.template as>(); + return v0; +} + +/// Compute vectorized unary operation for a value, with the value treated as a +/// vector type \p VecT. +/// \tparam [in] VecT The type of the vector +/// \tparam [in] UnaryOperation The unary operation class +/// \param [in] a The input value +/// \returns The vectorized unary operation value of the input value +template +inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) { + sycl::vec v0{a}; + auto v1 = v0.as(); + auto v2 = unary_op(v1); + v0 = v2.template as>(); + return v0; +} + +/// Compute vectorized absolute difference for two values without modulo +/// overflow, with each value treated as a vector type \p VecT. +/// \tparam [in] VecT The type of the vector +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized absolute difference of the two values +template +inline unsigned vectorized_sum_abs_diff(unsigned a, unsigned b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = sycl::abs_diff(v2, v3); + unsigned sum = 0; + for (size_t i = 0; i < v4.size(); ++i) { + sum += v4[i]; + } + return sum; +} + +namespace detail { +/// Extend the 'val' to 'bit' size, zero extend for unsigned int and signed +/// extend for signed int. +template +inline int64_t zero_or_signed_extent(T val, unsigned bit) { + if constexpr (std::is_signed_v) { + return int64_t(val) << (64 - bit) >> (64 - bit); + } + return val; +} + +template +inline constexpr RetT extend_binary(AT a, BT b, BinaryOperation binary_op) { + int64_t extend_a = zero_or_signed_extent(a, 33); + int64_t extend_b = zero_or_signed_extent(b, 33); + int64_t ret = binary_op(extend_a, extend_b); + if constexpr (NeedSat) + return dpct::clamp(ret, std::numeric_limits::min(), + std::numeric_limits::max()); + return ret; +} + +template +inline constexpr RetT extend_binary(AT a, BT b, CT c, + BinaryOperation1 binary_op, + BinaryOperation2 second_op) { + int64_t extend_a = zero_or_signed_extent(a, 33); + int64_t extend_b = zero_or_signed_extent(b, 33); + int64_t extend_temp = + zero_or_signed_extent(binary_op(extend_a, extend_b), 34); + if constexpr (NeedSat) + extend_temp = + dpct::clamp(extend_temp, std::numeric_limits::min(), + std::numeric_limits::max()); + int64_t extend_c = zero_or_signed_extent(c, 33); + return second_op(extend_temp, extend_c); +} +} // namespace detail + +/// Extend \p a and \p b to 33 bit and add them. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend addition of the two values +template +inline constexpr RetT extend_add(AT a, BT b) { + return detail::extend_binary(a, b, std::plus()); +} + +/// Extend Inputs to 33 bit, add \p a, \p b, then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend addition of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_add(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::plus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and add them with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend addition of the two values with saturation +template +inline constexpr RetT extend_add_sat(AT a, BT b) { + return detail::extend_binary(a, b, std::plus()); +} + +/// Extend Inputs to 33 bit, add \p a, \p b with saturation, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend addition of \p a, \p b with saturation and \p second_op +/// with \p c +template +inline constexpr RetT extend_add_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::plus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and minus them. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend subtraction of the two values +template +inline constexpr RetT extend_sub(AT a, BT b) { + return detail::extend_binary(a, b, std::minus()); +} + +/// Extend Inputs to 33 bit, minus \p a, \p b, then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend subtraction of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_sub(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::minus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and minus them with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend subtraction of the two values with saturation +template +inline constexpr RetT extend_sub_sat(AT a, BT b) { + return detail::extend_binary(a, b, std::minus()); +} + +/// Extend Inputs to 33 bit, minus \p a, \p b with saturation, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend subtraction of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_sub_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::minus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and do abs_diff. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend abs_diff of the two values +template +inline constexpr RetT extend_absdiff(AT a, BT b) { + return detail::extend_binary(a, b, abs_diff()); +} + +/// Extend Inputs to 33 bit, abs_diff \p a, \p b, then do \p second_op with \p +/// c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend abs_diff of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_absdiff(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, abs_diff(), second_op); +} + +/// Extend \p a and \p b to 33 bit and do abs_diff with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend abs_diff of the two values with saturation +template +inline constexpr RetT extend_absdiff_sat(AT a, BT b) { + return detail::extend_binary(a, b, abs_diff()); +} + +/// Extend Inputs to 33 bit, abs_diff \p a, \p b with saturation, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend abs_diff of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_absdiff_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, abs_diff(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return smaller one. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The smaller one of the two extended values +template +inline constexpr RetT extend_min(AT a, BT b) { + return detail::extend_binary(a, b, minimum()); +} + +/// Extend Inputs to 33 bit, find the smaller one in \p a, \p b, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The smaller one of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_min(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, minimum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return smaller one with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The smaller one of the two extended values with saturation +template +inline constexpr RetT extend_min_sat(AT a, BT b) { + return detail::extend_binary(a, b, minimum()); +} + +/// Extend Inputs to 33 bit, find the smaller one in \p a, \p b with saturation, +/// then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The smaller one of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_min_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, minimum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return bigger one. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The bigger one of the two extended values +template +inline constexpr RetT extend_max(AT a, BT b) { + return detail::extend_binary(a, b, maximum()); +} + +/// Extend Inputs to 33 bit, find the bigger one in \p a, \p b, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The bigger one of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_max(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, maximum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return bigger one with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The bigger one of the two extended values with saturation +template +inline constexpr RetT extend_max_sat(AT a, BT b) { + return detail::extend_binary(a, b, maximum()); +} + +/// Extend Inputs to 33 bit, find the bigger one in \p a, \p b with saturation, +/// then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The bigger one of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_max_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, maximum(), second_op); +} +} // namespace dpct + +#endif // __DPCT_MATH_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/memory.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/memory.h new file mode 100644 index 0000000..bb9462a --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/memory.h @@ -0,0 +1,1493 @@ +//==---- memory.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_MEMORY_HPP__ +#define __DPCT_MEMORY_HPP__ + +#include "device.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__linux__) +#include +#elif defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#error "Only support Windows and Linux." +#endif + +namespace dpct { + +enum memcpy_direction { + host_to_host, + host_to_device, + device_to_host, + device_to_device, + automatic +}; +enum memory_region { + global = 0, // device global memory + constant, // device constant memory + local, // device local memory + shared, // memory which can be accessed by host and device +}; + +typedef uint8_t byte_t; + +/// Buffer type to be used in Memory Management runtime. +typedef sycl::buffer buffer_t; + +/// Pitched 2D/3D memory data. +class pitched_data { +public: + pitched_data() : pitched_data(nullptr, 0, 0, 0) {} + pitched_data(void *data, size_t pitch, size_t x, size_t y) + : _data(data), _pitch(pitch), _x(x), _y(y) {} + + void *get_data_ptr() { return _data; } + void set_data_ptr(void *data) { _data = data; } + + size_t get_pitch() { return _pitch; } + void set_pitch(size_t pitch) { _pitch = pitch; } + + size_t get_x() { return _x; } + void set_x(size_t x) { _x = x; }; + + size_t get_y() { return _y; } + void set_y(size_t y) { _y = y; } + +private: + void *_data; + size_t _pitch, _x, _y; +}; + +namespace detail { +class mem_mgr { + mem_mgr() { + // Reserved address space, no real memory allocation happens here. +#if defined(__linux__) + mapped_address_space = + (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); +#elif defined(_WIN64) + mapped_address_space = (byte_t *)VirtualAlloc( + NULL, // NULL specified as the base address parameter + mapped_region_size, // Size of allocation + MEM_RESERVE, // Allocate reserved pages + PAGE_NOACCESS); // Protection = no access +#else +#error "Only support Windows and Linux." +#endif + next_free = mapped_address_space; + }; + +public: + using buffer_id_t = int; + + struct allocation { + buffer_t buffer; + byte_t *alloc_ptr; + size_t size; + }; + + ~mem_mgr() { +#if defined(__linux__) + munmap(mapped_address_space, mapped_region_size); +#elif defined(_WIN64) + VirtualFree(mapped_address_space, 0, MEM_RELEASE); +#else +#error "Only support Windows and Linux." +#endif + }; + + mem_mgr(const mem_mgr &) = delete; + mem_mgr &operator=(const mem_mgr &) = delete; + mem_mgr(mem_mgr &&) = delete; + mem_mgr &operator=(mem_mgr &&) = delete; + + /// Allocate + void *mem_alloc(size_t size) { + if (!size) + return nullptr; + std::lock_guard lock(m_mutex); + if (next_free + size > mapped_address_space + mapped_region_size) { + throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool"); + } + // Allocation + sycl::range<1> r(size); + buffer_t buf(r); + allocation A{buf, next_free, size}; + // Map allocation to device pointer + void *result = next_free; + m_map.emplace(next_free + size, A); + // Update pointer to the next free space. + next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); + + return result; + } + + /// Deallocate + void mem_free(const void *ptr) { + if (!ptr) + return; + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + m_map.erase(it); + } + + /// map: device pointer -> allocation(buffer, alloc_ptr, size) + allocation translate_ptr(const void *ptr) { + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + return it->second; + } + + /// Check if the pointer represents device pointer or not. + bool is_device_ptr(const void *ptr) const { + std::lock_guard lock(m_mutex); + return (mapped_address_space <= ptr) && + (ptr < mapped_address_space + mapped_region_size); + } + + /// Returns the instance of memory manager singleton. + static mem_mgr &instance() { + static mem_mgr m; + return m; + } + +private: + std::map m_map; + mutable std::mutex m_mutex; + byte_t *mapped_address_space; + byte_t *next_free; + const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; + const size_t alignment = 256; + /// This padding may be defined to some positive value to debug + /// out of bound accesses. + const size_t extra_padding = 0; + + std::map::iterator get_map_iterator(const void *ptr) { + auto it = m_map.upper_bound((byte_t *)ptr); + if (it == m_map.end()) { + // Not a virtual pointer. + throw std::runtime_error("can not get buffer from non-virtual pointer"); + } + const allocation &alloc = it->second; + if (ptr < alloc.alloc_ptr) { + // Out of bound. + // This may happen if there's a gap between allocations due to alignment + // or extra padding and pointer points to this gap. + throw std::runtime_error("invalid virtual pointer"); + } + return it; + } +}; + +template class accessor; +template class memory_traits { +public: + static constexpr sycl::access::target target = + sycl::access::target::device; + static constexpr sycl::access_mode mode = + (Memory == constant) ? sycl::access_mode::read + : sycl::access_mode::read_write; + static constexpr size_t type_size = sizeof(T); + using element_t = + typename std::conditional::type; + using value_t = typename std::remove_cv::type; + template + using accessor_t = typename std::conditional< + Memory == local, sycl::local_accessor, + sycl::accessor>::type; + using pointer_t = T *; +}; + +static inline void *dpct_malloc(size_t size, sycl::queue &q) { +#ifdef DPCT_USM_LEVEL_NONE + return mem_mgr::instance().mem_alloc(size * sizeof(byte_t)); +#else + return sycl::malloc_device(size, q.get_device(), q.get_context()); +#endif // DPCT_USM_LEVEL_NONE +} + +#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) +static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, + sycl::queue &q) { + pitch = PITCH_DEFAULT_ALIGN(x); + return dpct_malloc(pitch * y * z, q); +} + +/** + * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @return An event representing the memset operation. + */ +template +static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, + valueT value, size_t size) { +#ifdef DPCT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + assert(mm.is_device_ptr(dev_ptr)); + auto alloc = mm.translate_ptr(dev_ptr); + size_t offset = (valueT *)dev_ptr - (valueT *)alloc.alloc_ptr; + + return q.submit([&](sycl::handler &cgh) { + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + auto new_buffer = alloc.buffer.reinterpret( + sycl::range<1>(alloc.size / sizeof(valueT))); + sycl::accessor + acc(new_buffer, cgh, r, o); + cgh.fill(acc, value); + }); +#else + return q.fill(dev_ptr, value, size); +#endif // DPCT_USM_LEVEL_NONE +} + +/** + * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @return An event list representing the memset operations. + */ +template +static inline std::vector +dpct_memset(sycl::queue &q, pitched_data data, valueT value, + sycl::range<3> size) { + std::vector event_list; + size_t slice = data.get_pitch() * data.get_y(); + unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char *data_ptr = data_surface; + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); + data_ptr += data.get_pitch(); + } + data_surface += slice; + } + return event_list; +} + +/** + * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @return An event list representing the memset operations. + */ +template +static inline std::vector +dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, + size_t y) { + return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, + sycl::range<3>(x, y, 1)); +} + +enum class pointer_access_attribute { + host_only = 0, + device_only, + host_device, + end +}; + +static pointer_access_attribute get_pointer_attribute(sycl::queue &q, + const void *ptr) { +#ifdef DPCT_USM_LEVEL_NONE + return mem_mgr::instance().is_device_ptr(ptr) + ? pointer_access_attribute::device_only + : pointer_access_attribute::host_only; +#else + switch (sycl::get_pointer_type(ptr, q.get_context())) { + case sycl::usm::alloc::unknown: + return pointer_access_attribute::host_only; + case sycl::usm::alloc::device: + return pointer_access_attribute::device_only; + case sycl::usm::alloc::shared: + case sycl::usm::alloc::host: + return pointer_access_attribute::host_device; + } +#endif +} + +static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, + const void *from_ptr, + memcpy_direction dir) { + switch (dir) { + case memcpy_direction::host_to_host: + case memcpy_direction::host_to_device: + case memcpy_direction::device_to_host: + case memcpy_direction::device_to_device: + return dir; + case memcpy_direction::automatic: { + // table[to_attribute][from_attribute] + static const memcpy_direction + direction_table[static_cast(pointer_access_attribute::end)] + [static_cast(pointer_access_attribute::end)] = + {{memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}}; + return direction_table[static_cast(get_pointer_attribute( + q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } +} + +static sycl::event +dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) { + if (!size) + return sycl::event{}; +#ifdef DPCT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + auto real_direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + + switch (real_direction) { + case host_to_host: + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + cgh.host_task([=] { std::memcpy(to_ptr, from_ptr, size); }); + }); + case host_to_device: { + auto alloc = mm.translate_ptr(to_ptr); + size_t offset = (byte_t *)to_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(from_ptr, acc); + }); + } + case device_to_host: { + auto alloc = mm.translate_ptr(from_ptr); + size_t offset = (byte_t *)from_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(acc, to_ptr); + }); + } + case device_to_device: { + auto to_alloc = mm.translate_ptr(to_ptr); + auto from_alloc = mm.translate_ptr(from_ptr); + size_t to_offset = (byte_t *)to_ptr - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_ptr - from_alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, r, to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, r, from_o); + cgh.copy(from_acc, to_acc); + }); + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } +#else + return q.memcpy(to_ptr, from_ptr, size, dep_events); +#endif // DPCT_USM_LEVEL_NONE +} + +// Get actual copy range and make sure it will not exceed range. +static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); +} + +static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); +} + +/// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr +/// and \p from_range to another specified by \p to_ptr and \p to_range. +static inline std::vector +dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) { + // RAII for host pointer + class host_buffer { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() { + if (_buf) { + _q.submit([&](sycl::handler &cgh) { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); + }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } else { + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } else { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: +#ifdef DPCT_USM_LEVEL_NONE + { + auto &mm = mem_mgr::instance(); + auto to_alloc = mm.translate_ptr(to_surface); + auto from_alloc = mm.translate_ptr(from_surface); + size_t to_offset = (byte_t *)to_surface - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_surface - from_alloc.alloc_ptr; + event_list.push_back(q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, + get_copy_range(size, to_slice, to_range.get(0)), to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, + get_copy_range(size, from_slice, from_range.get(0)), from_o); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_acc[get_offset(id, to_slice, to_range.get(0))] = + from_acc[get_offset(id, from_slice, from_range.get(0))]; + }); + })); + } +#else + event_list.push_back(q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); + })); +#endif + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; +} + +/// memcpy 2D/3D matrix specified by pitched_data. +static inline std::vector +dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); +} + +/// memcpy 2D matrix with pitch. +static inline std::vector +dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); +} + +namespace deprecated { + +template +class usm_allocator { +private: + using Alloc = sycl::usm_allocator; + Alloc _impl; + +public: + using value_type = typename std::allocator_traits::value_type; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = typename std::allocator_traits::const_pointer; + using void_pointer = typename std::allocator_traits::void_pointer; + using const_void_pointer = + typename std::allocator_traits::const_void_pointer; + using reference = typename std::allocator_traits::value_type &; + using const_reference = + const typename std::allocator_traits::value_type &; + using difference_type = + typename std::allocator_traits::difference_type; + using size_type = typename std::allocator_traits::size_type; + using propagate_on_container_copy_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_copy_assignment; + using propagate_on_container_move_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_move_assignment; + using propagate_on_container_swap = + typename std::allocator_traits::propagate_on_container_swap; + using is_always_equal = + typename std::allocator_traits::is_always_equal; + + template struct rebind { + typedef usm_allocator other; + }; + + usm_allocator() : _impl(dpct::get_default_queue()) {} + ~usm_allocator() {} + usm_allocator(const usm_allocator &other) : _impl(other._impl) {} + usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} + pointer address(reference r) { return &r; } + const_pointer address(const_reference r) { return &r; } + pointer allocate(size_type cnt, const_void_pointer hint = nullptr) { + return std::allocator_traits::allocate(_impl, cnt, hint); + } + void deallocate(pointer p, size_type cnt) { + std::allocator_traits::deallocate(_impl, p, cnt); + } + size_type max_size() const { + return std::allocator_traits::max_size(_impl); + } + bool operator==(const usm_allocator &other) const { return _impl == other._impl; } + bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } +}; + +} // namespace deprecated + +inline void dpct_free(void *ptr, + const sycl::queue &q) { + if (ptr) { +#ifdef DPCT_USM_LEVEL_NONE + detail::mem_mgr::instance().mem_free(ptr); +#else + sycl::free(ptr, q.get_context()); +#endif // DPCT_USM_LEVEL_NONE + } +} +} // namespace detail + +#ifdef DPCT_USM_LEVEL_NONE +/// Check if the pointer \p ptr represents device pointer or not. +/// +/// \param ptr The pointer to be checked. +/// \returns true if \p ptr is a device pointer. +template +static inline bool is_device_ptr(T ptr) { + if constexpr (std::is_pointer::value) { + return detail::mem_mgr::instance().is_device_ptr(ptr); + } + return false; +} +#endif + +/// Get the buffer and the offset of a piece of memory pointed to by \p ptr. +/// +/// \param ptr Pointer to a piece of memory. +/// If NULL is passed as an argument, an exception will be thrown. +/// \returns a pair containing both the buffer and the offset. +static std::pair get_buffer_and_offset(const void *ptr) { + if (ptr) { + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + size_t offset = (byte_t *)ptr - alloc.alloc_ptr; + return std::make_pair(alloc.buffer, offset); + } else { + throw std::runtime_error( + "NULL pointer argument in get_buffer_and_offset function is invalid"); + } +} + +/// Get the data pointed from \p ptr as a 1D buffer reinterpreted as type T. +template static sycl::buffer get_buffer(const void *ptr) { + if (!ptr) + return sycl::buffer(sycl::range<1>(0)); + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + return alloc.buffer.reinterpret( + sycl::range<1>(alloc.size / sizeof(T))); +} + +/// Get the buffer of a piece of memory pointed to by \p ptr. +/// +/// \param ptr Pointer to a piece of memory. +/// \returns the buffer. +static buffer_t get_buffer(const void *ptr) { + return detail::mem_mgr::instance().translate_ptr(ptr).buffer; +} + +/// A wrapper class contains an accessor and an offset. +template +class access_wrapper { + sycl::accessor accessor; + size_t offset; + +public: + /// Construct the accessor wrapper for memory pointed by \p ptr. + /// + /// \param ptr Pointer to memory. + /// \param cgh The command group handler. + access_wrapper(const void *ptr, sycl::handler &cgh) + : accessor(get_buffer(ptr).get_access(cgh)), offset(0) { + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + offset = (byte_t *)ptr - alloc.alloc_ptr; + } + + /// Get the device pointer. + /// + /// \returns a device pointer with offset. + dataT get_raw_pointer() const { return (dataT)(&accessor[0] + offset); } +}; + +/// Get the accessor for memory pointed by \p ptr. +/// +/// \param ptr Pointer to memory. +/// If NULL is passed as an argument, an exception will be thrown. +/// \param cgh The command group handler. +/// \returns an accessor. +template +static sycl::accessor +get_access(const void *ptr, sycl::handler &cgh) { + if (ptr) { + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + return alloc.buffer.get_access(cgh); + } else { + throw std::runtime_error( + "NULL pointer argument in get_access function is invalid"); + } +} + +/// Allocate memory block on the device. +/// \param num_bytes Number of bytes to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +template +static inline void *dpct_malloc(T num_bytes, + sycl::queue &q = get_default_queue()) { + return detail::dpct_malloc(static_cast(num_bytes), q); +} + +/// Get the host pointer from a buffer that is mapped to virtual pointer ptr. +/// \param ptr Virtual Pointer mapped to device buffer +/// \returns A host pointer +template static inline T *get_host_ptr(const void *ptr) { + auto BufferOffset = get_buffer_and_offset(ptr); + auto host_ptr = + BufferOffset.first.get_host_access() + .get_pointer(); + return (T *)(host_ptr + BufferOffset.second); +} + +/// Allocate memory block for 3D array on the device. +/// \param size Size of the memory block, in bytes. +/// \param q Queue to execute the allocate task. +/// \returns A pitched_data object which stores the memory info. +static inline pitched_data +dpct_malloc(sycl::range<3> size, sycl::queue &q = get_default_queue()) { + pitched_data pitch(nullptr, 0, size.get(0), size.get(1)); + size_t pitch_size; + pitch.set_data_ptr(detail::dpct_malloc(pitch_size, size.get(0), size.get(1), + size.get(2), q)); + pitch.set_pitch(pitch_size); + return pitch; +} + +/// Allocate memory block for 2D array on the device. +/// \param [out] pitch Aligned size of x in bytes. +/// \param x Range in dim x. +/// \param y Range in dim y. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, + sycl::queue &q = get_default_queue()) { + return detail::dpct_malloc(pitch, x, y, 1, q); +} + +/// free +/// \param ptr Point to free. +/// \param q Queue to execute the free task. +/// \returns no return value. +static inline void dpct_free(void *ptr, + sycl::queue &q = get_default_queue()) { + detail::dpct_free(ptr, q); +} + +/// Free the device memory pointed by a batch of pointers in \p pointers which +/// are related to \p q after \p events completed. +/// +/// \param pointers The pointers point to the device memory requested to be freed. +/// \param events The events to be waited. +/// \param q The sycl::queue the memory relates to. +inline void async_dpct_free(const std::vector &pointers, + const std::vector &events, + sycl::queue &q = get_default_queue()) { + q.submit([&](sycl::handler &cgh) { + cgh.depends_on(events); + cgh.host_task([=] { + for (auto p : pointers) + if (p) { + detail::dpct_free(p, q); + } + }); + }); +} + +/// Synchronously copies \p size bytes from the address specified by \p from_ptr +/// to the address specified by \p to_ptr. The value of \p direction is used to +/// set the copy direction, it can be \a host_to_host, \a host_to_device, +/// \a device_to_host, \a device_to_device or \a automatic. The function will +/// return after the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param from_ptr Pointer to source memory address. +/// \param size Number of bytes to be copied. +/// \param direction Direction of the copy. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static void dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction = automatic, + sycl::queue &q = get_default_queue()) { + detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction).wait(); +} + +/// Asynchronously copies \p size bytes from the address specified by \p +/// from_ptr to the address specified by \p to_ptr. The value of \p direction is +/// used to set the copy direction, it can be \a host_to_host, \a +/// host_to_device, \a device_to_host, \a device_to_device or \a automatic. The +/// return of the function does NOT guarantee the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param from_ptr Pointer to source memory address. +/// \param size Number of bytes to be copied. +/// \param direction Direction of the copy. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction = automatic, + sycl::queue &q = dpct::get_default_queue()) { + detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); +} + +/// Synchronously copies 2D matrix specified by \p x and \p y from the address +/// specified by \p from_ptr to the address specified by \p to_ptr, while \p +/// from_pitch and \p to_pitch are the range of dim x in bytes of the matrix +/// specified by \p from_ptr and \p to_ptr. The value of \p direction is used to +/// set the copy direction, it can be \a host_to_host, \a host_to_device, \a +/// device_to_host, \a device_to_device or \a automatic. The function will +/// return after the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param to_pitch Range of dim x in bytes of destination matrix. +/// \param from_ptr Pointer to source memory address. +/// \param from_pitch Range of dim x in bytes of source matrix. +/// \param x Range of dim x of matrix to be copied. +/// \param y Range of dim y of matrix to be copied. +/// \param direction Direction of the copy. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void dpct_memcpy(void *to_ptr, size_t to_pitch, + const void *from_ptr, size_t from_pitch, + size_t x, size_t y, + memcpy_direction direction = automatic, + sycl::queue &q = dpct::get_default_queue()) { + sycl::event::wait(detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, + from_pitch, x, y, direction)); +} + +/// Asynchronously copies 2D matrix specified by \p x and \p y from the address +/// specified by \p from_ptr to the address specified by \p to_ptr, while \p +/// \p from_pitch and \p to_pitch are the range of dim x in bytes of the matrix +/// specified by \p from_ptr and \p to_ptr. The value of \p direction is used to +/// set the copy direction, it can be \a host_to_host, \a host_to_device, \a +/// device_to_host, \a device_to_device or \a automatic. The return of the +/// function does NOT guarantee the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param to_pitch Range of dim x in bytes of destination matrix. +/// \param from_ptr Pointer to source memory address. +/// \param from_pitch Range of dim x in bytes of source matrix. +/// \param x Range of dim x of matrix to be copied. +/// \param y Range of dim y of matrix to be copied. +/// \param direction Direction of the copy. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void +async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, + size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic, + sycl::queue &q = get_default_queue()) { + detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, + direction); +} + +/// Synchronously copies a subset of a 3D matrix specified by \p to to another +/// 3D matrix specified by \p from. The from and to position info are specified +/// by \p from_pos and \p to_pos The copied matrix size is specified by \p size. +/// The value of \p direction is used to set the copy direction, it can be \a +/// host_to_host, \a host_to_device, \a device_to_host, \a device_to_device or +/// \a automatic. The function will return after the copy is completed. +/// +/// \param to Destination matrix info. +/// \param to_pos Position of destination. +/// \param from Source matrix info. +/// \param from_pos Position of destination. +/// \param size Range of the submatrix to be copied. +/// \param direction Direction of the copy. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void dpct_memcpy(pitched_data to, sycl::id<3> to_pos, + pitched_data from, sycl::id<3> from_pos, + sycl::range<3> size, + memcpy_direction direction = automatic, + sycl::queue &q = dpct::get_default_queue()) { + sycl::event::wait( + detail::dpct_memcpy(q, to, to_pos, from, from_pos, size, direction)); +} + +/// Asynchronously copies a subset of a 3D matrix specified by \p to to another +/// 3D matrix specified by \p from. The from and to position info are specified +/// by \p from_pos and \p to_pos The copied matrix size is specified by \p size. +/// The value of \p direction is used to set the copy direction, it can be \a +/// host_to_host, \a host_to_device, \a device_to_host, \a device_to_device or +/// \a automatic. The return of the function does NOT guarantee the copy is +/// completed. +/// +/// \param to Destination matrix info. +/// \param to_pos Position of destination. +/// \param from Source matrix info. +/// \param from_pos Position of destination. +/// \param size Range of the submatrix to be copied. +/// \param direction Direction of the copy. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void +async_dpct_memcpy(pitched_data to, sycl::id<3> to_pos, pitched_data from, + sycl::id<3> from_pos, sycl::range<3> size, + memcpy_direction direction = automatic, + sycl::queue &q = get_default_queue()) { + detail::dpct_memcpy(q, to, to_pos, from, from_pos, size, direction); +} +/** + * @brief Sets 1 byte data \p value to the first \p size elements starting from + * \p dev_ptr in \p q synchronously. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @param [in] q The queue in which the operation is done. + */ +static void dpct_memset(void *dev_ptr, int value, size_t size, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, dev_ptr, value, size).wait(); +} + +/** + * @brief Sets 2 bytes data \p value to the first \p size elements starting from + * \p dev_ptr in \p q synchronously. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @param [in] q The queue in which the operation is done. + */ +static void dpct_memset_d16(void *dev_ptr, unsigned short value, size_t size, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, dev_ptr, value, size).wait(); +} +/** + * @brief Sets 4 bytes data \p value to the first \p size elements starting from + * \p dev_ptr in \p q synchronously. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @param [in] q The queue in which the operation is done. + */ +static void dpct_memset_d32(void *dev_ptr, unsigned int value, size_t size, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, dev_ptr, value, size).wait(); +} + +/** + * @brief Sets 1 byte data \p value to the first \p size elements starting from + * \p dev_ptr in \p q asynchronously. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @param [in] q The queue in which the operation is done. + */ +static void async_dpct_memset(void *dev_ptr, int value, size_t size, + sycl::queue &q = dpct::get_default_queue()) { + detail::dpct_memset(q, dev_ptr, value, size); +} +/** + * @brief Sets 2 bytes data \p value to the first \p size elements starting from + * \p dev_ptr in \p q asynchronously. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @param [in] q The queue in which the operation is done. + */ +static void async_dpct_memset_d16(void *dev_ptr, unsigned short value, size_t size, + sycl::queue &q = dpct::get_default_queue()) { + detail::dpct_memset(q, dev_ptr, value, size); +} +/** + * @brief Sets 4 bytes data \p value to the first \p size elements starting from + * \p dev_ptr in \p q asynchronously. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @param [in] q The queue in which the operation is done. + */ +static void async_dpct_memset_d32(void *dev_ptr, unsigned int value, size_t size, + sycl::queue &q = dpct::get_default_queue()) { + detail::dpct_memset(q, dev_ptr, value, size); +} + +/** + * @brief Sets 1 byte data \p val to the pitched 2D memory region pointed by \p ptr in \p q + * synchronously. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void dpct_memset(void *ptr, size_t pitch, int val, size_t x, + size_t y, + sycl::queue &q = get_default_queue()) { + sycl::event::wait(detail::dpct_memset(q, ptr, pitch, val, x, y)); +} +/** + * @brief Sets 2 bytes data \p val to the pitched 2D memory region pointed by \p ptr in \p q + * synchronously. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void dpct_memset_d16(void *ptr, size_t pitch, unsigned short val, size_t x, + size_t y, + sycl::queue &q = get_default_queue()) { + sycl::event::wait(detail::dpct_memset(q, ptr, pitch, val, x, y)); +} +/** + * @brief Sets 4 bytes data \p val to the pitched 2D memory region pointed by \p ptr in \p q + * synchronously. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void dpct_memset_d32(void *ptr, size_t pitch, unsigned int val, size_t x, + size_t y, + sycl::queue &q = get_default_queue()) { + sycl::event::wait(detail::dpct_memset(q, ptr, pitch, val, x, y)); +} + +/** + * @brief Sets 1 byte data \p val to the pitched 2D memory region pointed by \p ptr in \p q + * asynchronously. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void async_dpct_memset(void *ptr, size_t pitch, int val, size_t x, + size_t y, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, ptr, pitch, val, x, y); +} + +/** + * @brief Sets 2 bytes data \p val to the pitched 2D memory region pointed by \p ptr in \p q + * asynchronously. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void async_dpct_memset_d16(void *ptr, size_t pitch, + unsigned short val, size_t x, size_t y, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, ptr, pitch, val, x, y); +} + +/** + * @brief Sets 4 bytes data \p val to the pitched 2D memory region pointed by \p ptr in \p q + * asynchronously. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void async_dpct_memset_d32(void *ptr, size_t pitch, + unsigned int val, size_t x, size_t y, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, ptr, pitch, val, x, y); +} + +/** + * @brief Sets 1 byte data \p value to the 3D memory region pointed by \p data in \p q + * synchronously. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void dpct_memset(pitched_data pitch, int val, + sycl::range<3> size, + sycl::queue &q = get_default_queue()) { + sycl::event::wait(detail::dpct_memset(q, pitch, val, size)); +} + +/** + * @brief Sets 1 byte data \p value to the 3D memory region pointed by \p data in \p q + * asynchronously. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @param [in] q The queue in which the operation is done. + */ +static inline void async_dpct_memset(pitched_data pitch, int val, + sycl::range<3> size, + sycl::queue &q = get_default_queue()) { + detail::dpct_memset(q, pitch, val, size); +} + +/// dpct accessor used as device function parameter. +template class accessor; +template class accessor { +public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<3>; + accessor(pointer_t data, const sycl::range<3> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<3> &in_range) + : accessor(acc.get_pointer(), in_range) {} + accessor operator[](size_t index) const { + sycl::range<2> sub(_range.get(1), _range.get(2)); + return accessor(_data + index * sub.size(), sub); + } + + pointer_t get_ptr() const { return _data; } + +private: + pointer_t _data; + sycl::range<3> _range; +}; +template class accessor { +public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<2>; + accessor(pointer_t data, const sycl::range<2> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<2> &in_range) + : accessor(acc.get_pointer(), in_range) {} + + pointer_t operator[](size_t index) const { + return _data + _range.get(1) * index; + } + + pointer_t get_ptr() const { return _data; } + +private: + pointer_t _data; + sycl::range<2> _range; +}; + +namespace detail { +/// Device variable with address space of shared, global or constant. +template +class device_memory { +public: + using accessor_t = + typename detail::memory_traits::template accessor_t; + using value_t = typename detail::memory_traits::value_t; + using dpct_accessor_t = dpct::accessor; + + device_memory() : device_memory(sycl::range(1)) {} + + /// Constructor of 1-D array with initializer list + device_memory( + const sycl::range &in_range, + std::initializer_list &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range.size()); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); + } + + /// Constructor of 2-D array with initializer list + template + device_memory( + const typename std::enable_if>::type &in_range, + std::initializer_list> &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range[0]); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + auto tmp_data = _host_ptr; + for (auto sub_list : init_list) { + assert(sub_list.size() <= in_range[1]); + std::memcpy(tmp_data, sub_list.begin(), sub_list.size() * sizeof(T)); + tmp_data += in_range[1]; + } + } + + /// Constructor with range + device_memory(const sycl::range &range_in) + : _size(range_in.size() * sizeof(T)), _range(range_in), _reference(false), + _host_ptr(nullptr), _device_ptr(nullptr) { + static_assert( + (Memory == global) || (Memory == constant) || (Memory == shared), + "device memory region should be global, constant or shared"); + // Make sure that singleton class mem_mgr and dev_mgr will destruct later + // than this. + detail::mem_mgr::instance(); + dev_mgr::instance(); + } + + /// Constructor with range + template + device_memory(Args... Arguments) + : device_memory(sycl::range(Arguments...)) {} + + ~device_memory() { + if (_device_ptr && !_reference) + dpct::dpct_free(_device_ptr); + if (_host_ptr) + std::free(_host_ptr); + } + + /// Allocate memory with default queue, and init memory if has initial value. + void init() { + init(dpct::get_default_queue()); + } + /// Allocate memory with specified queue, and init memory if has initial value. + void init(sycl::queue &q) { + if (_device_ptr) + return; + if (!_size) + return; + allocate_device(q); + if (_host_ptr) + detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, host_to_device); + } + + /// The variable is assigned to a device pointer. + void assign(value_t *src, size_t size) { + this->~device_memory(); + new (this) device_memory(src, size); + } + + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr() { + return get_ptr(get_default_queue()); + } + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr(sycl::queue &q) { + init(q); + return _device_ptr; + } + + /// Get the device memory object size in bytes. + size_t get_size() { return _size; } + + template + typename std::enable_if::type &operator[](size_t index) { + init(); +#ifdef DPCT_USM_LEVEL_NONE + return dpct::get_buffer::type>( + _device_ptr) + .template get_access()[index]; +#else + return _device_ptr[index]; +#endif // DPCT_USM_LEVEL_NONE + } + +#ifdef DPCT_USM_LEVEL_NONE + /// Get sycl::accessor for the device memory object when usm is not used. + accessor_t get_access(sycl::handler &cgh) { + return get_buffer(_device_ptr) + .template reinterpret(_range) + .template get_access::mode, + detail::memory_traits::target>(cgh); + } +#else + /// Get dpct::accessor with dimension info for the device memory object + /// when usm is used and dimension is greater than 1. + template + typename std::enable_if::type + get_access(sycl::handler &cgh) { + return dpct_accessor_t((T *)_device_ptr, _range); + } +#endif // DPCT_USM_LEVEL_NONE + +private: + device_memory(value_t *memory_ptr, size_t size) + : _size(size), _range(size / sizeof(T)), _reference(true), + _device_ptr(memory_ptr) {} + + void allocate_device(sycl::queue &q) { +#ifndef DPCT_USM_LEVEL_NONE + if (Memory == shared) { + _device_ptr = (value_t *)sycl::malloc_shared( + _size, q.get_device(), q.get_context()); + return; + } +#ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY + if (Memory == constant) { + _device_ptr = (value_t *)sycl::malloc_device( + _size, q.get_device(), q.get_context(), + sycl::ext::oneapi::property::usm::device_read_only()); + return; + } +#endif +#endif + _device_ptr = (value_t *)detail::dpct_malloc(_size, q); + } + + size_t _size; + sycl::range _range; + bool _reference; + value_t *_host_ptr; + value_t *_device_ptr; +}; +template +class device_memory : public device_memory { +public: + using base = device_memory; + using value_t = typename base::value_t; + using accessor_t = + typename detail::memory_traits::template accessor_t<0>; + + /// Constructor with initial value. + device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} + + /// Default constructor + device_memory() : base(1) {} + +#ifdef DPCT_USM_LEVEL_NONE + /// Get sycl::accessor for the device memory object when usm is not used. + accessor_t get_access(sycl::handler &cgh) { + auto buf = get_buffer(base::get_ptr()) + .template reinterpret(sycl::range<1>(1)); + return accessor_t(buf, cgh); + } +#endif // DPCT_USM_LEVEL_NONE +}; +} + +template +using global_memory = detail::device_memory; +template +using constant_memory = detail::device_memory; +template +using shared_memory = detail::device_memory; + +// dpct::deprecated:: is for functionality that was introduced for compatibility +// purpose, but relies on deprecated C++ features, which are either removed or +// will be removed in the future standards. +// Direct use of deprecated functionality in this namespace should be avoided. +namespace deprecated { + +template +using usm_host_allocator = detail::deprecated::usm_allocator; + +template +using usm_device_allocator = detail::deprecated::usm_allocator; +} // namespace deprecated + +class pointer_attributes { +public: + void init(const void *ptr, + sycl::queue &q = dpct::get_default_queue()) { +#ifdef DPCT_USM_LEVEL_NONE + throw std::runtime_error( + "dpct::pointer_attributes: only works for USM pointer."); +#else + memory_type = sycl::get_pointer_type(ptr, q.get_context()); + device_pointer = (memory_type != + sycl::usm::alloc::unknown) ? ptr : nullptr; + host_pointer = (memory_type != + sycl::usm::alloc::unknown) && + (memory_type != sycl::usm::alloc::device) ? ptr : nullptr; + sycl::device device_obj = sycl::get_pointer_device(ptr, q.get_context()); + device_id = dpct::dev_mgr::instance().get_device_id(device_obj); +#endif + } + + sycl::usm::alloc get_memory_type() { + return memory_type; + } + + const void *get_device_pointer() { + return device_pointer; + } + + const void *get_host_pointer() { + return host_pointer; + } + + bool is_memory_shared() { + return memory_type == sycl::usm::alloc::shared; + } + + unsigned int get_device_id() { + return device_id; + } + +private: + sycl::usm::alloc memory_type = sycl::usm::alloc::unknown; + const void *device_pointer = nullptr; + const void *host_pointer = nullptr; + unsigned int device_id = 0; +}; +} // namespace dpct +#endif // __DPCT_MEMORY_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/rng_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/rng_utils.h new file mode 100644 index 0000000..acd6c78 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/rng_utils.h @@ -0,0 +1,535 @@ +//==---- rng_utils.hpp ----------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_RNG_UTILS_HPP__ +#define __DPCT_RNG_UTILS_HPP__ + +#include +#include +#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this. +#include +#endif +#include "device.h" +#include "lib_common_utils.h" + +namespace dpct { +namespace rng { +#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this. +namespace device { +/// The random number generator on device. +/// \tparam engine_t The device random number generator engine. It can only be +/// oneapi::mkl::rng::device::mrg32k3a<1> or +/// oneapi::mkl::rng::device::mrg32k3a<4> or +/// oneapi::mkl::rng::device::philox4x32x10<1> or +/// oneapi::mkl::rng::device::philox4x32x10<4>. +template class rng_generator { + static_assert( + std::disjunction_v< + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>>, + "engine_t can only be oneapi::mkl::rng::device::mrg32k3a<1> or " + "oneapi::mkl::rng::device::mrg32k3a<4> or " + "oneapi::mkl::rng::device::philox4x32x10<1> or " + "oneapi::mkl::rng::device::philox4x32x10<4> or " + "oneapi::mkl::rng::device::mcg59<1>."); + static constexpr bool _is_engine_vec_size_one = std::disjunction_v< + std::is_same>, + std::is_same>, + std::is_same>>; + static constexpr std::uint64_t default_seed = 0; + oneapi::mkl::rng::device::bits _distr_bits; + oneapi::mkl::rng::device::uniform_bits _distr_uniform_bits; + oneapi::mkl::rng::device::gaussian _distr_gaussian_float; + oneapi::mkl::rng::device::gaussian _distr_gaussian_double; + oneapi::mkl::rng::device::lognormal _distr_lognormal_float; + oneapi::mkl::rng::device::lognormal _distr_lognormal_double; + oneapi::mkl::rng::device::poisson _distr_poisson; + oneapi::mkl::rng::device::uniform _distr_uniform_float; + oneapi::mkl::rng::device::uniform _distr_uniform_double; + engine_t _engine; + +public: + /// Default constructor of rng_generator + rng_generator() { _engine = engine_t(default_seed); } + /// Constructor of rng_generator if engine type is not mcg59 + /// \param [in] seed The seed to initialize the engine state. + /// \param [in] num_to_skip Set the number of elements need to be skipped. + /// The number is calculated as: num_to_skip[0] + num_to_skip[1] * 2^64 + + /// num_to_skip[2] * 2^128 + ... + num_to_skip[n-1] * 2^(64*(n-1)) + template >>::type * = nullptr> + rng_generator(std::uint64_t seed, + std::initializer_list num_to_skip) { + _engine = engine_t(seed, num_to_skip); + } + /// Constructor of rng_generator if engine type is mcg59 + /// \param [in] seed The seed to initialize the engine state. + /// \param [in] num_to_skip Set the number of elements need to be skipped. + template >>::type * = nullptr> + rng_generator(std::uint64_t seed, std::uint64_t num_to_skip) { + _engine = engine_t(seed, num_to_skip); + } + + /// Generate random number(s) obeys distribution \tparam distr_t. + /// \tparam T The distribution of the random number. It can only be + /// oneapi::mkl::rng::device::bits, + /// oneapi::mkl::rng::device::uniform_bits, + /// oneapi::mkl::rng::device::gaussian, + /// oneapi::mkl::rng::device::gaussian, + /// oneapi::mkl::rng::device::lognormal, + /// oneapi::mkl::rng::device::lognormal, + /// oneapi::mkl::rng::device::poisson, + /// oneapi::mkl::rng::device::uniform or + /// oneapi::mkl::rng::device::uniform + /// \tparam vec_size The length of the return vector. It can only be 1, 2 + /// or 4. + /// \param distr_params The parameter(s) for lognormal or poisson + /// distribution. + /// \return The vector of the random number(s). + template + auto generate(distr_params_t... distr_params) { + static_assert(vec_size == 1 || vec_size == 2 || vec_size == 4, + "vec_size is not supported."); + static_assert( + std::disjunction_v< + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>, + std::is_same>>, + "distribution is not supported."); + + if constexpr (std::is_same_v< + distr_t, oneapi::mkl::rng::device::bits>) { + return generate_vec(_distr_bits); + } + if constexpr (std::is_same_v< + distr_t, + oneapi::mkl::rng::device::uniform_bits>) { + return generate_vec(_distr_uniform_bits); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_gaussian_float); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_gaussian_double); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_lognormal_float, distr_params..., + 0.0f, 1.0f); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_lognormal_double, distr_params..., + 0.0, 1.0); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_poisson, distr_params...); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_uniform_float); + } + if constexpr (std::is_same_v>) { + return generate_vec(_distr_uniform_double); + } + } + + /// Get the random number generator engine. + /// \return The reference of the internal random number generator engine. + engine_t &get_engine() { return _engine; } + +private: + template + auto generate_vec(distr_t &distr, distr_params_t... distr_params) { + if constexpr (sizeof...(distr_params_t)) { + typename distr_t::param_type pt(distr_params...); + distr.param(pt); + } + if constexpr (vec_size == 4) { + if constexpr (_is_engine_vec_size_one) { + sycl::vec res; + res.x() = oneapi::mkl::rng::device::generate(distr, _engine); + res.y() = oneapi::mkl::rng::device::generate(distr, _engine); + res.z() = oneapi::mkl::rng::device::generate(distr, _engine); + res.w() = oneapi::mkl::rng::device::generate(distr, _engine); + return res; + } else { + return oneapi::mkl::rng::device::generate(distr, _engine); + } + } else if constexpr (vec_size == 1) { + if constexpr (_is_engine_vec_size_one) { + return oneapi::mkl::rng::device::generate(distr, _engine); + } else { + return oneapi::mkl::rng::device::generate_single(distr, _engine); + } + } else if constexpr (vec_size == 2) { + if constexpr (_is_engine_vec_size_one) { + sycl::vec res; + res.x() = oneapi::mkl::rng::device::generate(distr, _engine); + res.y() = oneapi::mkl::rng::device::generate(distr, _engine); + return res; + } else { + sycl::vec res; + res.x() = oneapi::mkl::rng::device::generate_single(distr, _engine); + res.y() = oneapi::mkl::rng::device::generate_single(distr, _engine); + return res; + } + } + } +}; + +} // namespace device +#endif + +namespace host { +namespace detail { +class rng_generator_base { +public: + /// Set the seed of host rng_generator. + /// \param seed The engine seed. + virtual void set_seed(const std::uint64_t seed) = 0; + + /// Set the dimensions of host rng_generator. + /// \param dimensions The engine dimensions. + virtual void set_dimensions(const std::uint32_t dimensions) = 0; + + /// Set the queue of host rng_generator. + /// \param queue The engine queue. + virtual void set_queue(sycl::queue *queue) = 0; + + /// Generate unsigned int random number(s) with 'uniform_bits' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + virtual inline void generate_uniform_bits(unsigned int *output, + std::int64_t n) = 0; + + /// Generate unsigned long long random number(s) with 'uniform_bits' + /// distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + virtual inline void generate_uniform_bits(unsigned long long *output, + std::int64_t n) = 0; + + /// Generate float random number(s) with 'lognormal' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param m Mean of associated normal distribution + /// \param s Standard deviation of associated normal distribution. + virtual inline void generate_lognormal(float *output, std::int64_t n, float m, + float s) = 0; + + /// Generate double random number(s) with 'lognormal' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param m Mean of associated normal distribution + /// \param s Standard deviation of associated normal distribution. + virtual inline void generate_lognormal(double *output, std::int64_t n, + double m, double s) = 0; + + /// Generate float random number(s) with 'gaussian' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param mean Mean of normal distribution + /// \param stddev Standard deviation of normal distribution. + virtual inline void generate_gaussian(float *output, std::int64_t n, + float mean, float stddev) = 0; + + /// Generate double random number(s) with 'gaussian' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param mean Mean of normal distribution + /// \param stddev Standard deviation of normal distribution. + virtual inline void generate_gaussian(double *output, std::int64_t n, + double mean, double stddev) = 0; + + /// Generate unsigned int random number(s) with 'poisson' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param lambda Lambda for the Poisson distribution. + virtual inline void generate_poisson(unsigned int *output, std::int64_t n, + double lambda) = 0; + + /// Generate float random number(s) with 'uniform' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + virtual inline void generate_uniform(float *output, std::int64_t n) = 0; + + /// Generate double random number(s) with 'uniform' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + virtual inline void generate_uniform(double *output, std::int64_t n) = 0; + + /// Skip ahead several random number(s). + /// \param num_to_skip The number of random numbers to be skipped. + virtual void skip_ahead(const std::uint64_t num_to_skip) = 0; + + /// Set the direction numbers of host rng_generator. Only Sobol engine + /// supports this method. + /// \param direction_numbers The engine direction numbers. + virtual void set_direction_numbers( + const std::vector &direction_numbers) = 0; + +protected: + sycl::queue *_queue{&dpct::get_default_queue()}; + std::uint64_t _seed{0}; + std::uint32_t _dimensions{1}; + std::vector _direction_numbers; +}; + +/// The random number generator on host. +template +class rng_generator : public rng_generator_base { +public: + /// Constructor of rng_generator. + rng_generator() : _engine(create_engine(_queue, _seed, _dimensions)) {} + + /// Set the seed of host rng_generator. + /// \param seed The engine seed. + void set_seed(const std::uint64_t seed) { + if (seed == _seed) { + return; + } + _seed = seed; + _engine = create_engine(_queue, _seed, _dimensions); + } + + /// Set the dimensions of host rng_generator. + /// \param dimensions The engine dimensions. + void set_dimensions(const std::uint32_t dimensions) { + if (dimensions == _dimensions) { + return; + } + _dimensions = dimensions; + _engine = create_engine(_queue, _seed, _dimensions); + } + + /// Set the queue of host rng_generator. + /// \param queue The engine queue. + void set_queue(sycl::queue *queue) { + if (queue == _queue) { + return; + } + _queue = queue; + _engine = create_engine(_queue, _seed, _dimensions); + } + + /// Set the direction numbers of Sobol host rng_generator. + /// \param direction_numbers The user-defined direction numbers. + void + set_direction_numbers(const std::vector &direction_numbers) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " + "Interfaces Project does not support this API."); +#else + if constexpr (std::is_same_v) { + if (direction_numbers == _direction_numbers) { + return; + } + _direction_numbers = direction_numbers; + _engine = oneapi::mkl::rng::sobol(*_queue, _direction_numbers); + } else { + throw std::runtime_error("Only Sobol engine supports this method."); + } +#endif + } + + /// Generate unsigned int random number(s) with 'uniform_bits' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + inline void generate_uniform_bits(unsigned int *output, std::int64_t n) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " + "Interfaces Project does not support this API."); +#else + static_assert(sizeof(unsigned int) == sizeof(std::uint32_t)); + generate>( + (std::uint32_t *)output, n); +#endif + } + + /// Generate unsigned long long random number(s) with 'uniform_bits' + /// distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + inline void generate_uniform_bits(unsigned long long *output, + std::int64_t n) { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " + "Interfaces Project does not support this API."); +#else + static_assert(sizeof(unsigned long long) == sizeof(std::uint64_t)); + generate>( + (std::uint64_t *)output, n); +#endif + } + + /// Generate float random number(s) with 'lognormal' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param m Mean of associated normal distribution + /// \param s Standard deviation of associated normal distribution. + inline void generate_lognormal(float *output, std::int64_t n, float m, + float s) { + generate>(output, n, m, s); + } + + /// Generate double random number(s) with 'lognormal' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param m Mean of associated normal distribution + /// \param s Standard deviation of associated normal distribution. + inline void generate_lognormal(double *output, std::int64_t n, double m, + double s) { + generate>(output, n, m, s); + } + + /// Generate float random number(s) with 'gaussian' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param mean Mean of normal distribution + /// \param stddev Standard deviation of normal distribution. + inline void generate_gaussian(float *output, std::int64_t n, float mean, + float stddev) { + generate>(output, n, mean, stddev); + } + + /// Generate double random number(s) with 'gaussian' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param mean Mean of normal distribution + /// \param stddev Standard deviation of normal distribution. + inline void generate_gaussian(double *output, std::int64_t n, double mean, + double stddev) { + generate>(output, n, mean, stddev); + } + + /// Generate unsigned int random number(s) with 'poisson' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + /// \param lambda Lambda for the Poisson distribution. + inline void generate_poisson(unsigned int *output, std::int64_t n, + double lambda) { + generate>(output, n, lambda); + } + + /// Generate float random number(s) with 'uniform' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + inline void generate_uniform(float *output, std::int64_t n) { + generate>(output, n); + } + + /// Generate double random number(s) with 'uniform' distribution. + /// \param output The pointer of the first random number. + /// \param n The number of random numbers. + inline void generate_uniform(double *output, std::int64_t n) { + generate>(output, n); + } + + /// Skip ahead several random number(s). + /// \param num_to_skip The number of random numbers to be skipped. + void skip_ahead(const std::uint64_t num_to_skip) { +#ifndef __INTEL_MKL__ + oneapi::mkl::rng::skip_ahead(_engine, num_to_skip); +#else + if constexpr (std::is_same_v) + throw std::runtime_error("no skip_ahead method of mt2203 engine."); + else + oneapi::mkl::rng::skip_ahead(_engine, num_to_skip); +#endif + } + +private: + static inline engine_t create_engine(sycl::queue *queue, + const std::uint64_t seed, + const std::uint32_t dimensions) { +#ifdef __INTEL_MKL__ + return std::is_same_v + ? engine_t(*queue, dimensions) + : engine_t(*queue, seed); +#else + return engine_t(*queue, seed); +#endif + } + + template + void generate(buffer_t *output, const std::int64_t n, + const distr_params_t... distr_params) { + auto output_buf = dpct::detail::get_memory(output); + oneapi::mkl::rng::generate(distr_t(distr_params...), _engine, n, + output_buf); + } + engine_t _engine{}; +}; +} // namespace detail +} // namespace host + +enum class random_engine_type { + philox4x32x10, + mrg32k3a, + mt2203, + mt19937, + sobol, + mcg59 +}; + +typedef std::shared_ptr host_rng_ptr; + +/// Create a host random number generator. +/// \param type The random engine type. +/// \return The pointer of random number generator. +inline host_rng_ptr create_host_rng(const random_engine_type type) { + switch (type) { + case random_engine_type::philox4x32x10: + return std::make_shared< + rng::host::detail::rng_generator>(); + case random_engine_type::mrg32k3a: + return std::make_shared< + rng::host::detail::rng_generator>(); +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " + "Interfaces Project does not support this API."); +#else + case random_engine_type::mt2203: + return std::make_shared< + rng::host::detail::rng_generator>(); + case random_engine_type::mt19937: + return std::make_shared< + rng::host::detail::rng_generator>(); + case random_engine_type::sobol: + return std::make_shared< + rng::host::detail::rng_generator>(); + case random_engine_type::mcg59: + return std::make_shared< + rng::host::detail::rng_generator>(); +#endif + } +} +} // namespace rng +} // namespace dpct + +#endif // __DPCT_RNG_UTILS_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/util.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/util.h new file mode 100644 index 0000000..d916c59 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/dpct/util.h @@ -0,0 +1,1030 @@ +//==---- util.hpp ---------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef __DPCT_UTIL_HPP__ +#define __DPCT_UTIL_HPP__ + +#include +#include +#include +#include +#include + +// TODO: Remove these function definitions once they exist in the DPC++ compiler +#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER) +template +__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __attribute__((noduplicate)) +T __spirv_GroupNonUniformShuffle(__spv::Scope::Flag, T, unsigned) noexcept; + +template +__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __attribute__((noduplicate)) +T __spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag, T, unsigned) noexcept; + +template +__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __attribute__((noduplicate)) +T __spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, T, unsigned) noexcept; +#endif + +namespace dpct { + +namespace detail { + +template class generic_error_type { +public: + generic_error_type() = default; + generic_error_type(T value) : value{value} {} + operator T() const { return value; } + +private: + T value; +}; + +} // namespace detail + +using err0 = detail::generic_error_type; +using err1 = detail::generic_error_type; + +template struct integer_sequence {}; +template +struct make_index_sequence + : public make_index_sequence {}; +template +struct make_index_sequence<0, Ints...> : public integer_sequence {}; + +template struct DataType { using T2 = T; }; +template struct DataType> { + using T2 = std::complex; +}; + +inline void matrix_mem_copy(void *to_ptr, const void *from_ptr, int to_ld, + int from_ld, int rows, int cols, int elem_size, + memcpy_direction direction = automatic, + sycl::queue &queue = dpct::get_default_queue(), + bool async = false) { + if (to_ptr == from_ptr && to_ld == from_ld) { + return; + } + + if (to_ld == from_ld) { + size_t copy_size = elem_size * ((cols - 1) * (size_t)to_ld + rows); + if (async) + detail::dpct_memcpy(queue, (void *)to_ptr, (void *)from_ptr, + copy_size, direction); + else + detail::dpct_memcpy(queue, (void *)to_ptr, (void *)from_ptr, + copy_size, direction).wait(); + } else { + if (async) + detail::dpct_memcpy(queue, to_ptr, from_ptr, elem_size * to_ld, + elem_size * from_ld, elem_size * rows, cols, + direction); + else + sycl::event::wait(detail::dpct_memcpy( + queue, to_ptr, from_ptr, elem_size * to_ld, elem_size * from_ld, + elem_size * rows, cols, direction)); + } +} + +/// Copy matrix data. The default leading dimension is column. +/// \param [out] to_ptr A pointer points to the destination location. +/// \param [in] from_ptr A pointer points to the source location. +/// \param [in] to_ld The leading dimension the destination matrix. +/// \param [in] from_ld The leading dimension the source matrix. +/// \param [in] rows The number of rows of the source matrix. +/// \param [in] cols The number of columns of the source matrix. +/// \param [in] direction The direction of the data copy. +/// \param [in] queue The queue where the routine should be executed. +/// \param [in] async If this argument is true, the return of the function +/// does NOT guarantee the copy is completed. +template +inline void matrix_mem_copy(T *to_ptr, const T *from_ptr, int to_ld, + int from_ld, int rows, int cols, + memcpy_direction direction = automatic, + sycl::queue &queue = dpct::get_default_queue(), + bool async = false) { + using Ty = typename DataType::T2; + matrix_mem_copy((void *)to_ptr, (void *)from_ptr, to_ld, from_ld, rows, cols, + sizeof(Ty), direction, queue, async); +} + +/// Cast the high or low 32 bits of a double to an integer. +/// \param [in] d The double value. +/// \param [in] use_high32 Cast the high 32 bits of the double if true; +/// otherwise cast the low 32 bits. +inline int cast_double_to_int(double d, bool use_high32 = true) { + sycl::vec v0{d}; + auto v1 = v0.as(); + if (use_high32) + return v1[1]; + return v1[0]; +} + +/// Combine two integers, the first as the high 32 bits and the second +/// as the low 32 bits, into a double. +/// \param [in] high32 The integer as the high 32 bits +/// \param [in] low32 The integer as the low 32 bits +inline double cast_ints_to_double(int high32, int low32) { + sycl::int2 v0{low32, high32}; + auto v1 = v0.as>(); + return v1; +} + +/// Reverse the bit order of an unsigned integer +/// \param [in] a Input unsigned integer value +/// \returns Value of a with the bit order reversed +template inline T reverse_bits(T a) { + static_assert(std::is_unsigned::value && std::is_integral::value, + "unsigned integer required"); + if (!a) + return 0; + T mask = 0; + size_t count = 4 * sizeof(T); + mask = ~mask >> count; + while (count) { + a = ((a & mask) << count) | ((a & ~mask) >> count); + count = count >> 1; + mask = mask ^ (mask << count); + } + return a; +} + +/// \param [in] a The first value contains 4 bytes +/// \param [in] b The second value contains 4 bytes +/// \param [in] s The selector value, only lower 16bit used +/// \returns the permutation result of 4 bytes selected in the way +/// specified by \p s from \p a and \p b +inline unsigned int byte_level_permute(unsigned int a, unsigned int b, + unsigned int s) { + unsigned int ret; + ret = + ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) | + (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff) << 8) | + (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff) << 16) | + (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff) << 24); + return ret; +} + +/// Find position of first least significant set bit in an integer. +/// ffs(0) returns 0. +/// +/// \param [in] a Input integer value +/// \returns The position +template inline int ffs(T a) { + static_assert(std::is_integral::value, "integer required"); + return (sycl::ctz(a) + 1) % (sizeof(T) * 8 + 1); +} + +/// select_from_sub_group allows work-items to obtain a copy of a value held by +/// any other work-item in the sub_group. The input sub_group will be divided +/// into several logical sub_groups with id range [0, \p logical_sub_group_size +/// - 1]. Each work-item in logical sub_group gets value from another work-item +/// whose id is \p remote_local_id. If \p remote_local_id is outside the +/// logical sub_group id range, \p remote_local_id will modulo with \p +/// logical_sub_group_size. The \p logical_sub_group_size must be a power of 2 +/// and not exceed input sub_group size. +/// \tparam T Input value type +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] remote_local_id Input source work item id +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T select_from_sub_group(sycl::sub_group g, T x, int remote_local_id, + int logical_sub_group_size = 32) { + unsigned int start_index = + g.get_local_linear_id() / logical_sub_group_size * logical_sub_group_size; + return sycl::select_from_group( + g, x, start_index + remote_local_id % logical_sub_group_size); +} + +/// shift_sub_group_left move values held by the work-items in a sub_group +/// directly to another work-item in the sub_group, by shifting values a fixed +/// number of work-items to the left. The input sub_group will be divided into +/// several logical sub_groups with id range [0, \p logical_sub_group_size - 1]. +/// Each work-item in logical sub_group gets value from another work-item whose +/// id is caller's id adds \p delta. If calculated id is outside the logical +/// sub_group id range, the work-item will get value from itself. The \p +/// logical_sub_group_size must be a power of 2 and not exceed input sub_group +/// size. +/// \tparam T Input value type +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] delta Input delta +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T shift_sub_group_left(sycl::sub_group g, T x, unsigned int delta, + int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int end_index = + (id / logical_sub_group_size + 1) * logical_sub_group_size; + T result = sycl::shift_group_left(g, x, delta); + if ((id + delta) >= end_index) { + result = x; + } + return result; +} + +/// shift_sub_group_right move values held by the work-items in a sub_group +/// directly to another work-item in the sub_group, by shifting values a fixed +/// number of work-items to the right. The input sub_group will be divided into +/// several logical_sub_groups with id range [0, \p logical_sub_group_size - 1]. +/// Each work-item in logical_sub_group gets value from another work-item whose +/// id is caller's id subtracts \p delta. If calculated id is outside the +/// logical sub_group id range, the work-item will get value from itself. The \p +/// logical_sub_group_size must be a power of 2 and not exceed input sub_group +/// size. +/// \tparam T Input value type +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] delta Input delta +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T shift_sub_group_right(sycl::sub_group g, T x, unsigned int delta, + int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + T result = sycl::shift_group_right(g, x, delta); + if ((id - start_index) < delta) { + result = x; + } + return result; +} + +/// permute_sub_group_by_xor permutes values by exchanging values held by pairs +/// of work-items identified by computing the bitwise exclusive OR of the +/// work-item id and some fixed mask. The input sub_group will be divided into +/// several logical sub_groups with id range [0, \p logical_sub_group_size - 1]. +/// Each work-item in logical sub_group gets value from another work-item whose +/// id is bitwise exclusive OR of the caller's id and \p mask. If calculated id +/// is outside the logical sub_group id range, the work-item will get value from +/// itself. The \p logical_sub_group_size must be a power of 2 and not exceed +/// input sub_group size. +/// \tparam T Input value type +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] mask Input mask +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, + int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + unsigned int target_offset = (id % logical_sub_group_size) ^ mask; + return sycl::select_from_group(g, x, + target_offset < logical_sub_group_size + ? start_index + target_offset + : id); +} + +/// The function match_any_over_sub_group conducts a comparison of values +/// across work-items within a sub-group. match_any_over_sub_group return a mask +/// in which some bits are set to 1, indicating that the \p value provided by +/// the work-item represented by these bits are equal. The n-th bit of mask +/// representing the work-item with id n. The parameter \p member_mask +/// indicating the work-items participating the call. +/// \tparam T Input value type +/// \param [in] g Input sub_group +/// \param [in] member_mask Input mask +/// \param [in] value Input value +/// \returns The result +template +unsigned int match_any_over_sub_group(sycl::sub_group g, unsigned member_mask, + T value) { + static_assert(std::is_arithmetic_v, "Value type must be arithmetic type."); + if (!member_mask) { + return 0; + } + unsigned int id = g.get_local_linear_id(); + unsigned int flag = 0, result = 0, reduce_result = 0; + unsigned int bit_index = 0x1 << id; + bool is_participate = member_mask & bit_index; + T broadcast_value = 0; + bool matched = false; + while (flag != member_mask) { + broadcast_value = + sycl::select_from_group(g, value, sycl::ctz((~flag & member_mask))); + reduce_result = sycl::reduce_over_group( + g, is_participate ? (broadcast_value == value ? bit_index : 0) : 0, + sycl::plus<>()); + flag |= reduce_result; + matched = reduce_result & bit_index; + result = matched * reduce_result + (1 - matched) * result; + } + return result; +} + +/// The function match_all_over_sub_group conducts a comparison of values +/// across work-items within a sub-group. match_all_over_sub_group return \p +/// member_mask and predicate \p pred will be set to 1 if all \p value that +/// provided by each work-item in \p member_mask are equal, otherwise return 0 +/// and the predicate \p pred will be set to 0. The n-th bit of \p member_mask +/// representing the work-item with id n. The parameter \p member_mask +/// indicating the work-items participating the call. +/// \tparam T Input value type +/// \param [in] g Input sub_group +/// \param [in] member_mask Input mask +/// \param [in] value Input value +/// \param [out] pred Output predicate +/// \returns The result +template +unsigned int match_all_over_sub_group(sycl::sub_group g, unsigned member_mask, + T value, int *pred) { + static_assert(std::is_arithmetic_v, "Value type must be arithmetic type."); + if (!member_mask) { + return 0; + } + unsigned int id = g.get_local_linear_id(); + unsigned int bit_index = 0x1 << id; + bool is_participate = member_mask & bit_index; + T broadcast_value = sycl::select_from_group(g, value, sycl::ctz(member_mask)); + unsigned int reduce_result = sycl::reduce_over_group( + g, + (member_mask & bit_index) ? (broadcast_value == value ? bit_index : 0) + : 0, + sycl::plus<>()); + bool all_equal = (reduce_result == member_mask); + *pred = is_participate & all_equal; + return all_equal * member_mask; +} + +namespace experimental { +/// Masked version of select_from_sub_group, which execute masked sub-group +/// operation. The parameter member_mask indicating the work-items participating +/// the call. Whether the n-th bit is set to 1 representing whether the +/// work-item with id n is participating the call. All work-items named in +/// member_mask must be executed with the same member_mask, or the result is +/// undefined. +/// \tparam T Input value type +/// \param [in] member_mask Input mask +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] remote_local_id Input source work item id +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T select_from_sub_group(unsigned int member_mask, + sycl::sub_group g, T x, int remote_local_id, + int logical_sub_group_size = 32) { + unsigned int start_index = + g.get_local_linear_id() / logical_sub_group_size * logical_sub_group_size; + unsigned logical_remote_id = + start_index + remote_local_id % logical_sub_group_size; +#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER) +#if defined(__SPIR__) + return __spirv_GroupNonUniformShuffle(__spv::Scope::Subgroup, x, logical_remote_id); +#else + throw sycl::exception(sycl::errc::runtime, "Masked version of select_from_sub_group " + "only supports SPIR-V backends."); +#endif // __SPIR__ +#else + (void)g; + (void)x; + (void)remote_local_id; + (void)logical_sub_group_size; + (void)member_mask; + throw sycl::exception(sycl::errc::runtime, "Masked version of select_from_sub_group not " + "supported on host device and none intel compiler."); +#endif // __SYCL_DEVICE_ONLY__ && __INTEL_LLVM_COMPILER +} + +/// Masked version of shift_sub_group_left, which execute masked sub-group +/// operation. The parameter member_mask indicating the work-items participating +/// the call. Whether the n-th bit is set to 1 representing whether the +/// work-item with id n is participating the call. All work-items named in +/// member_mask must be executed with the same member_mask, or the result is +/// undefined. +/// \tparam T Input value type +/// \param [in] member_mask Input mask +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] delta Input delta +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T shift_sub_group_left(unsigned int member_mask, + sycl::sub_group g, T x, unsigned int delta, + int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int end_index = + (id / logical_sub_group_size + 1) * logical_sub_group_size; +#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER) +#if defined(__SPIR__) + T result = __spirv_GroupNonUniformShuffleDown(__spv::Scope::Subgroup, x, delta); + if ((id + delta) >= end_index) { + result = x; + } + return result; +#else + throw sycl::exception(sycl::errc::runtime, "Masked version of shift_sub_group_left " + "only supports SPIR-V backends."); +#endif // __SPIR__ +#else + (void)g; + (void)x; + (void)delta; + (void)logical_sub_group_size; + (void)member_mask; + throw sycl::exception(sycl::errc::runtime, "Masked version of select_from_sub_group not " + "supported on host device and none intel compiler."); +#endif // __SYCL_DEVICE_ONLY__ && __INTEL_LLVM_COMPILER +} + +/// Masked version of shift_sub_group_right, which execute masked sub-group +/// operation. The parameter member_mask indicating the work-items participating +/// the call. Whether the n-th bit is set to 1 representing whether the +/// work-item with id n is participating the call. All work-items named in +/// member_mask must be executed with the same member_mask, or the result is +/// undefined. +/// \tparam T Input value type +/// \param [in] member_mask Input mask +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] delta Input delta +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T shift_sub_group_right(unsigned int member_mask, + sycl::sub_group g, T x, unsigned int delta, + int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; +#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER) +#if defined(__SPIR__) + T result = __spirv_GroupNonUniformShuffleUp(__spv::Scope::Subgroup, x, delta); + if ((id - start_index) < delta) { + result = x; + } + return result; +#else + throw sycl::exception(sycl::errc::runtime, "Masked version of shift_sub_group_right " + "only supports SPIR-V backends."); +#endif // __SPIR__ +#else + (void)g; + (void)x; + (void)delta; + (void)logical_sub_group_size; + (void)member_mask; + throw sycl::exception(sycl::errc::runtime, "Masked version of select_from_sub_group not " + "supported on host device and none intel compiler."); +#endif // __SYCL_DEVICE_ONLY && __INTEL_LLVM_COMPILER +} + +/// Masked version of permute_sub_group_by_xor, which execute masked sub-group +/// operation. The parameter member_mask indicating the work-items participating +/// the call. Whether the n-th bit is set to 1 representing whether the +/// work-item with id n is participating the call. All work-items named in +/// member_mask must be executed with the same member_mask, or the result is +/// undefined. +/// \tparam T Input value type +/// \param [in] member_mask Input mask +/// \param [in] g Input sub_group +/// \param [in] x Input value +/// \param [in] mask Input mask +/// \param [in] logical_sub_group_size Input logical sub_group size +/// \returns The result +template +T permute_sub_group_by_xor(unsigned int member_mask, + sycl::sub_group g, T x, unsigned int mask, + int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + unsigned int target_offset = (id % logical_sub_group_size) ^ mask; + unsigned logical_remote_id = (target_offset < logical_sub_group_size) ? start_index + target_offset : id; +#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER) +#if defined(__SPIR__) + return __spirv_GroupNonUniformShuffle(__spv::Scope::Subgroup, x, logical_remote_id); +#else + throw sycl::exception(sycl::errc::runtime, "Masked version of permute_sub_group_by_xor " + "only supports SPIR-V backends."); +#endif // __SPIR__ +#else + (void)g; + (void)x; + (void)mask; + (void)logical_sub_group_size; + (void)member_mask; + throw sycl::exception(sycl::errc::runtime, "Masked version of select_from_sub_group not " + "supported on host device and none intel compiler."); +#endif // __SYCL_DEVICE_ONLY__ && __INTEL_LLVM_COMPILER +} +} // namespace experimental + +/// Computes the multiplication of two complex numbers. +/// \tparam T Complex element type +/// \param [in] x The first input complex number +/// \param [in] y The second input complex number +/// \returns The result +template +sycl::vec cmul(sycl::vec x, sycl::vec y) { + std::complex t1(x[0], x[1]), t2(y[0], y[1]); + t1 = t1 * t2; + return sycl::vec(t1.real(), t1.imag()); +} + +/// Computes the division of two complex numbers. +/// \tparam T Complex element type +/// \param [in] x The first input complex number +/// \param [in] y The second input complex number +/// \returns The result +template +sycl::vec cdiv(sycl::vec x, sycl::vec y) { + std::complex t1(x[0], x[1]), t2(y[0], y[1]); + t1 = t1 / t2; + return sycl::vec(t1.real(), t1.imag()); +} + +/// Computes the magnitude of a complex number. +/// \tparam T Complex element type +/// \param [in] x The input complex number +/// \returns The result +template +T cabs(sycl::vec x) { + std::complex t(x[0], x[1]); + return std::abs(t); +} + +/// Computes the complex conjugate of a complex number. +/// \tparam T Complex element type +/// \param [in] x The input complex number +/// \returns The result +template +sycl::vec conj(sycl::vec x) { + std::complex t(x[0], x[1]); + t = std::conj(t); + return sycl::vec(t.real(), t.imag()); +} + +inline int get_sycl_language_version() { +#ifdef SYCL_LANGUAGE_VERSION + return SYCL_LANGUAGE_VERSION; +#else + return 202000; +#endif +} + +namespace experimental { +/// Synchronize work items from all work groups within a SYCL kernel. +/// \param [in] item: Represents a work group. +/// \param [in] counter: An atomic object defined on a device memory which can +/// be accessed by work items in all work groups. The initial value of the +/// counter should be zero. +/// Note: Please make sure that all the work items of all work groups within +/// a SYCL kernel can be scheduled actively at the same time on a device. +template +inline void +nd_range_barrier(const sycl::nd_item &item, + sycl::atomic_ref< + unsigned int, sycl::memory_order::seq_cst, + sycl::memory_scope::device, + sycl::access::address_space::global_space> &counter) { + + static_assert(dimensions == 3, "dimensions must be 3."); + + unsigned int num_groups = item.get_group_range(2) * item.get_group_range(1) * + item.get_group_range(0); + + item.barrier(); + + if (item.get_local_linear_id() == 0) { + unsigned int inc = 1; + unsigned int old_arrive = 0; + bool is_group0 = + (item.get_group(2) + item.get_group(1) + item.get_group(0) == 0); + if (is_group0) { + inc = 0x80000000 - (num_groups - 1); + } + + old_arrive = counter.fetch_add(inc); + // Synchronize all the work groups + while (((old_arrive ^ counter.load()) & 0x80000000) == 0) + ; + } + + item.barrier(); +} + +/// Synchronize work items from all work groups within a SYCL kernel. +/// \param [in] item: Represents a work group. +/// \param [in] counter: An atomic object defined on a device memory which can +/// be accessed by work items in all work groups. The initial value of the +/// counter should be zero. +/// Note: Please make sure that all the work items of all work groups within +/// a SYCL kernel can be scheduled actively at the same time on a device. +template <> +inline void +nd_range_barrier(const sycl::nd_item<1> &item, + sycl::atomic_ref< + unsigned int, sycl::memory_order::seq_cst, + sycl::memory_scope::device, + sycl::access::address_space::global_space> &counter) { + unsigned int num_groups = item.get_group_range(0); + + item.barrier(); + + if (item.get_local_linear_id() == 0) { + unsigned int inc = 1; + unsigned int old_arrive = 0; + bool is_group0 = (item.get_group(0) == 0); + if (is_group0) { + inc = 0x80000000 - (num_groups - 1); + } + + old_arrive = counter.fetch_add(inc); + // Synchronize all the work groups + while (((old_arrive ^ counter.load()) & 0x80000000) == 0) + ; + } + + item.barrier(); +} + +/// The logical-group is a logical collection of some work-items within a +/// work-group. +/// Note: Please make sure that the logical-group size is a power of 2 in the +/// range [1, current_sub_group_size]. +template class logical_group { + sycl::nd_item _item; + sycl::group _g; + uint32_t _logical_group_size; + uint32_t _group_linear_range_in_parent; + +public: + /// Dividing \p parent_group into several logical-groups. + /// \param [in] item Current work-item. + /// \param [in] parent_group The group to be divided. + /// \param [in] size The logical-group size. + logical_group(sycl::nd_item item, + sycl::group parent_group, uint32_t size) + : _item(item), _g(parent_group), _logical_group_size(size) { + _group_linear_range_in_parent = + (_g.get_local_linear_range() - 1) / _logical_group_size + 1; + } + logical_group(sycl::nd_item item) + : _item(item), _g(item.get_group()) {} + /// Returns the index of the work-item within the logical-group. + uint32_t get_local_linear_id() const { + return _item.get_local_linear_id() % _logical_group_size; + } + /// Returns the index of the logical-group in the parent group. + uint32_t get_group_linear_id() const { + return _item.get_local_linear_id() / _logical_group_size; + } + /// Returns the number of work-items in the logical-group. + uint32_t get_local_linear_range() const { + if (_g.get_local_linear_range() % _logical_group_size == 0) { + return _logical_group_size; + } + uint32_t last_item_group_id = + _g.get_local_linear_range() / _logical_group_size; + uint32_t first_of_last_group = last_item_group_id * _logical_group_size; + if (_item.get_local_linear_id() >= first_of_last_group) { + return _g.get_local_linear_range() - first_of_last_group; + } else { + return _logical_group_size; + } + } + /// Returns the number of logical-group in the parent group. + uint32_t get_group_linear_range() const { + return _group_linear_range_in_parent; + } +}; + +// The original source of the functions calculate_max_active_wg_per_xecore and +// calculate_max_potential_wg were under the license below: +// +// Copyright (C) Intel Corporation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +/// This function is used for occupancy calculation, it computes the max active +/// work-group number per Xe-Core. Ref to +/// https://github.com/oneapi-src/oneAPI-samples/tree/master/Tools/GPU-Occupancy-Calculator +/// \param [out] num_wg Active work-group number. +/// \param [in] wg_size Work-group size. +/// \param [in] slm_size Share local memory size. +/// \param [in] sg_size Sub-group size. +/// \param [in] used_barrier Whether barrier is used. +/// \param [in] used_large_grf Whether large General Register File is used. +/// \return If no error, returns 0. +/// If \p wg_size exceeds the max work-group size, the max work-group size will +/// be used instead of \p wg_size and returns -1. +inline int calculate_max_active_wg_per_xecore(int *num_wg, int wg_size, + int slm_size = 0, + int sg_size = 32, + bool used_barrier = false, + bool used_large_grf = false) { + int ret = 0; + const int slm_size_per_xe_core = 64 * 1024; + const int max_barrier_registers = 32; + dpct::device_ext &dev = dpct::get_current_device(); + + size_t max_wg_size = dev.get_info(); + if (wg_size > max_wg_size) { + wg_size = max_wg_size; + ret = -1; + } + + int num_threads_ss = 56; + int max_num_wg = 56; + if (dev.has(sycl::aspect::ext_intel_gpu_eu_count_per_subslice) && + dev.has(sycl::aspect::ext_intel_gpu_hw_threads_per_eu)) { + auto eu_count = + dev.get_info(); + auto threads_count = + dev.get_info(); + num_threads_ss = eu_count * threads_count; + max_num_wg = eu_count * threads_count; + } + + if (used_barrier) { + max_num_wg = max_barrier_registers; + } + + // Calculate num_wg_slm + int num_wg_slm = 0; + if (slm_size == 0) { + num_wg_slm = max_num_wg; + } else { + num_wg_slm = std::floor((float)slm_size_per_xe_core / slm_size); + } + + // Calculate num_wg_threads + if (used_large_grf) + num_threads_ss = num_threads_ss / 2; + int num_threads = std::ceil((float)wg_size / sg_size); + int num_wg_threads = std::floor((float)num_threads_ss / num_threads); + + // Calculate num_wg + *num_wg = std::min(num_wg_slm, num_wg_threads); + *num_wg = std::min(*num_wg, max_num_wg); + return ret; +} + +/// This function is used for occupancy calculation, it computes the work-group +/// number and the work-group size which achieves the maximum occupancy of the +/// device potentially. Ref to +/// https://github.com/oneapi-src/oneAPI-samples/tree/master/Tools/GPU-Occupancy-Calculator +/// \param [out] num_wg Work-group number. +/// \param [out] wg_size Work-group size. +/// \param [in] max_ws_size_for_device_code The maximum working work-group size +/// for current device code logic. Zero means no limitation. +/// \param [in] slm_size Share local memory size. +/// \param [in] sg_size Sub-group size. +/// \param [in] used_barrier Whether barrier is used. +/// \param [in] used_large_grf Whether large General Register File is used. +/// \return Returns 0. +inline int calculate_max_potential_wg(int *num_wg, int *wg_size, + int max_ws_size_for_device_code, + int slm_size = 0, int sg_size = 32, + bool used_barrier = false, + bool used_large_grf = false) { + sycl::device &dev = dpct::get_current_device(); + size_t max_wg_size = dev.get_info(); + if (max_ws_size_for_device_code == 0 || + max_ws_size_for_device_code >= max_wg_size) + *wg_size = (int)max_wg_size; + else + *wg_size = max_ws_size_for_device_code; + calculate_max_active_wg_per_xecore(num_wg, *wg_size, slm_size, sg_size, + used_barrier, used_large_grf); + std::uint32_t num_ss = 1; + if (dev.has(sycl::aspect::ext_intel_gpu_slices) && + dev.has(sycl::aspect::ext_intel_gpu_subslices_per_slice)) { + num_ss = + dev.get_info() * + dev.get_info(); + } + num_wg[0] = num_ss * num_wg[0]; + return 0; +} + +/// Supported group type during migration. +enum class group_type { work_group, sub_group, logical_group, root_group }; + +/// The group_base will dispatch the function call to the specific interface +/// based on the group type. +template class group_base { +public: + group_base(sycl::nd_item item) + : nd_item(item), logical_group(item) {} + ~group_base() {} + /// Returns the number of work-items in the group. + size_t get_local_linear_range() { + switch (type) { + case group_type::work_group: + return nd_item.get_group().get_local_linear_range(); + case group_type::sub_group: + return nd_item.get_sub_group().get_local_linear_range(); + case group_type::logical_group: + return logical_group.get_local_linear_range(); + default: + return -1; // Unkonwn group type + } + } + /// Returns the index of the work-item within the group. + size_t get_local_linear_id() { + switch (type) { + case group_type::work_group: + return nd_item.get_group().get_local_linear_id(); + case group_type::sub_group: + return nd_item.get_sub_group().get_local_linear_id(); + case group_type::logical_group: + return logical_group.get_local_linear_id(); + default: + return -1; // Unkonwn group type + } + } + /// Wait for all the elements within the group to complete their execution + /// before proceeding. + void barrier() { + switch (type) { + case group_type::work_group: + sycl::group_barrier(nd_item.get_group()); + break; + case group_type::sub_group: + case group_type::logical_group: + sycl::group_barrier(nd_item.get_sub_group()); + break; + default: + break; + } + } + +protected: + logical_group logical_group; + sycl::nd_item nd_item; + group_type type; +}; + +/// The group class is a container type that can storage supported group_type. +template +class group : public group_base { + using group_base::type; + using group_base::logical_group; + +public: + group(T g, sycl::nd_item item) : group_base(item) { + if constexpr (std::is_same_v) { + type = group_type::sub_group; + } else if constexpr (std::is_same_v>) { + type = group_type::work_group; + } else if constexpr (std::is_same_v>) { + logical_group = g; + type = group_type::logical_group; + } + } +}; +} // namespace experimental + +/// If x <= 2, then return a pointer to the deafult queue; +/// otherwise, return x reinterpreted as a dpct::queue_ptr. +inline queue_ptr int_as_queue_ptr(uintptr_t x) { + return x <= 2 ? + &get_default_queue() + : reinterpret_cast(x); +} + +template +class args_selector; + +/// args_selector is a helper class for extracting arguments from an +/// array of pointers to arguments or buffer of arguments to pass to a +/// kernel function. +/// +/// \param R(Ts...) The type of the kernel +/// \param n_nondefault_params The number of nondefault parameters of the kernel +/// (excluding parameters that like sycl::nd_item, etc.) +/// \param n_default_params The number of default parameters of the kernel +/// +/// Example usage: +/// With the following kernel: +/// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float f=.1) {} +/// and with the declaration: +/// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra); +/// we have: +/// selector.get<0>() returns a reference to sycl::float*, +/// selector.get<1>() returns a reference to int, +/// selector.get<2>() returns a reference to float +template +class args_selector { +private: + void **kernel_params; + char *args_buffer; + + template + static constexpr int account_for_default_params() { + constexpr int n_total_params = sizeof...(Ts); + if constexpr (i >= n_nondefault_params) { + return n_total_params - n_default_params + (i - n_nondefault_params); + } else { + return i; + } + } + +public: + /// Get the type of the ith argument of R(Ts...) + /// \param [in] i Index of parameter to get + /// \returns Type of ith parameter + template + using arg_type = std::tuple_element_t(), + std::tuple>; +private: + template + static constexpr int get_offset() { + if constexpr (i == 0) { + // we can assume args_buffer is properly aligned to the + // first argument + return 0; + } else { + constexpr int prev_off = get_offset(); + constexpr int prev_past_end = prev_off + sizeof(arg_type); + using T = arg_type; + // is the past-the-end of the i-1st element properly aligned + // with the ith element's alignment? + if constexpr (prev_past_end % alignof(T) == 0) { + return prev_past_end; + } + // otherwise bump prev_past_end to match alignment + else { + return prev_past_end + (alignof(T) - (prev_past_end % alignof(T))); + } + } + } + + static char *get_args_buffer(void **extra) { + if (!extra) + return nullptr; + for (; (std::size_t) *extra != 0; ++extra) { + if ((std::size_t) *extra == 1) { + return static_cast(*(extra+1)); + } + } + return nullptr; + } + +public: + /// If kernel_params is nonnull, then args_selector will + /// extract arguments from kernel_params. Otherwise, it + /// will extract them from extra. + /// \param [in] kernel_params Array of pointers to arguments + /// a or null pointer. + /// \param [in] extra Array containing pointer to argument buffer. + args_selector(void **kernel_params, void **extra) + : kernel_params(kernel_params), + args_buffer(get_args_buffer(extra)) + {} + + /// Get a reference to the ith argument extracted from kernel_params + /// or extra. + /// \param [in] i Index of argument to get + /// \returns Reference to the ith argument + template + arg_type &get() { + if (kernel_params) { + return *static_cast*>(kernel_params[i]); + } else { + return *reinterpret_cast*>(args_buffer + get_offset()); + } + } +}; + +#ifdef _WIN32 +#define DPCT_EXPORT __declspec(dllexport) +#else +#define DPCT_EXPORT +#endif + +} // namespace dpct + +#endif // __DPCT_UTIL_HPP__ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/ds_kernel_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/ds_kernel_utils.h index b0a67c0..47c39d0 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/ds_kernel_utils.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/ds_kernel_utils.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -11,7 +26,7 @@ used throughout the codebase. #pragma once #include -#include +#include #ifdef BF16_AVAILABLE #endif @@ -19,29 +34,13 @@ used throughout the codebase. #define DS_HD_INLINE __dpct_inline__ #define DS_D_INLINE __dpct_inline__ -#ifdef __HIP_PLATFORM_AMD__ - -// constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; -#define HALF_PRECISION_AVAILABLE = 1 -#include -#include - -#else // !__HIP_PLATFORM_AMD__ // constexpr variant of warpSize for templating constexpr int hw_warp_size = 32; -#if DPCT_COMPATIBILITY_TEMP >= 530 #define HALF_PRECISION_AVAILABLE = 1 // #define PTX_AVAILABLE -#endif // __CUDA_ARCH__ >= 530 - -#if DPCT_COMPATIBILITY_TEMP >= 800 -#define ASYNC_COPY_AVAILABLE -#endif // __CUDA_ARCH__ >= 800 -#endif //__HIP_PLATFORM_AMD__ inline int next_pow2(const int val) { diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/gemm_test.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/gemm_test.h index f1e948e..1f6dd58 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/gemm_test.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/gemm_test.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -6,12 +21,7 @@ #pragma once #include -#include -#ifndef __HIP_PLATFORM_AMD__ -#endif -#ifdef __HIP_PLATFORM_AMD__ -#include -#endif +#include #include #include #include @@ -19,19 +29,19 @@ #include #include #include "StopWatch.h" -#include "cublas_wrappers.h" +#include "mkl_wrappers.h" #include template void check(T result, char const* const func, const char* const file, int const line) { if (result) { - std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) + + std::cout << (std::string("SYCL runtime error: ") + +file + ":" + std::to_string(line) + " \n"); } } -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_sycl_error(val) check((val), #val, __FILE__, __LINE__) template class GemmTest { @@ -46,18 +56,18 @@ class GemmTest { { dpct::device_ext& dev_ct1 = dpct::get_current_device(); sycl::queue& q_ct1 = dev_ct1.in_order_queue(); - check_cuda_error(DPCT_CHECK_ERROR(A = (T*)sycl::malloc_device(sizeof(T) * M * K, q_ct1))); - check_cuda_error(DPCT_CHECK_ERROR(B = (T*)sycl::malloc_device(sizeof(T) * K * N, q_ct1))); - check_cuda_error(DPCT_CHECK_ERROR(C = (T*)sycl::malloc_device(sizeof(T) * M * N, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(A = (T*)sycl::malloc_device(sizeof(T) * M * K, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(B = (T*)sycl::malloc_device(sizeof(T) * K * N, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(C = (T*)sycl::malloc_device(sizeof(T) * M * N, q_ct1))); } ~GemmTest() { dpct::device_ext& dev_ct1 = dpct::get_current_device(); sycl::queue& q_ct1 = dev_ct1.in_order_queue(); - check_cuda_error(DPCT_CHECK_ERROR(sycl::free(A, q_ct1))); - check_cuda_error(DPCT_CHECK_ERROR(sycl::free(B, q_ct1))); - check_cuda_error(DPCT_CHECK_ERROR(sycl::free(C, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(sycl::free(A, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(sycl::free(B, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(sycl::free(C, q_ct1))); } std::array TestAlgo(int loops) @@ -66,7 +76,7 @@ class GemmTest { float beta = (T)0.0f; int algo_fw = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, + mkl_gemm_ex(handle, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, N, @@ -77,15 +87,11 @@ class GemmTest { B, A, C, -#ifdef __HIP_PLATFORM_AMD__ - static_cast(algo)); -#else static_cast(algo)); -#endif }); int algo_bw1 = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, + mkl_gemm_ex(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::trans, K, @@ -96,15 +102,11 @@ class GemmTest { A, C, B, -#ifdef __HIP_PLATFORM_AMD__ - static_cast(algo)); -#else static_cast(algo)); -#endif }); int algo_bw2 = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, + mkl_gemm_ex(handle, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans, K, @@ -115,11 +117,7 @@ class GemmTest { B, C, A, -#ifdef __HIP_PLATFORM_AMD__ - static_cast(algo)); -#else static_cast(algo)); -#endif }); return std::array({algo_fw, algo_bw1, algo_bw2}); @@ -132,11 +130,7 @@ class GemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_AMD__ - for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; -#else for (int algo = (int)99; algo <= (int)115; -#endif algo++) { int warm_up = 5; for (int i = 0; i < warm_up; ++i) f(algo); @@ -186,11 +180,11 @@ class StridedGemmTest { { dpct::device_ext& dev_ct1 = dpct::get_current_device(); sycl::queue& q_ct1 = dev_ct1.in_order_queue(); - check_cuda_error( + check_sycl_error( DPCT_CHECK_ERROR(A = (T*)sycl::malloc_device(sizeof(T) * M * K * bsz, q_ct1))); - check_cuda_error( + check_sycl_error( DPCT_CHECK_ERROR(B = (T*)sycl::malloc_device(sizeof(T) * K * N * bsz, q_ct1))); - check_cuda_error( + check_sycl_error( DPCT_CHECK_ERROR(C = (T*)sycl::malloc_device(sizeof(T) * M * N * bsz, q_ct1))); } @@ -198,9 +192,9 @@ class StridedGemmTest { { dpct::device_ext& dev_ct1 = dpct::get_current_device(); sycl::queue& q_ct1 = dev_ct1.in_order_queue(); - check_cuda_error(DPCT_CHECK_ERROR(sycl::free(A, q_ct1))); - check_cuda_error(DPCT_CHECK_ERROR(sycl::free(B, q_ct1))); - check_cuda_error(DPCT_CHECK_ERROR(sycl::free(C, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(sycl::free(A, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(sycl::free(B, q_ct1))); + check_sycl_error(DPCT_CHECK_ERROR(sycl::free(C, q_ct1))); } std::array TestAlgo(int loops) @@ -213,7 +207,7 @@ class StridedGemmTest { int stride_b = N * K; int stride_c = M * N; - cublas_strided_batched_gemm(handle, + mkl_strided_batched_gemm(handle, M, N, K, @@ -228,11 +222,7 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ - static_cast(algo)); -#else static_cast(algo)); -#endif }); int algo_bw1 = Run(loops, [=](int algo) { @@ -249,7 +239,7 @@ class StridedGemmTest { : oneapi::mkl::transpose::trans); // Calculate d_A. - cublas_strided_batched_gemm(handle, + mkl_strided_batched_gemm(handle, mb, kb, N, @@ -264,11 +254,7 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ - static_cast(algo)); -#else static_cast(algo)); -#endif }); int algo_bw2 = Run(loops, [=](int algo) { @@ -282,7 +268,7 @@ class StridedGemmTest { int stride_c = N * K; // Calculate d_B. - cublas_strided_batched_gemm(handle, + mkl_strided_batched_gemm(handle, K, N, M, @@ -297,11 +283,7 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ - static_cast(algo)); -#else static_cast(algo)); -#endif }); return std::array({algo_fw, algo_bw1, algo_bw2}); @@ -314,11 +296,7 @@ class StridedGemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_AMD__ - for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; -#else for (int algo = (int)99; algo <= (int)115; -#endif algo++) { int warm_up = 5; for (int i = 0; i < warm_up; ++i) f(algo); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/memory_access_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/memory_access_utils.h index 0af12e3..2784f46 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/memory_access_utils.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/memory_access_utils.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -6,7 +21,7 @@ #pragma once #include -#include +#include #include "ds_kernel_utils.h" /////////////////////////////// Memory Access Utils /////////////////////////////// @@ -657,8 +672,8 @@ namespace internal { #ifdef PTX_AVAILABLE __dpct_inline__ unsigned convert_to_shared(const void* ptr) { -#if __CUDACC_VER_MAJOR__ >= 11 - // In CUDA 11 we have a builtin intrinsic +#if __SYCLCC_VER_MAJOR__ >= 11 + // In SYCL 11 we have a builtin intrinsic return __cvta_generic_to_shared(ptr); #else unsigned ret_val; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/mkl_wrappers.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/mkl_wrappers.h new file mode 100644 index 0000000..a55c803 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/mkl_wrappers.h @@ -0,0 +1,88 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include + +#include + +int mkl_gemm_ex(dpct::queue_ptr handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + int algo = -1); + +int mkl_gemm_ex(dpct::queue_ptr handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const sycl::half* A, + const sycl::half* B, + sycl::half* C, + int algo = 99); + +int mkl_strided_batched_gemm(dpct::queue_ptr handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + oneapi::mkl::transpose op_A, + oneapi::mkl::transpose op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo = -1); + +int mkl_strided_batched_gemm(dpct::queue_ptr handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const sycl::half* A, + const sycl::half* B, + sycl::half* C, + oneapi::mkl::transpose op_A, + oneapi::mkl::transpose op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo = 99); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization.h index c70ef14..3aad3bb 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -6,7 +21,7 @@ #pragma once #include -#include +#include #include "ds_kernel_utils.h" namespace quantize { diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization_utils.h index 82bceba..2ea1894 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization_utils.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/quantization_utils.h @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include #include "conversion_utils.h" #include "ds_kernel_utils.h" diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/reduction_utils.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/reduction_utils.h index c21d19e..fedc859 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/reduction_utils.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/reduction_utils.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -6,7 +21,7 @@ #pragma once #include -#include +#include #include "conversion_utils.h" #include "ds_kernel_utils.h" #include "memory_access_utils.h" @@ -178,23 +193,13 @@ DS_D_INLINE sycl::half element(const sycl::half lhs, const sycl::h template <> DS_D_INLINE sycl::half element(const sycl::half lhs, const sycl::half rhs) { -#if DPCT_COMPATIBILITY_TEMP >= 800 - // Intrinsic limited to Ampere + newer - return __hmax(lhs, rhs); -#else return (lhs > rhs) ? lhs : rhs; -#endif } template <> DS_D_INLINE sycl::half element(const sycl::half lhs, const sycl::half rhs) { -#if DPCT_COMPATIBILITY_TEMP >= 800 - // Intrinsic limited to Ampere + newer - return __hmin(lhs, rhs); -#else return (lhs < rhs) ? lhs : rhs; -#endif } /* sycl::half2 element reduce implementation */ @@ -207,27 +212,19 @@ DS_D_INLINE sycl::half2 element(const sycl::half2 lhs, const sycl: template <> DS_D_INLINE sycl::half2 element(const sycl::half2 lhs, const sycl::half2 rhs) { -#if DPCT_COMPATIBILITY_TEMP >= 800 - return __hmax2(lhs, rhs); -#else sycl::half2 ret_val; ret_val.x() = (lhs.x() > rhs.x()) ? lhs.x() : rhs.x(); ret_val.y() = (lhs.y() > rhs.y()) ? lhs.y() : rhs.y(); return ret_val; -#endif } template <> DS_D_INLINE sycl::half2 element(const sycl::half2 lhs, const sycl::half2 rhs) { -#if DPCT_COMPATIBILITY_TEMP >= 800 - return __hmin2(lhs, rhs); -#else sycl::half2 ret_val; ret_val.x() = (lhs.x() < rhs.x()) ? lhs.x() : rhs.x(); ret_val.y() = (lhs.y() < rhs.y()) ? lhs.y() : rhs.y(); return ret_val; -#endif } template <> @@ -310,55 +307,39 @@ DS_D_INLINE float init() template <> DS_D_INLINE sycl::half init() { - constexpr uint16_t zero = {0x0000}; - return sycl::half(zero); + return sycl::half(0.0); } template <> DS_D_INLINE sycl::half init() { - constexpr uint16_t inf = {0x7C00}; + constexpr sycl::half inf = std::numeric_limits::infinity(); return sycl::half(inf); } template <> DS_D_INLINE sycl::half init() { - constexpr uint16_t neg_inf = {0xFC00}; + constexpr sycl::half neg_inf = -std::numeric_limits::infinity(); return sycl::half(neg_inf); } template <> DS_D_INLINE sycl::half2 init() { -#ifdef __HIP_PLATFORM_AMD__ - return sycl::half2{_Float16_2{0x0000, 0x0000}}; -#else - constexpr sycl::half2 zero = {0x0000, 0x0000}; - return sycl::half2(zero); -#endif + return {0.0, 0.0}; } template <> DS_D_INLINE sycl::half2 init() { -#ifdef __HIP_PLATFORM_AMD__ - return sycl::half2{_Float16_2{0x7C00, 0x7C00}}; -#else - constexpr sycl::half2 inf = {0x7C00, 0x7C00}; - return sycl::half2(inf); -#endif + return {std::numeric_limits::infinity(), std::numeric_limits::infinity()}; } template <> DS_D_INLINE sycl::half2 init() { -#ifdef __HIP_PLATFORM_AMD__ - return sycl::half2{_Float16_2{0xFC00, 0xFC00}}; -#else - constexpr sycl::half2 neg_inf = {0xFC00, 0xFC00}; - return sycl::half2(neg_inf); -#endif + return {-std::numeric_limits::infinity(), -std::numeric_limits::infinity()}; } template <> @@ -478,8 +459,10 @@ huge overkill that harms readability) that would be wonderful. template DS_D_INLINE void _warp(sycl::sub_group& warp, T* data) { + auto tb = sycl::ext::oneapi::experimental::this_group<3>(); + auto reduce_width_ = tb.get_local_range(2) < reduce_width ? tb.get_local_range(2) : reduce_width; #pragma unroll - for (int i = 1; i < reduce_width; i *= 2) { + for (int i = 1; i < reduce_width_; i *= 2) { data[0] = element(data[0], sycl::permute_group_by_xor( sycl::ext::oneapi::experimental::this_sub_group(), data[0], i)); @@ -489,8 +472,10 @@ DS_D_INLINE void _warp(sycl::sub_group& warp, T* data) template DS_D_INLINE void _warp(sycl::sub_group& warp, T* data) { + auto tb = sycl::ext::oneapi::experimental::this_group<3>(); + auto reduce_width_ = tb.get_local_range(2) < reduce_width ? tb.get_local_range(2) : reduce_width; #pragma unroll - for (int i = 1; i < reduce_width; i *= 2) { + for (int i = 1; i < reduce_width_; i *= 2) { data[0] = element(data[0], sycl::permute_group_by_xor( sycl::ext::oneapi::experimental::this_sub_group(), data[0], i)); @@ -552,16 +537,11 @@ DS_D_INLINE void _block(sycl::group<3>& tb, sycl::sub_group& warp_arg, T* data) auto& reduce_buffer = *sycl::ext::oneapi::group_local_memory_for_overwrite( sycl::ext::oneapi::experimental::this_group<3>()); -#ifdef __HIP_PLATFORM_AMD__ - const int total_threads = blockDim.x * blockDim.y * blockDim.z; - const int running_warps = total_threads / hw_warp_size; -#else /* DPCT1007:7: Migration of cooperative_groups::thread_block_tile::meta_group_size is not supported. */ const int running_warps = warp_arg.get_group_range().size(); -#endif // Always perform warp-scope reduction _warp(warp_arg, data); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/simd.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/simd.h index f77568b..d5855e7 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/simd.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/simd.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/type_shim.h b/intel_extension_for_deepspeed/op_builder/csrc/includes/type_shim.h index e8a9e01..5317c61 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/includes/type_shim.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/type_shim.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -5,7 +20,7 @@ /* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ #include -#include +#include #include // Forward/backward compatibility hack around diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/dequantize.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/dequantize.dp.cpp index 1b9a116..cfdbd2b 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/dequantize.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/dequantize.dp.cpp @@ -1,23 +1,53 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include "dequantization_utils.h" #include "memory_access_utils.h" template -void dequantize_kernel(T* __restrict__ dequant_data, - const int8_t* __restrict__ q_data, - const float* __restrict__ q_params, - int elems_per_group, - int total_elems) -{ - dequantize::to_global( - dequant_data, q_data, q_params, elems_per_group, total_elems); -} +class dequantize_kernel { +private: + T* __restrict__ dequant_data; + const int8_t* __restrict__ q_data; + const float* __restrict__ q_params; + int elems_per_group; + int total_elems; + +public: + dequantize_kernel(T* __restrict__ dequant_data, + const int8_t* __restrict__ q_data, + const float* __restrict__ q_params, + int elems_per_group, + int total_elems): dequant_data(dequant_data), + q_data(q_data), + q_params(q_params), + elems_per_group(elems_per_group), + total_elems(total_elems) {} + void operator()(sycl::nd_item<3>) const + { + dequantize::to_global( + dequant_data, q_data, q_params, elems_per_group, total_elems); + } +}; /* DPCT1049:47: The work-group size passed to the SYCL kernel may exceed the limit. To get the device @@ -26,19 +56,12 @@ limit, query info::device::max_work_group_size. Adjust the work-group size if ne #define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \ { \ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + dequantize_kernel \ + fn(dequant_data, q_data, q_params, elems_per_group, total_elems); \ stream->submit([&](sycl::handler& cgh) { \ - T* dequant_data_ct0 = dequant_data; \ - const int8_t* q_data_ct1 = q_data; \ - const float* q_params_ct2 = q_params; \ - auto elems_per_group_ct3 = elems_per_group; \ - auto total_elems_ct4 = total_elems; \ - \ cgh.parallel_for( \ sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - dequantize_kernel( \ - dequant_data_ct0, q_data_ct1, q_params_ct2, elems_per_group_ct3, total_elems_ct4); \ - }); \ + fn); \ }); \ } diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/fake_quantizer.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/fake_quantizer.dp.cpp index ddcafd9..c71570d 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/fake_quantizer.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/fake_quantizer.dp.cpp @@ -1,40 +1,71 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include -#include "custom_cuda_layers.h" +#include "custom_sycl_layers.h" #include "memory_access_utils.h" +#include "conversion_utils.h" -void fake_quantize_kernel(sycl::half* vals, int group_size, int num_bits) -{ -#if DPCT_COMPATIBILITY_TEMP >= 700 || defined(__HIP_PLATFORM_AMD__) +template +class fake_quantize_kernel {}; + +template<> +class fake_quantize_kernel { +private: + sycl::half* vals; + int group_size; + int num_bits; +public: + fake_quantize_kernel(sycl::half* vals, int group_size, int num_bits): vals(vals), group_size(group_size), num_bits(num_bits) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); - int gid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - int id = threadIdx.x; + /* int gid = threadIdx.x >> 5; */ + /* int lane = threadIdx.x & 0x1f; */ + /* int warp_num = blockDim.x >> 5; */ + /* int id = threadIdx.x; */ + + auto gid = item_ct1.get_local_id(2) >> 5; + auto lane = item_ct1.get_local_id(2) & 0x1f; + auto warp_num = item_ct1.get_local_range(2) >> 5; + auto id = item_ct1.get_local_id(2); constexpr int granularity = 16; constexpr int vals_per_access = granularity / sizeof(sycl::half); sycl::half data[vals_per_access]; - int group_id = blockIdx.x; + /* int group_id = blockIdx.x; */ + auto group_id = item_ct1.get_group(2); int thread_index = id * vals_per_access; int reg_count = 0; int offset = group_id * group_size; float max = -10000.0; for (int thread_index = id * vals_per_access; thread_index < group_size; - thread_index += blockDim.x * vals_per_access) { + thread_index += item_ct1.get_local_range(2) * vals_per_access) { mem_access::load_global(data, vals + offset + thread_index); #pragma unroll @@ -48,11 +79,14 @@ void fake_quantize_kernel(sycl::half* vals, int group_size, int num_bits) auto temp = g.shuffle_xor(max, i); if (max < temp) max = temp; } - __shared__ float partialMax[WARP_SIZE]; + /* __shared__ float partialMax[WARP_SIZE]; */ + auto& partialMax = *sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::experimental::this_group<3>()); if (lane == 0) partialMax[gid] = max; - b.sync(); + /* b.sync(); */ + item_ct1.barrier(); if (lane < warp_num) max = partialMax[lane]; @@ -70,25 +104,37 @@ void fake_quantize_kernel(sycl::half* vals, int group_size, int num_bits) int q_range_min = -(1 << (num_bits - 1)); for (int thread_index = id * vals_per_access; thread_index < group_size; - thread_index += blockDim.x * vals_per_access) { + thread_index += item_ct1.get_local_range(2) * vals_per_access) { mem_access::load_global(data, vals + offset + thread_index); #pragma unroll for (int j = 0; j < vals_per_access; j++) { float q_data; - q_data = sycl::half2float(data[j]); - q_data = __float2int_rn(q_data * q_scale); + /* q_data = sycl::half2float(data[j]); */ + q_data = conversion::to(data[j]); + /* q_data = __float2int_rn(q_data * q_scale); */ + q_data = sycl::vec{(q_data * q_scale)} + .convert()[0]; q_data = q_data > (q_range_max) ? (q_range_max) : (q_data < (q_range_min) ? (q_range_min) : q_data); - data[j] = __float2half_rn(q_data * q_scale_inv); + /* data[j] = __float2half_rn(q_data * q_scale_inv); */ + data[j] = conversion::to(q_data * q_scale_inv); } mem_access::store_global(vals + offset + thread_index, data); } + } -#endif -} +}; -void fake_quantize_kernel(float* vals, int group_size, int num_bits) -{ +template<> +class fake_quantize_kernel { +private: + float* vals; + int group_size; + int num_bits; +public: + fake_quantize_kernel(float* vals, int group_size, int num_bits): vals(vals), group_size(group_size), num_bits(num_bits) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); auto g = sycl::ext::oneapi::experimental::this_sub_group(); @@ -178,6 +224,7 @@ void fake_quantize_kernel(float* vals, int group_size, int num_bits) mem_access::store_global(vals + offset + thread_index, data); } } +}; template void launch_fake_quantize_kernel(T* vals, @@ -189,17 +236,10 @@ void launch_fake_quantize_kernel(T* vals, sycl::range<3> grid_dim(1, 1, group_num); sycl::range<3> block_dim(1, 1, 1024); - /* - DPCT1049:44: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), - [=](sycl::nd_item<3> item_ct1) { - fake_quantize_kernel(vals, total_count / group_num, num_bits); - }); - } + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + fake_quantize_kernel fn(vals, total_count / group_num, num_bits); + stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); + } template void launch_fake_quantize_kernel(float* vals, @@ -213,37 +253,59 @@ template void launch_fake_quantize_kernel(sycl::half* vals, int num_bits, dpct::queue_ptr stream); -void sr_fake_quantize_kernel(sycl::half* vals, - int token_size, - int token_num, - int num_bits, - std::pair seed) -{ -#if DPCT_COMPATIBILITY_TEMP >= 700 || defined(__HIP_PLATFORM_AMD__) +template +class sr_fake_quantize_kernel {}; + +template<> +class sr_fake_quantize_kernel { +private: + sycl::half* vals; + int token_size; + int token_num; + int num_bits; + std::pair seed; +public: + sr_fake_quantize_kernel(sycl::half* vals, + int token_size, + int token_num, + int num_bits, + std::pair seed): vals(vals), + token_size(token_size), + token_num(token_num), + num_bits(num_bits), + seed(seed) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); - int gid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; + /* int gid = threadIdx.x >> 5; */ + auto gid = item_ct1.get_local_id(2) >> 5; + /* int lane = threadIdx.x & 0x1f; */ + auto lane = item_ct1.get_local_id(2) & 0x1f; + /* int warp_num = blockDim.x >> 5; */ + auto warp_num = item_ct1.get_local_range(2) >> 5; - int idx = blockIdx.x * blockDim.x + threadIdx.x; + /* int idx = blockIdx.x * blockDim.x + threadIdx.x; */ + auto idx = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); - float2* vals_cast = reinterpret_cast(vals); + sycl::float2* vals_cast = reinterpret_cast(vals); sycl::half2 data_low[128]; sycl::half2 data_high[128]; - int bid = blockIdx.x; + /* int bid = blockIdx.x; */ + auto bid = item_ct1.get_group(2); /* curandStatePhilox4_32_10_t state; */ /* curand_init(seed.first, idx, seed.second, &state); */ dpct::rng::device::rng_generator> state; state = dpct::rng::device::rng_generator>(seed.first, {seed.second, (unsigned long)idx * 4}); - unsigned int tid = threadIdx.x; + /* unsigned int tid = threadIdx.x; */ + auto tid = item_ct1.get_local_id(2); int reg_count = 0; int offset = bid * token_size; int group_index = bid * token_size + tid; @@ -253,21 +315,24 @@ void sr_fake_quantize_kernel(sycl::half* vals, // float min = 10000.0; float max = -10000.0; while (tid < token_size) { - float2 data = vals_cast[offset + tid]; + sycl::float2 data = vals_cast[offset + tid]; sycl::half2* data_h = reinterpret_cast(&data); data_low[reg_count] = data_h[0]; data_high[reg_count] = data_h[1]; - float2 data_f[2]; - data_f[0] = sycl::half22float2(data_h[0]); - data_f[1] = sycl::half22float2(data_h[1]); + sycl::float2 data_f[2]; + /* data_f[0] = sycl::half22float2(data_h[0]); */ + data_f[0] = conversion::to(data_h[0]); + /* data_f[1] = sycl::half22float2(data_h[1]); */ + data_f[1] = conversion::to(data_h[1]); - if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x); - if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y); - if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x); - if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y); + if (abs((float)data_f[0].x()) > max) max = abs((float)data_f[0].x()); + if (abs((float)data_f[0].y()) > max) max = abs((float)data_f[0].y()); + if (abs((float)data_f[1].x()) > max) max = abs((float)data_f[1].x()); + if (abs((float)data_f[1].y()) > max) max = abs((float)data_f[1].y()); - tid += blockDim.x; + /* tid += blockDim.x; */ + tid += item_ct1.get_local_range(2); reg_count++; } @@ -277,11 +342,14 @@ void sr_fake_quantize_kernel(sycl::half* vals, if (max < temp) max = temp; } - __shared__ float partialMax[WARP_SIZE]; + /* __shared__ float partialMax[WARP_SIZE]; */ + auto& partialMax = *sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::experimental::this_group<3>()); if (lane == 0) partialMax[gid] = max; - b.sync(); + /* b.sync(); */ + item_ct1.barrier(); if (lane < warp_num) max = partialMax[lane]; @@ -298,67 +366,86 @@ void sr_fake_quantize_kernel(sycl::half* vals, float low_q = (float)(-((1 << (num_bits - 1)))); for (int i = 0; i < reg_count; i++) { - int token_index = i * blockDim.x + threadIdx.x; + /* int token_index = i * blockDim.x + threadIdx.x; */ + int token_index = i * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); if (token_index < token_size) { - float2 data_f[2]; - data_f[0] = sycl::half22float2(data_low[i]); - data_f[1] = sycl::half22float2(data_high[i]); - - float2 q_data_int[2]; - q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val)); - q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val)); - q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val)); - q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val)); + sycl::float2 data_f[2]; + /* data_f[0] = sycl::half22float2(data_low[i]); */ + data_f[0] = conversion::to(data_low[i]); + /* data_f[1] = sycl::half22float2(data_high[i]); */ + data_f[1] = conversion::to(data_high[i]); + + sycl::float2 q_data_int[2]; + q_data_int[0].x() = (float)((int)(data_f[0].x() * q_scale_val)); + q_data_int[0].y() = (float)((int)(data_f[0].y() * q_scale_val)); + q_data_int[1].x() = (float)((int)(data_f[1].x() * q_scale_val)); + q_data_int[1].y() = (float)((int)(data_f[1].y() * q_scale_val)); // Stochastic rounding sycl::float4 rand = state.generate, 4>(); float q_error[4]; - q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val; - q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val; - q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val; - q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val; - - q_data_int[0].x = - (rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q) - ? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1)) - : q_data_int[0].x; - q_data_int[0].y = - (rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q) - ? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1)) - : q_data_int[0].y; - q_data_int[1].x = - (rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q) - ? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1)) - : q_data_int[1].x; - q_data_int[1].y = - (rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q) - ? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1)) - : q_data_int[1].y; - - data_f[0].x = q_data_int[0].x / q_scale_val; - data_f[0].y = q_data_int[0].y / q_scale_val; - data_f[1].x = q_data_int[1].x / q_scale_val; - data_f[1].y = q_data_int[1].y / q_scale_val; - - float2 result; + q_error[0] = abs(data_f[0].x() - (q_data_int[0].x() / q_scale_val)) * q_scale_val; + q_error[1] = abs(data_f[0].y() - (q_data_int[0].y() / q_scale_val)) * q_scale_val; + q_error[2] = abs(data_f[1].x() - (q_data_int[1].x() / q_scale_val)) * q_scale_val; + q_error[3] = abs(data_f[1].y() - (q_data_int[1].y() / q_scale_val)) * q_scale_val; + + q_data_int[0].x() = + (rand.x() < q_error[0] && q_data_int[0].x() > low_q && q_data_int[0].x() < high_q) + ? (q_data_int[0].x() + (data_f[0].x() > 0 ? 1 : -1)) + : q_data_int[0].x(); + q_data_int[0].y() = + (rand.y() < q_error[1] && q_data_int[0].y() > low_q && q_data_int[0].y() < high_q) + ? (q_data_int[0].y() + (data_f[0].y() > 0 ? 1 : -1)) + : q_data_int[0].y(); + q_data_int[1].x() = + (rand.w() < q_error[2] && q_data_int[1].x() > low_q && q_data_int[1].x() < high_q) + ? (q_data_int[1].x() + (data_f[1].x() > 0 ? 1 : -1)) + : q_data_int[1].x(); + q_data_int[1].y() = + (rand.z() < q_error[3] && q_data_int[1].y() > low_q && q_data_int[1].y() < high_q) + ? (q_data_int[1].y() + (data_f[1].y() > 0 ? 1 : -1)) + : q_data_int[1].y(); + + data_f[0].x() = q_data_int[0].x() / q_scale_val; + data_f[0].y() = q_data_int[0].y() / q_scale_val; + data_f[1].x() = q_data_int[1].x() / q_scale_val; + data_f[1].y() = q_data_int[1].y() / q_scale_val; + + sycl::float2 result; sycl::half2* result_h = reinterpret_cast(&result); - result_h[0] = __float22half2_rn(data_f[0]); - result_h[1] = __float22half2_rn(data_f[1]); + /* result_h[0] = __float22half2_rn(data_f[0]); */ + result_h[0] = conversion::to(data_f[0]); + /* result_h[1] = __float22half2_rn(data_f[1]); */ + result_h[1] = conversion::to(data_f[1]); vals_cast[offset + token_index] = result; } } } -#endif -} - -void sr_fake_quantize_kernel(float* vals, - int token_size, - int token_num, - int num_bits, - std::pair seed) -{ + } +}; + +template<> +class sr_fake_quantize_kernel { +private: + float* vals; + int token_size; + int token_num; + int num_bits; + std::pair seed; +public: + sr_fake_quantize_kernel(float* vals, + int token_size, + int token_num, + int num_bits, + std::pair seed): vals(vals), + token_size(token_size), + token_num(token_num), + num_bits(num_bits), + seed(seed) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); @@ -478,7 +565,9 @@ void sr_fake_quantize_kernel(float* vals, } } } -} + } +}; + template void launch_sr_fake_quantize_kernel(T* vals, @@ -499,11 +588,9 @@ void launch_sr_fake_quantize_kernel(T* vals, */ { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + sr_fake_quantize_kernel fn(vals, (total_count / group_num) / 4, group_num, num_bits, seed); stream->parallel_for( - sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { - sr_fake_quantize_kernel( - vals, (total_count / group_num) / 4, group_num, num_bits, seed); - }); + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); } } template void launch_sr_fake_quantize_kernel(float* vals, @@ -516,25 +603,41 @@ template void launch_sr_fake_quantize_kernel(sycl::half* vals, int group_num, int num_bits, dpct::queue_ptr stream); - -void fake_quantize_kernel_asym(sycl::half* vals, int group_size, int num_bits) -{ -#if DPCT_COMPATIBILITY_TEMP >= 700 || defined(__HIP_PLATFORM_AMD__) +template +class fake_quantize_kernel_asym {}; + +template<> +class fake_quantize_kernel_asym { +private: + sycl::half* vals; + int group_size; + int num_bits; +public: + fake_quantize_kernel_asym(sycl::half* vals, int group_size, int num_bits): vals(vals), + group_size(group_size), + num_bits(num_bits) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); - int gid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - int id = threadIdx.x; + /* int gid = threadIdx.x >> 5; */ + auto gid = item_ct1.get_local_id(2) >> 5; + /* int lane = threadIdx.x & 0x1f; */ + auto lane = item_ct1.get_local_id(2) & 0x1f; + /* int warp_num = blockDim.x >> 5; */ + auto warp_num = item_ct1.get_local_range(2) >> 5; + /* int id = threadIdx.x; */ + auto id = item_ct1.get_local_id(2); - float2* vals_cast = reinterpret_cast(vals); + sycl::float2* vals_cast = reinterpret_cast(vals); - float2 data[MAX_REG]; + sycl::float2 data[MAX_REG]; - int group_id = blockIdx.x; + /* int group_id = blockIdx.x; */ + auto group_id = item_ct1.get_group(2); { int group_index = id; @@ -557,7 +660,8 @@ void fake_quantize_kernel_asym(sycl::half* vals, int group_size, int num_bits) if (((float)data_h[2]) < min) min = (float)data_h[2]; if (((float)data_h[3]) < min) min = (float)data_h[3]; - group_index += blockDim.x; + /* group_index += blockDim.x; */ + group_index += item_ct1.get_local_range(2); reg_count++; } @@ -573,13 +677,18 @@ void fake_quantize_kernel_asym(sycl::half* vals, int group_size, int num_bits) if (min > temp) min = temp; } - __shared__ float partialMax[WARP_SIZE]; - __shared__ float partialMin[WARP_SIZE]; + /* __shared__ float partialMax[WARP_SIZE]; */ + auto& partialMax = *sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::experimental::this_group<3>()); + /* __shared__ float partialMin[WARP_SIZE]; */ + auto& partialMin = *sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::experimental::this_group<3>()); if (lane == 0) partialMax[gid] = max; if (lane == 0) partialMin[gid] = min; - b.sync(); + /* b.sync(); */ + item_ct1.barrier(); if (lane < warp_num) max = partialMax[lane]; if (lane < warp_num) min = partialMin[lane]; @@ -602,37 +711,52 @@ void fake_quantize_kernel_asym(sycl::half* vals, int group_size, int num_bits) float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { - group_index = i * blockDim.x + id; + /* group_index = i * blockDim.x + id; */ + group_index = i * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); if (group_index < group_size) { sycl::half2* data_h = reinterpret_cast(&data[i]); - float2 q_data[2]; - q_data[0] = sycl::half22float2(data_h[0]); - q_data[1] = sycl::half22float2(data_h[1]); + sycl::float2 q_data[2]; + /* q_data[0] = sycl::half22float2(data_h[0]); */ + q_data[0] = conversion::to(data_h[0]); + /* q_data[1] = sycl::half22float2(data_h[1]); */ + q_data[1] = conversion::to(data_h[1]); - float2 q_data_int[2]; + sycl::float2 q_data_int[2]; - q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv); - q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv); - q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv); - q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv); + q_data_int[0].x() = roundf((q_data[0].x() - min) * q_scale_inv); + q_data_int[0].y() = roundf((q_data[0].y() - min) * q_scale_inv); + q_data_int[1].x() = roundf((q_data[1].x() - min) * q_scale_inv); + q_data_int[1].y() = roundf((q_data[1].y() - min) * q_scale_inv); - q_data_int[0].x = q_data_int[0].x * q_scale + min; - q_data_int[0].y = q_data_int[0].y * q_scale + min; - q_data_int[1].x = q_data_int[1].x * q_scale + min; - q_data_int[1].y = q_data_int[1].y * q_scale + min; + q_data_int[0].x() = q_data_int[0].x() * q_scale + min; + q_data_int[0].y() = q_data_int[0].y() * q_scale + min; + q_data_int[1].x() = q_data_int[1].x() * q_scale + min; + q_data_int[1].y() = q_data_int[1].y() * q_scale + min; - data_h[0] = __float22half2_rn(q_data_int[0]); - data_h[1] = __float22half2_rn(q_data_int[1]); + /* data_h[0] = __float22half2_rn(q_data_int[0]); */ + data_h[0] = conversion::to(q_data_int[0]); + /* data_h[1] = __float22half2_rn(q_data_int[1]); */ + data_h[1] = conversion::to(q_data_int[1]); vals_cast[offset + group_index] = data[i]; } } } -#endif -} - -void fake_quantize_kernel_asym(float* vals, int group_size, int num_bits) -{ + } +}; + +template<> +class fake_quantize_kernel_asym { +private: + float* vals; + int group_size; + int num_bits; +public: + fake_quantize_kernel_asym(float* vals, int group_size, int num_bits): vals(vals), + group_size(group_size), + num_bits(num_bits) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); @@ -740,7 +864,8 @@ void fake_quantize_kernel_asym(float* vals, int group_size, int num_bits) vals_cast[group_index + bid * group_size] = q_data; } } -} + } +}; template void launch_fake_quantize_kernel_asym(T* vals, @@ -758,10 +883,9 @@ void launch_fake_quantize_kernel_asym(T* vals, */ { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + fake_quantize_kernel_asym fn(vals, (total_count / group_num) / 4, num_bits); stream->parallel_for( - sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { - fake_quantize_kernel_asym(vals, (total_count / group_num) / 4, num_bits); - }); + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); } } @@ -782,30 +906,35 @@ void sr_fake_quantize_kernel_asym(sycl::half* vals, int num_bits, std::pair seed) { -#if DPCT_COMPATIBILITY_TEMP >= 700 || defined(__HIP_PLATFORM_AMD__) auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); - int gid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; + /* int gid = threadIdx.x >> 5; */ + auto gid = item_ct1.get_local_id(2) >> 5; + /* int lane = threadIdx.x & 0x1f; */ + auto lane = item_ct1.get_local_id(2) & 0x1f; + /* int warp_num = blockDim.x >> 5; */ + auto warp_num = item_ct1.get_local_range(2) >> 5; - int idx = blockIdx.x * blockDim.x + threadIdx.x; + /* int idx = blockIdx.x * blockDim.x + threadIdx.x; */ + auto idx = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); - float2* vals_cast = reinterpret_cast(vals); + sycl::float2* vals_cast = reinterpret_cast(vals); sycl::half2 data_low[128]; sycl::half2 data_high[128]; - int bid = blockIdx.x; + /* int bid = blockIdx.x; */ + auto bid = item_ct1.get_group(2); /* curandStatePhilox4_32_10_t state; */ /* curand_init(seed.first, idx, seed.second, &state); */ dpct::rng::device::rng_generator> state; state = dpct::rng::device::rng_generator>(seed.first, {seed.second, (unsigned long)idx * 4}); - unsigned int tid = threadIdx.x; + /* unsigned int tid = threadIdx.x; */ + auto tid = item_ct1.get_local_id(2); int reg_count = 0; int offset = bid * token_size; int group_index = bid * token_size + tid; @@ -815,26 +944,29 @@ void sr_fake_quantize_kernel_asym(sycl::half* vals, float min = 10000.0; float max = -10000.0; while (tid < token_size) { - float2 data = vals_cast[offset + tid]; + sycl::float2 data = vals_cast[offset + tid]; sycl::half2* data_h = reinterpret_cast(&data); data_low[reg_count] = data_h[0]; data_high[reg_count] = data_h[1]; - float2 data_f[2]; - data_f[0] = sycl::half22float2(data_h[0]); - data_f[1] = sycl::half22float2(data_h[1]); + sycl::float2 data_f[2]; + /* data_f[0] = sycl::half22float2(data_h[0]); */ + data_f[0] = conversion::to(data_h[0]); + /* data_f[1] = sycl::half22float2(data_h[1]); */ + data_f[1] = conversion::to(data_h[1]); - if (((float)data_f[0].x) > max) max = (float)data_f[0].x; - if (((float)data_f[0].y) > max) max = (float)data_f[0].y; - if (((float)data_f[1].x) > max) max = (float)data_f[1].x; - if (((float)data_f[1].y) > max) max = (float)data_f[1].y; + if (((float)data_f[0].x()) > max) max = (float)data_f[0].x(); + if (((float)data_f[0].y()) > max) max = (float)data_f[0].y(); + if (((float)data_f[1].x()) > max) max = (float)data_f[1].x(); + if (((float)data_f[1].y()) > max) max = (float)data_f[1].y(); - if (((float)data_f[0].x) < min) min = (float)data_f[0].x; - if (((float)data_f[0].y) < min) min = (float)data_f[0].y; - if (((float)data_f[1].x) < min) min = (float)data_f[1].x; - if (((float)data_f[1].y) < min) min = (float)data_f[1].y; + if (((float)data_f[0].x()) < min) min = (float)data_f[0].x(); + if (((float)data_f[0].y()) < min) min = (float)data_f[0].y(); + if (((float)data_f[1].x()) < min) min = (float)data_f[1].x(); + if (((float)data_f[1].y()) < min) min = (float)data_f[1].y(); - tid += blockDim.x; + /* tid += blockDim.x; */ + tid += item_ct1.get_local_range(2); reg_count++; } @@ -850,13 +982,18 @@ void sr_fake_quantize_kernel_asym(sycl::half* vals, if (min > temp) min = temp; } - __shared__ float partialMax[WARP_SIZE]; - __shared__ float partialMin[WARP_SIZE]; + /* __shared__ float partialMax[WARP_SIZE]; */ + auto& partialMax = *sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::experimental::this_group<3>()); + /* __shared__ float partialMin[WARP_SIZE]; */ + auto& partialMin = *sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::experimental::this_group<3>()); if (lane == 0) partialMax[gid] = max; if (lane == 0) partialMin[gid] = min; - b.sync(); + /* b.sync(); */ + item_ct1.barrier(); if (lane < warp_num) max = partialMax[lane]; if (lane < warp_num) min = partialMin[lane]; @@ -880,17 +1017,20 @@ void sr_fake_quantize_kernel_asym(sycl::half* vals, float high_q = (float)((1 << num_bits) - 1); for (int i = 0; i < reg_count; i++) { - int token_index = i * blockDim.x + threadIdx.x; + /* int token_index = i * blockDim.x + threadIdx.x; */ + int token_index = i * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); if (token_index < token_size) { - float2 data_f[2]; - data_f[0] = sycl::half22float2(data_low[i]); - data_f[1] = sycl::half22float2(data_high[i]); - - float2 q_data_int[2]; - q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv)); - q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv)); - q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv)); - q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv)); + sycl::float2 data_f[2]; + /* data_f[0] = sycl::half22float2(data_low[i]); */ + data_f[0] = conversion::to(data_low[i]); + /* data_f[1] = sycl::half22float2(data_high[i]); */ + data_f[1] = conversion::to(data_high[i]); + + sycl::float2 q_data_int[2]; + q_data_int[0].x() = (float)((unsigned int)((data_f[0].x() - min) * q_scale_val_inv)); + q_data_int[0].y() = (float)((unsigned int)((data_f[0].y() - min) * q_scale_val_inv)); + q_data_int[1].x() = (float)((unsigned int)((data_f[1].x() - min) * q_scale_val_inv)); + q_data_int[1].y() = (float)((unsigned int)((data_f[1].y() - min) * q_scale_val_inv)); // Stochastic rounding /* float4 rand = curand_uniform4(&state); */ @@ -898,42 +1038,43 @@ void sr_fake_quantize_kernel_asym(sycl::half* vals, float q_error[4]; q_error[0] = - abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv; + abs(data_f[0].x() - ((q_data_int[0].x() * q_scale_val) + min)) * q_scale_val_inv; q_error[1] = - abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv; + abs(data_f[0].y() - ((q_data_int[0].y() * q_scale_val) + min)) * q_scale_val_inv; q_error[2] = - abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv; + abs(data_f[1].x() - ((q_data_int[1].x() * q_scale_val) + min)) * q_scale_val_inv; q_error[3] = - abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv; - - q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q) - ? (q_data_int[0].x + 1) - : q_data_int[0].x; - q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q) - ? (q_data_int[0].y + 1) - : q_data_int[0].y; - q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q) - ? (q_data_int[1].x + 1) - : q_data_int[1].x; - q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q) - ? (q_data_int[1].y + 1) - : q_data_int[1].y; - - data_f[0].x = q_data_int[0].x * q_scale_val + min; - data_f[0].y = q_data_int[0].y * q_scale_val + min; - data_f[1].x = q_data_int[1].x * q_scale_val + min; - data_f[1].y = q_data_int[1].y * q_scale_val + min; - - float2 result; + abs(data_f[1].y() - ((q_data_int[1].y() * q_scale_val) + min)) * q_scale_val_inv; + + q_data_int[0].x() = (rand.x() < q_error[0] && q_data_int[0].x() < high_q) + ? (q_data_int[0].x() + 1) + : q_data_int[0].x(); + q_data_int[0].y() = (rand.y() < q_error[1] && q_data_int[0].y() < high_q) + ? (q_data_int[0].y() + 1) + : q_data_int[0].y(); + q_data_int[1].x() = (rand.w() < q_error[2] && q_data_int[1].x() < high_q) + ? (q_data_int[1].x() + 1) + : q_data_int[1].x(); + q_data_int[1].y() = (rand.z() < q_error[3] && q_data_int[1].y() < high_q) + ? (q_data_int[1].y() + 1) + : q_data_int[1].y(); + + data_f[0].x() = q_data_int[0].x() * q_scale_val + min; + data_f[0].y() = q_data_int[0].y() * q_scale_val + min; + data_f[1].x() = q_data_int[1].x() * q_scale_val + min; + data_f[1].y() = q_data_int[1].y() * q_scale_val + min; + + sycl::float2 result; sycl::half2* result_h = reinterpret_cast(&result); - result_h[0] = __float22half2_rn(data_f[0]); - result_h[1] = __float22half2_rn(data_f[1]); + /* result_h[0] = __float22half2_rn(data_f[0]); */ + result_h[0] = conversion::to(data_f[0]); + /* result_h[1] = __float22half2_rn(data_f[1]); */ + result_h[1] = conversion::to(data_f[1]); vals_cast[offset + token_index] = result; } } } -#endif } void sr_fake_quantize_kernel_asym(float* vals, @@ -1102,11 +1243,9 @@ void launch_sr_fake_quantize_kernel_asym(T* vals, */ { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + sr_fake_quantize_kernel fn(vals, (total_count / group_num) / 4, group_num, num_bits, seed); stream->parallel_for( - sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { - sr_fake_quantize_kernel( - vals, (total_count / group_num) / 4, group_num, num_bits, seed); - }); + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); } } template void launch_sr_fake_quantize_kernel_asym(float* vals, diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/pt_binding.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/pt_binding.cpp index 12fe120..008efdf 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/pt_binding.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/pt_binding.cpp @@ -1,11 +1,26 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include -/* #include */ +#include +/* #include */ #include #include #include @@ -23,7 +38,7 @@ at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits) if ((((size / groups) - 1) / 4096 + 1) <= 256) { launch_fake_quantize_kernel( - (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); + (T*)vals.data_ptr(), size, groups, bits, at::sycl::getCurrentSYCLStream()); } return vals; } @@ -37,7 +52,7 @@ at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits) if (((size / groups) / 4 / 1024) <= 256) { launch_sr_fake_quantize_kernel( - (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); + (T*)vals.data_ptr(), size, groups, bits, at::sycl::getCurrentSYCLStream()); } return vals; } @@ -51,7 +66,7 @@ at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits) if ((((size / groups) - 1) / 4096 + 1) <= 256) { launch_fake_quantize_kernel_asym( - (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); + (T*)vals.data_ptr(), size, groups, bits, at::sycl::getCurrentSYCLStream()); } return vals; } @@ -65,7 +80,7 @@ at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits) if (((size / groups) / 4 / 1024) <= 256) { launch_sr_fake_quantize_kernel_asym( - (T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream()); + (T*)vals.data_ptr(), size, groups, bits, at::sycl::getCurrentSYCLStream()); } return vals; } @@ -103,7 +118,7 @@ std::vector quantize_kernel(at::Tensor& input_vals, elems_per_group, numBits, quantType, - at::cuda::getCurrentCUDAStream()); + at::sycl::getCurrentSYCLStream()); return {output, params}; } @@ -136,7 +151,7 @@ at::Tensor dequantize(at::Tensor& quantized_data, num_bits, elems_per_group, total_elems, - at::cuda::getCurrentCUDAStream()); + at::sycl::getCurrentSYCLStream()); return output; } @@ -156,7 +171,7 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in, (sycl::half*)min_val_buffer.data_ptr(), num_group, group_size, - at::cuda::getCurrentCUDAStream()); + at::sycl::getCurrentSYCLStream()); return output; } @@ -176,7 +191,7 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in, (sycl::half*)min_val_buffer.data_ptr(), num_group, group_size, - at::cuda::getCurrentCUDAStream()); + at::sycl::getCurrentSYCLStream()); return output; } @@ -219,7 +234,7 @@ std::vector ds_swizzle_quant(at::Tensor& input_vals, pipeline_size, nodes, devices_per_node, - at::cuda::getCurrentCUDAStream()); + at::sycl::getCurrentSYCLStream()); return {output, scales}; } @@ -266,27 +281,27 @@ std::vector quantized_reduction(at::Tensor& input_vals, elems_per_in_tensor, in_groups / devices_per_node, elems_per_in_group, - at::cuda::getCurrentCUDAStream()); + at::sycl::getCurrentSYCLStream()); return {output, scales}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("ds_quantize_fp32", &ds_quantize, "DeepSpeed Quantize with fp32 (CUDA)"); - m.def("ds_quantize_fp16", &ds_quantize, "DeepSpeed Quantize with fp16 (CUDA)"); - m.def("ds_sr_quantize_fp32", &ds_sr_quantize, "DeepSpeed Quantize with fp32 (CUDA)"); + m.def("ds_quantize_fp32", &ds_quantize, "DeepSpeed Quantize with fp32"); + m.def("ds_quantize_fp16", &ds_quantize, "DeepSpeed Quantize with fp16"); + m.def("ds_sr_quantize_fp32", &ds_sr_quantize, "DeepSpeed Quantize with fp32"); m.def( - "ds_sr_quantize_fp16", &ds_sr_quantize, "DeepSpeed Quantize with fp16 (CUDA)"); - m.def("ds_quantize_asym_fp32", &ds_quantize_asym, "DeepSpeed Quantize with fp32 (CUDA)"); + "ds_sr_quantize_fp16", &ds_sr_quantize, "DeepSpeed Quantize with fp16"); + m.def("ds_quantize_asym_fp32", &ds_quantize_asym, "DeepSpeed Quantize with fp32"); m.def("ds_quantize_asym_fp16", &ds_quantize_asym, - "DeepSpeed Quantize with fp16 (CUDA)"); + "DeepSpeed Quantize with fp16"); m.def("ds_sr_quantize_asym_fp32", &ds_sr_quantize_asym, - "DeepSpeed Quantize with fp32 (CUDA)"); + "DeepSpeed Quantize with fp32"); m.def("ds_sr_quantize_asym_fp16", &ds_sr_quantize_asym, - "DeepSpeed Quantize with fp16 (CUDA)"); + "DeepSpeed Quantize with fp16"); pybind11::enum_(m, "QuantizationType") .value("Symmetric", quantize::Type::Symmetric) .value("Asymmetric", quantize::Type::Asymmetric) diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/quant_reduce.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/quant_reduce.dp.cpp index ae3142f..9a792d2 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/quant_reduce.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/quant_reduce.dp.cpp @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include #include "dequantization_utils.h" #include "ds_kernel_utils.h" @@ -20,22 +35,38 @@ to leverage some parallel reductions here to improve performance. */ template -/* -DPCT1110:46: The total declared local variable size in device function dequant_reduce exceeds 128 -bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void dequant_reduce(int8_t* reduced_data, - float* reduced_scales, - const int8_t* input_data, - const float* input_scales, - int elems_per_out_group, - int elems_per_in_tensor, - int groups_per_in_tensor, - int elems_per_in_group, - int num_tensors) -{ +class dequant_reduce { +private: + int8_t* reduced_data; + float* reduced_scales; + const int8_t* input_data; + const float* input_scales; + int elems_per_out_group; + int elems_per_in_tensor; + int groups_per_in_tensor; + int elems_per_in_group; + int num_tensors; +public: + dequant_reduce(int8_t* reduced_data, + float* reduced_scales, + const int8_t* input_data, + const float* input_scales, + int elems_per_out_group, + int elems_per_in_tensor, + int groups_per_in_tensor, + int elems_per_in_group, + int num_tensors): reduced_data(reduced_data), + reduced_scales(reduced_scales), + input_data(input_data), + input_scales(input_scales), + elems_per_out_group(elems_per_out_group), + elems_per_in_tensor(elems_per_in_tensor), + groups_per_in_tensor(groups_per_in_tensor), + elems_per_in_group(elems_per_in_group), + num_tensors(num_tensors) {} + + + void operator()(sycl::nd_item<3>) const { sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); @@ -135,7 +166,10 @@ void dequant_reduce(int8_t* reduced_data, mem_access::store_global(reduced_data + iter_offset, local_output); } } -} + } + +}; + template int32_t pow2_round(int32_t raw_value) @@ -150,30 +184,17 @@ limit, query info::device::max_work_group_size. Adjust the work-group size if ne #define LAUNCH_DEQUANT_REDUCE(num_chunks) \ { \ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + dequant_reduce fn(reduced_data, \ + reduced_scales, \ + input_data, \ + input_scales, \ + elems_per_out_group, \ + elems_per_in_tensor, \ + groups_per_in_tensor, \ + elems_per_in_group, \ + num_tensors); \ stream->submit([&](sycl::handler& cgh) { \ - int8_t* reduced_data_ct0 = reduced_data; \ - float* reduced_scales_ct1 = reduced_scales; \ - const int8_t* input_data_ct2 = input_data; \ - const float* input_scales_ct3 = input_scales; \ - auto elems_per_out_group_ct4 = elems_per_out_group; \ - auto elems_per_in_tensor_ct5 = elems_per_in_tensor; \ - auto groups_per_in_tensor_ct6 = groups_per_in_tensor; \ - auto elems_per_in_group_ct7 = elems_per_in_group; \ - auto num_tensors_ct8 = num_tensors; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - dequant_reduce( \ - reduced_data_ct0, \ - reduced_scales_ct1, \ - input_data_ct2, \ - input_scales_ct3, \ - elems_per_out_group_ct4, \ - elems_per_in_tensor_ct5, \ - groups_per_in_tensor_ct6, \ - elems_per_in_group_ct7, \ - num_tensors_ct8); \ - }); \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ }); \ } diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize.dp.cpp index 286258b..13faecd 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize.dp.cpp @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include "ds_kernel_utils.h" #include "memory_access_utils.h" #include "quantization.h" @@ -20,18 +35,24 @@ template -/* -DPCT1110:46: The total declared local variable size in device function cached_quantization exceeds -128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void cached_quantization(int8_t* __restrict__ output_data, - float* __restrict__ params, - const sycl::half* __restrict__ input_data, - int groups, - int elems_per_group) -{ +class cached_quantization { +private: + int8_t* __restrict__ output_data; + float* __restrict__ params; + const sycl::half* __restrict__ input_data; + int groups; + int elems_per_group; +public: + cached_quantization(int8_t* __restrict__ output_data, + float* __restrict__ params, + const sycl::half* __restrict__ input_data, + int groups, + int elems_per_group): output_data(output_data), + params(params), + input_data(input_data), + groups(groups), + elems_per_group(elems_per_group) {} + void operator()(sycl::nd_item<3>) const { sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); @@ -65,33 +86,26 @@ void cached_quantization(int8_t* __restrict__ output_data, quantize:: local_array( local_buffer, params, output_data, elems_per_group, groups); -} + } +}; + /********* Launcher methods ***********/ /* DPCT1049:47: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ -#define LAUNCH_CACHED_QUANT_CALL(q_bits, quant_type) \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - int8_t* output_data_ct0 = output_data; \ - float* params_ct1 = params; \ - const sycl::half* input_data_ct2 = input_data; \ - int groups_ct3 = groups; \ - int elems_per_group_ct4 = elems_per_group; \ - \ - cgh.parallel_for( \ - sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - cached_quantization( \ - output_data_ct0, params_ct1, input_data_ct2, groups_ct3, elems_per_group_ct4); \ - }); \ +#define LAUNCH_CACHED_QUANT_CALL(q_bits, quant_type) \ + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + cached_quantization fn(output_data, params, input_data, groups, elems_per_group); \ + stream->submit([&](sycl::handler& cgh) { \ + cgh.parallel_for( \ + sycl::nd_range<3>(grid * block, block), fn); \ }); #define LAUNCH_CACHED_QUANT( \ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize_intX.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize_intX.dp.cpp index 20ade04..5db498e 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize_intX.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/quantize_intX.dp.cpp @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include #include "memory_access_utils.h" #include @@ -185,13 +200,28 @@ __dpct_inline__ AlignedArray int4_to_half(const AlignedArray< return ret; } -void dequantize_int4_to_half(uint8_t* data_in, - sycl::half* data_out, - sycl::half* scale_buffer, - sycl::half* min_val_buffer, - int num_group, - int group_size) -{ +class dequantize_int4_to_half { +private: + uint8_t* data_in; + sycl::half* data_out; + sycl::half* scale_buffer; + sycl::half* min_val_buffer; + int num_group; + int group_size; +public: + dequantize_int4_to_half(uint8_t* data_in, + sycl::half* data_out, + sycl::half* scale_buffer, + sycl::half* min_val_buffer, + int num_group, + int group_size): data_in(data_in), + data_out(data_out), + scale_buffer(scale_buffer), + min_val_buffer(min_val_buffer), + num_group(num_group), + group_size(group_size) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); using AccessType = AlignedArray; using AccessTypeOut = AlignedArray; @@ -210,7 +240,9 @@ void dequantize_int4_to_half(uint8_t* data_in, reinterpret_cast(data_out)[idx] = output; } -} + } +}; + void launch_dequantize_int4_to_half_experimental(uint8_t* data_in, sycl::half* data_out, @@ -225,13 +257,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in, { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + dequantize_int4_to_half fn( + data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_block) * sycl::range<3>(1, 1, 256), - sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_int4_to_half( - data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); - }); + sycl::range<3>(1, 1, 256)), fn); } } @@ -246,13 +276,27 @@ __dpct_inline__ AlignedArray int8_to_half(const AlignedArray) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); using AccessType = AlignedArray; using AccessTypeOut = AlignedArray; @@ -271,7 +315,9 @@ void dequantize_int8_to_half(uint8_t* data_in, reinterpret_cast(data_out)[idx] = output; } -} + } +}; + void launch_dequantize_int8_to_half_experimental(uint8_t* data_in, sycl::half* data_out, @@ -286,12 +332,10 @@ void launch_dequantize_int8_to_half_experimental(uint8_t* data_in, { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + dequantize_int8_to_half fn( + data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_block) * sycl::range<3>(1, 1, 256), - sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_int8_to_half( - data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); - }); + sycl::range<3>(1, 1, 256)), fn); } } diff --git a/intel_extension_for_deepspeed/op_builder/csrc/quantization/swizzled_quantize.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/quantization/swizzled_quantize.dp.cpp index 2863276..f442253 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/quantization/swizzled_quantize.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/quantization/swizzled_quantize.dp.cpp @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team #include -#include +#include #include "memory_access_utils.h" #include "quantization_utils.h" #include "reduction_utils.h" @@ -20,19 +35,27 @@ constexpr int h_per_step = step_granularity * quantize::h_per_load; } // namespace swiz_quant template -/* -DPCT1110:46: The total declared local variable size in device function swizzled_quant_kernel exceeds -128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void swizzled_quant_kernel(int8_t* quantized_data, - float* quantized_scales, - const sycl::half* uncompressed_data, - int elems_per_group, - int nodes, - int devices_per_node) -{ +class swizzled_quant_kernel { +private: + int8_t* quantized_data; + float* quantized_scales; + const sycl::half* uncompressed_data; + int elems_per_group; + int nodes; + int devices_per_node; +public: + swizzled_quant_kernel(int8_t* quantized_data, + float* quantized_scales, + const sycl::half* uncompressed_data, + int elems_per_group, + int nodes, + int devices_per_node):quantized_data(quantized_data), + quantized_scales(quantized_scales), + uncompressed_data(uncompressed_data), + elems_per_group(elems_per_group), + nodes(nodes), + devices_per_node(devices_per_node) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); @@ -92,7 +115,8 @@ void swizzled_quant_kernel(int8_t* quantized_data, mem_access::store_global(out_base + i * out_stride, local_output); } } -} + } +}; /* DPCT1049:47: The work-group size passed to the SYCL kernel may exceed the limit. To get the device @@ -101,24 +125,16 @@ limit, query info::device::max_work_group_size. Adjust the work-group size if ne #define LAUNCH_SWIZZLE_QUANT(total_chunks, threads) \ { \ dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + swizzled_quant_kernel fn( \ + q_data, \ + q_scales, \ + input_data, \ + elems_per_group, \ + nodes, \ + devices_per_node); \ stream->submit([&](sycl::handler& cgh) { \ - int8_t* q_data_ct0 = q_data; \ - float* q_scales_ct1 = q_scales; \ - const sycl::half* input_data_ct2 = input_data; \ - auto elems_per_group_ct3 = elems_per_group; \ - auto nodes_ct4 = nodes; \ - auto devices_per_node_ct5 = devices_per_node; \ - \ cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - swizzled_quant_kernel( \ - q_data_ct0, \ - q_scales_ct1, \ - input_data_ct2, \ - elems_per_group_ct3, \ - nodes_ct4, \ - devices_per_node_ct5); \ - }); \ + fn); \ }); \ } diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/apply_rotary_pos_emb.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/apply_rotary_pos_emb.dp.cpp index 80a6b5e..3ecd0bb 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/apply_rotary_pos_emb.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/apply_rotary_pos_emb.dp.cpp @@ -1,243 +1,248 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "hip/hip_cooperative_groups.h" -#else -#endif #include "ds_kernel_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" #include "memory_access_utils.h" -#ifndef __HIP_PLATFORM_AMD__ -#endif - namespace rot_half { constexpr int threads = 256; -} // namespace rot_half +} // namespace rot_half template -/* -DPCT1110:3: The total declared local variable size in device function apply_rotary_pos_half exceeds -128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void apply_rotary_pos_half(T* mixed_query, - T* key_layer, - unsigned rotary_dim, - unsigned seq_len, - unsigned seq_offset, - unsigned num_heads, - unsigned head_size, - unsigned total_count, - float rope_theta, - int max_out_tokens) -{ +class apply_rotary_pos_half { + private: + T* mixed_query; + T* key_layer; + unsigned rotary_dim; + unsigned seq_len; + unsigned seq_offset; + unsigned num_heads; + unsigned head_size; + unsigned total_count; + float rope_theta; + int max_out_tokens; + + public: + apply_rotary_pos_half( + T* mixed_query, + T* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count, + float rope_theta, + int max_out_tokens) + : mixed_query(mixed_query), + rotary_dim(rotary_dim), + seq_len(seq_len), + seq_offset(seq_offset), + num_heads(num_heads), + head_size(head_size), + total_count(total_count), + rope_theta(rope_theta), + max_out_tokens(max_out_tokens) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int T_per_thread = granularity / sizeof(T); constexpr int heads_per_block = rot_half::threads / threadsPerHead; sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); - auto head_group = - sycl::ext::oneapi::experimental::this_sub_group(); + auto head_group = sycl::ext::oneapi::experimental::this_sub_group(); - const int head_idx = - item_ct1.get_group(2) * heads_per_block + item_ct1.get_local_id(2) / threadsPerHead; + const int head_idx = item_ct1.get_group(2) * heads_per_block + + item_ct1.get_local_id(2) / threadsPerHead; const int cur_seq_idx = head_idx % seq_len; const int offset = head_idx * head_size; - const int k_offset = (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size; + const int k_offset = + (cur_seq_idx + (head_idx / seq_len) * max_out_tokens) * head_size; const int seq_idx = cur_seq_idx + seq_offset; const int half_dim = rotary_dim >> 1; const int half_dim_threads = half_dim / T_per_thread; if (head_idx < total_count) { - /* - DPCT1007:0: Migration of thread_rank is not supported. - */ - const int base_neuron_idx = head_group.get_local_linear_id() * T_per_thread; - - T q[T_per_thread], k[T_per_thread]; - mem_access::load_global(q, mixed_query + offset + base_neuron_idx); - mem_access::load_global(k, key_layer + k_offset + base_neuron_idx); + /* + DPCT1007:0: Migration of thread_rank is not supported. + */ + const int base_neuron_idx = + head_group.get_local_linear_id() * T_per_thread; + + T q[T_per_thread], k[T_per_thread]; + mem_access::load_global( + q, mixed_query + offset + base_neuron_idx); + mem_access::load_global( + k, key_layer + k_offset + base_neuron_idx); #pragma unroll - for (int i = 0; i < T_per_thread; i++) { - const int neuron_idx = base_neuron_idx + i; - if (neuron_idx < rotary_dim) { - float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; - inv_freq = 1.0 / dpct::pow(rope_theta, inv_freq) * (float)seq_idx; - - float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0); - float q_rot = conversion::to(q[i]) * rotary_sign; - float k_rot = conversion::to(k[i]) * rotary_sign; - - const int target_lane = (neuron_idx < half_dim) - /* - DPCT1007:1: Migration of thread_rank is not supported. - */ - ? head_group.get_local_linear_id() + half_dim_threads - /* - DPCT1007:2: Migration of thread_rank is not supported. - */ - : head_group.get_local_linear_id() - half_dim_threads; - - /* - DPCT1007:5: Migration of cooperative_groups::thread_block_tile::shfl is not - supported. - */ - const float q_rot_temp = head_group.shuffle(q_rot, target_lane); - /* - DPCT1007:6: Migration of cooperative_groups::thread_block_tile::shfl is not - supported. - */ - const float k_rot_temp = head_group.shuffle(k_rot, target_lane); - - q[i] = conversion::to(conversion::to(q[i]) * sycl::cos(inv_freq) + - q_rot_temp * sycl::sin(inv_freq)); - k[i] = conversion::to(conversion::to(k[i]) * sycl::cos(inv_freq) + - k_rot_temp * sycl::sin(inv_freq)); - } + for (int i = 0; i < T_per_thread; i++) { + const int neuron_idx = base_neuron_idx + i; + if (neuron_idx < rotary_dim) { + float inv_freq = + (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim; + inv_freq = 1.0 / dpct::pow(rope_theta, inv_freq) * (float)seq_idx; + + float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0); + float q_rot = conversion::to(q[i]) * rotary_sign; + float k_rot = conversion::to(k[i]) * rotary_sign; + + const int target_lane = (neuron_idx < half_dim) + /* + DPCT1007:1: Migration of thread_rank is not supported. + */ + ? head_group.get_local_linear_id() + half_dim_threads + /* + DPCT1007:2: Migration of thread_rank is not supported. + */ + : head_group.get_local_linear_id() - half_dim_threads; + + /* + DPCT1007:5: Migration of cooperative_groups::thread_block_tile::shfl + is not supported. + */ + const float q_rot_temp = head_group.shuffle(q_rot, target_lane); + /* + DPCT1007:6: Migration of cooperative_groups::thread_block_tile::shfl + is not supported. + */ + const float k_rot_temp = head_group.shuffle(k_rot, target_lane); + + q[i] = conversion::to( + conversion::to(q[i]) * sycl::cos(inv_freq) + + q_rot_temp * sycl::sin(inv_freq)); + k[i] = conversion::to( + conversion::to(k[i]) * sycl::cos(inv_freq) + + k_rot_temp * sycl::sin(inv_freq)); } + } - mem_access::store_global(mixed_query + offset + base_neuron_idx, q); - mem_access::store_global(key_layer + k_offset + base_neuron_idx, k); + mem_access::store_global( + mixed_query + offset + base_neuron_idx, q); + mem_access::store_global( + key_layer + k_offset + base_neuron_idx, k); } -} - -/* -DPCT1049:4: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. -*/ -#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ - { \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* mixed_query_ct0 = mixed_query; \ - T* key_layer_ct1 = key_layer; \ - auto rotary_dim_ct2 = rotary_dim; \ - auto seq_len_ct3 = seq_len; \ - auto offset_ct4 = offset; \ - auto num_heads_ct5 = num_heads; \ - auto head_size_ct6 = head_size; \ - auto total_count_ct7 = total_count; \ - auto rope_theta_ct8 = rope_theta; \ - auto max_out_tokens_ct9 = max_out_tokens; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - apply_rotary_pos_half(mixed_query_ct0, \ - key_layer_ct1, \ - rotary_dim_ct2, \ - seq_len_ct3, \ - offset_ct4, \ - num_heads_ct5, \ - head_size_ct6, \ - total_count_ct7, \ - rope_theta_ct8, \ - max_out_tokens_ct9); \ - }); \ - }); \ + } +}; + +#define LAUNCH_ROT_POS_EMB_HALF(HEAD_THREADS, ALIGNMENT) \ + { \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + apply_rotary_pos_half fn( \ + mixed_query, \ + key_layer, \ + rotary_dim, \ + seq_len, \ + offset, \ + num_heads, \ + head_size, \ + total_count, \ + rope_theta, \ + max_out_tokens); \ + stream->submit([&](sycl::handler& cgh) { \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ + }); \ } -#ifdef __HIP_PLATFORM_AMD__ -#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 4) { \ - LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ - } else if (threads_per_head == 8) { \ - LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ - } else if (threads_per_head == 16) { \ - LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ - } else if (threads_per_head == 32) { \ - LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ - } else if (threads_per_head == 64) { \ - LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ - } else { \ - assert(false); \ - } -#else -#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 4) { \ - LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ - } else if (threads_per_head == 8) { \ - LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ - } else if (threads_per_head == 16) { \ - LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ - } else if (threads_per_head == 32) { \ - LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ - } else { \ - assert(false); \ - } -#endif +#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ + if (threads_per_head == 4) { \ + LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ + } else if (threads_per_head == 8) { \ + LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ + } else if (threads_per_head == 16) { \ + LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ + } else if (threads_per_head == 32) { \ + LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ + } else { \ + assert(false); \ + } template -void launch_apply_rotary_pos_emb(T* mixed_query, - T* key_layer, - unsigned head_size, - unsigned seq_len, - unsigned rotary_dim, - unsigned offset, - unsigned num_heads, - unsigned batch, - float rope_theta, - dpct::queue_ptr stream, - int max_out_tokens) -{ - const int half_dim = rotary_dim >> 1; - - int alignment = sizeof(T); - if (half_dim % (16 / sizeof(T)) == 0) { - alignment = 16; - } else if (half_dim % (8 / sizeof(T)) == 0) { - alignment = 8; - } else if (half_dim % (4 / sizeof(T)) == 0) { - alignment = 4; - } else { - assert(false); - } - const int T_per_elem = alignment / sizeof(T); +void launch_apply_rotary_pos_emb( + T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + float rope_theta, + dpct::queue_ptr stream, + int max_out_tokens) { + const int half_dim = rotary_dim >> 1; + + int alignment = sizeof(T); + if (half_dim % (16 / sizeof(T)) == 0) { + alignment = 16; + } else if (half_dim % (8 / sizeof(T)) == 0) { + alignment = 8; + } else if (half_dim % (4 / sizeof(T)) == 0) { + alignment = 4; + } else { + assert(false); + } + const int T_per_elem = alignment / sizeof(T); - int total_count = batch * num_heads * seq_len; + int total_count = batch * num_heads * seq_len; - const int padded_head_size = next_pow2(head_size); + const int padded_head_size = next_pow2(head_size); - assert(padded_head_size <= hw_warp_size * T_per_elem); + assert(padded_head_size <= hw_warp_size * T_per_elem); - const int threads_per_head = padded_head_size / T_per_elem; - const int heads_per_block = rot_half::threads / threads_per_head; + const int threads_per_head = padded_head_size / T_per_elem; + const int heads_per_block = rot_half::threads / threads_per_head; - sycl::range<3> block(1, 1, rot_half::threads); - sycl::range<3> grid(1, 1, (total_count + heads_per_block - 1) / heads_per_block); + sycl::range<3> block(1, 1, rot_half::threads); + sycl::range<3> grid( + 1, 1, (total_count + heads_per_block - 1) / heads_per_block); - if (alignment == 4) { - LAUNCH_FOR_ALIGNMENT(4); - } else if (alignment == 8) { - LAUNCH_FOR_ALIGNMENT(8); - } else if (alignment == 16) { - LAUNCH_FOR_ALIGNMENT(16); - } else { - assert(false); - } + if (alignment == 4) { + LAUNCH_FOR_ALIGNMENT(4); + } else if (alignment == 8) { + LAUNCH_FOR_ALIGNMENT(8); + } else if (alignment == 16) { + LAUNCH_FOR_ALIGNMENT(16); + } else { + assert(false); + } } -#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ - template void launch_apply_rotary_pos_emb(T*, \ - T*, \ - unsigned, \ - unsigned, \ - unsigned, \ - unsigned, \ - unsigned, \ - unsigned, \ - float, \ - dpct::queue_ptr, \ - int); +#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \ + template void launch_apply_rotary_pos_emb( \ + T*, \ + T*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + float, \ + dpct::queue_ptr, \ + int); INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float); #ifdef BF16_AVAILABLE diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/dequantize.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/dequantize.dp.cpp index 26f8ff4..a45fbac 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/dequantize.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/dequantize.dp.cpp @@ -1,12 +1,27 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" #define MAX_QUANTIZE_GROUPING 1024 @@ -14,14 +29,32 @@ #define loop_unroll_bits 1 template -void dequantize_kernel(T* output, - const int8_t* input, - const float* qscale, - int output_size, - int hidden_dim, - int groups, - int merge_count) -{ +class dequantize_kernel { + private: + T* output; + const int8_t* input; + const float* qscale; + int output_size; + int hidden_dim; + int groups; + int merge_count; + + public: + dequantize_kernel( + T* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count) + : output(output), + qscale(qscale), + output_size(output_size), + hidden_dim(hidden_dim), + groups(groups), + merge_count(merge_count) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); unsigned merge_hidden = hidden_dim >> merge_count; unsigned quantization_stride = (merge_hidden * output_size) / groups; @@ -30,56 +63,60 @@ void dequantize_kernel(T* output, unsigned tid = item_ct1.get_local_id(2); while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; + unsigned w_index = bid / merge_hidden; + unsigned q_index = tid + bid * output_size; - auto q = input[q_index]; + auto q = input[q_index]; - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; + unsigned merge_hidden_total = w_index * merge_hidden; + unsigned scale_index = + ((((bid - merge_hidden_total) + tid * merge_hidden) / + quantization_stride) + << merge_count) + + w_index; - float scale_data = qscale[scale_index]; + float scale_data = qscale[scale_index]; - output[q_index] = conversion::to(scale_data * (float)q); - tid += item_ct1.get_local_range(2); + output[q_index] = conversion::to(scale_data * (float)q); + tid += item_ct1.get_local_range(2); } -} + } +}; template -void launch_dequantize(T* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count, - dpct::queue_ptr stream) -{ - unsigned threads = 1024; - sycl::range<3> block_dims(1, 1, threads); - sycl::range<3> grid_dims(1, 1, hidden_dim); - - /* - DPCT1049:0: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - dequantize_kernel( - output, input, qscale, output_size, hidden_dim, groups, merge_count); - }); - } +void launch_dequantize( + T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + dpct::queue_ptr stream) { + unsigned threads = 1024; + sycl::range<3> block_dims(1, 1, threads); + sycl::range<3> grid_dims(1, 1, hidden_dim); + + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + dequantize_kernel fn( + output, input, qscale, output_size, hidden_dim, groups, merge_count); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } #define INSTANTIATE_DEQUANTIZE_MERGE(T) \ - template void launch_dequantize( \ - T*, const int8_t*, const float*, unsigned, unsigned, unsigned, unsigned, dpct::queue_ptr); + template void launch_dequantize( \ + T*, \ + const int8_t*, \ + const float*, \ + unsigned, \ + unsigned, \ + unsigned, \ + unsigned, \ + dpct::queue_ptr); INSTANTIATE_DEQUANTIZE_MERGE(float); #ifdef BF16_AVAILABLE @@ -87,25 +124,35 @@ INSTANTIATE_DEQUANTIZE_MERGE(sycl::ext::oneapi::bfloat16); #endif INSTANTIATE_DEQUANTIZE_MERGE(sycl::half); -void dequantize_kernel(float* output, - const int8_t* input, - const float* qscale, - int hidden_dim, - unsigned merge_hidden, - int cnt) -{ -} - template -void dequantize_kernel(T* output, - const int8_t* input, - const float* qscale, - unsigned hidden_dim, - unsigned merge_hidden, - int cnt) -{ +class dequantize_kernel_2 { + private: + T* output; + const int8_t* input; + const float* qscale; + unsigned hidden_dim; + unsigned merge_hidden; + int cnt; + + public: + dequantize_kernel_2( + T* output, + const int8_t* input, + const float* qscale, + unsigned hidden_dim, + unsigned merge_hidden, + int cnt) + : output(output), + input(input), + qscale(qscale), + hidden_dim(hidden_dim), + merge_hidden(merge_hidden), + cnt(cnt) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - unsigned bid = item_ct1.get_group(2) * item_ct1.get_group_range(1) + item_ct1.get_group(1); + unsigned bid = item_ct1.get_group(2) * item_ct1.get_group_range(1) + + item_ct1.get_group(1); unsigned tid = item_ct1.get_local_id(2); float local_scale = qscale[item_ct1.get_group(2)]; @@ -117,59 +164,62 @@ void dequantize_kernel(T* output, output_cast += bid * merge_hidden; for (int c = 0; c < cnt; c++) { - if (tid < merge_hidden) { - float q = input_cast[tid]; - int8_t* q_int8 = (int8_t*)&q; - - sycl::float2 q_f; - T* q_h = (T*)&q_f; - - q_h[0] = conversion::to(local_scale * (float)q_int8[0]); - q_h[1] = conversion::to(local_scale * (float)q_int8[1]); - q_h[2] = conversion::to(local_scale * (float)q_int8[2]); - q_h[3] = conversion::to(local_scale * (float)q_int8[3]); - output_cast[tid] = q_f; - tid += item_ct1.get_local_range(2); - } + if (tid < merge_hidden) { + float q = input_cast[tid]; + int8_t* q_int8 = (int8_t*)&q; + + sycl::float2 q_f; + T* q_h = (T*)&q_f; + + q_h[0] = conversion::to(local_scale * (float)q_int8[0]); + q_h[1] = conversion::to(local_scale * (float)q_int8[1]); + q_h[2] = conversion::to(local_scale * (float)q_int8[2]); + q_h[3] = conversion::to(local_scale * (float)q_int8[3]); + output_cast[tid] = q_f; + tid += item_ct1.get_local_range(2); + } } -} + } +}; template -void launch_dequantize(T* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - dpct::queue_ptr stream) -{ - unsigned threads = 1024; - hidden_dim /= 4; - unsigned thd_cnt = (hidden_dim - 1) / threads + 1; - - assert(output_size % groups == 0); - unsigned blocks = output_size / groups; - - sycl::range<3> block_dims(1, 1, threads); - sycl::range<3> grid_dims(1, blocks, groups); - - /* - DPCT1049:1: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - dequantize_kernel(output, input, qscale, hidden_dim, hidden_dim, thd_cnt); - }); - } +void launch_dequantize( + T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + dpct::queue_ptr stream) { + unsigned threads = 1024; + hidden_dim /= 4; + unsigned thd_cnt = (hidden_dim - 1) / threads + 1; + + assert(output_size % groups == 0); + unsigned blocks = output_size / groups; + + sycl::range<3> block_dims(1, 1, threads); + sycl::range<3> grid_dims(1, blocks, groups); + + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + dequantize_kernel_2 fn( + output, input, qscale, hidden_dim, hidden_dim, thd_cnt); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } #define INSTANTIATE_DEQUANTIZE_NO_MERGE(T) \ - template void launch_dequantize( \ - T*, const int8_t*, const float*, unsigned, unsigned, unsigned, dpct::queue_ptr); + template void launch_dequantize( \ + T*, \ + const int8_t*, \ + const float*, \ + unsigned, \ + unsigned, \ + unsigned, \ + dpct::queue_ptr); INSTANTIATE_DEQUANTIZE_NO_MERGE(float); #ifdef BF16_AVAILABLE diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/gelu.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/gelu.dp.cpp index d02eaff..4ea8f64 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/gelu.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/gelu.dp.cpp @@ -1,12 +1,27 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" #include "memory_access_utils.h" #define MAX_CAP 4 @@ -17,76 +32,91 @@ using __nv_bfloat162 = sycl::half2; #endif -inline float gelu(const float x) -{ - constexpr float sqrt_param = 0.79788456080286535587989211986876f; - constexpr float mul_param = 0.044715; - return x * 0.5f * (1.0f + sycl::tanh(sqrt_param * (x + mul_param * x * x * x))); +inline float gelu(const float x) { + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715; + return x * 0.5f * + (1.0f + sycl::tanh(sqrt_param * (x + mul_param * x * x * x))); } -/* -In-place gelu(biasAdd(x)) for channels last -*/ template -void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size) -{ +class fused_bias_gelu { + private: + T* input; + const T* bias; + int total_count; + int intermediate_size; + + public: + fused_bias_gelu( + T* input, + const T* bias, + int total_count, + int intermediate_size) + : input(input), + bias(bias), + total_count(total_count), + intermediate_size(intermediate_size) {} + void operator()(sycl::nd_item<3>) const { // Input restriction: intermediate_size % vals_per_access == 0 auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int granularity = 16; constexpr int values_per_access = granularity / sizeof(T); - const int offset = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2)) * + const int offset = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * values_per_access; if (offset < total_count) { - T data[values_per_access]; - T data_bias[values_per_access]; - mem_access::load_global(data, input + offset); - mem_access::load_global( - data_bias, bias + (offset % intermediate_size), bias != nullptr); + T data[values_per_access]; + T data_bias[values_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll - for (int i = 0; i < values_per_access; i++) { - float data_f = conversion::to(data[i]); - float bias_f = conversion::to(data_bias[i]); - data[i] = conversion::to(gelu(data_f + bias_f)); - } + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(gelu(data_f + bias_f)); + } - mem_access::store_global(input + offset, data); + mem_access::store_global(input + offset, data); } -} + } +}; template -void launch_bias_gelu(T* input, - const T* bias, - int intermediate_size, - int batch_size, - dpct::queue_ptr stream) -{ - constexpr int threads = 1024; - constexpr int granularity = 16; - - const int total_count = batch_size * intermediate_size; - const int elems_per_block = threads * (granularity / sizeof(T)); - sycl::range<3> block_dims(1, 1, threads); - sycl::range<3> grid_dims(1, 1, (total_count + elems_per_block - 1) / elems_per_block); - - /* - DPCT1049:0: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - fused_bias_gelu(input, bias, total_count, intermediate_size); - }); - } +void launch_bias_gelu( + T* input, + const T* bias, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream) { + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); + sycl::range<3> block_dims(1, 1, threads); + sycl::range<3> grid_dims( + 1, 1, (total_count + elems_per_block - 1) / elems_per_block); + + /* + DPCT1049:0: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + fused_bias_gelu fn(input, bias, total_count, intermediate_size); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } #define INSTANTIATE_LAUNCH_BIAS_GELU(T) \ - template void launch_bias_gelu(T*, const T*, int, int, dpct::queue_ptr); + template void launch_bias_gelu(T*, const T*, int, int, dpct::queue_ptr); INSTANTIATE_LAUNCH_BIAS_GELU(float) #ifdef BF16_AVAILABLE @@ -98,65 +128,83 @@ INSTANTIATE_LAUNCH_BIAS_GELU(sycl::half) In-place channels-last bias add */ template -void fused_bias_add(T* input, const T* bias, int total_count, int intermediate_size) -{ +class fused_bias_add { + private: + T* input; + const T* bias; + int total_count; + int intermediate_size; + + public: + fused_bias_add( + T* input, + const T* bias, + int total_count, + int intermediate_size) + : input(input), + bias(bias), + total_count(total_count), + intermediate_size(intermediate_size) {} + void operator()(sycl::nd_item<3>) const { // Input restriction: intermediate_size % vals_per_access == 0 auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int granularity = 16; constexpr int values_per_access = granularity / sizeof(T); - const int offset = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2)) * + const int offset = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * values_per_access; if (offset < total_count) { - T data[values_per_access]; - T data_bias[values_per_access]; - mem_access::load_global(data, input + offset); - mem_access::load_global( - data_bias, bias + (offset % intermediate_size), bias != nullptr); + T data[values_per_access]; + T data_bias[values_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll - for (int i = 0; i < values_per_access; i++) { - float data_f = conversion::to(data[i]); - float bias_f = conversion::to(data_bias[i]); - data[i] = conversion::to(data_f + bias_f); - } + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(data_f + bias_f); + } - mem_access::store_global(input + offset, data); + mem_access::store_global(input + offset, data); } -} + } +}; template -void launch_bias_add(T* input, - const T* bias, - int intermediate_size, - int batch_size, - dpct::queue_ptr stream) -{ - constexpr int threads = 1024; - constexpr int granularity = 16; - - const int total_count = batch_size * intermediate_size; - const int elems_per_block = threads * (granularity / sizeof(T)); - sycl::range<3> block_dims(1, 1, threads); - sycl::range<3> grid_dims(1, 1, (total_count + elems_per_block - 1) / elems_per_block); - - /* - DPCT1049:1: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - fused_bias_add(input, bias, total_count, intermediate_size); - }); - } +void launch_bias_add( + T* input, + const T* bias, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream) { + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); + sycl::range<3> block_dims(1, 1, threads); + sycl::range<3> grid_dims( + 1, 1, (total_count + elems_per_block - 1) / elems_per_block); + + /* + DPCT1049:1: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + fused_bias_add fn(input, bias, total_count, intermediate_size); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } #define INSTANTIATE_LAUNCH_BIAS_ADD(T) \ - template void launch_bias_add(T*, const T*, int, int, dpct::queue_ptr); + template void launch_bias_add(T*, const T*, int, int, dpct::queue_ptr); INSTANTIATE_LAUNCH_BIAS_ADD(float) #ifdef BF16_AVAILABLE @@ -164,183 +212,251 @@ INSTANTIATE_LAUNCH_BIAS_ADD(sycl::ext::oneapi::bfloat16) #endif INSTANTIATE_LAUNCH_BIAS_ADD(sycl::half) -void fused_bias_residual(float* residual, - const float* hidden_state, - const float* attn, - const float* bias, - const float* attn_bias, - const int total_count, - const int intermediate_size, - const float mp_scale, - const bool preln) -{ +template +class fused_bias_residual { + private: + T* residual; + const T* hidden_state; + const T* attn; + const T* bias; + const T* attn_bias; + const int total_count; + const int intermediate_size; + const float mp_scale; + const bool preln; + + public: + fused_bias_residual( + T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale, + const bool preln) + : residual(residual), + hidden_state(hidden_state), + attn(attn), + bias(bias), + attn_bias(attn_bias), + total_count(total_count), + intermediate_size(intermediate_size), + mp_scale(mp_scale), + preln(preln) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - sycl::float4* res_fl4_ptr = reinterpret_cast(residual); - const sycl::float4* hs_fl4_ptr = reinterpret_cast(hidden_state); - const sycl::float4* attn_fl4_ptr = reinterpret_cast(attn); - const sycl::float4* bias_fl4_ptr = reinterpret_cast(bias); - const sycl::float4* attn_bias_fl4_ptr = reinterpret_cast(attn_bias); - const int offset = - item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + using T2 = typename std::conditional< + std::is_same::value, + sycl::half2, + sycl::marray>::type; + sycl::float2* res_fl2_ptr = reinterpret_cast(residual); + const sycl::float2* hs_fl2_ptr = + reinterpret_cast(hidden_state); + const sycl::float2* attn_fl2_ptr = + reinterpret_cast(attn); + const sycl::float2* bias_fl2_ptr = + reinterpret_cast(bias); + const sycl::float2* attn_bias_fl2_ptr = + reinterpret_cast(attn_bias); + const int offset = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); if (offset < total_count) { - sycl::float4 res_fl4 = res_fl4_ptr[offset]; - const sycl::float4 hs_fl4 = hs_fl4_ptr[offset]; - const sycl::float4 attn_fl4 = attn_fl4_ptr[offset]; - const sycl::float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; - const sycl::float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size]; - if (preln) { - // residual = (residual + attention + bias + attention_bias) * - // mp_scale + hidden_state - res_fl4.x() = - (res_fl4.x() + attn_fl4.x() + bias_fl4.x() + attn_bias_fl4.x()) * mp_scale + - (hs_fl4.x()); - res_fl4.y() = - (res_fl4.y() + attn_fl4.y() + bias_fl4.y() + attn_bias_fl4.y()) * mp_scale + - (hs_fl4.y()); - res_fl4.z() = - (res_fl4.z() + attn_fl4.z() + bias_fl4.z() + attn_bias_fl4.z()) * mp_scale + - (hs_fl4.z()); - res_fl4.w() = - (res_fl4.w() + attn_fl4.w() + bias_fl4.w() + attn_bias_fl4.w()) * mp_scale + - (hs_fl4.w()); - } else { - // residual += hidden_state + bias - res_fl4.x() = res_fl4.x() + hs_fl4.x() + bias_fl4.x(); - res_fl4.y() = res_fl4.y() + hs_fl4.y() + bias_fl4.y(); - res_fl4.z() = res_fl4.z() + hs_fl4.z() + bias_fl4.z(); - res_fl4.w() = res_fl4.w() + hs_fl4.w() + bias_fl4.w(); - } - res_fl4_ptr[offset] = res_fl4; + sycl::float2 res_fl2 = res_fl2_ptr[offset]; + const sycl::float2 hs_fl2 = hs_fl2_ptr[offset]; + const sycl::float2 attn_fl2 = attn_fl2_ptr[offset]; + const sycl::float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; + const sycl::float2 attn_bias_fl2 = + attn_bias_fl2_ptr[offset % intermediate_size]; + + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + + sycl::float2 res_low = conversion::to(res_half2[0]); + sycl::float2 res_high = conversion::to(res_half2[1]); + + const sycl::float2 hs_low = conversion::to(hs_half2[0]); + const sycl::float2 hs_high = conversion::to(hs_half2[1]); + + const sycl::float2 attn_low = conversion::to(attn_half2[0]); + const sycl::float2 attn_high = + conversion::to(attn_half2[1]); + + const sycl::float2 bias_low = conversion::to(bias_half2[0]); + const sycl::float2 bias_high = + conversion::to(bias_half2[1]); + + const sycl::float2 attn_bias_low = + conversion::to(attn_bias_half2[0]); + const sycl::float2 attn_bias_high = + conversion::to(attn_bias_half2[1]); + + if (preln) { + // residual = (residual + attention + bias + attention_bias) * + // mp_scale + hidden_state + res_low.x() = + (res_low.x() + attn_low.x() + bias_low.x() + attn_bias_low.x()) * + mp_scale + + hs_low.x(); + res_low.y() = + (res_low.y() + attn_low.y() + bias_low.y() + attn_bias_low.y()) * + mp_scale + + hs_low.y(); + res_high.x() = (res_high.x() + attn_high.x() + bias_high.x() + + attn_bias_high.x()) * + mp_scale + + hs_high.x(); + res_high.y() = (res_high.y() + attn_high.y() + bias_high.y() + + attn_bias_high.y()) * + mp_scale + + hs_high.y(); + } else { + // residual += hidden_state + bias + res_low.x() = (res_low.x() + hs_low.x() + bias_low.x()); + res_low.y() = (res_low.y() + hs_low.y() + bias_low.y()); + res_high.x() = (res_high.x() + hs_high.x() + bias_high.x()); + res_high.y() = (res_high.y() + hs_high.y() + bias_high.y()); + } + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); + + res_fl2_ptr[offset] = res_fl2; } -} - -template -/* -DPCT1110:2: The total declared local variable size in device function fused_bias_residual exceeds -128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void fused_bias_residual(T* residual, - const T* hidden_state, - const T* attn, - const T* bias, - const T* attn_bias, - const int total_count, - const int intermediate_size, - const float mp_scale, - const bool preln) -{ + } +}; + +template <> +class fused_bias_residual { + private: + float* residual; + const float* hidden_state; + const float* attn; + const float* bias; + const float* attn_bias; + const int total_count; + const int intermediate_size; + const float mp_scale; + const bool preln; + + public: + fused_bias_residual( + float* residual, + const float* hidden_state, + const float* attn, + const float* bias, + const float* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale, + const bool preln) + : residual(residual), + hidden_state(hidden_state), + attn(attn), + bias(bias), + attn_bias(attn_bias), + total_count(total_count), + intermediate_size(intermediate_size), + mp_scale(mp_scale), + preln(preln) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - sycl::float2* res_fl2_ptr = reinterpret_cast(residual); - const sycl::float2* hs_fl2_ptr = reinterpret_cast(hidden_state); - const sycl::float2* attn_fl2_ptr = reinterpret_cast(attn); - const sycl::float2* bias_fl2_ptr = reinterpret_cast(bias); - const sycl::float2* attn_bias_fl2_ptr = reinterpret_cast(attn_bias); - const int offset = - item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + sycl::float4* res_fl4_ptr = reinterpret_cast(residual); + const sycl::float4* hs_fl4_ptr = + reinterpret_cast(hidden_state); + const sycl::float4* attn_fl4_ptr = + reinterpret_cast(attn); + const sycl::float4* bias_fl4_ptr = + reinterpret_cast(bias); + const sycl::float4* attn_bias_fl4_ptr = + reinterpret_cast(attn_bias); + const int offset = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); if (offset < total_count) { - sycl::float2 res_fl2 = res_fl2_ptr[offset]; - const sycl::float2 hs_fl2 = hs_fl2_ptr[offset]; - const sycl::float2 attn_fl2 = attn_fl2_ptr[offset]; - const sycl::float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; - const sycl::float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; - - T2* res_half2 = reinterpret_cast(&res_fl2); - const T2* hs_half2 = reinterpret_cast(&hs_fl2); - const T2* attn_half2 = reinterpret_cast(&attn_fl2); - const T2* bias_half2 = reinterpret_cast(&bias_fl2); - const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); - - sycl::float2 res_low = conversion::to(res_half2[0]); - sycl::float2 res_high = conversion::to(res_half2[1]); - - const sycl::float2 hs_low = conversion::to(hs_half2[0]); - const sycl::float2 hs_high = conversion::to(hs_half2[1]); - - const sycl::float2 attn_low = conversion::to(attn_half2[0]); - const sycl::float2 attn_high = conversion::to(attn_half2[1]); - - const sycl::float2 bias_low = conversion::to(bias_half2[0]); - const sycl::float2 bias_high = conversion::to(bias_half2[1]); - - const sycl::float2 attn_bias_low = conversion::to(attn_bias_half2[0]); - const sycl::float2 attn_bias_high = conversion::to(attn_bias_half2[1]); - - if (preln) { - // residual = (residual + attention + bias + attention_bias) * - // mp_scale + hidden_state - res_low.x() = - (res_low.x() + attn_low.x() + bias_low.x() + attn_bias_low.x()) * mp_scale + - hs_low.x(); - res_low.y() = - (res_low.y() + attn_low.y() + bias_low.y() + attn_bias_low.y()) * mp_scale + - hs_low.y(); - res_high.x() = - (res_high.x() + attn_high.x() + bias_high.x() + attn_bias_high.x()) * mp_scale + - hs_high.x(); - res_high.y() = - (res_high.y() + attn_high.y() + bias_high.y() + attn_bias_high.y()) * mp_scale + - hs_high.y(); - } else { - // residual += hidden_state + bias - res_low.x() = (res_low.x() + hs_low.x() + bias_low.x()); - res_low.y() = (res_low.y() + hs_low.y() + bias_low.y()); - res_high.x() = (res_high.x() + hs_high.x() + bias_high.x()); - res_high.y() = (res_high.y() + hs_high.y() + bias_high.y()); - } - res_half2[0] = conversion::to(res_low); - res_half2[1] = conversion::to(res_high); - - res_fl2_ptr[offset] = res_fl2; + sycl::float4 res_fl4 = res_fl4_ptr[offset]; + const sycl::float4 hs_fl4 = hs_fl4_ptr[offset]; + const sycl::float4 attn_fl4 = attn_fl4_ptr[offset]; + const sycl::float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; + const sycl::float4 attn_bias_fl4 = + attn_bias_fl4_ptr[offset % intermediate_size]; + if (preln) { + // residual = (residual + attention + bias + attention_bias) * + // mp_scale + hidden_state + res_fl4.x() = + (res_fl4.x() + attn_fl4.x() + bias_fl4.x() + attn_bias_fl4.x()) * + mp_scale + + (hs_fl4.x()); + res_fl4.y() = + (res_fl4.y() + attn_fl4.y() + bias_fl4.y() + attn_bias_fl4.y()) * + mp_scale + + (hs_fl4.y()); + res_fl4.z() = + (res_fl4.z() + attn_fl4.z() + bias_fl4.z() + attn_bias_fl4.z()) * + mp_scale + + (hs_fl4.z()); + res_fl4.w() = + (res_fl4.w() + attn_fl4.w() + bias_fl4.w() + attn_bias_fl4.w()) * + mp_scale + + (hs_fl4.w()); + } else { + // residual += hidden_state + bias + res_fl4.x() = res_fl4.x() + hs_fl4.x() + bias_fl4.x(); + res_fl4.y() = res_fl4.y() + hs_fl4.y() + bias_fl4.y(); + res_fl4.z() = res_fl4.z() + hs_fl4.z() + bias_fl4.z(); + res_fl4.w() = res_fl4.w() + hs_fl4.w() + bias_fl4.w(); + } + res_fl4_ptr[offset] = res_fl4; } -} + } +}; template -void launch_bias_residual(T* residual, - T* hidden_state, - T* attn, - T* bias, - T* attn_bias, - int batch, - int hidden_dim, - int mp_size, - bool preln, - dpct::queue_ptr stream) -{ - int total_count = batch * hidden_dim / 4; - sycl::range<3> block_dims(1, 1, 1024); - sycl::range<3> grid_dims(1, 1, (total_count - 1) / 1024 + 1); // (batch_size); - - /* - DPCT1049:3: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - fused_bias_residual(residual, - hidden_state, - attn, - bias, - attn_bias, - total_count, - hidden_dim / 4, - 1.0 / mp_size, - preln); - }); - } +void launch_bias_residual( + T* residual, + T* hidden_state, + T* attn, + T* bias, + T* attn_bias, + int batch, + int hidden_dim, + int mp_size, + bool preln, + dpct::queue_ptr stream) { + int total_count = batch * hidden_dim / 4; + sycl::range<3> block_dims(1, 1, 1024); + sycl::range<3> grid_dims(1, 1, (total_count - 1) / 1024 + 1); // (batch_size); + + /* + DPCT1049:3: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + fused_bias_residual fn( + residual, + hidden_state, + attn, + bias, + attn_bias, + total_count, + hidden_dim / 4, + 1.0 / mp_size, + preln); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } #define INSTANTIATE_LAUNCH_BIAS_RESIDUAL(T) \ - template void launch_bias_residual(T*, T*, T*, T*, T*, int, int, int, bool, dpct::queue_ptr); + template void launch_bias_residual( \ + T*, T*, T*, T*, T*, int, int, int, bool, dpct::queue_ptr); INSTANTIATE_LAUNCH_BIAS_RESIDUAL(float); #ifdef BF16_AVAILABLE @@ -348,161 +464,226 @@ INSTANTIATE_LAUNCH_BIAS_RESIDUAL(sycl::ext::oneapi::bfloat16); #endif INSTANTIATE_LAUNCH_BIAS_RESIDUAL(sycl::half); -void gptj_residual_add(float* residual, - const float* hidden_state, - const float* attn, - const float* bias, - const float* attn_bias, - const int total_count, - const int intermediate_size, - const float mp_scale) -{ +template +class gptj_residual_add { + private: + T* residual; + const T* hidden_state; + const T* attn; + const T* bias; + const T* attn_bias; + const int total_count; + const int intermediate_size; + const float mp_scale; + + public: + gptj_residual_add( + T* residual, + const T* hidden_state, + const T* attn, + const T* bias, + const T* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale) + : residual(residual), + hidden_state(hidden_state), + attn(attn), + bias(bias), + attn_bias(attn_bias), + total_count(total_count), + intermediate_size(intermediate_size), + mp_scale(mp_scale) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - sycl::float4* res_fl4_ptr = reinterpret_cast(residual); - const sycl::float4* hs_fl4_ptr = reinterpret_cast(hidden_state); - const sycl::float4* attn_fl4_ptr = reinterpret_cast(attn); - const sycl::float4* bias_fl4_ptr = reinterpret_cast(bias); - const sycl::float4* attn_bias_fl4_ptr = reinterpret_cast(attn_bias); - const int offset = - item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + using T2 = typename std::conditional< + std::is_same::value, + sycl::half2, + sycl::marray>::type; + sycl::float2* res_fl2_ptr = reinterpret_cast(residual); + const sycl::float2* hs_fl2_ptr = + reinterpret_cast(hidden_state); + const sycl::float2* attn_fl2_ptr = + reinterpret_cast(attn); + const sycl::float2* bias_fl2_ptr = + reinterpret_cast(bias); + const sycl::float2* attn_bias_fl2_ptr = + reinterpret_cast(attn_bias); + const int offset = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); if (offset < total_count) { - sycl::float4 res_fl4 = res_fl4_ptr[offset]; - const sycl::float4 hs_fl4 = hs_fl4_ptr[offset]; - const sycl::float4 attn_fl4 = attn_fl4_ptr[offset]; - const sycl::float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; - - if (attn_bias) { - sycl::float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size]; - // residual += attention_bias - res_fl4.x() += attn_bias_fl4.x(); - res_fl4.y() += attn_bias_fl4.y(); - res_fl4.z() += attn_bias_fl4.z(); - res_fl4.w() += attn_bias_fl4.w(); - } - // residual = hidden_state + attention + (residual + bias) * mp_scale - res_fl4.x() = hs_fl4.x() + attn_fl4.x() + (res_fl4.x() + bias_fl4.x()) * mp_scale; - res_fl4.y() = hs_fl4.y() + attn_fl4.y() + (res_fl4.y() + bias_fl4.y()) * mp_scale; - res_fl4.z() = hs_fl4.z() + attn_fl4.z() + (res_fl4.z() + bias_fl4.z()) * mp_scale; - res_fl4.w() = hs_fl4.w() + attn_fl4.w() + (res_fl4.w() + bias_fl4.w()) * mp_scale; + sycl::float2 res_fl2 = res_fl2_ptr[offset]; + const sycl::float2 hs_fl2 = hs_fl2_ptr[offset]; + const sycl::float2 attn_fl2 = attn_fl2_ptr[offset]; + const sycl::float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; - res_fl4_ptr[offset] = res_fl4; - } -} + T2* res_half2 = reinterpret_cast(&res_fl2); + const T2* hs_half2 = reinterpret_cast(&hs_fl2); + const T2* attn_half2 = reinterpret_cast(&attn_fl2); + const T2* bias_half2 = reinterpret_cast(&bias_fl2); -template -/* -DPCT1110:4: The total declared local variable size in device function gptj_residual_add exceeds 128 -bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void gptj_residual_add(T* residual, - const T* hidden_state, - const T* attn, - const T* bias, - const T* attn_bias, - const int total_count, - const int intermediate_size, - const float mp_scale) -{ - auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - sycl::float2* res_fl2_ptr = reinterpret_cast(residual); - const sycl::float2* hs_fl2_ptr = reinterpret_cast(hidden_state); - const sycl::float2* attn_fl2_ptr = reinterpret_cast(attn); - const sycl::float2* bias_fl2_ptr = reinterpret_cast(bias); - const sycl::float2* attn_bias_fl2_ptr = reinterpret_cast(attn_bias); - const int offset = - item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + sycl::float2 res_low = conversion::to(res_half2[0]); + sycl::float2 res_high = conversion::to(res_half2[1]); - if (offset < total_count) { - sycl::float2 res_fl2 = res_fl2_ptr[offset]; - const sycl::float2 hs_fl2 = hs_fl2_ptr[offset]; - const sycl::float2 attn_fl2 = attn_fl2_ptr[offset]; - const sycl::float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; - - T2* res_half2 = reinterpret_cast(&res_fl2); - const T2* hs_half2 = reinterpret_cast(&hs_fl2); - const T2* attn_half2 = reinterpret_cast(&attn_fl2); - const T2* bias_half2 = reinterpret_cast(&bias_fl2); - - sycl::float2 res_low = conversion::to(res_half2[0]); - sycl::float2 res_high = conversion::to(res_half2[1]); - - const sycl::float2 hs_low = conversion::to(hs_half2[0]); - const sycl::float2 hs_high = conversion::to(hs_half2[1]); - - const sycl::float2 attn_low = conversion::to(attn_half2[0]); - const sycl::float2 attn_high = conversion::to(attn_half2[1]); - - const sycl::float2 bias_low = conversion::to(bias_half2[0]); - const sycl::float2 bias_high = conversion::to(bias_half2[1]); - - if (attn_bias) { - const sycl::float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; - const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); - const sycl::float2 attn_bias_low = conversion::to(attn_bias_half2[0]); - const sycl::float2 attn_bias_high = conversion::to(attn_bias_half2[1]); - // residual += attention_bias - res_low.x() += attn_bias_low.x(); - res_low.y() += attn_bias_low.y(); - res_high.x() += attn_bias_high.x(); - res_high.y() += attn_bias_high.y(); - } - // residual = hidden_state + attention + (residual + bias) * mp_scale - res_low.x() = attn_low.x() + hs_low.x() + (res_low.x() + bias_low.x()) * mp_scale; - res_low.y() = attn_low.y() + hs_low.y() + (res_low.y() + bias_low.y()) * mp_scale; - res_high.x() = attn_high.x() + hs_high.x() + (res_high.x() + bias_high.x()) * mp_scale; - res_high.y() = attn_high.y() + hs_high.y() + (res_high.y() + bias_high.y()) * mp_scale; + const sycl::float2 hs_low = conversion::to(hs_half2[0]); + const sycl::float2 hs_high = conversion::to(hs_half2[1]); - res_half2[0] = conversion::to(res_low); - res_half2[1] = conversion::to(res_high); + const sycl::float2 attn_low = conversion::to(attn_half2[0]); + const sycl::float2 attn_high = + conversion::to(attn_half2[1]); - res_fl2_ptr[offset] = res_fl2; + const sycl::float2 bias_low = conversion::to(bias_half2[0]); + const sycl::float2 bias_high = + conversion::to(bias_half2[1]); + + if (attn_bias) { + const sycl::float2 attn_bias_fl2 = + attn_bias_fl2_ptr[offset % intermediate_size]; + const T2* attn_bias_half2 = reinterpret_cast(&attn_bias_fl2); + const sycl::float2 attn_bias_low = + conversion::to(attn_bias_half2[0]); + const sycl::float2 attn_bias_high = + conversion::to(attn_bias_half2[1]); + // residual += attention_bias + res_low.x() += attn_bias_low.x(); + res_low.y() += attn_bias_low.y(); + res_high.x() += attn_bias_high.x(); + res_high.y() += attn_bias_high.y(); + } + // residual = hidden_state + attention + (residual + bias) * mp_scale + res_low.x() = + attn_low.x() + hs_low.x() + (res_low.x() + bias_low.x()) * mp_scale; + res_low.y() = + attn_low.y() + hs_low.y() + (res_low.y() + bias_low.y()) * mp_scale; + res_high.x() = attn_high.x() + hs_high.x() + + (res_high.x() + bias_high.x()) * mp_scale; + res_high.y() = attn_high.y() + hs_high.y() + + (res_high.y() + bias_high.y()) * mp_scale; + + res_half2[0] = conversion::to(res_low); + res_half2[1] = conversion::to(res_high); + + res_fl2_ptr[offset] = res_fl2; } -} + } +}; + +template <> +class gptj_residual_add { + private: + float* residual; + const float* hidden_state; + const float* attn; + const float* bias; + const float* attn_bias; + const int total_count; + const int intermediate_size; + const float mp_scale; + + public: + gptj_residual_add( + float* residual, + const float* hidden_state, + const float* attn, + const float* bias, + const float* attn_bias, + const int total_count, + const int intermediate_size, + const float mp_scale) + : residual(residual), + hidden_state(hidden_state), + attn(attn), + bias(bias), + attn_bias(attn_bias), + total_count(total_count), + intermediate_size(intermediate_size), + mp_scale(mp_scale) {} + + void operator()(sycl::nd_item<3>) const { + auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); + sycl::float4* res_fl4_ptr = reinterpret_cast(residual); + const sycl::float4* hs_fl4_ptr = + reinterpret_cast(hidden_state); + const sycl::float4* attn_fl4_ptr = + reinterpret_cast(attn); + const sycl::float4* bias_fl4_ptr = + reinterpret_cast(bias); + const sycl::float4* attn_bias_fl4_ptr = + reinterpret_cast(attn_bias); + const int offset = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); -template -void launch_gptj_residual_add(T* residual, - T* hidden_state, - T* attn, - T* bias, - T* attn_bias, - int hidden_dim, - int batch, - int mp_size, - dpct::queue_ptr stream) -{ - int total_count = batch * hidden_dim / 4; - sycl::range<3> block_dims(1, 1, 1024); - sycl::range<3> grid_dims(1, 1, (total_count - 1) / 1024 + 1); // (batch_size); - - /* - DPCT1049:5: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - gptj_residual_add(residual, - hidden_state, - attn, - bias, - attn_bias, - total_count, - hidden_dim / 4, - 1.0 / mp_size); - }); + if (offset < total_count) { + sycl::float4 res_fl4 = res_fl4_ptr[offset]; + const sycl::float4 hs_fl4 = hs_fl4_ptr[offset]; + const sycl::float4 attn_fl4 = attn_fl4_ptr[offset]; + const sycl::float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size]; + + if (attn_bias) { + sycl::float4 attn_bias_fl4 = + attn_bias_fl4_ptr[offset % intermediate_size]; + // residual += attention_bias + res_fl4.x() += attn_bias_fl4.x(); + res_fl4.y() += attn_bias_fl4.y(); + res_fl4.z() += attn_bias_fl4.z(); + res_fl4.w() += attn_bias_fl4.w(); + } + // residual = hidden_state + attention + (residual + bias) * mp_scale + res_fl4.x() = + hs_fl4.x() + attn_fl4.x() + (res_fl4.x() + bias_fl4.x()) * mp_scale; + res_fl4.y() = + hs_fl4.y() + attn_fl4.y() + (res_fl4.y() + bias_fl4.y()) * mp_scale; + res_fl4.z() = + hs_fl4.z() + attn_fl4.z() + (res_fl4.z() + bias_fl4.z()) * mp_scale; + res_fl4.w() = + hs_fl4.w() + attn_fl4.w() + (res_fl4.w() + bias_fl4.w()) * mp_scale; + + res_fl4_ptr[offset] = res_fl4; } + } +}; + +template +void launch_gptj_residual_add( + T* residual, + T* hidden_state, + T* attn, + T* bias, + T* attn_bias, + int hidden_dim, + int batch, + int mp_size, + dpct::queue_ptr stream) { + int total_count = batch * hidden_dim / 4; + sycl::range<3> block_dims(1, 1, 1024); + sycl::range<3> grid_dims(1, 1, (total_count - 1) / 1024 + 1); // (batch_size); + + /* + DPCT1049:5: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + gptj_residual_add fn( + residual, + hidden_state, + attn, + bias, + attn_bias, + total_count, + hidden_dim / 4, + 1.0 / mp_size); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } -#define INSTANTIATE_GPT_RES_ADD(T) \ - template void launch_gptj_residual_add(T*, T*, T*, T*, T*, int, int, int, dpct::queue_ptr); +#define INSTANTIATE_GPT_RES_ADD(T) \ + template void launch_gptj_residual_add( \ + T*, T*, T*, T*, T*, int, int, int, dpct::queue_ptr); INSTANTIATE_GPT_RES_ADD(float); INSTANTIATE_GPT_RES_ADD(sycl::half); @@ -511,8 +692,22 @@ INSTANTIATE_GPT_RES_ADD(sycl::ext::oneapi::bfloat16); #endif template -void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) -{ +class moe_res_matmul { + private: + T* residual; + T* coef; + T* mlp_out; + int seq_len; + int hidden_dim; + + public: + moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) + : residual(residual), + coef(coef), + mlp_out(mlp_out), + seq_len(seq_len), + hidden_dim(hidden_dim) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int granularity = 16; constexpr int vals_per_access = granularity / sizeof(T); @@ -520,52 +715,54 @@ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_di T* residual_seq = residual + item_ct1.get_group(2) * hidden_dim; T* mlp_out_seq = mlp_out + item_ct1.get_group(2) * hidden_dim; - for (unsigned tid = item_ct1.get_local_id(2) * vals_per_access; tid < hidden_dim; + for (unsigned tid = item_ct1.get_local_id(2) * vals_per_access; + tid < hidden_dim; tid += item_ct1.get_local_range(2) * vals_per_access) { - T mlp[vals_per_access]; - T res[vals_per_access]; - T coef1[vals_per_access]; - T coef2[vals_per_access]; + T mlp[vals_per_access]; + T res[vals_per_access]; + T coef1[vals_per_access]; + T coef2[vals_per_access]; - mem_access::load_global(mlp, mlp_out_seq + tid); - mem_access::load_global(res, residual_seq + tid); - mem_access::load_global(coef1, coef + tid); - mem_access::load_global(coef2, coef + tid + hidden_dim); + mem_access::load_global(mlp, mlp_out_seq + tid); + mem_access::load_global(res, residual_seq + tid); + mem_access::load_global(coef1, coef + tid); + mem_access::load_global(coef2, coef + tid + hidden_dim); #pragma unroll - for (int idx = 0; idx < vals_per_access; idx++) { - mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx]; - } + for (int idx = 0; idx < vals_per_access; idx++) { + mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx]; + } - mem_access::store_global(mlp_out_seq + tid, mlp); + mem_access::store_global(mlp_out_seq + tid, mlp); } -} + } +}; template -void launch_moe_res_matmul(T* residual, - T* coef, - T* mlp_out, - int seq_len, - int hidden_dim, - dpct::queue_ptr stream) -{ - sycl::range<3> grid_dim(1, 1, seq_len); - sycl::range<3> block_dim(1, 1, 1024); - /* - DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), - [=](sycl::nd_item<3> item_ct1) { - moe_res_matmul(residual, coef, mlp_out, seq_len, hidden_dim); - }); - } +void launch_moe_res_matmul( + T* residual, + T* coef, + T* mlp_out, + int seq_len, + int hidden_dim, + dpct::queue_ptr stream) { + sycl::range<3> grid_dim(1, 1, seq_len); + sycl::range<3> block_dim(1, 1, 1024); + /* + DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + moe_res_matmul fn(residual, coef, mlp_out, seq_len, hidden_dim); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); + } } #define INSTANTIATE_LAUNCH_MOE_RES_MATMUL(T) \ - template void launch_moe_res_matmul(T*, T*, T*, int, int, dpct::queue_ptr); + template void launch_moe_res_matmul(T*, T*, T*, int, int, dpct::queue_ptr); INSTANTIATE_LAUNCH_MOE_RES_MATMUL(float); #ifdef BF16_AVAILABLE @@ -574,15 +771,34 @@ INSTANTIATE_LAUNCH_MOE_RES_MATMUL(sycl::ext::oneapi::bfloat16); INSTANTIATE_LAUNCH_MOE_RES_MATMUL(sycl::half); template -void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head_size) -{ +class pad_data_kernel { + private: + T* padded_output; + T* output; + int head_size; + int padded_head_size; + + public: + pad_data_kernel( + T* padded_output, + T* output, + int head_size, + int padded_head_size) + : padded_output(padded_output), + output(output), + head_size(head_size), + padded_head_size(padded_head_size) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - sycl::float4* padded_output_cast = reinterpret_cast(padded_output); + using T2 = typename std::conditional< + std::is_same::value, + sycl::half2, + sycl::marray>::type; + sycl::float4* padded_output_cast = + reinterpret_cast(padded_output); sycl::float4* output_cast = reinterpret_cast(output); - int bid = item_ct1.get_group(2) * (item_ct1.get_local_range(1)) + item_ct1.get_local_id(1); + int bid = item_ct1.get_group(2) * (item_ct1.get_local_range(1)) + + item_ct1.get_local_id(1); int idx = item_ct1.get_local_id(2); padded_output_cast += (bid * padded_head_size); output_cast += (bid * head_size); @@ -590,46 +806,49 @@ void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head const T2 zero_h = conversion::to(0.f); T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll - for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + for (int i = 0; i < 4; i++) + ZERO_h[i] = zero_h; if (idx < head_size) - padded_output_cast[idx] = output_cast[idx]; + padded_output_cast[idx] = output_cast[idx]; else - padded_output_cast[idx] = ZERO; -} + padded_output_cast[idx] = ZERO; + } +}; -void pad_data_kernel(float* padded_output, - float* output, - int head_size, - int padded_head_size) -{ -} +/* void pad_data_kernel(float* padded_output, */ +/* float* output, */ +/* int head_size, */ +/* int padded_head_size) */ +/* { */ +/* } */ template -void pad_data(T* padded_output, - T* output, - int bsz, - int head_size, - int padded_head_size, - dpct::queue_ptr stream) -{ - sycl::range<3> grid_dim(1, 1, (bsz - 1) / 16 + 1); - sycl::range<3> block_dim(1, 16, padded_head_size / 8); - /* - DPCT1049:7: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { - pad_data_kernel(padded_output, output, head_size / 8, padded_head_size / 8); - }); - } +void pad_data( + T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + dpct::queue_ptr stream) { + sycl::range<3> grid_dim(1, 1, (bsz - 1) / 16 + 1); + sycl::range<3> block_dim(1, 16, padded_head_size / 8); + /* + DPCT1049:7: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + pad_data_kernel fn( + padded_output, output, head_size / 8, padded_head_size / 8); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); + } } #define INSTANTIATE_PAD_DATA(T) \ - template void pad_data(T*, T*, int, int, int, dpct::queue_ptr stream); + template void pad_data(T*, T*, int, int, int, dpct::queue_ptr stream); INSTANTIATE_PAD_DATA(float); INSTANTIATE_PAD_DATA(sycl::half); @@ -638,21 +857,41 @@ INSTANTIATE_PAD_DATA(sycl::ext::oneapi::bfloat16); #endif template -void pad_head_seq_kernel(T* padded_output, - T* output, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size) -{ +class pad_head_seq_kernel { + private: + T* padded_output; + T* output; + int seq_len; + int padded_seq_len; + int head_size; + int padded_head_size; + + public: + pad_head_seq_kernel( + T* padded_output, + T* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) + : padded_output(padded_output), + output(output), + seq_len(seq_len), + padded_seq_len(padded_seq_len), + head_size(head_size), + padded_head_size(padded_head_size) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - sycl::float4* padded_output_cast = reinterpret_cast(padded_output); + using T2 = typename std::conditional< + std::is_same::value, + sycl::half2, + sycl::marray>::type; + sycl::float4* padded_output_cast = + reinterpret_cast(padded_output); sycl::float4* output_cast = reinterpret_cast(output); int bsz = item_ct1.get_group(2); - int bid = item_ct1.get_group(1) * (item_ct1.get_local_range(1)) + item_ct1.get_local_id(1); + int bid = item_ct1.get_group(1) * (item_ct1.get_local_range(1)) + + item_ct1.get_local_id(1); int idx = item_ct1.get_local_id(2); padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; output_cast += (bsz * seq_len + bid) * head_size; @@ -660,56 +899,60 @@ void pad_head_seq_kernel(T* padded_output, const T2 zero_h = conversion::to(0.f); T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll - for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + for (int i = 0; i < 4; i++) + ZERO_h[i] = zero_h; if (idx < head_size && bid < seq_len) - padded_output_cast[idx] = output_cast[idx]; + padded_output_cast[idx] = output_cast[idx]; else - padded_output_cast[idx] = ZERO; -} - -void pad_head_seq_kernel(float* padded_output, - float* output, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size) -{ -} + padded_output_cast[idx] = ZERO; + } +}; + +/* void pad_head_seq_kernel(float* padded_output, */ +/* float* output, */ +/* int seq_len, */ +/* int padded_seq_len, */ +/* int head_size, */ +/* int padded_head_size) */ +/* { */ +/* } */ template -void pad_head_seq(T* padded_output, - T* output, - int bsz, - int seq_len, - int padded_seq_len, - int head_size, - int padded_head_size, - dpct::queue_ptr stream) -{ - sycl::range<3> grid_dim(1, padded_seq_len / 16, bsz); - sycl::range<3> block_dim(1, 16, padded_head_size / 8); - /* - DPCT1049:8: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), - [=](sycl::nd_item<3> item_ct1) { - pad_head_seq_kernel(padded_output, - output, - seq_len, - padded_seq_len, - head_size / 8, - padded_head_size / 8); - }); - } +void pad_head_seq( + T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + dpct::queue_ptr stream) { + sycl::range<3> grid_dim(1, padded_seq_len / 16, bsz); + sycl::range<3> block_dim(1, 16, padded_head_size / 8); + /* + DPCT1049:8: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + pad_head_seq_kernel fn( + padded_output, + output, + seq_len, + padded_seq_len, + head_size / 8, + padded_head_size / 8); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); + } } #define INSTANTIATE_PAD_HEAD_SEQ(T) \ - template void pad_head_seq(T*, T*, int, int, int, int, int, dpct::queue_ptr); + template void pad_head_seq( \ + T*, T*, int, int, int, int, int, dpct::queue_ptr); INSTANTIATE_PAD_HEAD_SEQ(sycl::half); #ifdef BF16_AVAILABLE @@ -718,134 +961,143 @@ INSTANTIATE_PAD_HEAD_SEQ(sycl::ext::oneapi::bfloat16); INSTANTIATE_PAD_HEAD_SEQ(float); // TODO(cmikeh2): evaluate different GeLU performance -__dpct_inline__ float old_gelu(float val) -{ - // 1 / sqrt(2) - constexpr float rsqrt_2 = 0.707106769084930419922; - return val * 0.5f * (1.0f + sycl::erf(val * rsqrt_2)); +__dpct_inline__ float old_gelu(float val) { + // 1 / sqrt(2) + constexpr float rsqrt_2 = 0.707106769084930419922; + return val * 0.5f * (1.0f + sycl::erf(val * rsqrt_2)); } namespace fused_geglu { constexpr int threads = 256; constexpr int steps = 2; constexpr int granularity = 16; -} // namespace fused_geglu +} // namespace fused_geglu -__dpct_inline__ float silu(float val) { return val / (1.0f + sycl::native::exp(-val)); } +__dpct_inline__ float silu(float val) { + return val / (1.0f + sycl::native::exp(-val)); +} template -void fused_gate_activation(T* output, - const T* activation, - const T* bias, - int base_channels, - int output_stride, - int total_elems) -{ +class fused_gate_activation { + private: + T* output; + const T* activation; + const T* bias; + int base_channels; + int output_stride; + int total_elems; + + public: + fused_gate_activation( + T* output, + const T* activation, + const T* bias, + int base_channels, + int output_stride, + int total_elems) + : output(output), + activation(activation), + bias(bias), + base_channels(base_channels), + output_stride(output_stride), + total_elems(total_elems) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int T_per_access = fused_geglu::granularity / sizeof(T); constexpr int T_per_step = T_per_access * fused_geglu::threads; constexpr int T_per_block = T_per_step * fused_geglu::steps; - const int id = item_ct1.get_group(2) * T_per_block + item_ct1.get_local_id(2) * T_per_access; + const int id = item_ct1.get_group(2) * T_per_block + + item_ct1.get_local_id(2) * T_per_access; #pragma unroll for (int i = 0; i < fused_geglu::steps; i++) { - T activation_buffer_1[T_per_access]; - T activation_buffer_2[T_per_access]; - T bias_buffer_1[T_per_access]; - T bias_buffer_2[T_per_access]; - - const int iter_id = id + T_per_step * i; - if (iter_id < total_elems) { - const int channel_id = iter_id % base_channels; - const int seq_id = iter_id / base_channels; - const int seq_offset = seq_id * base_channels * 2; - - mem_access::load_global(activation_buffer_1, - activation + seq_offset + channel_id); - mem_access::load_global( - activation_buffer_2, activation + seq_offset + channel_id + base_channels); - mem_access::load_global( - bias_buffer_1, bias + channel_id, bias != nullptr); - mem_access::load_global( - bias_buffer_2, bias + channel_id + base_channels, bias != nullptr); - - // Since the GeLU is going to happen at float, might as well - // convert + T activation_buffer_1[T_per_access]; + T activation_buffer_2[T_per_access]; + T bias_buffer_1[T_per_access]; + T bias_buffer_2[T_per_access]; + + const int iter_id = id + T_per_step * i; + if (iter_id < total_elems) { + const int channel_id = iter_id % base_channels; + const int seq_id = iter_id / base_channels; + const int seq_offset = seq_id * base_channels * 2; + + mem_access::load_global( + activation_buffer_1, activation + seq_offset + channel_id); + mem_access::load_global( + activation_buffer_2, + activation + seq_offset + channel_id + base_channels); + mem_access::load_global( + bias_buffer_1, bias + channel_id, bias != nullptr); + mem_access::load_global( + bias_buffer_2, bias + channel_id + base_channels, bias != nullptr); + + // Since the GeLU is going to happen at float, might as well + // convert #pragma unroll - for (int v = 0; v < T_per_access; v++) { - T hidden_state = activation_buffer_1[v] + bias_buffer_1[v]; - T pre_gate = activation_buffer_2[v] + bias_buffer_2[v]; - float pre_gate_f = conversion::to(pre_gate); - float gate_f = (useGelu) ? old_gelu(pre_gate_f) : silu(pre_gate_f); - T gate = conversion::to(gate_f); - activation_buffer_1[v] = hidden_state * gate; - } - - mem_access::store_global( - output + seq_id * output_stride + channel_id, activation_buffer_1); + for (int v = 0; v < T_per_access; v++) { + T hidden_state = activation_buffer_1[v] + bias_buffer_1[v]; + T pre_gate = activation_buffer_2[v] + bias_buffer_2[v]; + float pre_gate_f = conversion::to(pre_gate); + float gate_f = (useGelu) ? old_gelu(pre_gate_f) : silu(pre_gate_f); + T gate = conversion::to(gate_f); + activation_buffer_1[v] = hidden_state * gate; } + + mem_access::store_global( + output + seq_id * output_stride + channel_id, activation_buffer_1); + } } -} + } +}; template -void launch_gated_activation(T* output, - const T* activation, - const T* bias, - int rows, - int output_stride, - int elems_per_row, - bool use_gelu, - dpct::queue_ptr stream) -{ - /* - Fused bias GEGLU is a variant of the gated activation functions. - The input here is a matrix of [batch, seq_len, 2 * intermediate_dim] - where the second half of the channels act as GeLU gates for the first - half. - */ - - // Re-derive the above figures - constexpr int T_per_access = fused_geglu::granularity / sizeof(T); - constexpr int T_per_step = T_per_access * fused_geglu::threads; - constexpr int T_per_block = T_per_step * fused_geglu::steps; - - const int base_channels = elems_per_row / 2; - const int total_elems = base_channels * rows; - - sycl::range<3> block(1, 1, fused_geglu::threads); - sycl::range<3> grid(1, 1, (total_elems + T_per_block - 1) / T_per_block); - - if (use_gelu) { - /* - DPCT1049:9: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { - fused_gate_activation( - output, activation, bias, base_channels, output_stride, total_elems); - }); - } else { - /* - DPCT1049:10: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for( - sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { - fused_gate_activation( - output, activation, bias, base_channels, output_stride, total_elems); - }); - } +void launch_gated_activation( + T* output, + const T* activation, + const T* bias, + int rows, + int output_stride, + int elems_per_row, + bool use_gelu, + dpct::queue_ptr stream) { + /* + Fused bias GEGLU is a variant of the gated activation functions. + The input here is a matrix of [batch, seq_len, 2 * intermediate_dim] + where the second half of the channels act as GeLU gates for the first + half. + */ + + // Re-derive the above figures + constexpr int T_per_access = fused_geglu::granularity / sizeof(T); + constexpr int T_per_step = T_per_access * fused_geglu::threads; + constexpr int T_per_block = T_per_step * fused_geglu::steps; + + const int base_channels = elems_per_row / 2; + const int total_elems = base_channels * rows; + + sycl::range<3> block(1, 1, fused_geglu::threads); + sycl::range<3> grid(1, 1, (total_elems + T_per_block - 1) / T_per_block); + + if (use_gelu) { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + fused_gate_activation fn( + output, activation, bias, base_channels, output_stride, total_elems); + stream->parallel_for(sycl::nd_range<3>(grid * block, block), fn); + } else { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + fused_gate_activation fn( + output, activation, bias, base_channels, output_stride, total_elems); + stream->parallel_for(sycl::nd_range<3>(grid * block, block), fn); + } } #define INSTANTIATE_LAUNCH_GATED_ACTIVATION(T) \ - template void launch_gated_activation( \ - T*, const T*, const T*, int, int, int, bool, dpct::queue_ptr); + template void launch_gated_activation( \ + T*, const T*, const T*, int, int, int, bool, dpct::queue_ptr); INSTANTIATE_LAUNCH_GATED_ACTIVATION(sycl::half); #ifdef BF16_AVAILABLE diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/layer_norm.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/layer_norm.dp.cpp index 18a0ee5..c27a4f2 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/layer_norm.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/layer_norm.dp.cpp @@ -1,13 +1,28 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" #include "ds_kernel_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" #include "memory_access_utils.h" #include "reduction_utils.h" @@ -15,7 +30,7 @@ using rop = reduce::ROpType; namespace ln { constexpr int granularity = 16; -} // namespace ln +} // namespace ln /* Primary layer norm implementation. Assumes elems_per_row % 8 @@ -30,31 +45,45 @@ is equal to 0. elems_per_row: number of elements each block will normalize */ template -/* -DPCT1110:3: The total declared local variable size in device function fused_ln exceeds 128 bytes and -may cause high register pressure. Consult with your hardware vendor to find the total register size -available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -void fused_ln(T* output, - const T* vals, - const T* gamma, - const T* beta, - float epsilon, - int elems_per_row) -{ +class fused_ln { + private: + T* output; + const T* vals; + const T* gamma; + const T* beta; + float epsilon; + int elems_per_row; + + public: + fused_ln( + T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) + : output(output), + vals(vals), + gamma(gamma), + beta(beta), + epsilon(epsilon), + elems_per_row(elems_per_row) {} + + void operator()(sycl::nd_item<3>) const { constexpr int T_per_load = ln::granularity / sizeof(T); sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); // X-dimension of the block - const int block_offset = - (tb.get_group_id()[2] * (maxThreads / threadsPerGroup) * elems_per_row) + + const int block_offset = (tb.get_group_id()[2] * + (maxThreads / threadsPerGroup) * elems_per_row) + (tb.get_local_id()[1] * elems_per_row); const int thread_offset = tb.get_local_id()[2] * T_per_load; const int base_offset = block_offset + thread_offset; const int stride = - sycl::ext::oneapi::experimental::this_nd_item<3>().get_local_range(2) * T_per_load; + sycl::ext::oneapi::experimental::this_nd_item<3>().get_local_range(2) * + T_per_load; float sum = reduce::init(); @@ -64,16 +93,18 @@ void fused_ln(T* output, #pragma unRoll for (int i = 0; i < unRoll; i++) { - T* iteration_buffer = local_buffer + i * T_per_load; + T* iteration_buffer = local_buffer + i * T_per_load; - mem_access::load_global( - iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + mem_access::load_global( + iteration_buffer, + input_base + i * stride, + thread_offset + i * stride < elems_per_row); #pragma unRoll - for (int j = 0; j < T_per_load; j++) { - float vals_up_cast = conversion::to(iteration_buffer[j]); - sum = reduce::element(sum, vals_up_cast); - } + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + sum = reduce::element(sum, vals_up_cast); + } } reduce::partitioned_block(tb, warp, sum); @@ -84,21 +115,23 @@ void fused_ln(T* output, #pragma unRoll for (int i = 0; i < unRoll; i++) { #pragma unRoll - for (int j = 0; j < T_per_load; j++) { - // Using a 0 value here skews the variance, have to if-guard - if (thread_offset + i * stride < elems_per_row) { - float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); - mean_diff = reduce::element(mean_diff, diff * diff); - } + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = + (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); } + } } reduce::partitioned_block(tb, warp, mean_diff); const float variance = mean_diff / elems_per_row; /* - DPCT1013:9: The rounding mode could not be specified and the generated code may have different - accuracy than the original code. Verify the correctness. SYCL math built-in function rounding - mode is aligned with OpenCL C 1.2 standard. + DPCT1013:9: The rounding mode could not be specified and the generated code + may have different accuracy than the original code. Verify the correctness. + SYCL math built-in function rounding mode is aligned with OpenCL C 1.2 + standard. */ const float denom = sycl::rsqrt(variance + epsilon); @@ -109,123 +142,122 @@ void fused_ln(T* output, #pragma unRoll for (int i = 0; i < unRoll; i++) { - T* iteration_buffer = local_buffer + i * T_per_load; - const int iter_idx = i * stride + thread_offset; - const bool do_loads = iter_idx < elems_per_row; + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; - T gamma_local[T_per_load], beta_local[T_per_load]; + T gamma_local[T_per_load], beta_local[T_per_load]; - mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); - mem_access::load_global(beta_local, beta + iter_idx, do_loads); + mem_access::load_global( + gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global( + beta_local, beta + iter_idx, do_loads); #pragma unRoll - for (int j = 0; j < T_per_load; j++) { - float val = conversion::to(iteration_buffer[j]); - val = (val - mean) * denom; - val = - val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); - iteration_buffer[j] = conversion::to(val); - } - - if (do_loads) { - mem_access::store_global(block_output + iter_idx, iteration_buffer); - } + for (int j = 0; j < T_per_load; j++) { + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = val * conversion::to(gamma_local[j]) + + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global( + block_output + iter_idx, iteration_buffer); + } } -} - -/* -DPCT1049:4: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. -*/ -#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \ - { \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* output_ct0 = output; \ - const T* vals_ct1 = vals; \ - const T* gamma_ct2 = gamma; \ - const T* beta_ct3 = beta; \ - auto epsilon_ct4 = epsilon; \ - auto elems_per_row_ct5 = elems_per_row; \ - \ - cgh.parallel_for( \ - sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - fused_ln( \ - output_ct0, vals_ct1, gamma_ct2, beta_ct3, epsilon_ct4, elems_per_row_ct5); \ - }); \ - }); \ + } +}; + +#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \ + { \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + stream->submit([&](sycl::handler& cgh) { \ + fused_ln fn( \ + output, vals, gamma, beta, epsilon, elems_per_row); \ + \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ + }); \ } template -void launch_fused_ln(T* output, - const T* vals, - const T* gamma, - const T* beta, - float epsilon, - int rows, - int elems_per_row, - dpct::queue_ptr stream) -{ - // 8 for sycl::half, 4 for float - constexpr int T_per_load = ln::granularity / sizeof(T); - - constexpr int maxThreads = 256; - - // For Flaoat, unRoll 4, for sycl::half, unRoll 2 - constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; - - const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; - const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; - - // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of - // warp-sized blocks rather than stepping up to 64/96 threads - const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); - const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; - - const int groups_per_block_max = - is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; - const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; - const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; - - sycl::range<3> block(1, groups_per_block, threadsPerGroup); - sycl::range<3> grid(1, 1, groups_launch); - - const int elems_per_step = threadsPerGroup * h_per_step; - const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; - - if (is_subblock_schedule) { - // <=128 - if (threadsPerGroup == 1) { - LAUNCH_FUSED_LN(1, 1, maxThreads); - } else if (threadsPerGroup == 2) { - LAUNCH_FUSED_LN(1, 2, maxThreads); - } else if (threadsPerGroup == 4) { - LAUNCH_FUSED_LN(1, 4, maxThreads); - } else if (threadsPerGroup == 8) { - LAUNCH_FUSED_LN(1, 8, maxThreads); - } else if (threadsPerGroup == 16) { - LAUNCH_FUSED_LN(1, 16, maxThreads); - } - } else if (external_unRoll == 1) { - // 129 - 4096 elems - // (this can launch with 1-7 warps as well) - LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 2) { - // 4097 - 8192 elems - LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 3) { - // 8193 - 12288 elems - LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 4) { - // 12289 - 16384 elems - LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads); +void launch_fused_ln( + T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream) { + // 8 for sycl::half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for sycl::half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = + is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign + // multiple stages of warp-sized blocks rather than stepping up to 64/96 + // threads + const int one_step_threads = + next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = + (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = is_subblock_schedule + ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup + : 1; + const int groups_per_block = + (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + sycl::range<3> block(1, groups_per_block, threadsPerGroup); + sycl::range<3> grid(1, 1, groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = + (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_LN(1, 16, maxThreads); } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads); + } } -#define INSTANTIATE_FUSED_LN(T) \ - template void launch_fused_ln( \ - T*, const T*, const T*, const T*, float, int, int, dpct::queue_ptr); +#define INSTANTIATE_FUSED_LN(T) \ + template void launch_fused_ln( \ + T*, const T*, const T*, const T*, float, int, int, dpct::queue_ptr); INSTANTIATE_FUSED_LN(sycl::half); #ifdef BF16_AVAILABLE @@ -254,36 +286,59 @@ Template arg: StoreResidual: controls whether the residual calculation is stored or not. When set to false, the input `res_output` is unused. */ -template -/* -DPCT1110:5: The total declared local variable size in device function fused_residual_ln exceeds 128 -bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void fused_residual_ln(T* output, - T* res_output, - const T* vals, - const T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int elems_per_row) -{ +template < + typename T, + int unRoll, + int threadsPerGroup, + int maxThreads, + bool preLnResidual> +class fused_residual_ln { + private: + T* output; + T* res_output; + const T* vals; + const T* residual; + const T* bias; + const T* gamma; + const T* beta; + float epsilon; + int elems_per_row; + + public: + fused_residual_ln( + T* output, + T* res_output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) + : output(output), + res_output(res_output), + vals(vals), + residual(residual), + bias(bias), + gamma(gamma), + beta(beta), + epsilon(epsilon), + elems_per_row(elems_per_row) {} + void operator()(sycl::nd_item<3>) const { constexpr int T_per_load = ln::granularity / sizeof(T); sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); // X-dimension of the block - const int block_offset = - (tb.get_group_id()[2] * (maxThreads / threadsPerGroup) * elems_per_row) + + const int block_offset = (tb.get_group_id()[2] * + (maxThreads / threadsPerGroup) * elems_per_row) + (tb.get_local_id()[1] * elems_per_row); const int thread_offset = tb.get_local_id()[2] * T_per_load; const int base_offset = block_offset + thread_offset; - const int stride = - sycl::ext::oneapi::experimental::this_group<3>().get_local_linear_range() * T_per_load; + const int stride = sycl::ext::oneapi::experimental::this_group<3>() + .get_local_linear_range() * + T_per_load; float sum = reduce::init(); @@ -298,32 +353,37 @@ void fused_residual_ln(T* output, // makes the most sense if we find we are having performance issues. #pragma unRoll for (int i = 0; i < unRoll; i++) { - T* iteration_buffer = local_buffer + i * T_per_load; - T residual_buffer[T_per_load]; - T bias_buffer[T_per_load]; - - mem_access::load_global( - iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); - mem_access::load_global(residual_buffer, - residual_base + i * stride, - thread_offset + i * stride < elems_per_row); - mem_access::load_global( - bias_buffer, bias_base + i * stride, thread_offset + i * stride < elems_per_row); + T* iteration_buffer = local_buffer + i * T_per_load; + T residual_buffer[T_per_load]; + T bias_buffer[T_per_load]; + + mem_access::load_global( + iteration_buffer, + input_base + i * stride, + thread_offset + i * stride < elems_per_row); + mem_access::load_global( + residual_buffer, + residual_base + i * stride, + thread_offset + i * stride < elems_per_row); + mem_access::load_global( + bias_buffer, + bias_base + i * stride, + thread_offset + i * stride < elems_per_row); #pragma unRoll - for (int j = 0; j < T_per_load; j++) { - float vals_up_cast = conversion::to(iteration_buffer[j]); - float res_up_cast = conversion::to(residual_buffer[j]); - float bias_up_cast = conversion::to(bias_buffer[j]); - vals_up_cast = vals_up_cast + bias_up_cast + res_up_cast; - sum = reduce::element(sum, vals_up_cast); - iteration_buffer[j] = conversion::to(vals_up_cast); - } - - if (preLnResidual && (thread_offset + i * stride < elems_per_row)) { - mem_access::store_global(res_output + base_offset + i * stride, - iteration_buffer); - } + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + float res_up_cast = conversion::to(residual_buffer[j]); + float bias_up_cast = conversion::to(bias_buffer[j]); + vals_up_cast = vals_up_cast + bias_up_cast + res_up_cast; + sum = reduce::element(sum, vals_up_cast); + iteration_buffer[j] = conversion::to(vals_up_cast); + } + + if (preLnResidual && (thread_offset + i * stride < elems_per_row)) { + mem_access::store_global( + res_output + base_offset + i * stride, iteration_buffer); + } } reduce::partitioned_block(tb, warp, sum); @@ -333,21 +393,23 @@ void fused_residual_ln(T* output, #pragma unRoll for (int i = 0; i < unRoll; i++) { #pragma unRoll - for (int j = 0; j < T_per_load; j++) { - // Using a 0 value here skews the variance, have to if-guard - if (thread_offset + i * stride < elems_per_row) { - float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); - mean_diff = reduce::element(mean_diff, diff * diff); - } + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = + (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); } + } } reduce::partitioned_block(tb, warp, mean_diff); const float variance = mean_diff / elems_per_row; /* - DPCT1013:10: The rounding mode could not be specified and the generated code may have different - accuracy than the original code. Verify the correctness. SYCL math built-in function rounding - mode is aligned with OpenCL C 1.2 standard. + DPCT1013:10: The rounding mode could not be specified and the generated code + may have different accuracy than the original code. Verify the correctness. + SYCL math built-in function rounding mode is aligned with OpenCL C 1.2 + standard. */ const float denom = sycl::rsqrt(variance + epsilon); @@ -355,254 +417,271 @@ void fused_residual_ln(T* output, #pragma unRoll for (int i = 0; i < unRoll; i++) { - T* iteration_buffer = local_buffer + i * T_per_load; - const int iter_idx = i * stride + thread_offset; - const bool do_loads = iter_idx < elems_per_row; + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; - T gamma_local[T_per_load], beta_local[T_per_load]; + T gamma_local[T_per_load], beta_local[T_per_load]; - mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); - mem_access::load_global(beta_local, beta + iter_idx, do_loads); + mem_access::load_global( + gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global( + beta_local, beta + iter_idx, do_loads); #pragma unRoll - for (int j = 0; j < T_per_load; j++) { - // iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute; - // iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j]; - float val = conversion::to(iteration_buffer[j]); - val = (val - mean) * denom; - val = - val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); - iteration_buffer[j] = conversion::to(val); - } - - if (do_loads) { - mem_access::store_global(block_output + iter_idx, iteration_buffer); - } + for (int j = 0; j < T_per_load; j++) { + // iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * + // denom_compute; iteration_buffer[j] = iteration_buffer[j] * + // gamma_local[j] + beta_local[j]; + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = val * conversion::to(gamma_local[j]) + + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global( + block_output + iter_idx, iteration_buffer); + } } -} + } +}; -// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified. +// TODO(cmikeh2): There's a bunch of redundancy here that needs to be +// removed/simplified. /* -DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. +DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the limit. +To get the device limit, query info::device::max_work_group_size. Adjust the +work-group size if needed. */ -#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \ - { \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* output_ct0 = output; \ - auto nullptr_ct1 = nullptr; \ - const T* vals_ct2 = vals; \ - const T* residual_ct3 = residual; \ - const T* bias_ct4 = bias; \ - const T* gamma_ct5 = gamma; \ - const T* beta_ct6 = beta; \ - auto epsilon_ct7 = epsilon; \ - auto elems_per_row_ct8 = elems_per_row; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - fused_residual_ln( \ - output_ct0, \ - nullptr_ct1, \ - vals_ct2, \ - residual_ct3, \ - bias_ct4, \ - gamma_ct5, \ - beta_ct6, \ - epsilon_ct7, \ - elems_per_row_ct8); \ - }); \ - }); \ +#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \ + { \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + stream->submit([&](sycl::handler& cgh) { \ + fused_residual_ln \ + fn(output, \ + nullptr, \ + vals, \ + residual, \ + bias, \ + gamma, \ + beta, \ + epsilon, \ + elems_per_row); \ + \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ + }); \ } template -void launch_fused_residual_ln(T* output, - const T* vals, - const T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int rows, - int elems_per_row, - dpct::queue_ptr stream) -{ - // 8 for sycl::half, 4 for float - constexpr int T_per_load = ln::granularity / sizeof(T); - - constexpr int maxThreads = 256; - - // For Flaoat, unRoll 4, for sycl::half, unRoll 2 - constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; - - const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; - const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; - - // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of - // warp-sized blocks rather than stepping up to 64/96 threads - const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); - const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; - - const int groups_per_block_max = - is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; - const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; - const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; - - sycl::range<3> block(1, groups_per_block, threadsPerGroup); - sycl::range<3> grid(1, 1, groups_launch); - - const int elems_per_step = threadsPerGroup * h_per_step; - const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; - - if (is_subblock_schedule) { - // <=128 - if (threadsPerGroup == 1) { - LAUNCH_FUSED_RES_LN(1, 1, maxThreads); - } else if (threadsPerGroup == 2) { - LAUNCH_FUSED_RES_LN(1, 2, maxThreads); - } else if (threadsPerGroup == 4) { - LAUNCH_FUSED_RES_LN(1, 4, maxThreads); - } else if (threadsPerGroup == 8) { - LAUNCH_FUSED_RES_LN(1, 8, maxThreads); - } else if (threadsPerGroup == 16) { - LAUNCH_FUSED_RES_LN(1, 16, maxThreads); - } - } else if (external_unRoll == 1) { - // 129 - 4096 elems - // (this can launch with 1-7 warps as well) - LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 2) { - // 4097 - 8192 elems - LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 3) { - // 8193 - 12288 elems - LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 4) { - // 12289 - 16384 elems - LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads); +void launch_fused_residual_ln( + T* output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream) { + // 8 for sycl::half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for sycl::half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = + is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign + // multiple stages of warp-sized blocks rather than stepping up to 64/96 + // threads + const int one_step_threads = + next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = + (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = is_subblock_schedule + ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup + : 1; + const int groups_per_block = + (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + sycl::range<3> block(1, groups_per_block, threadsPerGroup); + sycl::range<3> grid(1, 1, groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = + (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN(1, 16, maxThreads); } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads); + } } /* -DPCT1049:7: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. +DPCT1049:7: The work-group size passed to the SYCL kernel may exceed the limit. +To get the device limit, query info::device::max_work_group_size. Adjust the +work-group size if needed. */ -#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \ - { \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* norm_output_ct0 = norm_output; \ - T* res_output_ct1 = res_output; \ - const T* vals_ct2 = vals; \ - const T* residual_ct3 = residual; \ - const T* bias_ct4 = bias; \ - const T* gamma_ct5 = gamma; \ - const T* beta_ct6 = beta; \ - auto epsilon_ct7 = epsilon; \ - auto elems_per_row_ct8 = elems_per_row; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - fused_residual_ln( \ - norm_output_ct0, \ - res_output_ct1, \ - vals_ct2, \ - residual_ct3, \ - bias_ct4, \ - gamma_ct5, \ - beta_ct6, \ - epsilon_ct7, \ - elems_per_row_ct8); \ - }); \ - }); \ +#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES( \ + unRollFactor, threadsPerGroup, maxThreads) \ + { \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + stream->submit([&](sycl::handler& cgh) { \ + fused_residual_ln \ + fn(norm_output, \ + res_output, \ + vals, \ + residual, \ + bias, \ + gamma, \ + beta, \ + epsilon, \ + elems_per_row); \ + \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ + }); \ } template -void launch_fused_residual_ln_store_pre_ln_res(T* norm_output, - T* res_output, - const T* vals, - const T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int rows, - int elems_per_row, - dpct::queue_ptr stream) -{ - // 8 for sycl::half, 4 for float - constexpr int T_per_load = ln::granularity / sizeof(T); - - constexpr int maxThreads = 256; - - // For Flaoat, unRoll 4, for sycl::half, unRoll 2 - constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; - - const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; - const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; - - // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of - // warp-sized blocks rather than stepping up to 64/96 threads - const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); - const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; - - const int groups_per_block_max = - is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; - const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; - const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; - - sycl::range<3> block(1, groups_per_block, threadsPerGroup); - sycl::range<3> grid(1, 1, groups_launch); - - const int elems_per_step = threadsPerGroup * h_per_step; - const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; - - if (is_subblock_schedule) { - // <=128 - if (threadsPerGroup == 1) { - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads); - } else if (threadsPerGroup == 2) { - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads); - } else if (threadsPerGroup == 4) { - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads); - } else if (threadsPerGroup == 8) { - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads); - } else if (threadsPerGroup == 16) { - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads); - } - } else if (external_unRoll == 1) { - // 129 - 4096 elems - // (this can launch with 1-7 warps as well) - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 2) { - // 4097 - 8192 elems - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 3) { - // 8193 - 12288 elems - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads); - } else if (external_unRoll == 4) { - // 12289 - 16384 elems - LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads); +void launch_fused_residual_ln_store_pre_ln_res( + T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream) { + // 8 for sycl::half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for sycl::half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = + is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign + // multiple stages of warp-sized blocks rather than stepping up to 64/96 + // threads + const int one_step_threads = + next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = + (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = is_subblock_schedule + ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup + : 1; + const int groups_per_block = + (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + sycl::range<3> block(1, groups_per_block, threadsPerGroup); + sycl::range<3> grid(1, 1, groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = + (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads); } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES( + 1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES( + 2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES( + 3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES( + 4 * internal_unRoll, maxThreads, maxThreads); + } } -#define INSTANTIATE_RES_LN(T) \ - template void launch_fused_residual_ln( \ - T*, const T*, const T*, const T*, const T*, const T*, float, int, int, dpct::queue_ptr); - -#define INSTANTIATE_PRE_LN_RES(T) \ - template void launch_fused_residual_ln_store_pre_ln_res(T*, \ - T*, \ - const T*, \ - const T*, \ - const T*, \ - const T*, \ - const T*, \ - float, \ - int, \ - int, \ - dpct::queue_ptr); +#define INSTANTIATE_RES_LN(T) \ + template void launch_fused_residual_ln( \ + T*, \ + const T*, \ + const T*, \ + const T*, \ + const T*, \ + const T*, \ + float, \ + int, \ + int, \ + dpct::queue_ptr); + +#define INSTANTIATE_PRE_LN_RES(T) \ + template void launch_fused_residual_ln_store_pre_ln_res( \ + T*, \ + T*, \ + const T*, \ + const T*, \ + const T*, \ + const T*, \ + const T*, \ + float, \ + int, \ + int, \ + dpct::queue_ptr); INSTANTIATE_RES_LN(sycl::half); INSTANTIATE_RES_LN(float); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pointwise_ops.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pointwise_ops.dp.cpp index 22bd348..0f07a38 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pointwise_ops.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pointwise_ops.dp.cpp @@ -1,10 +1,25 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" #include "ds_kernel_utils.h" #include "memory_access_utils.h" @@ -13,68 +28,88 @@ namespace pwise { constexpr int granularity = 16; constexpr int unroll = 4; constexpr int threads = 256; -} // namespace pwise +} // namespace pwise template -void vector_add_kernel(T* out, const T* a, const T* b, float gamma, int num_elems) -{ +class vector_add_kernel { + private: + T* out; + const T* a; + const T* b; + float gamma; + int num_elems; + + public: + vector_add_kernel(T* out, const T* a, const T* b, float gamma, int num_elems) + : out(out), a(a), b(b), gamma(gamma), num_elems(num_elems) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int T_per_access = pwise::granularity / sizeof(T); - const int block_offset = item_ct1.get_group(2) * pwise::threads * pwise::unroll * T_per_access; + const int block_offset = + item_ct1.get_group(2) * pwise::threads * pwise::unroll * T_per_access; const int thread_offset = item_ct1.get_local_id(2) * T_per_access; const int total_offset = block_offset + thread_offset; constexpr int stride = pwise::threads * T_per_access; #pragma unroll for (int i = 0; i < pwise::unroll; i++) { - T temp_buf_a[T_per_access], temp_buf_b[T_per_access]; + T temp_buf_a[T_per_access], temp_buf_b[T_per_access]; - const int iter_idx = total_offset + i * stride; + const int iter_idx = total_offset + i * stride; - mem_access::load_global(temp_buf_a, a + iter_idx, iter_idx < num_elems); - mem_access::load_global(temp_buf_b, b + iter_idx, iter_idx < num_elems); + mem_access::load_global( + temp_buf_a, a + iter_idx, iter_idx < num_elems); + mem_access::load_global( + temp_buf_b, b + iter_idx, iter_idx < num_elems); #pragma unroll - for (int j = 0; j < T_per_access; j++) { - float up_cast_a = conversion::to(temp_buf_a[j]); - float up_cast_b = conversion::to(temp_buf_b[j]); - temp_buf_a[j] = conversion::to((gamma * up_cast_a) + up_cast_b); - } - - if (iter_idx < num_elems) { - mem_access::store_global(out + iter_idx, temp_buf_a); - } + for (int j = 0; j < T_per_access; j++) { + float up_cast_a = conversion::to(temp_buf_a[j]); + float up_cast_b = conversion::to(temp_buf_b[j]); + temp_buf_a[j] = conversion::to((gamma * up_cast_a) + up_cast_b); + } + + if (iter_idx < num_elems) { + mem_access::store_global( + out + iter_idx, temp_buf_a); + } } -} + } +}; template -void launch_vector_add(T* out, - const T* a, - const T* b, - float gamma, - int num_elems, - dpct::queue_ptr stream) -{ - constexpr int T_per_access = pwise::granularity / sizeof(T); - constexpr int T_per_block = pwise::threads * T_per_access * pwise::unroll; - - sycl::range<3> block(1, 1, pwise::threads); - sycl::range<3> grid(1, 1, (num_elems + T_per_block - 1) / T_per_block); - - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid * block, block), - [=](sycl::nd_item<3> item_ct1) { - vector_add_kernel(out, a, b, gamma, num_elems); - }); - } +void launch_vector_add( + T* out, + const T* a, + const T* b, + float gamma, + int num_elems, + dpct::queue_ptr stream) { + constexpr int T_per_access = pwise::granularity / sizeof(T); + constexpr int T_per_block = pwise::threads * T_per_access * pwise::unroll; + + sycl::range<3> block(1, 1, pwise::threads); + sycl::range<3> grid(1, 1, (num_elems + T_per_block - 1) / T_per_block); + + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + stream->parallel_for( + sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { + vector_add_kernel(out, a, b, gamma, num_elems); + }); + } } -#define INSTANTIATE_VECTOR_ADD(T) \ - template void launch_vector_add( \ - T * out, const T* a, const T* b, float gamma, int num_elems, dpct::queue_ptr stream); +#define INSTANTIATE_VECTOR_ADD(T) \ + template void launch_vector_add( \ + T * out, \ + const T* a, \ + const T* b, \ + float gamma, \ + int num_elems, \ + dpct::queue_ptr stream); INSTANTIATE_VECTOR_ADD(float) INSTANTIATE_VECTOR_ADD(sycl::half) diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pt_binding.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pt_binding.cpp index 6d2ec84..7e6640b 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1,2139 +1,2282 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team -// #include +#include #include #include -#include #include +#include #include #include #include "inference_context.h" -#include "inference_cublas_wrappers.h" -#include "inference_cuda_layers.h" -#include +#include "inference_mkl_wrappers.h" +#include "inference_sycl_layers.h" std::array gemm_algos = std::array({99, 99, 99}); // NOTE: This activation function type enum should be always in sync // with the python counterpart, otherwise the casting from python binding // will be incorrect. -enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2, GATED_GELU = 3, GATED_SILU = 4 }; +enum class ActivationFuncType { + UNKNOWN = 0, + GELU = 1, + ReLU = 2, + GATED_GELU = 3, + GATED_SILU = 4 +}; enum class NormType { UNKNOWN = 0, LayerNorm = 1, GroupNorm = 2, RMSNorm = 3 }; enum class TransformerType : uint8_t { UNKNOWN = 0, GPTType = 1, BERTType = 2 }; -// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT style models -// based on the dimensions of the corresponding attention mask. -inline auto infer_transformer_type(at::Tensor& attn_mask) -> TransformerType -{ - auto attn_mask_num_dims = attn_mask.sizes().size(); - - if (attn_mask_num_dims > 2) { - return TransformerType::GPTType; - } else if (attn_mask_num_dims == 2) { - return TransformerType::BERTType; - } else { - return TransformerType::UNKNOWN; - } +// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT +// style models based on the dimensions of the corresponding attention mask. +inline auto infer_transformer_type(at::Tensor& attn_mask) -> TransformerType { + auto attn_mask_num_dims = attn_mask.sizes().size(); + + if (attn_mask_num_dims > 2) { + return TransformerType::GPTType; + } else if (attn_mask_num_dims == 2) { + return TransformerType::BERTType; + } else { + return TransformerType::UNKNOWN; + } } // infer stride of attention mask memory layout based on the model type. -inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int -{ - auto trnsfrmr_type = infer_transformer_type(attn_mask); - - if (trnsfrmr_type == TransformerType::GPTType) { - return attn_mask.size(2); - } else if (trnsfrmr_type == TransformerType::BERTType) { - // Bert style models have always a mask stride of 1. - return 1; - } else if (trnsfrmr_type == TransformerType::UNKNOWN) { - return 0; - } - - // this is just to make the compiler happy. +inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int { + auto trnsfrmr_type = infer_transformer_type(attn_mask); + + if (trnsfrmr_type == TransformerType::GPTType) { + return attn_mask.size(2); + } else if (trnsfrmr_type == TransformerType::BERTType) { + // Bert style models have always a mask stride of 1. + return 1; + } else if (trnsfrmr_type == TransformerType::UNKNOWN) { return 0; + } + + // this is just to make the compiler happy. + return 0; } template -at::Tensor ds_softmax(at::Tensor& attn_scores, - at::Tensor& attn_mask, - at::Tensor& alibi, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - bool async_op, - float layer_scale, - int head_offset, - int mp_size) -{ - auto attn_scores_c = attn_scores.contiguous(); - int bsz = attn_scores_c.size(0); - - int seq_len = attn_scores_c.size(1); - int len = attn_scores_c.sizes().size(); - if (len > 2) seq_len = attn_scores_c.size(2); - - int soft_len = attn_scores_c.size(2); - if (len > 3) soft_len = attn_scores_c.size(3); - - int heads = 1; - if (len > 1) heads = attn_scores_c.size(1); - - auto mask_stride = get_attn_mask_stride(attn_mask); - - launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), - (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), - (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), - layer_scale, - triangular, - recompute, - local_attention, - window_size, - bsz, - heads, - seq_len, - soft_len, - head_offset, - mask_stride, - mp_size, - InferenceContext::Instance().GetCurrentStream(async_op)); - - return attn_scores_c; +at::Tensor ds_softmax( + at::Tensor& attn_scores, + at::Tensor& attn_mask, + at::Tensor& alibi, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + bool async_op, + float layer_scale, + int head_offset, + int mp_size) { + auto attn_scores_c = attn_scores.contiguous(); + int bsz = attn_scores_c.size(0); + + int seq_len = attn_scores_c.size(1); + int len = attn_scores_c.sizes().size(); + if (len > 2) + seq_len = attn_scores_c.size(2); + + int soft_len = attn_scores_c.size(2); + if (len > 3) + soft_len = attn_scores_c.size(3); + + int heads = 1; + if (len > 1) + heads = attn_scores_c.size(1); + + auto mask_stride = get_attn_mask_stride(attn_mask); + + launch_attn_softmax_v2( + (T*)attn_scores_c.data_ptr(), + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), + layer_scale, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + head_offset, + mask_stride, + mp_size, + InferenceContext::Instance().GetCurrentStream(async_op)); + + return attn_scores_c; } template -void allocate_workspace(unsigned hidden_dim, - unsigned num_heads, - unsigned prompt_length, - unsigned batch_size, - unsigned num_layers, - unsigned mp_size = 1, - bool external_cache = false, - unsigned rank = 0, - unsigned max_out_tokens = 1024, - unsigned min_out_tokens = 1) -{ - InferenceContext::Instance().GenWorkSpace(num_layers, - num_heads, - batch_size, - prompt_length, - hidden_dim, - mp_size, - external_cache, - sizeof(T), - rank, - max_out_tokens, - min_out_tokens); +void allocate_workspace( + unsigned hidden_dim, + unsigned num_heads, + unsigned prompt_length, + unsigned batch_size, + unsigned num_layers, + unsigned mp_size = 1, + bool external_cache = false, + unsigned rank = 0, + unsigned max_out_tokens = 1024, + unsigned min_out_tokens = 1) { + InferenceContext::Instance().GenWorkSpace( + num_layers, + num_heads, + batch_size, + prompt_length, + hidden_dim, + mp_size, + external_cache, + sizeof(T), + rank, + max_out_tokens, + min_out_tokens); } template -at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) -{ - auto options = at::TensorOptions() - .dtype(Q.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - float alpha = 1; - float gemm_beta = 0.0; - - /* - // Reallocate memory if we received a new prompt - if (!workspace || input.size(1) != 1) { - allocate_workspace(W.size(1), InferenceContext::Instance().GetMaxTokenLength(), - Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - } - */ - - auto O = at::from_blob(workspace, - {Q.size(1), Q.size(2), W.size(1)}, - c10::TensorType::contiguousStridesOf({Q.size(1), Q.size(2), W.size(1)}), - nullptr, - options, - Q.device()); - unsigned m = W.size(1); - unsigned n = Q.size(1) * Q.size(2); - unsigned k = Q.size(0); - cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), - oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::trans, - m, - n, - k, - &alpha, - &gemm_beta, - (T*)W.data_ptr(), - (T*)Q.data_ptr(), - (T*)O.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - return O; +at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) { + auto options = at::TensorOptions() + .dtype(Q.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + float alpha = 1; + float gemm_beta = 0.0; + + /* + // Reallocate memory if we received a new prompt + if (!workspace || input.size(1) != 1) { + allocate_workspace(W.size(1), + InferenceContext::Instance().GetMaxTokenLength(), Q.size(0), 1, head_size); + workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + } + */ + + auto O = at::from_blob( + workspace, + {Q.size(1), Q.size(2), W.size(1)}, + c10::TensorType::contiguousStridesOf({Q.size(1), Q.size(2), W.size(1)}), + nullptr, + options, + Q.device()); + unsigned m = W.size(1); + unsigned n = Q.size(1) * Q.size(2); + unsigned k = Q.size(0); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::trans, + m, + n, + k, + &alpha, + &gemm_beta, + (T*)W.data_ptr(), + (T*)Q.data_ptr(), + (T*)O.data_ptr(), + 99); + return O; } template -void attention_unfused(at::Tensor& prev_key_cont, - at::Tensor& query_cont, - at::Tensor& attn_mask, - at::Tensor& prev_value_cont, - at::Tensor& output, - int& bsz, - int& seq_len, - int& soft_len, - int& heads, - float& norm_factor, - bool triangular, - bool recompute, - bool local_attention, - int window_size) -{ - auto options = at::TensorOptions() - .dtype(query_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - float alpha = norm_factor; - float gemm_beta = 0.0; - auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options); - int k = prev_value_cont.size(2) / heads; - - auto mask_stride = get_attn_mask_stride(attn_mask); - - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), - soft_len, - seq_len, - k, - &alpha, - &gemm_beta, - (T*)prev_key_cont.data_ptr(), - (T*)query_cont.data_ptr(), - (T*)attn_score.data_ptr(), - oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::nontrans, - soft_len * k, - seq_len * k, - seq_len * soft_len, - bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - launch_attn_softmax_v2((T*)attn_score.data_ptr(), - (T*)(attn_mask.sizes().size() > 1 ? attn_mask.data_ptr() : nullptr), - (T*)nullptr, - 1.0, - triangular, - recompute, - local_attention, - window_size, - bsz, - heads, - seq_len, - soft_len, - 0, - mask_stride, - 1, - InferenceContext::Instance().GetCurrentStream(false)); - alpha = 1.0; - cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), - k, - seq_len, - soft_len, - &alpha, - &gemm_beta, - (T*)prev_value_cont.data_ptr(), - (T*)attn_score.data_ptr(), - (T*)output.data_ptr(), - oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::nontrans, - soft_len * k, - seq_len * soft_len, - seq_len * k, - bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif +void attention_unfused( + at::Tensor& prev_key_cont, + at::Tensor& query_cont, + at::Tensor& attn_mask, + at::Tensor& prev_value_cont, + at::Tensor& output, + int& bsz, + int& seq_len, + int& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size) { + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + float alpha = norm_factor; + float gemm_beta = 0.0; + auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options); + int k = prev_value_cont.size(2) / heads; + + auto mask_stride = get_attn_mask_stride(attn_mask); + + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_strided_batched_gemm( + InferenceContext::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont.data_ptr(), + (T*)query_cont.data_ptr(), + (T*)attn_score.data_ptr(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + soft_len * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, + 99); + launch_attn_softmax_v2( + (T*)attn_score.data_ptr(), + (T*)(attn_mask.sizes().size() > 1 ? attn_mask.data_ptr() : nullptr), + (T*)nullptr, + 1.0, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 0, + mask_stride, + 1, + InferenceContext::Instance().GetCurrentStream(false)); + alpha = 1.0; + mkl_strided_batched_gemm( + InferenceContext::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont.data_ptr(), + (T*)attn_score.data_ptr(), + (T*)output.data_ptr(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + soft_len * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, + 99); } template -std::vector ds_softmax_context1(at::Tensor& query, - at::Tensor& prev_key, - at::Tensor& new_key, - at::Tensor& attn_mask, - at::Tensor& prev_value, - at::Tensor& new_value, - int heads, - float norm_factor, - bool merging, - bool triangular, - bool local_attention, - int window_size, - bool no_masking) -{ - auto query_cont = query.contiguous(); - auto prev_key_cont = prev_key.contiguous(); - auto prev_value_cont = prev_value.contiguous(); - - int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); - - // Attn_Score [ batch Head Sequence-length Softmax-length] - - int bsz = query_cont.size(0); - int seq_len = query_cont.size(1); - int soft_len = prev_value.size(1); - - auto options = at::TensorOptions() - .dtype(query_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - auto output = - at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options); - attention_unfused(prev_key_cont, - query_cont, - attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), - prev_value_cont, - output, - bsz, - seq_len, - soft_len, - heads, - norm_factor, - (triangular && (new_size == 0)), - (new_size == 0), - local_attention, - window_size); - - return {output, prev_key, prev_value}; +std::vector ds_softmax_context1( + at::Tensor& query, + at::Tensor& prev_key, + at::Tensor& new_key, + at::Tensor& attn_mask, + at::Tensor& prev_value, + at::Tensor& new_value, + int heads, + float norm_factor, + bool merging, + bool triangular, + bool local_attention, + int window_size, + bool no_masking) { + auto query_cont = query.contiguous(); + auto prev_key_cont = prev_key.contiguous(); + auto prev_value_cont = prev_value.contiguous(); + + int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); + + // Attn_Score [ batch Head Sequence-length Softmax-length] + + int bsz = query_cont.size(0); + int seq_len = query_cont.size(1); + int soft_len = prev_value.size(1); + + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + auto output = at::empty( + {prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, + options); + attention_unfused( + prev_key_cont, + query_cont, + attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), + prev_value_cont, + output, + bsz, + seq_len, + soft_len, + heads, + norm_factor, + (triangular && (new_size == 0)), + (new_size == 0), + local_attention, + window_size); + + return {output, prev_key, prev_value}; } template -void ds_softmax_internal(T* attn_scores, - at::Tensor& attn_mask, - at::Tensor& alibi, - float& layer_scale, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int bsz, - int seq_len, - int soft_len, - int heads) -{ - auto mask_stride = get_attn_mask_stride(attn_mask); - - launch_attn_softmax_v2((T*)attn_scores, - (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), - (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), - layer_scale, - triangular, - recompute, - local_attention, - window_size, - bsz, - heads, - seq_len, - soft_len, - 0, - mask_stride, - 1, - at::cuda::getCurrentCUDAStream()); +void ds_softmax_internal( + T* attn_scores, + at::Tensor& attn_mask, + at::Tensor& alibi, + float& layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int bsz, + int seq_len, + int soft_len, + int heads) { + auto mask_stride = get_attn_mask_stride(attn_mask); + + launch_attn_softmax_v2( + (T*)attn_scores, + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), + layer_scale, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 0, + mask_stride, + 1, + at::sycl::getCurrentSYCLStream()); } template -void attention_unfused(T* prev_key_cont, - T* query_cont, - at::Tensor& attn_mask, - T* prev_value_cont, - T* output, - unsigned& bsz, - int& k, - unsigned& seq_len, - unsigned& soft_len, - int& heads, - float& norm_factor, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - at::Tensor& alibi, - int layer_id) -{ - float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; - float alpha = norm_factor * norm_factor / layer_scale; - float gemm_beta = 0.0; - T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace(); - - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), - soft_len, - seq_len, - k, - &alpha, - &gemm_beta, - (T*)prev_key_cont, - (T*)query_cont, - workspace, - oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, - InferenceContext::Instance().GetMaxTokenLength() * k, - seq_len * k, - seq_len * soft_len, - bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - ds_softmax_internal(workspace, - attn_mask, - alibi, - layer_scale, - triangular, - recompute, - local_attention, - window_size, - bsz, - seq_len, - soft_len, - heads); - alpha = 1.0; - cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), - k, - seq_len, - soft_len, - &alpha, - &gemm_beta, - (T*)prev_value_cont, - workspace, - (T*)output, - oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::nontrans, - InferenceContext::Instance().GetMaxTokenLength() * k, - seq_len * soft_len, - seq_len * k, - bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif +void attention_unfused( + T* prev_key_cont, + T* query_cont, + at::Tensor& attn_mask, + T* prev_value_cont, + T* output, + unsigned& bsz, + int& k, + unsigned& seq_len, + unsigned& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + at::Tensor& alibi, + int layer_id) { + float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; + float alpha = norm_factor * norm_factor / layer_scale; + float gemm_beta = 0.0; + T* workspace = + (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace(); + + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_strided_batched_gemm( + InferenceContext::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont, + (T*)query_cont, + workspace, + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + InferenceContext::Instance().GetMaxTokenLength() * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, + 99); + ds_softmax_internal( + workspace, + attn_mask, + alibi, + layer_scale, + triangular, + recompute, + local_attention, + window_size, + bsz, + seq_len, + soft_len, + heads); + alpha = 1.0; + mkl_strided_batched_gemm( + InferenceContext::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont, + workspace, + (T*)output, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + InferenceContext::Instance().GetMaxTokenLength() * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, + 99); } -void reset_cache() { InferenceContext::Instance().reset_tokens(); } +void reset_cache() { + InferenceContext::Instance().reset_tokens(); +} template -std::vector ds_softmax_context(at::Tensor& query_key_value, - at::Tensor& attn_mask, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - int heads, - int num_kv, - float norm_factor, - bool triangular, - bool local_attention, - int window_size, - bool no_masking, - unsigned layer_id, - unsigned num_layers, - at::Tensor& alibi, - float rope_theta) -{ - unsigned bsz = query_key_value.size(0); - unsigned seq_len = query_key_value.size(1); - int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads)); - unsigned hidden_dim = heads * k; - - bool is_prompt = (seq_len > 1); - - if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len); - unsigned soft_len = InferenceContext::Instance().current_tokens(); - - auto options = at::TensorOptions() - .dtype(query_key_value.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - size_t buf_size = bsz * seq_len * hidden_dim; - auto output = at::from_blob(workspace + 4 * buf_size, - {bsz, seq_len, hidden_dim}, - c10::TensorType::contiguousStridesOf({bsz, seq_len, hidden_dim}), - nullptr, - options, - query_key_value.device()); - - auto query_cont = workspace + 5 * buf_size; - size_t offset = - 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLength()) + - layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; - unsigned all_tokens = soft_len; - auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); - size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; - - T* temp_buf = (T*)output.data_ptr() + at::numel(output); - launch_bias_add_transform_0213((T*)query_cont, - kv_cache, - kv_cache + value_offset, - (T*)query_key_value.data_ptr(), - nullptr, - bsz, - seq_len, - (is_prompt ? 0 : soft_len - 1), - soft_len, - hidden_dim, - heads, - (num_kv > 0 ? num_kv : heads), - rotary_dim, - rotate_half, - rotate_every_two, - InferenceContext::Instance().GetCurrentStream(), - 3, - InferenceContext::Instance().GetMaxTokenLength(), - rope_theta); - if (rotary_dim > 0 && rotate_half) - launch_apply_rotary_pos_emb(query_cont, - kv_cache, - k, - seq_len, - rotary_dim, - (is_prompt ? 0 : soft_len - 1), - heads, - bsz, - rope_theta, - InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLength()); - - attention_unfused(workspace + offset, - (T*)query_cont, - attn_mask, - workspace + offset + value_offset, - temp_buf, - bsz, - k, - seq_len, - all_tokens, - heads, - norm_factor, - (triangular && is_prompt), - is_prompt, - local_attention, - window_size, - alibi, - layer_id); - launch_transform4d_0213((T*)output.data_ptr(), - temp_buf, - bsz, - heads, - seq_len, - output.size(2), - InferenceContext::Instance().GetCurrentStream(false), - 1); - - if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); - auto prev_key = at::from_blob(workspace + offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), - k * InferenceContext::Instance().GetMaxTokenLength(), - k, - 1}, - nullptr, - options, - query_key_value.device()); - - auto prev_value = at::from_blob(workspace + offset + value_offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), - k * InferenceContext::Instance().GetMaxTokenLength(), - k, - 1}, - nullptr, - options, - query_key_value.device()); - - return {output, prev_key, prev_value}; +std::vector ds_softmax_context( + at::Tensor& query_key_value, + at::Tensor& attn_mask, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int heads, + int num_kv, + float norm_factor, + bool triangular, + bool local_attention, + int window_size, + bool no_masking, + unsigned layer_id, + unsigned num_layers, + at::Tensor& alibi, + float rope_theta) { + unsigned bsz = query_key_value.size(0); + unsigned seq_len = query_key_value.size(1); + int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads)); + unsigned hidden_dim = heads * k; + + bool is_prompt = (seq_len > 1); + + if (is_prompt) + InferenceContext::Instance().reset_tokens(seq_len); + unsigned soft_len = InferenceContext::Instance().current_tokens(); + + auto options = at::TensorOptions() + .dtype(query_key_value.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + size_t buf_size = bsz * seq_len * hidden_dim; + auto output = at::from_blob( + workspace + 4 * buf_size, + {bsz, seq_len, hidden_dim}, + c10::TensorType::contiguousStridesOf({bsz, seq_len, hidden_dim}), + nullptr, + options, + query_key_value.device()); + + auto query_cont = workspace + 5 * buf_size; + size_t offset = 10 * + (hidden_dim * bsz * + InferenceContext::Instance().GetMaxTokenLength()) + + layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLength() * + hidden_dim; + unsigned all_tokens = soft_len; + auto kv_cache = workspace + offset + + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); + size_t value_offset = + bsz * InferenceContext::Instance().GetMaxTokenLength() * hidden_dim; + + T* temp_buf = (T*)output.data_ptr() + at::numel(output); + launch_bias_add_transform_0213( + (T*)query_cont, + kv_cache, + kv_cache + value_offset, + (T*)query_key_value.data_ptr(), + nullptr, + bsz, + seq_len, + (is_prompt ? 0 : soft_len - 1), + soft_len, + hidden_dim, + heads, + (num_kv > 0 ? num_kv : heads), + rotary_dim, + rotate_half, + rotate_every_two, + InferenceContext::Instance().GetCurrentStream(), + 3, + InferenceContext::Instance().GetMaxTokenLength(), + rope_theta); + if (rotary_dim > 0 && rotate_half) + launch_apply_rotary_pos_emb( + query_cont, + kv_cache, + k, + seq_len, + rotary_dim, + (is_prompt ? 0 : soft_len - 1), + heads, + bsz, + rope_theta, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + + attention_unfused( + workspace + offset, + (T*)query_cont, + attn_mask, + workspace + offset + value_offset, + temp_buf, + bsz, + k, + seq_len, + all_tokens, + heads, + norm_factor, + (triangular && is_prompt), + is_prompt, + local_attention, + window_size, + alibi, + layer_id); + launch_transform4d_0213( + (T*)output.data_ptr(), + temp_buf, + bsz, + heads, + seq_len, + output.size(2), + InferenceContext::Instance().GetCurrentStream(false), + 1); + + if (layer_id == num_layers - 1) + InferenceContext::Instance().advance_tokens(); + auto prev_key = at::from_blob( + workspace + offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), + k * InferenceContext::Instance().GetMaxTokenLength(), + k, + 1}, + nullptr, + options, + query_key_value.device()); + + auto prev_value = at::from_blob( + workspace + offset + value_offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), + k * InferenceContext::Instance().GetMaxTokenLength(), + k, + 1}, + nullptr, + options, + query_key_value.device()); + + return {output, prev_key, prev_value}; } template -at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - int intermediate_size = input_cont.size(2); - - launch_bias_gelu((T*)input_cont.data_ptr(), - (T*)bias.data_ptr(), - intermediate_size, - bsz, - InferenceContext::Instance().GetCurrentStream()); - return input_cont; +at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) { + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_gelu( + (T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + InferenceContext::Instance().GetCurrentStream()); + return input_cont; } -#define DISPATCH_GATED_ACT(T_TYPE, C_TYPE) \ - if (activation.options().dtype() == torch::T_TYPE) { \ - launch_gated_activation((C_TYPE*)output.data_ptr(), \ - (const C_TYPE*)activation.data_ptr(), \ - (const C_TYPE*)bias.data_ptr(), \ - rows, \ - out_channels, \ - channels, \ - activation_type == ActivationFuncType::GATED_GELU, \ - InferenceContext::Instance().GetCurrentStream()); \ - } - -at::Tensor ds_gated_activation(at::Tensor& activation, at::Tensor& bias, int actFun) -{ - /* - Used in FF of Stable diffusion - */ - - const ActivationFuncType activation_type = static_cast(actFun); - - assert(activation_type == ActivationFuncType::GATED_GELU || - activation_type == ActivationFuncType::GATED_SILU); - - const int batch_size = activation.size(0); - const int seq_len = activation.size(1); - const int channels = activation.size(2); - - const int rows = batch_size * seq_len; - // Dimensionality is cut in half - const int out_channels = channels / 2; - - auto output = at::empty({batch_size, seq_len, out_channels}, activation.options()); - - DISPATCH_GATED_ACT(kFloat, float); - DISPATCH_GATED_ACT(kHalf, sycl::half); +#define DISPATCH_GATED_ACT(T_TYPE, C_TYPE) \ + if (activation.options().dtype() == torch::T_TYPE) { \ + launch_gated_activation( \ + (C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)activation.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + rows, \ + out_channels, \ + channels, \ + activation_type == ActivationFuncType::GATED_GELU, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_gated_activation( + at::Tensor& activation, + at::Tensor& bias, + int actFun) { + /* + Used in FF of Stable diffusion + */ + + const ActivationFuncType activation_type = + static_cast(actFun); + + assert( + activation_type == ActivationFuncType::GATED_GELU || + activation_type == ActivationFuncType::GATED_SILU); + + const int batch_size = activation.size(0); + const int seq_len = activation.size(1); + const int channels = activation.size(2); + + const int rows = batch_size * seq_len; + // Dimensionality is cut in half + const int out_channels = channels / 2; + + auto output = + at::empty({batch_size, seq_len, out_channels}, activation.options()); + + DISPATCH_GATED_ACT(kFloat, float); + DISPATCH_GATED_ACT(kHalf, sycl::half); #ifdef BF16_AVAILABLE - DISPATCH_GATED_ACT(kBFloat16, sycl::ext::oneapi::bfloat16); + DISPATCH_GATED_ACT(kBFloat16, sycl::ext::oneapi::bfloat16); #endif - return output; + return output; } template -at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - int intermediate_size = input_cont.size(2); - - launch_bias_relu((T*)input_cont.data_ptr(), - (T*)bias.data_ptr(), - intermediate_size, - bsz, - InferenceContext::Instance().GetCurrentStream()); - return input_cont; +at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) { + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_relu( + (T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + InferenceContext::Instance().GetCurrentStream()); + return input_cont; } template -at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - int hidden_size = input_cont.size(2); - - launch_bias_add((T*)input_cont.data_ptr(), - (T*)bias.data_ptr(), - hidden_size, - bsz, - InferenceContext::Instance().GetCurrentStream()); - return input_cont; +at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias) { + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int hidden_size = input_cont.size(2); + + launch_bias_add( + (T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + hidden_size, + bsz, + InferenceContext::Instance().GetCurrentStream()); + return input_cont; } template -at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - auto residual_cont = residual.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - // launch_bias_residual((T*)input_cont.data_ptr(), - // (T*)residual_cont.data_ptr(), - // (T*)bias.data_ptr(), - // bsz, - // input_cont.size(2), - // (bias.size(0) > 1), - // InferenceContext::Instance().GetCurrentStream()); - return input_cont; +at::Tensor ds_bias_residual( + at::Tensor& input, + at::Tensor& residual, + at::Tensor& bias) { + auto input_cont = input.contiguous(); + auto residual_cont = residual.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + // launch_bias_residual((T*)input_cont.data_ptr(), + // (T*)residual_cont.data_ptr(), + // (T*)bias.data_ptr(), + // bsz, + // input_cont.size(2), + // (bias.size(0) > 1), + // InferenceContext::Instance().GetCurrentStream()); + return input_cont; } -#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \ - if (input.options().dtype() == torch::T_TYPE) { \ - launch_fused_ln((C_TYPE*)output.data_ptr(), \ - (const C_TYPE*)input.data_ptr(), \ - (const C_TYPE*)gamma.data_ptr(), \ - (const C_TYPE*)beta.data_ptr(), \ - epsilon, \ - rows, \ - elems_per_row, \ - InferenceContext::Instance().GetCurrentStream()); \ - } - -at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, float epsilon) -{ - const int rows = input.size(0) * input.size(1); - const int elems_per_row = input.size(2); - auto output = at::empty_like(input); - - DISPATCH_LAYER_NORM(kFloat, float); - DISPATCH_LAYER_NORM(kHalf, sycl::half); +#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_ln( \ + (C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_layer_norm( + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) { + const int rows = input.size(0) * input.size(1); + const int elems_per_row = input.size(2); + auto output = at::empty_like(input); + + DISPATCH_LAYER_NORM(kFloat, float); + DISPATCH_LAYER_NORM(kHalf, sycl::half); #ifdef BF16_AVAILABLE - DISPATCH_LAYER_NORM(kBFloat16, sycl::ext::oneapi::bfloat16); + DISPATCH_LAYER_NORM(kBFloat16, sycl::ext::oneapi::bfloat16); #endif - return output; + return output; } -#define DISPATCH_RMS_NORM(T_TYPE, C_TYPE) \ - if (input.options().dtype() == torch::T_TYPE) { \ - launch_rms_norm((C_TYPE*)output.data_ptr(), \ - (C_TYPE*)nullptr, \ - (const C_TYPE*)input.data_ptr(), \ - (const C_TYPE*)nullptr, \ - (const C_TYPE*)gamma.data_ptr(), \ - epsilon, \ - rows, \ - elems_per_row, \ - InferenceContext::Instance().GetCurrentStream()); \ - } - -at::Tensor ds_rms_norm(at::Tensor& input, at::Tensor& gamma, float epsilon) -{ - // Get number of dims of tensor - int num_dims = input.dim(); - const int rows = (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); - const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); - - auto output = at::empty_like(input); - - DISPATCH_RMS_NORM(kFloat, float); - DISPATCH_RMS_NORM(kHalf, sycl::half); +#define DISPATCH_RMS_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_rms_norm( \ + (C_TYPE*)output.data_ptr(), \ + (C_TYPE*)nullptr, \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)nullptr, \ + (const C_TYPE*)gamma.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor ds_rms_norm(at::Tensor& input, at::Tensor& gamma, float epsilon) { + // Get number of dims of tensor + int num_dims = input.dim(); + const int rows = + (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); + + auto output = at::empty_like(input); + + DISPATCH_RMS_NORM(kFloat, float); + DISPATCH_RMS_NORM(kHalf, sycl::half); #ifdef BF16_AVAILABLE - DISPATCH_RMS_NORM(kBFloat16, sycl::ext::oneapi::bfloat16); + DISPATCH_RMS_NORM(kBFloat16, sycl::ext::oneapi::bfloat16); #endif - return output; + return output; } -#define DISPATCH_PRE_RMS_NORM(T_TYPE, C_TYPE) \ - if (input.options().dtype() == torch::T_TYPE) { \ - launch_rms_norm((C_TYPE*)output.data_ptr(), \ - (C_TYPE*)res_out.data_ptr(), \ - (const C_TYPE*)input.data_ptr(), \ - (const C_TYPE*)residual.data_ptr(), \ - (const C_TYPE*)gamma.data_ptr(), \ - epsilon, \ - rows, \ - elems_per_row, \ - InferenceContext::Instance().GetCurrentStream()); \ - } - -std::vector ds_pre_rms_norm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& gamma, - float epsilon) -{ - // Get number of dims of tensor - int num_dims = input.dim(); - const int rows = (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); - const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); - - auto output = at::empty_like(input); - auto res_out = at::empty_like(residual); - - DISPATCH_PRE_RMS_NORM(kFloat, float); - DISPATCH_PRE_RMS_NORM(kHalf, sycl::half); +#define DISPATCH_PRE_RMS_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_rms_norm( \ + (C_TYPE*)output.data_ptr(), \ + (C_TYPE*)res_out.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +std::vector ds_pre_rms_norm( + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + float epsilon) { + // Get number of dims of tensor + int num_dims = input.dim(); + const int rows = + (num_dims == 2) ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = (num_dims == 2) ? input.size(1) : input.size(2); + + auto output = at::empty_like(input); + auto res_out = at::empty_like(residual); + + DISPATCH_PRE_RMS_NORM(kFloat, float); + DISPATCH_PRE_RMS_NORM(kHalf, sycl::half); #ifdef BF16_AVAILABLE - DISPATCH_PRE_RMS_NORM(kBFloat16, sycl::ext::oneapi::bfloat16); + DISPATCH_PRE_RMS_NORM(kBFloat16, sycl::ext::oneapi::bfloat16); #endif - return {output, res_out}; + return {output, res_out}; } template -void ds_layer_norm_internal(T* workspace, - at::Tensor& input, - at::Tensor& gamma, - at::Tensor& beta, - float epsilon) -{ - int bsz = input.size(0) * input.size(1); - launch_fused_ln(workspace, - (const T*)input.data_ptr(), - (const T*)gamma.data_ptr(), - (const T*)beta.data_ptr(), - epsilon, - bsz, - input.size(2), - InferenceContext::Instance().GetCurrentStream()); +void ds_layer_norm_internal( + T* workspace, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) { + int bsz = input.size(0) * input.size(1); + launch_fused_ln( + workspace, + (const T*)input.data_ptr(), + (const T*)gamma.data_ptr(), + (const T*)beta.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); } -#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ - if (input.options().dtype() == torch::T_TYPE) { \ - launch_fused_residual_ln((C_TYPE*)output.data_ptr(), \ - (const C_TYPE*)input.data_ptr(), \ - (const C_TYPE*)residual.data_ptr(), \ - (const C_TYPE*)bias.data_ptr(), \ - (const C_TYPE*)gamma.data_ptr(), \ - (const C_TYPE*)beta.data_ptr(), \ - epsilon, \ - rows, \ - elems_per_row, \ - InferenceContext::Instance().GetCurrentStream()); \ - } +#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_residual_ln( \ + (C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } /* Currently only used in unit testing */ -at::Tensor ds_layer_norm_residual(at::Tensor& input, - at::Tensor& bias, - at::Tensor& residual, - at::Tensor& gamma, - at::Tensor& beta, - float epsilon) -{ - const int rows = input.size(0) * input.size(1); - const int elems_per_row = input.size(2); - auto output = at::empty_like(input); - - DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float); - DISPATCH_LAYER_NORM_RESIDUAL(kHalf, sycl::half); +at::Tensor ds_layer_norm_residual( + at::Tensor& input, + at::Tensor& bias, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) { + const int rows = input.size(0) * input.size(1); + const int elems_per_row = input.size(2); + auto output = at::empty_like(input); + + DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_LAYER_NORM_RESIDUAL(kHalf, sycl::half); #ifdef BF16_AVAILABLE - DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, sycl::ext::oneapi::bfloat16); + DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, sycl::ext::oneapi::bfloat16); #endif - return output; + return output; } -#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ - if (input.options().dtype() == torch::T_TYPE) { \ - launch_fused_residual_ln_store_pre_ln_res( \ - (C_TYPE*)norm_output.data_ptr(), \ - (C_TYPE*)res_output.data_ptr(), \ - (const C_TYPE*)input.data_ptr(), \ - (const C_TYPE*)residual.data_ptr(), \ - (const C_TYPE*)bias.data_ptr(), \ - (const C_TYPE*)gamma.data_ptr(), \ - (const C_TYPE*)beta.data_ptr(), \ - epsilon, \ - rows, \ - elems_per_row, \ - InferenceContext::Instance().GetCurrentStream()); \ - } +#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_residual_ln_store_pre_ln_res( \ + (C_TYPE*)norm_output.data_ptr(), \ + (C_TYPE*)res_output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)bias.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + InferenceContext::Instance().GetCurrentStream()); \ + } /* Currently only used in unit testing */ -std::vector ds_layer_norm_residual_store_pre_ln_res(at::Tensor& input, - at::Tensor& bias, - at::Tensor& residual, - at::Tensor& gamma, - at::Tensor& beta, - float epsilon) -{ - const int rows = input.size(0) * input.size(1); - const int elems_per_row = input.size(2); - auto norm_output = at::empty_like(input); - auto res_output = at::empty_like(input); - - DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float); - DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, sycl::half); +std::vector ds_layer_norm_residual_store_pre_ln_res( + at::Tensor& input, + at::Tensor& bias, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) { + const int rows = input.size(0) * input.size(1); + const int elems_per_row = input.size(2); + auto norm_output = at::empty_like(input); + auto res_output = at::empty_like(input); + + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, sycl::half); #ifdef BF16_AVAILABLE - DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, sycl::ext::oneapi::bfloat16); + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, sycl::ext::oneapi::bfloat16); #endif - return {norm_output, res_output}; + return {norm_output, res_output}; } template -void quantized_gemm(void* output, - T* input, - at::Tensor& weight, - at::Tensor& qscale, - int groups, - int bsz, - int hidden_size) -{ - // T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz; - - auto options = at::TensorOptions() - .dtype(at::kHalf) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - auto tmp = torch::empty(weight.sizes(), options); - T* weight16 = (T*)tmp.data_ptr(); - launch_dequantize(weight16, - (int8_t*)weight.data_ptr(), - (float*)qscale.data_ptr(), - weight.size(0), - weight.size(1), - groups, - InferenceContext::Instance().GetCurrentStream()); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), - oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, - weight.size(0), - bsz, - weight.size(1), - &alpha, - &gemm_beta, - weight16, - (T*)input, - (T*)output, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif -} - -template -at::Tensor qkv_unfused_cublas(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& q_scale, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias, - bool q_int8, - bool transposed_mode) -{ - int bsz = input.size(0) * input.size(1); - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - workspace += (3 * bsz * input.size(2)); - ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); - - if (q_int8) { - quantized_gemm( - output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2)); - } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - weight.size(transposed_mode ? 0 : 1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - workspace, - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - return at::from_blob(workspace, - input.sizes(), - c10::TensorType::contiguousStridesOf(input.sizes()), - nullptr, - input.options(), - input.options().device()); +void quantized_gemm( + void* output, + T* input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int bsz, + int hidden_size) { + // T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * + // hidden_size * bsz; + + auto options = at::TensorOptions() + .dtype(at::kHalf) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + auto tmp = torch::empty(weight.sizes(), options); + T* weight16 = (T*)tmp.data_ptr(); + launch_dequantize( + weight16, + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + InferenceContext::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + weight.size(0), + bsz, + weight.size(1), + &alpha, + &gemm_beta, + weight16, + (T*)input, + (T*)output, + 99); } template -std::vector ds_rms_qkv(at::Tensor& input, - at::Tensor& weight, - at::Tensor& q_scale, - at::Tensor& gamma, - const float epsilon, - bool q_int8, - bool transposed_mode) -{ - const int bsz = input.size(0) * input.size(1); - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - T* rms_norm_ptr = workspace + (3 * bsz * input.size(2)); - int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); - - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - auto rms_norm = at::from_blob(rms_norm_ptr, - input.sizes(), - c10::TensorType::contiguousStridesOf(input.sizes()), - nullptr, - options, - input.device()); - auto output = at::from_blob( +at::Tensor qkv_unfused_mkl( + at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias, + bool q_int8, + bool transposed_mode) { + int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + workspace += (3 * bsz * input.size(2)); + ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); + + if (q_int8) { + quantized_gemm( + output.data_ptr(), workspace, - {input.size(0), input.size(1), out_size}, - c10::TensorType::contiguousStridesOf({input.size(0), input.size(1), out_size}), - nullptr, - options, - input.device()); - - launch_rms_norm((T*)rms_norm.data_ptr(), - (T*)nullptr, - (const T*)input.data_ptr(), - (const T*)nullptr, - (const T*)gamma.data_ptr(), - epsilon, - bsz, - input.size(2), - InferenceContext::Instance().GetCurrentStream()); - - if (q_int8) { - quantized_gemm((T*)output.data_ptr(), - (T*)rms_norm.data_ptr(), - weight, - q_scale, - q_scale.size(0), - bsz, - input.size(2)); - } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - weight.size(transposed_mode ? 0 : 1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)rms_norm.data_ptr(), - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - - return {output, rms_norm}; -} + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; -template -std::vector ds_qkv_gemm(at::Tensor& input, - at::Tensor& weight, - at::Tensor& q_scale, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias, - bool q_int8, - bool transposed_mode) -{ - int bsz = input.size(0) * input.size(1); - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); - - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - auto output = at::from_blob( + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), workspace, - {input.size(0), input.size(1), out_size}, - c10::TensorType::contiguousStridesOf({input.size(0), input.size(1), out_size}), - nullptr, - options, - input.device()); - auto inp_norm = qkv_unfused_cublas(output, - input, - weight, - q_scale, - bias, - gamma, - beta, - epsilon, - add_bias, - q_int8, - transposed_mode); - - return {output, inp_norm}; + (T*)output.data_ptr(), + 99); + } + if (add_bias) + launch_bias_add( + (T*)output.data_ptr(), + (T*)bias.data_ptr(), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + return at::from_blob( + workspace, + input.sizes(), + c10::TensorType::contiguousStridesOf(input.sizes()), + nullptr, + input.options(), + input.options().device()); } template -void quantized_gemm(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& qscale, - int groups, - int merge_count) -{ - int bsz = input.size(0) * input.size(1); - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); - - launch_dequantize((T*)weight16.data_ptr(), - (int8_t*)weight.data_ptr(), - (float*)qscale.data_ptr(), - weight.size(0), - weight.size(1), - groups, - merge_count, - InferenceContext::Instance().GetCurrentStream()); - +std::vector ds_rms_qkv( + at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& gamma, + const float epsilon, + bool q_int8, + bool transposed_mode) { + const int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* rms_norm_ptr = workspace + (3 * bsz * input.size(2)); + int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + auto rms_norm = at::from_blob( + rms_norm_ptr, + input.sizes(), + c10::TensorType::contiguousStridesOf(input.sizes()), + nullptr, + options, + input.device()); + auto output = at::from_blob( + workspace, + {input.size(0), input.size(1), out_size}, + c10::TensorType::contiguousStridesOf( + {input.size(0), input.size(1), out_size}), + nullptr, + options, + input.device()); + + launch_rms_norm( + (T*)rms_norm.data_ptr(), + (T*)nullptr, + (const T*)input.data_ptr(), + (const T*)nullptr, + (const T*)gamma.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); + + if (q_int8) { + quantized_gemm( + (T*)output.data_ptr(), + (T*)rms_norm.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), - oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, - weight.size(0), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight16.data_ptr(), - (T*)input.data_ptr(), - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif -} - -template -at::Tensor ds_linear_layer(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - bool add_bias, - bool do_flash_attn, - int num_heads, - bool transposed_mode, - float rope_theta) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - int head_size = input_cont.size(2) / num_heads; - int bsz = input.size(0) * input.size(1); - int out_size = transposed_mode ? weight.size(0) : weight.size(1); - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - auto output = at::from_blob( - workspace, - {input.size(0), input.size(1), out_size}, - c10::TensorType::contiguousStridesOf({input.size(0), input.size(1), out_size}), - nullptr, - options, - input.device()); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; *(InferenceContext::Instance().GetCublasHandle()) = *(InferenceContext::Instance().GetCurrentStream()); - - cublas_gemm_ex( + mkl_gemm_ex( InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), oneapi::mkl::transpose::nontrans, weight.size(transposed_mode ? 0 : 1), bsz, - input_cont.size(2), + input.size(2), &alpha, &gemm_beta, (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), + (T*)rms_norm.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else 99); -#endif - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(transposed_mode ? 0 : 1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); - if (do_flash_attn) { - if (add_padding) { - int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); - auto padded_output = workspace + output.numel(); - auto final_output = - padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); - pad_data(padded_output, - workspace, - 3 * bsz * num_heads, - head_size, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - - launch_bias_add_transform_0213( - final_output, - final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size), - final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size), - padded_output, - nullptr, - input.size(0), - input.size(1), - 0, - input.size(1), - (num_heads * padded_head_size), - num_heads, - -1, - -1, - false, - false, - InferenceContext::Instance().GetCurrentStream(), - 3, - input.size(1), - rope_theta); - return at::from_blob( - final_output, - {3, input.size(0), num_heads, input.size(1), padded_head_size}, - c10::TensorType::contiguousStridesOf( - {3, input.size(0), num_heads, input.size(1), padded_head_size}), - nullptr, - options, - input.device()); - // return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads, - // padded_head_size}, options); - } else { - auto final_output = workspace + output.numel(); - launch_bias_add_transform_0213( - final_output, - final_output + (input.size(0) * input.size(1) * input_cont.size(2)), - final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), - workspace, - nullptr, - input.size(0), - input.size(1), - 0, - input.size(1), - input_cont.size(2), - num_heads, - -1, - -1, - false, - false, - InferenceContext::Instance().GetCurrentStream(), - 3, - input.size(1), - rope_theta); - return at::from_blob(final_output, - {3, input.size(0), num_heads, input.size(1), head_size}, - c10::TensorType::contiguousStridesOf( - {3, input.size(0), num_heads, input.size(1), head_size}), - nullptr, - options, - input.device()); - // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, - // head_size}, options); - } - - } else - return output; + } + + return {output, rms_norm}; } template -std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tensor& value) -{ - int head_size = query.size(3); - int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); - T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; - pad_head_seq(workspace, - (T*)query.data_ptr(), - query.size(0) * query.size(1), - query.size(2), - query.size(2), - head_size, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - pad_head_seq(key_pad_ptr, - (T*)key.data_ptr(), - query.size(0) * query.size(1), - key.size(2), - 128, - head_size, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - pad_head_seq(value_pad_ptr, - (T*)value.data_ptr(), - query.size(0) * query.size(1), - key.size(2), - 128, - head_size, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - return {at::from_blob(workspace, - {query.size(0), query.size(1), query.size(2), padded_head_size}, - c10::TensorType::contiguousStridesOf( - {query.size(0), query.size(1), query.size(2), padded_head_size}), - nullptr, - query.options(), - query.options().device()), - at::from_blob(key_pad_ptr, - {query.size(0), query.size(1), 128, padded_head_size}, - c10::TensorType::contiguousStridesOf( - {query.size(0), query.size(1), 128, padded_head_size}), - nullptr, - query.options(), - query.options().device()), - at::from_blob(value_pad_ptr, - {query.size(0), query.size(1), 128, padded_head_size}, - c10::TensorType::contiguousStridesOf( - {query.size(0), query.size(1), 128, padded_head_size}), - nullptr, - query.options(), - query.options().device())}; +std::vector ds_qkv_gemm( + at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias, + bool q_int8, + bool transposed_mode) { + int bsz = input.size(0) * input.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + auto output = at::from_blob( + workspace, + {input.size(0), input.size(1), out_size}, + c10::TensorType::contiguousStridesOf( + {input.size(0), input.size(1), out_size}), + nullptr, + options, + input.device()); + auto inp_norm = qkv_unfused_mkl( + output, + input, + weight, + q_scale, + bias, + gamma, + beta, + epsilon, + add_bias, + q_int8, + transposed_mode); + + return {output, inp_norm}; } template -std::vector padd_add_transform(at::Tensor& query, - at::Tensor& key, - at::Tensor& value, - int heads, - bool add_padding) -{ - int head_size = query.size(2) / heads; - int key_value_length = add_padding ? 128 : key.size(1); - int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) - : head_size; - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); - T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; - launch_pad_add_transform_0213(workspace, - (T*)query.data_ptr(), - query.size(0), - query.size(2), - query.size(1), - query.size(1), - heads, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - launch_pad_add_transform_0213(key_pad_ptr, - (T*)key.data_ptr(), - key.size(0), - key.size(2), - key.size(1), - key_value_length, - heads, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - launch_pad_add_transform_0213(value_pad_ptr, - (T*)value.data_ptr(), - value.size(0), - value.size(2), - value.size(1), - key_value_length, - heads, - padded_head_size, - InferenceContext::Instance().GetCurrentStream()); - return {at::from_blob(workspace, - {query.size(0), heads, query.size(1), padded_head_size}, - c10::TensorType::contiguousStridesOf( - {query.size(0), heads, query.size(1), padded_head_size}), - nullptr, - query.options(), - query.options().device()), - at::from_blob(key_pad_ptr, - {query.size(0), heads, key_value_length, padded_head_size}, - c10::TensorType::contiguousStridesOf( - {query.size(0), heads, key_value_length, padded_head_size}), - nullptr, - query.options(), - query.options().device()), - at::from_blob(value_pad_ptr, - {query.size(0), heads, key_value_length, padded_head_size}, - c10::TensorType::contiguousStridesOf( - {query.size(0), heads, key_value_length, padded_head_size}), - nullptr, - query.options(), - query.options().device())}; +void quantized_gemm( + at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int merge_count) { + int bsz = input.size(0) * input.size(1); + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize( + (T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + merge_count, + InferenceContext::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + weight.size(0), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight16.data_ptr(), + (T*)input.data_ptr(), + (T*)output.data_ptr(), + 99); } template -at::Tensor ds_vector_matmul(at::Tensor& input, - at::Tensor& weight, - bool async_op, - at::Tensor& q_scale, - bool q_int8, - bool transposed_mode) -{ - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1); - int bsz = input.size(0) * input.size(1); - - T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); - auto output = at::from_blob( - workspace, - {input.size(0), input.size(1), out_size}, - c10::TensorType::contiguousStridesOf({input.size(0), input.size(1), out_size}), - nullptr, - options, - input.device()); - if (q_int8) { - quantized_gemm(output.data_ptr(), - (T*)input.data_ptr(), - weight, - q_scale, - q_scale.size(0), - bsz, - input.size(2)); +at::Tensor ds_linear_layer( + at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + bool add_bias, + bool do_flash_attn, + int num_heads, + bool transposed_mode, + float rope_theta) { + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + int head_size = input_cont.size(2) / num_heads; + int bsz = input.size(0) * input.size(1); + int out_size = transposed_mode ? weight.size(0) : weight.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + auto output = at::from_blob( + workspace, + {input.size(0), input.size(1), out_size}, + c10::TensorType::contiguousStridesOf( + {input.size(0), input.size(1), out_size}), + nullptr, + options, + input.device()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + weight.size(transposed_mode ? 0 : 1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), + 99); + if (add_bias) + launch_bias_add( + (T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(transposed_mode ? 0 : 1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + bool add_padding = + (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); + if (do_flash_attn) { + if (add_padding) { + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + auto padded_output = workspace + output.numel(); + auto final_output = padded_output + + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); + pad_data( + padded_output, + workspace, + 3 * bsz * num_heads, + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + + launch_bias_add_transform_0213( + final_output, + final_output + + (input.size(0) * input.size(1) * num_heads * padded_head_size), + final_output + + (input.size(0) * input.size(1) * 2 * num_heads * + padded_head_size), + padded_output, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + (num_heads * padded_head_size), + num_heads, + -1, + -1, + false, + false, + InferenceContext::Instance().GetCurrentStream(), + 3, + input.size(1), + rope_theta); + return at::from_blob( + final_output, + {3, input.size(0), num_heads, input.size(1), padded_head_size}, + c10::TensorType::contiguousStridesOf( + {3, input.size(0), num_heads, input.size(1), padded_head_size}), + nullptr, + options, + input.device()); + // return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, + // num_heads, padded_head_size}, options); } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream(async_op)); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - weight.size(transposed_mode ? 0 : 1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input.data_ptr(), - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif + auto final_output = workspace + output.numel(); + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * input_cont.size(2)), + final_output + + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), + workspace, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + input_cont.size(2), + num_heads, + -1, + -1, + false, + false, + InferenceContext::Instance().GetCurrentStream(), + 3, + input.size(1), + rope_theta); + return at::from_blob( + final_output, + {3, input.size(0), num_heads, input.size(1), head_size}, + c10::TensorType::contiguousStridesOf( + {3, input.size(0), num_heads, input.size(1), head_size}), + nullptr, + options, + input.device()); + // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, + // num_heads, head_size}, options); } + + } else return output; } template -at::Tensor ds_vector_matmul_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& q_scale, - int groups, - int merge_count) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); - return output; +std::vector add_padding( + at::Tensor& query, + at::Tensor& key, + at::Tensor& value) { + int head_size = query.size(3); + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + + padded_head_size * query.size(0) * query.size(1) * query.size(2); + T* value_pad_ptr = + key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; + pad_head_seq( + workspace, + (T*)query.data_ptr(), + query.size(0) * query.size(1), + query.size(2), + query.size(2), + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + pad_head_seq( + key_pad_ptr, + (T*)key.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + pad_head_seq( + value_pad_ptr, + (T*)value.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + return { + at::from_blob( + workspace, + {query.size(0), query.size(1), query.size(2), padded_head_size}, + c10::TensorType::contiguousStridesOf( + {query.size(0), query.size(1), query.size(2), padded_head_size}), + nullptr, + query.options(), + query.options().device()), + at::from_blob( + key_pad_ptr, + {query.size(0), query.size(1), 128, padded_head_size}, + c10::TensorType::contiguousStridesOf( + {query.size(0), query.size(1), 128, padded_head_size}), + nullptr, + query.options(), + query.options().device()), + at::from_blob( + value_pad_ptr, + {query.size(0), query.size(1), 128, padded_head_size}, + c10::TensorType::contiguousStridesOf( + {query.size(0), query.size(1), 128, padded_head_size}), + nullptr, + query.options(), + query.options().device())}; } template -at::Tensor mlp_unfused_cublas(at::Tensor& output, - at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& weight1, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm, - bool mlp_after_attn, - at::Tensor& q_scale, - at::Tensor& q_scale1, - bool q_int8, - ActivationFuncType act_func_type, - bool transposed_mode) -{ - int bsz = input.size(0) * input.size(1); - T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) + - torch::numel(output); - T* intermediate = inp_norm + torch::numel(input); - - if (mlp_after_attn) { - launch_fused_residual_ln((T*)inp_norm, - (const T*)input.data_ptr(), - (const T*)residual.data_ptr(), - (const T*)input_bias.data_ptr(), - (const T*)gamma.data_ptr(), - (const T*)beta.data_ptr(), - epsilon, - bsz, - input.size(2), - InferenceContext::Instance().GetCurrentStream()); - } else { - ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); - } - if (q_int8) { - quantized_gemm( - intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2)); - } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - weight.size(transposed_mode ? 0 : 1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - inp_norm, - intermediate, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - if (act_func_type == ActivationFuncType::GELU) { - launch_bias_gelu(intermediate, - (T*)bias.data_ptr(), - (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - } else if (act_func_type == ActivationFuncType::ReLU) { - launch_bias_relu(intermediate, - (T*)bias.data_ptr(), - (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), - bsz, - InferenceContext::Instance().GetCurrentStream()); - } - - if (q_int8) { - quantized_gemm(output.data_ptr(), - intermediate, - weight1, - q_scale1, - q_scale1.size(0), - bsz, - input.size(2)); - } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - weight1.size(transposed_mode ? 0 : 1), - bsz, - weight1.size(transposed_mode ? 1 : 0), - &alpha, - &gemm_beta, - (T*)weight1.data_ptr(), - intermediate, - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - - return at::from_blob(inp_norm, - input.sizes(), - c10::TensorType::contiguousStridesOf(input.sizes()), - nullptr, - input.options(), - input.options().device()); +std::vector padd_add_transform( + at::Tensor& query, + at::Tensor& key, + at::Tensor& value, + int heads, + bool add_padding) { + int head_size = query.size(2) / heads; + int key_value_length = add_padding ? 128 : key.size(1); + int padded_head_size = add_padding + ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) + : head_size; + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + T* key_pad_ptr = + workspace + padded_head_size * query.size(0) * heads * query.size(1); + T* value_pad_ptr = + key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; + launch_pad_add_transform_0213( + workspace, + (T*)query.data_ptr(), + query.size(0), + query.size(2), + query.size(1), + query.size(1), + heads, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + launch_pad_add_transform_0213( + key_pad_ptr, + (T*)key.data_ptr(), + key.size(0), + key.size(2), + key.size(1), + key_value_length, + heads, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + launch_pad_add_transform_0213( + value_pad_ptr, + (T*)value.data_ptr(), + value.size(0), + value.size(2), + value.size(1), + key_value_length, + heads, + padded_head_size, + InferenceContext::Instance().GetCurrentStream()); + return { + at::from_blob( + workspace, + {query.size(0), heads, query.size(1), padded_head_size}, + c10::TensorType::contiguousStridesOf( + {query.size(0), heads, query.size(1), padded_head_size}), + nullptr, + query.options(), + query.options().device()), + at::from_blob( + key_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + c10::TensorType::contiguousStridesOf( + {query.size(0), heads, key_value_length, padded_head_size}), + nullptr, + query.options(), + query.options().device()), + at::from_blob( + value_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + c10::TensorType::contiguousStridesOf( + {query.size(0), heads, key_value_length, padded_head_size}), + nullptr, + query.options(), + query.options().device())}; } template -std::vector ds_mlp_gemm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight_interm, - at::Tensor& weight_out, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm, - bool mlp_after_attn, - at::Tensor& q_scale, - at::Tensor& q_scale1, - bool q_int8, - int activation_type, - bool transposed_mode) -{ - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1); - auto output = at::from_blob( - (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input), - {input.size(0), input.size(1), out_size}, - c10::TensorType::contiguousStridesOf({input.size(0), input.size(1), out_size}), - nullptr, - options, - input.device()); - int bsz = input.size(0) * input.size(1); - - auto act_func_type = static_cast(activation_type); - auto res_add = mlp_unfused_cublas(output, - mlp_after_attn ? input : residual, - residual, - input_bias, - weight_interm, - weight_out, - bias, - gamma, - beta, - epsilon, - preLayerNorm, - mlp_after_attn, - q_scale, - q_scale1, - q_int8, - act_func_type, - transposed_mode); - - return {output, res_add}; +at::Tensor ds_vector_matmul( + at::Tensor& input, + at::Tensor& weight, + bool async_op, + at::Tensor& q_scale, + bool q_int8, + bool transposed_mode) { + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + int out_size = (q_int8 || transposed_mode) ? weight.size(0) : weight.size(1); + int bsz = input.size(0) * input.size(1); + + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + auto output = at::from_blob( + workspace, + {input.size(0), input.size(1), out_size}, + c10::TensorType::contiguousStridesOf( + {input.size(0), input.size(1), out_size}), + nullptr, + options, + input.device()); + if (q_int8) { + quantized_gemm( + output.data_ptr(), + (T*)input.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream(async_op)); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input.data_ptr(), + (T*)output.data_ptr(), + 99); + } + return output; } template -std::vector ds_rms_mlp_gemm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& weight_interm, - at::Tensor& weight_out, - at::Tensor& gamma, - const float epsilon, - at::Tensor& q_scale, - at::Tensor& q_scale1, - bool q_int8, - int activation_type, - bool transposed_mode) -{ - const int bsz = input.size(0) * input.size(1); - const size_t input_neurons = input.size(2); - const int mlp_1_out_neurons = transposed_mode ? weight_interm.size(0) - : weight_interm.size(1); - const size_t mlp_2_in_neurons = transposed_mode ? weight_out.size(1) : weight_out.size(0); - - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - T* output_ptr = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input); - T* inp_norm_ptr = output_ptr + torch::numel(input); - T* intermediate_ptr = inp_norm_ptr + torch::numel(input); - - auto output = at::from_blob(output_ptr, - input.sizes(), - c10::TensorType::contiguousStridesOf(input.sizes()), - nullptr, - options, - input.device()); - auto inp_norm = at::from_blob(inp_norm_ptr, - input.sizes(), - c10::TensorType::contiguousStridesOf(input.sizes()), - nullptr, - options, - input.device()); - auto intermediate_gemm = at::from_blob( - intermediate_ptr, - {input.size(0), input.size(1), mlp_1_out_neurons}, - c10::TensorType::contiguousStridesOf({input.size(0), input.size(1), mlp_1_out_neurons}), - nullptr, - options, - input.device()); - - auto act_func_type = static_cast(activation_type); - - // RMS Norm, we'll update the residual in-place - launch_rms_norm((T*)inp_norm.data_ptr(), - (T*)residual.data_ptr(), - (const T*)input.data_ptr(), - (const T*)residual.data_ptr(), - (const T*)gamma.data_ptr(), - epsilon, - bsz, - input_neurons, - InferenceContext::Instance().GetCurrentStream()); - - if (q_int8) { - quantized_gemm(intermediate_ptr, - (T*)inp_norm.data_ptr(), - weight_interm, - q_scale, - q_scale.size(0), - bsz, - input_neurons); - } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - mlp_1_out_neurons, - bsz, - input_neurons, - &alpha, - &gemm_beta, - (T*)weight_interm.data_ptr(), - (T*)inp_norm.data_ptr(), - intermediate_ptr, -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - - if (act_func_type == ActivationFuncType::GELU) { - launch_bias_gelu(intermediate_ptr, - (T*)nullptr, - mlp_1_out_neurons, - bsz, - InferenceContext::Instance().GetCurrentStream()); - } else if (act_func_type == ActivationFuncType::ReLU) { - launch_bias_relu(intermediate_ptr, - (T*)nullptr, - mlp_1_out_neurons, - bsz, - InferenceContext::Instance().GetCurrentStream()); - } else if (act_func_type == ActivationFuncType::GATED_GELU) { - launch_gated_activation(intermediate_ptr, - (const T*)intermediate_ptr, - (const T*)nullptr, - bsz, - mlp_1_out_neurons, - mlp_1_out_neurons, - true, - InferenceContext::Instance().GetCurrentStream()); - } else if (act_func_type == ActivationFuncType::GATED_SILU) { - launch_gated_activation(intermediate_ptr, - (const T*)intermediate_ptr, - (const T*)nullptr, - bsz, - mlp_1_out_neurons, - mlp_1_out_neurons, - false, - InferenceContext::Instance().GetCurrentStream()); - } +at::Tensor ds_vector_matmul_int8( + at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + int groups, + int merge_count) { + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + auto output = at::empty( + {input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); + return output; +} - if (q_int8) { - quantized_gemm(output.data_ptr(), - intermediate_ptr, - weight_out, - q_scale1, - q_scale1.size(0), - bsz, - input.size(2)); - } else { - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - input_neurons, - bsz, - mlp_2_in_neurons, - &alpha, - &gemm_beta, - (T*)weight_out.data_ptr(), - intermediate_ptr, - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard, -#else - 99, -#endif - mlp_1_out_neurons); - } +template +at::Tensor mlp_unfused_mkl( + at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& weight1, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + ActivationFuncType act_func_type, + bool transposed_mode) { + int bsz = input.size(0) * input.size(1); + T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + + torch::numel(input) + torch::numel(output); + T* intermediate = inp_norm + torch::numel(input); + + if (mlp_after_attn) { + launch_fused_residual_ln( + (T*)inp_norm, + (const T*)input.data_ptr(), + (const T*)residual.data_ptr(), + (const T*)input_bias.data_ptr(), + (const T*)gamma.data_ptr(), + (const T*)beta.data_ptr(), + epsilon, + bsz, + input.size(2), + InferenceContext::Instance().GetCurrentStream()); + } else { + ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); + } + if (q_int8) { + quantized_gemm( + intermediate, + inp_norm, + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + weight.size(transposed_mode ? 0 : 1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + inp_norm, + intermediate, + 99); + } + if (act_func_type == ActivationFuncType::GELU) { + launch_bias_gelu( + intermediate, + (T*)bias.data_ptr(), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::ReLU) { + launch_bias_relu( + intermediate, + (T*)bias.data_ptr(), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), + bsz, + InferenceContext::Instance().GetCurrentStream()); + } + + if (q_int8) { + quantized_gemm( + output.data_ptr(), + intermediate, + weight1, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + weight1.size(transposed_mode ? 0 : 1), + bsz, + weight1.size(transposed_mode ? 1 : 0), + &alpha, + &gemm_beta, + (T*)weight1.data_ptr(), + intermediate, + (T*)output.data_ptr(), + 99); + } + + return at::from_blob( + inp_norm, + input.sizes(), + c10::TensorType::contiguousStridesOf(input.sizes()), + nullptr, + input.options(), + input.options().device()); +} - return {output, residual}; +template +std::vector ds_mlp_gemm( + at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight_interm, + at::Tensor& weight_out, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + int activation_type, + bool transposed_mode) { + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + int out_size = + (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1); + auto output = at::from_blob( + (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input), + {input.size(0), input.size(1), out_size}, + c10::TensorType::contiguousStridesOf( + {input.size(0), input.size(1), out_size}), + nullptr, + options, + input.device()); + int bsz = input.size(0) * input.size(1); + + auto act_func_type = static_cast(activation_type); + auto res_add = mlp_unfused_mkl( + output, + mlp_after_attn ? input : residual, + residual, + input_bias, + weight_interm, + weight_out, + bias, + gamma, + beta, + epsilon, + preLayerNorm, + mlp_after_attn, + q_scale, + q_scale1, + q_int8, + act_func_type, + transposed_mode); + + return {output, res_add}; } template -at::Tensor fused_gemm_gelu(at::Tensor& input, - at::Tensor& weight, - at::Tensor& weight_scale, - at::Tensor& bias, - at::Tensor& weight_out, - at::Tensor& weight_out_scale, - bool q_int8, - bool transposed_mode) -{ - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - - int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); - - // auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + - // torch::numel(input), - // {input.size(0), input.size(1), out_size}, - // options); - // T* intermediate = (T*)input.data_ptr() + torch::numel(input); - auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options); - - int bsz = input.size(0) * input.size(1); +std::vector ds_rms_mlp_gemm( + at::Tensor& input, + at::Tensor& residual, + at::Tensor& weight_interm, + at::Tensor& weight_out, + at::Tensor& gamma, + const float epsilon, + at::Tensor& q_scale, + at::Tensor& q_scale1, + bool q_int8, + int activation_type, + bool transposed_mode) { + const int bsz = input.size(0) * input.size(1); + const size_t input_neurons = input.size(2); + const int mlp_1_out_neurons = + transposed_mode ? weight_interm.size(0) : weight_interm.size(1); + const size_t mlp_2_in_neurons = + transposed_mode ? weight_out.size(1) : weight_out.size(0); + + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + T* output_ptr = + (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input); + T* inp_norm_ptr = output_ptr + torch::numel(input); + T* intermediate_ptr = inp_norm_ptr + torch::numel(input); + + auto output = at::from_blob( + output_ptr, + input.sizes(), + c10::TensorType::contiguousStridesOf(input.sizes()), + nullptr, + options, + input.device()); + auto inp_norm = at::from_blob( + inp_norm_ptr, + input.sizes(), + c10::TensorType::contiguousStridesOf(input.sizes()), + nullptr, + options, + input.device()); + auto intermediate_gemm = at::from_blob( + intermediate_ptr, + {input.size(0), input.size(1), mlp_1_out_neurons}, + c10::TensorType::contiguousStridesOf( + {input.size(0), input.size(1), mlp_1_out_neurons}), + nullptr, + options, + input.device()); + + auto act_func_type = static_cast(activation_type); + + // RMS Norm, we'll update the residual in-place + launch_rms_norm( + (T*)inp_norm.data_ptr(), + (T*)residual.data_ptr(), + (const T*)input.data_ptr(), + (const T*)residual.data_ptr(), + (const T*)gamma.data_ptr(), + epsilon, + bsz, + input_neurons, + InferenceContext::Instance().GetCurrentStream()); + + if (q_int8) { + quantized_gemm( + intermediate_ptr, + (T*)inp_norm.data_ptr(), + weight_interm, + q_scale, + q_scale.size(0), + bsz, + input_neurons); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + mlp_1_out_neurons, + bsz, + input_neurons, + &alpha, + &gemm_beta, + (T*)weight_interm.data_ptr(), + (T*)inp_norm.data_ptr(), + intermediate_ptr, + 99); + } + if (act_func_type == ActivationFuncType::GELU) { + launch_bias_gelu( + intermediate_ptr, + (T*)nullptr, + mlp_1_out_neurons, + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::ReLU) { + launch_bias_relu( + intermediate_ptr, + (T*)nullptr, + mlp_1_out_neurons, + bsz, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::GATED_GELU) { + launch_gated_activation( + intermediate_ptr, + (const T*)intermediate_ptr, + (const T*)nullptr, + bsz, + mlp_1_out_neurons, + mlp_1_out_neurons, + true, + InferenceContext::Instance().GetCurrentStream()); + } else if (act_func_type == ActivationFuncType::GATED_SILU) { + launch_gated_activation( + intermediate_ptr, + (const T*)intermediate_ptr, + (const T*)nullptr, + bsz, + mlp_1_out_neurons, + mlp_1_out_neurons, + false, + InferenceContext::Instance().GetCurrentStream()); + } + + if (q_int8) { + quantized_gemm( + output.data_ptr(), + intermediate_ptr, + weight_out, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); + } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; - if (q_int8) { - quantized_gemm(intermediate.data_ptr(), - (T*)input.data_ptr(), - weight, - weight_scale, - weight_scale.size(0), - bsz, - input.size(2)); - } else { - *(InferenceContext::Instance().GetCublasHandle()) = - *(InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - intm_dim, - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input.data_ptr(), - (T*)intermediate.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - launch_bias_gelu((T*)intermediate.data_ptr(), - (T*)bias.data_ptr(), - intm_dim, - bsz, - InferenceContext::Instance().GetCurrentStream()); - - int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1); - auto output = at::empty({input.size(0), input.size(1), out_size}, options); - if (q_int8) { - quantized_gemm(output.data_ptr(), - (T*)intermediate.data_ptr(), - weight_out, - weight_out_scale, - weight_out_scale.size(0), - bsz, - input.size(2)); - } else { - cublas_gemm_ex( - InferenceContext::Instance().GetCublasHandle(), - (transposed_mode ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans), - oneapi::mkl::transpose::nontrans, - out_size, - bsz, - intm_dim, - &alpha, - &gemm_beta, - (T*)weight_out.data_ptr(), - (T*)intermediate.data_ptr(), - (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ - rocblas_gemm_algo_standard); -#else - 99); -#endif - } - // cudaEventRecord(InferenceContext::Instance().GetCompEvent(2), - // InferenceContext::Instance().GetCurrentStream(true)); - return output; + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + input_neurons, + bsz, + mlp_2_in_neurons, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + intermediate_ptr, + (T*)output.data_ptr(), + 99, + mlp_1_out_neurons); + } + + return {output, residual}; } template -at::Tensor& residual_add_bias(at::Tensor& hidden_state, - at::Tensor& residual, - const at::Tensor& attention_output, - const at::Tensor& attention_bias, - const at::Tensor& final_bias, - const int mp_size, - const bool mlp_after_attn, - const bool add_bias, - const bool preln) -{ - int bsz = residual.size(0) * residual.size(1); - int hidden_size = residual.size(2); - if (mlp_after_attn) - launch_bias_residual(static_cast(residual.data_ptr()), - static_cast(hidden_state.data_ptr()), - static_cast(attention_output.data_ptr()), - static_cast(final_bias.data_ptr()), - static_cast(attention_bias.data_ptr()), - bsz, - hidden_size, - mp_size, - preln, - InferenceContext::Instance().GetCurrentStream()); - else - launch_gptj_residual_add( - static_cast(residual.data_ptr()), - static_cast(hidden_state.data_ptr()), - static_cast(attention_output.data_ptr()), - static_cast(final_bias.data_ptr()), - static_cast((add_bias ? attention_bias.data_ptr() : nullptr)), - hidden_size, - bsz, - mp_size, - InferenceContext::Instance().GetCurrentStream()); - return residual; +at::Tensor fused_gemm_gelu( + at::Tensor& input, + at::Tensor& weight, + at::Tensor& weight_scale, + at::Tensor& bias, + at::Tensor& weight_out, + at::Tensor& weight_out_scale, + bool q_int8, + bool transposed_mode) { + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + + int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); + + // auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + // + torch::numel(input), + // {input.size(0), input.size(1), out_size}, + // options); + // T* intermediate = (T*)input.data_ptr() + torch::numel(input); + auto intermediate = + at::empty({input.size(0), input.size(1), intm_dim}, options); + + int bsz = input.size(0) * input.size(1); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + if (q_int8) { + quantized_gemm( + intermediate.data_ptr(), + (T*)input.data_ptr(), + weight, + weight_scale, + weight_scale.size(0), + bsz, + input.size(2)); + } else { + *(InferenceContext::Instance().GetCublasHandle()) = + *(InferenceContext::Instance().GetCurrentStream()); + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + intm_dim, + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input.data_ptr(), + (T*)intermediate.data_ptr(), + 99); + } + launch_bias_gelu( + (T*)intermediate.data_ptr(), + (T*)bias.data_ptr(), + intm_dim, + bsz, + InferenceContext::Instance().GetCurrentStream()); + + int out_size = + (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1); + auto output = at::empty({input.size(0), input.size(1), out_size}, options); + if (q_int8) { + quantized_gemm( + output.data_ptr(), + (T*)intermediate.data_ptr(), + weight_out, + weight_out_scale, + weight_out_scale.size(0), + bsz, + input.size(2)); + } else { + mkl_gemm_ex( + InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans), + oneapi::mkl::transpose::nontrans, + out_size, + bsz, + intm_dim, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + (T*)intermediate.data_ptr(), + (T*)output.data_ptr(), + 99); + } + return output; } -#define DISPATCH_VECTOR_ADD(T_TYPE, C_TYPE) \ - if (a.scalar_type() == at::k##T_TYPE) { \ - launch_vector_add((C_TYPE*)(a.data_ptr()), \ - (const C_TYPE*)(a.data_ptr()), \ - (const C_TYPE*)(b.data_ptr()), \ - gamma, \ - total_elems, \ - InferenceContext::Instance().GetCurrentStream()); \ - } - -at::Tensor& _vector_add(at::Tensor& a, at::Tensor& b, float gamma) -{ - const int total_elems = a.numel(); +template +at::Tensor& residual_add_bias( + at::Tensor& hidden_state, + at::Tensor& residual, + const at::Tensor& attention_output, + const at::Tensor& attention_bias, + const at::Tensor& final_bias, + const int mp_size, + const bool mlp_after_attn, + const bool add_bias, + const bool preln) { + int bsz = residual.size(0) * residual.size(1); + int hidden_size = residual.size(2); + if (mlp_after_attn) + launch_bias_residual( + static_cast(residual.data_ptr()), + static_cast(hidden_state.data_ptr()), + static_cast(attention_output.data_ptr()), + static_cast(final_bias.data_ptr()), + static_cast(attention_bias.data_ptr()), + bsz, + hidden_size, + mp_size, + preln, + InferenceContext::Instance().GetCurrentStream()); + else + launch_gptj_residual_add( + static_cast(residual.data_ptr()), + static_cast(hidden_state.data_ptr()), + static_cast(attention_output.data_ptr()), + static_cast(final_bias.data_ptr()), + static_cast((add_bias ? attention_bias.data_ptr() : nullptr)), + hidden_size, + bsz, + mp_size, + InferenceContext::Instance().GetCurrentStream()); + return residual; +} - DISPATCH_VECTOR_ADD(Float, float) - DISPATCH_VECTOR_ADD(Half, sycl::half) +#define DISPATCH_VECTOR_ADD(T_TYPE, C_TYPE) \ + if (a.scalar_type() == at::k##T_TYPE) { \ + launch_vector_add( \ + (C_TYPE*)(a.data_ptr()), \ + (const C_TYPE*)(a.data_ptr()), \ + (const C_TYPE*)(b.data_ptr()), \ + gamma, \ + total_elems, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor& _vector_add(at::Tensor& a, at::Tensor& b, float gamma) { + const int total_elems = a.numel(); + + DISPATCH_VECTOR_ADD(Float, float) + DISPATCH_VECTOR_ADD(Half, sycl::half) #ifdef BF16_AVAILABLE - DISPATCH_VECTOR_ADD(BFloat16, sycl::ext::oneapi::bfloat16) + DISPATCH_VECTOR_ADD(BFloat16, sycl::ext::oneapi::bfloat16) #endif - return a; + return a; } -std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, - at::Tensor& key_layer, - unsigned rotary_dim, - unsigned offset, - unsigned num_heads, - bool rotate_half, - float rope_theta) -{ - auto query_cont = mixed_query.contiguous(); - auto key_cont = key_layer.contiguous(); - - unsigned bsz = mixed_query.size(0); - unsigned head_size = mixed_query.size(2) / num_heads; - unsigned seq_len = mixed_query.size(1); - - if (mixed_query.scalar_type() == at::kFloat) - launch_apply_rotary_pos_emb((float*)query_cont.data_ptr(), - (float*)key_cont.data_ptr(), - head_size, - seq_len, - rotary_dim, - offset, - num_heads, - bsz, - rope_theta, - InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLength()); - else - launch_apply_rotary_pos_emb((sycl::half*)query_cont.data_ptr(), - (sycl::half*)key_cont.data_ptr(), - head_size, - seq_len, - rotary_dim, - offset, - num_heads, - bsz, - rope_theta, - InferenceContext::Instance().GetCurrentStream(), - InferenceContext::Instance().GetMaxTokenLength()); - return {query_cont, key_cont}; +std::vector apply_rotary_pos_emb( + at::Tensor& mixed_query, + at::Tensor& key_layer, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + bool rotate_half, + float rope_theta) { + auto query_cont = mixed_query.contiguous(); + auto key_cont = key_layer.contiguous(); + + unsigned bsz = mixed_query.size(0); + unsigned head_size = mixed_query.size(2) / num_heads; + unsigned seq_len = mixed_query.size(1); + + if (mixed_query.scalar_type() == at::kFloat) + launch_apply_rotary_pos_emb( + (float*)query_cont.data_ptr(), + (float*)key_cont.data_ptr(), + head_size, + seq_len, + rotary_dim, + offset, + num_heads, + bsz, + rope_theta, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + else + launch_apply_rotary_pos_emb( + (sycl::half*)query_cont.data_ptr(), + (sycl::half*)key_cont.data_ptr(), + head_size, + seq_len, + rotary_dim, + offset, + num_heads, + bsz, + rope_theta, + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLength()); + return {query_cont, key_cont}; } -#define DISPATCH_MOE_RESIDUAL(T_TYPE, C_TYPE) \ - if (moe_res.scalar_type() == torch::T_TYPE) { \ - launch_moe_res_matmul((C_TYPE*)moe_res.data_ptr(), \ - (C_TYPE*)coef.data_ptr(), \ - (C_TYPE*)output.data_ptr(), \ - M, \ - N, \ - InferenceContext::Instance().GetCurrentStream()); \ - } - -at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output) -{ - int M = moe_res.size(0) * moe_res.size(1); - int N = moe_res.size(2); - InferenceContext::Instance().SynchComm(); - - DISPATCH_MOE_RESIDUAL(kFloat, float) - DISPATCH_MOE_RESIDUAL(kHalf, sycl::half) +#define DISPATCH_MOE_RESIDUAL(T_TYPE, C_TYPE) \ + if (moe_res.scalar_type() == torch::T_TYPE) { \ + launch_moe_res_matmul( \ + (C_TYPE*)moe_res.data_ptr(), \ + (C_TYPE*)coef.data_ptr(), \ + (C_TYPE*)output.data_ptr(), \ + M, \ + N, \ + InferenceContext::Instance().GetCurrentStream()); \ + } + +at::Tensor moe_res_matmul( + at::Tensor& moe_res, + at::Tensor& coef, + at::Tensor& output) { + int M = moe_res.size(0) * moe_res.size(1); + int N = moe_res.size(2); + + DISPATCH_MOE_RESIDUAL(kFloat, float) + DISPATCH_MOE_RESIDUAL(kHalf, sycl::half) #ifdef BF16_AVAILABLE - DISPATCH_MOE_RESIDUAL(kBFloat16, sycl::ext::oneapi::bfloat16) + DISPATCH_MOE_RESIDUAL(kBFloat16, sycl::ext::oneapi::bfloat16) #endif - return output; + return output; } -void ds_release_workspace() { InferenceContext::Instance().release_workspace(); } +void ds_release_workspace() { + InferenceContext::Instance().release_workspace(); +} -bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); } +bool ds_retake_workspace() { + return InferenceContext::Instance().retake_workspace(); +} template -at::Tensor ds_dequantize(at::Tensor& weight, at::Tensor& qscale, int groups) -{ - auto options = at::TensorOptions() - .dtype(torch::kFloat16) - .layout(at::kStrided) - .device(at::kXPU) - .requires_grad(false); - auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); - - launch_dequantize((T*)weight16.data_ptr(), - (int8_t*)weight.data_ptr(), - (float*)qscale.data_ptr(), - weight.size(0), - weight.size(1), - groups, - InferenceContext::Instance().GetCurrentStream()); - - return weight16; +at::Tensor ds_dequantize(at::Tensor& weight, at::Tensor& qscale, int groups) { + auto options = at::TensorOptions() + .dtype(torch::kFloat16) + .layout(at::kStrided) + .device(at::kXPU) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize( + (T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + InferenceContext::Instance().GetCurrentStream()); + + return weight16; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("softmax_context_int8", - &ds_softmax_context1, - "DeepSpeed attention with int8 (CUDA)"); - - // The following functions handle type dispatching internally - m.def("gated_activation", &ds_gated_activation, "DeepSpeed Bias GEGLU (CUDA)"); - m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)"); - m.def( - "_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)"); - m.def("layer_norm_residual_store_pre_ln_res", - &ds_layer_norm_residual_store_pre_ln_res, - "DeepSpeed layer norm + store pre Layernorm residual (CUDA)"); - m.def("rms_norm", &ds_rms_norm, "DeepSpeed rms norm (CUDA)"); - m.def("pre_rms_norm", &ds_pre_rms_norm, "DeepSpeed pre rms norm (CUDA)"); - m.def("_vector_add", &_vector_add, "DeepSpeed vector add (CUDA)"); - m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); - m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); - m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace"); - m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace"); - - // The following functions are templated and need to be explicitly instantiated and bound - // to different python methods -#define DEF_OPS(_name, _dtype) \ - m.def("softmax_" #_name, &ds_softmax<_dtype>, "DeepSpeed SoftMax with " #_name " (CUDA)"); \ - m.def("softmax_context_" #_name, \ - &ds_softmax_context<_dtype>, \ - "DeepSpeed attention with " #_name " (CUDA)"); \ - m.def("bias_gelu_" #_name, &ds_bias_gelu<_dtype>, "DeepSpeed Gelu with " #_name " (CUDA)"); \ - m.def("bias_add_" #_name, &ds_bias_add<_dtype>, "DeepSpeed Bias Add with " #_name " (CUDA)"); \ - m.def("bias_relu_" #_name, &ds_bias_relu<_dtype>, "DeepSpeed ReLU with " #_name " (CUDA)"); \ - m.def("bias_residual_" #_name, \ - &ds_bias_residual<_dtype>, \ - "DeepSpeed residual-bias add with " #_name " (CUDA)"); \ - m.def("qkv_gemm_" #_name, &ds_qkv_gemm<_dtype>, "DeepSpeed qkv gemm with " #_name " (CUDA)"); \ - m.def("rms_qkv_gemm_" #_name, \ - &ds_rms_qkv<_dtype>, \ - "DeepSpeed rms qkv gemm with " #_name " (CUDA)"); \ - m.def("mlp_gemm_" #_name, &ds_mlp_gemm<_dtype>, "DeepSpeed mlp with " #_name " (CUDA)"); \ - m.def("rms_mlp_gemm_" #_name, \ - &ds_rms_mlp_gemm<_dtype>, \ - "DeepSpeed rms mlp gemm with " #_name " (CUDA)"); \ - m.def("vector_matmul_" #_name, \ - &ds_vector_matmul<_dtype>, \ - "DeepSpeed vector-MM with " #_name " (CUDA)"); \ - m.def("linear_layer_" #_name, \ - &ds_linear_layer<_dtype>, \ - "DeepSpeed linear_layer with " #_name " (CUDA)"); \ - m.def("fused_gemm_gelu_" #_name, \ - &fused_gemm_gelu<_dtype>, \ - "DeepSpeed mlp with " #_name " (CUDA)"); \ - m.def("residual_add_bias_" #_name, \ - &residual_add_bias<_dtype>, \ - "DeepSpeed residual add with " #_name " (CUDA)"); \ - m.def("einsum_sec_sm_ecm_" #_name, \ - &einsum_sec_sm_ecm<_dtype>, \ - "DeepSpeed vector-MM with " #_name " (CUDA)"); \ - m.def("add_padding_" #_name, \ - &add_padding<_dtype>, \ - "DeepSpeed residual add with " #_name " (CUDA)"); \ - m.def("pad_transform_" #_name, \ - &padd_add_transform<_dtype>, \ - "DeepSpeed residual add with " #_name " (CUDA)"); \ - m.def("allocate_workspace_" #_name, \ - &allocate_workspace<_dtype>, \ - "DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)"); \ - m.def("dequantize_" #_name, \ - &ds_dequantize<_dtype>, \ - "DeepSpeed dequantize with " #_name " (CUDA)") - - DEF_OPS(fp32, float); - DEF_OPS(fp16, sycl::half); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "softmax_context_int8", + &ds_softmax_context1, + "DeepSpeed attention with int8 (SYCL)"); + + // The following functions handle type dispatching internally + m.def( + "gated_activation", &ds_gated_activation, "DeepSpeed Bias GEGLU (SYCL)"); + m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (SYCL)"); + m.def( + "_layer_norm_residual", + &ds_layer_norm_residual, + "DeepSpeed layer norm + residual (SYCL)"); + m.def( + "layer_norm_residual_store_pre_ln_res", + &ds_layer_norm_residual_store_pre_ln_res, + "DeepSpeed layer norm + store pre Layernorm residual (SYCL)"); + m.def("rms_norm", &ds_rms_norm, "DeepSpeed rms norm (SYCL)"); + m.def("pre_rms_norm", &ds_pre_rms_norm, "DeepSpeed pre rms norm (SYCL)"); + m.def("_vector_add", &_vector_add, "DeepSpeed vector add (SYCL)"); + m.def( + "apply_rotary_pos_emb", + &apply_rotary_pos_emb, + "DeepSpeed mlp with fp16 (SYCL)"); + m.def( + "moe_res_matmul", + &moe_res_matmul, + "DeepSpeed moe residual matmul (SYCL)"); + m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); + m.def( + "release_workspace", + &ds_release_workspace, + "DeepSpeed Release Workspace"); + m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace"); + + // The following functions are templated and need to be explicitly + // instantiated and bound to different python methods +#define DEF_OPS(_name, _dtype) \ + m.def( \ + "softmax_" #_name, \ + &ds_softmax<_dtype>, \ + "DeepSpeed SoftMax with " #_name " (SYCL)"); \ + m.def( \ + "softmax_context_" #_name, \ + &ds_softmax_context<_dtype>, \ + "DeepSpeed attention with " #_name " (SYCL)"); \ + m.def( \ + "bias_gelu_" #_name, \ + &ds_bias_gelu<_dtype>, \ + "DeepSpeed Gelu with " #_name " (SYCL)"); \ + m.def( \ + "bias_add_" #_name, \ + &ds_bias_add<_dtype>, \ + "DeepSpeed Bias Add with " #_name " (SYCL)"); \ + m.def( \ + "bias_relu_" #_name, \ + &ds_bias_relu<_dtype>, \ + "DeepSpeed ReLU with " #_name " (SYCL)"); \ + m.def( \ + "bias_residual_" #_name, \ + &ds_bias_residual<_dtype>, \ + "DeepSpeed residual-bias add with " #_name " (SYCL)"); \ + m.def( \ + "qkv_gemm_" #_name, \ + &ds_qkv_gemm<_dtype>, \ + "DeepSpeed qkv gemm with " #_name " (SYCL)"); \ + m.def( \ + "rms_qkv_gemm_" #_name, \ + &ds_rms_qkv<_dtype>, \ + "DeepSpeed rms qkv gemm with " #_name " (SYCL)"); \ + m.def( \ + "mlp_gemm_" #_name, \ + &ds_mlp_gemm<_dtype>, \ + "DeepSpeed mlp with " #_name " (SYCL)"); \ + m.def( \ + "rms_mlp_gemm_" #_name, \ + &ds_rms_mlp_gemm<_dtype>, \ + "DeepSpeed rms mlp gemm with " #_name " (SYCL)"); \ + m.def( \ + "vector_matmul_" #_name, \ + &ds_vector_matmul<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (SYCL)"); \ + m.def( \ + "linear_layer_" #_name, \ + &ds_linear_layer<_dtype>, \ + "DeepSpeed linear_layer with " #_name " (SYCL)"); \ + m.def( \ + "fused_gemm_gelu_" #_name, \ + &fused_gemm_gelu<_dtype>, \ + "DeepSpeed mlp with " #_name " (SYCL)"); \ + m.def( \ + "residual_add_bias_" #_name, \ + &residual_add_bias<_dtype>, \ + "DeepSpeed residual add with " #_name " (SYCL)"); \ + m.def( \ + "einsum_sec_sm_ecm_" #_name, \ + &einsum_sec_sm_ecm<_dtype>, \ + "DeepSpeed vector-MM with " #_name " (SYCL)"); \ + m.def( \ + "add_padding_" #_name, \ + &add_padding<_dtype>, \ + "DeepSpeed residual add with " #_name " (SYCL)"); \ + m.def( \ + "pad_transform_" #_name, \ + &padd_add_transform<_dtype>, \ + "DeepSpeed residual add with " #_name " (SYCL)"); \ + m.def( \ + "allocate_workspace_" #_name, \ + &allocate_workspace<_dtype>, \ + "DeepSpeed memory allocation for GPT inference with " #_name " (SYCL)"); \ + m.def( \ + "dequantize_" #_name, \ + &ds_dequantize<_dtype>, \ + "DeepSpeed dequantize with " #_name " (SYCL)") + + DEF_OPS(fp32, float); + DEF_OPS(fp16, sycl::half); #ifdef BF16_AVAILABLE - DEF_OPS(bf16, sycl::ext::oneapi::bfloat16); + DEF_OPS(bf16, sycl::ext::oneapi::bfloat16); #endif } diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/relu.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/relu.dp.cpp index 77d4807..dc178ea 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/relu.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/relu.dp.cpp @@ -1,82 +1,117 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" #include "memory_access_utils.h" #define MAX_CAP 4 #define MAX_SEQ 2048 -inline float relu(const float x) { return x < 0 ? 0 : x; } +inline float relu(const float x) { + return x < 0 ? 0 : x; +} /* In-place relu(biasAdd(x)) for channels last */ template -void fused_bias_relu(T* input, const T* bias, int total_count, int intermediate_size) -{ +class fused_bias_relu { + private: + T* input; + const T* bias; + int total_count; + int intermediate_size; + + public: + fused_bias_relu( + T* input, + const T* bias, + int total_count, + int intermediate_size) + : input(input), + bias(bias), + total_count(total_count), + intermediate_size(intermediate_size) {} + void operator()(sycl::nd_item<3>) const { // Input restriction: intermediate_size % vals_per_access == 0 auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); constexpr int granularity = 16; constexpr int values_per_access = granularity / sizeof(T); - const int offset = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2)) * + const int offset = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * values_per_access; if (offset < total_count) { - T data[values_per_access]; - T data_bias[values_per_access]; - mem_access::load_global(data, input + offset); - mem_access::load_global( - data_bias, bias + (offset % intermediate_size), bias != nullptr); + T data[values_per_access]; + T data_bias[values_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global( + data_bias, bias + (offset % intermediate_size), bias != nullptr); #pragma unroll - for (int i = 0; i < values_per_access; i++) { - float data_f = conversion::to(data[i]); - float bias_f = conversion::to(data_bias[i]); - data[i] = conversion::to(relu(data_f + bias_f)); - } + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(relu(data_f + bias_f)); + } - mem_access::store_global(input + offset, data); + mem_access::store_global(input + offset, data); } -} + } +}; template -void launch_bias_relu(T* input, - const T* bias, - int intermediate_size, - int batch_size, - dpct::queue_ptr stream) -{ - constexpr int threads = 1024; - constexpr int granularity = 16; +void launch_bias_relu( + T* input, + const T* bias, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream) { + constexpr int threads = 1024; + constexpr int granularity = 16; - const int total_count = batch_size * intermediate_size; - const int elems_per_block = threads * (granularity / sizeof(T)); - sycl::range<3> block_dims(1, 1, threads); - sycl::range<3> grid_dims(1, 1, (total_count + elems_per_block - 1) / elems_per_block); + const int total_count = batch_size * intermediate_size; + const int elems_per_block = threads * (granularity / sizeof(T)); + sycl::range<3> block_dims(1, 1, threads); + sycl::range<3> grid_dims( + 1, 1, (total_count + elems_per_block - 1) / elems_per_block); - /* - DPCT1049:0: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - fused_bias_relu(input, bias, total_count, intermediate_size); - }); - } + /* + DPCT1049:0: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + fused_bias_relu fn(input, bias, total_count, intermediate_size); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } #define INSTANTIATE_LAUNCH_BIAS_RELU(T) \ - template void launch_bias_relu(T*, const T*, int, int, dpct::queue_ptr); + template void launch_bias_relu(T*, const T*, int, int, dpct::queue_ptr); INSTANTIATE_LAUNCH_BIAS_RELU(float) #ifdef BF16_AVAILABLE diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/rms_norm.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/rms_norm.dp.cpp index ce3d8cd..6f1a429 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/rms_norm.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/rms_norm.dp.cpp @@ -1,13 +1,28 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include "conversion_utils.h" #include "ds_kernel_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" #include "memory_access_utils.h" #include "reduction_utils.h" @@ -15,29 +30,44 @@ using rop = reduce::ROpType; namespace rms { constexpr int granularity = 16; -} // namespace rms +} // namespace rms template -/* -DPCT1110:3: The total declared local variable size in device function rms_norm exceeds 128 bytes and -may cause high register pressure. Consult with your hardware vendor to find the total register size -available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row) -{ +class rms_norm { + private: + T* output; + const T* vals; + const T* gamma; + float epsilon; + int elems_per_row; + + public: + rms_norm( + T* output, + const T* vals, + const T* gamma, + float epsilon, + int elems_per_row) + : output(output), + vals(vals), + gamma(gamma), + epsilon(epsilon), + elems_per_row(elems_per_row) {} + void operator()(sycl::nd_item<3>) const { constexpr int T_per_load = rms::granularity / sizeof(T); sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); // X-dimension of the block - const int block_offset = - (tb.get_group_id()[2] * (maxThreads / threadsPerGroup) * elems_per_row) + + const int block_offset = (tb.get_group_id()[2] * + (maxThreads / threadsPerGroup) * elems_per_row) + (tb.get_local_id()[1] * elems_per_row); const int thread_offset = tb.get_local_id()[2] * T_per_load; const int base_offset = block_offset + thread_offset; const int stride = - sycl::ext::oneapi::experimental::this_nd_item<3>().get_local_range(2) * T_per_load; + sycl::ext::oneapi::experimental::this_nd_item<3>().get_local_range(2) * + T_per_load; float var_sum = reduce::init(); @@ -47,26 +77,28 @@ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems #pragma unroll for (int i = 0; i < UNROLL; i++) { - T* iteration_buffer = local_buffer + (i * T_per_load); + T* iteration_buffer = local_buffer + (i * T_per_load); - mem_access::load_global(iteration_buffer, - input_base + (i * stride), - thread_offset + (i * stride) < elems_per_row); + mem_access::load_global( + iteration_buffer, + input_base + (i * stride), + thread_offset + (i * stride) < elems_per_row); #pragma unroll - for (int j = 0; j < T_per_load; j++) { - float up_cast = conversion::to(iteration_buffer[j]); - float sq_val = up_cast * up_cast; - var_sum = reduce::element(var_sum, sq_val); - } + for (int j = 0; j < T_per_load; j++) { + float up_cast = conversion::to(iteration_buffer[j]); + float sq_val = up_cast * up_cast; + var_sum = reduce::element(var_sum, sq_val); + } } reduce::partitioned_block(tb, warp, var_sum); const float var = var_sum / elems_per_row; /* - DPCT1013:8: The rounding mode could not be specified and the generated code may have different - accuracy than the original code. Verify the correctness. SYCL math built-in function rounding - mode is aligned with OpenCL C 1.2 standard. + DPCT1013:8: The rounding mode could not be specified and the generated code + may have different accuracy than the original code. Verify the correctness. + SYCL math built-in function rounding mode is aligned with OpenCL C 1.2 + standard. */ const T denom = conversion::to(sycl::rsqrt(var + epsilon)); @@ -74,53 +106,71 @@ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems #pragma unroll for (int i = 0; i < UNROLL; i++) { - T* iteration_buffer = local_buffer + (i * T_per_load); - const int iter_idx = i * stride + thread_offset; - const bool do_loads = (iter_idx < elems_per_row); + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); - T gamma_local[T_per_load]; + T gamma_local[T_per_load]; - mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global( + gamma_local, gamma + iter_idx, do_loads); #pragma unroll - for (int j = 0; j < T_per_load; j++) { - iteration_buffer[j] *= denom; - iteration_buffer[j] *= gamma_local[j]; - } - - if (do_loads) { - mem_access::store_global(block_output + iter_idx, iteration_buffer); - } + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global( + block_output + iter_idx, iteration_buffer); + } } -} + } +}; template -/* -DPCT1110:4: The total declared local variable size in device function pre_rms_norm exceeds 128 bytes -and may cause high register pressure. Consult with your hardware vendor to find the total register -size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -void pre_rms_norm(T* output, - T* res_out, - const T* vals, - const T* residual, - const T* gamma, - float epsilon, - int elems_per_row) -{ +class pre_rms_norm { + private: + T* output; + T* res_out; + const T* vals; + const T* residual; + const T* gamma; + float epsilon; + int elems_per_row; + + public: + pre_rms_norm( + T* output, + T* res_out, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int elems_per_row) + : output(output), + res_out(res_out), + vals(vals), + residual(residual), + gamma(gamma), + epsilon(epsilon), + elems_per_row(elems_per_row) {} + void operator()(sycl::nd_item<3>) const { constexpr int T_per_load = rms::granularity / sizeof(T); sycl::group<3> tb = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group warp = sycl::ext::oneapi::experimental::this_sub_group(); // X-dimension of the block - const int block_offset = - (tb.get_group_id()[2] * (maxThreads / threadsPerGroup) * elems_per_row) + + const int block_offset = (tb.get_group_id()[2] * + (maxThreads / threadsPerGroup) * elems_per_row) + (tb.get_local_id()[1] * elems_per_row); const int thread_offset = tb.get_local_id()[2] * T_per_load; const int base_offset = block_offset + thread_offset; const int stride = - sycl::ext::oneapi::experimental::this_nd_item<3>().get_local_range(2) * T_per_load; + sycl::ext::oneapi::experimental::this_nd_item<3>().get_local_range(2) * + T_per_load; float var_sum = reduce::init(); @@ -132,36 +182,39 @@ void pre_rms_norm(T* output, #pragma unroll for (int i = 0; i < UNROLL; i++) { - T* iteration_buffer = local_buffer + (i * T_per_load); - T residual_buffer[T_per_load]; + T* iteration_buffer = local_buffer + (i * T_per_load); + T residual_buffer[T_per_load]; - const int iter_offset = i * stride + thread_offset; - const bool do_loads = (iter_offset < elems_per_row); + const int iter_offset = i * stride + thread_offset; + const bool do_loads = (iter_offset < elems_per_row); - mem_access::load_global( - iteration_buffer, input_base + (i * stride), do_loads); - mem_access::load_global( - residual_buffer, residual_base + (i * stride), do_loads); + mem_access::load_global( + iteration_buffer, input_base + (i * stride), do_loads); + mem_access::load_global( + residual_buffer, residual_base + (i * stride), do_loads); #pragma unroll - for (int j = 0; j < T_per_load; j++) { - iteration_buffer[j] += residual_buffer[j]; - float vals_up_cast = conversion::to(iteration_buffer[j]); - - var_sum = reduce::element(var_sum, vals_up_cast * vals_up_cast); - } - - if (do_loads) { - mem_access::store_global(res_output + i * stride, iteration_buffer); - } + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] += residual_buffer[j]; + float vals_up_cast = conversion::to(iteration_buffer[j]); + + var_sum = reduce::element( + var_sum, vals_up_cast * vals_up_cast); + } + + if (do_loads) { + mem_access::store_global( + res_output + i * stride, iteration_buffer); + } } reduce::partitioned_block(tb, warp, var_sum); const float var = var_sum / elems_per_row; /* - DPCT1013:9: The rounding mode could not be specified and the generated code may have different - accuracy than the original code. Verify the correctness. SYCL math built-in function rounding - mode is aligned with OpenCL C 1.2 standard. + DPCT1013:9: The rounding mode could not be specified and the generated code + may have different accuracy than the original code. Verify the correctness. + SYCL math built-in function rounding mode is aligned with OpenCL C 1.2 + standard. */ const T denom = conversion::to(sycl::rsqrt(var + epsilon)); @@ -169,156 +222,151 @@ void pre_rms_norm(T* output, #pragma unroll for (int i = 0; i < UNROLL; i++) { - T* iteration_buffer = local_buffer + (i * T_per_load); - const int iter_idx = i * stride + thread_offset; - const bool do_loads = (iter_idx < elems_per_row); + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); - T gamma_local[T_per_load]; + T gamma_local[T_per_load]; - mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global( + gamma_local, gamma + iter_idx, do_loads); #pragma unroll - for (int j = 0; j < T_per_load; j++) { - iteration_buffer[j] *= denom; - iteration_buffer[j] *= gamma_local[j]; - } - - if (do_loads) { - mem_access::store_global(block_output + iter_idx, iteration_buffer); - } + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global( + block_output + iter_idx, iteration_buffer); + } } -} - -/* -DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. -*/ -#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* norm_output_ct0 = norm_output; \ - const T* vals_ct1 = vals; \ - const T* gamma_ct2 = gamma; \ - auto epsilon_ct3 = epsilon; \ - auto elems_per_row_ct4 = elems_per_row; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - rms_norm( \ - norm_output_ct0, vals_ct1, gamma_ct2, epsilon_ct3, elems_per_row_ct4); \ - }); \ + } +}; + +#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + stream->submit([&](sycl::handler& cgh) { \ + rms_norm fn( \ + norm_output, vals, gamma, epsilon, elems_per_row); \ + \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ }); /* -DPCT1049:5: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. +DPCT1049:5: The work-group size passed to the SYCL kernel may exceed the limit. +To get the device limit, query info::device::max_work_group_size. Adjust the +work-group size if needed. */ -#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* norm_output_ct0 = norm_output; \ - T* res_output_ct1 = res_output; \ - const T* vals_ct2 = vals; \ - const T* residual_ct3 = residual; \ - const T* gamma_ct4 = gamma; \ - auto epsilon_ct5 = epsilon; \ - auto elems_per_row_ct6 = elems_per_row; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - pre_rms_norm(norm_output_ct0, \ - res_output_ct1, \ - vals_ct2, \ - residual_ct3, \ - gamma_ct4, \ - epsilon_ct5, \ - elems_per_row_ct6); \ - }); \ +#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + stream->submit([&](sycl::handler& cgh) { \ + pre_rms_norm fn( \ + norm_output, \ + res_output, \ + vals, \ + residual, \ + gamma, \ + epsilon, \ + elems_per_row); \ + \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ }); #define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ - if (pre_norm) { \ - LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ - } else { \ - LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ - } + if (pre_norm) { \ + LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } else { \ + LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } template -void launch_rms_norm(T* norm_output, - T* res_output, - const T* vals, - const T* residual, - const T* gamma, - float epsilon, - int rows, - int elems_per_row, - dpct::queue_ptr stream) -{ - // 8 for sycl::half, 4 for float - constexpr int T_per_load = rms::granularity / sizeof(T); - constexpr int maxThreads = 256; - constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2; - - const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; - const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll; - - // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of - // warp-sized blocks rather than stepping up to 64/96 threads - const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); - const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; - - const int groups_per_block_max = - is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1; - const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; - const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; - - sycl::range<3> block(1, groups_per_block, threads_per_group); - sycl::range<3> grid(1, 1, groups_launch); - - const int elems_per_step = threads_per_group * h_per_step; - const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; - - bool pre_norm = (residual == nullptr) ? false : true; - - if (is_subblock_schedule) { - // <=128 - if (threads_per_group == 1) { - LAUNCH_ALL_RMS_NORM(1, 1, maxThreads); - } else if (threads_per_group == 2) { - LAUNCH_ALL_RMS_NORM(1, 2, maxThreads); - } else if (threads_per_group == 4) { - LAUNCH_ALL_RMS_NORM(1, 4, maxThreads); - } else if (threads_per_group == 8) { - LAUNCH_ALL_RMS_NORM(1, 8, maxThreads); - } else if (threads_per_group == 16) { - LAUNCH_ALL_RMS_NORM(1, 16, maxThreads); - } - } else if (external_unRoll == 1) { - // 129 - 4096 elems - // (this can launch with 1-7 warps as well) - LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads); - } else if (external_unRoll == 2) { - // 4097 - 8192 elems - LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads); - } else if (external_unRoll == 3) { - // 8193 - 12288 elems - LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads); - } else if (external_unRoll == 4) { - // 12289 - 16384 elems - LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads); +void launch_rms_norm( + T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream) { + // 8 for sycl::half, 4 for float + constexpr int T_per_load = rms::granularity / sizeof(T); + constexpr int maxThreads = 256; + constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = + is_subblock_schedule ? T_per_load : T_per_load * internalUnroll; + + // Scheduling concern: may be slightly faster for some inputs to assign + // multiple stages of warp-sized blocks rather than stepping up to 64/96 + // threads + const int one_step_threads = + next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threads_per_group = + (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = is_subblock_schedule + ? (maxThreads + threads_per_group - 1) / threads_per_group + : 1; + const int groups_per_block = + (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + sycl::range<3> block(1, groups_per_block, threads_per_group); + sycl::range<3> grid(1, 1, groups_launch); + + const int elems_per_step = threads_per_group * h_per_step; + const int external_unRoll = + (elems_per_row + elems_per_step - 1) / elems_per_step; + + bool pre_norm = (residual == nullptr) ? false : true; + + if (is_subblock_schedule) { + // <=128 + if (threads_per_group == 1) { + LAUNCH_ALL_RMS_NORM(1, 1, maxThreads); + } else if (threads_per_group == 2) { + LAUNCH_ALL_RMS_NORM(1, 2, maxThreads); + } else if (threads_per_group == 4) { + LAUNCH_ALL_RMS_NORM(1, 4, maxThreads); + } else if (threads_per_group == 8) { + LAUNCH_ALL_RMS_NORM(1, 8, maxThreads); + } else if (threads_per_group == 16) { + LAUNCH_ALL_RMS_NORM(1, 16, maxThreads); } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads); + } } -#define INSTANTIATE_LAUNCH_RMS_NORM(T) \ - template void launch_rms_norm(T * norm_output, \ - T * res_output, \ - const T* vals, \ - const T* residual, \ - const T* gamma, \ - float epsilon, \ - int rows, \ - int elems_per_row, \ - dpct::queue_ptr stream); +#define INSTANTIATE_LAUNCH_RMS_NORM(T) \ + template void launch_rms_norm( \ + T * norm_output, \ + T * res_output, \ + const T* vals, \ + const T* residual, \ + const T* gamma, \ + float epsilon, \ + int rows, \ + int elems_per_row, \ + dpct::queue_ptr stream); INSTANTIATE_LAUNCH_RMS_NORM(float) INSTANTIATE_LAUNCH_RMS_NORM(sycl::half) diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/softmax.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/softmax.dp.cpp index 261bdbf..96c773b 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/softmax.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/softmax.dp.cpp @@ -1,16 +1,29 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team +#include #include -#include #include #include "conversion_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" -#ifndef __HIP_PLATFORM_AMD__ -#endif #include #include #include @@ -19,50 +32,81 @@ #define minus_infinity -10000.0 -void CheckCudaErrorAux(const char* file, unsigned line) -{ - /* - DPCT1010:11: SYCL uses exceptions to report errors and does not use the error codes. The call - was replaced with 0. You need to rewrite this code. - */ - dpct::err0 err = 0; - if (err == 0) return; - /* - DPCT1009:12: SYCL uses exceptions to report errors and does not use the error codes. The - original code was commented out and a warning string was inserted. You need to rewrite this - code. - */ - std::cerr << "cudaGetErrorString is not supported" /*cudaGetErrorString(err)*/ << "(" << err - << ") at " << file << ":" << line << std::endl; - throw std::runtime_error("CUDA ERROR!!!\n"); +void CheckCudaErrorAux(const char* file, unsigned line) { + /* + DPCT1010:11: SYCL uses exceptions to report errors and does not use the error + codes. The call was replaced with 0. You need to rewrite this code. + */ + dpct::err0 err = 0; + if (err == 0) + return; + /* + DPCT1009:12: SYCL uses exceptions to report errors and does not use the error + codes. The original code was commented out and a warning string was inserted. + You need to rewrite this code. + */ + std::cerr << "syclGetErrorString is not supported" /*syclGetErrorString(err)*/ + << "(" << err << ") at " << file << ":" << line << std::endl; + throw std::runtime_error("SYCL ERROR!!!\n"); } -#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) +#define SYCL_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) template -/* -DPCT1110:0: The total declared local variable size in device function attn_softmax_v2 exceeds 128 -bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void attn_softmax_v2(T* vals, - T* mask, - T* alibi, - float layer_scale, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int total_count, - int heads, - int sequence_length, - int num_seq, - int head_offset, - int mask_stride, - int mp_size, - int reduceWidth) -{ +class attn_softmax_v2 { + private: + mutable T* vals; + T* mask; + T* alibi; + float layer_scale; + bool triangular; + bool recompute; + bool local_attention; + int window_size; + int total_count; + int heads; + int sequence_length; + int num_seq; + int head_offset; + int mask_stride; + int mp_size; + int reduceWidth; + + public: + attn_softmax_v2( + T* vals, + T* mask, + T* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + int head_offset, + int mask_stride, + int mp_size, + int reduceWidth) + : vals(vals), + mask(mask), + alibi(alibi), + layer_scale(layer_scale), + triangular(triangular), + recompute(recompute), + local_attention(local_attention), + window_size(window_size), + total_count(total_count), + heads(heads), + sequence_length(sequence_length), + num_seq(num_seq), + head_offset(head_offset), + mask_stride(mask_stride), + mp_size(mp_size), + reduceWidth(reduceWidth) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); @@ -78,252 +122,314 @@ void attn_softmax_v2(T* vals, int reduce_blocks = reduceWidth >> 5; int seq_lane = item_ct1.get_local_id(2) % reduceWidth; - auto& partialSum = *sycl::ext::oneapi::group_local_memory_for_overwrite( - sycl::ext::oneapi::experimental::this_group<3>()); + auto& partialSum = *sycl::ext::oneapi::group_local_memory_for_overwrite< + float[MAX_WARP_NUM]>(sycl::ext::oneapi::experimental::this_group<3>()); - int iter_offset = item_ct1.get_group(2) * (warp_num / reduce_blocks) + (wid / reduce_blocks); + int iter_offset = item_ct1.get_group(2) * (warp_num / reduce_blocks) + + (wid / reduce_blocks); int batch_idx = iter_offset / (num_seq * heads); int alibi_offset = batch_idx * heads * mp_size + head_offset; int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); if (iter_offset < total_count) { - vals += (iter_offset * sequence_length); - - alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length; - mask_offset = mask_offset * sequence_length; - int seq_id = iter_offset % num_seq; - - int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); - int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) - ? (real_seq_id >> 2) - (window_size >> 2) - : 0; - int window_stride = - (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; - - float max_val = minus_infinity; - // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset); - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane); - bool check = (data_id >> 2) >= window_stride4; - bool low_x_check = check && (data_id < sequence_length) && - (!triangular || (data_id <= seq_id)) && (data_id > window_stride); - bool low_y_check = check && ((data_id + reduceWidth) < sequence_length) && - (!triangular || ((data_id + reduceWidth) <= seq_id)) && - ((data_id + reduceWidth) > window_stride); - bool high_x_check = check && ((data_id + reduceWidth * 2) < sequence_length) && - (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && - ((data_id + reduceWidth * 2) > window_stride); - bool high_y_check = check && ((data_id + reduceWidth * 3) < sequence_length) && - (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && - ((data_id + reduceWidth * 3) > window_stride); - - if (mask && alibi) { - low_data[i].x() = low_x_check - ? conversion::to(vals[data_id]) * layer_scale + - (conversion::to(alibi[data_id + alibi_offset])) + - (conversion::to(mask[data_id + mask_offset])) - : minus_infinity; - low_data[i].y() = - low_y_check - ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + - (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) + - (conversion::to(mask[data_id + mask_offset + reduceWidth])) - : minus_infinity; - high_data[i].x() = - high_x_check - ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + - (conversion::to( - alibi[data_id + alibi_offset + reduceWidth * 2])) + - (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) - : minus_infinity; - high_data[i].y() = - high_y_check - ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + - (conversion::to( - alibi[data_id + alibi_offset + reduceWidth * 3])) + - (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) - : minus_infinity; - } else if (mask) { - low_data[i].x() = low_x_check - ? conversion::to(vals[data_id]) * layer_scale + - (conversion::to(mask[data_id + mask_offset])) - : minus_infinity; - low_data[i].y() = - low_y_check - ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + - (conversion::to(mask[data_id + mask_offset + reduceWidth])) - : minus_infinity; - high_data[i].x() = - high_x_check - ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + - (conversion::to(mask[data_id + mask_offset + reduceWidth * 2])) - : minus_infinity; - high_data[i].y() = - high_y_check - ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + - (conversion::to(mask[data_id + mask_offset + reduceWidth * 3])) - : minus_infinity; - } else if (alibi) { - low_data[i].x() = low_x_check - ? conversion::to(vals[data_id]) * layer_scale + - (conversion::to(alibi[data_id + alibi_offset])) - : minus_infinity; - low_data[i].y() = - low_y_check - ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + - (conversion::to(alibi[data_id + alibi_offset + reduceWidth])) - : minus_infinity; - high_data[i].x() = - high_x_check - ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale + - (conversion::to( - alibi[data_id + alibi_offset + reduceWidth * 2])) - : minus_infinity; - high_data[i].y() = - high_y_check - ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale + - (conversion::to( - alibi[data_id + alibi_offset + reduceWidth * 3])) - : minus_infinity; - } else { - low_data[i].x() = low_x_check ? conversion::to(vals[data_id]) * layer_scale - : minus_infinity; - low_data[i].y() = - low_y_check ? conversion::to(vals[data_id + reduceWidth]) * layer_scale - : minus_infinity; - high_data[i].x() = - high_x_check - ? conversion::to(vals[data_id + reduceWidth * 2]) * layer_scale - : minus_infinity; - high_data[i].y() = - high_y_check - ? conversion::to(vals[data_id + reduceWidth * 3]) * layer_scale - : minus_infinity; - } - - // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); - max_val = (low_data[i].x() > max_val ? low_data[i].x() : max_val); - max_val = (low_data[i].y() > max_val ? low_data[i].y() : max_val); - max_val = (high_data[i].x() > max_val ? high_data[i].x() : max_val); - max_val = (high_data[i].y() > max_val ? high_data[i].y() : max_val); + vals += (iter_offset * sequence_length); + + alibi_offset = + (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length; + mask_offset = mask_offset * sequence_length; + int seq_id = iter_offset % num_seq; + + int real_seq_id = + seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = + (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = (local_attention && real_seq_id >= window_size) + ? real_seq_id - window_size + : -1; + + float max_val = minus_infinity; + // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset); + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + bool check = (data_id >> 2) >= window_stride4; + bool low_x_check = check && (data_id < sequence_length) && + (!triangular || (data_id <= seq_id)) && (data_id > window_stride); + bool low_y_check = check && + ((data_id + reduceWidth) < sequence_length) && + (!triangular || ((data_id + reduceWidth) <= seq_id)) && + ((data_id + reduceWidth) > window_stride); + bool high_x_check = check && + ((data_id + reduceWidth * 2) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && + ((data_id + reduceWidth * 2) > window_stride); + bool high_y_check = check && + ((data_id + reduceWidth * 3) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && + ((data_id + reduceWidth * 3) > window_stride); + + if (mask && alibi) { + low_data[i].x() = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + + (conversion::to(mask[data_id + mask_offset])) + : minus_infinity; + low_data[i].y() = low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * + layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth])) + + (conversion::to( + mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; + high_data[i].x() = high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * + layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) + + (conversion::to( + mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; + high_data[i].y() = high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * + layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) + + (conversion::to( + mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; + } else if (mask) { + low_data[i].x() = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(mask[data_id + mask_offset])) + : minus_infinity; + low_data[i].y() = low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * + layer_scale + + (conversion::to( + mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; + high_data[i].x() = high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * + layer_scale + + (conversion::to( + mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; + high_data[i].y() = high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * + layer_scale + + (conversion::to( + mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; + } else if (alibi) { + low_data[i].x() = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + + (conversion::to(alibi[data_id + alibi_offset])) + : minus_infinity; + low_data[i].y() = low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * + layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth])) + : minus_infinity; + high_data[i].x() = high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * + layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 2])) + : minus_infinity; + high_data[i].y() = high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * + layer_scale + + (conversion::to( + alibi[data_id + alibi_offset + reduceWidth * 3])) + : minus_infinity; + } else { + low_data[i].x() = low_x_check + ? conversion::to(vals[data_id]) * layer_scale + : minus_infinity; + low_data[i].y() = low_y_check + ? conversion::to(vals[data_id + reduceWidth]) * layer_scale + : minus_infinity; + high_data[i].x() = high_x_check + ? conversion::to(vals[data_id + reduceWidth * 2]) * + layer_scale + : minus_infinity; + high_data[i].y() = high_y_check + ? conversion::to(vals[data_id + reduceWidth * 3]) * + layer_scale + : minus_infinity; } - for (int i = 1; i < WARP_SIZE; i *= 2) { - auto temp = sycl::permute_group_by_xor( - sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); - max_val = (temp > max_val ? temp : max_val); + // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, + // seq_id); + max_val = (low_data[i].x() > max_val ? low_data[i].x() : max_val); + max_val = (low_data[i].y() > max_val ? low_data[i].y() : max_val); + max_val = (high_data[i].x() > max_val ? high_data[i].x() : max_val); + max_val = (high_data[i].y() > max_val ? high_data[i].y() : max_val); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) + partialSum[wid] = max_val; + /* + DPCT1065:1: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (lane < warp_num) + max_val = partialSum[lane]; + + /* + DPCT1065:2: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); + max_val = (temp > max_val ? temp : max_val); } - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = max_val; - /* - DPCT1065:1: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - if (lane < warp_num) max_val = partialSum[lane]; - - /* - DPCT1065:2: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - auto temp = sycl::permute_group_by_xor( - sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - /* - DPCT1007:13: Migration of cooperative_groups::thread_block_tile::shfl is not supported. - */ - max_val = g.shuffle(max_val, item_ct1.get_local_id(2) / WARP_SIZE); - } - float sum = 0; - for (int i = 0; i < iterations; i++) { - low_data[i].x() = sycl::native::exp(low_data[i].x() - max_val); - low_data[i].y() = sycl::native::exp(low_data[i].y() - max_val); - high_data[i].x() = sycl::native::exp(high_data[i].x() - max_val); - high_data[i].y() = sycl::native::exp(high_data[i].y() - max_val); - - sum += (low_data[i].x() + low_data[i].y() + high_data[i].x() + high_data[i].y()); + /* + DPCT1007:13: Migration of cooperative_groups::thread_block_tile::shfl is + not supported. + */ + max_val = g.shuffle(max_val, item_ct1.get_local_id(2) / WARP_SIZE); + } + float sum = 0; + for (int i = 0; i < iterations; i++) { + low_data[i].x() = sycl::native::exp(low_data[i].x() - max_val); + low_data[i].y() = sycl::native::exp(low_data[i].y() - max_val); + high_data[i].x() = sycl::native::exp(high_data[i].x() - max_val); + high_data[i].y() = sycl::native::exp(high_data[i].y() - max_val); + + sum += + (low_data[i].x() + low_data[i].y() + high_data[i].x() + + high_data[i].y()); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) + sum += sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) + partialSum[wid] = sum; + /* + DPCT1065:3: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (lane < warp_num) + sum = partialSum[lane]; + + /* + DPCT1065:4: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + sum += sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), sum, i); } - for (int i = 1; i < WARP_SIZE; i *= 2) sum += - sycl::permute_group_by_xor(sycl::ext::oneapi::experimental::this_sub_group(), sum, i); - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = sum; - /* - DPCT1065:3: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - if (lane < warp_num) sum = partialSum[lane]; - - /* - DPCT1065:4: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - sum += sycl::permute_group_by_xor( - sycl::ext::oneapi::experimental::this_sub_group(), sum, i); - } - - /* - DPCT1007:14: Migration of cooperative_groups::thread_block_tile::shfl is not supported. - */ - sum = g.shuffle(sum, item_ct1.get_local_id(2) / WARP_SIZE); - } - sum += 1e-6; - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane); - if (data_id < sequence_length) { - vals[data_id] = conversion::to(low_data[i].x() / sum); - if ((data_id + reduceWidth) < sequence_length) - vals[data_id + reduceWidth] = conversion::to(low_data[i].y() / sum); - if ((data_id + reduceWidth * 2) < sequence_length) - vals[data_id + reduceWidth * 2] = conversion::to(high_data[i].x() / sum); - if ((data_id + reduceWidth * 3) < sequence_length) - vals[data_id + reduceWidth * 3] = conversion::to(high_data[i].y() / sum); - } + /* + DPCT1007:14: Migration of cooperative_groups::thread_block_tile::shfl is + not supported. + */ + sum = g.shuffle(sum, item_ct1.get_local_id(2) / WARP_SIZE); + } + sum += 1e-6; + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + if (data_id < sequence_length) { + vals[data_id] = conversion::to(low_data[i].x() / sum); + if ((data_id + reduceWidth) < sequence_length) + vals[data_id + reduceWidth] = + conversion::to(low_data[i].y() / sum); + if ((data_id + reduceWidth * 2) < sequence_length) + vals[data_id + reduceWidth * 2] = + conversion::to(high_data[i].x() / sum); + if ((data_id + reduceWidth * 3) < sequence_length) + vals[data_id + reduceWidth * 3] = + conversion::to(high_data[i].y() / sum); } + } } -} + } +}; template -/* -DPCT1110:5: The total declared local variable size in device function attn_softmax_v2 exceeds 128 -bytes and may cause high register pressure. Consult with your hardware vendor to find the total -register size available and adjust the code, or use smaller sub-group size to avoid high register -pressure. -*/ -void attn_softmax_v2(float* vals, - float* attn_mask, - float* alibi, - float layer_scale, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int total_count, - int heads, - int sequence_length, - int num_seq, - int head_offset, - int mask_stride, - int mp_size, - int reduceWidth) -{ +class attn_softmax_v2 { + private: + mutable float* vals; + float* attn_mask; + float* alibi; + float layer_scale; + bool triangular; + bool recompute; + bool local_attention; + int window_size; + int total_count; + int heads; + int sequence_length; + int num_seq; + int head_offset; + int mask_stride; + int mp_size; + int reduceWidth; + + public: + attn_softmax_v2( + float* vals, + float* attn_mask, + float* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + int head_offset, + int mask_stride, + int mp_size, + int reduceWidth) + : vals(vals), + attn_mask(attn_mask), + alibi(alibi), + layer_scale(layer_scale), + triangular(triangular), + recompute(recompute), + local_attention(local_attention), + window_size(window_size), + total_count(total_count), + heads(heads), + sequence_length(sequence_length), + num_seq(num_seq), + head_offset(head_offset), + mask_stride(mask_stride), + mp_size(mp_size), + reduceWidth(reduceWidth) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); sycl::group<3> b = sycl::ext::oneapi::experimental::this_group<3>(); sycl::sub_group g = sycl::ext::oneapi::experimental::this_sub_group(); @@ -337,290 +443,291 @@ void attn_softmax_v2(float* vals, int reduce_blocks = reduceWidth >> 5; int seq_lane = item_ct1.get_local_id(2) % reduceWidth; - auto& partialSum = *sycl::ext::oneapi::group_local_memory_for_overwrite( - sycl::ext::oneapi::experimental::this_group<3>()); + auto& partialSum = *sycl::ext::oneapi::group_local_memory_for_overwrite< + float[MAX_WARP_NUM]>(sycl::ext::oneapi::experimental::this_group<3>()); - int iter_offset = item_ct1.get_group(2) * (warp_num / reduce_blocks) + (wid / reduce_blocks); + int iter_offset = item_ct1.get_group(2) * (warp_num / reduce_blocks) + + (wid / reduce_blocks); if (iter_offset < total_count) { - vals += (iter_offset * sequence_length); - - int batch_idx = iter_offset / (num_seq * heads); - int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); - mask_offset = mask_offset * sequence_length; - int seq_id = iter_offset % num_seq; - - int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); - int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) - ? (real_seq_id >> 2) - (window_size >> 2) - : 0; - int window_stride = - (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane); - bool check = (data_id >> 2) >= window_stride4; - bool x_check = check && (data_id < sequence_length) && - (!triangular || (data_id <= seq_id)) && (data_id > window_stride); - bool y_check = check && ((data_id + reduceWidth) < sequence_length) && - (!triangular || ((data_id + reduceWidth) <= seq_id)) && - ((data_id + reduceWidth) > window_stride); - bool z_check = check && ((data_id + reduceWidth * 2) < sequence_length) && - (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && - ((data_id + reduceWidth * 2) > window_stride); - bool w_check = check && ((data_id + reduceWidth * 3) < sequence_length) && - (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && - ((data_id + reduceWidth * 3) > window_stride); - - if (attn_mask) { - data[i].x() = x_check ? vals[data_id] + attn_mask[data_id + mask_offset] - : minus_infinity; - data[i].y() = y_check ? vals[data_id + reduceWidth] + - attn_mask[data_id + mask_offset + reduceWidth] - : minus_infinity; - data[i].z() = z_check ? vals[data_id + reduceWidth * 2] + - attn_mask[data_id + mask_offset + reduceWidth * 2] - : minus_infinity; - data[i].w() = w_check ? vals[data_id + reduceWidth * 3] + - attn_mask[data_id + mask_offset + reduceWidth * 3] - : minus_infinity; - } else { - data[i].x() = x_check ? vals[data_id] : minus_infinity; - data[i].y() = y_check ? vals[data_id + reduceWidth] : minus_infinity; - data[i].z() = z_check ? vals[data_id + reduceWidth * 2] : minus_infinity; - data[i].w() = w_check ? vals[data_id + reduceWidth * 3] : minus_infinity; - } - - max_val = (data[i].x() > max_val ? data[i].x() : max_val); - max_val = (data[i].y() > max_val ? data[i].y() : max_val); - max_val = (data[i].z() > max_val ? data[i].z() : max_val); - max_val = (data[i].w() > max_val ? data[i].w() : max_val); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { - auto temp = sycl::permute_group_by_xor( - sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); - max_val = (temp > max_val ? temp : max_val); + vals += (iter_offset * sequence_length); + + int batch_idx = iter_offset / (num_seq * heads); + int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); + mask_offset = mask_offset * sequence_length; + int seq_id = iter_offset % num_seq; + + int real_seq_id = + seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = + (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = (local_attention && real_seq_id >= window_size) + ? real_seq_id - window_size + : -1; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + bool check = (data_id >> 2) >= window_stride4; + bool x_check = check && (data_id < sequence_length) && + (!triangular || (data_id <= seq_id)) && (data_id > window_stride); + bool y_check = check && ((data_id + reduceWidth) < sequence_length) && + (!triangular || ((data_id + reduceWidth) <= seq_id)) && + ((data_id + reduceWidth) > window_stride); + bool z_check = check && + ((data_id + reduceWidth * 2) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && + ((data_id + reduceWidth * 2) > window_stride); + bool w_check = check && + ((data_id + reduceWidth * 3) < sequence_length) && + (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && + ((data_id + reduceWidth * 3) > window_stride); + + if (attn_mask) { + data[i].x() = x_check + ? vals[data_id] + attn_mask[data_id + mask_offset] + : minus_infinity; + data[i].y() = y_check ? vals[data_id + reduceWidth] + + attn_mask[data_id + mask_offset + reduceWidth] + : minus_infinity; + data[i].z() = z_check ? vals[data_id + reduceWidth * 2] + + attn_mask[data_id + mask_offset + reduceWidth * 2] + : minus_infinity; + data[i].w() = w_check ? vals[data_id + reduceWidth * 3] + + attn_mask[data_id + mask_offset + reduceWidth * 3] + : minus_infinity; + } else { + data[i].x() = x_check ? vals[data_id] : minus_infinity; + data[i].y() = y_check ? vals[data_id + reduceWidth] : minus_infinity; + data[i].z() = + z_check ? vals[data_id + reduceWidth * 2] : minus_infinity; + data[i].w() = + w_check ? vals[data_id + reduceWidth * 3] : minus_infinity; } - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = max_val; - /* - DPCT1065:6: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - if (lane < warp_num) max_val = partialSum[lane]; - - /* - DPCT1065:7: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - auto temp = sycl::permute_group_by_xor( - sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - /* - DPCT1007:15: Migration of cooperative_groups::thread_block_tile::shfl is not supported. - */ - max_val = g.shuffle(max_val, item_ct1.get_local_id(2) / WARP_SIZE); + max_val = (data[i].x() > max_val ? data[i].x() : max_val); + max_val = (data[i].y() > max_val ? data[i].y() : max_val); + max_val = (data[i].z() > max_val ? data[i].z() : max_val); + max_val = (data[i].w() > max_val ? data[i].w() : max_val); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) + partialSum[wid] = max_val; + /* + DPCT1065:6: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (lane < warp_num) + max_val = partialSum[lane]; + + /* + DPCT1065:7: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), max_val, i); + max_val = (temp > max_val ? temp : max_val); } - float sum = 0; - for (int i = 0; i < iterations; i++) { - data[i].x() = sycl::native::exp(data[i].x() - max_val); - data[i].y() = sycl::native::exp(data[i].y() - max_val); - data[i].z() = sycl::native::exp(data[i].z() - max_val); - data[i].w() = sycl::native::exp(data[i].w() - max_val); - - sum += (data[i].x() + data[i].y() + data[i].z() + data[i].w()); + /* + DPCT1007:15: Migration of cooperative_groups::thread_block_tile::shfl is + not supported. + */ + max_val = g.shuffle(max_val, item_ct1.get_local_id(2) / WARP_SIZE); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x() = sycl::native::exp(data[i].x() - max_val); + data[i].y() = sycl::native::exp(data[i].y() - max_val); + data[i].z() = sycl::native::exp(data[i].z() - max_val); + data[i].w() = sycl::native::exp(data[i].w() - max_val); + + sum += (data[i].x() + data[i].y() + data[i].z() + data[i].w()); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) + sum += sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) + partialSum[wid] = sum; + /* + DPCT1065:8: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (lane < warp_num) + sum = partialSum[lane]; + + /* + DPCT1065:9: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + sum += sycl::permute_group_by_xor( + sycl::ext::oneapi::experimental::this_sub_group(), sum, i); } - for (int i = 1; i < WARP_SIZE; i *= 2) sum += - sycl::permute_group_by_xor(sycl::ext::oneapi::experimental::this_sub_group(), sum, i); - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = sum; - /* - DPCT1065:8: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - if (lane < warp_num) sum = partialSum[lane]; - - /* - DPCT1065:9: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if - there is no access to global memory. - */ - item_ct1.barrier(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - sum += sycl::permute_group_by_xor( - sycl::ext::oneapi::experimental::this_sub_group(), sum, i); - } - - /* - DPCT1007:16: Migration of cooperative_groups::thread_block_tile::shfl is not supported. - */ - sum = g.shuffle(sum, item_ct1.get_local_id(2) / WARP_SIZE); - } - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane); - if (data_id < sequence_length) { - vals[data_id] = data[i].x() / sum; - if ((data_id + reduceWidth) < sequence_length) - vals[data_id + reduceWidth] = data[i].y() / sum; - if ((data_id + reduceWidth * 2) < sequence_length) - vals[data_id + reduceWidth * 2] = data[i].z() / sum; - if ((data_id + reduceWidth * 3) < sequence_length) - vals[data_id + reduceWidth * 3] = data[i].w() / sum; - } + /* + DPCT1007:16: Migration of cooperative_groups::thread_block_tile::shfl is + not supported. + */ + sum = g.shuffle(sum, item_ct1.get_local_id(2) / WARP_SIZE); + } + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane); + if (data_id < sequence_length) { + vals[data_id] = data[i].x() / sum; + if ((data_id + reduceWidth) < sequence_length) + vals[data_id + reduceWidth] = data[i].y() / sum; + if ((data_id + reduceWidth * 2) < sequence_length) + vals[data_id + reduceWidth * 2] = data[i].z() / sum; + if ((data_id + reduceWidth * 3) < sequence_length) + vals[data_id + reduceWidth * 3] = data[i].w() / sum; } + } } -} - -/* -DPCT1049:10: The work-group size passed to the SYCL kernel may exceed the limit. To get the device -limit, query info::device::max_work_group_size. Adjust the work-group size if needed. -*/ -#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ - { \ - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ - stream->submit([&](sycl::handler& cgh) { \ - T* vals_ct0 = vals; \ - T* mask_ct1 = mask; \ - T* alibi_ct2 = alibi; \ - auto layer_scale_ct3 = layer_scale; \ - auto triangular_ct4 = triangular; \ - auto recompute_ct5 = recompute; \ - auto local_attention_ct6 = local_attention; \ - auto window_size_ct7 = window_size; \ - auto total_count_ct8 = total_count; \ - auto heads_ct9 = heads; \ - auto sequence_length_ct10 = sequence_length; \ - auto num_seq_ct11 = num_seq; \ - auto head_offset_ct12 = head_offset; \ - auto mask_stride_ct13 = mask_stride; \ - auto mp_size_ct14 = mp_size; \ - auto reduce_width_ct15 = reduce_width; \ - \ - cgh.parallel_for(sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - attn_softmax_v2(vals_ct0, \ - mask_ct1, \ - alibi_ct2, \ - layer_scale_ct3, \ - triangular_ct4, \ - recompute_ct5, \ - local_attention_ct6, \ - window_size_ct7, \ - total_count_ct8, \ - heads_ct9, \ - sequence_length_ct10, \ - num_seq_ct11, \ - head_offset_ct12, \ - mask_stride_ct13, \ - mp_size_ct14, \ - reduce_width_ct15); \ - }); \ - }); \ + } +}; + +#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ + { \ + dpct::has_capability_or_fail( \ + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); \ + stream->submit([&](sycl::handler& cgh) { \ + attn_softmax_v2 fn( \ + vals, \ + mask, \ + alibi, \ + layer_scale, \ + triangular, \ + recompute, \ + local_attention, \ + window_size, \ + total_count, \ + heads, \ + sequence_length, \ + num_seq, \ + head_offset, \ + mask_stride, \ + mp_size, \ + reduce_width); \ + \ + cgh.parallel_for(sycl::nd_range<3>(grid * block, block), fn); \ + }); \ } template -void launch_attn_softmax_v2(T* vals, - T* mask, - T* alibi, - float layer_scale, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - int head_offset, - int mask_stride, - int mp_size, - dpct::queue_ptr stream) -{ - const int total_count = batch_size * heads * num_seq; - - // Scheduling Overview - // 4 element unroll with power of 2 `reduce_width` threads to a ceiling of `attn_threads` - // Each block should be partitioned into as many `reduce_width` blocks - // as can be fit. - constexpr int attn_threads = 256; - constexpr int min_reduce_width = hw_warp_size; - constexpr int internal_unroll = 4; - - // Handle internal unroll then round to next power of 2. Bump up to minimum granularity. - const int thread_steps_rounded = - next_pow2((sequence_length + internal_unroll - 1) / internal_unroll); - const int thread_steps_schedule = - (thread_steps_rounded < min_reduce_width) ? min_reduce_width : thread_steps_rounded; - // Bound reduce width to the number of threads - const int reduce_width = (thread_steps_schedule < attn_threads) ? thread_steps_schedule - : attn_threads; - // Scale for the excess - const int iterations = thread_steps_schedule / reduce_width; - // Should be safe since reduce_width is capped to attn_threads - const int partitions = attn_threads / reduce_width; - - // Launch params - sycl::range<3> grid(1, 1, (total_count + partitions - 1) / partitions); - sycl::range<3> block(1, 1, attn_threads); - - if (sequence_length <= 32768) { - if (iterations == 1) { - LAUNCH_ATTN_SOFTMAX_V2(1); - } else if (iterations == 2) { - LAUNCH_ATTN_SOFTMAX_V2(2); - } else if (iterations == 4) { - LAUNCH_ATTN_SOFTMAX_V2(4); - } else if (iterations == 8) { - LAUNCH_ATTN_SOFTMAX_V2(8); - } else if (iterations == 16) { - LAUNCH_ATTN_SOFTMAX_V2(16); - } else if (iterations == 32) { - LAUNCH_ATTN_SOFTMAX_V2(32); - } else if (iterations == 64) { - LAUNCH_ATTN_SOFTMAX_V2(64); - } - } else - throw std::runtime_error("Unsupport Seq_Length!"); +void launch_attn_softmax_v2( + T* vals, + T* mask, + T* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + int head_offset, + int mask_stride, + int mp_size, + dpct::queue_ptr stream) { + const int total_count = batch_size * heads * num_seq; + + // Scheduling Overview + // 4 element unroll with power of 2 `reduce_width` threads to a ceiling of + // `attn_threads` Each block should be partitioned into as many `reduce_width` + // blocks as can be fit. + constexpr int attn_threads = 256; + constexpr int min_reduce_width = hw_warp_size; + constexpr int internal_unroll = 4; + + // Handle internal unroll then round to next power of 2. Bump up to minimum + // granularity. + const int thread_steps_rounded = + next_pow2((sequence_length + internal_unroll - 1) / internal_unroll); + const int thread_steps_schedule = (thread_steps_rounded < min_reduce_width) + ? min_reduce_width + : thread_steps_rounded; + // Bound reduce width to the number of threads + const int reduce_width = (thread_steps_schedule < attn_threads) + ? thread_steps_schedule + : attn_threads; + // Scale for the excess + const int iterations = thread_steps_schedule / reduce_width; + // Should be safe since reduce_width is capped to attn_threads + const int partitions = attn_threads / reduce_width; + + // Launch params + sycl::range<3> grid(1, 1, (total_count + partitions - 1) / partitions); + sycl::range<3> block(1, 1, attn_threads); + + if (sequence_length <= 32768) { + if (iterations == 1) { + LAUNCH_ATTN_SOFTMAX_V2(1); + } else if (iterations == 2) { + LAUNCH_ATTN_SOFTMAX_V2(2); + } else if (iterations == 4) { + LAUNCH_ATTN_SOFTMAX_V2(4); + } else if (iterations == 8) { + LAUNCH_ATTN_SOFTMAX_V2(8); + } else if (iterations == 16) { + LAUNCH_ATTN_SOFTMAX_V2(16); + } else if (iterations == 32) { + LAUNCH_ATTN_SOFTMAX_V2(32); + } else if (iterations == 64) { + LAUNCH_ATTN_SOFTMAX_V2(64); + } + } else + throw std::runtime_error("Unsupport Seq_Length!"); } -#define INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(T) \ - template void launch_attn_softmax_v2(T* vals, \ - T* mask, \ - T* alibi, \ - float layer_scale, \ - bool triangular, \ - bool recompute, \ - bool local_attention, \ - int window_size, \ - int batch_size, \ - int heads, \ - int num_seq, \ - int sequence_length, \ - int head_offset, \ - int mask_stride, \ - int mp_size, \ - dpct::queue_ptr stream); +#define INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(T) \ + template void launch_attn_softmax_v2( \ + T* vals, \ + T* mask, \ + T* alibi, \ + float layer_scale, \ + bool triangular, \ + bool recompute, \ + bool local_attention, \ + int window_size, \ + int batch_size, \ + int heads, \ + int num_seq, \ + int sequence_length, \ + int head_offset, \ + int mask_stride, \ + int mp_size, \ + dpct::queue_ptr stream); INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(float); #ifdef BF16_AVAILABLE @@ -628,53 +735,54 @@ INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(sycl::ext::oneapi::bfloat16); #endif INSTANTIATE_LAUNCH_ATTN_SOFTMAX_V2(sycl::half); -#define DEF_ATTN_SOFTMAX_V2_HALF(_iter) \ - template void attn_softmax_v2(sycl::half * vals, \ - sycl::half * mask, \ - sycl::half * alibi, \ - float layer_scale, \ - bool triangular, \ - bool recompute, \ - bool local_attention, \ - int window_size, \ - int total_count, \ - int heads, \ - int sequence_length, \ - int num_seq, \ - int head_offset, \ - int mask_stride, \ - int mp_size, \ - int reduceWidth) - -#define DEF_ATTN_SOFTMAX_V2_BF16(_iter) \ - template void attn_softmax_v2( \ - sycl::ext::oneapi::bfloat16 * vals, \ - sycl::ext::oneapi::bfloat16 * mask, \ - sycl::ext::oneapi::bfloat16 * alibi, \ - float layer_scale, \ - bool triangular, \ - bool recompute, \ - bool local_attention, \ - int window_size, \ - int total_count, \ - int heads, \ - int sequence_length, \ - int num_seq, \ - int head_offset, \ - int mask_stride, \ - int mp_size, \ - int reduceWidth) - -#define FOREACH_ITERATIONS(cb) \ - cb(1); \ - cb(2); \ - cb(4); \ - cb(8); \ - cb(16); \ - cb(32); \ - cb(64) - -FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_HALF); -#ifdef BF16_AVAILABLE -FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_BF16); -#endif +/* #define DEF_ATTN_SOFTMAX_V2_HALF(_iter) \ */ +/* template void attn_softmax_v2( \ */ +/* sycl::half * vals, \ */ +/* sycl::half * mask, \ */ +/* sycl::half * alibi, \ */ +/* float layer_scale, \ */ +/* bool triangular, \ */ +/* bool recompute, \ */ +/* bool local_attention, \ */ +/* int window_size, \ */ +/* int total_count, \ */ +/* int heads, \ */ +/* int sequence_length, \ */ +/* int num_seq, \ */ +/* int head_offset, \ */ +/* int mask_stride, \ */ +/* int mp_size, \ */ +/* int reduceWidth) */ + +/* #define DEF_ATTN_SOFTMAX_V2_BF16(_iter) \ */ +/* template void attn_softmax_v2( \ */ +/* sycl::ext::oneapi::bfloat16 * vals, \ */ +/* sycl::ext::oneapi::bfloat16 * mask, \ */ +/* sycl::ext::oneapi::bfloat16 * alibi, \ */ +/* float layer_scale, \ */ +/* bool triangular, \ */ +/* bool recompute, \ */ +/* bool local_attention, \ */ +/* int window_size, \ */ +/* int total_count, \ */ +/* int heads, \ */ +/* int sequence_length, \ */ +/* int num_seq, \ */ +/* int head_offset, \ */ +/* int mask_stride, \ */ +/* int mp_size, \ */ +/* int reduceWidth) */ + +/* #define FOREACH_ITERATIONS(cb) \ */ +/* cb(1); \ */ +/* cb(2); \ */ +/* cb(4); \ */ +/* cb(8); \ */ +/* cb(16); \ */ +/* cb(32); \ */ +/* cb(64) */ + +/* FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_HALF); */ +/* #ifdef BF16_AVAILABLE */ +/* FOREACH_ITERATIONS(DEF_ATTN_SOFTMAX_V2_BF16); */ +/* #endif */ diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/transform.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/transform.dp.cpp index a035fca..dcd9c2f 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/transform.dp.cpp +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/csrc/transform.dp.cpp @@ -1,14 +1,27 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 // DeepSpeed Team -#ifndef __HIP_PLATFORM_AMD__ +#include #include -#include -#endif #include "conversion_utils.h" -#include "inference_cuda_layers.h" +#include "inference_sycl_layers.h" // only used to avoid compilation error due to lack of definition. #ifndef BF16_AVAILABLE @@ -16,47 +29,102 @@ using __nv_bfloat162 = sycl::half2; #endif // Bias add - -void bias_add_transform_0213(float* output, - float* k_cache, - float* v_cache, - const float* vals, - const float* bias, - int hidden_dim, - int seq_length, - unsigned seq_offset, - int heads, - int head_stride, - int num_kv, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - int head_ext, - int max_out_tokens, - float rope_theta) -{ +template +class bias_add_transform_0213 { + private: + T* output; + T* k_cache; + T* v_cache; + const T* vals; + const T* bias; + int hidden_dim; + int seq_length; + unsigned seq_offset; + int all_tokens; + int heads; + int head_stride; + int num_kv; + int rotary_dim; + bool rotate_half; + bool rotate_every_two; + int head_ext; + int max_out_tokens; + float rope_theta; + + public: + bias_add_transform_0213( + T* output, + T* k_cache, + T* v_cache, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + unsigned seq_offset, + int all_tokens, + int heads, + int head_stride, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int head_ext, + int max_out_tokens, + float rope_theta) + : output(output), + k_cache(k_cache), + v_cache(v_cache), + vals(vals), + bias(bias), + hidden_dim(hidden_dim), + seq_length(seq_length), + seq_offset(seq_offset), + all_tokens(all_tokens), + heads(heads), + head_stride(head_stride), + num_kv(num_kv), + rotary_dim(rotary_dim), + rotate_half(rotate_half), + rotate_every_two(rotate_every_two), + head_ext(head_ext), + max_out_tokens(max_out_tokens), + rope_theta(rope_theta) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); + using T2 = typename std::conditional< + std::is_same::value, + sycl::half2, + sycl::marray>::type; + unsigned half_dim = (rotary_dim << 3) >> 1; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_group(1); // Sequence ID (0-127) - int cnt = item_ct1.get_group(0) / head_ext; // Hidden count + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1); // Sequence ID (0-127) + int cnt = item_ct1.get_group(0) / head_ext; // Hidden count int d2 = item_ct1.get_local_id(1) + - (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); + sycl::float4 vals_arr; + sycl::float4 output_arr; + + T2* vals_half = reinterpret_cast(&vals_arr); + T2* output_half = reinterpret_cast(&output_arr); + const sycl::float4* vals_vec = reinterpret_cast(vals); - sycl::float4* output_vec = - reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); + sycl::float4* output_vec = reinterpret_cast( + cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length); - vals_vec += d1 * (d1_stride + num_kv * 2 * d2_stride); - vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); + vals_vec += (d1 * (d1_stride + num_kv * 2 * d2_stride)); + vals_vec += (cnt == 0 ? 0 : d1_stride) + + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride); output_vec += (d1 * d2_stride); @@ -64,87 +132,118 @@ void bias_add_transform_0213(float* output, output_vec += (d2 * d2_out_stride); unsigned seq_id = d1 + seq_offset; - sycl::float4 inputs = vals_vec[d3]; + int lane = d3 & 0x1f; if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { - sycl::float4 q = vals_vec[d3]; - sycl::float2* q_f = reinterpret_cast(&q); - if (rotate_every_two) { + sycl::float4 q = vals_vec[d3]; + T2* q_h = reinterpret_cast(&q); + if (rotate_every_two) { #pragma unroll - for (int o = 0; o < 2; o++) { - float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2); - inv_freq = 1.0 / dpct::pow(rope_theta, inv_freq) * (float)seq_id; - q_f[o].x() = - (-1.0 * q_f[o].y() * sycl::sin(inv_freq) + q_f[o].x() * sycl::cos(inv_freq)); - q_f[o].y() = (q_f[o].x() * sycl::sin(inv_freq) + q_f[o].y() * sycl::cos(inv_freq)); - } + for (int o = 0; o < 4; o++) { + float inv_freq = + (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); + inv_freq = 1.0 / dpct::pow(rope_theta, inv_freq) * (float)seq_id; + float q_data[2]; + q_data[0] = conversion::to(q_h[o][0]); + q_data[1] = conversion::to(q_h[o][1]); + q_h[o][0] = conversion::to( + -1.0 * q_data[1] * sycl::sin(inv_freq) + + q_data[0] * sycl::cos(inv_freq)); + q_h[o][1] = conversion::to( + q_data[0] * sycl::sin(inv_freq) + + q_data[1] * sycl::cos(inv_freq)); } - output_vec[d3] = q; + } + output_vec[d3] = q; } else - output_vec[d3] = inputs; -} + output_vec[d3] = vals_vec[d3]; + } +}; -#define ATTN_H 3 -#define MAX_SEQ_LINE 10 - -template -/* -DPCT1110:0: The total declared local variable size in device function bias_add_transform_0213 -exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find -the total register size available and adjust the code, or use smaller sub-group size to avoid high -register pressure. -*/ -void bias_add_transform_0213(T* output, // q - T* k_cache, - T* v_cache, - const T* vals, // qkv - const T* bias, - int hidden_dim, - int seq_length, - unsigned seq_offset, - int all_tokens, - int heads, - int head_stride, - int num_kv, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - int head_ext, - int max_out_tokens, - float rope_theta) -{ +template <> +class bias_add_transform_0213 { + private: + float* output; + float* k_cache; + float* v_cache; + const float* vals; + const float* bias; + int hidden_dim; + int seq_length; + int all_tokens; + unsigned seq_offset; + int heads; + int head_stride; + int num_kv; + int rotary_dim; + bool rotate_half; + bool rotate_every_two; + int head_ext; + int max_out_tokens; + float rope_theta; + + public: + bias_add_transform_0213( + float* output, + float* k_cache, + float* v_cache, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int all_tokens, + unsigned seq_offset, + int heads, + int head_stride, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int head_ext, + int max_out_tokens, + float rope_theta) + : output(output), + k_cache(k_cache), + v_cache(v_cache), + vals(vals), + bias(bias), + hidden_dim(hidden_dim), + seq_length(seq_length), + seq_offset(seq_offset), + all_tokens(all_tokens), + heads(heads), + head_stride(head_stride), + num_kv(num_kv), + rotary_dim(rotary_dim), + rotate_half(rotate_half), + rotate_every_two(rotate_every_two), + head_ext(head_ext), + max_out_tokens(max_out_tokens), + rope_theta(rope_theta) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - unsigned half_dim = (rotary_dim << 3) >> 1; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_group(1); // Sequence ID (0-127) - int cnt = item_ct1.get_group(0) / head_ext; // Hidden count + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1); // Sequence ID (0-127) + int cnt = item_ct1.get_group(0) / head_ext; // Hidden count int d2 = item_ct1.get_local_id(1) + - (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); - sycl::float4 vals_arr; - sycl::float4 output_arr; - - T2* vals_half = reinterpret_cast(&vals_arr); - T2* output_half = reinterpret_cast(&output_arr); - const sycl::float4* vals_vec = reinterpret_cast(vals); - sycl::float4* output_vec = - reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); + sycl::float4* output_vec = reinterpret_cast( + cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length); - vals_vec += (d1 * (d1_stride + num_kv * 2 * d2_stride)); - vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); + vals_vec += d1 * (d1_stride + num_kv * 2 * d2_stride); + vals_vec += (cnt == 0 ? 0 : d1_stride) + + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride); vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride); output_vec += (d1 * d2_stride); @@ -152,160 +251,162 @@ void bias_add_transform_0213(T* output, // q output_vec += (d2 * d2_out_stride); unsigned seq_id = d1 + seq_offset; - + sycl::float4 inputs = vals_vec[d3]; int lane = d3 & 0x1f; if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { - sycl::float4 q = vals_vec[d3]; - T2* q_h = reinterpret_cast(&q); - if (rotate_every_two) { + sycl::float4 q = vals_vec[d3]; + sycl::float2* q_f = reinterpret_cast(&q); + if (rotate_every_two) { #pragma unroll - for (int o = 0; o < 4; o++) { - float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); - inv_freq = 1.0 / dpct::pow(rope_theta, inv_freq) * (float)seq_id; - float q_data[2]; - q_data[0] = conversion::to(q_h[o][0]); - q_data[1] = conversion::to(q_h[o][1]); - q_h[o][0] = conversion::to(-1.0 * q_data[1] * sycl::sin(inv_freq) + - q_data[0] * sycl::cos(inv_freq)); - q_h[o][1] = conversion::to(q_data[0] * sycl::sin(inv_freq) + - q_data[1] * sycl::cos(inv_freq)); - } + for (int o = 0; o < 2; o++) { + float inv_freq = + (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2); + inv_freq = 1.0 / dpct::pow(rope_theta, inv_freq) * (float)seq_id; + q_f[o].x() = + (-1.0 * q_f[o].y() * sycl::sin(inv_freq) + + q_f[o].x() * sycl::cos(inv_freq)); + q_f[o].y() = + (q_f[o].x() * sycl::sin(inv_freq) + + q_f[o].y() * sycl::cos(inv_freq)); } - output_vec[d3] = q; + } + output_vec[d3] = q; } else - output_vec[d3] = vals_vec[d3]; -} + output_vec[d3] = inputs; + } +}; + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 // [B S C*H] - > C * [B A S N] template <> -void launch_bias_add_transform_0213(float* output, - float* k_cache, - float* v_cache, - const float* vals, - const float* bias, - int batch_size, - int seq_length, - unsigned seq_offset, - int all_tokens, - int hidden_dim, - int heads, - int num_kv, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - dpct::queue_ptr stream, - int trans_count, - int max_out_tokens, - float rope_theta) -{ - hidden_dim >>= 2; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - - sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); - sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); - - /* - DPCT1049:1: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), - [=](sycl::nd_item<3> item_ct1) { - bias_add_transform_0213(output, - k_cache, - v_cache, - vals, - bias, - hidden_dim, - seq_length, - seq_offset, - heads, - num_kv > 0 ? (heads / num_kv) : 1, - num_kv > 0 ? num_kv : heads, - rotary_dim >> 2, - rotate_half, - rotate_every_two, - head_ext, - max_out_tokens, - rope_theta); - }); +void launch_bias_add_transform_0213( + float* output, + float* k_cache, + float* v_cache, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + dpct::queue_ptr stream, + int trans_count, + int max_out_tokens, + float rope_theta) { + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); + + bias_add_transform_0213 fn( + output, + k_cache, + v_cache, + vals, + bias, + hidden_dim, + seq_length, + seq_offset, + 0, + heads, + num_kv > 0 ? (heads / num_kv) : 1, + num_kv > 0 ? num_kv : heads, + rotary_dim >> 2, + rotate_half, + rotate_every_two, + head_ext, + max_out_tokens, + rope_theta); + stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); } template -void launch_bias_add_transform_0213(T* output, - T* k_cache, - T* v_cache, - const T* vals, - const T* bias, - int batch_size, - int seq_length, - unsigned seq_offset, - int all_tokens, - int hidden_dim, - int heads, - int num_kv, - int rotary_dim, - bool rotate_half, - bool rotate_every_two, - dpct::queue_ptr stream, - int trans_count, - int max_out_tokens, - float rope_theta) -{ - hidden_dim >>= 3; - int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; - sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); - sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); - /* - DPCT1049:2: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), - [=](sycl::nd_item<3> item_ct1) { - bias_add_transform_0213(output, - k_cache, - v_cache, - vals, - bias, - hidden_dim, - seq_length, - seq_offset, - all_tokens, - heads, - num_kv > 0 ? (heads / num_kv) : 1, - num_kv > 0 ? num_kv : heads, - rotary_dim >> 3, - rotate_half, - rotate_every_two, - head_ext, - max_out_tokens, - rope_theta); - }); - } +void launch_bias_add_transform_0213( + T* output, + T* k_cache, + T* v_cache, + const T* vals, + const T* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + dpct::queue_ptr stream, + int trans_count, + int max_out_tokens, + float rope_theta) { + hidden_dim >>= 3; + int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); + /* + DPCT1049:2: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + bias_add_transform_0213 fn( + output, + k_cache, + v_cache, + vals, + bias, + hidden_dim, + seq_length, + seq_offset, + all_tokens, + heads, + num_kv > 0 ? (heads / num_kv) : 1, + num_kv > 0 ? num_kv : heads, + rotary_dim >> 3, + rotate_half, + rotate_every_two, + head_ext, + max_out_tokens, + rope_theta); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); + } } -#define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \ - template void launch_bias_add_transform_0213(T*, \ - T*, \ - T*, \ - const T*, \ - const T*, \ - int, \ - int, \ - unsigned, \ - int, \ - int, \ - int, \ - int, \ - int, \ - bool, \ - bool, \ - dpct::queue_ptr, \ - int, \ - int, \ - float) +#define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \ + template void launch_bias_add_transform_0213( \ + T*, \ + T*, \ + T*, \ + const T*, \ + const T*, \ + int, \ + int, \ + unsigned, \ + int, \ + int, \ + int, \ + int, \ + int, \ + bool, \ + bool, \ + dpct::queue_ptr, \ + int, \ + int, \ + float) #ifdef BF16_AVAILABLE INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(sycl::ext::oneapi::bfloat16); @@ -314,44 +415,65 @@ INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(sycl::half); // Bias add -void pad_add_transform_0213(float* output, - const float* vals, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size) -{ -} +/* void pad_add_transform_0213(float* output, */ +/* const float* vals, */ +/* int hidden_dim, */ +/* int seq_length, */ +/* int padded_seq_len, */ +/* int heads, */ +/* int padded_head_size) */ +/* { */ +/* } */ template -void pad_add_transform_0213(T* output, - const T* vals, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size) -{ +class pad_add_transform_0213 { + private: + T* output; + const T* vals; + int hidden_dim; + int seq_length; + int padded_seq_len; + int heads; + int padded_head_size; + + public: + pad_add_transform_0213( + T* output, + const T* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) + : output(output), + vals(vals), + hidden_dim(hidden_dim), + seq_length(seq_length), + padded_seq_len(padded_seq_len), + heads(heads), + padded_head_size(padded_head_size) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; + using T2 = typename std::conditional< + std::is_same::value, + sycl::half2, + sycl::marray>::type; sycl::float4 ZERO; const T2 zero_h = conversion::to(0.f); T2* ZERO_h = reinterpret_cast(&ZERO); #pragma unroll - for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + for (int i = 0; i < 4; i++) + ZERO_h[i] = zero_h; int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0 = item_ct1.get_group(2); // Batch + int d0 = item_ct1.get_group(2); // Batch int d1 = item_ct1.get_group(1) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0); // Sequence ID (0-127) - int d2 = item_ct1.get_local_id(1); // Head (0-11) - int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + item_ct1.get_local_id(0); // Sequence ID (0-127) + int d2 = item_ct1.get_local_id(1); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) int d2_out_stride = padded_head_size * padded_seq_len; int d0_out_stride = heads * d2_out_stride; @@ -368,62 +490,58 @@ void pad_add_transform_0213(T* output, output_vec += (d2 * d2_out_stride); if (d3 < d2_stride && d1 < seq_length) - output_vec[d3] = vals_vec[d3]; + output_vec[d3] = vals_vec[d3]; else - output_vec[d3] = ZERO; -} + output_vec[d3] = ZERO; + } +}; // [B S C*H] - > C * [B A S N] template <> -void launch_pad_add_transform_0213(float* output, - const float* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - dpct::queue_ptr stream) -{ -} +void launch_pad_add_transform_0213( + float* output, + const float* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + dpct::queue_ptr stream) {} template -void launch_pad_add_transform_0213(T* output, - const T* vals, - int batch_size, - int hidden_dim, - int seq_length, - int padded_seq_len, - int heads, - int padded_head_size, - dpct::queue_ptr stream) -{ - hidden_dim >>= 3; - sycl::range<3> block_dim(2, heads, (padded_head_size >> 3)); - sycl::range<3> grid_dim(1, padded_seq_len / 2, batch_size); - /* - DPCT1049:3: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp64, sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), - [=](sycl::nd_item<3> item_ct1) { - pad_add_transform_0213(output, - vals, - hidden_dim, - seq_length, - padded_seq_len, - heads, - padded_head_size >> 3); - }); - } +void launch_pad_add_transform_0213( + T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + dpct::queue_ptr stream) { + hidden_dim >>= 3; + sycl::range<3> block_dim(2, heads, (padded_head_size >> 3)); + sycl::range<3> grid_dim(1, padded_seq_len / 2, batch_size); + { + dpct::has_capability_or_fail( + stream->get_device(), {sycl::aspect::fp64, sycl::aspect::fp16}); + pad_add_transform_0213 fn( + output, + vals, + hidden_dim, + seq_length, + padded_seq_len, + heads, + padded_head_size >> 3); + stream->parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), fn); + } } #define INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(T) \ - template void launch_pad_add_transform_0213( \ - T*, const T*, int, int, int, int, int, int, dpct::queue_ptr); + template void launch_pad_add_transform_0213( \ + T*, const T*, int, int, int, int, int, int, dpct::queue_ptr); INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(sycl::half); #ifdef BF16_AVAILABLE @@ -431,223 +549,305 @@ INSTANTIATE_LAUNCH_PAD_ADD_TRANSFORM_0213_SIMPLE(sycl::ext::oneapi::bfloat16); #endif // Bias add -template -void bias_add_transform_0213(T* output, - const T* vals, - const T* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext); - -template <> -void bias_add_transform_0213(float* output, - const float* vals, - const float* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ - auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_group(1); // Sequence ID (0-127) - int cnt = item_ct1.get_group(0) / head_ext; // Hidden count - int d2 = item_ct1.get_local_id(1) + - (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = item_ct1.get_local_id(2); // Values (groups of 4) - - const sycl::float4* vals_vec = reinterpret_cast(vals); - const sycl::float4* bias_vec = reinterpret_cast(bias); - sycl::float4* output_vec = reinterpret_cast(output); - - sycl::float4 inputs = - vals_vec[d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext) + cnt * d1_stride + - d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext) + d2 * d2_stride + d3]; - sycl::float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; - - sycl::float4 outputs; - outputs.x() = inputs.x() + biases.x(); - outputs.y() = inputs.y() + biases.y(); - outputs.z() = inputs.z() + biases.z(); - outputs.w() = inputs.w() + biases.w(); - - output_vec[cnt * d0_out_stride * item_ct1.get_group_range(2) + d0 * d0_out_stride + - d1 * d1_out_stride + d2 * d2_out_stride + d3] = outputs; -} - -template -void bias_add_transform_0213(T* output, - const T* vals, - const T* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ - auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d2_out_stride = d2_stride * seq_length; - - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_group(1); // Sequence ID (0-127) - int cnt = item_ct1.get_group(0) / head_ext; // Hidden count - int d2 = item_ct1.get_local_id(1) + - (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = item_ct1.get_local_id(2); // Values (groups of 4) - - sycl::float4 vals_arr; - sycl::float4 bias_arr; - sycl::float4 output_arr; - T2* vals_half = reinterpret_cast(&vals_arr); - T2* bias_half = reinterpret_cast(&bias_arr); - T2* output_half = reinterpret_cast(&output_arr); - - const sycl::float4* vals_vec = reinterpret_cast(vals); - const sycl::float4* bias_vec = reinterpret_cast(bias); - sycl::float4* output_vec = reinterpret_cast(output); - - vals_vec += (d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext)); - vals_vec += (d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext)); - vals_vec += (cnt * d1_stride); - vals_vec += (d2 * d2_stride); - - bias_vec += (cnt * d1_stride); - bias_vec += (d2 * d2_stride); - - output_vec += (cnt * d0_stride * item_ct1.get_group_range(2)); - output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); - output_vec += (d2 * d2_out_stride); - - bias_arr = bias_vec[d3]; - vals_arr = vals_vec[d3]; - - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; - output_vec[d3] = output_arr; -} +/* template */ +/* class bias_add_transform_0213 { */ +/* private: */ +/* T* output; */ +/* const T* vals; */ +/* const T* bias; */ +/* int hidden_dim; */ +/* int seq_length; */ +/* int heads; */ +/* int head_ext; */ +/* }; */ + +/* template <> */ +/* void bias_add_transform_0213( */ +/* float* output, */ +/* const float* vals, */ +/* const float* bias, */ +/* int hidden_dim, */ +/* int seq_length, */ +/* int heads, */ +/* int head_ext) { */ +/* auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); */ +/* int d0_stride = hidden_dim * seq_length; */ +/* int d1_stride = hidden_dim; */ +/* int d2_stride = hidden_dim / heads; */ + +/* int d0_out_stride = d0_stride; */ +/* int d1_out_stride = d2_stride; */ +/* int d2_out_stride = d2_stride * seq_length; */ + +/* int d0 = item_ct1.get_group(2); // Batch */ +/* int d1 = item_ct1.get_group(1); // Sequence ID (0-127) */ +/* int cnt = item_ct1.get_group(0) / head_ext; // Hidden count */ +/* int d2 = item_ct1.get_local_id(1) + */ +/* (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + */ +/* int d3 = item_ct1.get_local_id(2); // Values (groups of 4) */ + +/* const sycl::float4* vals_vec = reinterpret_cast(vals); + */ +/* const sycl::float4* bias_vec = reinterpret_cast(bias); + */ +/* sycl::float4* output_vec = reinterpret_cast(output); */ + +/* sycl::float4 inputs = vals_vec */ +/* [d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext) + */ +/* cnt * d1_stride + */ +/* d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext) + */ +/* d2 * d2_stride + d3]; */ +/* sycl::float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; */ + +/* sycl::float4 outputs; */ +/* outputs.x() = inputs.x() + biases.x(); */ +/* outputs.y() = inputs.y() + biases.y(); */ +/* outputs.z() = inputs.z() + biases.z(); */ +/* outputs.w() = inputs.w() + biases.w(); */ + +/* output_vec */ +/* [cnt * d0_out_stride * item_ct1.get_group_range(2) + d0 * d0_out_stride + * + */ +/* d1 * d1_out_stride + d2 * d2_out_stride + d3] = outputs; */ +/* } */ + +/* template */ +/* void bias_add_transform_0213( */ +/* T* output, */ +/* const T* vals, */ +/* const T* bias, */ +/* int hidden_dim, */ +/* int seq_length, */ +/* int heads, */ +/* int head_ext) { */ +/* auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); */ +/* using T2 = typename std::conditional< */ +/* std::is_same::value, */ +/* sycl::half2, */ +/* sycl::marray>::type; */ +/* int d0_stride = hidden_dim * seq_length; */ +/* int d1_stride = hidden_dim; */ +/* int d2_stride = hidden_dim / heads; */ + +/* int d2_out_stride = d2_stride * seq_length; */ + +/* int d0 = item_ct1.get_group(2); // Batch */ +/* int d1 = item_ct1.get_group(1); // Sequence ID (0-127) */ +/* int cnt = item_ct1.get_group(0) / head_ext; // Hidden count */ +/* int d2 = item_ct1.get_local_id(1) + */ +/* (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + */ +/* int d3 = item_ct1.get_local_id(2); // Values (groups of 4) */ + +/* sycl::float4 vals_arr; */ +/* sycl::float4 bias_arr; */ +/* sycl::float4 output_arr; */ +/* T2* vals_half = reinterpret_cast(&vals_arr); */ +/* T2* bias_half = reinterpret_cast(&bias_arr); */ +/* T2* output_half = reinterpret_cast(&output_arr); */ + +/* const sycl::float4* vals_vec = reinterpret_cast(vals); + */ +/* const sycl::float4* bias_vec = reinterpret_cast(bias); + */ +/* sycl::float4* output_vec = reinterpret_cast(output); */ + +/* vals_vec += (d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext)); */ +/* vals_vec += (d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext)); */ +/* vals_vec += (cnt * d1_stride); */ +/* vals_vec += (d2 * d2_stride); */ + +/* bias_vec += (cnt * d1_stride); */ +/* bias_vec += (d2 * d2_stride); */ + +/* output_vec += (cnt * d0_stride * item_ct1.get_group_range(2)); */ +/* output_vec += (d1 * d2_stride); */ +/* output_vec += (d0 * d0_stride); */ +/* output_vec += (d2 * d2_out_stride); */ + +/* bias_arr = bias_vec[d3]; */ +/* vals_arr = vals_vec[d3]; */ + +/* output_half[0] = vals_half[0] + bias_half[0]; */ +/* output_half[1] = vals_half[1] + bias_half[1]; */ +/* output_half[2] = vals_half[2] + bias_half[2]; */ +/* output_half[3] = vals_half[3] + bias_half[3]; */ +/* output_vec[d3] = output_arr; */ +/* } */ + +/*template */ +/*void bias_add_transform_0213_v2( */ +/* T* output, */ +/* const T* vals, */ +/* const T* bias, */ +/* int hidden_dim, */ +/* int seq_length, */ +/* int heads) { */ +/* auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); */ +/* using T2 = typename std::conditional< */ +/* std::is_same::value, */ +/* sycl::half2, */ +/* sycl::marray>::type; */ +/* auto& in_data = */ +/* *sycl::ext::oneapi::group_local_memory_for_overwrite( + */ +/* sycl::ext::oneapi::experimental::this_group<3>()); */ + +/* int d0_stride = hidden_dim * seq_length; */ +/* int d1_stride = hidden_dim; */ +/* int d2_stride = hidden_dim / heads; */ +/* int iteration_stride = */ +/* d1_stride * item_ct1.get_local_range(0); // Hidden * 3 / 8 */ +/* int batch_stride = */ +/* d0_stride * item_ct1.get_local_range(0); // Hidden * S * 3 / 8 */ + +/* int d0_out_stride = d0_stride; */ +/* int d1_out_stride = d2_stride; */ +/* int d2_out_stride = d2_stride * seq_length; */ + +/* int d0 = item_ct1.get_group(2); // Batch */ +/* int d1 = item_ct1.get_group(1); // Sequence ID (0-127) */ +/* int cnt = item_ct1.get_local_id(0); // blockIdx.z; // Hidden count */ +/* int d2 = item_ct1.get_local_id(1); // Head (0-11) */ +/* int d3 = item_ct1.get_local_id(2); // Values (groups of 4) */ + +/* sycl::float4 vals_arr[1]; */ +/* sycl::float4 bias_arr[1]; */ +/* sycl::float4 output_arr[1]; */ +/* T2* vals_half = reinterpret_cast(vals_arr); */ +/* T2* bias_half = reinterpret_cast(bias_arr); */ +/* T2* output_half = reinterpret_cast(output_arr); */ + +/* const sycl::float4* vals_vec = reinterpret_cast(vals); + */ +/* const sycl::float4* bias_vec = reinterpret_cast(bias); + */ +/* sycl::float4* output_vec = reinterpret_cast(output); */ + +/* int iter_index = cnt * d1_stride + d2 * d2_stride + d3; */ +/* int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); */ +/* bias_arr[0] = bias_vec[iter_index]; */ + +/*#pragma unroll */ +/* for (int iter = 0; iter < 2; iter++) { */ +/* int iter_id = iter * iteration_stride + iter_index; */ +/* vals_arr[0] = vals_vec[input_offset + iter_id]; */ + +/* output_half[0] = vals_half[0] + bias_half[0]; */ +/* output_half[1] = vals_half[1] + bias_half[1]; */ +/* output_half[2] = vals_half[2] + bias_half[2]; */ +/* output_half[3] = vals_half[3] + bias_half[3]; */ + +/* in_data[iter_id] = output_arr[0]; */ +/* } */ +/* /1* */ +/* DPCT1065:7: Consider replacing sycl::nd_item::barrier() with */ +/* sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better */ +/* performance if there is no access to global memory. */ +/* *1/ */ +/* item_ct1.barrier(); */ + +/* iteration_stride = */ +/* item_ct1.get_local_range(0) * (item_ct1.get_local_range(1) >> 1); */ +/* int matrix_stride = (d0_out_stride * item_ct1.get_group_range(2)); */ +/* int head_count = (d2 >> 1) + cnt * (item_ct1.get_local_range(1) >> 1); */ + +/* int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + */ +/* (d2 % 2) * d2_stride; */ + +/*#pragma unroll */ +/* for (int iter = 0; iter < 2; iter++) { */ +/* int iter_row = (iter * iteration_stride) + head_count; */ +/* int iter_offset = (iter_row % item_ct1.get_local_range(1)) * d2_out_stride + * + */ +/* (iter_row / item_ct1.get_local_range(1)) * matrix_stride; */ +/* output_vec[out_index + iter_offset] = in_data */ +/* [iter_row * d2_stride + d3 + */ +/* (d2 % 2) * (d1_stride * item_ct1.get_local_range(0))]; */ +/* } */ +/*} */ template -/* -DPCT1110:4: The total declared local variable size in device function bias_add_transform_0213_v2 -exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find -the total register size available and adjust the code, or use smaller sub-group size to avoid high -register pressure. -*/ -void bias_add_transform_0213_v2(T* output, - const T* vals, - const T* bias, - int hidden_dim, - int seq_length, - int heads) -{ +class transform4d_0213 { + private: + T* out; + const T* in; + int heads; + int seq_length; + int hidden_dim; + int head_ext; + + public: + transform4d_0213( + T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) + : out(out), + in(in), + heads(heads), + seq_length(seq_length), + hidden_dim(hidden_dim), + head_ext(head_ext) {} + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - using T2 = typename std::conditional::value, - sycl::half2, - sycl::marray>::type; - auto& in_data = *sycl::ext::oneapi::group_local_memory_for_overwrite( - sycl::ext::oneapi::experimental::this_group<3>()); - - int d0_stride = hidden_dim * seq_length; + int d0_stride = hidden_dim * (seq_length / head_ext); int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int iteration_stride = d1_stride * item_ct1.get_local_range(0); // Hidden * 3 / 8 - int batch_stride = d0_stride * item_ct1.get_local_range(0); // Hidden * S * 3 / 8 - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_group(1); // Sequence ID (0-127) - int cnt = item_ct1.get_local_id(0); // blockIdx.z; // Hidden count - int d2 = item_ct1.get_local_id(1); // Head (0-11) - int d3 = item_ct1.get_local_id(2); // Values (groups of 4) - - sycl::float4 vals_arr[1]; - sycl::float4 bias_arr[1]; - sycl::float4 output_arr[1]; - T2* vals_half = reinterpret_cast(vals_arr); - T2* bias_half = reinterpret_cast(bias_arr); - T2* output_half = reinterpret_cast(output_arr); - const sycl::float4* vals_vec = reinterpret_cast(vals); - const sycl::float4* bias_vec = reinterpret_cast(bias); - sycl::float4* output_vec = reinterpret_cast(output); - - int iter_index = cnt * d1_stride + d2 * d2_stride + d3; - int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); - bias_arr[0] = bias_vec[iter_index]; - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_id = iter * iteration_stride + iter_index; - vals_arr[0] = vals_vec[input_offset + iter_id]; - - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; - - in_data[iter_id] = output_arr[0]; - } - /* - DPCT1065:7: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there - is no access to global memory. - */ - item_ct1.barrier(); + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_local_id(1) + + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head + int d2 = item_ct1.get_group(0) / head_ext; // Sequence + int cnt = item_ct1.get_group(1); // Hidden count + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) - iteration_stride = item_ct1.get_local_range(0) * (item_ct1.get_local_range(1) >> 1); - int matrix_stride = (d0_out_stride * item_ct1.get_group_range(2)); - int head_count = (d2 >> 1) + cnt * (item_ct1.get_local_range(1) >> 1); + const sycl::float4* in_vec = reinterpret_cast(in); + sycl::float4* out_vec = reinterpret_cast(out); - int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + in_vec += (cnt * d0_stride * item_ct1.get_group_range(2)); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_row = (iter * iteration_stride) + head_count; - int iter_offset = (iter_row % item_ct1.get_local_range(1)) * d2_out_stride + - (iter_row / item_ct1.get_local_range(1)) * matrix_stride; - output_vec[out_index + iter_offset] = - in_data[iter_row * d2_stride + d3 + - (d2 % 2) * (d1_stride * item_ct1.get_local_range(0))]; - } -} + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * item_ct1.get_group_range(1)); + out_vec += (d2 * d1_stride * item_ct1.get_group_range(1)); -template -void transform4d_0213(T* out, - const T* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext); + out_vec[d3] = in_vec[d3]; + } +}; template <> -void transform4d_0213(float* out, - const float* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) -{ +class transform4d_0213 { + private: + float* out; + const float* in; + int heads; + int seq_length; + int hidden_dim; + int head_ext; + + public: + transform4d_0213( + float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) + : out(out), + in(in), + heads(heads), + seq_length(seq_length), + hidden_dim(hidden_dim), + head_ext(head_ext) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); int d0_stride = hidden_dim * seq_length; int d1_stride = d0_stride / heads; @@ -657,171 +857,168 @@ void transform4d_0213(float* out, int d1_out_stride = d2_stride; int d2_out_stride = hidden_dim; - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_group(1) / ((seq_length - 1) / item_ct1.get_local_range(1) + 1); // Head - int d2 = (item_ct1.get_local_id(1) + item_ct1.get_local_range(1) * item_ct1.get_group(1)) % - seq_length; + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1) / + ((seq_length - 1) / item_ct1.get_local_range(1) + 1); // Head + int d2 = (item_ct1.get_local_id(1) + + item_ct1.get_local_range(1) * item_ct1.get_group(1)) % + seq_length; int cnt = item_ct1.get_group(0); - int d3 = item_ct1.get_local_id(2); // Values (groups of 8) + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) if (d2 < seq_length) { - const sycl::float4* in_vec = reinterpret_cast(in); - sycl::float4* out_vec = reinterpret_cast(out); - - sycl::float4 vals_vec = in_vec[cnt * d0_stride * item_ct1.get_group_range(2) + - d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; - out_vec[d0 * d0_out_stride * item_ct1.get_group_range(0) + cnt * d2_out_stride + - d1 * d1_out_stride + d2 * d2_out_stride * item_ct1.get_group_range(0) + d3] = - vals_vec; + const sycl::float4* in_vec = reinterpret_cast(in); + sycl::float4* out_vec = reinterpret_cast(out); + + sycl::float4 vals_vec = in_vec + [cnt * d0_stride * item_ct1.get_group_range(2) + d0 * d0_stride + + d1 * d1_stride + d2 * d2_stride + d3]; + out_vec + [d0 * d0_out_stride * item_ct1.get_group_range(0) + + cnt * d2_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride * item_ct1.get_group_range(0) + d3] = vals_vec; } -} - -template -void transform4d_0213(T* out, - const T* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) -{ - auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); - int d0_stride = hidden_dim * (seq_length / head_ext); - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0 = item_ct1.get_group(2); // Batch - int d1 = - item_ct1.get_local_id(1) + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head - int d2 = item_ct1.get_group(0) / head_ext; // Sequence - int cnt = item_ct1.get_group(1); // Hidden count - int d3 = item_ct1.get_local_id(2); // Values (groups of 8) - - const sycl::float4* in_vec = reinterpret_cast(in); - sycl::float4* out_vec = reinterpret_cast(out); - - in_vec += (cnt * d0_stride * item_ct1.get_group_range(2)); - in_vec += (d0 * d0_stride); - in_vec += (d2 * d2_stride); - in_vec += (d1 * d2_stride * seq_length); - - out_vec += (cnt * d1_stride); - out_vec += (d1 * d2_stride); - out_vec += (d0 * d0_stride * item_ct1.get_group_range(1)); - out_vec += (d2 * d1_stride * item_ct1.get_group_range(1)); - - out_vec[d3] = in_vec[d3]; -} + } +}; template -void transform4d_0213_v2(T* out, const T* in, int heads, int seq_length, int hidden_dim) -{ +class transform4d_0213_v2 { + private: + T* out; + const T* in; + int heads; + int seq_length; + int hidden_dim; + + public: + transform4d_0213_v2( + T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim) + : out(out), + in(in), + heads(heads), + seq_length(seq_length), + hidden_dim(hidden_dim) {} + + void operator()(sycl::nd_item<3>) const { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); -auto& in_data = *sycl::ext::oneapi::group_local_memory_for_overwrite( - sycl::ext::oneapi::experimental::this_group<3>()); + auto& in_data = *sycl::ext::oneapi::group_local_memory_for_overwrite< + sycl::float4[3072]>(sycl::ext::oneapi::experimental::this_group<3>()); int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0 = item_ct1.get_group(2); // Batch - int d1 = item_ct1.get_local_id(1); // Head - int d2 = item_ct1.get_group(1); // Sequence - int cnt = item_ct1.get_local_id(0); // Hidden count - int d3 = item_ct1.get_local_id(2); // Values (groups of 8) + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_local_id(1); // Head + int d2 = item_ct1.get_group(1); // Sequence + int cnt = item_ct1.get_local_id(0); // Hidden count + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) const sycl::float4* in_vec = reinterpret_cast(in); sycl::float4* out_vec = reinterpret_cast(out); - int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int input_offset = + d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; int head_count = (d1 >> 1) + cnt * (item_ct1.get_local_range(1) >> 1); - int iteration_stride = item_ct1.get_local_range(0) * (item_ct1.get_local_range(1) >> 1); + int iteration_stride = + item_ct1.get_local_range(0) * (item_ct1.get_local_range(1) >> 1); int matrix_stride = (d0_stride * item_ct1.get_group_range(2)); #pragma unroll for (int iter = 0; iter < 2; iter++) { - int iter_row = iter * iteration_stride + head_count; - int iter_offset = (iter_row % item_ct1.get_local_range(1)) * d2_stride; - - in_data[d3 + iter_offset + - (iter_row / item_ct1.get_local_range(1) + (d1 % 2) * item_ct1.get_local_range(0)) * - d1_stride] = in_vec[input_offset + iter_offset * seq_length + - (iter_row / item_ct1.get_local_range(1)) * matrix_stride]; + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % item_ct1.get_local_range(1)) * d2_stride; + + in_data + [d3 + iter_offset + + (iter_row / item_ct1.get_local_range(1) + + (d1 % 2) * item_ct1.get_local_range(0)) * + d1_stride] = in_vec + [input_offset + iter_offset * seq_length + + (iter_row / item_ct1.get_local_range(1)) * matrix_stride]; } /* DPCT1065:8: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there - is no access to global memory. + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. */ item_ct1.barrier(); iteration_stride = d1_stride * item_ct1.get_local_range(0); int iter_index = cnt * d1_stride + d1 * d2_stride + d3; - int output_offset = d0 * d0_stride * item_ct1.get_local_range(0) + d2 * (iteration_stride << 1); + int output_offset = d0 * d0_stride * item_ct1.get_local_range(0) + + d2 * (iteration_stride << 1); #pragma unroll for (int iter = 0; iter < 2; iter++) { - int iter_id = iter * iteration_stride + iter_index; - out_vec[output_offset + iter_id] = in_data[iter_id]; + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; } -} + } +}; // 3 * [B A S N] - > [B S C*H] template <> -void launch_transform4d_0213(float* out, - const float* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - dpct::queue_ptr stream, - int trans_count) -{ - hidden_dim >>= 2; - sycl::range<3> grid_dims(trans_count, heads * ((seq_length - 1) / 8 + 1), batch_size); - sycl::range<3> block_dims(1, 8, hidden_dim / heads); - /* - DPCT1049:5: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - transform4d_0213(out, in, heads, seq_length, hidden_dim, 1); - }); - } +void launch_transform4d_0213( + float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + dpct::queue_ptr stream, + int trans_count) { + hidden_dim >>= 2; + sycl::range<3> grid_dims( + trans_count, heads * ((seq_length - 1) / 8 + 1), batch_size); + sycl::range<3> block_dims(1, 8, hidden_dim / heads); + /* + DPCT1049:5: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + transform4d_0213 fn(out, in, heads, seq_length, hidden_dim, 1); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } template -void launch_transform4d_0213(T* out, - const T* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - dpct::queue_ptr stream, - int trans_count) -{ - hidden_dim >>= 3; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - sycl::range<3> grid_dims((seq_length * head_ext), trans_count, batch_size); - sycl::range<3> block_dims(1, (heads / head_ext), hidden_dim / heads); - /* - DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the limit. To get the - device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - transform4d_0213(out, in, heads, seq_length, hidden_dim, head_ext); - }); - } +void launch_transform4d_0213( + T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + dpct::queue_ptr stream, + int trans_count) { + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> grid_dims((seq_length * head_ext), trans_count, batch_size); + sycl::range<3> block_dims(1, (heads / head_ext), hidden_dim / heads); + /* + DPCT1049:6: The work-group size passed to the SYCL kernel may exceed the + limit. To get the device limit, query info::device::max_work_group_size. + Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + transform4d_0213 fn(out, in, heads, seq_length, hidden_dim, head_ext); + stream->parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), fn); + } } -#define INSTANTIATE_2B_LAUNCH_TRANSFORM4D(T) \ - template void launch_transform4d_0213( \ - T*, const T*, int, int, int, int, dpct::queue_ptr, int); +#define INSTANTIATE_2B_LAUNCH_TRANSFORM4D(T) \ + template void launch_transform4d_0213( \ + T*, const T*, int, int, int, int, dpct::queue_ptr, int); INSTANTIATE_2B_LAUNCH_TRANSFORM4D(sycl::half) #ifdef BF16_AVAILABLE diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_context.h b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_context.h index 74cb853..c4e1b94 100644 --- a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_context.h +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_context.h @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ // Copyright (c) Microsoft Corporation. // SPDX-License-Identifier: Apache-2.0 @@ -5,336 +20,361 @@ #pragma once -// // #include +#include +#include #include #include -#include #include #include #include -#include #include -#include #include +#include #include namespace at { - namespace cuda { - inline dpct::queue_ptr getCurrentCUDAStream() { - auto device_type = c10::DeviceType::XPU; - c10::impl::VirtualGuardImpl impl(device_type); - c10::Stream c10_stream = impl.getStream(c10::Device(device_type)); - auto& queue = xpu::get_queue_from_stream(c10_stream); - return &queue; - } +namespace sycl { +inline dpct::queue_ptr getCurrentSYCLStream() { + auto device_type = c10::DeviceType::XPU; + c10::impl::VirtualGuardImpl impl(device_type); + c10::Stream c10_stream = impl.getStream(c10::Device(device_type)); + auto& queue = xpu::get_queue_from_stream(c10_stream); + return &queue; +} - inline dpct::queue_ptr getStreamFromPool(bool) { - // not implemented - return nullptr; - } - } +inline dpct::queue_ptr getStreamFromPool(bool) { + // not implemented + return nullptr; } +} // namespace sycl +} // namespace at + #define MEGABYTE (1024 * 1024) #define GIGABYTE (1024 * 1024 * 1024) // TODO: refactor out #define WARP_SIZE 32 -#define CUDA_CHECK(callstr) \ - { \ - cudaError_t error_code = callstr; \ - if (error_code != cudaSuccess) { \ - std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ - assert(0); \ - } \ - } +#define SYCL_CHECK(callstr) \ + { \ + syclError_t error_code = callstr; \ + if (error_code != syclSuccess) { \ + std::cerr << "SYCL error " << error_code << " at " << __FILE__ << ":" \ + << __LINE__; \ + assert(0); \ + } \ + } -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) +#define SYCL_1D_KERNEL_LOOP(i, n) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) -#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ - for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) +#define SYCL_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \ + j += blockDim.y * gridDim.y) -#define DS_CUDA_NUM_THREADS 512 +#define DS_SYCL_NUM_THREADS 512 #define DS_MAXIMUM_NUM_BLOCKS 262144 -inline int DS_GET_BLOCKS(const int N) -{ - return std::max( - std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), - // Use at least 1 block, since CUDA does not allow empty block - 1); +inline int DS_GET_BLOCKS(const int N) { + return std::max( + std::min( + (N + DS_SYCL_NUM_THREADS - 1) / DS_SYCL_NUM_THREADS, + DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since SYCL does not allow empty block + 1); } class InferenceContext { -public: - InferenceContext() - try : _workspace(nullptr), _seed(42), _curr_offset(0), - _stream(&dpct::get_in_order_queue()), _free_memory_size(0), _num_tokens(1), - _attention_unfused_workspace_offset(0), _workSpaceSize(0) { - _workSpaceSize = 0; - _workspace = 0; - - int stat = DPCT_CHECK_ERROR(_cublasHandle = &dpct::get_in_order_queue()); - if (stat != 0) { - // It would be nice to use cublasGetStatusName and - // cublasGetStatusString, but they were only added in CUDA 11.4.2. - auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") + - std::to_string(stat); - std::cerr << message << std::endl; - throw std::runtime_error(message); - } -#ifndef __HIP_PLATFORM_AMD__ - /* - DPCT1026:0: The call to cublasSetMathMode was removed because this call is redundant in - SYCL. - */ -#endif - _comp1_event = new sycl::event(); - _comp2_event = new sycl::event(); - _comp_event = new sycl::event(); - _comm_event = new sycl::event(); - } - catch (sycl::exception const& exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ - << std::endl; - std::exit(1); + public: + InferenceContext() try + : _workspace(nullptr), + _seed(42), + _curr_offset(0), + _stream(&dpct::get_in_order_queue()), + _free_memory_size(0), + _num_tokens(1), + _attention_unfused_workspace_offset(0), + _workSpaceSize(0) { + _workSpaceSize = 0; + _workspace = 0; + + int stat = DPCT_CHECK_ERROR(_mklHandle = &dpct::get_in_order_queue()); + if (stat != 0) { + // It would be nice to use mklGetStatusName and + // mklGetStatusString, but they were only added in SYCL 11.4.2. + auto message = + std::string("Failed to create mkl handle: mklStatus_t was ") + + std::to_string(stat); + std::cerr << message << std::endl; + throw std::runtime_error(message); } + _comp1_event = new sycl::event(); + _comp2_event = new sycl::event(); + _comp_event = new sycl::event(); + _comm_event = new sycl::event(); + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); + } - virtual ~InferenceContext() - { - _cublasHandle = nullptr; - sycl::free(_workspace, dpct::get_in_order_queue()); - dpct::destroy_event(_comp1_event); - dpct::destroy_event(_comp2_event); - dpct::destroy_event(_comp_event); - dpct::destroy_event(_comm_event); - } + virtual ~InferenceContext() { + _mklHandle = nullptr; + sycl::free(_workspace, dpct::get_in_order_queue()); + dpct::destroy_event(_comp1_event); + dpct::destroy_event(_comp2_event); + dpct::destroy_event(_comp_event); + dpct::destroy_event(_comm_event); + } - static InferenceContext& Instance() - { - static InferenceContext _ctx; - return _ctx; - } + static InferenceContext& Instance() { + static InferenceContext _ctx; + return _ctx; + } - void GenWorkSpace(const unsigned& num_layers, - const unsigned& num_heads, - const size_t& batch_size, - const size_t& prompt_len, - const size_t& hidden_dim, - const unsigned& mp_size, - const bool& external_cache, - const size_t& elem_size, - const unsigned& rank, - unsigned max_out_tokens, - unsigned min_out_tokens) - { + void GenWorkSpace( + const unsigned& num_layers, + const unsigned& num_heads, + const size_t& batch_size, + const size_t& prompt_len, + const size_t& hidden_dim, + const unsigned& mp_size, + const bool& external_cache, + const size_t& elem_size, + const unsigned& rank, + unsigned max_out_tokens, + unsigned min_out_tokens) { dpct::device_ext& dev_ct1 = dpct::get_current_device(); sycl::queue& q_ct1 = dev_ct1.in_order_queue(); - size_t total_size; - /* - DPCT1106:1: 'cudaMemGetInfo' was migrated with the Intel extensions for device information - which may not be supported by all compilers or runtimes. You may need to adjust the code. - */ - _free_memory_size = 21474836480; -if (!_free_memory_size) { - dpct::get_current_device().get_memory_info(_free_memory_size, total_size); - } - - // Flash attention requires padded heads and we'll conservatively allocate - // for that here. Flash attention is only enabled for head size <= 128 right now - const int head_size = hidden_dim / num_heads; - const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128); - const int effective_head_size = (head_size > 128) ? head_size : padded_head_size; - - size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size; - // Other sequence length dimension is added when the final workSpaceSize is calculated - size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens; - size_t cache_size = - num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2; - size_t minimal_requirements = - temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; - if (_free_memory_size < minimal_requirements) { - printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", - minimal_requirements, - _free_memory_size, - total_size); - throw std::runtime_error("Workspace can't be allocated, no enough memory."); - } - - _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / - (activation_size + temp_size + cache_size); - _max_seq_len = std::min((size_t)max_out_tokens, _max_seq_len); - size_t workSpaceSize = ((external_cache ? (activation_size + temp_size) - : (activation_size + temp_size + cache_size))) * - _max_seq_len * elem_size; - temp_size *= _max_seq_len * elem_size; - - if (_max_seq_len < min_out_tokens) { - printf( - "Allocatable workspace available (%ld tokens) is less than minimum requested " - "workspace (%d tokens)\n", - _max_seq_len, - min_out_tokens); - throw std::runtime_error("Workspace can't be allocated, not enough memory"); - } - - if (!_workspace) { - assert(_workspace == nullptr); - _workspace = (void*)sycl::malloc_device(workSpaceSize, q_ct1); - } else if (_workSpaceSize < workSpaceSize) { - sycl::free(_workspace, q_ct1); - _workspace = (void*)sycl::malloc_device(workSpaceSize, q_ct1); - } - if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize)) - printf( - "------------------------------------------------------\n" - "Free memory : %f (GigaBytes) \n" - "Total memory: %f (GigaBytes) \n" - "Requested memory: %f (GigaBytes) \n" - "Setting maximum total tokens (input + output) to %lu \n" - "WorkSpace: %p \n" - "------------------------------------------------------\n", - (float)_free_memory_size / GIGABYTE, - (float)total_size / GIGABYTE, - (float)workSpaceSize / GIGABYTE, - _max_seq_len, - _workspace); - - if (!_workspace) { - printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", - workSpaceSize, - _free_memory_size, - total_size); - throw std::runtime_error("Workspace is null."); - } - _workSpaceSize = workSpaceSize; - _attention_unfused_workspace_offset = workSpaceSize - temp_size; + size_t total_size; + /* + DPCT1106:1: 'syclMemGetInfo' was migrated with the Intel extensions for + device information which may not be supported by all compilers or runtimes. + You may need to adjust the code. + */ + _free_memory_size = 21474836480; + if (!_free_memory_size) { + dpct::get_current_device().get_memory_info(_free_memory_size, total_size); } - inline int GetMaxTokenLength() const { return _max_seq_len; } - dpct::event_ptr GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } + // Flash attention requires padded heads and we'll conservatively allocate + // for that here. Flash attention is only enabled for head size <= 128 right + // now + const int head_size = hidden_dim / num_heads; + const int padded_head_size = + head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128); + const int effective_head_size = + (head_size > 128) ? head_size : padded_head_size; + + size_t activation_size = + 10 * (num_heads * effective_head_size) * batch_size; + // Other sequence length dimension is added when the final workSpaceSize is + // calculated + size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens; + size_t cache_size = num_layers * batch_size * + ((num_heads * effective_head_size) / mp_size) * 2; + size_t minimal_requirements = + temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; + if (_free_memory_size < minimal_requirements) { + printf( + "Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + minimal_requirements, + _free_memory_size, + total_size); + throw std::runtime_error( + "Workspace can't be allocated, no enough memory."); + } - size_t get_workspace_size() const { return _workSpaceSize; } - void* GetWorkSpace() { return _workspace; } - void* GetAttentionUnfusedWorkspace() - { - return (char*)_workspace + _attention_unfused_workspace_offset; + _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / + (activation_size + temp_size + cache_size); + _max_seq_len = std::min((size_t)max_out_tokens, _max_seq_len); + size_t workSpaceSize = + ((external_cache ? (activation_size + temp_size) + : (activation_size + temp_size + cache_size))) * + _max_seq_len * elem_size; + temp_size *= _max_seq_len * elem_size; + + if (_max_seq_len < min_out_tokens) { + printf( + "Allocatable workspace available (%ld tokens) is less than minimum requested " + "workspace (%d tokens)\n", + _max_seq_len, + min_out_tokens); + throw std::runtime_error( + "Workspace can't be allocated, not enough memory"); } - inline unsigned new_token(unsigned layer_id) - { - if (layer_id == 0) _token_length++; - return _token_length; + if (!_workspace) { + assert(_workspace == nullptr); + _workspace = (void*)sycl::malloc_device(workSpaceSize, q_ct1); + } else if (_workSpaceSize < workSpaceSize) { + sycl::free(_workspace, q_ct1); + _workspace = (void*)sycl::malloc_device(workSpaceSize, q_ct1); + } + if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize)) + printf( + "------------------------------------------------------\n" + "Free memory : %f (GigaBytes) \n" + "Total memory: %f (GigaBytes) \n" + "Requested memory: %f (GigaBytes) \n" + "Setting maximum total tokens (input + output) to %lu \n" + "WorkSpace: %p \n" + "------------------------------------------------------\n", + (float)_free_memory_size / GIGABYTE, + (float)total_size / GIGABYTE, + (float)workSpaceSize / GIGABYTE, + _max_seq_len, + _workspace); + + if (!_workspace) { + printf( + "Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + workSpaceSize, + _free_memory_size, + total_size); + throw std::runtime_error("Workspace is null."); } + _workSpaceSize = workSpaceSize; + _attention_unfused_workspace_offset = workSpaceSize - temp_size; + } + inline int GetMaxTokenLength() const { + return _max_seq_len; + } - inline void reset_tokens(unsigned initial_tokens = 1) - { - _num_tokens = initial_tokens; - } //_token_length = 0; } + dpct::event_ptr GetCompEvent(int id) { + return id == 1 ? _comp1_event : _comp2_event; + } - inline unsigned current_tokens() const { return _num_tokens; } + size_t get_workspace_size() const { + return _workSpaceSize; + } + void* GetWorkSpace() { + return _workspace; + } + void* GetAttentionUnfusedWorkspace() { + return (char*)_workspace + _attention_unfused_workspace_offset; + } - inline void advance_tokens() { _num_tokens++; } + inline unsigned new_token(unsigned layer_id) { + if (layer_id == 0) + _token_length++; + return _token_length; + } - dpct::queue_ptr GetCommStream(bool async_op = false) - { - if (!_comm_stream) - _comm_stream = async_op ? at::cuda::getStreamFromPool(true) - : at::cuda::getCurrentCUDAStream(); - return _comm_stream; - } - dpct::queue_ptr GetCurrentStream(bool other_stream = false) - { - // get current pytorch stream. - if (other_stream) { - if (!_stream) _stream = at::cuda::getStreamFromPool(true); - return _stream; - } - dpct::queue_ptr stream = at::cuda::getCurrentCUDAStream(); - return stream; - } + inline void reset_tokens(unsigned initial_tokens = 1) { + _num_tokens = initial_tokens; + } //_token_length = 0; } - void release_workspace() - { - sycl::free(_workspace, dpct::get_in_order_queue()); - _workspace = nullptr; - } - bool retake_workspace() - { - if (_workspace != nullptr || _workSpaceSize == 0) return true; - _workspace = (void*)sycl::malloc_device(_workSpaceSize, dpct::get_in_order_queue()); - return _workspace != nullptr; - } - dpct::queue_ptr GetCublasHandle() { return _cublasHandle; } + inline unsigned current_tokens() const { + return _num_tokens; + } - std::pair IncrementOffset(uint64_t offset_inc) - { - uint64_t offset = _curr_offset; - _curr_offset += offset_inc; - return std::pair(_seed, offset); + inline void advance_tokens() { + _num_tokens++; + } + + dpct::queue_ptr GetCommStream(bool async_op = false) { + if (!_comm_stream) + _comm_stream = async_op ? at::sycl::getStreamFromPool(true) + : at::sycl::getCurrentSYCLStream(); + return _comm_stream; + } + dpct::queue_ptr GetCurrentStream(bool other_stream = false) { + // get current pytorch stream. + if (other_stream) { + if (!_stream) + _stream = at::sycl::getStreamFromPool(true); + return _stream; } + dpct::queue_ptr stream = at::sycl::getCurrentSYCLStream(); + return stream; + } + + void release_workspace() { + sycl::free(_workspace, dpct::get_in_order_queue()); + _workspace = nullptr; + } + bool retake_workspace() { + if (_workspace != nullptr || _workSpaceSize == 0) + return true; + _workspace = + (void*)sycl::malloc_device(_workSpaceSize, dpct::get_in_order_queue()); + return _workspace != nullptr; + } + dpct::queue_ptr GetCublasHandle() { + return _mklHandle; + } - void SetSeed(uint64_t new_seed) { _seed = new_seed; } + std::pair IncrementOffset(uint64_t offset_inc) { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } - const std::vector>& GetGemmAlgos() const { return _gemm_algos; } + void SetSeed(uint64_t new_seed) { + _seed = new_seed; + } - inline void SynchComp() - { - /* - DPCT1012:2: Detected kernel execution time measurement pattern and generated an initial code - for time measurements in SYCL. You can change the way time is measured depending on your - goals. - */ - _comp_event_ct1 = std::chrono::steady_clock::now(); - *_comp_event = _comp_stream->ext_oneapi_submit_barrier(); - _comm_stream->ext_oneapi_submit_barrier({*_comp_event}); - } - inline void SynchComm() - { - /* - DPCT1012:3: Detected kernel execution time measurement pattern and generated an initial code - for time measurements in SYCL. You can change the way time is measured depending on your - goals. - */ - _comm_event_ct1 = std::chrono::steady_clock::now(); - *_comm_event = _comm_stream->ext_oneapi_submit_barrier(); - _comp_stream->ext_oneapi_submit_barrier({*_comm_event}); - } + const std::vector>& GetGemmAlgos() const { + return _gemm_algos; + } + + inline void SynchComp() { + /* + DPCT1012:2: Detected kernel execution time measurement pattern and generated + an initial code for time measurements in SYCL. You can change the way time + is measured depending on your goals. + */ + _comp_event_ct1 = std::chrono::steady_clock::now(); + *_comp_event = _comp_stream->ext_oneapi_submit_barrier(); + _comm_stream->ext_oneapi_submit_barrier({*_comp_event}); + } + inline void SynchComm() { + /* + DPCT1012:3: Detected kernel execution time measurement pattern and generated + an initial code for time measurements in SYCL. You can change the way time + is measured depending on your goals. + */ + _comm_event_ct1 = std::chrono::steady_clock::now(); + *_comm_event = _comm_stream->ext_oneapi_submit_barrier(); + _comp_stream->ext_oneapi_submit_barrier({*_comm_event}); + } -private: - dpct::queue_ptr _cublasHandle; + private: + dpct::queue_ptr _mklHandle; - dpct::event_ptr _comp_event; - std::chrono::time_point _comp_event_ct1; - dpct::event_ptr _comm_event; - std::chrono::time_point _comm_event_ct1; + dpct::event_ptr _comp_event; + std::chrono::time_point _comp_event_ct1; + dpct::event_ptr _comm_event; + std::chrono::time_point _comm_event_ct1; - void* _workspace; - // offset from _workspace for attention unfused memory - size_t _attention_unfused_workspace_offset; - uint64_t _seed; - uint64_t _curr_offset; + void* _workspace; + // offset from _workspace for attention unfused memory + size_t _attention_unfused_workspace_offset; + uint64_t _seed; + uint64_t _curr_offset; - size_t _workSpaceSize; - size_t _free_memory_size; + size_t _workSpaceSize; + size_t _free_memory_size; - size_t _max_seq_len; + size_t _max_seq_len; - dpct::event_ptr _comp1_event; - dpct::event_ptr _comp2_event; + dpct::event_ptr _comp1_event; + dpct::event_ptr _comp2_event; - dpct::queue_ptr _stream; + dpct::queue_ptr _stream; - unsigned _token_length; - unsigned _num_tokens; - std::vector> _gemm_algos; + unsigned _token_length; + unsigned _num_tokens; + std::vector> _gemm_algos; - dpct::queue_ptr _comp_stream; - dpct::queue_ptr _comm_stream; + dpct::queue_ptr _comp_stream; + dpct::queue_ptr _comm_stream; - std::unordered_map _world_sizes; + std::unordered_map _world_sizes; }; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_mkl_wrappers.h b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_mkl_wrappers.h new file mode 100644 index 0000000..e683f88 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_mkl_wrappers.h @@ -0,0 +1,265 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include + +#ifdef BF16_AVAILABLE +#endif +#include +#include + +int mkl_gemm_ex( + dpct::queue_ptr handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + int algo, + int b_stride = -1) try { + const int ldb = (b_stride == -1) + ? ((transb == oneapi::mkl::transpose::nontrans) ? k : n) + : b_stride; + int status = DPCT_CHECK_ERROR(dpct::gemm( + *handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + dpct::library_data_t::real_float, + (transa == oneapi::mkl::transpose::nontrans) ? m : k, + (const void*)B, + dpct::library_data_t::real_float, + ldb, + (const void*)beta, + C, + dpct::library_data_t::real_float, + m, + dpct::library_data_t::real_float)); + + if (status != 0) { + fprintf( + stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +template +int mkl_gemm_ex( + dpct::queue_ptr handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const T* A, + const T* B, + T* C, + int algo, + int b_stride = -1) try { + const int ldb = (b_stride == -1) + ? ((transb == oneapi::mkl::transpose::nontrans) ? k : n) + : b_stride; + constexpr auto mkl_dtype_16 = std::is_same::value + ? dpct::library_data_t::real_half + : dpct::library_data_t::real_bfloat16; + int status = DPCT_CHECK_ERROR(dpct::gemm( + *handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + mkl_dtype_16, + (transa == oneapi::mkl::transpose::nontrans) ? m : k, + (const void*)B, + mkl_dtype_16, + ldb, + (const void*)beta, + (void*)C, + mkl_dtype_16, + m, + dpct::library_data_t::real_float)); + + if (status != 0) { + fprintf( + stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +int mkl_strided_batched_gemm( + dpct::queue_ptr handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + oneapi::mkl::transpose op_A, + oneapi::mkl::transpose op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo) try { + int status = DPCT_CHECK_ERROR(dpct::gemm_batch( + *handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + dpct::library_data_t::real_float, + (op_A == oneapi::mkl::transpose::nontrans) ? m : k, + stride_A, + B, + dpct::library_data_t::real_float, + (op_B == oneapi::mkl::transpose::nontrans) ? k : n, + stride_B, + beta, + C, + dpct::library_data_t::real_float, + m, + stride_C, + batch, + dpct::library_data_t::real_float)); + + if (status != 0) { + fprintf( + stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +template +int mkl_strided_batched_gemm( + dpct::queue_ptr handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const T* A, + const T* B, + T* C, + oneapi::mkl::transpose op_A, + oneapi::mkl::transpose op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo) try { + constexpr auto mkl_dtype_16 = std::is_same::value + ? dpct::library_data_t::real_half + : dpct::library_data_t::real_bfloat16; + int status = DPCT_CHECK_ERROR(dpct::gemm_batch( + *handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + mkl_dtype_16, + (op_A == oneapi::mkl::transpose::nontrans) ? m : k, + stride_A, + B, + mkl_dtype_16, + (op_B == oneapi::mkl::transpose::nontrans) ? k : n, + stride_B, + beta, + C, + mkl_dtype_16, + m, + stride_C, + batch, + dpct::library_data_t::real_float)); + + if (status != 0) { + fprintf( + stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + + return 0; +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_sycl_layers.h b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_sycl_layers.h new file mode 100644 index 0000000..7327b9d --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/inference/includes/inference_sycl_layers.h @@ -0,0 +1,287 @@ +/******************************************************************************* + * Copyright 2016-2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +#ifdef BF16_AVAILABLE +#endif +#include +#include +#include +#include + +#define MAX_WARP_NUM 32 +#define WARP_SIZE 32 + +#define MAX_THREADS 1024 +#define SMs 80 + +#define MAX_REGISTERS 256 + +template +void launch_attn_softmax_v2( + T* vals, + T* mask, + T* alibi, + float layer_scale, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + int offset, + int mask_stride, + int mp_size, + dpct::queue_ptr stream); + +// Fused bias add with gelu activation +template +void launch_bias_gelu( + T* input, + const T* bias, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream); + +template +void launch_gated_activation( + T* output, + const T* activation, + const T* bias, + int rows, + int output_stride, + int elems_per_row, + bool use_gelu, + dpct::queue_ptr stream); + +// Fused bias add with relu activation +template +void launch_bias_relu( + T* input, + const T* bias, + int intermediate_size, + int batch_size, + dpct::queue_ptr stream); + +template +void launch_bias_add( + T* input, + const T* bias, + int hidden_size, + int batch_size, + dpct::queue_ptr stream); + +template +void launch_bias_residual( + T* input, + T* output, + T* attn, + T* bias, + T* attn_bias, + int batch, + int hidden_dim, + int mp_size, + bool preln, + dpct::queue_ptr stream); + +template +void launch_fused_ln( + T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream); + +template +void launch_fused_residual_ln( + T* output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream); + +template +void launch_fused_residual_ln_store_pre_ln_res( + T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream); + +template +void launch_rms_norm( + T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + dpct::queue_ptr stream); + +template +void launch_dequantize( + T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + dpct::queue_ptr stream); + +template +void launch_dequantize( + T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + dpct::queue_ptr stream); +template +void launch_gptj_residual_add( + T* input, + T* output, + T* attn, + T* bias, + T* attn_bias, + int batch, + int head_size, + int mp_size, + dpct::queue_ptr stream); + +template +void launch_apply_rotary_pos_emb( + T* mixed_query, + T* key_layer, + unsigned head_size, + unsigned seq_len, + unsigned rotary_dim, + unsigned offset, + unsigned num_heads, + unsigned batch, + float rope_theta, + dpct::queue_ptr stream, + int max_out_tokens); + +template +void launch_moe_res_matmul( + T* residual, + T* coef, + T* mlp_out, + int seq_len, + int hidden_dim, + dpct::queue_ptr stream); + +// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3] +template +void launch_transform4d_0213( + T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + dpct::queue_ptr stream, + int trans_count); +template +void launch_bias_add_transform_0213( + T* outputs, + T* vals, + T* vals1, + const T* vals2, + const T* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int seq_length1, + int hidden_dim, + int heads, + int num_kv, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + dpct::queue_ptr stream, + int trans_count, + int max_out_tokens, + float rope_theta); +template +void pad_data( + T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + dpct::queue_ptr stream); + +template +void pad_head_seq( + T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + dpct::queue_ptr stream); + +template +void launch_pad_add_transform_0213( + T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + dpct::queue_ptr stream); + +template +void launch_vector_add( + T* out, + const T* a, + const T* b, + float gamma, + int num_elems, + dpct::queue_ptr stream); diff --git a/intel_extension_for_deepspeed/op_builder/quantizer.py b/intel_extension_for_deepspeed/op_builder/quantizer.py index 155dc42..8de8573 100644 --- a/intel_extension_for_deepspeed/op_builder/quantizer.py +++ b/intel_extension_for_deepspeed/op_builder/quantizer.py @@ -29,4 +29,4 @@ def sources(self): ] def include_paths(self): - return [sycl_kernel_include('csrc/includes')] + return [sycl_kernel_include('csrc/includes'), sycl_kernel_include('csrc/includes/dpct')] diff --git a/intel_extension_for_deepspeed/op_builder/transformer_inference.py b/intel_extension_for_deepspeed/op_builder/transformer_inference.py index 63c68c9..aa6dd8b 100755 --- a/intel_extension_for_deepspeed/op_builder/transformer_inference.py +++ b/intel_extension_for_deepspeed/op_builder/transformer_inference.py @@ -43,6 +43,7 @@ def include_paths(self): includes = [ sycl_kernel_include('csrc/transformer/inference/includes'), sycl_kernel_include('csrc/includes'), + sycl_kernel_include('csrc/includes/dpct'), ] return includes