Skip to content

Conversation

@AgainstEntropy
Copy link

also updated test case in test-backend-ops.
But since F32 kernel type is not supported on CPU, only GGML_TYPE_F16 is kept and GGML_TYPE_F32 can be uncommented back in the future.

…types and improve parameter handling

- Introduced a `conv2d_transpose_params` struct for better parameter management.
- Updated `conv2d_transpose_kernel` to be templated for different kernel types (float and half).
- Modified `ggml_cuda_conv_2d_transpose_p0` to handle both F16 and F32 kernel types.
- Enhanced test cases to validate functionality for both kernel types.
…ernel types

- Updated `test_conv_transpose_2d` structure to improve parameter handling by reordering constructor arguments.
- Enhanced test case generation to iterate over kernel types, allowing for flexible testing of different configurations.
- Removed hardcoded kernel type instances in favor of a loop for better maintainability and scalability.
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 8, 2025
Copy link
Collaborator

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR make a difference to something? From what I understand, the kernel value is upcast into float before doing any accumulation (and accumulation is anyway in f32). So unless there are kernels around which don't fit into f16 I don't see a benefit to supporting this, especially when we don't support the f16 inputs yet (which incidentally might be more relevant than kernels being f32 as we could potentially do half2 multiplications)

input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
if (kernel->type == GGML_TYPE_F16) {
conv2d_transpose_cuda_f16(input_data, (const half *) kernel_data, output_data, params, st);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need separate cuda_f16 and cuda_f32 functions here, you can straight away dispatch here to conv2d_transpose_cuda<type> and remove those two functions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to conv2d.cu for the current dispatching manner, and I thought there is some convention in llama.cpp 😂

template <typename T>
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
}
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
}

if (kernel->type == GGML_TYPE_F16) {
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
} else {
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a convention and I would prefer to avoid 1 line functions

const int total;
};

template <typename T>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a better name for this should be kernel_t

@AgainstEntropy
Copy link
Author

AgainstEntropy commented Nov 12, 2025

Does this PR make a difference to something? From what I understand, the kernel value is upcast into float before doing any accumulation (and accumulation is anyway in f32). So unless there are kernels around which don't fit into f16 I don't see a benefit to supporting this, especially when we don't support the f16 inputs yet (which incidentally might be more relevant than kernels being f32 as we could potentially do half2 multiplications)

So the motivations of this PR are:

  1. Currently in ggml_backend_cuda_device_supports_op it always returns true for GGML_OP_CONV_TRANSPOSE_2D without checking the kernel type, thus may cause crashes when actually computing. This PR fixs this mismatching behavior.

    case GGML_OP_CONV_TRANSPOSE_2D:
    case GGML_OP_POOL_2D:
    case GGML_OP_ACC:
    return true;

  2. Some recent models are natively BF16, and using F16 kernel can lead to overflows. F32 is safe here and can be readily used for precision verification.

@am17an
Copy link
Collaborator

am17an commented Nov 13, 2025

So the motivations of this PR are:

  1. Currently in ggml_backend_cuda_device_supports_op it always returns true for GGML_OP_CONV_TRANSPOSE_2D without checking the kernel type, thus may cause crashes when actually computing. This PR fixs this mismatching behavior.

That's because it matches the CPU capabilities exactly

  1. Some recent models are natively BF16, and using F16 kernel can lead to overflows. F32 is safe here and can be readily used for precision verification.

That would be a problem in a conversion to GGUF, not necessarily a problem to be solved here.

@am17an
Copy link
Collaborator

am17an commented Nov 13, 2025

You should add the CPU version for the f32 kernel too, that way this PR makes more sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants