Skip to content

Commit

Permalink
fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 6, 2024
1 parent ad0b83f commit fe57c87
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
6 changes: 3 additions & 3 deletions operators/cuda/transpose_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ struct Transpose2DCast {
if (shape.size() != 2) {
ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION);
}
size_t n_rows = shape[0];
size_t n_cols = shape[1];
int n_rows = static_cast<int>(shape[0]);
int n_cols = static_cast<int>(shape[1]);

std::vector<int64_t> new_shape{n_cols, n_rows};
std::vector<int64_t> new_shape{static_cast<int64_t>(n_cols), static_cast<int64_t>(n_rows)};
TOUT* output_data = output.Allocate(new_shape);
if (0 == n_rows || 0 == n_cols) {
return {};
Expand Down
9 changes: 4 additions & 5 deletions operators/cuda/transpose_cast_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,25 @@ __global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data,
}

template <typename TIN, typename TOUT>
cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols,
cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols,
const TIN* input, TOUT* output) {
dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1);
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
using TTIN = typename contrib::CudaT<TIN>::MappedType;
using TTOUT = typename contrib::CudaT<TOUT>::MappedType;
Transpose2DCastKernel<TTOUT, TTIN><<<dimGrid, dimBlock, TILE_DIM * TILE_DIM + TILE_DIM, stream>>>(
reinterpret_cast<TTOUT*>(output), reinterpret_cast<const TTIN*>(input),
static_cast<int>(n_rows), static_cast<int>(n_cols));
reinterpret_cast<TTOUT*>(output), reinterpret_cast<const TTIN*>(input), n_rows, n_cols);
return cudaGetLastError();
}

template <>
cudaError_t LaunchTranspose2DCastKernel<float, ortc::MFloat16>(cudaStream_t stream, size_t n_rows, size_t n_cols,
cudaError_t LaunchTranspose2DCastKernel<float, ortc::MFloat16>(cudaStream_t stream, int n_rows, int n_cols,
const float* input, ortc::MFloat16* output) {
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
}

template <>
cudaError_t LaunchTranspose2DCastKernel<ortc::MFloat16, float>(cudaStream_t stream, size_t n_rows, size_t n_cols,
cudaError_t LaunchTranspose2DCastKernel<ortc::MFloat16, float>(cudaStream_t stream, int n_rows, int n_cols,
const ortc::MFloat16* input, float* output) {
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
}
2 changes: 1 addition & 1 deletion operators/cuda/transpose_cast_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
#include <cuda_runtime.h>

template <typename TIN, typename TOUT>
cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, size_t n_rows, size_t n_cols, const TIN* input, TOUT* output);
cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output);

0 comments on commit fe57c87

Please sign in to comment.