diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..dfac036 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 79 +extend-ignore = E203 +exclude = + .venv,build,dist,docs,examples diff --git a/liberate/__init__.py b/liberate/__init__.py new file mode 100644 index 0000000..3d081ee --- /dev/null +++ b/liberate/__init__.py @@ -0,0 +1 @@ +from . import csprng, fhe, utils diff --git a/liberate/csprng/__init__.py b/liberate/csprng/__init__.py new file mode 100644 index 0000000..d576d02 --- /dev/null +++ b/liberate/csprng/__init__.py @@ -0,0 +1 @@ +from .csprng import Csprng diff --git a/liberate/csprng/chacha20.cpp b/liberate/csprng/chacha20.cpp new file mode 100644 index 0000000..815182c --- /dev/null +++ b/liberate/csprng/chacha20.cpp @@ -0,0 +1,43 @@ +#include +#include + +// chacha20 is a mutating function. +// That means the input is mutated and there's no need to return a value. + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void chacha20_cuda(torch::Tensor input, torch::Tensor dest, size_t step); + +std::vector chacha20(std::vector inputs, size_t step) { + // The input must be a contiguous long tensor of size 16 x N. + // Also, the tensor must be contiguous to enable pointer arithmetic, + // and must be stored in a cuda device. + // Note that the input is a vector of inputs in different devices. + + std::vector outputs; + + for (auto &input : inputs){ + CHECK_INPUT(input); + + // Prepare an output. + auto dest = input.clone(); + + // Run in cuda. + chacha20_cuda(input, dest, step); + + // Store to the dest. + outputs.push_back(dest); + } + + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("chacha20", &chacha20, "CHACHA20 (CUDA)"); +} diff --git a/liberate/csprng/chacha20_cuda_kernel.cu b/liberate/csprng/chacha20_cuda_kernel.cu new file mode 100644 index 0000000..429a894 --- /dev/null +++ b/liberate/csprng/chacha20_cuda_kernel.cu @@ -0,0 +1,69 @@ +#include +#include +#include +#include + +#include "chacha20_cuda_kernel.h" + +#define BLOCK_SIZE 256 + +__global__ void chacha20_cuda_kernel( + torch::PackedTensorAccessor32 input, + torch::PackedTensorAccessor32 dest, + size_t step) { + + // input is configured as N x 16 + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int64_t x[BLOCK_SIZE][16]; + + #pragma unroll + for(int i=0; i<16; ++i){ + x[threadIdx.x][i] = dest[index][i]; + + // Vectorized load. + // Not much beneficial for our case, though ... + //reinterpret_cast(x[threadIdx.x])[i] = + // *reinterpret_cast(&(dest[index][i])); + } + + // Repeat 10 times for chacha20. + #pragma unroll + for(int i=0; i<10; ++i){ + ONE_ROUND(x, threadIdx.x); + } + + #pragma unroll + for(int i=0; i<16; ++i){ + dest[index][i] = (dest[index][i] + x[threadIdx.x][i]) & MASK; + } + + // Step the state. + input[index][12] += step; + input[index][13] += (input[index][12] >> 32); + input[index][12] &= MASK; +} + + +// The wrapper. +void chacha20_cuda(torch::Tensor input, torch::Tensor dest, size_t step){ + + // Required number of blocks in a grid. + // Note that we do not use grids here, since the + // tensor we're dealing with must be chopped in 1-d. + // input is configured as 16 x N + // N must be a multitude of 1024. + + const int dim_block = BLOCK_SIZE; + int dim_grid = input.size(0) / dim_block; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = input.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Run the cuda kernel. + auto input_acc = input.packed_accessor32(); + auto dest_acc = dest.packed_accessor32(); + chacha20_cuda_kernel<<>>(input_acc, dest_acc, step); +} diff --git a/liberate/csprng/chacha20_cuda_kernel.h b/liberate/csprng/chacha20_cuda_kernel.h new file mode 100644 index 0000000..2e445aa --- /dev/null +++ b/liberate/csprng/chacha20_cuda_kernel.h @@ -0,0 +1,34 @@ +#define MASK 0xffffffff + +#define ROLL16(x) x = (((x) << 16) | ((x) >> 16)) & MASK +#define ROLL12(x) x = (((x) << 12) | ((x) >> 20)) & MASK +#define ROLL8(x) x = (((x) << 8) | ((x) >> 24)) & MASK +#define ROLL7(x) x = (((x) << 7) | ((x) >> 25)) & MASK + +#define QR(x, ind, a, b, c, d)\ + x[ind][a] += x[ind][b];\ + x[ind][a] &= MASK;\ + x[ind][d] ^= x[ind][a];\ + ROLL16(x[ind][d]);\ + x[ind][c] += x[ind][d];\ + x[ind][c] &= MASK;\ + x[ind][b] ^= x[ind][c];\ + ROLL12(x[ind][b]);\ + x[ind][a] += x[ind][b];\ + x[ind][a] &= MASK;\ + x[ind][d] ^= x[ind][a];\ + ROLL8(x[ind][d]);\ + x[ind][c] += x[ind][d];\ + x[ind][c] &= MASK;\ + x[ind][b] ^= x[ind][c];\ + ROLL7(x[ind][b]) + +#define ONE_ROUND(x, ind)\ + QR(x, ind, 0, 4, 8, 12);\ + QR(x, ind, 1, 5, 9, 13);\ + QR(x, ind, 2, 6, 10, 14);\ + QR(x, ind, 3, 7, 11, 15);\ + QR(x, ind, 0, 5, 10, 15);\ + QR(x, ind, 1, 6, 11, 12);\ + QR(x, ind, 2, 7, 8, 13);\ + QR(x, ind, 3, 4, 9, 14) diff --git a/liberate/csprng/chacha20_naive.py b/liberate/csprng/chacha20_naive.py new file mode 100644 index 0000000..5632f48 --- /dev/null +++ b/liberate/csprng/chacha20_naive.py @@ -0,0 +1,246 @@ +import binascii +import math +import os + +import torch + +torch.backends.cudnn.benchmark = True + + +@torch.jit.script +def roll(x: torch.Tensor, s: int) -> None: + """ + x's dtype must be torch.int64. + We are kind of forced to do this because + 1. pytorch doesn't support unit32, and + 2. >> doesn't move the sign bit. + """ + mask = 0xFFFFFFFF + right_shift = 32 - s + down = (x & mask) >> right_shift + # x <<= s + # x |= down + x.__ilshift__(s).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll16(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 16 + x.__ilshift__(16).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll12(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 20 + x.__ilshift__(12).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll8(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 24 + x.__ilshift__(8).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll7(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 25 + x.__ilshift__(7).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def QR(x: torch.Tensor, a: int, b: int, c: int, d: int) -> None: + """ + The CHACHA quarter round. + """ + mask = 0xFFFFFFFF + + x[a].add_(x[b]) + x[a].bitwise_and_(mask) + x[d].bitwise_xor_(x[a]) + roll16(x[d]) + + x[c].add_(x[d]) + x[c].bitwise_and_(mask) + x[b].bitwise_xor_(x[c]) + roll12(x[b]) + + x[a].add_(x[b]) + x[a].bitwise_and_(mask) + x[d].bitwise_xor_(x[a]) + roll8(x[d]) + + x[c].add_(x[d]) + x[c].bitwise_and_(mask) + x[b].bitwise_xor_(x[c]) + roll7(x[b]) + + +@torch.jit.script +def one_round(x: torch.Tensor) -> None: + # Odd round. + QR(x, 0, 4, 8, 12) + QR(x, 1, 5, 9, 13) + QR(x, 2, 6, 10, 14) + QR(x, 3, 7, 11, 15) + # Even round. + QR(x, 0, 5, 10, 15) + QR(x, 1, 6, 11, 12) + QR(x, 2, 7, 8, 13) + QR(x, 3, 4, 9, 14) + + +@torch.jit.script +def increment_counter(state: torch.Tensor, inc: int) -> None: + state[12] += inc + state[13] += state[12] >> 32 + state[12] = state[12] & 0xFFFFFFFF + + +@torch.jit.script +def chacha20(state: torch.Tensor) -> torch.Tensor: + x = state.clone() + + for _ in range(10): + one_round(x) + + # Return the random bytes. + return (x + state) & 0xFFFFFFFF + + +class chacha20_naive: + def __init__( + self, size, seed=None, nonce=None, count_step=1, device="cuda:0" + ): + self.size = size + self.device = device + self.count_step = count_step + + # expand 32-byte k. + # This is 1634760805, 857760878, 2036477234, 1797285236. + str2ord = lambda s: sum([2 ** (i * 8) * c for i, c in enumerate(s)]) + self.nothing_up_my_sleeve = torch.tensor( + [ + str2ord(b"expa"), + str2ord(b"nd 3"), + str2ord(b"2-by"), + str2ord(b"te k"), + ], + device=device, + dtype=torch.int64, + ) + + # Prepare the state tensor. + self.state_size = (*self.size, 16) + self.state_buffer = torch.zeros( + 16, (math.prod(self.size)), dtype=torch.int64, device=self.device + ) + + # The ind counter. + self.ind = torch.arange( + 0, math.prod(self.size), self.count_step, device=device + ) + + # Increment is the number of indices. + self.inc = math.prod(self.size) + + self.initialize_state(seed, nonce) + + # Capture computation graph. + self.capture() + + def initialize_state(self, seed=None, nonce=None): + # Generate seed if necessary. + self.key(seed) + + # Generate nonce if necessary. + self.generate_nonce(nonce) + + # Zero out the state. + self.state_buffer.zero_() + + # Set the counter. + self.state_buffer[12, :] = self.ind + + # Set the expand 32-bye k + self.state_buffer[0:4, :] = self.nothing_up_my_sleeve[:, None] + + # Set the seed. + self.state_buffer[4:12, :] = self.seed[:, None] + + # Fill in nonce. + self.state_buffer[14:, :] = self.nonce[:, None] + + def key(self, seed=None): + # 256bits seed as a key. + if seed is None: + # 256bits key as a seed, + nbytes = 32 + part_bytes = 4 + n_keys = nbytes // part_bytes + hex2int = lambda x, nbytes: int(binascii.hexlify(x), 16) + self.seed = torch.tensor( + [ + hex2int(os.urandom(part_bytes), part_bytes) + for _ in range(n_keys) + ], + device=self.device, + dtype=torch.int64, + ) + else: + self.seed = torch.tensor( + seed, device=self.device, dtype=torch.int64 + ) + + def generate_nonce(self, nonce): + # nonce is 64bits. + if nonce is None: + # 256bits key as a seed, + nbytes = 8 + part_bytes = 4 + n_keys = nbytes // part_bytes + hex2int = lambda x, nbytes: int(binascii.hexlify(x), 16) + self.nonce = torch.tensor( + [ + hex2int(os.urandom(part_bytes), part_bytes) + for _ in range(n_keys) + ], + device=self.device, + dtype=torch.int64, + ) + else: + self.nonce = torch.tensor( + nonce, device=self.device, dtype=torch.int64 + ) + + def capture(self, warmup_periods=3, fuser="fuser1"): + with torch.cuda.device(self.device): + # Reserve ample amount of excution cache. + torch.jit.set_fusion_strategy([("STATIC", 100)]) + + # Output buffer. + self.out = torch.zeros_like(self.state_buffer) + + # Warm up. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s), torch.jit.fuser(fuser): + for _ in range(warmup_periods): + self.out = chacha20(self.state_buffer) + torch.cuda.current_stream().wait_stream(s) + + # Capture. + self.g = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.g): + self.out = chacha20(self.state_buffer) + + def step(self): + increment_counter(self.state_buffer, self.inc) + + def randbyte(self): + self.g.replay() + self.step() + return self.out.clone() diff --git a/liberate/csprng/csprng.py b/liberate/csprng/csprng.py new file mode 100644 index 0000000..77110da --- /dev/null +++ b/liberate/csprng/csprng.py @@ -0,0 +1,306 @@ +import binascii +import os + +import numpy as np +import torch + +from . import ( + chacha20_cuda, + discrete_gaussian_cuda, + randint_cuda, + randround_cuda, +) +from .discrete_gaussian_sampler import build_CDT_binary_search_tree + +torch.backends.cudnn.benchmark = True + + +class Csprng: + def __init__(self, num_coefs=2 ** 15, num_channels=[8], num_repeating_channels=2, + sigma=3.2, devices=None, seed=None, nonce=None): + """N is the length of the polynomial, and C is the number of RNS channels. + procure the maximum (at level zero, special multiplication) at initialization.""" + + # This CSPRNG class generates + # 1. num_coefs x (num_channels + num_repeating_channels) uniform distributed + # random numbers at max. num_channels can be reduced down according + # to the input q at the time of generation. + # The numbers generated ri are 0 <= ri < qi. + # Seeds in the repeated channels are the same, and hence in those + # channels, the generated numbers are the same across GPUs. + # Generation of the repeated random numbers is optional. + # The same function can be used to generate ranged random integers + # in a fixed range. Again, generation of the repeated numbers is optional. + # 2. Generation of Discrete Gaussian random numbers. The numbers can be generated + # in the non-repeating channels (with the maximum number of channels num_channels), + # or in the repeating channels (with the maximum number of + # channels num_repeating_channels, where in the most typical scenario is same as 1). + + self.num_coefs = num_coefs + self.num_channels = num_channels + self.num_repeating_channels = num_repeating_channels + self.sigma = sigma + + # Set up GPUs. + # By default, use all the available GPUs on the system. + if devices is None: + gpu_count = torch.cuda.device_count() + self.devices = [f'cuda:{i}' for i in range(gpu_count)] + else: + self.devices = devices + + self.num_devices = len(self.devices) + + # Compute shares of channels per a GPU. + if len(self.num_channels) == 1: + # Allocate the same number of channels per every GPU. + self.shares = [self.num_channels[0]] * self.num_devices + elif len(self.num_channels) == self.num_devices: + self.shares = self.num_channels + else: + # User input was contradicting. + raise Exception("There was a contradicting mismatch between " + "num_channels, and devices.") + + # How many channels in total? + self.total_num_channels = sum(self.shares) + + # We generate random bytes 4x4 = 16 per an array and hence, + # internally only need to procure N // 4 length arrays. + # Out of the 16, we generate discrete gaussian or uniform + # samples 4 at a time. + self.L = self.num_coefs // 4 + + # We build binary search tree for discrete gaussian here. + self.btree, self.btree_ptr, self.btree_size, self.tree_depth = \ + build_CDT_binary_search_tree(security_bits=128, sigma=sigma) + + # Counter range at each GPU. + # Note that the counters only include the non-repeating channels. + # We can add later counters at the end that start from the + # self.total_num_channels * self.L + self.start_ind = [0] + [s * self.L for s in self.shares[:-1]] + self.ind_increments = [s * self.L for s in self.shares] + self.end_ind = [s + e for s, e in zip(self.start_ind, self.ind_increments)] + + # Total increment to add to counters after each random bytes generation. + self.inc = (self.total_num_channels + self.num_repeating_channels) * self.L + self.repeating_start = self.total_num_channels * self.L + + # expand 32-byte k. + # This is 1634760805, 857760878, 2036477234, 1797285236. + str2ord = lambda s: sum([2 ** (i * 8) * c for i, c in enumerate(s)]) + self.nothing_up_my_sleeve = [] + for device in self.devices: + str_constant = torch.tensor( + [ + str2ord(b'expa'), str2ord(b'nd 3'), str2ord(b'2-by'), str2ord(b'te k') + ], device=device, dtype=torch.int64 + ) + self.nothing_up_my_sleeve.append(str_constant) + + # Prepare the state tensors. + self.states = [] + for dev_id in range(self.num_devices): + state_size = ( + (self.shares[dev_id] + self.num_repeating_channels) * self.L, 16) + state = torch.zeros( + state_size, + dtype=torch.int64, + device=self.devices[dev_id] + ) + self.states.append(state) + + # Prepare a channeled views. + self.channeled_states = [ + self.states[i].view( + self.shares[i] + self.num_repeating_channels, + self.L, -1) for i in range(self.num_devices) + ] + + # The counter. + self.counters = [] + repeating_counter = list(range(self.repeating_start, self.inc)) + for dev_id in range(self.num_devices): + counter = list(range( + self.start_ind[dev_id], self.end_ind[dev_id])) + repeating_counter + + counter_tensor = torch.tensor(counter, + dtype=torch.int64, + device=self.devices[dev_id]) + self.counters.append(counter_tensor) + + self.refresh(seed, nonce) + + def refresh(self, seed=None, nonce=None): + # Generate seed if necessary. + self.key = self.generate_key(seed) + + # Generate nonce if necessary. + self.nonce = self.generate_nonce(nonce) + + # Iterate over all devices. + for dev_id in range(self.num_devices): + self.initialize_states(dev_id, seed, nonce) + + def initialize_states(self, dev_id, seed=None, nonce=None): + + state = self.states[dev_id] + state.zero_() + + # Set the counter. + # It is hardly unlikely we will use CxL > 2**32. + # Just fill in the 12th element + # (The lower bytes of the counter). + state[:, 12] = self.counters[dev_id][None, :] + + # Set the expand 32-bye k + state[:, 0:4] = self.nothing_up_my_sleeve[dev_id][None, :] + + # Set the seed. + state[:, 4:12] = self.key[dev_id][None, :] + + # Fill in nonce. + state[:, 14:] = self.nonce[dev_id][None, :] + + def generate_initial_bytes(self, nbytes, part_bytes=4, seed=None): + seeds = [] + if seed is None: + n_keys = nbytes // part_bytes + hex2int = lambda x, nbytes: int(binascii.hexlify(x), 16) + seed0 = [ + hex2int(os.urandom(part_bytes), part_bytes) for _ in range(n_keys) + ] + for dev_id in range(self.num_devices): + cuda_seed = torch.tensor( + seed0, + dtype=torch.int64, + device=self.devices[dev_id] + ) + seeds.append(cuda_seed) + else: + seed0 = seed + for dev_id in range(self.num_devices): + cuda_seed = torch.tensor( + seed0, + dtype=torch.int64, + device=self.devices[dev_id] + ) + seeds.append(cuda_seed) + return seeds + + def generate_key(self, seed): + # 256bits seed as a key. + # We generate the same key seed for every GPU. + # Randomity is produced by counters, not the key. + return self.generate_initial_bytes(32, seed=None) + + def generate_nonce(self, seed): + # nonce is 64bits. + return self.generate_initial_bytes(8, seed=None) + + def randbytes(self, shares=None, repeats=0, length=None, reshape=False): + # Generates (shares_i + repeats) X length random bytes. + if shares is None: + shares = self.shares + + if length is None: + L = self.L + else: + L = length + + # Set the target states. + target_states = [] + for devi in range(self.num_devices): + start_channel = self.shares[devi] - shares[devi] + end_channel = self.shares[devi] + repeats + device_states = self.channeled_states[devi][ + start_channel:end_channel, :L, :] + target_states.append(device_states.view(-1, 16)) + + # Derive random bytes. + random_bytes = chacha20_cuda.chacha20(target_states, self.inc) + + # If not reshape, flatten. + if reshape: + random_bytes = [rb.view(-1, L, 16) for rb in random_bytes] + + return random_bytes + + def randint(self, amax=3, shift=0, repeats=0, length=None): + # The default values are for generating the same uniform ternary + # arrays in all GPUs. + + if not isinstance(amax, (list, tuple)): + amax = [[amax] for share in self.shares] + + if length is None: + L = self.L + else: + L = length + + # Calculate shares. + # If repeats are greater than 0, those channels are + # subtracted from shares. + shares = [len(am) - repeats for am in amax] + + # Convert the amax list to contiguous numpy array pointers. + q_conti = [np.ascontiguousarray(q, dtype=np.uint64) for q in amax] + q_ptr = [q.__array_interface__['data'][0] for q in q_conti] + + # Set the target states. + target_states = [] + for devi in range(self.num_devices): + start_channel = self.shares[devi] - shares[devi] + end_channel = self.shares[devi] + repeats + device_states = self.channeled_states[devi][ + start_channel:end_channel, :L, :] + target_states.append(device_states) + + # Generate the randint. + rand_int = randint_cuda.randint_fast( + target_states, q_ptr, shift, self.inc) + + return rand_int + + def discrete_gaussian(self, non_repeats=0, repeats=1, length=None): + + if not isinstance(non_repeats, (list, tuple)): + shares = [non_repeats] * self.num_devices + else: + shares = non_repeats + + if length is None: + L = self.L + else: + L = length + + # Set the target states. + target_states = [] + for devi in range(self.num_devices): + start_channel = self.shares[devi] - shares[devi] + end_channel = self.shares[devi] + repeats + device_states = self.channeled_states[devi][ + start_channel:end_channel, :L, :] + target_states.append(device_states.view(-1, 16)) + + # Generate the randint. + rand_int = discrete_gaussian_cuda.discrete_gaussian_fast(target_states, + self.btree_ptr, + self.btree_size, + self.tree_depth, + self.inc) + # Reformat the rand_int. + rand_int = [ri.view(-1, self.num_coefs) for ri in rand_int] + + return rand_int + + def randround(self, coef): + """Randomly round coef. Coef must be a double tensor. + coef must reside in the fist GPU in the GPUs list""" + + L = self.num_coefs // 16 + rand_bytes = chacha20_cuda.chacha20( + (self.states[0][:L],), self.inc)[0].ravel() + randround_cuda.randround([coef], [rand_bytes]) + return rand_bytes diff --git a/liberate/csprng/discrete_gaussian.cpp b/liberate/csprng/discrete_gaussian.cpp new file mode 100644 index 0000000..ee2d66f --- /dev/null +++ b/liberate/csprng/discrete_gaussian.cpp @@ -0,0 +1,69 @@ + +#include +#include + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void discrete_gaussian_cuda(torch::Tensor rand_bytes, + uint64_t* btree, + int btree_size, + int depth); + +torch::Tensor discrete_gaussian_fast_cuda(torch::Tensor states, + uint64_t* btree, + int btree_size, + int depth, + size_t step); + + + +// The main function. +//---------------- +// Normal version. +void discrete_gaussian(std::vector inputs, + size_t btree_ptr, + int btree_size, + int depth) { + + // reinterpret pointers from numpy. + uint64_t *btree = reinterpret_cast(btree_ptr); + + for (auto &rand_bytes : inputs){ + CHECK_INPUT(rand_bytes); + + // Run in cuda. + discrete_gaussian_cuda(rand_bytes, btree, btree_size, depth); + } +} + +//-------------- +// Fast version. + +std::vector discrete_gaussian_fast(std::vector states, + size_t btree_ptr, + int btree_size, + int depth, + size_t step) { + + // reinterpret pointers from numpy. + uint64_t *btree = reinterpret_cast(btree_ptr); + + std::vector outputs; + + for (auto &my_states : states){ + auto result = discrete_gaussian_fast_cuda(my_states, btree, btree_size, depth, step); + outputs.push_back(result); + } + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("discrete_gaussian", &discrete_gaussian, "discrete gaussian sampling (128 bits)."); + m.def("discrete_gaussian_fast", &discrete_gaussian_fast, "discrete gaussian sampling fast (chacha20 fused, 128 bits)."); +} diff --git a/liberate/csprng/discrete_gaussian_cuda_kernel.cu b/liberate/csprng/discrete_gaussian_cuda_kernel.cu new file mode 100644 index 0000000..118ce23 --- /dev/null +++ b/liberate/csprng/discrete_gaussian_cuda_kernel.cu @@ -0,0 +1,237 @@ +#include +#include +#include +#include + +#include "chacha20_cuda_kernel.h" + +#define GE(x_high, x_low, y_high, y_low)\ + (((x_high) > (y_high)) | (((x_high) == (y_high)) & ((x_low) >= (y_low)))) + +#define COMBINE_TWO(high, low)\ + ((static_cast(high) << 32) | static_cast(low)) + +// Use 256 BLOCK_SIZE. 64 * 4 = 256. +#define BLOCK_SIZE 64 + +#define LUT_SIZE 128 +__constant__ uint64_t LUT[LUT_SIZE]; + +/////////////////////////////////////////////////////////// +// The implementation + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +__global__ void discrete_gaussian_fast_cuda_kernel( + torch::PackedTensorAccessor32 states, + torch::PackedTensorAccessor32 dst, + int btree_size, + int depth, + size_t step){ + + // Where am I? + const int thread_ind = threadIdx.x; + const int poly_order = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int64_t x[BLOCK_SIZE][16]; + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = states[poly_order][i]; + } + + // Repeat 10 times for chacha20. + #pragma unroll + for(int i=0; i<10; ++i){ + ONE_ROUND(x, thread_ind); + } + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = (states[poly_order][i] + x[thread_ind][i]) & MASK; + } + + // Step the state. + states[poly_order][12] += step; + states[poly_order][13] += (states[poly_order][12] >> 32); + states[poly_order][12] &= MASK; + + + // Discrete gaussian + for(int i=0; i<16; i+=4){ + // Traverse the tree in the LUT. + // Note that, out of the 16 32-bit randon numbers, + // we generate 4 discrete gaussian samples. + int jump = 1; + int current = 0; + int counter = 0; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(x[thread_ind][i], x[thread_ind][i+1]); + uint64_t x_high = COMBINE_TWO(x[thread_ind][i+2], x[thread_ind][i+3]); + + // Reserve a sign bit. + // Since we are dealing with the half plane, + // The CDT values in the LUT are at most 0.5, which means + // that the values are 127 bits. + // Also, rigorously speaking, we need to take out the MSB from the x_high + // value to take the sign, but every bit has probability of occurrence=0.5. + // Hence, it doesn't matter where we take the bit. + // For convenience, take the LSB of x_high. + int64_t sign_bit = x_high & 1; + x_high >>= 1; + + // Traverse the binary search tree. + for(int j=0; j(current); + + // Store the result. + const int new_poly_order = poly_order * 4 + i/4; + dst[new_poly_order] = sample; + } +} + + + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + +// rand_bytes are configured as N x 16. +__global__ void discrete_gaussian_cuda_kernel( + torch::PackedTensorAccessor32 rand_bytes, + int btree_size, + int depth){ + // Where am I? + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + // i is the index of the starting element at the threadIdx.y row. + const int i = threadIdx.y * 4; + + // Traverse the tree in the LUT. + // Note that, out of the 16 32-bit randon numbers, + // we generate 4 discrete gaussian samples. + int jump = 1; + int current = 0; + int counter = 0; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(rand_bytes[index][i], rand_bytes[index][i+1]); + uint64_t x_high = COMBINE_TWO(rand_bytes[index][i+2], rand_bytes[index][i+3]); + + // Reserve a sign bit. + // Since we are dealing with the half plane, + // The CDT values in the LUT are at most 0.5, which means + // that the values are 127 bits. + // Also, rigorously speaking, we need to take out the MSB from the x_high + // value to take the sign, but every bit has probability of occurrence=0.5. + // Hence, it doesn't matter where we take the bit. + // For convenience, take the LSB of x_high. + int64_t sign_bit = x_high & 1; + x_high >>= 1; + + // Traverse the binary search tree. + for(int j=0; j(current); + + // Store the result. + rand_bytes[index][i] = sample; +} + + + +/////////////////////////////////////////////////////////// +// The wrapper. + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +torch::Tensor discrete_gaussian_fast_cuda(torch::Tensor states, + uint64_t* btree, + int btree_size, + int depth, + size_t step){ + + // rand_bytes has the dim N x 16. + int dim_block = BLOCK_SIZE; + int dim_grid = states.size(0) / dim_block; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = states.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Prepare the result. + // 16 elements in each state turn into 4 random numbers. + auto result = states.new_empty({states.size(0) * 4}); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(LUT, btree, btree_size * 2 * sizeof(uint64_t)); + + // Run the cuda kernel. + auto states_acc = states.packed_accessor32(); + auto result_acc = result.packed_accessor32(); + discrete_gaussian_fast_cuda_kernel<<>>( + states_acc, result_acc, btree_size, depth, step); + + return result; +} + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + +void discrete_gaussian_cuda(torch::Tensor rand_bytes, + uint64_t* btree, + int btree_size, + int depth){ + + // rand_bytes has the dim N x 16. + dim3 dim_block(BLOCK_SIZE, 4); + int dim_grid = rand_bytes.size(0) / dim_block.x; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = rand_bytes.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(LUT, btree, btree_size * 2 * sizeof(uint64_t)); + + // Run the cuda kernel. + auto access = rand_bytes.packed_accessor32(); + discrete_gaussian_cuda_kernel<<>>(access, btree_size, depth); +} diff --git a/liberate/csprng/discrete_gaussian_sampler.py b/liberate/csprng/discrete_gaussian_sampler.py new file mode 100644 index 0000000..3f43037 --- /dev/null +++ b/liberate/csprng/discrete_gaussian_sampler.py @@ -0,0 +1,156 @@ +# Sample from discrete gaussian distribution. +import math + +import mpmath as mpm +import numpy as np +import torch + +# import discrete_gaussian_cuda +from . import discrete_gaussian_cuda + + +def build_CDT_binary_search_tree(security_bits=128, sigma=3.2): + """Currently, ONLY the discrete gaussian sampling at the default input values + is supported. That is equivalent to the 128 security measure.""" + # Set accuracy to 258 bits = (128 * 2) bits. + # We will use higher 128 bits for the 128 security model, + # Hence, retaining the 256 bits are safe for intermediate calculations + # to carry over all the significant digits. + + mpm.mp.prec = security_bits * 2 + + # Truncation boundary. + # The minimum recommended by + # Shi Bai, Adeline Langlois, Tancr`ede Lepoint, Damien Stehl ́e, + # and Ron Ste- infeld. Improved security proofs in lattice-based cryptography: + # Using the r ́enyi divergence rather than the statistical distance., + # as tau = 6 sigma. + # We want the number tau to be the power of 2 in fact, since it makes the + # binary tree search constant time. Using a larger tau the the minumum required + # is no problem in fact and in terms of tree traversing it doesn't score a performance + # hit because the tree will be balanced as a result. + # So we calculate the smallest power of 2 bigger than the minimum tau as the number + # of sampling points. + sampling_power = math.ceil(math.log2(6 * sigma)) + num_sampling_points = 2 ** sampling_power + sampling_points = list(range(num_sampling_points)) + + # Calculate probabilities at sampling points. + # Be careful when converting the python float to mpmath float. + # No mormalization is done and the mpmath tries to retain the + # bit pattern of the original float. + # As a result, when you do mpm.mpf(3.2), you get + # mpf('3.20000000000000017763568394002504646778106689453125'). + # As a workaround, we can do mpm.mpf('3.2') to get + # mpf('3.200000000000000000000000000000000000000000000000000000000000000000000000000007') + mp_sigma = mpm.mpf(str(sigma)) + mp_two = mpm.mpf("2") + S = mp_sigma * mpm.sqrt(mp_two * mpm.pi) + discrete_gaussian_prob = ( + lambda x: mpm.exp(-mpm.mpf(str(x)) ** 2 / (mp_two * mp_sigma ** 2)) / S + ) + gaussian_prob_at_sampling_points = [ + discrete_gaussian_prob(x) for x in sampling_points + ] + + # We need to halve the probability at 0. + # We need to take into account the effect of symmetry, and we have + # only calculated the probability for the half section. + gaussian_prob_at_sampling_points[0] /= 2 + + # Now, calculate the Cumulative Distribution Table. + CDT = [0] + for P in gaussian_prob_at_sampling_points: + CDT.append(CDT[-1] + P) + + # At this point, we should end up with CDT[-1] which is very close to 0.5. + # This makes sense because we are calculating the CDT of the half plane. + # That reduces the effective bits in the CDT to 127, not 128. + # This again makes sense because we need to reserve 1 bit for sign. + + # We need 128 bits integer representation of the CDT. + CDT = [int(x * mp_two ** mpm.mpf(str(security_bits))) for x in CDT] + + # Chop the numbers down in a series of 64 bit integers. + num_chops = security_bits // 64 + chopped_CDT = [] + mask = (1 << 64) - 1 + for chop in range(num_chops): + chopped_share = [(x >> (64 * chop)) & mask for x in CDT] + chopped_CDT.append(chopped_share) + + # Now we can put the chopped CDT into a numpy array. + # All the numbers in the lists are representable by int64s. + # We transpose the resulting array to make it configured as N x 2. + CDT_table = np.array(chopped_CDT, dtype=np.uint64).T + + # We want to search through this table efficiently. + # Build a binary tree. + # Note that the last leaf is the sampled values. + # The last leaf index will be calculated in-place at runtime, + # Thus ommitted. + tree_depth = sampling_power + CDT_binary_tree = [] + for depth in range(tree_depth): + num_nodes = 2 ** depth + node_index_step = num_sampling_points // num_nodes + first_node_index = num_sampling_points // num_nodes // 2 + node_indices = list( + range(first_node_index, num_sampling_points, node_index_step) + ) + CDT_binary_tree += node_indices + # Use 1D expanded binary tree. + # See https://en.wikipedia.org/wiki/Binary_tree#Arrays. + btree = CDT_table[CDT_binary_tree] + + # Return the CType pointer together with the array. + btree_size = btree.shape[0] + btree_conti = np.ascontiguousarray(btree.T.ravel(), dtype=np.uint64) + btree_ptr = btree_conti.__array_interface__["data"][0] + + # The returned tree has probably has 31 x 2 dimension. + # The 31 for the number of nodes, and + # the 2 for (lower 64 bits, higher 64 bits). + return btree, btree_ptr, btree_size, tree_depth + + +def test_discrete_gaussian(N): + btree, depth = build_CDT_binary_search_tree() + + rand = np.random.randint(0, 2 ** 64, size=(N, 2), dtype=np.uint64) + + GE = lambda x_high, x_low, y_high, y_low: ( + ((x_high) > (y_high)) | (((x_high) == (y_high)) & ((x_low) >= (y_low))) + ) + result = [] + for r in rand: + jump = 1 + current = 0 + counter = 0 + + sign_bit = int(r[0]) & 1 + r_high = int(r[0]) >> 1 + r_low = r[1] + + for j in range(depth): + ge_flag = GE( + r_high, + r_low, + btree[counter + current, 1], + btree[counter + current, 0], + ) + # ge_flag = (r_high > btree[counter+current, 1]) + + # Update current location. + current = 2 * current + int(ge_flag) + + # Update counter. + counter += jump + + # Update jump + jump *= 2 + + sample = (sign_bit * 2 - 1) * current + result.append(sample) + + return result diff --git a/liberate/csprng/randint.cpp b/liberate/csprng/randint.cpp new file mode 100644 index 0000000..f6ac564 --- /dev/null +++ b/liberate/csprng/randint.cpp @@ -0,0 +1,55 @@ +#include +#include + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void randint_cuda(torch::Tensor rand_bytes, uint64_t *q); +torch::Tensor randint_fast_cuda(torch::Tensor states, uint64_t *q, int64_t shift, size_t step); + + + +// The main function. +//---------------- +// Normal version. +void randint(std::vector inputs, + std::vector q_ptrs) { + + for (auto i=0; i(q_ptrs[i]); + + // Run in cuda. + randint_cuda(inputs[i], q); + } +} + +//-------------- +// Fast version. +std::vector randint_fast( + std::vector states, + std::vector q_ptrs, + int64_t shift, + size_t step) { + + std::vector outputs; + + for (auto i=0; i(q_ptrs[i]); + auto result = randint_fast_cuda(states[i], q, shift, step); + outputs.push_back(result); + } + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("randint", &randint, "random integer sampling (0 to q, 64 bits)."); + m.def("randint_fast", &randint_fast, "random integer sampling (randint fused, 0 to q, 64 bits)."); +} diff --git a/liberate/csprng/randint_cuda_kernel.cu b/liberate/csprng/randint_cuda_kernel.cu new file mode 100644 index 0000000..4fbc74a --- /dev/null +++ b/liberate/csprng/randint_cuda_kernel.cu @@ -0,0 +1,215 @@ +#include +#include +#include +#include + +#include "chacha20_cuda_kernel.h" + +#define COMBINE_TWO(high, low)\ + ((static_cast(high) << 32) | static_cast(low)) + +#define BLOCK_SIZE 64 + +#define LUT_SIZE 128 +__constant__ uint64_t Q[LUT_SIZE]; + +/////////////////////////////////////////////////////////// +// The implementation + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +__global__ void randint_fast_cuda_kernel( + torch::PackedTensorAccessor32 states, + torch::PackedTensorAccessor32 dst, + int64_t shift, + size_t step) { + + // Where am I? + // blockDim -> BLOCK_SIZE_FAST + // gridDim -> num_rns_channels, num_poly_orders / BLOCK_DIM_FAST + const int thread_ind = threadIdx.x; + const int poly_order = blockIdx.y * blockDim.x + threadIdx.x; + const int rns_channel = blockIdx.x; + + __shared__ int64_t x[BLOCK_SIZE][16]; + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = states[rns_channel][poly_order][i]; + } + + // Repeat 10 times for chacha20. + #pragma unroll + for(int i=0; i<10; ++i){ + ONE_ROUND(x, thread_ind); + } + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = (states[rns_channel][poly_order][i] + x[thread_ind][i]) & MASK; + } + + // Step the state. + states[rns_channel][poly_order][12] += step; + states[rns_channel][poly_order][13] += (states[rns_channel][poly_order][12] >> 32); + states[rns_channel][poly_order][12] &= MASK; + + // Randint. + #pragma unroll + for(int i=0; i<16; i+=4){ + + // What is my Q? + auto p = Q[rns_channel]; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(x[thread_ind][i], x[thread_ind][i+1]); + + // Use CUDA integer intrinsics to calculate + // (x_low * p) >> 64. + // Refer to https://github.com/apple/swift/pull/39143 + auto alpha = __umul64hi(p, x_low); + + // We need to calculate carry. + auto pl = p & MASK; // 1-32 + auto ph = p >> 32; // 33-64 + //--------------------------------------- + auto xhh = x[thread_ind][i+2]; + auto xhl = x[thread_ind][i+3]; + //--------------------------------------- + auto plxhl = pl * xhl; // 65-128 + auto plxhh = pl * xhh; // 97-160 + auto phxhl = ph * xhl; // 97-160 + auto phxhh = ph * xhh; // 129-192 + //--------------------------------------- + auto carry = ((plxhl & MASK) + (alpha & MASK)) >> 32; + carry = (carry + + (plxhl >> 32) + + (alpha >> 32) + + (phxhl & MASK) + + (plxhh & MASK)) >> 32; + auto sample = (carry + + (phxhl >> 32) + + (plxhh >> 32) + phxhh); + + // Store the result. + // Don't forget the shift!!! + const int new_poly_order = poly_order * 4 + i/4; + dst[rns_channel][new_poly_order] = sample + shift; + } +} + + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + + +// rand_bytes are configured as C x N x 16, where C denotes the q channels. +__global__ void randint_cuda_kernel(torch::PackedTensorAccessor32 rand_bytes){ + + // Where am I? + const int index = blockIdx.y * blockDim.x + threadIdx.x; + + // i is the index of the starting element at the threadIdx.y row. + const int i = threadIdx.y * 4; + + // What is my Q? + auto p = Q[blockIdx.x]; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(rand_bytes[blockIdx.x][index][i], rand_bytes[blockIdx.x][index][i+1]); + + // Use CUDA integer intrinsics to calculate + // (x_low * p) >> 64. + // Refer to https://github.com/apple/swift/pull/39143 + auto alpha = __umul64hi(p, x_low); + + // We need to calculate carry. + auto pl = p & MASK; // 1-32 + auto ph = p >> 32; // 33-64 + //--------------------------------------- + auto xhh = rand_bytes[blockIdx.x][index][i+2]; + auto xhl = rand_bytes[blockIdx.x][index][i+3]; + //--------------------------------------- + auto plxhl = pl * xhl; // 65-128 + auto plxhh = pl * xhh; // 97-160 + auto phxhl = ph * xhl; // 97-160 + auto phxhh = ph * xhh; // 129-192 + //--------------------------------------- + auto carry = ((plxhl & MASK) + (alpha & MASK)) >> 32; + carry = (carry + + (plxhl >> 32) + + (alpha >> 32) + + (phxhl & MASK) + + (plxhh & MASK)) >> 32; + auto sample = (carry + + (phxhl >> 32) + + (plxhh >> 32) + phxhh); + + // Store the result. + rand_bytes[blockIdx.x][index][i] = sample; +} + +/////////////////////////////////////////////////////////// +// The wrapper. + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +torch::Tensor randint_fast_cuda(torch::Tensor states, uint64_t *q, int64_t shift, size_t step){ + + // rand_bytes has the dim C x N x 16. + int dim_block = BLOCK_SIZE; + dim3 dim_grid(states.size(0), states.size(1) / dim_block); + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = states.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Prepare the result. + // 16 elements in each state turn into 4 random numbers. + auto result = states.new_empty({states.size(0), states.size(1) * 4}); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(Q, q, states.size(0) * sizeof(uint64_t)); + + // Run the cuda kernel. + auto states_acc = states.packed_accessor32(); + auto result_acc = result.packed_accessor32(); + randint_fast_cuda_kernel<<>>(states_acc, result_acc, shift, step); + + return result; +} + + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + + +void randint_cuda(torch::Tensor rand_bytes, uint64_t* q){ + + // rand_bytes has the dim C x N x 16. + dim3 dim_block(BLOCK_SIZE, 4); + dim3 dim_grid(rand_bytes.size(0), rand_bytes.size(1) / dim_block.x); + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = rand_bytes.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(Q, q, rand_bytes.size(0) * sizeof(uint64_t)); + + // Run the cuda kernel. + auto access = rand_bytes.packed_accessor32(); + randint_cuda_kernel<<>>(access); +} diff --git a/liberate/csprng/randround.cpp b/liberate/csprng/randround.cpp new file mode 100644 index 0000000..5af14d5 --- /dev/null +++ b/liberate/csprng/randround.cpp @@ -0,0 +1,32 @@ + +#include +#include + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void randround_cuda(torch::Tensor input, torch::Tensor rand_bytes); + +// The main function. +// rand_bytes are N 1D uint64_t tensors. +// Inputs are N 1D double tensors. +// The output will be returned in rand_bytes. +void randround(std::vector inputs, + std::vector rand_bytes) { + + for (auto i=0; i +#include +#include +#include + +#define BLOCK_SIZE 256 + +__global__ void randint_cuda_kernel(torch::PackedTensorAccessor32 input, + torch::PackedTensorAccessor32 rand_bytes){ + + // Where am I? + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + auto coef = input[index]; + auto sign_bit = signbit(coef); + auto abs_coef = fabs(coef); + + auto integ = floor(abs_coef); + auto frac = abs_coef - integ; + int64_t intinteg = static_cast(integ); + + // Convert a double to a signed 64-bit int in round-to-nearest-even mode. + // https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__INTRINSIC__CAST.html#group__CUDA__MATH__INTRINSIC__CAST + constexpr double rounder = static_cast(0x100000000); + int64_t ifrac = __double2ll_rn(frac * rounder); + + // Random round. + // The bool value must be 1 for True. + int64_t round = rand_bytes[index] < ifrac; + + // Round and recover sign. + int64_t sign = (sign_bit)? -1 : 1; + int64_t rounded = sign * (intinteg + round); + + // Put back into the rand_bytes. + rand_bytes[index] = rounded; +} + + +// The wrapper. +void randround_cuda(torch::Tensor input, torch::Tensor rand_bytes){ + + // rand_bytes has the dim C x N x 16. + const int dim_block = BLOCK_SIZE; + const int dim_grid = rand_bytes.size(0) / dim_block; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = rand_bytes.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Run the cuda kernel. + auto input_access = input.packed_accessor32(); + auto rand_bytes_access = rand_bytes.packed_accessor32(); + randint_cuda_kernel<<>>(input_access, rand_bytes_access); +} diff --git a/liberate/csprng/setup.py b/liberate/csprng/setup.py new file mode 100644 index 0000000..13b1c92 --- /dev/null +++ b/liberate/csprng/setup.py @@ -0,0 +1,33 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ext_modules = [ + CUDAExtension( + name="randint_cuda", + sources=["randint.cpp", "randint_cuda_kernel.cu"], + ), + CUDAExtension( + name="randround_cuda", + sources=["randround.cpp", "randround_cuda_kernel.cu"], + ), + CUDAExtension( + name="discrete_gaussian_cuda", + sources=["discrete_gaussian.cpp", "discrete_gaussian_cuda_kernel.cu"], + ), + CUDAExtension( + name="chacha20_cuda", + sources=["chacha20.cpp", "chacha20_cuda_kernel.cu"], + ), +] + +setup( + name="csprng", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + script_args=["build_ext"], + options={ + "build_ext": { + "inplace": True, + } + }, +) diff --git a/liberate/fhe/__init__.py b/liberate/fhe/__init__.py new file mode 100644 index 0000000..3ab2a1f --- /dev/null +++ b/liberate/fhe/__init__.py @@ -0,0 +1,4 @@ +from . import context, encdec +from .ckks_engine import ckks_engine +from .presets import params +from .cache import cache diff --git a/liberate/fhe/cache/__init__.py b/liberate/fhe/cache/__init__.py new file mode 100644 index 0000000..6f702bc --- /dev/null +++ b/liberate/fhe/cache/__init__.py @@ -0,0 +1 @@ +from . import cache diff --git a/liberate/fhe/cache/cache.py b/liberate/fhe/cache/cache.py new file mode 100644 index 0000000..1092baa --- /dev/null +++ b/liberate/fhe/cache/cache.py @@ -0,0 +1,32 @@ +import glob +import os +from ..context import generate_primes + +path_cache = os.path.abspath(__file__).replace("cache.py", "resources") + + +# logN_N_M = os.path.join(path_cache, "logN_N_M.pkl") +# message_special_primes = os.path.join(path_cache, "message_special_primes.pkl") +# scale_primes = os.path.join(path_cache, "scale_primes.pkl") + + +def clean_cache(path=None): + if path is None: + path = path_cache + files = glob.glob(os.path.join(path, "*.pkl")) + for file in files: + try: + os.unlink(file) + except Exception as e: + print(e) + pass + return + + +def generate_cache(path=None): + if path is None: + path = path_cache + # Read in pre-calculated high-quality primes. + _ = generate_primes.generate_message_primes(cache_folder=path) + _ = generate_primes.generate_scale_primes(cache_folder=path) + return diff --git a/liberate/fhe/cache/resources/logN_N_M.pkl b/liberate/fhe/cache/resources/logN_N_M.pkl new file mode 100644 index 0000000..444186d Binary files /dev/null and b/liberate/fhe/cache/resources/logN_N_M.pkl differ diff --git a/liberate/fhe/cache/resources/message_special_primes.pkl b/liberate/fhe/cache/resources/message_special_primes.pkl new file mode 100644 index 0000000..3e691fa Binary files /dev/null and b/liberate/fhe/cache/resources/message_special_primes.pkl differ diff --git a/liberate/fhe/cache/resources/scale_primes.pkl b/liberate/fhe/cache/resources/scale_primes.pkl new file mode 100644 index 0000000..88edd01 Binary files /dev/null and b/liberate/fhe/cache/resources/scale_primes.pkl differ diff --git a/liberate/fhe/ckks_engine.py b/liberate/fhe/ckks_engine.py new file mode 100644 index 0000000..7753ae0 --- /dev/null +++ b/liberate/fhe/ckks_engine.py @@ -0,0 +1,2615 @@ +import datetime +import gc +import math +import pickle +from hashlib import sha256 +from pathlib import Path + +import numpy as np +import torch + +# from context.ckks_context import ckks_context +from .context.ckks_context import ckks_context +from .data_struct import data_struct +from .encdec import decode, encode, rotate, conjugate +from .version import VERSION +from .presets import types, errors +from liberate.ntt import ntt_context +from liberate.ntt import ntt_cuda +from liberate.csprng import Csprng + + +class ckks_engine: + @errors.log_error + def __init__(self, devices: list[int] = None, verbose: bool = False, + bias_guard: bool = True, norm: str = 'forward', **ctx_params): + """ + buffer_bit_length=62, + scale_bits=40, + logN=15, + num_scales=None, + num_special_primes=2, + sigma=3.2, + uniform_tenary_secret=True, + cache_folder='cache/', + security_bits=128, + quantum='post_quantum', + distribution='uniform', + read_cache=True, + save_cache=True, + verbose=False + """ + + self.bias_guard = bias_guard + + self.norm = norm + + self.version = VERSION + + self.ctx = ckks_context(**ctx_params) + self.ntt = ntt_context(self.ctx, devices=devices, verbose=verbose) + + if self.bias_guard: + if self.ctx.num_special_primes < 2: + raise errors.NotEnoughPrimesForBiasGuard( + bias_guard=self.bias_guard, + num_special_primes=self.ctx.num_special_primes) + + self.num_levels = self.ntt.num_levels - 1 + + self.num_slots = self.ctx.N // 2 + + rng_repeats = max(self.ntt.num_special_primes, 2) + self.rng = Csprng(self.ntt.ctx.N, [len(di) for di in self.ntt.p.d], rng_repeats, devices=self.ntt.devices) + + self.int_scale = 2 ** self.ctx.scale_bits + self.scale = np.float64(self.int_scale) + + qstr = ','.join([str(qi) for qi in self.ctx.q]) + hashstr = (self.ctx.generation_string + "_" + qstr).encode("utf-8") + self.hash = sha256(bytes(hashstr)).hexdigest() + + self.make_adjustments_and_corrections() + + self.device0 = self.ntt.devices[0] + + self.make_mont_PR() + + self.reserve_ksk_buffers() + + self.create_ksk_rescales() + + self.alloc_parts() + + self.leveled_devices() + + self.create_rescale_scales() + + self.galois_deltas = [2 ** i for i in range(self.ctx.logN - 1)] + + self.mult_dispatch_dict = { + (data_struct, data_struct): self.auto_cc_mult, + (list, data_struct): self.mc_mult, + (np.ndarray, data_struct): self.mc_mult, + (data_struct, np.ndarray): self.cm_mult, + (data_struct, list): self.cm_mult, + (float, data_struct): self.scalar_mult, + (data_struct, float): self.mult_scalar, + (int, data_struct): self.int_scalar_mult, + (data_struct, int): self.mult_int_scalar + } + + self.add_dispatch_dict = { + (data_struct, data_struct): self.auto_cc_add, + (list, data_struct): self.mc_add, + (np.ndarray, data_struct): self.mc_add, + (data_struct, np.ndarray): self.cm_add, + (data_struct, list): self.cm_add, + (float, data_struct): self.scalar_add, + (data_struct, float): self.add_scalar, + (int, data_struct): self.scalar_add, + (data_struct, int): self.add_scalar + } + + self.sub_dispatch_dict = { + (data_struct, data_struct): self.auto_cc_sub, + (list, data_struct): self.mc_sub, + (np.ndarray, data_struct): self.mc_sub, + (data_struct, np.ndarray): self.cm_sub, + (data_struct, list): self.cm_sub, + (float, data_struct): self.scalar_sub, + (data_struct, float): self.sub_scalar, + (int, data_struct): self.scalar_sub, + (data_struct, int): self.sub_scalar + } + + # ------------------------------------------------------------------------------------------- + # Various pre-calculations. + # ------------------------------------------------------------------------------------------- + def create_rescale_scales(self): + self.rescale_scales = [] + for level in range(self.num_levels): + self.rescale_scales.append([]) + + for device_id in range(self.ntt.num_devices): + dest_level = self.ntt.p.destination_arrays[level] + + if device_id < len(dest_level): + dest = dest_level[device_id] + rescaler_device_id = self.ntt.p.rescaler_loc[level] + m0 = self.ctx.q[level] + + if rescaler_device_id == device_id: + m = [self.ctx.q[i] for i in dest[1:]] + else: + m = [self.ctx.q[i] for i in dest] + + scales = [(pow(m0, -1, mi) * self.ctx.R) % mi for mi in m] + + scales = torch.tensor(scales, + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id]) + self.rescale_scales[level].append(scales) + + def leveled_devices(self): + self.len_devices = [] + for level in range(self.num_levels): + self.len_devices.append(len([[a] for a in self.ntt.p.p[level] if len(a) > 0])) + + self.neighbor_devices = [] + for level in range(self.num_levels): + self.neighbor_devices.append([]) + len_devices_at = self.len_devices[level] + available_devices_ids = range(len_devices_at) + for src_device_id in available_devices_ids: + neighbor_devices_at = [ + device_id for device_id in available_devices_ids if device_id != src_device_id + ] + self.neighbor_devices[level].append(neighbor_devices_at) + + def alloc_parts(self): + self.parts_alloc = [] + for level in range(self.num_levels): + num_parts = [len(parts) for parts in self.ntt.p.p[level]] + parts_alloc = [ + alloc[-num_parts[di] - 1:-1] for di, alloc in enumerate(self.ntt.p.part_allocations) + ] + self.parts_alloc.append(parts_alloc) + + self.stor_ids = [] + for level in range(self.num_levels): + self.stor_ids.append([]) + alloc = self.parts_alloc[level] + min_id = min([min(a) for a in alloc if len(a) > 0]) + for device_id in range(self.ntt.num_devices): + global_ids = self.parts_alloc[level][device_id] + new_ids = [i - min_id for i in global_ids] + self.stor_ids[level].append(new_ids) + + def create_ksk_rescales(self): + R = self.ctx.R + P = self.ctx.q[-self.ntt.num_special_primes:][::-1] + m = self.ctx.q + PiR = [[(pow(Pj, -1, mi) * R) % mi for mi in m[:-P_ind - 1]] for P_ind, Pj in enumerate(P)] + + self.PiRs = [] + + level = 0 + self.PiRs.append([]) + + for P_ind in range(self.ntt.num_special_primes): + self.PiRs[level].append([]) + + for device_id in range(self.ntt.num_devices): + dest = self.ntt.p.destination_arrays_with_special[level][device_id] + PiRi = [PiR[P_ind][i] for i in dest[:-P_ind - 1]] + PiRi = torch.tensor(PiRi, + device=self.ntt.devices[device_id], + dtype=self.ctx.torch_dtype) + self.PiRs[level][P_ind].append(PiRi) + + for level in range(1, self.num_levels): + self.PiRs.append([]) + + for P_ind in range(self.ntt.num_special_primes): + + self.PiRs[level].append([]) + + for device_id in range(self.ntt.num_devices): + start = self.ntt.starts[level][device_id] + PiRi = self.PiRs[0][P_ind][device_id][start:] + + self.PiRs[level][P_ind].append(PiRi) + + def reserve_ksk_buffers(self): + self.ksk_buffers = [] + for device_id in range(self.ntt.num_devices): + self.ksk_buffers.append([]) + for part_id in range(len(self.ntt.p.p[0][device_id])): + buffer = torch.empty( + [self.ntt.num_special_primes, self.ctx.N], + dtype=self.ctx.torch_dtype + ).pin_memory() + self.ksk_buffers[device_id].append(buffer) + + def make_mont_PR(self): + P = math.prod(self.ntt.ctx.q[-self.ntt.num_special_primes:]) + R = self.ctx.R + PR = P * R + self.mont_PR = [] + for device_id in range(self.ntt.num_devices): + dest = self.ntt.p.destination_arrays[0][device_id] + m = [self.ctx.q[i] for i in dest] + PRm = [PR % mi for mi in m] + PRm = torch.tensor(PRm, + device=self.ntt.devices[device_id], + dtype=self.ctx.torch_dtype) + self.mont_PR.append(PRm) + + def make_adjustments_and_corrections(self): + + self.alpha = [(self.scale / np.float64(q)) ** 2 for q in self.ctx.q[:self.ctx.num_scales]] + self.deviations = [1] + for al in self.alpha: + self.deviations.append(self.deviations[-1] ** 2 * al) + + self.final_q_ind = [da[0][0] for da in self.ntt.p.destination_arrays[:-1]] + self.final_q = [self.ctx.q[ind] for ind in self.final_q_ind] + self.final_alpha = [(self.scale / np.float64(q)) for q in self.final_q] + self.corrections = [1 / (d * fa) for d, fa in zip(self.deviations, self.final_alpha)] + + self.base_prime = self.ctx.q[self.ntt.p.base_prime_idx] + + self.final_scalar = [] + for qi, q in zip(self.final_q_ind, self.final_q): + scalar = (pow(q, -1, self.base_prime) * self.ctx.R) % self.base_prime + scalar = torch.tensor([scalar], + device=self.ntt.devices[0], + dtype=self.ctx.torch_dtype) + self.final_scalar.append(scalar) + + # ------------------------------------------------------------------------------------------- + # Example generation. + # ------------------------------------------------------------------------------------------- + + def absmax_error(self, x, y): + if type(x[0]) == np.complex128 and type(y[0]) == np.complex128: + r = np.abs(x.real - y.real).max() + np.abs(x.imag - y.imag).max() * 1j + else: + r = np.abs(np.array(x) - np.array(y)).max() + return r + + def integral_bits_available(self): + base_prime = self.base_prime + max_bits = math.floor(math.log2(base_prime)) + integral_bits = max_bits - self.ctx.scale_bits + return integral_bits + + @errors.log_error + def example(self, amin=None, amax=None, decimal_places: int = 10) -> np.array: + if amin is None: + amin = -(2 ** self.integral_bits_available()) + + if amax is None: + amax = 2 ** self.integral_bits_available() + + base = 10 ** decimal_places + a = np.random.randint(amin * base, amax * base, self.ctx.N // 2) / base + b = np.random.randint(amin * base, amax * base, self.ctx.N // 2) / base + + sample = a + b * 1j + + return sample + + # ------------------------------------------------------------------------------------------- + # Encode/Decode + # ------------------------------------------------------------------------------------------- + + def padding(self, m): + # m = m[:self.num_slots] + try: + m_len = len(m) + padding_result = np.pad(m, (0, self.num_slots - m_len), constant_values=(0, 0)) + except TypeError as e: + m_len = len([m]) + padding_result = np.pad([m], (0, self.num_slots - m_len), constant_values=(0, 0)) + except Exception as e: + raise Exception("[Error] encoding Padding Error.") + return padding_result + + @errors.log_error + def encode(self, m, level: int = 0, padding=True) -> list[torch.Tensor]: + """ + Encode a plain message m, using an encoding function. + Note that the encoded plain text is pre-permuted to yield cyclic rotation. + """ + deviation = self.deviations[level] + if padding: + m = self.padding(m) + encoded = [encode(m, scale=self.scale, rng=self.rng, + device=self.device0, + deviation=deviation, norm=self.norm)] + + pt_buffer = self.ksk_buffers[0][0][0] + pt_buffer.copy_(encoded[-1]) + for dev_id in range(1, self.ntt.num_devices): + encoded.append(pt_buffer.cuda(self.ntt.devices[dev_id])) + return encoded + + @errors.log_error + def decode(self, m, level=0, is_real: bool = False) -> list: + """ + Base prime is located at -1 of the RNS channels in GPU0. + Assuming this is an orginary RNS deinclude_special. + """ + correction = self.corrections[level] + decoded = decode(m[0].squeeze(), scale=self.scale, correction=correction, norm=self.norm) + m = decoded[:self.ctx.N // 2].cpu().numpy() + if is_real: + m = m.real + return m + + # ------------------------------------------------------------------------------------------- + # secret key/public key generation. + # ------------------------------------------------------------------------------------------- + + @errors.log_error + def create_secret_key(self, include_special: bool = True) -> data_struct: + uniform_ternary = self.rng.randint(amax=3, shift=-1, repeats=1) + + mult_type = -2 if include_special else -1 + unsigned_ternary = self.ntt.tile_unsigned(uniform_ternary, lvl=0, mult_type=mult_type) + self.ntt.enter_ntt(unsigned_ternary, 0, mult_type) + + return data_struct( + data=unsigned_ternary, + include_special=include_special, + montgomery_state=True, + ntt_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + + @errors.log_error + def create_public_key(self, sk: data_struct, include_special: bool = False, + a: list[torch.Tensor] = None) -> data_struct: + """ + Generates a public key against the secret key sk. + pk = -a * sk + e = e - a * sk + """ + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if include_special and not sk.include_special: + raise errors.SecretKeyNotIncludeSpecialPrime() + + # Set the mult_type + mult_type = -2 if include_special else -1 + + # Generate errors for the ordinary case. + level = 0 + e = self.rng.discrete_gaussian(repeats=1) + e = self.ntt.tile_unsigned(e, level, mult_type) + + self.ntt.enter_ntt(e, level, mult_type) + repeats = self.ctx.num_special_primes if sk.include_special else 0 + + # Applying mont_mult in the order of 'a', sk will + if a is None: + a = self.rng.randint( + self.ntt.q_prepack[mult_type][level][0], + repeats=repeats + ) + + sa = self.ntt.mont_mult(a, sk.data, 0, mult_type) + pk0 = self.ntt.mont_sub(e, sa, 0, mult_type) + + return data_struct( + data=(pk0, a), + include_special=include_special, + ntt_state=True, + montgomery_state=True, + origin=types.origins["pk"], + level=0, + hash=self.hash, + version=self.version + ) + + # ------------------------------------------------------------------------------------------- + # Encrypt/Decrypt + # ------------------------------------------------------------------------------------------- + + @errors.log_error + def encrypt(self, pt: list[torch.Tensor], pk: data_struct, level: int = 0) -> data_struct: + """ + We again, multiply pt by the scale. + Since pt is already multiplied by the scale, + the multiplied pt no longer can be stored + in a single RNS channel. + That means we are forced to do the multiplication + in full RNS domain. + Note that we allow encryption at + levels other than 0, and that will take care of multiplying + the deviation factors. + """ + if pk.origin != types.origins["pk"]: + raise errors.NotMatchType(origin=pk.origin, to=types.origins["pk"]) + + mult_type = -2 if pk.include_special else -1 + + e0e1 = self.rng.discrete_gaussian(repeats=2) + + e0 = [e[0] for e in e0e1] + e1 = [e[1] for e in e0e1] + + e0_tiled = self.ntt.tile_unsigned(e0, level, mult_type) + e1_tiled = self.ntt.tile_unsigned(e1, level, mult_type) + + pt_tiled = self.ntt.tile_unsigned(pt, level, mult_type) + self.ntt.mont_enter_scale(pt_tiled, level, mult_type) + self.ntt.mont_redc(pt_tiled, level, mult_type) + pte0 = self.ntt.mont_add(pt_tiled, e0_tiled, level, mult_type) + + start = self.ntt.starts[level] + pk0 = [pk.data[0][di][start[di]:] for di in range(self.ntt.num_devices)] + pk1 = [pk.data[1][di][start[di]:] for di in range(self.ntt.num_devices)] + + v = self.rng.randint(amax=2, shift=0, repeats=1) + + v = self.ntt.tile_unsigned(v, level, mult_type) + self.ntt.enter_ntt(v, level, mult_type) + + vpk0 = self.ntt.mont_mult(v, pk0, level, mult_type) + vpk1 = self.ntt.mont_mult(v, pk1, level, mult_type) + + self.ntt.intt_exit(vpk0, level, mult_type) + self.ntt.intt_exit(vpk1, level, mult_type) + + ct0 = self.ntt.mont_add(vpk0, pte0, level, mult_type) + ct1 = self.ntt.mont_add(vpk1, e1_tiled, level, mult_type) + + self.ntt.reduce_2q(ct0, level, mult_type) + self.ntt.reduce_2q(ct1, level, mult_type) + + ct = data_struct( + data=(ct0, ct1), + include_special=mult_type == -2, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + return ct + + def decrypt_triplet(self, ct_mult: data_struct, sk: data_struct) -> list[torch.Tensor]: + if ct_mult.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=ct_mult.origin, to=types.origins["ctt"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + if not ct_mult.ntt_state or not ct_mult.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct_mult.origin) + if (not sk.ntt_state) or (not sk.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk.origin) + + level = ct_mult.level + d0 = [ct_mult.data[0][0].clone()] + d1 = [ct_mult.data[1][0]] + d2 = [ct_mult.data[2][0]] + + self.ntt.intt_exit_reduce(d0, level) + + sk_data = [sk.data[0][self.ntt.starts[level][0]:]] + + d1_s = self.ntt.mont_mult(d1, sk_data, level) + + s2 = self.ntt.mont_mult(sk_data, sk_data, level) + d2_s2 = self.ntt.mont_mult(d2, s2, level) + + self.ntt.intt_exit(d1_s, level) + self.ntt.intt_exit(d2_s2, level) + + pt = self.ntt.mont_add(d0, d1_s, level) + pt = self.ntt.mont_add(pt, d2_s2, level) + self.ntt.reduce_2q(pt, level) + + base_at = -self.ctx.num_special_primes - 1 if ct_mult.include_special else -1 + + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + return scaled + + def decrypt_double(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + if not sk.ntt_state or not sk.montgomery_state: + raise errors.NotMatchDataStructState(origin=sk.origin) + + ct0 = ct.data[0][0] + level = ct.level + sk_data = sk.data[0][self.ntt.starts[level][0]:] + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + pt = self.ntt.mont_add([ct0], sa, level) + self.ntt.reduce_2q(pt, level) + + base_at = -self.ctx.num_special_primes - 1 if ct.include_special else -1 + + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + ############################################################################# + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + return scaled + + def decrypt(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: + """ + Decrypt the cipher text ct using the secret key sk. + Note that the final rescaling must precede the actual decryption process. + """ + + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + if ct.origin == types.origins["ctt"]: + pt = self.decrypt_triplet(ct_mult=ct, sk=sk) + elif ct.origin == types.origins["ct"]: + pt = self.decrypt_double(ct=ct, sk=sk) + else: + raise errors.NotMatchType(origin=ct.origin, to=f"{types.origins['ct']} or {types.origins['ctt']}") + + return pt + + # ------------------------------------------------------------------------------------------- + # Key switching. + # ------------------------------------------------------------------------------------------- + + def create_key_switching_key(self, sk_from: data_struct, sk_to: data_struct, a=None) -> data_struct: + """ + Creates a key to switch the key for sk_src to sk_dst. + """ + + if sk_from.origin != types.origins["sk"] or sk_from.origin != types.origins["sk"]: + raise errors.NotMatchType(origin="not a secret key", to=types.origins["sk"]) + if (not sk_from.ntt_state) or (not sk_from.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_from.origin) + if (not sk_to.ntt_state) or (not sk_to.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_to.origin) + + level = 0 + + stops = self.ntt.stops[-1] + Psk_src = [sk_from.data[di][:stops[di]].clone() for di in range(self.ntt.num_devices)] + + self.ntt.mont_enter_scalar(Psk_src, self.mont_PR, level) + + ksk = [[] for _ in range(self.ntt.p.num_partitions + 1)] + + for device_id in range(self.ntt.num_devices): + for part_id, part in enumerate(self.ntt.p.p[level][device_id]): + global_part_id = self.ntt.p.part_allocations[device_id][part_id] + + crs = a[global_part_id] if a else None + pk = self.create_public_key(sk_to, include_special=True, a=crs) + + key = tuple(part) + astart = part[0] + astop = part[-1] + 1 + shard = Psk_src[device_id][astart:astop] + pk_data = pk.data[0][device_id][astart:astop] + + _2q = self.ntt.parts_pack[device_id][key]['_2q'] + update_part = ntt_cuda.mont_add([pk_data], [shard], _2q)[0] + pk_data.copy_(update_part, non_blocking=True) + + pk_name = f'key switch key part index {global_part_id}' + pk = pk._replace(origin=pk_name) + + ksk[global_part_id] = pk + + return data_struct( + data=ksk, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ksk"], + level=level, + hash=self.hash, + version=self.version) + + def pre_extend(self, a, device_id, level, part_id, exit_ntt=False): + text_part = self.ntt.p.parts[level][device_id][part_id] + param_part = self.ntt.p.p[level][device_id][part_id] + + alpha = len(text_part) + a_part = a[device_id][text_part[0]:text_part[-1] + 1] + + if exit_ntt: + self.ntt.intt_exit_reduce([a_part], level, device_id, part_id) + + state = a_part[0].repeat(alpha, 1) + key = tuple(param_part) + for i in range(alpha - 1): + mont_pack = self.ntt.parts_pack[device_id][param_part[i + 1],]['mont_pack'] + _2q = self.ntt.parts_pack[device_id][param_part[i + 1],]['_2q'] + Y_scalar = self.ntt.parts_pack[device_id][key]['Y_scalar'][i][None] + + Y = (a_part[i + 1] - state[i + 1])[None, :] + + ntt_cuda.mont_enter([Y], [Y_scalar], *mont_pack) + ntt_cuda.reduce_2q([Y], _2q) + + state[i + 1] = Y + + if i + 2 < alpha: + state_key = tuple(param_part[i + 2:]) + state_mont_pack = self.ntt.parts_pack[device_id][state_key]['mont_pack'] + state_2q = self.ntt.parts_pack[device_id][state_key]['_2q'] + L_scalar = self.ntt.parts_pack[device_id][key]['L_scalar'][i] + new_state_len = alpha - (i + 2) + new_state = Y.repeat(new_state_len, 1) + ntt_cuda.mont_enter([new_state], [L_scalar], *state_mont_pack) + ntt_cuda.reduce_2q([new_state], state_2q) + state[i + 2:] += new_state + ntt_cuda.reduce_2q([state[i + 2:]], state_2q) + + return state + + def extend(self, state, device_id, level, part_id, target_device_id=None): + + if target_device_id is None: + target_device_id = device_id + + rns_len = len( + self.ntt.p.destination_arrays_with_special[level][target_device_id]) + alpha = len(state) + + extended = state[0].repeat(rns_len, 1) + self.ntt.mont_enter([extended], level, target_device_id, -2) + + part = self.ntt.p.p[level][device_id][part_id] + key = tuple(part) + + L_enter = self.ntt.parts_pack[device_id][key]['L_enter'][target_device_id] + + start = self.ntt.starts[level][target_device_id] + + for i in range(alpha - 1): + Y = state[i + 1].repeat(rns_len, 1) + + self.ntt.mont_enter_scalar([Y], [L_enter[i][start:]], level, target_device_id, -2) + extended = self.ntt.mont_add([extended], [Y], level, target_device_id, -2)[0] + + return extended + + def create_switcher(self, a: list[torch.Tensor], ksk: data_struct, level, exit_ntt=False) -> tuple: + ksk_alloc = self.parts_alloc[level] + + len_devices = self.len_devices[level] + neighbor_devices = self.neighbor_devices[level] + + num_parts = sum([len(alloc) for alloc in ksk_alloc]) + part_results = [[[[] for _ in range(len_devices)], [[] for _ in range(len_devices)]] for _ in range(num_parts)] + + states = [[] for _ in range(num_parts)] + for src_device_id in range(len_devices): + for part_id in range(len(self.ntt.p.p[level][src_device_id])): + storage_id = self.stor_ids[level][src_device_id][part_id] + state = self.pre_extend(a, + src_device_id, + level, + part_id, + exit_ntt + ) + states[storage_id] = state + + CPU_states = [[] for _ in range(num_parts)] + for src_device_id in range(len_devices): + for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): + storage_id = self.stor_ids[level][src_device_id][part_id] + alpha = len(part) + CPU_state = self.ksk_buffers[src_device_id][part_id][:alpha] + CPU_state.copy_(states[storage_id], non_blocking=True) + CPU_states[storage_id] = CPU_state + + for src_device_id in range(len_devices): + for part_id in range(len(self.ntt.p.p[level][src_device_id])): + storage_id = self.stor_ids[level][src_device_id][part_id] + state = states[storage_id] + d0, d1 = self.switcher_later_part(state, ksk, + src_device_id, + src_device_id, + level, part_id) + + part_results[storage_id][0][src_device_id] = d0 + part_results[storage_id][1][src_device_id] = d1 + + CUDA_states = [[] for _ in range(num_parts)] + for src_device_id in range(len_devices): + for j, dst_device_id in enumerate( + neighbor_devices[src_device_id]): + for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): + storage_id = self.stor_ids[level][src_device_id][part_id] + CPU_state = CPU_states[storage_id] + CUDA_states[storage_id] = CPU_state.cuda(self.ntt.devices[dst_device_id], non_blocking=True) + + torch.cuda.synchronize() + + for src_device_id in range(len_devices): + for j, dst_device_id in enumerate( + neighbor_devices[src_device_id]): + for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): + storage_id = self.stor_ids[level][src_device_id][part_id] + CUDA_state = CUDA_states[storage_id] + d0, d1 = self.switcher_later_part(CUDA_state, + ksk, + src_device_id, + dst_device_id, + level, + part_id) + part_results[storage_id][0][dst_device_id] = d0 + part_results[storage_id][1][dst_device_id] = d1 + + summed0 = part_results[0][0] + summed1 = part_results[0][1] + + for i in range(1, len(part_results)): + summed0 = self.ntt.mont_add(summed0, part_results[i][0], level, -2) + summed1 = self.ntt.mont_add(summed1, part_results[i][1], level, -2) + + d0 = summed0 + d1 = summed1 + + current_len = [len(d) for d in self.ntt.p.destination_arrays_with_special[level]] + + for P_ind in range(self.ntt.num_special_primes): + current_len = [c - 1 for c in current_len] + + PiRi = self.PiRs[level][P_ind] + + P0 = [d[-1].repeat(current_len[di], 1) for di, d in enumerate(d0)] + P1 = [d[-1].repeat(current_len[di], 1) for di, d in enumerate(d1)] + + d0 = [d0[i][:current_len[i]] - P0[i] for i in range(len_devices)] + d1 = [d1[i][:current_len[i]] - P1[i] for i in range(len_devices)] + + self.ntt.mont_enter_scalar(d0, PiRi, level, -2) + self.ntt.mont_enter_scalar(d1, PiRi, level, -2) + + self.ntt.reduce_2q(d0, level, -1) + self.ntt.reduce_2q(d1, level, -1) + + return d0, d1 + + def switcher_later_part(self, + state, ksk, + src_device_id, + dst_device_id, + level, part_id): + + extended = self.extend(state, src_device_id, level, part_id, dst_device_id) + + self.ntt.ntt([extended], level, dst_device_id, -2) + + ksk_loc = self.parts_alloc[level][src_device_id][part_id] + ksk_part_data = ksk.data[ksk_loc].data + + start = self.ntt.starts[level][dst_device_id] + ksk0_data = ksk_part_data[0][dst_device_id][start:] + ksk1_data = ksk_part_data[1][dst_device_id][start:] + + d0 = self.ntt.mont_mult([extended], [ksk0_data], level, dst_device_id, -2) + d1 = self.ntt.mont_mult([extended], [ksk1_data], level, dst_device_id, -2) + + self.ntt.intt_exit_reduce(d0, level, dst_device_id, -2) + self.ntt.intt_exit_reduce(d1, level, dst_device_id, -2) + + return d0[0], d1[0] + + def switch_key(self, ct: data_struct, ksk: data_struct) -> data_struct: + include_special = ct.include_special + ntt_state = ct.ntt_state + montgomery_state = ct.montgomery_state + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + + level = ct.level + a = ct.data[1] + d0, d1 = self.create_switcher(a, ksk, level, exit_ntt=ct.ntt_state) + + new_ct0 = self.ntt.mont_add(ct.data[0], d0, level, -1) + + return data_struct( + data=(new_ct0, d1), + include_special=include_special, + ntt_state=ntt_state, + montgomery_state=montgomery_state, + origin=types.origins["ct"], + level=level, + hash=self.hash + ) + + # ------------------------------------------------------------------------------------------- + # Multiplication. + # ------------------------------------------------------------------------------------------- + + def rescale(self, ct: data_struct, exact_rounding=True) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + level = ct.level + next_level = level + 1 + + if next_level >= self.num_levels: + raise errors.MaximumLevelError(level=ct.level, level_max=self.num_levels) + + rescaler_device_id = self.ntt.p.rescaler_loc[level] + neighbor_devices_before = self.neighbor_devices[level] + neighbor_devices_after = self.neighbor_devices[next_level] + len_devices_after = len(neighbor_devices_after) + len_devices_before = len(neighbor_devices_before) + + rescaling_scales = self.rescale_scales[level] + data0 = [[] for _ in range(len_devices_after)] + data1 = [[] for _ in range(len_devices_after)] + + rescaler0 = [[] for _ in range(len_devices_before)] + rescaler1 = [[] for _ in range(len_devices_before)] + + rescaler0_at = ct.data[0][rescaler_device_id][0] + rescaler0[rescaler_device_id] = rescaler0_at + + rescaler1_at = ct.data[1][rescaler_device_id][0] + rescaler1[rescaler_device_id] = rescaler1_at + + if rescaler_device_id < len_devices_after: + data0[rescaler_device_id] = ct.data[0][rescaler_device_id][1:] + data1[rescaler_device_id] = ct.data[1][rescaler_device_id][1:] + + CPU_rescaler0 = self.ksk_buffers[0][0][0] + CPU_rescaler1 = self.ksk_buffers[0][1][0] + + CPU_rescaler0.copy_(rescaler0_at, non_blocking=False) + CPU_rescaler1.copy_(rescaler1_at, non_blocking=False) + + for device_id in neighbor_devices_before[rescaler_device_id]: + device = self.ntt.devices[device_id] + CUDA_rescaler0 = CPU_rescaler0.cuda(device) + CUDA_rescaler1 = CPU_rescaler1.cuda(device) + + rescaler0[device_id] = CUDA_rescaler0 + rescaler1[device_id] = CUDA_rescaler1 + + if device_id < len_devices_after: + data0[device_id] = ct.data[0][device_id] + data1[device_id] = ct.data[1][device_id] + + if exact_rounding: + rescale_channel_prime_id = self.ntt.p.destination_arrays[level][rescaler_device_id][0] + + round_at = self.ctx.q[rescale_channel_prime_id] // 2 + + rounder0 = [[] for _ in range(len_devices_before)] + rounder1 = [[] for _ in range(len_devices_before)] + + for device_id in range(len_devices_after): + rounder0[device_id] = torch.where(rescaler0[device_id] > round_at, 1, 0) + rounder1[device_id] = torch.where(rescaler1[device_id] > round_at, 1, 0) + + data0 = [(d - s) for d, s in zip(data0, rescaler0)] + data1 = [(d - s) for d, s in zip(data1, rescaler1)] + + self.ntt.mont_enter_scalar(data0, self.rescale_scales[level], next_level) + + self.ntt.mont_enter_scalar(data1, self.rescale_scales[level], next_level) + + if exact_rounding: + data0 = [(d + r) for d, r in zip(data0, rounder0)] + data1 = [(d + r) for d, r in zip(data1, rounder1)] + + self.ntt.reduce_2q(data0, next_level) + self.ntt.reduce_2q(data1, next_level) + + return data_struct( + data=(data0, data1), + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=next_level, + hash=self.hash, + version=self.version + ) + + def create_evk(self, sk: data_struct) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + sk2_data = self.ntt.mont_mult(sk.data, sk.data, 0, -2) + sk2 = data_struct( + data=sk2_data, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=sk.level, + hash=self.hash, + version=self.version + ) + + return self.create_key_switching_key(sk2, sk) + + def cc_mult(self, a: data_struct, b: data_struct, evk: data_struct, relin=True) -> data_struct: + if a.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=a.origin, to=types.origins["sk"]) + if b.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=b.origin, to=types.origins["sk"]) + # Rescale. + x = self.rescale(a) + y = self.rescale(b) + + level = x.level + + # Multiply. + x0 = x.data[0] + x1 = x.data[1] + + y0 = y.data[0] + y1 = y.data[1] + + self.ntt.enter_ntt(x0, level) + self.ntt.enter_ntt(x1, level) + self.ntt.enter_ntt(y0, level) + self.ntt.enter_ntt(y1, level) + + d0 = self.ntt.mont_mult(x0, y0, level) + + x0y1 = self.ntt.mont_mult(x0, y1, level) + x1y0 = self.ntt.mont_mult(x1, y0, level) + d1 = self.ntt.mont_add(x0y1, x1y0, level) + + d2 = self.ntt.mont_mult(x1, y1, level) + + ct_mult = data_struct( + data=(d0, d1, d2), + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash + ) + if relin: + ct_mult = self.relinearize(ct_triplet=ct_mult, evk=evk) + + return ct_mult + + def relinearize(self, ct_triplet: data_struct, evk: data_struct) -> data_struct: + if ct_triplet.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=ct_triplet.origin, to=types.origins["ctt"]) + if not ct_triplet.ntt_state or not ct_triplet.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct_triplet.origin) + + d0, d1, d2 = ct_triplet.data + level = ct_triplet.level + + # intt. + self.ntt.intt_exit_reduce(d0, level) + self.ntt.intt_exit_reduce(d1, level) + self.ntt.intt_exit_reduce(d2, level) + + # Key switch the x1y1. + d2_0, d2_1 = self.create_switcher(d2, evk, level) + + # Add the switcher to d0, d1. + d0 = [p + q for p, q in zip(d0, d2_0)] + d1 = [p + q for p, q in zip(d1, d2_1)] + + # Final reduction. + self.ntt.reduce_2q(d0, level) + self.ntt.reduce_2q(d1, level) + + # Compose and return. + return data_struct( + data=(d0, d1), + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash + ) + + # ------------------------------------------------------------------------------------------- + # Rotation. + # ------------------------------------------------------------------------------------------- + + def create_rotation_key(self, sk: data_struct, delta: int, a: list[torch.Tensor] = None) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + sk_new_data = [s.clone() for s in sk.data] + self.ntt.intt(sk_new_data) + sk_new_data = [rotate(s, delta) for s in sk_new_data] + self.ntt.ntt(sk_new_data) + sk_rotated = data_struct( + data=sk_new_data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + + rotk = self.create_key_switching_key(sk_rotated, sk, a=a) + rotk = rotk._replace(origin=types.origins["rotk"] + f"{delta}") + return rotk + + def rotate_single(self, ct: data_struct, rotk: data_struct) -> data_struct: + + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if types.origins["rotk"] not in rotk.origin: + raise errors.NotMatchType(origin=rotk.origin, to=types.origins["rotk"]) + + level = ct.level + include_special = ct.include_special + ntt_state = ct.ntt_state + montgomery_state = ct.montgomery_state + origin = rotk.origin + delta = int(origin.split(':')[-1]) + + rotated_ct_data = [[rotate(d, delta) for d in ct_data] for ct_data in ct.data] + + rotated_ct_rotated_sk = data_struct( + data=rotated_ct_data, + include_special=include_special, + ntt_state=ntt_state, + montgomery_state=montgomery_state, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + rotated_ct = self.switch_key(rotated_ct_rotated_sk, rotk) + return rotated_ct + + def create_galois_key(self, sk: data_struct) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + galois_key_parts = [self.create_rotation_key(sk, delta) for delta in self.galois_deltas] + + galois_key = data_struct( + data=galois_key_parts, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["galk"], + level=0, + hash=self.hash, + version=self.version + ) + return galois_key + + def rotate_galois(self, ct: data_struct, gk: data_struct, delta: int, return_circuit=False) -> data_struct: + + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if gk.origin != types.origins["galk"]: + raise errors.NotMatchType(origin=gk.origin, to=types.origins["galk"]) + + current_delta = delta % (self.ctx.N // 2) + galois_circuit = [] + + while current_delta: + galois_ind = int(math.log2(current_delta)) + galois_delta = self.galois_deltas[galois_ind] + galois_circuit.append(galois_ind) + current_delta -= galois_delta + + if len(galois_circuit) > 0: + rotated_ct = self.rotate_single(ct, gk.data[galois_circuit[0]]) + + for delta_ind in galois_circuit[1:]: + rotated_ct = self.rotate_single(rotated_ct, gk.data[delta_ind]) + elif len(galois_circuit) == 0: + rotated_ct = ct + else: + pass + + if return_circuit: + return rotated_ct, galois_circuit + else: + return rotated_ct + + # ------------------------------------------------------------------------------------------- + # Add/sub. + # ------------------------------------------------------------------------------------------- + def cc_add_double(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ct"] or b.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ct"]) + if a.ntt_state or a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if b.ntt_state or b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + c0 = self.ntt.mont_add(a.data[0], b.data[0], level) + c1 = self.ntt.mont_add(a.data[1], b.data[1], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + data.extend([c0, c1]) + + return data_struct( + data=data, + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash + ) + + def cc_add_triplet(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ctt"] or b.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ctt"]) + if not a.ntt_state or not a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if not b.ntt_state or not b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + c0 = self.ntt.mont_add(a.data[0], b.data[0], level) + c1 = self.ntt.mont_add(a.data[1], b.data[1], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + data.extend([c0, c1]) + c2 = self.ntt.mont_add(a.data[2], b.data[2], level) + self.ntt.reduce_2q(c2, level) + data.append(c2) + + return data_struct( + data=data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash, + version=self.version + ) + + def cc_add(self, a: data_struct, b: data_struct) -> data_struct: + + if a.origin == types.origins["ct"] and b.origin == types.origins["ct"]: + ct_add = self.cc_add_double(a, b) + elif a.origin == types.origins["ctt"] and b.origin == types.origins["ctt"]: + ct_add = self.cc_add_triplet(a, b) + else: + raise errors.DifferentTypeError(a=a.origin, b=b.origin) + + return ct_add + + def cc_sub_double(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ct"] or b.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ct"]) + if a.ntt_state or a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if b.ntt_state or b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + + c0 = self.ntt.mont_sub(a.data[0], b.data[0], level) + c1 = self.ntt.mont_sub(a.data[1], b.data[1], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + data.extend([c0, c1]) + + return data_struct( + data=data, + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + def cc_sub_triplet(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ctt"] or b.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ctt"]) + if not a.ntt_state or not a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if not b.ntt_state or not b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + c0 = self.ntt.mont_sub(a.data[0], b.data[0], level) + c1 = self.ntt.mont_sub(a.data[1], b.data[1], level) + c2 = self.ntt.mont_sub(a.data[2], b.data[2], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + self.ntt.reduce_2q(c2, level) + data.extend([c0, c1, c2]) + + return data_struct( + data=data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash, + version=self.version + ) + + def cc_sub(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != b.origin: + raise Exception(f"[Error] triplet error") + + if types.origins["ct"] == a.origin and types.origins["ct"] == b.origin: + ct_sub = self.cc_sub_double(a, b) + elif a.origin == types.origins["ctt"] and b.origin == types.origins["ctt"]: + ct_sub = self.cc_sub_triplet(a, b) + else: + raise errors.DifferentTypeError(a=a.origin, b=b.origin) + return ct_sub + + def cc_subtract(self, a, b): + return self.cc_sub(a, b) + + # ------------------------------------------------------------------------------------------- + # Level up. + # ------------------------------------------------------------------------------------------- + def level_up(self, ct: data_struct, dst_level: int): + if types.origins["ct"] != ct.origin: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + + current_level = ct.level + + new_ct = self.rescale(ct) + + src_level = current_level + 1 + + dst_len_devices = len(self.ntt.p.destination_arrays[dst_level]) + + diff_deviation = self.deviations[dst_level] / np.sqrt(self.deviations[src_level]) + + deviated_delta = round(self.scale * diff_deviation) + + if dst_level - src_level > 0: + src_rns_lens = [len(d) for d in self.ntt.p.destination_arrays[src_level]] + dst_rns_lens = [len(d) for d in self.ntt.p.destination_arrays[dst_level]] + + diff_rns_lens = [y - x for x, y in zip(dst_rns_lens, src_rns_lens)] + + new_ct_data0 = [] + new_ct_data1 = [] + + for device_id in range(dst_len_devices): + new_ct_data0.append(new_ct.data[0][device_id][diff_rns_lens[device_id]:]) + new_ct_data1.append(new_ct.data[1][device_id][diff_rns_lens[device_id]:]) + else: + new_ct_data0, new_ct_data1 = new_ct.data + + multipliers = [] + for device_id in range(dst_len_devices): + dest = self.ntt.p.destination_arrays[dst_level][device_id] + q = [self.ctx.q[i] for i in dest] + + multiplier = [(deviated_delta * self.ctx.R) % qi for qi in q] + multiplier = torch.tensor(multiplier, dtype=self.ctx.torch_dtype, device=self.ntt.devices[device_id]) + multipliers.append(multiplier) + + self.ntt.mont_enter_scalar(new_ct_data0, multipliers, dst_level) + self.ntt.mont_enter_scalar(new_ct_data1, multipliers, dst_level) + + self.ntt.reduce_2q(new_ct_data0, dst_level) + self.ntt.reduce_2q(new_ct_data1, dst_level) + + new_ct = data_struct( + data=(new_ct_data0, new_ct_data1), + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=dst_level, + hash=self.hash, + version=self.version + ) + + return new_ct + + # ------------------------------------------------------------------------------------------- + # Fused enc/dec. + # ------------------------------------------------------------------------------------------- + def encodecrypt(self, m, pk: data_struct, level: int = 0, padding=True) -> data_struct: + if pk.origin != types.origins["pk"]: + raise errors.NotMatchType(origin=pk.origin, to=types.origins["pk"]) + + if padding: + m = self.padding(m=m) + + deviation = self.deviations[level] + pt = encode(m, scale=self.scale, + device=self.device0, norm=self.norm, + deviation=deviation, rng=self.rng, + return_without_scaling=self.bias_guard) + + if self.bias_guard: + dc_integral = pt[0].item() // 1 + pt[0] -= dc_integral + + dc_scale = int(dc_integral) * int(self.scale) + dc_rns = [] + for device_id, dest in enumerate(self.ntt.p.destination_arrays[level]): + dci = [dc_scale % self.ctx.q[i] for i in dest] + dci = torch.tensor(dci, + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id]) + dc_rns.append(dci) + + pt *= np.float64(self.scale) + pt = self.rng.randround(pt) + + encoded = [pt] + + pt_buffer = self.ksk_buffers[0][0][0] + pt_buffer.copy_(encoded[-1]) + for dev_id in range(1, self.ntt.num_devices): + encoded.append(pt_buffer.cuda(self.ntt.devices[dev_id])) + + mult_type = -2 if pk.include_special else -1 + + e0e1 = self.rng.discrete_gaussian(repeats=2) + + e0 = [e[0] for e in e0e1] + e1 = [e[1] for e in e0e1] + + e0_tiled = self.ntt.tile_unsigned(e0, level, mult_type) + e1_tiled = self.ntt.tile_unsigned(e1, level, mult_type) + + pt_tiled = self.ntt.tile_unsigned(encoded, level, mult_type) + + if self.bias_guard: + for device_id, pti in enumerate(pt_tiled): + pti[:, 0] += dc_rns[device_id] + + self.ntt.mont_enter_scale(pt_tiled, level, mult_type) + self.ntt.mont_redc(pt_tiled, level, mult_type) + pte0 = self.ntt.mont_add(pt_tiled, e0_tiled, level, mult_type) + + start = self.ntt.starts[level] + pk0 = [pk.data[0][di][start[di]:] for di in range(self.ntt.num_devices)] + pk1 = [pk.data[1][di][start[di]:] for di in range(self.ntt.num_devices)] + + v = self.rng.randint(amax=2, shift=0, repeats=1) + + v = self.ntt.tile_unsigned(v, level, mult_type) + self.ntt.enter_ntt(v, level, mult_type) + + vpk0 = self.ntt.mont_mult(v, pk0, level, mult_type) + vpk1 = self.ntt.mont_mult(v, pk1, level, mult_type) + + self.ntt.intt_exit(vpk0, level, mult_type) + self.ntt.intt_exit(vpk1, level, mult_type) + + ct0 = self.ntt.mont_add(vpk0, pte0, level, mult_type) + ct1 = self.ntt.mont_add(vpk1, e1_tiled, level, mult_type) + + self.ntt.reduce_2q(ct0, level, mult_type) + self.ntt.reduce_2q(ct1, level, mult_type) + + ct = data_struct( + data=(ct0, ct1), + include_special=mult_type == -2, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + return ct + + def decryptcode(self, ct: data_struct, sk: data_struct, is_real=False) -> data_struct: + if (not sk.ntt_state) or (not sk.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk.origin) + + level = ct.level + sk_data = sk.data[0][self.ntt.starts[level][0]:] + + if ct.origin == types.origins["ct"]: + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + + ct0 = ct.data[0][0] + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + pt = self.ntt.mont_add([ct0], sa, level) + self.ntt.reduce_2q(pt, level) + + elif ct.origin == types.origins["ctt"]: + if not ct.ntt_state or not ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + + d0 = [ct.data[0][0].clone()] + d1 = [ct.data[1][0]] + d2 = [ct.data[2][0]] + + self.ntt.intt_exit_reduce(d0, level) + + sk_data = [sk.data[0][self.ntt.starts[level][0]:]] + + d1_s = self.ntt.mont_mult(d1, sk_data, level) + + s2 = self.ntt.mont_mult(sk_data, sk_data, level) + d2_s2 = self.ntt.mont_mult(d2, s2, level) + + self.ntt.intt_exit(d1_s, level) + self.ntt.intt_exit(d2_s2, level) + + pt = self.ntt.mont_add(d0, d1_s, level) + pt = self.ntt.mont_add(pt, d2_s2, level) + self.ntt.reduce_2q(pt, level) + else: + raise errors.NotMatchType(origin=ct.origin, to=f"{types.origins['ct']} or {types.origins['ctt']}") + + base_at = -self.ctx.num_special_primes - 1 if ct.include_special else -1 + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + + len_left = len(self.ntt.p.destination_arrays[level][0]) + + if (len_left >= 3) and self.bias_guard: + dc0 = base[0][0].item() + dc1 = scaler[0][0].item() + dc2 = pt[0][1][0].item() + + base[0][0] = 0 + scaler[0][0] = 0 + + q0_ind = self.ntt.p.destination_arrays[level][0][base_at] + q1_ind = self.ntt.p.destination_arrays[level][0][0] + q2_ind = self.ntt.p.destination_arrays[level][0][1] + + q0 = self.ctx.q[q0_ind] + q1 = self.ctx.q[q1_ind] + q2 = self.ctx.q[q2_ind] + + Q = q0 * q1 * q2 + Q0 = q1 * q2 + Q1 = q0 * q2 + Q2 = q0 * q1 + + Qi0 = pow(Q0, -1, q0) + Qi1 = pow(Q1, -1, q1) + Qi2 = pow(Q2, -1, q2) + + dc = (dc0 * Qi0 * Q0 + dc1 * Qi1 * Q1 + dc2 * Qi2 * Q2) % Q + + half_Q = Q // 2 + dc = dc if dc <= half_Q else dc - Q + + dc = (dc + (q1 - 1)) // q1 + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + + # Decoding. + correction = self.corrections[level] + decoded = decode( + scaled[0][-1], + scale=self.scale, + correction=correction, + norm=self.norm, + return_without_scaling=self.bias_guard + ) + decoded = decoded[:self.ctx.N // 2].cpu().numpy() + ## + + decoded = decoded / self.scale * correction + + # Bias guard. + if (len_left >= 3) and self.bias_guard: + decoded += dc / self.scale * correction + if is_real: + decoded = decoded.real + return decoded + + # Shortcuts. + def encorypt(self, m, pk: data_struct, level: int = 0, padding=True): + return self.encodecrypt(m, pk=pk, level=level, padding=padding) + + def decrode(self, ct: data_struct, sk: data_struct, is_real=False): + return self.decryptcode(ct=ct, sk=sk, is_real=is_real) + + # ------------------------------------------------------------------------------------------- + # Conjugation + # ------------------------------------------------------------------------------------------- + + def create_conjugation_key(self, sk: data_struct) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if (not sk.ntt_state) or (not sk.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk.origin) + + sk_new_data = [s.clone() for s in sk.data] + self.ntt.intt(sk_new_data) + sk_new_data = [conjugate(s) for s in sk_new_data] + self.ntt.ntt(sk_new_data) + sk_rotated = data_struct( + data=sk_new_data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + rotk = self.create_key_switching_key(sk_rotated, sk) + rotk = rotk._replace(origin=types.origins["conjk"]) + return rotk + + def conjugate(self, ct: data_struct, conjk: data_struct): + level = ct.level + conj_ct_data = [[conjugate(d) for d in ct_data] for ct_data in ct.data] + + conj_ct_sk = data_struct( + data=conj_ct_data, + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + conj_ct = self.switch_key(conj_ct_sk, conjk) + return conj_ct + + # ------------------------------------------------------------------------------------------- + # Clone. + # ------------------------------------------------------------------------------------------- + + def clone_tensors(self, data: data_struct) -> data_struct: + new_data = [] + # Some data has 1 depth. + if not isinstance(data[0], list): + for device_data in data: + new_data.append(device_data.clone()) + else: + for part in data: + new_data.append([]) + for device_data in part: + new_data[-1].append(device_data.clone()) + return new_data + + def clone(self, text): + if not isinstance(text.data[0], data_struct): + # data are tensors. + data = self.clone_tensors(text.data) + + wrapper = data_struct( + data=data, + include_special=text.include_special, + ntt_state=text.ntt_state, + montgomery_state=text.montgomery_state, + origin=text.origin, + level=text.level, + hash=text.hash, + version=text.version + ) + + else: + wrapper = data_struct( + data=[], + include_special=text.include_special, + ntt_state=text.ntt_state, + montgomery_state=text.montgomery_state, + origin=text.origin, + level=text.level, + hash=text.hash, + version=text.version + ) + + for d in text.data: + wrapper.data.append(self.clone(d)) + + return wrapper + + # ------------------------------------------------------------------------------------------- + # Move data back and forth from GPUs to the CPU. + # ------------------------------------------------------------------------------------------- + + def download_to_cpu(self, gpu_data, level, include_special): + # Prepare destination arrays. + if include_special: + dest = self.ntt.p.destination_arrays_with_special[level] + else: + dest = self.ntt.p.destination_arrays[level] + + # dest contain indices that are the absolute order of primes. + # Convert them to tensor channel indices at this level. + # That is, force them to start from zero. + min_ind = min([min(d) for d in dest]) + dest = [[di - min_ind for di in d] for d in dest] + + # Tensor size parameters. + num_rows = sum([len(d) for d in dest]) + num_cols = self.ctx.N + cpu_size = (num_rows, num_cols) + + # Make a cpu tensor to aggregate the data in GPUs. + cpu_tensor = torch.empty(cpu_size, dtype=self.ctx.torch_dtype, device='cpu') + + for ten, dest_i in zip(gpu_data, dest): + # Check if the tensor is in the gpu. + if ten.device.type != 'cuda': + raise Exception("To download data to the CPU, it must already be in a GPU!!!") + + # Copy in the data. + cpu_tensor[dest_i] = ten.cpu() + + # To avoid confusion, make a list with a single element (only one device, that is the CPU), + # and return it. + return [cpu_tensor] + + def upload_to_gpu(self, cpu_data, level, include_special): + # There's only one device data in the cpu data. + cpu_tensor = cpu_data[0] + + # Check if the tensor is in the cpu. + if cpu_tensor.device.type != 'cpu': + raise Exception("To upload data to GPUs, it must already be in the CPU!!!") + + # Prepare destination arrays. + if include_special: + dest = self.ntt.p.destination_arrays_with_special[level] + else: + dest = self.ntt.p.destination_arrays[level] + + # dest contain indices that are the absolute order of primes. + # Convert them to tensor channel indices at this level. + # That is, force them to start from zero. + min_ind = min([min(d) for d in dest]) + dest = [[di - min_ind for di in d] for d in dest] + + gpu_data = [] + for device_id in range(len(dest)): + # Copy in the data. + dest_device = dest[device_id] + device = self.ntt.devices[device_id] + gpu_tensor = cpu_tensor[dest_device].to(device=device) + + # Append to the gpu_data list. + gpu_data.append(gpu_tensor) + + return gpu_data + + def move_tensors(self, data, level, include_special, direction): + func = { + 'gpu2cpu': self.download_to_cpu, + 'cpu2gpu': self.upload_to_gpu + }[direction] + + # Some data has 1 depth. + if not isinstance(data[0], list): + moved = func(data, level, include_special) + new_data = moved + else: + new_data = [] + for part in data: + moved = func(part, level, include_special) + new_data.append(moved) + return new_data + + def move_to(self, text, direction='gpu2cpu'): + if not isinstance(text.data[0], data_struct): + level = text.level + include_special = text.include_special + + # data are tensors. + data = self.move_tensors(text.data, level, + include_special, direction) + + wrapper = data_struct( + data, text.include_special, + text.ntt_state, text.montgomery_state, + text.origin, text.level, text.hash, text.version + ) + + else: + wrapper = data_struct( + [], text.include_special, + text.ntt_state, text.montgomery_state, + text.origin, text.level, text.hash, text.version + ) + + for d in text.data: + moved = self.move_to(d, direction) + wrapper.data.append(moved) + + return wrapper + + # Shortcuts + + def cpu(self, ct): + return self.move_to(ct, 'gpu2cpu') + + def cuda(self, ct): + return self.move_to(ct, 'cpu2gpu') + + # ------------------------------------------------------------------------------------------- + # check device. + # ------------------------------------------------------------------------------------------- + + def tensor_device(self, data): + # Some data has 1 depth. + if not isinstance(data[0], list): + return data[0].device.type + else: + return data[0][0].device.type + + def device(self, text): + if not isinstance(text.data[0], data_struct): + # data are tensors. + return self.tensor_device(text.data) + else: + return self.device(text.data[0]) + + # ------------------------------------------------------------------------------------------- + # Print data structure. + # ------------------------------------------------------------------------------------------- + + def tree_lead_text(self, level, tabs=2, final=False): + final_char = "└" if final else "├" + + if level == 0: + leader = " " * tabs + trailer = "─" * tabs + lead_text = "─" * tabs + "┬" + trailer + + elif level < 0: + level = -level + leader = " " * tabs + trailer = "─" + "─" * (tabs - 1) + lead_fence = leader + "│" * (level - 1) + lead_text = lead_fence + final_char + trailer + + else: + leader = " " * tabs + trailer = "┬" + "─" * (tabs - 1) + lead_fence = leader + "│" * (level - 1) + lead_text = lead_fence + "├" + trailer + + return lead_text + + def print_data_shapes(self, data, level): + # Some data structures have 1 depth. + if isinstance(data[0], list): + for part_i, part in enumerate(data): + for device_id, d in enumerate(part): + device = self.ntt.devices[device_id] + + if (device_id == len(part) - 1) and \ + (part_i == len(data) - 1): + final = True + else: + final = False + + lead_text = self.tree_lead_text(-level, final=final) + + print(f"{lead_text} tensor at device {device} with " + f"shape {d.shape}.") + else: + for device_id, d in enumerate(data): + device = self.ntt.devices[device_id] + + if device_id == len(data) - 1: + final = True + else: + final = False + + lead_text = self.tree_lead_text(-level, final=final) + + print(f"{lead_text} tensor at device {device} with " + f"shape {d.shape}.") + + def print_data_structure(self, text, level=0): + lead_text = self.tree_lead_text(level) + print(f"{lead_text} {text.origin}") + + if not isinstance(text.data[0], data_struct): + self.print_data_shapes(text.data, level + 1) + else: + for d in text.data: + self.print_data_structure(d, level + 1) + + # ------------------------------------------------------------------------------------------- + # Save and load. + # ------------------------------------------------------------------------------------------- + + def auto_generate_filename(self, fmt_str='%Y%m%d%H%M%s%f'): + return datetime.datetime.now().strftime(fmt_str) + '.pkl' + + def save(self, text, filename=None): + if filename is None: + filename = self.auto_generate_filename() + + savepath = Path(filename) + + # Check if the text is in the CPU. + # If not, move to CPU. + if self.device(text) != 'cpu': + cpu_text = self.cpu(text) + else: + cpu_text = text + + with savepath.open('wb') as f: + pickle.dump(cpu_text, f) + + def load(self, filename, move_to_gpu=True): + savepath = Path(filename) + with savepath.open('rb') as f: + # gc.disable() + cpu_text = pickle.load(f) + # gc.enable() + + if move_to_gpu: + text = self.cuda(cpu_text) + else: + text = cpu_text + + return text + + # ------------------------------------------------------------------------------------------- + # Negate. + # ------------------------------------------------------------------------------------------- + + def negate(self, ct: data_struct) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) # ctt + new_ct = self.clone(ct) + + new_data = new_ct.data + for part in new_data: + for d in part: + d *= -1 + self.ntt.make_signed(part, ct.level) + + return new_ct + + # ------------------------------------------------------------------------------------------- + # scalar ops. + # ------------------------------------------------------------------------------------------- + + def mult_int_scalar(self, ct: data_struct, scalar, evk=None, relin=True): + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + + device_len = len(ct.data[0]) + + int_scalar = int(scalar) + mont_scalar = [(int_scalar * self.ctx.R) % qi for qi in self.ctx.q] + + dest = self.ntt.p.destination_arrays[ct.level] + + partitioned_mont_scalar = [[mont_scalar[i] for i in desti] for desti in dest] + tensorized_scalar = [] + for device_id in range(device_len): + scal_tensor = torch.tensor( + partitioned_mont_scalar[device_id], + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id] + ) + tensorized_scalar.append(scal_tensor) + + new_ct = self.clone(ct) + new_data = new_ct.data + for i in [0, 1]: + self.ntt.mont_enter_scalar(new_data[i], tensorized_scalar, ct.level) + self.ntt.reduce_2q(new_data[i], ct.level) + + return new_ct + + def mult_scalar(self, ct, scalar, evk=None, relin=True): + device_len = len(ct.data[0]) + + scaled_scalar = int( + scalar * self.scale * np.sqrt(self.deviations[ct.level + 1]) + 0.5) + + mont_scalar = [(scaled_scalar * self.ctx.R) % qi for qi in self.ctx.q] + + dest = self.ntt.p.destination_arrays[ct.level] + + partitioned_mont_scalar = [[mont_scalar[i] for i in dest_i] for dest_i in dest] + tensorized_scalar = [] + for device_id in range(device_len): + scal_tensor = torch.tensor( + partitioned_mont_scalar[device_id], + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id] + ) + tensorized_scalar.append(scal_tensor) + + new_ct = self.clone(ct) + new_data = new_ct.data + + for i in [0, 1]: + self.ntt.mont_enter_scalar(new_data[i], tensorized_scalar, ct.level) + self.ntt.reduce_2q(new_data[i], ct.level) + + return self.rescale(new_ct) + + def add_scalar(self, ct, scalar): + device_len = len(ct.data[0]) + + scaled_scalar = int(scalar * self.scale * self.deviations[ct.level] + 0.5) + + if self.norm == 'backward': + scaled_scalar *= self.ctx.N + + scaled_scalar *= self.int_scale + + scaled_scalar = [scaled_scalar % qi for qi in self.ctx.q] + + dest = self.ntt.p.destination_arrays[ct.level] + + partitioned_mont_scalar = [[scaled_scalar[i] for i in desti] for desti in dest] + tensorized_scalar = [] + for device_id in range(device_len): + scal_tensor = torch.tensor( + partitioned_mont_scalar[device_id], + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id] + ) + tensorized_scalar.append(scal_tensor) + + new_ct = self.clone(ct) + new_data = new_ct.data + + dc = [d[:, 0] for d in new_data[0]] + for device_id in range(device_len): + dc[device_id] += tensorized_scalar[device_id] + + self.ntt.reduce_2q(new_data[0], ct.level) + + return new_ct + + def sub_scalar(self, ct, scalar): + return self.add_scalar(ct, -scalar) + + def int_scalar_mult(self, scalar, ct, evk=None, relin=True): + return self.mult_int_scalar(ct, scalar) + + def scalar_mult(self, scalar, ct, evk=None, relin=True): + return self.mult_scalar(ct, scalar) + + def scalar_add(self, scalar, ct): + return self.add_scalar(ct, scalar) + + def scalar_sub(self, scalar, ct): + neg_ct = self.negate(ct) + return self.add_scalar(neg_ct, scalar) + + # ------------------------------------------------------------------------------------------- + # message ops. + # ------------------------------------------------------------------------------------------- + + def mc_mult(self, m, ct, evk=None, relin=True): + m = np.array(m) * np.sqrt(self.deviations[ct.level + 1]) + + pt = self.encode(m, 0) + + pt_tiled = self.ntt.tile_unsigned(pt, ct.level) + + # Transform ntt to prepare for multiplication. + self.ntt.enter_ntt(pt_tiled, ct.level) + + # Prepare a new ct. + new_ct = self.clone(ct) + + self.ntt.enter_ntt(new_ct.data[0], ct.level) + self.ntt.enter_ntt(new_ct.data[1], ct.level) + + new_d0 = self.ntt.mont_mult(pt_tiled, new_ct.data[0], ct.level) + new_d1 = self.ntt.mont_mult(pt_tiled, new_ct.data[1], ct.level) + + self.ntt.intt_exit_reduce(new_d0, ct.level) + self.ntt.intt_exit_reduce(new_d1, ct.level) + + new_ct.data[0] = new_d0 + new_ct.data[1] = new_d1 + + return self.rescale(new_ct) + + def mc_add(self, m, ct): + pt = self.encode(m, ct.level) + pt_tiled = self.ntt.tile_unsigned(pt, ct.level) + + self.ntt.mont_enter_scale(pt_tiled, ct.level) + + new_ct = self.clone(ct) + self.ntt.mont_enter(new_ct.data[0], ct.level) + new_d0 = self.ntt.mont_add(pt_tiled, new_ct.data[0], ct.level) + self.ntt.mont_redc(new_d0, ct.level) + self.ntt.reduce_2q(new_d0, ct.level) + + new_ct.data[0] = new_d0 + + return new_ct + + def mc_sub(self, m, ct): + neg_ct = self.negate(ct) + return self.mc_add(m, neg_ct) + + def cm_mult(self, ct, m, evk=None, relin=True): + return self.mc_mult(m, ct) + + def cm_add(self, ct, m): + return self.mc_add(m, ct) + + def cm_sub(self, ct, m): + return self.mc_add(-np.array(m), ct) + + # ------------------------------------------------------------------------------------------- + # Automatic cc ops. + # ------------------------------------------------------------------------------------------- + + def auto_level(self, ct0, ct1): + level_diff = ct0.level - ct1.level + if level_diff < 0: + new_ct0 = self.level_up(ct0, ct1.level) + return new_ct0, ct1 + elif level_diff > 0: + new_ct1 = self.level_up(ct1, ct0.level) + return ct0, new_ct1 + else: + return ct0, ct1 + + def auto_cc_mult(self, ct0, ct1, evk, relin=True): + lct0, lct1 = self.auto_level(ct0, ct1) + return self.cc_mult(lct0, lct1, evk, relin=relin) + + def auto_cc_add(self, ct0, ct1): + lct0, lct1 = self.auto_level(ct0, ct1) + return self.cc_add(lct0, lct1) + + def auto_cc_sub(self, ct0, ct1): + lct0, lct1 = self.auto_level(ct0, ct1) + return self.cc_sub(lct0, lct1) + + # ------------------------------------------------------------------------------------------- + # Fully automatic ops. + # ------------------------------------------------------------------------------------------- + + def mult(self, a, b, evk=None, relin=True): + type_a = type(a) + type_b = type(b) + + try: + func = self.mult_dispatch_dict[type_a, type_b] + except Exception as e: + raise Exception(f"Unsupported data types are input.\n{e}") + + return func(a, b, evk, relin) + + def add(self, a, b): + type_a = type(a) + type_b = type(b) + + try: + func = self.add_dispatch_dict[type_a, type_b] + except Exception as e: + raise Exception(f"Unsupported data types are input.\n{e}") + + return func(a, b) + + def sub(self, a, b): + type_a = type(a) + type_b = type(b) + + try: + func = self.sub_dispatch_dict[type_a, type_b] + except Exception as e: + raise Exception(f"Unsupported data types are input.\n{e}") + + return func(a, b) + + # ------------------------------------------------------------------------------------------- + # Misc. + # ------------------------------------------------------------------------------------------- + + def refresh(self): + # Refreshes the rng state. + self.rng.refresh() + + def reduce_error(self, ct): + # Reduce the accumulated error in the cipher text. + return self.mult_scalar(ct, 1.0) + + # ------------------------------------------------------------------------------------------- + # Misc ops. + # ------------------------------------------------------------------------------------------- + + def sum(self, ct: data_struct, gk: data_struct, rescale_every=5) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if gk.origin != types.origins["galk"]: + raise errors.NotMatchType(origin=gk.origin, to=types.origins["galk"]) + + new_ct = self.clone(ct) + for roti in range(self.ctx.logN - 1): + rot_ct = self.rotate_single(new_ct, gk.data[roti]) + sum_ct = self.add(rot_ct, new_ct) + del new_ct, rot_ct + if roti != 0 and (roti % rescale_every) == 0: + new_ct = self.reduce_error(sum_ct) + else: + new_ct = sum_ct + return new_ct + + def mean(self, ct: data_struct, gk: data_struct, alpha=1, rescale_every=5) -> data_struct: + # Divide by num_slots. + # The cipher text is refreshed here, and hence + # doesn't need to be refreshed at roti=0 in the loop. + new_ct = self.mult(1 / self.num_slots / alpha, ct) + + for roti in range(self.ctx.logN - 1): + rotk = gk.data[roti] + rot_ct = self.rotate_single(new_ct, rotk) + sum_ct = self.add(rot_ct, new_ct) + del new_ct, rot_ct + if ((roti % rescale_every) == 0) and (roti != 0): + new_ct = self.reduce_error(sum_ct) + else: + new_ct = sum_ct + return new_ct + + def cov(self, ct_a: data_struct, ct_b: data_struct, + evk: data_struct, gk: data_struct, rescale_every=5) -> data_struct: + cta_mean = self.mean(ct_a, gk, rescale_every=rescale_every) + ctb_mean = self.mean(ct_b, gk, rescale_every=rescale_every) + + cta_dev = self.sub(ct_a, cta_mean) + ctb_dev = self.sub(ct_b, ctb_mean) + + ct_cov = self.mult(self.mult(cta_dev, ctb_dev, evk), 1 / (self.num_slots - 1)) + + return ct_cov + + def pow(self, ct: data_struct, power: int, evk: data_struct) -> data_struct: + current_exponent = 2 + pow_list = [ct] + while current_exponent <= power: + current_ct = pow_list[-1] + new_ct = self.cc_mult(current_ct, current_ct, evk) + pow_list.append(new_ct) + current_exponent *= 2 + + remaining_exponent = power - current_exponent // 2 + new_ct = pow_list[-1] + + while remaining_exponent > 0: + pow_ind = math.floor(math.log2(remaining_exponent)) + pow_term = pow_list[pow_ind] + new_ct = self.auto_cc_mult(new_ct, pow_term, evk) + remaining_exponent -= 2 ** pow_ind + + return new_ct + + def square(self, ct: data_struct, evk: data_struct, relin=True) -> data_struct: + x = self.rescale(ct) + + level = x.level + + # Multiply. + x0, x1 = x.data + + self.ntt.enter_ntt(x0, level) + self.ntt.enter_ntt(x1, level) + + d0 = self.ntt.mont_mult(x0, x0, level) + x0y1 = self.ntt.mont_mult(x0, x1, level) + d2 = self.ntt.mont_mult(x1, x1, level) + + d1 = self.ntt.mont_add(x0y1, x0y1, level) + + ct_mult = data_struct( + data=(d0, d1, d2), + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash, + version=self.version + ) + if relin: + ct_mult = self.relinearize(ct_triplet=ct_mult, evk=evk) + + return ct_mult + + # ------------------------------------------------------------------------------------------- + # Multiparty. + # ------------------------------------------------------------------------------------------- + def multiparty_public_crs(self, pk: data_struct): + crs = self.clone(pk).data[1] + return crs + + def multiparty_create_public_key(self, sk: data_struct, a=None, include_special=False) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if include_special and not sk.include_special: + raise errors.SecretKeyNotIncludeSpecialPrime() + mult_type = -2 if include_special else -1 + + level = 0 + e = self.rng.discrete_gaussian(repeats=1) + e = self.ntt.tile_unsigned(e, level, mult_type) + + self.ntt.enter_ntt(e, level, mult_type) + repeats = self.ctx.num_special_primes if sk.include_special else 0 + + if a is None: + a = self.rng.randint( + self.ntt.q_prepack[mult_type][level][0], + repeats=repeats + ) + + sa = self.ntt.mont_mult(a, sk.data, 0, mult_type) + pk0 = self.ntt.mont_sub(e, sa, 0, mult_type) + pk = data_struct( + data=(pk0, a), + include_special=include_special, + ntt_state=True, + montgomery_state=True, + origin=types.origins["pk"], + level=level, + hash=self.hash, + version=self.version + ) + return pk + + def multiparty_create_collective_public_key(self, pks: list[data_struct]) -> data_struct: + data, include_special, ntt_state, montgomery_state, origin, level, hash_, version = pks[0] + mult_type = -2 if include_special else -1 + b = [b.clone() for b in data[0]] # num of gpus + a = [a.clone() for a in data[1]] + + for pk in pks[1:]: + b = self.ntt.mont_add(b, pk.data[0], lvl=0, mult_type=mult_type) + + cpk = data_struct( + (b, a), + include_special=include_special, + ntt_state=ntt_state, + montgomery_state=montgomery_state, + origin=types.origins["pk"], + level=level, + hash=self.hash, + version=self.version + ) + return cpk + + def multiparty_decrypt_head(self, ct: data_struct, sk: data_struct): + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + if not sk.ntt_state or not sk.montgomery_state: + raise errors.NotMatchDataStructState(origin=sk.origin) + level = ct.level + + ct0 = ct.data[0][0] + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + + sk_data = sk.data[0][self.ntt.starts[level][0]:] + + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + pt = self.ntt.mont_add([ct0], sa, level) + + return pt + + def multiparty_decrypt_partial(self, ct: data_struct, sk: data_struct) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + if not sk.ntt_state or not sk.montgomery_state: + raise errors.NotMatchDataStructState(origin=sk.origin) + + data, include_special, ntt_state, montgomery_state, origin, level, hash_, version = ct + + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + + sk_data = sk.data[0][self.ntt.starts[level][0]:] + + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + return sa + + def multiparty_decrypt_fusion(self, pcts: list, level=0, include_special=False): + pt = [x.clone() for x in pcts[0]] + for pct in pcts[1:]: + pt = self.ntt.mont_add(pt, pct, level) + + self.ntt.reduce_2q(pt, level) + + base_at = -self.ctx.num_special_primes - 1 if include_special else -1 + + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + + m = self.decode(m=scaled, level=level) + + return m + + #### ------------------------------------------------------------------------------------------- + #### Multiparty. ROTATION + #### ------------------------------------------------------------------------------------------- + + def multiparty_create_key_switching_key(self, sk_src: data_struct, sk_dst: data_struct, a=None) -> data_struct: + if sk_src.origin != types.origins["sk"] or sk_src.origin != types.origins["sk"]: + raise errors.NotMatchType(origin="not a secret key", to=types.origins["sk"]) + if (not sk_src.ntt_state) or (not sk_src.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_src.origin) + if (not sk_dst.ntt_state) or (not sk_dst.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_dst.origin) + + level = 0 + + stops = self.ntt.stops[-1] + Psk_src = [sk_src.data[di][:stops[di]].clone() for di in range(self.ntt.num_devices)] + + self.ntt.mont_enter_scalar(Psk_src, self.mont_PR, level) + + ksk = [[] for _ in range(self.ntt.p.num_partitions + 1)] + for device_id in range(self.ntt.num_devices): + for part_id, part in enumerate(self.ntt.p.p[level][device_id]): + global_part_id = self.ntt.p.part_allocations[device_id][part_id] + + crs = a[global_part_id] if a else None + pk = self.multiparty_create_public_key(sk_dst, include_special=True, a=crs) + key = tuple(part) + astart = part[0] + astop = part[-1] + 1 + shard = Psk_src[device_id][astart:astop] + pk_data = pk.data[0][device_id][astart:astop] + + _2q = self.ntt.parts_pack[device_id][key]['_2q'] + update_part = ntt_cuda.mont_add([pk_data], [shard], _2q)[0] + pk_data.copy_(update_part, non_blocking=True) + + # Name the pk. + pk_name = f'key switch key part index {global_part_id}' + pk = pk._replace(origin=pk_name) + + ksk[global_part_id] = pk + + return data_struct( + data=ksk, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ksk"], + level=level, + hash=self.hash, + version=self.version + ) + + def multiparty_create_rotation_key(self, sk: data_struct, delta: int, a=None) -> data_struct: + sk_new_data = [s.clone() for s in sk.data] + self.ntt.intt(sk_new_data) + sk_new_data = [rotate(s, delta) for s in sk_new_data] + self.ntt.ntt(sk_new_data) + sk_rotated = data_struct( + data=sk_new_data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + rotk = self.multiparty_create_key_switching_key(sk_rotated, sk, a=a) + rotk = rotk._replace(origin=types.origins["rotk"] + f"{delta}") + return rotk + + def multiparty_generate_rotation_key(self, rotks: list[data_struct]) -> data_struct: + crotk = self.clone(rotks[0]) + for rotk in rotks[1:]: + for ksk_idx in range(len(rotk.data)): + update_parts = self.ntt.mont_add(crotk.data[ksk_idx].data[0], rotk.data[ksk_idx].data[0]) + crotk.data[ksk_idx].data[0][0].copy_(update_parts[0], non_blocking=True) + return crotk + + def generate_rotation_crs(self, rotk: data_struct): + if types.origins["rotk"] not in rotk.origin and types.origins["ksk"] != rotk.origin: + raise errors.NotMatchType(origin=rotk.origin, to=types.origins["ksk"]) + crss = [] + for ksk in rotk.data: + crss.append(ksk.data[1]) + return crss + + #### ------------------------------------------------------------------------------------------- + #### Multiparty. GALOIS + #### ------------------------------------------------------------------------------------------- + + def generate_galois_crs(self, galk: data_struct): + if galk.origin != types.origins["galk"]: + raise errors.NotMatchType(origin=galk.origin, to=types.origins["galk"]) + crs_s = [] + for rotk in galk.data: + crss = [ksk.data[1] for ksk in rotk.data] + crs_s.append(crss) + return crs_s + + def multiparty_create_galois_key(self, sk: data_struct, a: list) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + galois_key_parts = [ + self.multiparty_create_rotation_key(sk, self.galois_deltas[idx], a=a[idx]) + for idx in range(len(self.galois_deltas)) + ] + + galois_key = data_struct( + data=galois_key_parts, + include_special=True, + montgomery_state=True, + ntt_state=True, + origin=types.origins["galk"], + level=0, + hash=self.hash, + version=self.version + ) + return galois_key + + def multiparty_generate_galois_key(self, galks: list[data_struct]) -> data_struct: + cgalk = self.clone(galks[0]) + for galk in galks[1:]: # galk + for rotk_idx in range(len(galk.data)): # rotk + for ksk_idx in range(len(galk.data[rotk_idx].data)): # ksk + update_parts = self.ntt.mont_add( + cgalk.data[rotk_idx].data[ksk_idx].data[0], + galk.data[rotk_idx].data[ksk_idx].data[0] + ) + cgalk.data[rotk_idx].data[ksk_idx].data[0][0].copy_(update_parts[0], non_blocking=True) + return cgalk + + #### ------------------------------------------------------------------------------------------- + #### Multiparty. Evaluation Key + #### ------------------------------------------------------------------------------------------- + + def multiparty_sum_evk_share(self, evks_share: list[data_struct]): + evk_sum = self.clone(evks_share[0]) + for evk_share in evks_share[1:]: + for ksk_idx in range(len(evk_sum.data)): + update_parts = self.ntt.mont_add(evk_sum.data[ksk_idx].data[0], evk_share.data[ksk_idx].data[0]) + for dev_id in range(len(update_parts)): + evk_sum.data[ksk_idx].data[0][dev_id].copy_(update_parts[dev_id], non_blocking=True) + + return evk_sum + + def multiparty_mult_evk_share_sum(self, evk_sum: data_struct, sk: data_struct): + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + evk_sum_mult = self.clone(evk_sum) + + for ksk_idx in range(len(evk_sum.data)): + update_part_b = self.ntt.mont_mult(evk_sum_mult.data[ksk_idx].data[0], sk.data) + update_part_a = self.ntt.mont_mult(evk_sum_mult.data[ksk_idx].data[1], sk.data) + for dev_id in range(len(update_part_b)): + evk_sum_mult.data[ksk_idx].data[0][dev_id].copy_(update_part_b[dev_id], non_blocking=True) + evk_sum_mult.data[ksk_idx].data[1][dev_id].copy_(update_part_a[dev_id], non_blocking=True) + + return evk_sum_mult + + def multiparty_sum_evk_share_mult(self, evk_sum_mult: list[data_struct]) -> data_struct: + cevk = self.clone(evk_sum_mult[0]) + for evk in evk_sum_mult[1:]: + for ksk_idx in range(len(cevk.data)): + update_part_b = self.ntt.mont_add(cevk.data[ksk_idx].data[0], evk.data[ksk_idx].data[0]) + update_part_a = self.ntt.mont_add(cevk.data[ksk_idx].data[1], evk.data[ksk_idx].data[1]) + for dev_id in range(len(update_part_b)): + cevk.data[ksk_idx].data[0][dev_id].copy_(update_part_b[dev_id], non_blocking=True) + cevk.data[ksk_idx].data[1][dev_id].copy_(update_part_a[dev_id], non_blocking=True) + return cevk + + #### ------------------------------------------------------------------------------------------- + #### Statistics + #### ------------------------------------------------------------------------------------------- + + def sqrt(self, ct: data_struct, evk: data_struct, e=0.0001, alpha=0.0001) -> data_struct: + a = self.clone(ct) + b = self.clone(ct) + + while e <= 1 - alpha: + k = float(np.roots([1 - e ** 3, -6 + 6 * e ** 2, 9 - 9 * e])[1]) + t = self.mult_scalar(a, k, evk) + b0 = self.sub_scalar(t, 3) + b1 = self.mult_scalar(b, (k ** 0.5) / 2, evk) + b = self.cc_mult(b0, b1, evk) + + a0 = self.mult_scalar(a, (k ** 3) / 4) + t = self.sub_scalar(a, 3 / k) + a1 = self.square(t, evk) + a = self.cc_mult(a0, a1, evk) + e = k * (3 - k) ** 2 / 4 + + return b + + def var(self, ct: data_struct, evk: data_struct, gk: data_struct, relin=False) -> data_struct: + ct_mean = self.mean(ct=ct, gk=gk) + dev = self.sub(ct, ct_mean) + dev = self.square(ct=dev, evk=evk, relin=relin) + if not relin: + dev = self.relinearize(ct_triplet=dev, evk=evk) + ct_var = self.mean(ct=dev, gk=gk) + return ct_var + + def std(self, ct: data_struct, evk: data_struct, gk: data_struct, relin=False) -> data_struct: + ct_var = self.var(ct=ct, evk=evk, gk=gk, relin=relin) + ct_std = self.sqrt(ct=ct_var, evk=evk) + return ct_std diff --git a/liberate/fhe/context/__init__.py b/liberate/fhe/context/__init__.py new file mode 100644 index 0000000..d88f9f0 --- /dev/null +++ b/liberate/fhe/context/__init__.py @@ -0,0 +1,2 @@ +from .ckks_context import ckks_context +from .security_parameters import maximum_qbits, minimum_cyclotomic_order diff --git a/liberate/fhe/context/ckks_context.py b/liberate/fhe/context/ckks_context.py new file mode 100644 index 0000000..c6098a0 --- /dev/null +++ b/liberate/fhe/context/ckks_context.py @@ -0,0 +1,360 @@ +import math +import pickle +from pathlib import Path +import warnings + +import numpy as np +import torch + +from .generate_primes import generate_message_primes, generate_scale_primes +from .security_parameters import maximum_qbits +from liberate.fhe.cache import cache +from liberate.fhe.presets import errors + +# ------------------------------------------------------------------------------------------ +# NTT parameter pre-calculation. +# ------------------------------------------------------------------------------------------ +CACHE_FOLDER = cache.path_cache + + +def primitive_root_2N(q, N): + _2N = 2 * N + K = (q - 1) // _2N + for x in range(2, N): + g = pow(x, K, q) + h = pow(g, N, q) + if h != 1: + break + return g + + +def psi_power_series(psi, N, q): + series = [1] + for i in range(N - 1): + series.append(series[-1] * psi % q) + return series + + +def bit_rev_psi(q, logN): + N = 2 ** logN + psi = [primitive_root_2N(qi, N) for qi in q] + # Bit-reverse index. + ind = range(N) + brind = [bit_reverse(i, logN) for i in ind] + # The psi power and the indices are the same. + return [pow(psi, brpower, q) for brpower in brind] + + +def psi_bank(q, logN): + N = 2 ** logN + psi = [primitive_root_2N(qi, N) for qi in q] + ipsi = [pow(psii, -1, qi) for psii, qi in zip(psi, q)] + psi_series = [psi_power_series(psii, N, qi) for psii, qi in zip(psi, q)] + ipsi_series = [ + psi_power_series(ipsii, N, qi) for ipsii, qi in zip(ipsi, q) + ] + return psi_series, ipsi_series + + +def bit_reverse(a, nbits): + format_string = f"0{nbits}b" + binary_string = f"{a:{format_string}}" + reverse_binary_string = binary_string[::-1] + return int(reverse_binary_string, 2) + + +def bit_reverse_order_index(logN): + N = 2 ** logN + # Note that for a bit reversing, forward and backward permutations are the same. + # i.e., don't worry about which direction. + revi = np.array([bit_reverse(i, logN) for i in range(N)], dtype=np.int32) + return revi + + +def get_psi(q, logN, my_dtype): + np_dtype_dict = { + np.int32: np.int32, + np.int64: np.int64, + 30: np.int32, + 62: np.int64, + } + dtype = np_dtype_dict[my_dtype] + psi, ipsi = psi_bank(q, logN) + bit_reverse_index = bit_reverse_order_index(logN) + psi = np.array(psi, dtype=dtype)[:, bit_reverse_index] + ipsi = np.array(ipsi, dtype=dtype)[:, bit_reverse_index] + return psi, ipsi + + +def paint_butterfly_forward(logN): + N = 2 ** logN + t = N + painted_even = np.zeros((logN, N), dtype=np.bool8) + painted_odd = np.zeros((logN, N), dtype=np.bool8) + painted_psi = np.zeros((logN, N // 2), dtype=np.int32) + for logm in range(logN): + m = 2 ** logm + t //= 2 + psi_ind = 0 + for i in range(m): + j1 = 2 * i * t + j2 = j1 + t - 1 + Sind = m + i + for j in range(j1, j2 + 1): + Uind = j + Vind = j + t + painted_even[logm, Uind] = True + painted_odd[logm, Vind] = True + painted_psi[logm, psi_ind] = Sind + psi_ind += 1 + painted_eveni = np.where(painted_even)[1].reshape(logN, -1) + painted_oddi = np.where(painted_odd)[1].reshape(logN, -1) + return painted_eveni, painted_oddi, painted_psi + + +def paint_butterfly_backward(logN): + N = 2 ** logN + t = 1 + painted_even = np.zeros((logN, N), dtype=np.bool8) + painted_odd = np.zeros((logN, N), dtype=np.bool8) + painted_psi = np.zeros((logN, N // 2), dtype=np.int32) + for logm in range(logN, 0, -1): + level = logN - logm + m = 2 ** logm + j1 = 0 + h = m // 2 + psi_ind = 0 + for i in range(h): + j2 = j1 + t - 1 + Sind = h + i + for j in range(j1, j2 + 1): + Uind = j + Vind = j + t + # Paint + painted_even[level, Uind] = True + painted_odd[level, Vind] = True + painted_psi[level, psi_ind] = Sind + psi_ind += 1 + j1 += 2 * t + t *= 2 + painted_eveni = np.where(painted_even)[1].reshape(logN, -1) + painted_oddi = np.where(painted_odd)[1].reshape(logN, -1) + return painted_eveni, painted_oddi, painted_psi + + +# ------------------------------------------------------------------------------------------ +# The context class. +# ------------------------------------------------------------------------------------------ + + +@errors.log_error +class ckks_context: + def __init__( + self, + buffer_bit_length=62, + scale_bits=40, + logN=15, + num_scales=None, + num_special_primes=2, + sigma=3.2, + uniform_tenary_secret=True, + cache_folder=CACHE_FOLDER, + security_bits=128, + quantum="post_quantum", + distribution="uniform", + read_cache=True, + save_cache=True, + verbose=False, + is_secured=True + + ): + if not Path(cache_folder).exists(): + Path(cache_folder).mkdir(parents=True, exist_ok=True) + + self.generation_string = f"{buffer_bit_length}_{scale_bits}_{logN}_{num_scales}_" \ + f"{num_special_primes}_{security_bits}_{quantum}_" \ + f"{distribution}" + + self.is_secured = is_secured + # Compose cache savefile name. + savepath = Path(cache_folder) / Path(self.generation_string + ".pkl") + + if savepath.exists() and read_cache: + with savepath.open("rb") as f: + __dict__ = pickle.load(f) + self.__dict__.update(__dict__) + + if verbose: + print( + f"I have read in from the cached save file {savepath}!!!\n" + ) + self.init_print() + + return + + # Transfer input parameters. + self.buffer_bit_length = buffer_bit_length + self.scale_bits = scale_bits + self.logN = logN + self.num_special_primes = num_special_primes + self.cache_folder = cache_folder + self.security_bits = security_bits + self.quantum = quantum + self.distribution = distribution + # Sampling strategy. + self.sigma = sigma + self.uniform_tenary_secret = uniform_tenary_secret + if self.uniform_tenary_secret: + self.secret_key_sampling_method = "uniform tenary" + else: + self.secret_key_sampling_method = "sparse tenary" + + # dtypes. + self.torch_dtype = {30: torch.int32, 62: torch.int64}[ + self.buffer_bit_length + ] + self.numpy_dtype = {30: np.int32, 62: np.int64}[self.buffer_bit_length] + + # Polynomial length. + self.N = 2 ** self.logN + + # We set the message prime to of bit-length W-2. + self.message_bits = self.buffer_bit_length - 2 + + # Read in pre-calculated high-quality primes. + try: + message_special_primes = generate_message_primes(cache_folder=cache_folder)[self.message_bits][self.N] + except KeyError as e: + raise errors.NotFoundMessageSpecialPrimes(message_bit=self.message_bits, N=self.N) + + # For logN > 16, we need significantly more primes. + how_many = 64 if self.logN < 16 else 128 + try: + scale_primes = generate_scale_primes(cache_folder=cache_folder, how_many=how_many)[self.scale_bits, self.N] + except KeyError as e: + raise errors.NotFoundScalePrimes(scale_bits=self.scale_bits, N=self.N) + + # Compose the primes pack. + # Rescaling drops off primes in --> direction. + # Key switching drops off primes in <-- direction. + # Hence, [scale primes, base message prime, special primes] + self.max_qbits = int( + maximum_qbits(self.N, security_bits, quantum, distribution) + ) + base_special_primes = message_special_primes[: 1 + self.num_special_primes] + + # If num_scales is None, generate the maximal number of levels. + try: + if num_scales is None: + base_special_bits = sum( + [math.log2(p) for p in base_special_primes] + ) + available_bits = self.max_qbits - base_special_bits + num_scales = 0 + available_bits -= math.log2(scale_primes[num_scales]) + while available_bits > 0: + num_scales += 1 + available_bits -= math.log2(scale_primes[num_scales]) + + self.num_scales = num_scales + self.q = scale_primes[:num_scales] + base_special_primes + except IndexError as e: + raise errors.NotEnoughPrimes(scale_bits=self.scale_bits, N=self.N) + + # Check if security requirements are met. + self.total_qbits = math.ceil(sum([math.log2(qi) for qi in self.q])) + + if self.total_qbits > self.max_qbits: + if self.is_secured: + raise errors.ViolatedAllowedQbits( + scale_bits=self.scale_bits, N=self.N, num_scales=self.num_scales, + max_qbits=self.max_qbits, total_qbits=self.total_qbits) + else: + warnings.warn( + f"Maximum allowed qbits are violated: " + f"max_qbits={self.max_qbits:4d} and the " + f"requested total is {self.total_qbits:4d}." + ) + + # Generate Montgomery parameters and NTT paints. + self.generate_montgomery_parameters() + self.generate_paints() + + if verbose: + self.init_print() + + # Save cache. + if save_cache: + with savepath.open("wb") as f: + pickle.dump(self.__dict__, f) + + if verbose: + print(f"I have saved to the cached save file {savepath}!!!\n") + + def generate_montgomery_parameters(self): + self.R = 2 ** self.buffer_bit_length + self.R_square = [self.R ** 2 % qi for qi in self.q] + self.half_buffer_bit_length = self.buffer_bit_length // 2 + self.lower_bits_mask = (1 << self.half_buffer_bit_length) - 1 + self.full_bits_mask = (1 << self.buffer_bit_length) - 1 + + self.q_lower_bits = [qi & self.lower_bits_mask for qi in self.q] + self.q_higher_bits = [ + qi >> self.half_buffer_bit_length for qi in self.q + ] + self.q_double = [qi << 1 for qi in self.q] + + self.R_inv = [pow(self.R, -1, qi) for qi in self.q] + self.k = [ + (self.R * R_invi - 1) // qi + for R_invi, qi in zip(self.R_inv, self.q) + ] + self.k_lower_bits = [ki & self.lower_bits_mask for ki in self.k] + self.k_higher_bits = [ + ki >> self.half_buffer_bit_length for ki in self.k + ] + + def generate_paints(self): + self.N_inv = [pow(self.N, -1, qi) for qi in self.q] + + # psi and psi_inv. + psi, psi_inv = get_psi(self.q, self.logN, self.buffer_bit_length) + + # Paints. + ( + self.forward_even_indices, + self.forward_odd_indices, + forward_psi_paint, + ) = paint_butterfly_forward(self.logN) + ( + self.backward_even_indices, + self.backward_odd_indices, + backward_psi_paint, + ) = paint_butterfly_backward(self.logN) + + # Pre-painted psi and ipsi. + self.forward_psi = psi[..., forward_psi_paint.ravel()].reshape( + -1, *forward_psi_paint.shape + ) + self.backward_psi_inv = psi_inv[ + ..., backward_psi_paint.ravel() + ].reshape(-1, *backward_psi_paint.shape) + + def init_print(self): + print(f""" +I have received inputs: + buffer_bit_length\t\t= {self.buffer_bit_length:,d} + scale_bits\t\t\t= {self.scale_bits:,d} + logN\t\t\t\t= {self.logN:,d} + N\t\t\t\t= {self.N:,d} + Number of special primes\t= {self.num_special_primes:,d} + Number of scales\t\t= {self.num_scales:,d} + Cache folder\t\t\t= '{self.cache_folder:s}' + Security bits\t\t\t= {self.security_bits:,d} + Quantum security model\t\t= {self.quantum:s} + Security sampling distribution\t= {self.distribution:s} + Number of message bits\t\t= {self.message_bits:,d} + In total I will be using '{self.total_qbits:,d}' bits out of available maximum '{self.max_qbits:,d}' bits. + And is it secured?\t\t= {self.is_secured} +My RNS primes are {self.q}.""" + ) diff --git a/liberate/fhe/context/generate_primes.py b/liberate/fhe/context/generate_primes.py new file mode 100644 index 0000000..b912cf9 --- /dev/null +++ b/liberate/fhe/context/generate_primes.py @@ -0,0 +1,311 @@ +import math +import multiprocessing +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from joblib import Parallel, delayed + +from .prim_test import MillerRabinPrimalityTest +from .security_parameters import maximum_qbits +from liberate.fhe.cache import cache + +# Default cache folder. +CACHE_FOLDER = cache.path_cache + + +def generate_N_M(logN=None, cache_folder=CACHE_FOLDER, **kw): + if logN is None: + logN = list(range(12, 18)) + savefile = Path(cache_folder) / 'logN_N_M.pkl' + + if savefile.exists(): + with savefile.open('rb') as f: + logN_N_M = pickle.load(f) + logN = logN_N_M['logN'] + N = logN_N_M['N'] + M = logN_N_M['M'] + return logN, N, M + + N = [2 ** lN for lN in logN] + M = [2 * n for n in N] + + logN_N_M = { + 'logN': logN, + 'N': N, + 'M': M + } + + with savefile.open('wb') as f: + pickle.dump(logN_N_M, f) + + return logN, N, M + + +def check_ntt_primality(q: int, M: int): + # Is this in the KM+1 form? + NTT_comliance = (q - 1) % M + # It is compliant, go ahead. + if NTT_comliance == 0: + # Now, is q a prime? + is_prime = MillerRabinPrimalityTest(q) + if is_prime: + return True + return False + + +def generate_message_primes(mbits=None, cache_folder=CACHE_FOLDER, how_many=11, **kw): + if mbits is None: + mbits = [28, 60] + savefile = Path(cache_folder) / 'message_special_primes.pkl' + + if savefile.exists(): + with savefile.open('rb') as f: + mprimes = pickle.load(f) + # return mprimes + else: + logN, N, M = generate_N_M(cache_folder=cache_folder, **kw) + + mprimes = {} + for mb in mbits: + mprimes[mb] = {} + for m in M: + # We want to deal with N, and hide M implicitly. + N = m // 2 + mprimes[mb][N] = [] + current_query = 2 ** mb - 1 + q_count = 0 + + while True: + ok = check_ntt_primality(current_query, m) + if ok: + mprimes[mb][N].append(current_query) + q_count += 1 + + # Have we pulled out how_many primes? + if q_count == how_many: + break + + # Move onto the next query. + current_query -= 2 + + with savefile.open('wb') as f: + pickle.dump(mprimes, f) + + return mprimes + + +def maximum_levels(N: int, qbits: int = 40, mbits: int = 60, nksk: int = 2) -> int: + extra_bits = mbits * (1 + nksk) + f_levels = (maximum_qbits(N) - extra_bits) / qbits + return math.floor(f_levels) + + +def find_the_next_prime(start: int, m: int, up=True) -> int: + step: int = 2 if up else -2 + current_query: int = start + while True: + ok: bool = check_ntt_primality(q=current_query, M=m) + if ok: + break + current_query += step + return current_query + + +def generate_alternating_prime_sequence( + sb: int = 40, + N: int = 2 ** 15, + how_many: int = 60, + optimize: bool = True, + alternate_directions: bool = True, + fixed_direction: bool = False, +) -> list: + m: int = N * 2 + scale: int = 2 ** sb + + s_primes: list = [] + + up: int = scale + 1 + down: int = scale - 1 + + if alternate_directions: + # Initial search. + up0: int = find_the_next_prime(start=up, m=m) + down0: int = find_the_next_prime(start=down, m=m, up=False) + + # Initial error. + eup: int = up0 - scale + edown: int = scale - down0 + + # Set the current direction. + # True is up + # This is the NEXT direction. That means if the first item was up, then this direction is down. + current_direction: bool = False if eup < edown else True + + # Initialize count. + q_count: int = 0 + + cumulative_scale: int = 1 + + # The loop. + while True: + # Search in the given direction. + current_query: int = up if current_direction else down + next_prime: int = find_the_next_prime(start=current_query, m=m, up=current_direction) + # cumulative scale progresses as (beta_{i-1} * beta_i)**2 + + # cumulative_scale *= scale / next_prime + + # Use Pre-rescale quadratic deviation rule. + current_dev = scale / next_prime + cumulative_scale = cumulative_scale ** 2 * current_dev ** 2 + + # Set the next variable. + if current_direction: + up: int = next_prime + 2 + if optimize: + searched: int = int((cumulative_scale * scale) // 2 * 2 - 1) + down: int = searched if searched < down else down + else: + down: int = next_prime - 2 + if optimize: + searched: int = int((cumulative_scale * scale) // 2 * 2 + 1) + up: int = searched if searched > up else up + + # Switch the direction. + current_direction: bool = not current_direction + + # Store. + s_primes.append(next_prime) + q_count += 1 + + # Escape. + if q_count >= how_many: + break + + else: + q_count: int = 0 + current_query: int = up if fixed_direction else down + step: int = 2 if fixed_direction else -2 + while True: + current_query: int = find_the_next_prime( + start=current_query, m=m, up=fixed_direction + ) + s_primes.append(current_query) + # Move on. + q_count += 1 + current_query += step + + # Escape. + if q_count >= how_many: + break + return s_primes + + +def cum_prod(x: list) -> list: + ret: list = [1] + for i in range(len(x)): + ret.append(ret[-1] * x[i]) + return ret[1:] + + +def plot_cumulative_relative_error(sb: int = 40, N: int = 2 ** 15, label: str = "Optimized", **kw): + how_many: int = maximum_levels(N=N, qbits=sb) + p: list = generate_alternating_prime_sequence(sb=sb, N=N, how_many=how_many, **kw) + + # Check every prime in the sequence is unique. + unique_p: list = sorted(set(p)) + + err_msg = "There are repeating primes in the generate primes set!!!" + assert len(unique_p) == len(p), err_msg + + scale: int = 2 ** sb + e: list = [scale / pi for pi in p] + + y: list = cum_prod(e) + + plt.plot(y, label=label) + plt.grid() + plt.show() + # Error propagation. + q: np.array = np.array(y) - 1 + print(f"Error expanded {np.abs(q).max() / np.abs(q)[0]} times.") + + +def pgen_pseq(sb, N, how_many: int) -> list | str: + # TODO return 정리 + if how_many < 2: + return f"ERROR!!! sb = {sb}, N = {N}. Not enough primes." + + try: + res: list = generate_alternating_prime_sequence( + sb=sb, N=N, how_many=how_many + ) + except Exception as e: + # Try with the half how_many. + res: list = pgen_pseq(sb=sb, N=N, how_many=how_many // 2) + return res + + +def generate_scale_primes(cache_folder=CACHE_FOLDER, how_many=64, ncpu_cutdown=32, verbose=0): + savefile = Path(cache_folder) / 'scale_primes.pkl' + if savefile.exists(): + with open(savefile, 'rb') as f: + result_dict = pickle.load(f) + return result_dict + + ncpu = multiprocessing.cpu_count() + + # Cut down the number of n-cpus. It tends to slow down after 32. + ncpu = ncpu_cutdown if ncpu > ncpu_cutdown else ncpu + + logN, N, M = generate_N_M(cache_folder=cache_folder) + + # Scale. + logS = list(range(20, 55, 5)) + + # Generate input packages. + inputs = [] + for log_n, n in zip(logN, N): + how_many = 64 if log_n < 16 else 128 + for sb in logS: + inputs.append((sb, n, how_many)) + + result = Parallel(n_jobs=ncpu, + verbose=verbose)(delayed(pgen_pseq)(*inp) for inp in inputs) + + result_dict = {(sb, N): pr for (sb, N, how_many), pr in zip(inputs, result)} + + with savefile.open('wb') as f: + pickle.dump(result_dict, f) + + return result_dict + + +def measure_scale_primes_quality(sb: int = 40, N: int = 2 ** 15): + scale_primes: dict = generate_scale_primes() + + p: list = scale_primes[(sb, N)] + + # Check every prime in the sequence is unique. + unique_p = sorted(set(p)) + assert len(unique_p) == len( + p + ), "There are repeating primes in the generate primes set!!!" + + scale = 2 ** sb + e = [scale / pi for pi in p] + + y = cum_prod(e) + + plt.plot(y, label=f"Scale bits={sb}, logN={math.log2(N)}") + + # How many primes? + print(f"I have {len(p)} primes in the set.") + + # Error propagation. + q = np.array(y) - 1 + print(f"Max. relative error is {np.abs(q).max():.3e}.") + print(f"Min. relative error is {np.abs(q).min():.3e}.") + print(f"Error expanded {np.abs(q).max() / np.abs(q)[0]:.3f} times.") diff --git a/liberate/fhe/context/prim_test.py b/liberate/fhe/context/prim_test.py new file mode 100644 index 0000000..ead0a5b --- /dev/null +++ b/liberate/fhe/context/prim_test.py @@ -0,0 +1,64 @@ +import random + + +def MillerRabinPrimalityTest(number, rounds=10): + # If the input is an even number, return immediately with False. + if number == 2: + return True + elif number == 1 or number % 2 == 0: + return False + + # First we want to express n as : 2^s * r ( were r is odd ) + + # The odd part of the number + oddPartOfNumber = number - 1 + + # The number of time that the number is divided by two + timesTwoDividNumber = 0 + + # while r is even divid by 2 to find the odd part + while oddPartOfNumber % 2 == 0: + oddPartOfNumber = oddPartOfNumber / 2 + timesTwoDividNumber = timesTwoDividNumber + 1 + + # Make oddPartOfNumber integer. + oddPartOfNumber = int(oddPartOfNumber) + + # Since there are number that are cases of "strong liar" we need to check more than one number + for time in range(rounds): + # Choose "Good" random number + while True: + # Draw a RANDOM number in range of number ( Z_number ) + randomNumber = random.randint(2, number) - 1 + if randomNumber != 0 and randomNumber != 1: + break + + # randomNumberWithPower = randomNumber^oddPartOfNumber mod number + randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) + + # If random number is not 1 and not -1 ( in mod n ) + if (randomNumberWithPower != 1) and ( + randomNumberWithPower != number - 1 + ): + # number of iteration + iterationNumber = 1 + + # While we can squre the number and the squered number is not -1 mod number + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): + # Squre the number + randomNumberWithPower = pow(randomNumberWithPower, 2, number) + + # inc the number of iteration + iterationNumber = iterationNumber + 1 + + # If x != -1 mod number then it is because we did not find strong witnesses + # hence 1 have more then two roots in mod n ==> + # n is composite ==> return false for primality + + if randomNumberWithPower != (number - 1): + return False + + # The number pass the tests ==> it is probably prime ==> return true for primality + return True diff --git a/liberate/fhe/context/security_parameters.py b/liberate/fhe/context/security_parameters.py new file mode 100644 index 0000000..6057eaa --- /dev/null +++ b/liberate/fhe/context/security_parameters.py @@ -0,0 +1,201 @@ +from scipy.interpolate import InterpolatedUnivariateSpline + +# These are bits security levels, i.e., the 1^lambda measure. +security_levels = [128, 192, 256] + +# This is the dimension of the cyclotomic moduli, in +# ℤ[X]/Φ𝑚(𝑋), and m = 2^l. Where n is the leading (biggest) power of X in the polynomial Φ𝑚(𝑋). +# Such that, Φ𝑚(𝑋) = X^n + 1, where n = m / 2. +cyclotomic_n = [1024, 2048, 4096, 8192, 16384, 32768] + +# The following q is the moduli of a ring ℤq. Note that the numbers are given in log(q) values, +# where log in this context means log base 2. +# +# There are 2 sections in the standard documentation, namely pre- and post- quantum security. +# We separate them in respective dictionaries. +# +# Also, there are 3 different methods of sampling the messages according to respective distributions. +# Those are uniform, error, and (-1, 1) (tenary). +# We differentiate the message distribution by dictionary keys: 'uniform', 'error', and 'tenary'. + +# This is the pre-quantum security requirements. +logq_preq = {} +logq_preq["uniform"] = [ + 29, + 21, + 16, + 56, + 39, + 31, + 111, + 77, + 60, + 220, + 154, + 120, + 440, + 307, + 239, + 880, + 612, + 478, +] +logq_preq["error"] = [ + 29, + 21, + 16, + 56, + 39, + 31, + 111, + 77, + 60, + 220, + 154, + 120, + 440, + 307, + 239, + 883, + 613, + 478, +] +logq_preq["tenary"] = [ + 27, + 19, + 14, + 54, + 37, + 29, + 109, + 75, + 58, + 218, + 152, + 118, + 438, + 305, + 237, + 881, + 611, + 476, +] + +# This is the post-quantum security requirements. +logq_postq = {} +logq_postq["uniform"] = [ + 27, + 19, + 15, + 53, + 37, + 29, + 103, + 72, + 56, + 206, + 143, + 111, + 413, + 286, + 222, + 829, + 573, + 445, +] +logq_postq["error"] = [ + 27, + 19, + 15, + 53, + 37, + 29, + 103, + 72, + 56, + 206, + 143, + 111, + 413, + 286, + 222, + 829, + 573, + 445, +] +logq_postq["tenary"] = [ + 25, + 17, + 13, + 51, + 35, + 27, + 101, + 70, + 54, + 202, + 141, + 109, + 411, + 284, + 220, + 827, + 571, + 443, +] + + +# Partition q's by security levels. +def partitq(q): + qlen = len(q) + levlen = len(security_levels) + grouped = [ + [q[i] for i in range(0 + lev, qlen, levlen)] for lev in range(levlen) + ] + by_sec_lev = {lev: grouped[l] for l, lev in enumerate(security_levels)} + return by_sec_lev + + +# Gather up. +logq = {} +distributions = ["uniform", "error", "tenary"] +logq["pre_quantum"] = { + distributions[disti]: partitq(logq_preq[dist]) + for disti, dist in enumerate(distributions) +} +logq["post_quantum"] = { + distributions[disti]: partitq(logq_postq[dist]) + for disti, dist in enumerate(distributions) +} + + +def minimum_cyclotomic_order( + q_bits, security_bits=128, quantum="post_quantum", distribution="uniform" +): + assert quantum in [ + "pre_quantum", + "post_quantum", + ], "Wrong quantum security model!!!" + assert distribution in ["uniform", "error", "tenary"] + assert security_bits in [128, 192, 256] + + x = logq[quantum][distribution][security_bits] + y = cyclotomic_n + s = InterpolatedUnivariateSpline(x, y, k=1) + return s(q_bits) + + +def maximum_qbits( + L, security_bits=128, quantum="post_quantum", distribution="uniform" +): + assert quantum in [ + "pre_quantum", + "post_quantum", + ], "Wrong quantum security model!!!" + assert distribution in ["uniform", "error", "tenary"] + assert security_bits in [128, 192, 256] + + x = cyclotomic_n + y = logq[quantum][distribution][security_bits] + s = InterpolatedUnivariateSpline(x, y, k=1) + return s(L) diff --git a/liberate/fhe/data_struct.py b/liberate/fhe/data_struct.py new file mode 100644 index 0000000..3d82e12 --- /dev/null +++ b/liberate/fhe/data_struct.py @@ -0,0 +1,24 @@ +from typing import NamedTuple +from .version import VERSION + + +class data_struct(NamedTuple): + """ + - Data structure. + - data: the data in tensor format + - include_special: Boolean, including the special prime channels or not. + - ntt_state: Boolean, whether if the data is ntt transformed or not. + - montgomery_state: Boolean, whether if the data is in the Montgomery form or not. + - origin: String, where did this data came from - cipher text, secret key, etc. + - level: Integer, the current level where this data is situated. + - hash: String, a SHA256 hash of the input settings and the prime numbers used to RNS decompose numbers. + - version: String, version number. + """ + data: tuple | list + include_special: bool + ntt_state: bool + montgomery_state: bool + origin: str + level: int + hash: str + version: str = VERSION diff --git a/liberate/fhe/encdec/__init__.py b/liberate/fhe/encdec/__init__.py new file mode 100644 index 0000000..803a2da --- /dev/null +++ b/liberate/fhe/encdec/__init__.py @@ -0,0 +1 @@ +from .encdec import decode, encode, rotate, conjugate diff --git a/liberate/fhe/encdec/encdec.py b/liberate/fhe/encdec/encdec.py new file mode 100644 index 0000000..f0a3e26 --- /dev/null +++ b/liberate/fhe/encdec/encdec.py @@ -0,0 +1,323 @@ +import numpy as np +import torch + + +# --------------------------------------------------------------- +# Permutation. +# --------------------------------------------------------------- + +def circular_shift_permutation(N, shift=1): + left = np.roll(np.arange(N // 2), shift) + right = np.roll(np.arange(N // 2), -shift) + N // 2 + return np.concatenate([left, right]) + + +def canon_permutation(N, k=1, verbose=False): + """ + Permutes the coefficients of the lattice basis that yields correctly the permutation + of the decoded message. + + The canonical permutation is defined as mu_p(n) = pn mod M where p is coprime with M, + where p=2*k+1. + """ + M = 2 * N + p = int(2 * k + 1) # Make sure p is an integer. + n = np.arange(M) # n starts from 0. + pn = p * n % M + if verbose: + print(f"Canonical permutation for p={p} is\n{pn}") + return pn + + +def canon_permutation_torch(N, k=1, device='cuda:0', verbose=False): + """ + Permutes the coefficients of the lattice basis that yields correctly the permutation + of the decoded message. + + The canonical permutation is defined as mu_p(n) = pn mod M where p is coprime with M, + where p=2*k+1. + """ + M = N * 2 + p = int(2 * k + 1) # Make sure p is an integer. + n = torch.arange(N, device=device) # n starts from 0. + pn = p * n % M + if verbose: + print(f"Canonical permutation for p={p} is\n{pn}") + return pn + + +def fold_permutation(N, p, verbose=False): + """ + In application to crypto, we fold the FFT at Nyquist. + + Inverse FFT results in selection of alternating elements. + Folding should correct the indices of the permutation according to the + folding rule. + + For example, 1->0, 3->1, 5->2, and so on. + """ + fold_p = (p[1::2] - 1) // 2 + if verbose: + print(f"Folding\n{p}\nresulted in\n{fold_p}.") + return fold_p + + +def conjugate_permutation(p, q): + """ + Conjugate permutations p and q by stacking p on top of q. + + Permutations p and q must share the same cycle structures. + """ + # Calculate cycles. + pc = permutation_cycles(p) + qc = permutation_cycles(q) + + # Check if the cycle structures match. + cs1 = [len(c) for c in pc] + cs2 = [len(c) for c in qc] + assert ( + cs1 == cs2 + ), "Cycle structures of permutations must match for a conjugate to exist!!!" + + # Expand cycles. + pe = np.array([i for c in pc for i in c]) + qe = np.array([i for c in qc for i in c]) + + # Move slots. + r = np.zeros_like(p) + r[qe] = pe + + # Return. + return r + + +def permutation_cycles(perm): + """ + Transform a plain permutation into a composition of cycles. + """ + pi = {i: perm[i] for i in range(len(perm))} + cycles = [] + while pi: + elem0 = next(iter(pi)) # arbitrary starting element + this_elem = pi[elem0] + next_item = pi[this_elem] + + cycle = [] + while True: + cycle.append(this_elem) + del pi[this_elem] + this_elem = next_item + if next_item in pi: + next_item = pi[next_item] + else: + break + cycles.append(cycle) + return cycles + + +def inverse_permutation(p, verbose=False): + """ + Calculates the inverse permutation. + """ + N = len(p) + ind = np.arange(N) + ip = ind[np.argsort(p)] + if verbose: + print(f"The inverse of permutation\n{p}\nis\n{ip}.") + return ip + + +# --------------------------------------------------------------- +# Negacyclic fft. +# --------------------------------------------------------------- + + +def expand2conjugate(m): + return torch.concat([m, torch.flipud(torch.conj(m))]) + + +def generate_twister(N, device='cuda:0'): + expr = -1j * torch.pi * torch.arange(N, device=device, dtype=torch.float64) / N + return torch.exp(expr) + + +def generate_skewer(N, device='cuda:0'): + expr = 1j * torch.pi * torch.arange(N, device=device, dtype=torch.float64) / N + skew = torch.exp(expr) + return skew + + +def m2poly(m, twister, norm='backward'): + """ + m is the message and this function turns the message into + polynomial coefficients. + The message must be expanded mirrored in conjugacy. + """ + + # Run fft and multiply twister. + ffted = torch.fft.fft(m, norm=norm) + + # Twist. + twisted = ffted * twister + + # Return the real part. + return twisted.real + + +def poly2m(poly, skewer, norm='backward'): + """ + poly is the polynomial coefficients and this function turns the coefficients + into a plain message. + """ + + # Multiply skewer. + t = poly * skewer + + # Recover. + recovered = torch.fft.ifft(t, norm=norm) + + # Return the real part. + return recovered + + +# --------------------------------------------------------------- +# Utilities. +# --------------------------------------------------------------- + +perm_cache = {} +twister_cache = {} +skewer_cache = {} + + +def prepost_perms(N, device='cuda:0'): + circ_shift = circular_shift_permutation(N) + canon_perm = canon_permutation(N) + fold_perm = fold_permutation(N, canon_perm) + post_perm = conjugate_permutation(circ_shift, fold_perm) + pre_perm = inverse_permutation(post_perm)[:N // 2] + + post_perm = torch.from_numpy(post_perm).to(device) + pre_perm = torch.from_numpy(pre_perm).to(device) + return pre_perm, post_perm + + +def pre_permute(m, pre_perm): + """ + Input m must be a torch tensor. + """ + N = m.size(-1) + permed_m = torch.zeros((N * 2,), dtype=m.dtype, device=m.device) + permed_m[pre_perm] = m + conj_permed_m = permed_m + permed_m.conj().flip(0) + return conj_permed_m + + +def post_permute(m, post_perm): + """ + Input m must be a torch tensor. + """ + permed_m = torch.zeros_like(m) + permed_m[post_perm] = m + return permed_m + + +def rotate(m, delta): + N = m.size(-1) + C = m.numel() // N + + shift = delta % N + leap = (3 ** shift - 1) // 2 % (N * 2) + + if (N, leap, m.device) in perm_cache.keys(): + perm = perm_cache[(N, leap, m.device)] + else: + perm = canon_permutation_torch(N, leap, device=m.device) + perm_cache[(N, leap, m.device)] = perm + + perm_folded = perm % N + perm_sign = (-1) ** (perm // N) + + # Permute! + # We want to both be capable of 2D and 1D tensors. + # Use view. + rot_m = torch.zeros_like(m) + rot_m.view(C, N).T[perm_folded] = (perm_sign * m).view(C, N).T + + return rot_m + + +def conjugate(m): + N = m.size(-1) + C = m.numel() // N + + leap = N - 1 + + if (N, leap, m.device) in perm_cache.keys(): + perm = perm_cache[(N, leap, m.device)] + else: + perm = canon_permutation_torch(N, leap, device=m.device) + perm_cache[(N, leap, m.device)] = perm + + perm_folded = perm % N + perm_sign = (-1) ** (perm // N) + + # Permute! + # We want to both be capable of 2D and 1D tensors. + # Use view. + rot_m = torch.zeros_like(m) + rot_m.view(C, N).T[perm_folded] = (perm_sign * m).view(C, N).T + + return rot_m + + +def encode(m, rng=None, scale=2 ** 40, deviation=1.0, + device='cuda:0', norm='forward', + return_without_scaling=False): + N = len(m) * 2 + if (N, device) in perm_cache.keys(): + pre_perm, post_perm = perm_cache[(N, device)] + else: + pre_perm, post_perm = prepost_perms(N, device=device) + perm_cache[(N, device)] = (pre_perm, post_perm) + + mm = torch.from_numpy(np.array(m * deviation)).to(device) # check dtype m * deviation + mm = pre_permute(mm, pre_perm) + + if (N, device) in twister_cache.keys(): + twister = twister_cache[N, device] + else: + twister = generate_twister(N, device) + twister_cache[N, device] = twister + + if return_without_scaling: + return m2poly(mm, twister, norm) + else: + mm = m2poly(mm, twister, norm) * np.float64(scale) + return rng.randround(mm) + + +def decode(m, scale=2 ** 40, + correction=1.0, norm='forward', + return_without_scaling=False): + N = len(m) + device = m.device.type + ':' + str(m.device.index) + if (N, device) in perm_cache.keys(): + pre_perm, post_perm = perm_cache[(N, device)] + else: + pre_perm, post_perm = prepost_perms(N, device=device) + perm_cache[(N, device)] = (pre_perm, post_perm) + + if (N, device) in skewer_cache.keys(): + skewer = skewer_cache[N, device] + else: + skewer = generate_skewer(N, device) + skewer_cache[N, device] = skewer + + if return_without_scaling: + mm = poly2m(m, skewer, norm=norm) + mm = post_permute(mm, post_perm) + return mm + else: + mm = poly2m(m, skewer, norm=norm) / scale * correction + mm = post_permute(mm, post_perm) + return mm diff --git a/liberate/fhe/presets/__init__.py b/liberate/fhe/presets/__init__.py new file mode 100644 index 0000000..dc776af --- /dev/null +++ b/liberate/fhe/presets/__init__.py @@ -0,0 +1,3 @@ +from .params import params +from . import types +from . import errors diff --git a/liberate/fhe/presets/errors.py b/liberate/fhe/presets/errors.py new file mode 100644 index 0000000..d914122 --- /dev/null +++ b/liberate/fhe/presets/errors.py @@ -0,0 +1,166 @@ +from functools import wraps +import logging + + +def log_error(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logging.error(f"[Error] Error in {func.__name__} : {e}") + raise + + return wrapper + + +class TestException(Exception): + def __init__(self): + message_error = "test error" + super().__init__(message_error) + + +class NotFoundMessageSpecialPrimes(Exception): + def __init__(self, message_bit, N): + self.message_error = f"""Can't find message_bit = {message_bit:3 +#include + +//------------------------------------------------------------------ +// Main functions +//------------------------------------------------------------------ + +torch::Tensor mont_mult_cuda(const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void mont_enter_cuda(torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void ntt_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void enter_ntt_cuda(torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void intt_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + + +void mont_redc_cuda(torch::Tensor a, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + + +void intt_exit_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void intt_exit_reduce_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void intt_exit_reduce_signed_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void reduce_2q_cuda(torch::Tensor a, + const torch::Tensor _2q); + +void make_signed_cuda(torch::Tensor a, + const torch::Tensor _2q); + +void make_unsigned_cuda(torch::Tensor a, + const torch::Tensor _2q); + +torch::Tensor mont_add_cuda(const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor _2q); + +torch::Tensor mont_sub_cuda(const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor _2q); + +torch::Tensor tile_unsigned_cuda(torch::Tensor a, + const torch::Tensor _2q); + + +//------------------------------------------------------------------ +// Main functions +//------------------------------------------------------------------ + +std::vector mont_mult( + const std::vector a, + const std::vector b, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + std::vector outputs; + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector Rs, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector Rs, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector _2q) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector _2q) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector _2q) { + + const auto num_devices = a.size(); + for (int i=0; i mont_add( + const std::vector a, + const std::vector b, + const std::vector _2q) { + + std::vector outputs; + + const auto num_devices = a.size(); + for (int i=0; i mont_sub( + const std::vector a, + const std::vector b, + const std::vector _2q) { + + std::vector outputs; + + const auto num_devices = a.size(); + for (int i=0; i tile_unsigned( + std::vector a, + const std::vector _2q) { + + std::vector outputs; + + const auto num_devices = _2q.size(); + for (int i=0; i FORWARD NTT"); + m.def("intt", &intt, "INVERSE NTT"); + m.def("mont_redc", &mont_redc, "MONTGOMERY REDUCTION"); + m.def("intt_exit", &intt_exit, "INVERSE NTT -> EXIT"); + m.def("intt_exit_reduce", &intt_exit_reduce, "INVERSE NTT -> EXIT -> REDUCE"); + m.def("intt_exit_reduce_signed", &intt_exit_reduce_signed, "INVERSE NTT -> EXIT -> REDUCE -> MAKE SIGNED"); + m.def("reduce_2q", &reduce_2q, "REDUCE RANGE TO 2q"); + m.def("make_signed", &make_signed, "MAKE SIGNED"); + m.def("make_unsigned", &make_unsigned, "MAKE UNSIGNED"); + m.def("mont_add", &mont_add, "MONTGOMERY ADDITION"); + m.def("mont_sub", &mont_sub, "MONTGOMERY SUBTRACTION"); + m.def("tile_unsigned", &tile_unsigned, "TILE -> MAKE UNSIGNED"); +} diff --git a/liberate/ntt/ntt_context.py b/liberate/ntt/ntt_context.py new file mode 100644 index 0000000..7dc774b --- /dev/null +++ b/liberate/ntt/ntt_context.py @@ -0,0 +1,523 @@ +import datetime +import time + +import numpy as np +import torch + +from liberate.fhe.presets import errors +from . import ntt_cuda +from .rns_partition import rns_partition + + +@errors.log_error +class ntt_context: + def __init__(self, ctx, index_type=torch.int32, devices=None, verbose=False): + + # Mark the start time. + t0 = time.time() + + # Set devices first. + if devices is None: + gpu_count = torch.cuda.device_count() + self.devices = [f'cuda:{i}' for i in range(gpu_count)] + else: + self.devices = devices + + self.num_devices = len(self.devices) + + # Transfer input parameters. + self.index_type = index_type + self.verbose = verbose + self.ctx = ctx + + if self.verbose: + print(f"[{str(datetime.datetime.now())}] I have received the context:\n") + self.ctx.init_print() + print(f"[{str(datetime.datetime.now())}] Requested devices for computation are {self.devices}.") + + self.num_ordinary_primes = self.ctx.num_scales + 1 + self.num_special_primes = self.ctx.num_special_primes + self.num_levels = self.ctx.num_scales + 1 + + self.p = rns_partition(self.num_ordinary_primes, + self.num_special_primes, self.num_devices) + if verbose: + print(f"[{str(datetime.datetime.now())}] I have generated a partitioning scheme.") + print(f"[{str(datetime.datetime.now())}] I have in total {self.num_levels} levels available.") + print(f"[{str(datetime.datetime.now())}] I have {self.num_ordinary_primes} ordinary primes.") + print(f"[{str(datetime.datetime.now())}] I have {self.num_special_primes} special primes.") + + self.prepare_parameters() + + if verbose: + print(f"[{str(datetime.datetime.now())}] I prepared ntt parameters.") + + t1 = time.time() + if verbose: + print(f'[{str(datetime.datetime.now())}] ntt initialization took {(t1 - t0):.2f} seconds.') + + self.qlists = [qi.tolist() for qi in self.q] + + astop_special = [len(d) for d in self.p.destination_arrays_with_special[0]] + astop_ordinary = [len(d) for d in self.p.destination_arrays[0]] + self.starts = self.p.diff + + self.stops = [astop_special, astop_ordinary] + + self.generate_parts_pack() + self.pre_package() + + # ------------------------------------------------------------------------------------------------- + # Arrange according to partitioning scheme input variables, and copy to GPUs for fast access. + # ------------------------------------------------------------------------------------------------- + + def partition_variable(self, variable): + + np_v = np.array(variable, dtype=self.ctx.numpy_dtype) + + v_special = [] + dest = self.p.d_special + for dev_id in range(self.num_devices): + d = dest[dev_id] + parted_v = np_v[d] + v_special.append( + torch.from_numpy(parted_v).to( + self.devices[dev_id])) + + return v_special + + def copy_to_devices(self, variable): + return [torch.tensor( + variable, dtype=self.index_type, device=device) for device in self.devices] + + def psi_enter(self): + Rs = self.Rs + ql = self.ql + qh = self.qh + kl = self.kl + kh = self.kh + + p = self.psi + + a = [psi.view(psi.size(0), -1) for psi in p] + + ntt_cuda.mont_enter(a, Rs, ql, qh, kl, kh) + + p = self.ipsi + a = [psi.view(psi.size(0), -1) for psi in p] + ntt_cuda.mont_enter(a, Rs, ql, qh, kl, kh) + + def Ninv_enter(self): + self.Ninv = [ + (self.ctx.N_inv[i] * self.ctx.R) % self.ctx.q[i] + for i in range(len(self.ctx.q))] + + def prepare_parameters(self): + + scale = 2 ** self.ctx.scale_bits + self.Rs_scale = self.partition_variable([ + (Rs * scale) % q for Rs, q in + zip(self.ctx.R_square, self.ctx.q) + ]) + + self.Rs = self.partition_variable(self.ctx.R_square) + + self.q = self.partition_variable(self.ctx.q) + self._2q = self.partition_variable(self.ctx.q_double) + self.ql = self.partition_variable(self.ctx.q_lower_bits) + self.qh = self.partition_variable(self.ctx.q_higher_bits) + self.kl = self.partition_variable(self.ctx.k_lower_bits) + self.kh = self.partition_variable(self.ctx.k_higher_bits) + + self.even = self.copy_to_devices(self.ctx.forward_even_indices) + self.odd = self.copy_to_devices(self.ctx.forward_odd_indices) + self.ieven = self.copy_to_devices(self.ctx.backward_even_indices) + self.iodd = self.copy_to_devices(self.ctx.backward_odd_indices) + + self.psi = self.partition_variable(self.ctx.forward_psi) + self.ipsi = self.partition_variable(self.ctx.backward_psi_inv) + + self.Ninv_enter() + self.Ninv = self.partition_variable(self.Ninv) + + self.psi_enter() + + self.mont_pack0 = [self.ql, self.qh, self.kl, self.kh] + + self.ntt_pack0 = [self.even, self.odd, + self.psi, self._2q, self.ql, self.qh, + self.kl, self.kh] + + self.intt_pack0 = [self.ieven, self.iodd, + self.ipsi, self.Ninv, self._2q, + self.ql, self.qh, + self.kl, self.kh] + + def param_pack(self, param, astart, astop, remove_empty=True): + pack = [param[dev_id][astart[dev_id]:astop[dev_id]] + for dev_id in range(self.num_devices)] + + remove_empty_f = lambda x: [xi for xi in x if len(xi) > 0] + if remove_empty: + pack = remove_empty_f(pack) + return pack + + def mont_pack(self, astart, astop, remove_empty=True): + return [self.param_pack( + param, astart, astop, remove_empty) for param in self.mont_pack0] + + def ntt_pack(self, astart, astop, remove_empty=True): + remove_empty_f_x = lambda x: [xi for xi in x if len(xi) > 0] + + remove_empty_f_xy = lambda x, y: [xi for xi, yi in zip(x, y) if len(yi) > 0] + + even_odd = self.ntt_pack0[:2] + rest = [self.param_pack( + param, astart, astop, remove_empty=False) for param in self.ntt_pack0[2:]] + + if remove_empty: + even_odd = [remove_empty_f_xy(eo, rest[0]) for eo in even_odd] + rest = [remove_empty_f_x(r) for r in rest] + + return even_odd + rest + + def intt_pack(self, astart, astop, remove_empty=True): + remove_empty_f_x = lambda x: [xi for xi in x if len(xi) > 0] + + remove_empty_f_xy = lambda x, y: [xi for xi, yi in zip(x, y) if len(yi) > 0] + + even_odd = self.intt_pack0[:2] + rest = [self.param_pack( + param, astart, astop, remove_empty=False) for param in self.intt_pack0[2:]] + + if remove_empty: + even_odd = [remove_empty_f_xy(eo, rest[0]) for eo in even_odd] + rest = [remove_empty_f_x(r) for r in rest] + + return even_odd + rest + + def start_stop(self, lvl, mult_type): + return self.starts[lvl], self.stops[mult_type] + + # ------------------------------------------------------------------------------------------------- + # Package by parts. + # ------------------------------------------------------------------------------------------------- + + def params_pack_device(self, device_id, astart, astop): + starts = [0] * self.num_devices + stops = [0] * self.num_devices + + starts[device_id] = astart + stops[device_id] = astop + 1 + + stst = [starts, stops] + + item = {} + + item['mont_pack'] = self.mont_pack(*stst) + item['ntt_pack'] = self.ntt_pack(*stst) + item['intt_pack'] = self.intt_pack(*stst) + item['Rs'] = self.param_pack(self.Rs, *stst) + item['Rs_scale'] = self.param_pack(self.Rs_scale, *stst) + item['_2q'] = self.param_pack(self._2q, *stst) + item['qlist'] = self.param_pack(self.qlists, *stst) + + return item + + def generate_parts_pack(self): + + blank_L_enter = [None] * self.num_devices + + self.parts_pack = [] + + for device_id in range(self.num_devices): + self.parts_pack.append({}) + + for i in range(len( + self.p.destination_arrays_with_special[0][device_id])): + self.parts_pack[device_id][i,] = self.params_pack_device( + device_id, i, i) + + for level in range(self.num_levels): + + for mult_type in [-1, -2]: + starts, stops = self.start_stop(level, mult_type) + astart = starts[device_id] + + astop = stops[device_id] - 1 + + key = tuple(range(astart, astop + 1)) + + if len(key) > 0: + if key not in self.parts_pack[device_id]: + self.parts_pack[device_id][key] = self.params_pack_device( + device_id, astart, astop) + + for p in self.p.p_special[level][device_id]: + key = tuple(p) + if key not in self.parts_pack[device_id].keys(): + astart = p[0] + astop = p[-1] + self.parts_pack[device_id][key] = self.params_pack_device( + device_id, astart, astop) + + for device_id in range(self.num_devices): + for level in range(self.num_levels): + # We do basis extension for only ordinary parts. + for part_index, part in enumerate(self.p.destination_parts[level][device_id]): + + key = tuple(self.p.p[level][device_id][part_index]) + + # Check if Y and L are already calculated for this part. + if 'Y_scalar' not in self.parts_pack[device_id][key].keys(): + + alpha = len(part) + m = [self.ctx.q[idx] for idx in part] + L = [m[0]] + + for i in range(1, alpha - 1): + L.append(L[-1] * m[i]) + + Y_scalar = [] + L_scalar = [] + for i in range(alpha - 1): + L_inv = pow(L[i], -1, m[i + 1]) + L_inv_R = (L_inv * self.ctx.R) % m[i + 1] + Y_scalar.append(L_inv_R) + + if (i + 2) < alpha: + L_scalar.append([]) + for j in range(i + 2, alpha): + L_scalar[i].append((L[i] * self.ctx.R) % m[j]) + + L_enter_devices = [] + for target_device_id in range(self.num_devices): + + dest = self.p.destination_arrays_with_special[0][target_device_id] + q = [self.ctx.q[idx] for idx in dest] + Rs = [self.ctx.R_square[idx] for idx in dest] + + L_enter = [] + for i in range(alpha - 1): + + L_enter.append([]) + for j in range(len(dest)): + L_Rs = (L[i] * Rs[j]) % q[j] + L_enter[i].append(L_Rs) + L_enter_devices.append(L_enter) + + device = self.devices[device_id] + + if len(Y_scalar) > 0: + Y_scalar = torch.tensor( + Y_scalar, dtype=self.ctx.torch_dtype, device=device) + self.parts_pack[device_id][key]['Y_scalar'] = Y_scalar + + for target_device_id in range(self.num_devices): + target_device = self.devices[target_device_id] + + L_enter_devices[target_device_id] = [ + torch.tensor( + Li, + dtype=self.ctx.torch_dtype, + device=target_device + ) for Li in L_enter_devices[target_device_id]] + + self.parts_pack[device_id][key]['L_enter'] = L_enter_devices + + else: + self.parts_pack[device_id][key]['Y_scalar'] = None + self.parts_pack[device_id][key]['L_enter'] = blank_L_enter + + if len(L_scalar) > 0: + L_scalar = [torch.tensor( + Li, dtype=self.ctx.torch_dtype, device=device) for Li in L_scalar] + self.parts_pack[device_id][key]['L_scalar'] = L_scalar + else: + self.parts_pack[device_id][key]['L_scalar'] = None + + # ------------------------------------------------------------------------------------------------- + # Pre-packaging. + # ------------------------------------------------------------------------------------------------- + def pre_package(self): + self.mont_prepack = [] + self.ntt_prepack = [] + self.intt_prepack = [] + self.Rs_prepack = [] + self.Rs_scale_prepack = [] + self._2q_prepack = [] + + # q_prepack is a list of lists, not tensors. + # We need this for generating uniform samples. + self.q_prepack = [] + + for device_id in range(self.num_devices): + mont_prepack = [] + ntt_prepack = [] + intt_prepack = [] + Rs_prepack = [] + Rs_scale_prepack = [] + _2q_prepack = [] + q_prepack = [] + for lvl in range(self.num_levels): + mont_prepack_part = [] + ntt_prepack_part = [] + intt_prepack_part = [] + Rs_prepack_part = [] + Rs_scale_prepack_part = [] + _2q_prepack_part = [] + q_prepack_part = [] + for part in self.p.p_special[lvl][device_id]: + key = tuple(part) + item = self.parts_pack[device_id][key] + + mont_prepack_part.append(item['mont_pack']) + ntt_prepack_part.append(item['ntt_pack']) + intt_prepack_part.append(item['intt_pack']) + Rs_prepack_part.append(item['Rs']) + Rs_scale_prepack_part.append(item['Rs_scale']) + _2q_prepack_part.append(item['_2q']) + q_prepack_part.append(item['qlist']) + + for mult_type in [-2, -1]: + starts, stops = self.start_stop(lvl, mult_type) + astart = starts[device_id] + + astop = stops[device_id] - 1 + + key = tuple(range(astart, astop + 1)) + + if len(key) > 0: + item = self.parts_pack[device_id][key] + + mont_prepack_part.append(item['mont_pack']) + ntt_prepack_part.append(item['ntt_pack']) + intt_prepack_part.append(item['intt_pack']) + Rs_prepack_part.append(item['Rs']) + Rs_scale_prepack_part.append(item['Rs_scale']) + _2q_prepack_part.append(item['_2q']) + q_prepack_part.append(item['qlist']) + + else: + mont_prepack_part.append(None) + ntt_prepack_part.append(None) + intt_prepack_part.append(None) + Rs_prepack_part.append(None) + Rs_scale_prepack_part.append(None) + _2q_prepack_part.append(None) + q_prepack_part.append(None) + + mont_prepack.append(mont_prepack_part) + ntt_prepack.append(ntt_prepack_part) + intt_prepack.append(intt_prepack_part) + Rs_prepack.append(Rs_prepack_part) + Rs_scale_prepack.append(Rs_scale_prepack_part) + _2q_prepack.append(_2q_prepack_part) + q_prepack.append(q_prepack_part) + + self.mont_prepack.append(mont_prepack) + self.ntt_prepack.append(ntt_prepack) + self.intt_prepack.append(intt_prepack) + self.Rs_prepack.append(Rs_prepack) + self.Rs_scale_prepack.append(Rs_scale_prepack) + self._2q_prepack.append(_2q_prepack) + self.q_prepack.append(q_prepack) + + for mult_type in [-2, -1]: + mont_prepack = [] + ntt_prepack = [] + intt_prepack = [] + Rs_prepack = [] + Rs_scale_prepack = [] + _2q_prepack = [] + q_prepack = [] + for lvl in range(self.num_levels): + stst = self.start_stop(lvl, mult_type) + mont_prepack.append([self.mont_pack(*stst)]) + ntt_prepack.append([self.ntt_pack(*stst)]) + intt_prepack.append([self.intt_pack(*stst)]) + Rs_prepack.append([self.param_pack(self.Rs, *stst)]) + Rs_scale_prepack.append([self.param_pack(self.Rs_scale, *stst)]) + _2q_prepack.append([self.param_pack(self._2q, *stst)]) + q_prepack.append([self.param_pack(self.qlists, *stst)]) + self.mont_prepack.append(mont_prepack) + self.ntt_prepack.append(ntt_prepack) + self.intt_prepack.append(intt_prepack) + self.Rs_prepack.append(Rs_prepack) + self.Rs_scale_prepack.append(Rs_scale_prepack) + self._2q_prepack.append(_2q_prepack) + self.q_prepack.append(q_prepack) + + # ------------------------------------------------------------------------------------------------- + # Helper functions to do the Montgomery and NTT operations. + # ------------------------------------------------------------------------------------------------- + + def mont_enter(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_enter(a, self.Rs_prepack[mult_type][lvl][part], + *self.mont_prepack[mult_type][lvl][part]) + + def mont_enter_scale(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_enter( + a, self.Rs_scale_prepack[mult_type][lvl][part], + *self.mont_prepack[mult_type][lvl][part]) + + def mont_enter_scalar(self, a, b, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_enter( + a, b, *self.mont_prepack[mult_type][lvl][part]) + + def mont_mult(self, a, b, lvl=0, mult_type=-1, part=0): + return ntt_cuda.mont_mult( + a, b, *self.mont_prepack[mult_type][lvl][part]) + + def ntt(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.ntt( + a, *self.ntt_prepack[mult_type][lvl][part]) + + def enter_ntt(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.enter_ntt( + a, self.Rs_prepack[mult_type][lvl][part], + *self.ntt_prepack[mult_type][lvl][part]) + + def intt(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt( + a, *self.intt_prepack[mult_type][lvl][part]) + + def mont_redc(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_redc( + a, *self.mont_prepack[mult_type][lvl][part]) + + def intt_exit(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt_exit( + a, *self.intt_prepack[mult_type][lvl][part]) + + def intt_exit_reduce(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt_exit_reduce( + a, *self.intt_prepack[mult_type][lvl][part]) + + def intt_exit_reduce_signed(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt_exit_reduce_signed( + a, *self.intt_prepack[mult_type][lvl][part]) + + def reduce_2q(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.reduce_2q( + a, self._2q_prepack[mult_type][lvl][part]) + + def make_signed(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.make_signed( + a, self._2q_prepack[mult_type][lvl][part]) + + def make_unsigned(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.make_unsigned( + a, self._2q_prepack[mult_type][lvl][part]) + + def mont_add(self, a, b, lvl=0, mult_type=-1, part=0): + return ntt_cuda.mont_add( + a, b, self._2q_prepack[mult_type][lvl][part]) + + def mont_sub(self, a, b, lvl=0, mult_type=-1, part=0): + return ntt_cuda.mont_sub( + a, b, self._2q_prepack[mult_type][lvl][part]) + + def tile_unsigned(self, a, lvl=0, mult_type=-1, part=0): + return ntt_cuda.tile_unsigned( + a, self._2q_prepack[mult_type][lvl][part]) diff --git a/liberate/ntt/ntt_cuda_kernel.cu b/liberate/ntt/ntt_cuda_kernel.cu new file mode 100644 index 0000000..fc96eea --- /dev/null +++ b/liberate/ntt/ntt_cuda_kernel.cu @@ -0,0 +1,1230 @@ +#include +#include +#include +#include + +#define BLOCK_SIZE 256 + +//------------------------------------------------------------------ +// pointwise mont_mult +//------------------------------------------------------------------ + +template __device__ __forceinline__ scalar_t +mont_mult_scalar_cuda_kernel( + const scalar_t a, const scalar_t b, + const scalar_t ql, const scalar_t qh, + const scalar_t kl, const scalar_t kh) { + + // Masks. + constexpr scalar_t one = 1; + constexpr scalar_t nbits = sizeof(scalar_t) * 8 - 2; + constexpr scalar_t half_nbits = sizeof(scalar_t) * 4 - 1; + constexpr scalar_t fb_mask = ((one << nbits) - one); + constexpr scalar_t lb_mask = (one << half_nbits) - one; + + const scalar_t al = a & lb_mask; + const scalar_t ah = a >> half_nbits; + const scalar_t bl = b & lb_mask; + const scalar_t bh = b >> half_nbits; + + const scalar_t alpha = ah * bh; + const scalar_t beta = ah * bl + al * bh; + const scalar_t gamma = al * bl; + + // s = xk mod R + const scalar_t gammal = gamma & lb_mask; + const scalar_t gammah = gamma >> half_nbits; + const scalar_t betal = beta & lb_mask; + const scalar_t betah = beta >> half_nbits; + + scalar_t upper = gammal * kh; + upper = upper + (gammah + betal) * kl; + upper = upper << half_nbits; + scalar_t s = upper + gammal * kl; + s = upper + gammal * kl; + s = s & fb_mask; + + // t = x + sq + // u = t/R + const scalar_t sl = s & lb_mask; + const scalar_t sh = s >> half_nbits; + const scalar_t sqb = sh * ql + sl * qh; + const scalar_t sqbl = sqb & lb_mask; + const scalar_t sqbh = sqb >> half_nbits; + + scalar_t carry = (gamma + sl * ql) >> half_nbits; + carry = (carry + betal + sqbl) >> half_nbits; + + return alpha + betah + sqbh + carry + sh * qh; +} + + +//------------------------------------------------------------------ +// mont_mult +//------------------------------------------------------------------ + +template +__global__ void mont_mult_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32b_acc, + torch::PackedTensorAccessor32c_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + const scalar_t a = a_acc[i][j]; + const scalar_t b = b_acc[i][j]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Store the result. + c_acc[i][j] = mont_mult_scalar_cuda_kernel(a, b, ql, qh, kl, kh); +} + +template +void mont_mult_cuda_typed( + const torch::Tensor a, + const torch::Tensor b, + torch::Tensor c, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + const auto a_acc = a.packed_accessor32(); + const auto b_acc = b.packed_accessor32(); + auto c_acc = c.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + mont_mult_cuda_kernel<<>>( + a_acc, b_acc, c_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + + +torch::Tensor mont_mult_cuda( + const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Prepare the output. + torch::Tensor c = torch::empty_like(a); + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_mult_cuda", ([&] { + mont_mult_cuda_typed(a, b, c, ql, qh, kl, kh); + })); + + return c; +} + + + +//------------------------------------------------------------------ +// mont_enter +//------------------------------------------------------------------ + +template +__global__ void mont_enter_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32Rs_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + const scalar_t a = a_acc[i][j]; + const scalar_t Rs = Rs_acc[i]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Store the result. + a_acc[i][j] = mont_mult_scalar_cuda_kernel(a, Rs, ql, qh, kl, kh); +} + + +template +void mont_enter_cuda_typed( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + const auto Rs_acc = Rs.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + mont_enter_cuda_kernel<<>>( + a_acc, Rs_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +void mont_enter_cuda( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_enter_cuda", ([&] { + mont_enter_cuda_typed(a, Rs, ql, qh, kl, kh); + })); +} + + + + + +//------------------------------------------------------------------ +// ntt +//------------------------------------------------------------------ + +template +__global__ void ntt_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32even_acc, + const torch::PackedTensorAccessor32odd_acc, + const torch::PackedTensorAccessor32psi_acc, + const torch::PackedTensorAccessor32_2q_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc, + const int level){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Montgomery inputs. + const scalar_t _2q = _2q_acc[i]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Butterfly. + const int even_j = even_acc[level][j]; + const int odd_j = odd_acc[level][j]; + + const scalar_t U = a_acc[i][even_j]; + const scalar_t S = psi_acc[i][level][j]; + const scalar_t O = a_acc[i][odd_j]; + const scalar_t V = mont_mult_scalar_cuda_kernel(S, O, ql, qh, kl, kh); + + // Store back. + const scalar_t UplusV = U + V; + const scalar_t UminusV = U + _2q - V; + + a_acc[i][even_j] = (UplusV < _2q)? UplusV : UplusV - _2q; + a_acc[i][odd_j] = (UminusV < _2q)? UminusV : UminusV - _2q; +} + + +template +void ntt_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N = even.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } +} + + + +void ntt_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_ntt_cuda", ([&] { + ntt_cuda_typed(a, even, odd, psi, _2q, ql, qh, kl, kh); + })); +} + + +//------------------------------------------------------------------ +// enter_ntt +//------------------------------------------------------------------ + +template +void enter_ntt_cuda_typed( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + const auto Rs_acc = Rs.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + // enter. + mont_enter_cuda_kernel<<>>( + a_acc, Rs_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // ntt. + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } +} + + +void enter_ntt_cuda( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_enter_ntt_cuda", ([&] { + enter_ntt_cuda_typed(a, Rs, even, odd, psi, _2q, ql, qh, kl, kh); + })); +} + + + + + +//------------------------------------------------------------------ +// intt +//------------------------------------------------------------------ + +template +__global__ void intt_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32even_acc, + const torch::PackedTensorAccessor32odd_acc, + const torch::PackedTensorAccessor32psi_acc, + const torch::PackedTensorAccessor32_2q_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc, + const int level){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Montgomery inputs. + const scalar_t _2q = _2q_acc[i]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Butterfly. + const int even_j = even_acc[level][j]; + const int odd_j = odd_acc[level][j]; + + const scalar_t U = a_acc[i][even_j]; + const scalar_t S = psi_acc[i][level][j]; + const scalar_t V = a_acc[i][odd_j]; + + const scalar_t UminusV = U + _2q - V; + const scalar_t O = (UminusV < _2q)? UminusV : UminusV - _2q; + + const scalar_t W = mont_mult_scalar_cuda_kernel(S, O, ql, qh, kl, kh); + a_acc[i][odd_j] = W; + + const scalar_t UplusV = U + V; + a_acc[i][even_j] = (UplusV < _2q)? UplusV : UplusV - _2q; +} + + +template +void intt_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +void intt_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_cuda", ([&] { + intt_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + + + + + + +//------------------------------------------------------------------ +// mont_redc +//------------------------------------------------------------------ + +template +__global__ void mont_redc_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Masks. + constexpr scalar_t one = 1; + constexpr scalar_t nbits = sizeof(scalar_t) * 8 - 2; + constexpr scalar_t half_nbits = sizeof(scalar_t) * 4 - 1; + constexpr scalar_t fb_mask = ((one << nbits) - one); + constexpr scalar_t lb_mask = (one << half_nbits) - one; + + // Inputs. + const scalar_t x = a_acc[i][j]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Implementation. + // s= xk mod R + const scalar_t xl = x & lb_mask; + const scalar_t xh = x >> half_nbits; + const scalar_t xkb = xh * kl + xl * kh; + scalar_t s = (xkb << half_nbits) + xl * kl; + s = s & fb_mask; + + // t = x + sq + // u = t/R + // Note that x gets erased in t/R operation if x < R. + const scalar_t sl = s & lb_mask; + const scalar_t sh = s >> half_nbits; + const scalar_t sqb = sh * ql + sl * qh; + const scalar_t sqbl = sqb & lb_mask; + const scalar_t sqbh = sqb >> half_nbits; + scalar_t carry = (x + sl * ql) >> half_nbits; + carry = (carry + sqbl) >> half_nbits; + + // Assume we have satisfied the condition 4*q < R. + // Return the calculated value directly without conditional subtraction. + a_acc[i][j] = sqbh + carry + sh * qh; +} + + +template +void mont_redc_cuda_typed( + torch::Tensor a, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +void mont_redc_cuda( + torch::Tensor a, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_redc_cuda", ([&] { + mont_redc_cuda_typed(a, ql, qh, kl, kh); + })); +} + + +//------------------------------------------------------------------ +// Chained intt series. +//------------------------------------------------------------------ + +/**************************************************************/ +/* CUDA kernels */ +/**************************************************************/ + +template +__global__ void reduce_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t q = _2q_acc[i] >> one; + + // Reduce. + a_acc[i][j] = (a < q)? a : a - q; +} + +template +__global__ void make_signed_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t q = _2q_acc[i] >> one; + const scalar_t q_half = q >> one; + + // Make signed. + a_acc[i][j] = (a <= q_half)? a : a - q; +} + + +/**************************************************************/ +/* Typed functions */ +/**************************************************************/ + +/////////////////////////////////////////////////////////////// +// intt exit + +template +void intt_exit_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Exit. + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +/////////////////////////////////////////////////////////////// +// intt exit reduce + +template +void intt_exit_reduce_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Exit. + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Reduce. + reduce_cuda_kernel<<>>(a_acc, _2q_acc); +} + + +/////////////////////////////////////////////////////////////// +// intt exit reduce signed + +template +void intt_exit_reduce_signed_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Exit. + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Reduce. + reduce_cuda_kernel<<>>(a_acc, _2q_acc); + + // Make signed. + make_signed_cuda_kernel<<>>(a_acc, _2q_acc); +} + + + + +/**************************************************************/ +/* Connectors */ +/**************************************************************/ + +/////////////////////////////////////////////////////////////// +// intt exit + +void intt_exit_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_exit_cuda", ([&] { + intt_exit_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + +/////////////////////////////////////////////////////////////// +// intt exit reduce + +void intt_exit_reduce_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_exit_reduce_cuda", ([&] { + intt_exit_reduce_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + +/////////////////////////////////////////////////////////////// +// intt exit reduce signed + +void intt_exit_reduce_signed_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_exit_reduce_signed_cuda", ([&] { + intt_exit_reduce_signed_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + + +//------------------------------------------------------------------ +// Misc +//------------------------------------------------------------------ + +template +__global__ void make_unsigned_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t q = _2q_acc[i] >> one; + + // Make unsigned. + a_acc[i][j] += q; +} + +template +__global__ void tile_unsigned_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + torch::PackedTensorAccessor32dst_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t q = _2q_acc[i] >> one; + const scalar_t a = a_acc[j]; + + // Make unsigned. + dst_acc[i][j] = a + q; +} + +template +__global__ void mont_add_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32b_acc, + torch::PackedTensorAccessor32c_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t b = b_acc[i][j]; + const scalar_t _2q = _2q_acc[i]; + + // Add. + const scalar_t aplusb = a + b; + c_acc[i][j] = (aplusb < _2q)? aplusb : aplusb - _2q; +} + +template +__global__ void mont_sub_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32b_acc, + torch::PackedTensorAccessor32c_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t b = b_acc[i][j]; + const scalar_t _2q = _2q_acc[i]; + + // Sub. + const scalar_t aminusb = a + _2q - b; + c_acc[i][j] = (aminusb < _2q)? aminusb : aminusb - _2q; +} + +template +void reduce_2q_cuda_typed(torch::Tensor a, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = a.size(0); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + auto a_acc = a.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + reduce_cuda_kernel<<>>(a_acc, _2q_acc); +} + +template +void make_signed_cuda_typed(torch::Tensor a, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = a.size(0); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + auto a_acc = a.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + make_signed_cuda_kernel<<>>(a_acc, _2q_acc); +} + +template +void tile_unsigned_cuda_typed(const torch::Tensor a, torch::Tensor dst, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = _2q.size(0); + const auto N = a.size(0); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + const auto a_acc = a.packed_accessor32(); + auto dst_acc = dst.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + tile_unsigned_cuda_kernel<<>>(a_acc, dst_acc, _2q_acc); +} + +template +void make_unsigned_cuda_typed(torch::Tensor a, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = a.size(0); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + auto a_acc = a.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + make_unsigned_cuda_kernel<<>>(a_acc, _2q_acc); +} + +template +void mont_add_cuda_typed( + const torch::Tensor a, + const torch::Tensor b, + torch::Tensor c, + const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + const auto a_acc = a.packed_accessor32(); + const auto b_acc = b.packed_accessor32(); + auto c_acc = c.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + mont_add_cuda_kernel<<>>(a_acc, b_acc, c_acc, _2q_acc); +} + +template +void mont_sub_cuda_typed( + const torch::Tensor a, + const torch::Tensor b, + torch::Tensor c, + const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + const auto a_acc = a.packed_accessor32(); + const auto b_acc = b.packed_accessor32(); + auto c_acc = c.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + mont_sub_cuda_kernel<<>>(a_acc, b_acc, c_acc, _2q_acc); +} + +void reduce_2q_cuda(torch::Tensor a, const torch::Tensor _2q) { + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_reduce_2q_cuda", ([&] { + reduce_2q_cuda_typed(a, _2q); + })); +} + +void make_signed_cuda(torch::Tensor a, const torch::Tensor _2q) { + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_make_signed_cuda", ([&] { + make_signed_cuda_typed(a, _2q); + })); +} + +void make_unsigned_cuda(torch::Tensor a, const torch::Tensor _2q) { + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_make_unsigned_cuda", ([&] { + make_unsigned_cuda_typed(a, _2q); + })); +} + +torch::Tensor tile_unsigned_cuda(const torch::Tensor a, const torch::Tensor _2q) { + a.squeeze_(); + const auto C = _2q.size(0); + const auto N = a.size(0); + auto c = a.new_empty({C, N}); + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_tile_unsigned_cuda", ([&] { + tile_unsigned_cuda_typed(a, c, _2q); + })); + return c; +} + +torch::Tensor mont_add_cuda(const torch::Tensor a, const torch::Tensor b, const torch::Tensor _2q) { + torch::Tensor c = torch::empty_like(a); + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_add_cuda", ([&] { + mont_add_cuda_typed(a, b, c, _2q); + })); + return c; +} + +torch::Tensor mont_sub_cuda(const torch::Tensor a, const torch::Tensor b, const torch::Tensor _2q) { + torch::Tensor c = torch::empty_like(a); + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_sub_cuda", ([&] { + mont_sub_cuda_typed(a, b, c, _2q); + })); + return c; +} diff --git a/liberate/ntt/rns_partition.py b/liberate/ntt/rns_partition.py new file mode 100644 index 0000000..9e54c59 --- /dev/null +++ b/liberate/ntt/rns_partition.py @@ -0,0 +1,141 @@ +import numpy as np + + +class rns_partition: + def __init__(self, num_ordinary_primes=17, + num_special_primes=2, + num_devices=2): + + primes_idx = list(range(num_ordinary_primes - 1)) + base_idx = num_ordinary_primes - 1 + + num_partitions = -(-(num_ordinary_primes - 1) // num_special_primes) + + part = lambda i: primes_idx[i * num_special_primes:(i + 1) * num_special_primes] + partitions = [part(i) for i in range(num_partitions)] + + partitions.append([num_ordinary_primes - 1]) + + partitions.append( + list( + range(num_ordinary_primes, num_ordinary_primes + num_special_primes))) + + alloc = lambda i: list(range(num_partitions - i - 1, -1, -num_devices))[::-1] + part_allocations = [alloc(i) for i in range(num_devices)] + + part_allocations[0].append(num_partitions) + + for p in part_allocations: p.append(num_partitions + 1) + + expand_alloc = lambda i: [partitions[part] for part in part_allocations[i]] + prime_allocations = [expand_alloc(i) for i in range(num_devices)] + + flat_prime_allocations = [sum(alloc, []) for alloc in prime_allocations] + + self.num_ordinary_primes = num_ordinary_primes + self.num_special_primes = num_special_primes + self.num_devices = num_devices + self.num_partitions = num_partitions + self.partitions = partitions + self.part_allocations = part_allocations + self.prime_allocations = prime_allocations + self.flat_prime_allocations = flat_prime_allocations + self.num_scales = self.num_ordinary_primes - 1 + + self.base_prime_idx = self.num_ordinary_primes - 1 + self.special_prime_idx = list( + range( + self.num_ordinary_primes + 1, + self.num_ordinary_primes + 1 + self.num_special_primes) + ) + + self.compute_destination_arrays() + self.compute_rescaler_locations() + self.compute_partitions() + + def compute_destination_arrays(self): + filter_alloc = lambda devi, i: [ + a for a in self.flat_prime_allocations[devi] if a >= i] + + self.destination_arrays_with_special = [] + for lvl in range(self.num_ordinary_primes): + src = [filter_alloc(devi, lvl) for devi in range(self.num_devices)] + self.destination_arrays_with_special.append(src) + + special_removed = lambda i: [ + a[:-self.num_special_primes] for a in self.destination_arrays_with_special[i]] + + self.destination_arrays = [ + special_removed(i) for i in range(self.num_ordinary_primes)] + + lint = lambda arr: [a for a in arr if len(a) > 0] + self.destination_arrays = [lint(a) for a in self.destination_arrays] + + def compute_rescaler_locations(self): + mins = lambda arr: [min(a) for a in arr] + mins_loc = lambda a: mins(a).index(min(mins(a))) + self.rescaler_loc = [mins_loc(a) for a in self.destination_arrays_with_special] + + def partings(self, lvl): + count_element_sizes = lambda arr: np.array([len(a) for a in arr]) + cumsum_element_sizes = lambda arr: np.cumsum(arr) + remove_empty_parts = lambda arr: [a for a in arr if a > 0] + regenerate_parts = lambda arr: [list(range(a, b)) for a, b in zip([0] + arr[:-1], arr)] + + part_counts = [count_element_sizes(a) for a in self.prime_allocations] + part_cumsums = [cumsum_element_sizes(a) for a in part_counts] + level_diffs = [ + len(a) - len(b) for a, b in zip( + self.destination_arrays_with_special[0], self.destination_arrays_with_special[lvl])] + + part_cumsums_lvl = [remove_empty_parts(a - d) for a, d in zip(part_cumsums, level_diffs)] + part_count_lvl = [np.diff(a, prepend=0) for a in part_cumsums_lvl] + parts_lvl = [regenerate_parts(a) for a in part_cumsums_lvl] + return part_cumsums_lvl, part_count_lvl, parts_lvl + + def compute_partitions(self): + self.part_cumsums = [] + self.part_counts = [] + self.parts = [] + self.destination_parts = [] + self.destination_parts_with_special = [] + self.p = [] + self.p_special = [] + self.diff = [] + + self.d = [ + self.destination_arrays[0][dev_i] for dev_i in range(self.num_devices)] + + self.d_special = [ + self.destination_arrays_with_special[0][dev_i] + for dev_i in range(self.num_devices)] + + for lvl in range(self.num_ordinary_primes): + pcu, pco, par = self.partings(lvl) + self.part_cumsums.append(pcu) + self.part_counts.append(pco) + self.parts.append(par) + + dest = self.destination_arrays_with_special[lvl] + destp_special = [ + [ + [ + d[pi] for pi in p + ] for p in dev_p + ] for d, dev_p in zip(dest, par)] + destp = [dev_dp[:-1] for dev_dp in destp_special] + + self.destination_parts.append(destp) + self.destination_parts_with_special.append(destp_special) + + diff = [len(d1) - len(d2) for d1, d2 in zip( + self.destination_arrays_with_special[0], + self.destination_arrays_with_special[lvl])] + p_special = [[[pi + d for pi in p] + for p in dev_p] + for d, dev_p in zip(diff, self.parts[lvl])] + p = [dev_p[:-1] for dev_p in p_special] + + self.p.append(p) + self.p_special.append(p_special) + self.diff.append(diff) \ No newline at end of file diff --git a/liberate/ntt/setup.py b/liberate/ntt/setup.py new file mode 100644 index 0000000..ef8f21f --- /dev/null +++ b/liberate/ntt/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ext_modules = [ + CUDAExtension( + name="ntt_cuda", + sources=[ + "ntt.cpp", + "ntt_cuda_kernel.cu", + ], + ) +] +if __name__ == "__main__": + setup( + name="ntt", + ext_modules=ext_modules, + script_args=["build_ext"], + cmdclass={"build_ext": BuildExtension}, + # options={ + # "build":{ + # "build_lib":"liberate/ntt", + # } + # } + ) diff --git a/liberate/utils/__init__.py b/liberate/utils/__init__.py new file mode 100644 index 0000000..7a97403 --- /dev/null +++ b/liberate/utils/__init__.py @@ -0,0 +1,5 @@ +from . import helpers + +__all__ = [ + "helpers", +] diff --git a/liberate/utils/helpers.py b/liberate/utils/helpers.py new file mode 100644 index 0000000..a45474f --- /dev/null +++ b/liberate/utils/helpers.py @@ -0,0 +1,41 @@ +import numpy as np +from matplotlib import pyplot as plt + + +def random_complex_array( + n: int = 2 ** 8, + amin: int = -(2 ** 20), + amax: int = 2 ** 20, + decimal_exponent: int = 10, +): + base = 10 ** decimal_exponent + a = np.random.randint(amin * base, amax * base, n) / base + b = np.random.randint(amin * base, amax * base, n) / base + ret = a + b * 1j + return ret + + +def check_errors(test_message, test_message_dec, idx=10, title="errors"): + errs = test_message_dec - test_message + plt.figure(figsize=(16, 9)) + plt.plot(errs) + plt.grid() + plt.title(title) + plt.show() + + print("============================================================") + for x, y in zip(test_message[:idx], test_message_dec[:idx]): + print(f"{x.real:19.10f} | {y.real:19.10f} | {(y - x).real:14.10f}") + print("============================================================") + print(f"mean\t=\t{errs.mean():10.15f}") + print(f"std\t=\t{errs.std():10.15f}") + print(f"max err\t=\t{abs(errs).max().real:10.15f}") + print(f"min err\t=\t{abs(errs).min().real:10.15f}") + + +def absmax_error(x, y): + if type(x[0]) == np.complex128 and type(y[0]) == np.complex128: + r = np.abs(x.real - y.real).max() + np.abs(x.imag - y.imag).max() * 1j + else: + r = np.abs(np.array(x) - np.array(y)).max() + return r diff --git a/lint.sh b/lint.sh new file mode 100755 index 0000000..34872f0 --- /dev/null +++ b/lint.sh @@ -0,0 +1,6 @@ +echo 'isort:' +poetry run isort . +echo 'flake8:' +poetry run flake8 . +echo 'black:' +poetry run black . diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f94e3a3 --- /dev/null +++ b/setup.py @@ -0,0 +1,68 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ext_modules = [ + CUDAExtension( + name="randint_cuda", + sources=[ + "liberate/csprng/randint.cpp", + "liberate/csprng/randint_cuda_kernel.cu", + ], + ), + CUDAExtension( + name="randround_cuda", + sources=[ + "liberate/csprng/randround.cpp", + "liberate/csprng/randround_cuda_kernel.cu", + ], + ), + CUDAExtension( + name="discrete_gaussian_cuda", + sources=[ + "liberate/csprng/discrete_gaussian.cpp", + "liberate/csprng/discrete_gaussian_cuda_kernel.cu", + ], + ), + CUDAExtension( + name="chacha20_cuda", + sources=[ + "liberate/csprng/chacha20.cpp", + "liberate/csprng/chacha20_cuda_kernel.cu", + ], + ), +] + +ext_modules_ntt = [ + CUDAExtension( + name="ntt_cuda", + sources=[ + "liberate/ntt/ntt.cpp", + "liberate/ntt/ntt_cuda_kernel.cu", + ], + ) +] + +if __name__ == "__main__": + setup( + name="csprng", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + script_args=["build_ext"], + options={ + "build": { + "build_lib": "liberate/csprng", + } + }, + ) + + setup( + name="ntt", + ext_modules=ext_modules_ntt, + script_args=["build_ext"], + cmdclass={"build_ext": BuildExtension}, + options={ + "build": { + "build_lib": "liberate/ntt", + } + }, + )