-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6d7f6ae
commit 642012c
Showing
46 changed files
with
8,185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[flake8] | ||
max-line-length = 79 | ||
extend-ignore = E203 | ||
exclude = | ||
.venv,build,dist,docs,examples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import csprng, fhe, utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .csprng import Csprng |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
// chacha20 is a mutating function. | ||
// That means the input is mutated and there's no need to return a value. | ||
|
||
// Check types. | ||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") | ||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") | ||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) | ||
|
||
|
||
// Forward declaration. | ||
void chacha20_cuda(torch::Tensor input, torch::Tensor dest, size_t step); | ||
|
||
std::vector<torch::Tensor> chacha20(std::vector<torch::Tensor> inputs, size_t step) { | ||
// The input must be a contiguous long tensor of size 16 x N. | ||
// Also, the tensor must be contiguous to enable pointer arithmetic, | ||
// and must be stored in a cuda device. | ||
// Note that the input is a vector of inputs in different devices. | ||
|
||
std::vector<torch::Tensor> outputs; | ||
|
||
for (auto &input : inputs){ | ||
CHECK_INPUT(input); | ||
|
||
// Prepare an output. | ||
auto dest = input.clone(); | ||
|
||
// Run in cuda. | ||
chacha20_cuda(input, dest, step); | ||
|
||
// Store to the dest. | ||
outputs.push_back(dest); | ||
} | ||
|
||
return outputs; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("chacha20", &chacha20, "CHACHA20 (CUDA)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include <torch/extension.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
#include "chacha20_cuda_kernel.h" | ||
|
||
#define BLOCK_SIZE 256 | ||
|
||
__global__ void chacha20_cuda_kernel( | ||
torch::PackedTensorAccessor32<int64_t, 2> input, | ||
torch::PackedTensorAccessor32<int64_t, 2> dest, | ||
size_t step) { | ||
|
||
// input is configured as N x 16 | ||
const int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
__shared__ int64_t x[BLOCK_SIZE][16]; | ||
|
||
#pragma unroll | ||
for(int i=0; i<16; ++i){ | ||
x[threadIdx.x][i] = dest[index][i]; | ||
|
||
// Vectorized load. | ||
// Not much beneficial for our case, though ... | ||
//reinterpret_cast<int4*>(x[threadIdx.x])[i] = | ||
// *reinterpret_cast<int4*>(&(dest[index][i])); | ||
} | ||
|
||
// Repeat 10 times for chacha20. | ||
#pragma unroll | ||
for(int i=0; i<10; ++i){ | ||
ONE_ROUND(x, threadIdx.x); | ||
} | ||
|
||
#pragma unroll | ||
for(int i=0; i<16; ++i){ | ||
dest[index][i] = (dest[index][i] + x[threadIdx.x][i]) & MASK; | ||
} | ||
|
||
// Step the state. | ||
input[index][12] += step; | ||
input[index][13] += (input[index][12] >> 32); | ||
input[index][12] &= MASK; | ||
} | ||
|
||
|
||
// The wrapper. | ||
void chacha20_cuda(torch::Tensor input, torch::Tensor dest, size_t step){ | ||
|
||
// Required number of blocks in a grid. | ||
// Note that we do not use grids here, since the | ||
// tensor we're dealing with must be chopped in 1-d. | ||
// input is configured as 16 x N | ||
// N must be a multitude of 1024. | ||
|
||
const int dim_block = BLOCK_SIZE; | ||
int dim_grid = input.size(0) / dim_block; | ||
|
||
// Retrieve the device index, then set the corresponding device and stream. | ||
auto device_id = input.device().index(); | ||
cudaSetDevice(device_id); | ||
auto stream = at::cuda::getCurrentCUDAStream(device_id); | ||
|
||
// Run the cuda kernel. | ||
auto input_acc = input.packed_accessor32<int64_t, 2>(); | ||
auto dest_acc = dest.packed_accessor32<int64_t, 2>(); | ||
chacha20_cuda_kernel<<<dim_grid, dim_block, 0, stream>>>(input_acc, dest_acc, step); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#define MASK 0xffffffff | ||
|
||
#define ROLL16(x) x = (((x) << 16) | ((x) >> 16)) & MASK | ||
#define ROLL12(x) x = (((x) << 12) | ((x) >> 20)) & MASK | ||
#define ROLL8(x) x = (((x) << 8) | ((x) >> 24)) & MASK | ||
#define ROLL7(x) x = (((x) << 7) | ((x) >> 25)) & MASK | ||
|
||
#define QR(x, ind, a, b, c, d)\ | ||
x[ind][a] += x[ind][b];\ | ||
x[ind][a] &= MASK;\ | ||
x[ind][d] ^= x[ind][a];\ | ||
ROLL16(x[ind][d]);\ | ||
x[ind][c] += x[ind][d];\ | ||
x[ind][c] &= MASK;\ | ||
x[ind][b] ^= x[ind][c];\ | ||
ROLL12(x[ind][b]);\ | ||
x[ind][a] += x[ind][b];\ | ||
x[ind][a] &= MASK;\ | ||
x[ind][d] ^= x[ind][a];\ | ||
ROLL8(x[ind][d]);\ | ||
x[ind][c] += x[ind][d];\ | ||
x[ind][c] &= MASK;\ | ||
x[ind][b] ^= x[ind][c];\ | ||
ROLL7(x[ind][b]) | ||
|
||
#define ONE_ROUND(x, ind)\ | ||
QR(x, ind, 0, 4, 8, 12);\ | ||
QR(x, ind, 1, 5, 9, 13);\ | ||
QR(x, ind, 2, 6, 10, 14);\ | ||
QR(x, ind, 3, 7, 11, 15);\ | ||
QR(x, ind, 0, 5, 10, 15);\ | ||
QR(x, ind, 1, 6, 11, 12);\ | ||
QR(x, ind, 2, 7, 8, 13);\ | ||
QR(x, ind, 3, 4, 9, 14) |
Oops, something went wrong.