diff --git a/operators/cuda/transpose_cast.h b/operators/cuda/transpose_cast.h index 6a54be6ba..6ffae51c2 100644 --- a/operators/cuda/transpose_cast.h +++ b/operators/cuda/transpose_cast.h @@ -22,10 +22,10 @@ struct Transpose2DCast { if (shape.size() != 2) { ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION); } - size_t n_rows = shape[0]; - size_t n_cols = shape[1]; + int n_rows = static_cast(shape[0]); + int n_cols = static_cast(shape[1]); - std::vector new_shape{n_cols, n_rows}; + std::vector new_shape{static_cast(n_cols), static_cast(n_rows)}; TOUT* output_data = output.Allocate(new_shape); if (0 == n_rows || 0 == n_cols) { return {}; diff --git a/operators/cuda/transpose_cast_impl.cu b/operators/cuda/transpose_cast_impl.cu index c9d850222..cdd6a177b 100644 --- a/operators/cuda/transpose_cast_impl.cu +++ b/operators/cuda/transpose_cast_impl.cu @@ -32,26 +32,25 @@ __global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data, } template -cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, +cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output) { dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1); dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); using TTIN = typename contrib::CudaT::MappedType; using TTOUT = typename contrib::CudaT::MappedType; Transpose2DCastKernel<<>>( - reinterpret_cast(output), reinterpret_cast(input), - static_cast(n_rows), static_cast(n_cols)); + reinterpret_cast(output), reinterpret_cast(input), n_rows, n_cols); return cudaGetLastError(); } template <> -cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, +cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const float* input, ortc::MFloat16* output) { return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output); } template <> -cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, +cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const ortc::MFloat16* input, float* output) { return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output); } diff --git a/operators/cuda/transpose_cast_impl.cuh b/operators/cuda/transpose_cast_impl.cuh index 3e271f35d..b3fb2c44f 100644 --- a/operators/cuda/transpose_cast_impl.cuh +++ b/operators/cuda/transpose_cast_impl.cuh @@ -6,4 +6,4 @@ #include template -cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, const TIN* input, TOUT* output); \ No newline at end of file +cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output); \ No newline at end of file