diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index f0f5a587..5ef70eab 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -42,31 +42,120 @@ template __global__ void AdamAccumulateGradKernel(const T *grad_data, T *param_data, size_t num_elements, T *m_data, T *v_data, float learning_rate, float beta1, float beta2, float eps, const float bias_correction_m, const float bias_correction_v) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < num_elements) { - m_data[idx] = common::cuda::Fma(common::cuda::Cast(beta1), m_data[idx], - common::cuda::Cast(1 - beta1) * grad_data[idx]); - v_data[idx] = common::cuda::Fma(common::cuda::Cast(beta2), v_data[idx], - common::cuda::Cast(1 - beta2) * grad_data[idx] * grad_data[idx]); - - const float m_hat = common::cuda::Cast(m_data[idx]) / bias_correction_m; - const float v_hat = common::cuda::Cast(v_data[idx]) / bias_correction_v; - param_data[idx] = common::cuda::Sub( - param_data[idx], common::cuda::Cast(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps))); + // size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + // if (idx < num_elements) { + // m_data[idx] = common::cuda::Fma(common::cuda::Cast(beta1), m_data[idx], + // common::cuda::Cast(1 - beta1) * grad_data[idx]); + // v_data[idx] = common::cuda::Fma(common::cuda::Cast(beta2), v_data[idx], + // common::cuda::Cast(1 - beta2) * grad_data[idx] * grad_data[idx]); + + // const float m_hat = common::cuda::Cast(m_data[idx]) / bias_correction_m; + // const float v_hat = common::cuda::Cast(v_data[idx]) / bias_correction_v; + + // param_data[idx] = common::cuda::Sub( + // param_data[idx], common::cuda::Cast(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps))); + // } + + //先搞向量化内存 + constexpr int VEC_SIZE = 16 / sizeof(T); + size_t vec_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE; + size_t e_start = num_elements / VEC_SIZE * VEC_SIZE; + if (vec_idx < e_start) { + + //能向量化搬运就向量化搬运 + T local_grad[VEC_SIZE], local_param[VEC_SIZE], local_m[VEC_SIZE], local_v[VEC_SIZE]; + + // 开始搬运 + *reinterpret_cast(local_grad) = *reinterpret_cast(grad_data + vec_idx); + *reinterpret_cast(local_param) = *reinterpret_cast(param_data + vec_idx); + *reinterpret_cast(local_m) = *reinterpret_cast(m_data + vec_idx); + *reinterpret_cast(local_v) = *reinterpret_cast(v_data + vec_idx); + + + # pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + + //将存储的 int4* 转换为float + float g = common::cuda::Cast(local_grad[i]); + float p = common::cuda::Cast(local_param[i]); + float m = common::cuda::Cast(local_m[i]); + float v = common::cuda::Cast(local_v[i]); + + + m = beta1 * m + (1.0f - beta1) * g; + v = beta2 * v + (1.0f - beta2) * g * g; + + float m_hat = m / bias_correction_m; + float v_hat = v / bias_correction_v; + + // 使用内置的快速数学函数处理 float + p -= learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps); + + // 计算完毕,转回 T 存储到 local 数组 + local_m[i] = common::cuda::Cast(m); + local_v[i] = common::cuda::Cast(v); + local_param[i] = common::cuda::Cast(p); + } + + // 写回原数组 + *reinterpret_cast(param_data + vec_idx) = *reinterpret_cast(local_param); + *reinterpret_cast(m_data + vec_idx) = *reinterpret_cast(local_m); + *reinterpret_cast(v_data + vec_idx) = *reinterpret_cast(local_v); + + }else if(vec_idx == e_start){ + + # pragma unroll + for(size_t idx = vec_idx; idx < num_elements; ++ idx){ + + m_data[idx] = common::cuda::Fma(common::cuda::Cast(beta1), m_data[idx], + common::cuda::Cast(1 - beta1) * grad_data[idx]); + v_data[idx] = common::cuda::Fma(common::cuda::Cast(beta2), v_data[idx], + common::cuda::Cast(1 - beta2) * grad_data[idx] * grad_data[idx]); + + const float m_hat = common::cuda::Cast(m_data[idx]) / bias_correction_m; + const float v_hat = common::cuda::Cast(v_data[idx]) / bias_correction_v; + + param_data[idx] = common::cuda::Sub( + param_data[idx], common::cuda::Cast(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps))); + } } } void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_ptr ¶m, const std::shared_ptr &m, const std::shared_ptr &v, float learning_rate, float beta1, float beta2, float eps, int64_t t) { + // size_t num_elements = grad->NumElements(); + + // const float bias_correction_m = 1.0f - std::pow(beta1, t); + // const float bias_correction_v = 1.0f - std::pow(beta2, t); + + // int threads_per_block = 256; + // int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; + + // auto device = grad->GetDevice(); + // const auto &cuda_stream = dynamic_cast( + // infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + // ->cuda_stream(); + + // DispatchFunc( + // grad->Dtype(), + // [=]() { + // AdamAccumulateGradKernel<<>>( + // static_cast(grad->DataPtr()), static_cast(param->DataPtr()), num_elements, + // static_cast(m->DataPtr()), static_cast(v->DataPtr()), learning_rate, beta1, beta2, eps, + // bias_correction_m, bias_correction_v); + // }, + // "CUDA AdamAccumulateGrad"); + size_t num_elements = grad->NumElements(); + + const float bias_correction_m = 1.0f - std::pow(beta1, t); const float bias_correction_v = 1.0f - std::pow(beta2, t); - int threads_per_block = 256; - int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; + auto device = grad->GetDevice(); const auto &cuda_stream = dynamic_cast( @@ -76,13 +165,18 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p DispatchFunc( grad->Dtype(), [=]() { + int element_size = sizeof(T); + int VEC_SIZE = 16 / element_size; + int threads_per_block = 256; + int total_threads = (num_elements + VEC_SIZE - 1) / VEC_SIZE; + int num_blocks = (total_threads + threads_per_block - 1) / threads_per_block; AdamAccumulateGradKernel<<>>( static_cast(grad->DataPtr()), static_cast(param->DataPtr()), num_elements, static_cast(m->DataPtr()), static_cast(v->DataPtr()), learning_rate, beta1, beta2, eps, bias_correction_m, bias_correction_v); }, "CUDA AdamAccumulateGrad"); -} + } } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(kernel_name) \ diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 62a5b0d2..1d5d3ad7 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -14,10 +14,48 @@ namespace infini_train::kernels::cuda { template __global__ void CastKernel(Tdst *dst, const Tsrc *src, size_t num_elements, size_t offset) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + // size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + // if (idx < num_elements) { + // dst[idx] = common::cuda::Cast(src[idx]); + // } - if (idx < num_elements) { - dst[idx] = common::cuda::Cast(src[idx]); + // 统一每个线程处理 4 个元素 + constexpr int VEC_SIZE = 4; + size_t idx = (size_t)(blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE + offset; + + if (idx + VEC_SIZE <= num_elements) { + Tsrc s_vec[VEC_SIZE]; + Tdst d_vec[VEC_SIZE]; + + // 根据 Tsrc 宽度决定加载指令 (如果是 2 字节读 8 字节, 如果是 4 字节读 16 字节) + if constexpr (sizeof(Tsrc) == 2) { + *reinterpret_cast(s_vec) = *reinterpret_cast(src + idx); + } else if constexpr (sizeof(Tsrc) == 4) { + *reinterpret_cast(s_vec) = *reinterpret_cast(src + idx); + } else { + for (int i = 0; i < VEC_SIZE; ++i) s_vec[i] = src[idx + i]; + } + + // 寄存器内完成类型转换 + #pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + d_vec[i] = common::cuda::Cast(s_vec[i]); + } + + // 根据 Tdst 宽度决定写回指令 + if constexpr (sizeof(Tdst) == 2) { + *reinterpret_cast(d_vec) = *reinterpret_cast(d_vec); + *reinterpret_cast(dst + idx) = *reinterpret_cast(d_vec); + } else if constexpr (sizeof(Tdst) == 4) { + *reinterpret_cast(dst + idx) = *reinterpret_cast(d_vec); + } else { + for (int i = 0; i < VEC_SIZE; ++i) dst[idx + i] = d_vec[i]; + } + } else { + // 处理末尾非对齐数据 + for (size_t i = idx; i < num_elements && i < idx + VEC_SIZE; ++i) { + dst[i] = common::cuda::Cast(src[i]); + } } } @@ -29,15 +67,25 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { ->cuda_stream(); const size_t num_elements = input->NumElements(); + + // const size_t num_elements = input->NumElements(); + // dim3 block_dims(256); + // dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x)); + // const size_t step = grid_dims.x * block_dims.x; + + // 这里的 VEC_SIZE 必须与 Kernel 内部保持一致 + int VEC_SIZE = 4; dim3 block_dims(256); - dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x)); - const size_t step = grid_dims.x * block_dims.x; + // 每个线程干 4 个人的活,所以线程总数除以 4 + dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x * VEC_SIZE)); + const size_t step = grid_dims.x * block_dims.x * VEC_SIZE; DispatchFunc, DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); auto src = static_cast(input->DataPtr()); + // 网格步进循环处理超大规模 Tensor for (size_t offset = 0; offset < num_elements; offset += step) { CastKernel<<>>(dst, src, num_elements, offset); } @@ -53,4 +101,4 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { REGISTER_CUDA_CAST_KERNEL(Cast) -#undef REGISTER_CUDA_CAST_KERNEL +#undef REGISTER_CUDA_CAST_KERNEL \ No newline at end of file diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index d00bd0f2..733c2516 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -11,17 +11,50 @@ namespace infini_train::kernels::cuda { template __global__ void FillKernel(T *data, T value, size_t size) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - data[idx] = value; + // size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + // if (idx < size) { + // data[idx] = value; + // } + + // 计算一个线程处理的向量步长(16字节 / 类型大小) + constexpr int VEC_SIZE = 16 / sizeof(T); + // 重新计算向量化后的全局索引 + size_t idx = (size_t)(blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE; + + if (idx + VEC_SIZE <= size) { + T local[VEC_SIZE]; + // 强制循环展开,在寄存器中准备好填充值 + #pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + local[i] = value; + } + // 将寄存器数组转为 int4,单条指令完成 128-bit 写入,压榨显存带宽 + *reinterpret_cast(data + idx) = *reinterpret_cast(local); + } else { + // 处理末尾不足 VEC_SIZE 的非对齐部分 + for (size_t i = idx; i < size; ++i) { + data[i] = value; + } } } -// TODO(dcj): refactor Fill kernel with elementwise template void Fill(std::shared_ptr tensor, void *value_ptr) { - const int num_tokens = tensor->NumElements(); - const int threads_per_block = 256; - const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block; + // const int num_tokens = tensor->NumElements(); + // const int threads_per_block = 256; + // const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block; + // auto device = tensor->GetDevice(); + // const auto &cuda_stream = dynamic_cast( + // infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + // ->cuda_stream(); + // DispatchFunc( + // tensor->Dtype(), + // [=]() { + // FillKernel<<>>( + // static_cast(tensor->DataPtr()), *(static_cast(value_ptr)), tensor->NumElements()); + // }, + // "CUDA Fill"); + + size_t num_elements = tensor->NumElements(); auto device = tensor->GetDevice(); const auto &cuda_stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) @@ -30,11 +63,21 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { DispatchFunc( tensor->Dtype(), [=]() { + // 每一个 T 类型的大小 + int element_size = sizeof(T); + // 计算向量化步长,通常 float 是 4 个,half 是 8 个 + int VEC_SIZE = 16 / element_size; + int threads_per_block = 256; + // 因为每个线程处理 VEC_SIZE 个元素,所以 Block 总数要除以步长 + int total_threads = (num_elements + VEC_SIZE - 1) / VEC_SIZE; + int num_blocks = (total_threads + threads_per_block - 1) / threads_per_block; + FillKernel<<>>( - static_cast(tensor->DataPtr()), *(static_cast(value_ptr)), tensor->NumElements()); + static_cast(tensor->DataPtr()), *(static_cast(value_ptr)), num_elements); }, "CUDA Fill"); } + } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_FILL_KERNEL(kernel_name) \ @@ -42,4 +85,4 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { REGISTER_CUDA_FILL_KERNEL(Fill) -#undef REGISTER_CUDA_FILL_KERNEL +#undef REGISTER_CUDA_FILL_KERNEL \ No newline at end of file diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index 334b257c..fb53813b 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -16,123 +16,136 @@ namespace infini_train::kernels::cuda { -std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { - /* - output[*, m, n] = input[*, m, k] * other[*, k, n] - */ - const auto &input_dims = input->Dims(); - const auto &other_dims = other->Dims(); +// template __global__ void BiasCopyKernel(T *output, const T *bias, int bs, int out_features) { +// int idx = blockIdx.x * blockDim.x + threadIdx.x; +// if (idx >= bs * out_features) { +// return; +// } +// int j = idx % out_features; +// output[idx] = bias[j]; +// } + +// 向量化写入优化 +template +__global__ void BiasCopyKernel(T *output, const T *bias, int bs, int out_features) { + constexpr int VEC_SIZE = 16 / sizeof(T); + size_t idx = (size_t)(blockIdx.x * blockDim.x + threadIdx.x) * VEC_SIZE; + size_t total_elements = (size_t)bs * out_features; + + if (idx + VEC_SIZE <= total_elements) { + T local_vals[VEC_SIZE]; + #pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + local_vals[i] = bias[(idx + i) % out_features]; + } + *reinterpret_cast(output + idx) = *reinterpret_cast(local_vals); + } else { + for (size_t i = idx; i < total_elements; ++i) { + output[i] = bias[i % out_features]; + } + } +} + +// template +// __global__ void ReduceColumnsKernel(const T *__restrict__ input, T *__restrict__ output, int num_rows, int num_cols) { +// using BlockReduce = cub::BlockReduce; +// __shared__ typename BlockReduce::TempStorage temp_storage; +// int row = blockIdx.x; +// float sum = 0.0f; +// for (int col = threadIdx.x; col < num_cols; col += blockDim.x) { +// sum += common::cuda::Cast(input[row * num_cols + col]); +// } +// float reduced = BlockReduce(temp_storage).Sum(sum); +// if (threadIdx.x == 0) { +// output[row] = reduced; +// } +// } + +// 向量化加载与规约优化 +template +__global__ void ReduceColumnsKernel(const T *__restrict__ input, T *__restrict__ output, int num_rows, int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; - CHECK_GE(input_dims.size(), 2); - CHECK_GE(other_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); + int row = blockIdx.x; + float sum = 0.0f; + constexpr int VEC_SIZE = 16 / sizeof(T); + int col = threadIdx.x * VEC_SIZE; + + for (; col + VEC_SIZE <= num_cols; col += blockDim.x * VEC_SIZE) { + T local_vals[VEC_SIZE]; + *reinterpret_cast(local_vals) = *reinterpret_cast(input + (size_t)row * num_cols + col); + #pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + sum += common::cuda::Cast(local_vals[i]); + } + } + for (int c = col + (threadIdx.x % VEC_SIZE); c < num_cols; c += blockDim.x) { + sum += common::cuda::Cast(input[(size_t)row * num_cols + c]); + } + + float reduced = BlockReduce(temp_storage).Sum(sum); + if (threadIdx.x == 0) { + output[row] = common::cuda::Cast(reduced); + } +} +std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); const int64_t m = input_dims[input_dims.size() - 2]; const int64_t k = input_dims[input_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); const int64_t n = other_dims[other_dims.size() - 1]; - const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; - } - - auto dtype = input->Dtype(); - std::vector output_dims = input_dims; - output_dims[output_dims.size() - 1] = n; - auto output = std::make_shared(output_dims, dtype, input->GetDevice()); + auto output = std::make_shared(input_dims, input->Dtype(), input->GetDevice()); auto device = input->GetDevice(); const float alpha = 1.0f, beta = 0.0f; cublasHandle_t handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - // cuBLAS is colmun-major - // output = input * other --> output.T = other.T * input.T - // C = A * B ==> output.T[*, n, m] = other.T[*, n, k] * input.T[*, k, m] - // C = output.T[*, n, m] - // A = other.T[*, n, k] - // B = input.T[*, k, m] - int lda = n; - int ldb = k; - int ldc = n; - int64_t stride_a = n * k; - int64_t stride_b = k * m; - int64_t stride_c = m * n; - // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) - - switch (dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_32F, lda, - stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_32F, - ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_16BF, lda, - stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_16BF, - ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - default: - LOG_UNSUPPORTED_DTYPE(dtype, "CUDA MatmulForward"); - } + // 开启TF32加速 + CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); + int lda = n, ldb = k, ldc = n; + int64_t stride_a = n * k, stride_b = k * m, stride_c = m * n; + + switch (input->Dtype()) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_32F, lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), infini_train::DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_16BF, lda, stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), infini_train::DataType::kBFLOAT16) + } return output; } -std::tuple, std::shared_ptr> -MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { - /* - grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] - */ - +std::tuple, std::shared_ptr> +MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, + const std::shared_ptr &grad_output) { auto input_dtype = input->Dtype(); auto other_dtype = other->Dtype(); auto grad_output_dtype = grad_output->Dtype(); - DataType promoted_type - = DispatchFunc, DataTypeList, DataTypeList>( - {input_dtype, other_dtype, grad_output_dtype}, - [=]() { return DataTypeMap_v>; }, - "CUDA MatmulBackward"); + + DataType promoted_type = DispatchFunc, DataTypeList, DataTypeList>( + {input_dtype, other_dtype, grad_output_dtype}, + [=]() { return DataTypeMap_v>; }, + "CUDA MatmulBackward"); - auto input_promoted = input_dtype == promoted_type ? input : std::make_shared(input->To(promoted_type)); - auto other_promoted = other_dtype == promoted_type ? other : std::make_shared(other->To(promoted_type)); - auto grad_output_promoted - = grad_output_dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); + auto input_promoted = input_dtype == promoted_type ? input : std::make_shared(input->To(promoted_type)); + auto other_promoted = other_dtype == promoted_type ? other : std::make_shared(other->To(promoted_type)); + auto grad_output_promoted = grad_output_dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); const auto &input_dims = input->Dims(); const auto &other_dims = other->Dims(); - const auto &grad_output_dims = grad_output->Dims(); - - CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); - CHECK_EQ(input_dims.size(), grad_output_dims.size()); - const int64_t m = input_dims[input_dims.size() - 2]; const int64_t k = input_dims[input_dims.size() - 1]; const int64_t n = other_dims[other_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); - CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); - const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; - CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; - } - auto grad_input = std::make_shared(input_dims, promoted_type, grad_output->GetDevice()); - auto grad_other = std::make_shared(other_dims, promoted_type, grad_output->GetDevice()); + auto grad_input = std::make_shared(input_dims, promoted_type, grad_output->GetDevice()); + auto grad_other = std::make_shared(other_dims, promoted_type, grad_output->GetDevice()); DispatchFunc( - promoted_type, - [=]() { - grad_input->Fill(0); - grad_other->Fill(0); - }, - "CUDA MatmulBackward"); + promoted_type, [=]() { grad_input->Fill(0); grad_other->Fill(0); }, "CUDA MatmulBackward"); auto device = input_promoted->GetDevice(); const float alpha = 1.0f, beta = 0.0f; @@ -140,328 +153,124 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptrGetBlasHandle(device)) ->cublas_handle(); - { - // cuBLAS is colmun-major - // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T - // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] - // C = grad_input.T[*, k, m] - // A = other.T[*, n, k] - // B = grad_output.T[*, n, m] - const int lda = n, ldb = n, ldc = k; - const int64_t stride_a = k * n; - const int64_t stride_b = n * m; - const int64_t stride_c = m * k; - switch (promoted_type) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_32F, - lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE( - WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, lda, - stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_input->DataPtr(), - CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } + CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); + + const int lda = n, ldb = n, ldc = k; + const int64_t stride_a = k * n, stride_b = n * m, stride_c = m * k; + + switch (promoted_type) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_32F, lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_input->DataPtr(), CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kBFLOAT16) } - { - // cuBLAS is colmun-major - // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input - // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] - // C = grad_other.T[*, n, k] - // A = grad_output.T[*, n, m] - // B = input.T[*, k, m] - const int lda = n, ldb = k, ldc = n; - const int64_t stride_a = n * m; - const int64_t stride_b = k * m; - const int64_t stride_c = n * k; - switch (promoted_type) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_32F, lda, stride_a, input_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE( - WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), CUDA_R_16BF, - lda, stride_a, input_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_other->DataPtr(), - CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } + const int lda2 = n, ldb2 = k, ldc2 = n; + const int64_t stride_a2 = n * m, stride_b2 = k * m, stride_c2 = n * k; + switch (promoted_type) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), CUDA_R_32F, lda2, stride_a2, input_promoted->DataPtr(), CUDA_R_32F, ldb2, stride_b2, &beta, grad_other->DataPtr(), CUDA_R_32F, ldc2, stride_c2, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), CUDA_R_16BF, lda2, stride_a2, input_promoted->DataPtr(), CUDA_R_16BF, ldb2, stride_b2, &beta, grad_other->DataPtr(), CUDA_R_16BF, ldc2, stride_c2, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kBFLOAT16) } return {grad_input, grad_other}; } -template __global__ void BiasCopyKernel(T *output, const T *bias, int bs, int out_features) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= bs * out_features) { - return; - } - int j = idx % out_features; - output[idx] = bias[j]; -} - -std::shared_ptr LinearForward(const std::shared_ptr &input, const std::shared_ptr &weight, - bool transpose, const std::shared_ptr &bias) { - - /* - !transpose: output = input * weight + bias - output[*, out_features] = input[*, in_features] * weight[in_features, out_features] + bias[out_features] - - transpose: output = input * weight^T + bias - output[*, out_features] = input[*, in_features] * weight[out_features, in_features]^T + bias[out_features] - */ - +std::shared_ptr LinearForward(const std::shared_ptr &input, const std::shared_ptr &weight, + bool transpose, const std::shared_ptr &bias) { const auto &input_dims = input->Dims(); - CHECK_GE(input_dims.size(), 2); const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies{}); const int64_t in_features = *input_dims.rbegin(); + const int64_t out_features = weight->Dims()[transpose ? 0 : 1]; - const auto &weight_dims = weight->Dims(); - CHECK_EQ(weight_dims.size(), 2); - CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]); - - // As for cublas: - // C = alpha * op(B) * op(A) + beta * C - // Dimensions: - // input: (bs, in_features) - // weight: (in_features, out_features) or (out_features, in_features) if transposed - // output: (bs, out_features) - const int64_t out_features = weight_dims[transpose ? 0 : 1]; - - auto dtype = input->Dtype(); - auto output_dims = input_dims; - *output_dims.rbegin() = out_features; - auto output = std::make_shared(output_dims, dtype, input->GetDevice()); - + auto output = std::make_shared(input_dims, input->Dtype(), input->GetDevice()); auto device = input->GetDevice(); const auto &cuda_stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); if (bias) { - CHECK_EQ(bias->Dims().size(), 1); - CHECK_EQ(bias->Dims()[0], out_features); int threads_per_block = 256; int num_blocks = (bs * out_features + threads_per_block - 1) / threads_per_block; - - DispatchFunc( - dtype, - [=]() { - BiasCopyKernel<<>>( - static_cast(output->DataPtr()), static_cast(bias->DataPtr()), bs, out_features); - }, - "CUDA LinearForward"); - } else { - DispatchFunc( - input->Dtype(), [=]() { output->Fill(0); }, "CUDA LinearForward"); + DispatchFunc(input->Dtype(), [=]() { BiasCopyKernel<<>>(static_cast(output->DataPtr()), static_cast(bias->DataPtr()), bs, out_features); }, "CUDA LinearForward"); } - const float alpha = 1.0f; - const float beta = 1.0f; - auto trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - auto trans_b = CUBLAS_OP_N; - auto lda = transpose ? in_features : out_features; cublasHandle_t handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - // TODO(zbl): use cublasSgemv if possible for convenience and simplicity - // - // - if a is transposed: - // weight is [out_features, in_features] here - // output = input * weight.T --> output.T = weight * input.T - // C = output.T[out_features, bs] - // A = weight.T[in_features, out_features] - // B = input.T[in_features, bs] - // - // - if a is not transposed: - // output = input * weight --> output.T = weight.T * input.T - // C = output.T[out_features, bs] - // A = weight.T[out_features, in_features] - // B = input.T[in_features, bs] - switch (input->Dtype()) { - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasSgemm(handle, trans_a, trans_b, out_features, bs, in_features, &alpha, - static_cast(weight->DataPtr()), lda, - static_cast(input->DataPtr()), in_features, &beta, - static_cast(output->DataPtr()), out_features)); - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasGemmEx(handle, trans_a, trans_b, out_features, bs, in_features, &alpha, - weight->DataPtr(), CUDA_R_16BF, lda, input->DataPtr(), CUDA_R_16BF, - in_features, &beta, output->DataPtr(), CUDA_R_16BF, out_features, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - }), - DataType::kBFLOAT16) - } - - return output; -} - -template -__global__ void ReduceColumnsKernel(const T *__restrict__ input, T *__restrict__ output, int num_rows, int num_cols) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int row = blockIdx.x; - float sum = 0.0f; + CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); - for (int col = threadIdx.x; col < num_cols; col += blockDim.x) { - sum += common::cuda::Cast(input[row * num_cols + col]); - } - - float reduced = BlockReduce(temp_storage).Sum(sum); + const float alpha = 1.0f, beta = bias ? 1.0f : 0.0f; + auto trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = transpose ? in_features : out_features; - if (threadIdx.x == 0) { - output[row] = reduced; + switch (input->Dtype()) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasSgemm(handle, trans_a, CUBLAS_OP_N, out_features, bs, in_features, &alpha, static_cast(weight->DataPtr()), lda, static_cast(input->DataPtr()), in_features, &beta, static_cast(output->DataPtr()), out_features));), DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmEx(handle, trans_a, CUBLAS_OP_N, out_features, bs, in_features, &alpha, weight->DataPtr(), CUDA_R_16BF, lda, input->DataPtr(), CUDA_R_16BF, in_features, &beta, output->DataPtr(), CUDA_R_16BF, out_features, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kBFLOAT16) } + return output; } -std::tuple, std::shared_ptr, std::shared_ptr> -LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t out_features, const std::shared_ptr &grad_output, const bool bias) { +std::tuple, std::shared_ptr, std::shared_ptr> +LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, + int64_t out_features, const std::shared_ptr &grad_output, const bool bias) { const auto &input_dims = input->Dims(); - CHECK_GE(input_dims.size(), 2); const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies{}); const int64_t in_features = *input_dims.rbegin(); + DataType promoted_type = DispatchFunc, DataTypeList, DataTypeList>( + {input->Dtype(), weight->Dtype(), grad_output->Dtype()}, + [=]() { return DataTypeMap_v>; }, + "CUDA LinearBackward"); - auto dtype = grad_output->Dtype(); - auto input_dtype = input->Dtype(); - auto weight_dtype = weight->Dtype(); - DataType promoted_type - = DispatchFunc, DataTypeList, DataTypeList>( - {input_dtype, weight_dtype, dtype}, - [=]() { return DataTypeMap_v>; }, - "CUDA LinearBackward"); - - auto input_promoted = input_dtype == promoted_type ? input : std::make_shared(input->To(promoted_type)); - auto weight_promoted = weight_dtype == promoted_type ? weight : std::make_shared(weight->To(promoted_type)); - auto grad_output_promoted - = dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); - - const auto &weight_dims = weight->Dims(); - CHECK_EQ(weight_dims.size(), 2); - CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]); - CHECK_EQ(out_features, weight_dims[transpose ? 0 : 1]); - - auto grad_input = std::make_shared(input_dims, promoted_type, grad_output->GetDevice()); - auto grad_weight = std::make_shared(weight_dims, promoted_type, grad_output->GetDevice()); - std::shared_ptr grad_bias = nullptr; - - auto initialize_gradients = [&](auto zero_value, DataType dtype) { - using T = decltype(zero_value); - grad_input->Fill(zero_value); - grad_weight->Fill(zero_value); - if (bias) { - grad_bias = std::make_shared(std::vector{out_features}, dtype, grad_output->GetDevice()); - grad_bias->Fill(zero_value); - } - }; - DispatchFunc( - promoted_type, [=]() { initialize_gradients(T(0), promoted_type); }, "CUDA LinearBackward"); + auto grad_input = std::make_shared(input_dims, promoted_type, grad_output->GetDevice()); + auto grad_weight = std::make_shared(weight->Dims(), promoted_type, grad_output->GetDevice()); + std::shared_ptr grad_bias = nullptr; - auto device = input_promoted->GetDevice(); + DispatchFunc(promoted_type, [=, &grad_bias]() { + grad_input->Fill(0); grad_weight->Fill(0); + if (bias) { grad_bias = std::make_shared(std::vector{out_features}, promoted_type, grad_output->GetDevice()); grad_bias->Fill(0); } + }, "CUDA LinearBackward"); + + auto device = input->GetDevice(); const auto &cuda_stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); + + CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); + const float alpha = 1.0f, beta = 0.0f; - float alpha = 1.0f; - float beta = 0.0f; auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto trans_b1 = CUBLAS_OP_N; auto lda1 = transpose ? in_features : out_features; - auto trans_a2 = CUBLAS_OP_N; - auto trans_b2 = CUBLAS_OP_T; + switch (promoted_type) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasSgemm(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, &alpha, static_cast(weight->DataPtr()), lda1, static_cast(grad_output->DataPtr()), out_features, &beta, static_cast(grad_input->DataPtr()), in_features));), DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, &alpha, weight->DataPtr(), CUDA_R_16BF, lda1, grad_output->DataPtr(), CUDA_R_16BF, out_features, &beta, grad_input->DataPtr(), CUDA_R_16BF, in_features, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kBFLOAT16) + } + int m2 = transpose ? in_features : out_features; int n2 = transpose ? out_features : in_features; - const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); - const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); - auto lda2 = transpose ? in_features : out_features; - auto ldb2 = transpose ? out_features : in_features; - auto ldc2 = transpose ? in_features : out_features; - - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); + auto trans_a2 = CUBLAS_OP_N, trans_b2 = CUBLAS_OP_T; + const void *a2 = transpose ? input->DataPtr() : grad_output->DataPtr(); + const void *b2 = transpose ? grad_output->DataPtr() : input->DataPtr(); + int lda2 = transpose ? in_features : out_features, ldb2 = transpose ? out_features : in_features, ldc2 = transpose ? in_features : out_features; switch (promoted_type) { - // TODO(zbl): use cublasSgemv if possible - DISPATCH_CASE(WRAP({ - // - if transpose: - // weight is [out_features, in_features] here - // d_input = d_output * weight --> d_input.T = weight.T * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[in_features, out_features] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // weight is [in_features, out_features] here - // d_input = d_output * weight.T --> d_input.T = weight * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[out_features, in_features] - // B = d_output.T[out_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, - static_cast(weight_promoted->DataPtr()), lda1, - static_cast(grad_output_promoted->DataPtr()), - out_features, &beta, static_cast(grad_input->DataPtr()), - in_features)); - // - if transpose: - // d_weight = d_output.T * input --> d_weight.T = input.T * d_output - // C = d_weight.T[in_features, out_features] - // A = input.T[in_features, bs] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // d_weight = input.T * d_output --> d_weight.T = d_output.T * input - // C = d_weight.T[out_features, in_features] - // A = d_output.T[out_features, bs] - // B = input.T[in_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, - static_cast(a2), lda2, static_cast(b2), - ldb2, &beta, static_cast(grad_weight->DataPtr()), ldc2)); - // d_bias = \sum_i(i=0, bs-1) d_output[i] - // TODO(dcj): use thrust::fill or reduce kernel do this - if (bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, - weight_promoted->DataPtr(), CUDA_R_16BF, lda1, - grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, &beta, - grad_input->DataPtr(), CUDA_R_16BF, in_features, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT)); - CUBLAS_CHECK(cublasGemmEx(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, a2, CUDA_R_16BF, - lda2, b2, CUDA_R_16BF, ldb2, &beta, grad_weight->DataPtr(), - CUDA_R_16BF, ldc2, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - if (bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } - }), - DataType::kBFLOAT16) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, static_cast(a2), lda2, static_cast(b2), ldb2, &beta, static_cast(grad_weight->DataPtr()), ldc2));), DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmEx(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, a2, CUDA_R_16BF, lda2, b2, CUDA_R_16BF, ldb2, &beta, grad_weight->DataPtr(), CUDA_R_16BF, ldc2, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), DataType::kBFLOAT16) + } + + if (bias) { + constexpr int BLOCK_SIZE = 256; + int num_blocks = out_features; + DispatchFunc(promoted_type, [=]() { ReduceColumnsKernel<<>>(static_cast(grad_output->DataPtr()), static_cast(grad_bias->DataPtr()), out_features, bs); }, "CUDA LinearBackward"); } return {grad_input, grad_weight, grad_bias}; } + } // namespace infini_train::kernels::cuda -#define REGISTER_CUDA_LINEAR_KERNEL(kernel_name) \ +#define REGISTER_CUDA_LINEAR_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_LINEAR_KERNEL(MatmulForward) @@ -469,4 +278,4 @@ REGISTER_CUDA_LINEAR_KERNEL(MatmulBackward) REGISTER_CUDA_LINEAR_KERNEL(LinearForward) REGISTER_CUDA_LINEAR_KERNEL(LinearBackward) -#undef REGISTER_CUDA_LINEAR_KERNEL +#undef REGISTER_CUDA_LINEAR_KERNEL \ No newline at end of file