Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2831,8 +2831,12 @@ struct ggml_cplan ggml_graph_plan(
const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // Channels In

cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32);

cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03;
cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12;

} break;
case GGML_OP_FLASH_ATTN_EXT:
{
Expand Down
69 changes: 50 additions & 19 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6833,16 +6833,15 @@ void ggml_compute_forward_conv_3d(
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
}

// ggml_compute_forward_conv_transpose_2d

void ggml_compute_forward_conv_transpose_2d(
const ggml_compute_params * params,
ggml_tensor * dst) {
template <typename kernel_t>
static void ggml_compute_forward_conv_transpose_2d_impl(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand All @@ -6853,20 +6852,20 @@ void ggml_compute_forward_conv_transpose_2d(

const int nk = ne00*ne01*ne02*ne03;

GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
GGML_ASSERT(nb10 == sizeof(float));

if (ith == 0) {
memset(params->wdata, 0, params->wsize);

// permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
kernel_t * const wdata = (kernel_t *) params->wdata + 0;

for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02);
kernel_t * dst_data = wdata + i02*ne01*ne00*ne03;
for (int64_t i01 = 0; i01 < ne01; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
Expand All @@ -6878,13 +6877,17 @@ void ggml_compute_forward_conv_transpose_2d(

// permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
kernel_t * const wdata = (kernel_t *) params->wdata + nk;
for (int i12 = 0; i12 < ne12; i12++) {
for (int i11 = 0; i11 < ne11; i11++) {
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
kernel_t * dst_data = wdata + i11*ne10*ne12;
for (int i10 = 0; i10 < ne10; i10++) {
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
} else {
dst_data[i10*ne12 + i12] = src[i10];
}
}
}
}
Expand All @@ -6906,21 +6909,27 @@ void ggml_compute_forward_conv_transpose_2d(
const int ip0 = dp*ith;
const int ip1 = MIN(ip0 + dp, np);

ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
ggml_fp16_t * const wdata_src = wdata + nk;
kernel_t * const wdata = (kernel_t *) params->wdata + 0;
kernel_t * const wdata_src = wdata + nk;

for (int i2 = ip0; i2 < ip1; i2++) { // Cout
float * dst_data = (float *)((char *) dst->data + i2*nb2);
ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
for (int i11 = 0; i11 < ne11; i11++) {
for (int i10 = 0; i10 < ne10; i10++) {
const int i1n = i11*ne10*ne12 + i10*ne12;
for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
ggml_vec_dot_f16(ne03, &v, 0,
wdata_src + i1n, 0,
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
ggml_vec_dot_f16(ne03, &v, 0,
wdata_src + i1n, 0,
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
} else {
ggml_vec_dot_f32(ne03, &v, 0,
wdata_src + i1n, 0,
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
}
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
}
}
Expand All @@ -6929,6 +6938,28 @@ void ggml_compute_forward_conv_transpose_2d(
}
}

void ggml_compute_forward_conv_transpose_2d(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_conv_2d_dw

struct ggml_conv_2d_dw_params {
Expand Down
66 changes: 45 additions & 21 deletions ggml/src/ggml-cuda/conv2d-transpose.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
#include <algorithm>

#include "conv2d-transpose.cuh"
#include "ggml.h"

__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
const int out_h, const int kernel_w, const int kernel_h, const int stride,
const int c_in, const int c_out, const int batches) {
#include "convert.cuh"

template <typename kernel_t>
static __global__ void conv2d_transpose_kernel(const float * __restrict__ input,
const kernel_t * __restrict__ kernel,
float * __restrict__ output,
const int in_w,
const int in_h,
const int out_w,
const int out_h,
const int kernel_w,
const int kernel_h,
const int stride,
const int c_in,
const int c_out,
const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;

const int total_elements = out_w * out_h * c_out * batches;
Expand All @@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y_idx - kh;
if (in_y < 0 || in_y % stride) continue;
if (in_y < 0 || in_y % stride) {
continue;
}
in_y /= stride;
if (in_y >= in_h) continue;
if (in_y >= in_h) {
continue;
}

for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x_idx - kw;
if (in_x < 0 || in_x % stride) continue;
if (in_x < 0 || in_x % stride) {
continue;
}
in_x /= stride;
if (in_x >= in_w) continue;
if (in_x >= in_w) {
continue;
}

const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
const int kernel_idx =
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;

float input_val = input[input_idx];
half kern_val = kernel[kernel_idx];
float input_val = input[input_idx];
kernel_t kern_val = kernel[kernel_idx];

accumulator += input_val * (float) kern_val;
accumulator += input_val * ggml_cuda_cast<float>(kern_val);
}
}
}
Expand All @@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];

GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);

const float * input_data = (const float *) input->data;
float * output_data = (float *) dst->data;
const half * kernel_data = (const half *) kernel->data;
const void * kernel_data = kernel->data;

const int input_w = input->ne[0];
const int input_h = input->ne[1];
Expand All @@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(ggml_is_contiguous(dst));

const int total = (output_w * output_h * channels_out * batches);
const int total = output_w * output_h * channels_out * batches;
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;

conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
if (kernel->type == GGML_TYPE_F16) {
conv2d_transpose_kernel<half><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w,
kernel_h, stride, channels_in, channels_out, batches);

} else {
conv2d_transpose_kernel<float><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w,
kernel_h, stride, channels_in, channels_out, batches);
}
}
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/conv2d-transpose.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.cuh"

#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256

void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
31 changes: 20 additions & 11 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4487,28 +4487,33 @@ struct test_conv_transpose_1d : public test_case {

// GGML_OP_CONV_TRANSPOSE_2D
struct test_conv_transpose_2d : public test_case {
// Dimensions
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
const int stride;
// Types
const ggml_type kernel_type;

std::string vars() override {
return VARS_TO_STR3(ne_input, ne_kernel, stride);
return VARS_TO_STR4(kernel_type, ne_input, ne_kernel, stride);
}

double max_nmse_err() override {
return 5e-4; // The default 1e-7 is too small for Vulkan.
}

test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int stride = 1)
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
test_conv_transpose_2d(
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int stride = 1,
ggml_type kernel_type = GGML_TYPE_F16
) : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), kernel_type(kernel_type) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_set_name(input, "input");

ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
ggml_tensor * kernel = ggml_new_tensor(ctx, kernel_type, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");

ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
Expand Down Expand Up @@ -6841,8 +6846,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
}

test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
Expand Down Expand Up @@ -7813,9 +7820,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));

test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
}

test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));

Expand Down