diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c7348cc26c10c..8549b422bf74d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2831,8 +2831,12 @@ struct ggml_cplan ggml_graph_plan( const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32); + + cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03; + cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12; + } break; case GGML_OP_FLASH_ATTN_EXT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b6209588db1e4..a41b738b674bb 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6833,16 +6833,15 @@ void ggml_compute_forward_conv_3d( ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); } -// ggml_compute_forward_conv_transpose_2d - -void ggml_compute_forward_conv_transpose_2d( - const ggml_compute_params * params, - ggml_tensor * dst) { +template +static void ggml_compute_forward_conv_transpose_2d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -6853,7 +6852,7 @@ void ggml_compute_forward_conv_transpose_2d( const int nk = ne00*ne01*ne02*ne03; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == sizeof(float)); if (ith == 0) { @@ -6861,12 +6860,12 @@ void ggml_compute_forward_conv_transpose_2d( // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02); + kernel_t * dst_data = wdata + i02*ne01*ne00*ne03; for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; @@ -6878,13 +6877,17 @@ void ggml_compute_forward_conv_transpose_2d( // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + nk; for (int i12 = 0; i12 < ne12; i12++) { for (int i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + kernel_t * dst_data = wdata + i11*ne10*ne12; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + if constexpr (std::is_same_v) { + dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + } else { + dst_data[i10*ne12 + i12] = src[i10]; + } } } } @@ -6906,21 +6909,27 @@ void ggml_compute_forward_conv_transpose_2d( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; + kernel_t * const wdata_src = wdata + nk; for (int i2 = ip0; i2 < ip1; i2++) { // Cout float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; for (int i11 = 0; i11 < ne11; i11++) { for (int i10 = 0; i10 < ne10; i10++) { const int i1n = i11*ne10*ne12 + i10*ne12; for (int i01 = 0; i01 < ne01; i01++) { for (int i00 = 0; i00 < ne00; i00++) { float v = 0; - ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + if constexpr (std::is_same_v) { + ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } else { + ggml_vec_dot_f32(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; } } @@ -6929,6 +6938,28 @@ void ggml_compute_forward_conv_transpose_2d( } } +void ggml_compute_forward_conv_transpose_2d( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_conv_2d_dw struct ggml_conv_2d_dw_params { diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cu b/ggml/src/ggml-cuda/conv2d-transpose.cu index 03224e404d32d..6cbd6f879e6f5 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cu +++ b/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -1,12 +1,20 @@ -#include - #include "conv2d-transpose.cuh" -#include "ggml.h" - -__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, const int in_w, const int in_h, const int out_w, - const int out_h, const int kernel_w, const int kernel_h, const int stride, - const int c_in, const int c_out, const int batches) { +#include "convert.cuh" + +template +static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, + const kernel_t * __restrict__ kernel, + float * __restrict__ output, + const int in_w, + const int in_h, + const int out_w, + const int out_h, + const int kernel_w, + const int kernel_h, + const int stride, + const int c_in, + const int c_out, + const int batches) { const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; const int total_elements = out_w * out_h * c_out * batches; @@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { for (int kh = 0; kh < kernel_h; ++kh) { int in_y = out_y_idx - kh; - if (in_y < 0 || in_y % stride) continue; + if (in_y < 0 || in_y % stride) { + continue; + } in_y /= stride; - if (in_y >= in_h) continue; + if (in_y >= in_h) { + continue; + } for (int kw = 0; kw < kernel_w; ++kw) { int in_x = out_x_idx - kw; - if (in_x < 0 || in_x % stride) continue; + if (in_x < 0 || in_x % stride) { + continue; + } in_x /= stride; - if (in_x >= in_w) continue; + if (in_x >= in_w) { + continue; + } const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; - float input_val = input[input_idx]; - half kern_val = kernel[kernel_idx]; + float input_val = input[input_idx]; + kernel_t kern_val = kernel[kernel_idx]; - accumulator += input_val * (float) kern_val; + accumulator += input_val * ggml_cuda_cast(kern_val); } } } @@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * kernel = dst->src[0]; const ggml_tensor * input = dst->src[1]; - GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); const float * input_data = (const float *) input->data; float * output_data = (float *) dst->data; - const half * kernel_data = (const half *) kernel->data; + const void * kernel_data = kernel->data; const int input_w = input->ne[0]; const int input_h = input->ne[1]; @@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(ggml_is_contiguous(kernel)); GGML_ASSERT(ggml_is_contiguous(dst)); - const int total = (output_w * output_h * channels_out * batches); + const int total = output_w * output_h * channels_out * batches; const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; - conv2d_transpose_kernel<<>>( - input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, - channels_in, channels_out, batches); + if (kernel->type == GGML_TYPE_F16) { + conv2d_transpose_kernel<<>>( + input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + + } else { + conv2d_transpose_kernel<<>>( + input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + } } diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cuh b/ggml/src/ggml-cuda/conv2d-transpose.cuh index c9430b2485021..72889c5f0fa89 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cuh +++ b/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -1,4 +1,5 @@ #include "common.cuh" #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 + void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b11793963aa8e..f95ec7a577515 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4487,28 +4487,33 @@ struct test_conv_transpose_1d : public test_case { // GGML_OP_CONV_TRANSPOSE_2D struct test_conv_transpose_2d : public test_case { + // Dimensions const std::array ne_input; const std::array ne_kernel; const int stride; + // Types + const ggml_type kernel_type; std::string vars() override { - return VARS_TO_STR3(ne_input, ne_kernel, stride); + return VARS_TO_STR4(kernel_type, ne_input, ne_kernel, stride); } double max_nmse_err() override { return 5e-4; // The default 1e-7 is too small for Vulkan. } - test_conv_transpose_2d(std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] - std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] - int stride = 1) - : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){} + test_conv_transpose_2d( + std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] + std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] + int stride = 1, + ggml_type kernel_type = GGML_TYPE_F16 + ) : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), kernel_type(kernel_type) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); ggml_set_name(input, "input"); - ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data()); + ggml_tensor * kernel = ggml_new_tensor(ctx, kernel_type, 4, ne_kernel.data()); ggml_set_name(kernel, "kernel"); ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride); @@ -6841,8 +6846,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type)); + test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type)); + } test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1})); test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1})); @@ -7813,9 +7820,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true)); - test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1)); - test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1, kernel_type)); + test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1, kernel_type)); + test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type)); + } test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));