-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathext.cu
More file actions
20 lines (15 loc) · 861 Bytes
/
ext.cu
File metadata and controls
20 lines (15 loc) · 861 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <torch/extension.h>
// #include "float_grad.h"
// #include "vector_kernel_impl.h"
int test_floatgrad();
// torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B);
// torch::Tensor matmul_cuda_jvp(torch::Tensor A, torch::Tensor B);
// template <typename FloatTpye, int len>
// torch::Tensor float_dot_cuda(torch::Tensor A, torch::Tensor B);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("test_floatgrad", &test_floatgrad, "Test FloatGrad functionality");
// m.def("matmul_cuda", &matmul_cuda, "Matrix multiplication (CUDA)");
// m.def("matmul_cuda_jvp", &matmul_cuda_jvp, "Matrix multiplication (CUDA) with gradient propagation");
// m.def("float2_dot_cuda", &float_dot_cuda<float, 2>, "Float2 dot product (CUDA)");
// m.def("float2_dot_cuda_jvp", &float_dot_cuda<FloatGrad, 2>, "Float2 dot product with JVP (CUDA)");
}