diff --git a/src/gpuRIR_cuda.cu b/src/gpuRIR_cuda.cu index 2593e96..7b4fe29 100644 --- a/src/gpuRIR_cuda.cu +++ b/src/gpuRIR_cuda.cu @@ -473,11 +473,14 @@ __global__ void complexPointwiseMulAndScale(cufftComplex *signal_segments, cufft /* Mixed precision KERNELS */ /***************************/ -__global__ void generateRIR_mp_kernel(half2* initialRIR, scalar_t* amp, scalar_t* tau, int T, int M, int N, int iniRIR_N, int ini_red, scalar_t Fs) { +#if CUDART_VERSION < 9020 +__global__ void generateRIR_mp_kernel(half2* initialRIR, scalar_t* amp, scalar_t* tau, int T, int M, int N, int iniRIR_N, int ini_red, scalar_t Fs, scalar_t Tw_2, scalar_t Tw_inv) { + half2 h2Tw_2 = __float2half2_rn(Tw_2); + half2 h2Tw_inv = __float2half2_rn(Tw_inv); +#else +__global__ void generateRIR_mp_kernel(half2* initialRIR, scalar_t* amp, scalar_t* tau, int T, int M, int N, int iniRIR_N, int ini_red, scalar_t Fs, half2 h2Tw_2, half2 h2Tw_inv) { +#endif #if __CUDA_ARCH__ >= 530 - half2 Tw_2 = __float2half2_rn(8e-3f * Fs / 2); - half2 Tw_inv = __float2half2_rn(1.0f / (8e-3f * Fs)); - int t = blockIdx.x * blockDim.x + threadIdx.x; int m = blockIdx.y * blockDim.y + threadIdx.y; int n_ini = blockIdx.z * ini_red; @@ -489,7 +492,7 @@ __global__ void generateRIR_mp_kernel(half2* initialRIR, scalar_t* amp, scalar_t scalar_t loc_tim_2 = 2*t+1; for (int n=n_ini; n>>( initialRIR, amp, tau, T/2, M, N, iniRIR_N, initialReduction, Fs ); + #if CUDART_VERSION < 9020 + // For CUDA versions older than 9.2 it is nos possible to call from host code __float2half2_rn, + // but doing it in the kernel is slower + scalar_t Tw_2 = 8e-3f * Fs / 2; + scalar_t Tw_inv = 1.0f / (8e-3f * Fs); + #else + half2 Tw_2 = __float2half2_rn(8e-3f * Fs / 2); + half2 Tw_inv = __float2half2_rn(1.0f / (8e-3f * Fs)); + #endif + generateRIR_mp_kernel<<>>( initialRIR, amp, tau, T/2, M, N, iniRIR_N, initialReduction, Fs, Tw_2, Tw_inv ); gpuErrchk( cudaDeviceSynchronize() ); gpuErrchk( cudaPeekAtLastError() );