From 642012c6c285a3a151593dbee55436155f7bf6fb Mon Sep 17 00:00:00 2001 From: hanyul-ryu Date: Wed, 22 Nov 2023 19:57:16 +0900 Subject: [PATCH] upload liberate.fhe --- .flake8 | 5 + liberate/__init__.py | 1 + liberate/csprng/__init__.py | 1 + liberate/csprng/chacha20.cpp | 43 + liberate/csprng/chacha20_cuda_kernel.cu | 69 + liberate/csprng/chacha20_cuda_kernel.h | 34 + liberate/csprng/chacha20_naive.py | 246 ++ liberate/csprng/csprng.py | 306 ++ liberate/csprng/discrete_gaussian.cpp | 69 + .../csprng/discrete_gaussian_cuda_kernel.cu | 237 ++ liberate/csprng/discrete_gaussian_sampler.py | 156 + liberate/csprng/randint.cpp | 55 + liberate/csprng/randint_cuda_kernel.cu | 215 ++ liberate/csprng/randround.cpp | 32 + liberate/csprng/randround_cuda_kernel.cu | 56 + liberate/csprng/setup.py | 33 + liberate/fhe/__init__.py | 4 + liberate/fhe/cache/__init__.py | 1 + liberate/fhe/cache/cache.py | 32 + liberate/fhe/cache/resources/logN_N_M.pkl | Bin 0 -> 101 bytes .../resources/message_special_primes.pkl | Bin 0 -> 1110 bytes liberate/fhe/cache/resources/scale_primes.pkl | Bin 0 -> 141973 bytes liberate/fhe/ckks_engine.py | 2615 +++++++++++++++++ liberate/fhe/context/__init__.py | 2 + liberate/fhe/context/ckks_context.py | 360 +++ liberate/fhe/context/generate_primes.py | 311 ++ liberate/fhe/context/prim_test.py | 64 + liberate/fhe/context/security_parameters.py | 201 ++ liberate/fhe/data_struct.py | 24 + liberate/fhe/encdec/__init__.py | 1 + liberate/fhe/encdec/encdec.py | 323 ++ liberate/fhe/presets/__init__.py | 3 + liberate/fhe/presets/errors.py | 166 ++ liberate/fhe/presets/params.py | 31 + liberate/fhe/presets/types.py | 11 + liberate/fhe/version.py | 1 + liberate/ntt/__init__.py | 2 + liberate/ntt/ntt.cpp | 437 +++ liberate/ntt/ntt_context.py | 523 ++++ liberate/ntt/ntt_cuda_kernel.cu | 1230 ++++++++ liberate/ntt/rns_partition.py | 141 + liberate/ntt/setup.py | 24 + liberate/utils/__init__.py | 5 + liberate/utils/helpers.py | 41 + lint.sh | 6 + setup.py | 68 + 46 files changed, 8185 insertions(+) create mode 100644 .flake8 create mode 100644 liberate/__init__.py create mode 100644 liberate/csprng/__init__.py create mode 100644 liberate/csprng/chacha20.cpp create mode 100644 liberate/csprng/chacha20_cuda_kernel.cu create mode 100644 liberate/csprng/chacha20_cuda_kernel.h create mode 100644 liberate/csprng/chacha20_naive.py create mode 100644 liberate/csprng/csprng.py create mode 100644 liberate/csprng/discrete_gaussian.cpp create mode 100644 liberate/csprng/discrete_gaussian_cuda_kernel.cu create mode 100644 liberate/csprng/discrete_gaussian_sampler.py create mode 100644 liberate/csprng/randint.cpp create mode 100644 liberate/csprng/randint_cuda_kernel.cu create mode 100644 liberate/csprng/randround.cpp create mode 100644 liberate/csprng/randround_cuda_kernel.cu create mode 100644 liberate/csprng/setup.py create mode 100644 liberate/fhe/__init__.py create mode 100644 liberate/fhe/cache/__init__.py create mode 100644 liberate/fhe/cache/cache.py create mode 100644 liberate/fhe/cache/resources/logN_N_M.pkl create mode 100644 liberate/fhe/cache/resources/message_special_primes.pkl create mode 100644 liberate/fhe/cache/resources/scale_primes.pkl create mode 100644 liberate/fhe/ckks_engine.py create mode 100644 liberate/fhe/context/__init__.py create mode 100644 liberate/fhe/context/ckks_context.py create mode 100644 liberate/fhe/context/generate_primes.py create mode 100644 liberate/fhe/context/prim_test.py create mode 100644 liberate/fhe/context/security_parameters.py create mode 100644 liberate/fhe/data_struct.py create mode 100644 liberate/fhe/encdec/__init__.py create mode 100644 liberate/fhe/encdec/encdec.py create mode 100644 liberate/fhe/presets/__init__.py create mode 100644 liberate/fhe/presets/errors.py create mode 100644 liberate/fhe/presets/params.py create mode 100644 liberate/fhe/presets/types.py create mode 100644 liberate/fhe/version.py create mode 100644 liberate/ntt/__init__.py create mode 100644 liberate/ntt/ntt.cpp create mode 100644 liberate/ntt/ntt_context.py create mode 100644 liberate/ntt/ntt_cuda_kernel.cu create mode 100644 liberate/ntt/rns_partition.py create mode 100644 liberate/ntt/setup.py create mode 100644 liberate/utils/__init__.py create mode 100644 liberate/utils/helpers.py create mode 100755 lint.sh create mode 100644 setup.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..dfac036 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 79 +extend-ignore = E203 +exclude = + .venv,build,dist,docs,examples diff --git a/liberate/__init__.py b/liberate/__init__.py new file mode 100644 index 0000000..3d081ee --- /dev/null +++ b/liberate/__init__.py @@ -0,0 +1 @@ +from . import csprng, fhe, utils diff --git a/liberate/csprng/__init__.py b/liberate/csprng/__init__.py new file mode 100644 index 0000000..d576d02 --- /dev/null +++ b/liberate/csprng/__init__.py @@ -0,0 +1 @@ +from .csprng import Csprng diff --git a/liberate/csprng/chacha20.cpp b/liberate/csprng/chacha20.cpp new file mode 100644 index 0000000..815182c --- /dev/null +++ b/liberate/csprng/chacha20.cpp @@ -0,0 +1,43 @@ +#include +#include + +// chacha20 is a mutating function. +// That means the input is mutated and there's no need to return a value. + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void chacha20_cuda(torch::Tensor input, torch::Tensor dest, size_t step); + +std::vector chacha20(std::vector inputs, size_t step) { + // The input must be a contiguous long tensor of size 16 x N. + // Also, the tensor must be contiguous to enable pointer arithmetic, + // and must be stored in a cuda device. + // Note that the input is a vector of inputs in different devices. + + std::vector outputs; + + for (auto &input : inputs){ + CHECK_INPUT(input); + + // Prepare an output. + auto dest = input.clone(); + + // Run in cuda. + chacha20_cuda(input, dest, step); + + // Store to the dest. + outputs.push_back(dest); + } + + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("chacha20", &chacha20, "CHACHA20 (CUDA)"); +} diff --git a/liberate/csprng/chacha20_cuda_kernel.cu b/liberate/csprng/chacha20_cuda_kernel.cu new file mode 100644 index 0000000..429a894 --- /dev/null +++ b/liberate/csprng/chacha20_cuda_kernel.cu @@ -0,0 +1,69 @@ +#include +#include +#include +#include + +#include "chacha20_cuda_kernel.h" + +#define BLOCK_SIZE 256 + +__global__ void chacha20_cuda_kernel( + torch::PackedTensorAccessor32 input, + torch::PackedTensorAccessor32 dest, + size_t step) { + + // input is configured as N x 16 + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int64_t x[BLOCK_SIZE][16]; + + #pragma unroll + for(int i=0; i<16; ++i){ + x[threadIdx.x][i] = dest[index][i]; + + // Vectorized load. + // Not much beneficial for our case, though ... + //reinterpret_cast(x[threadIdx.x])[i] = + // *reinterpret_cast(&(dest[index][i])); + } + + // Repeat 10 times for chacha20. + #pragma unroll + for(int i=0; i<10; ++i){ + ONE_ROUND(x, threadIdx.x); + } + + #pragma unroll + for(int i=0; i<16; ++i){ + dest[index][i] = (dest[index][i] + x[threadIdx.x][i]) & MASK; + } + + // Step the state. + input[index][12] += step; + input[index][13] += (input[index][12] >> 32); + input[index][12] &= MASK; +} + + +// The wrapper. +void chacha20_cuda(torch::Tensor input, torch::Tensor dest, size_t step){ + + // Required number of blocks in a grid. + // Note that we do not use grids here, since the + // tensor we're dealing with must be chopped in 1-d. + // input is configured as 16 x N + // N must be a multitude of 1024. + + const int dim_block = BLOCK_SIZE; + int dim_grid = input.size(0) / dim_block; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = input.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Run the cuda kernel. + auto input_acc = input.packed_accessor32(); + auto dest_acc = dest.packed_accessor32(); + chacha20_cuda_kernel<<>>(input_acc, dest_acc, step); +} diff --git a/liberate/csprng/chacha20_cuda_kernel.h b/liberate/csprng/chacha20_cuda_kernel.h new file mode 100644 index 0000000..2e445aa --- /dev/null +++ b/liberate/csprng/chacha20_cuda_kernel.h @@ -0,0 +1,34 @@ +#define MASK 0xffffffff + +#define ROLL16(x) x = (((x) << 16) | ((x) >> 16)) & MASK +#define ROLL12(x) x = (((x) << 12) | ((x) >> 20)) & MASK +#define ROLL8(x) x = (((x) << 8) | ((x) >> 24)) & MASK +#define ROLL7(x) x = (((x) << 7) | ((x) >> 25)) & MASK + +#define QR(x, ind, a, b, c, d)\ + x[ind][a] += x[ind][b];\ + x[ind][a] &= MASK;\ + x[ind][d] ^= x[ind][a];\ + ROLL16(x[ind][d]);\ + x[ind][c] += x[ind][d];\ + x[ind][c] &= MASK;\ + x[ind][b] ^= x[ind][c];\ + ROLL12(x[ind][b]);\ + x[ind][a] += x[ind][b];\ + x[ind][a] &= MASK;\ + x[ind][d] ^= x[ind][a];\ + ROLL8(x[ind][d]);\ + x[ind][c] += x[ind][d];\ + x[ind][c] &= MASK;\ + x[ind][b] ^= x[ind][c];\ + ROLL7(x[ind][b]) + +#define ONE_ROUND(x, ind)\ + QR(x, ind, 0, 4, 8, 12);\ + QR(x, ind, 1, 5, 9, 13);\ + QR(x, ind, 2, 6, 10, 14);\ + QR(x, ind, 3, 7, 11, 15);\ + QR(x, ind, 0, 5, 10, 15);\ + QR(x, ind, 1, 6, 11, 12);\ + QR(x, ind, 2, 7, 8, 13);\ + QR(x, ind, 3, 4, 9, 14) diff --git a/liberate/csprng/chacha20_naive.py b/liberate/csprng/chacha20_naive.py new file mode 100644 index 0000000..5632f48 --- /dev/null +++ b/liberate/csprng/chacha20_naive.py @@ -0,0 +1,246 @@ +import binascii +import math +import os + +import torch + +torch.backends.cudnn.benchmark = True + + +@torch.jit.script +def roll(x: torch.Tensor, s: int) -> None: + """ + x's dtype must be torch.int64. + We are kind of forced to do this because + 1. pytorch doesn't support unit32, and + 2. >> doesn't move the sign bit. + """ + mask = 0xFFFFFFFF + right_shift = 32 - s + down = (x & mask) >> right_shift + # x <<= s + # x |= down + x.__ilshift__(s).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll16(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 16 + x.__ilshift__(16).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll12(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 20 + x.__ilshift__(12).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll8(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 24 + x.__ilshift__(8).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def roll7(x: torch.Tensor) -> None: + mask = 0xFFFFFFFF + down = x >> 25 + x.__ilshift__(7).bitwise_or_(down).bitwise_and_(mask) + + +@torch.jit.script +def QR(x: torch.Tensor, a: int, b: int, c: int, d: int) -> None: + """ + The CHACHA quarter round. + """ + mask = 0xFFFFFFFF + + x[a].add_(x[b]) + x[a].bitwise_and_(mask) + x[d].bitwise_xor_(x[a]) + roll16(x[d]) + + x[c].add_(x[d]) + x[c].bitwise_and_(mask) + x[b].bitwise_xor_(x[c]) + roll12(x[b]) + + x[a].add_(x[b]) + x[a].bitwise_and_(mask) + x[d].bitwise_xor_(x[a]) + roll8(x[d]) + + x[c].add_(x[d]) + x[c].bitwise_and_(mask) + x[b].bitwise_xor_(x[c]) + roll7(x[b]) + + +@torch.jit.script +def one_round(x: torch.Tensor) -> None: + # Odd round. + QR(x, 0, 4, 8, 12) + QR(x, 1, 5, 9, 13) + QR(x, 2, 6, 10, 14) + QR(x, 3, 7, 11, 15) + # Even round. + QR(x, 0, 5, 10, 15) + QR(x, 1, 6, 11, 12) + QR(x, 2, 7, 8, 13) + QR(x, 3, 4, 9, 14) + + +@torch.jit.script +def increment_counter(state: torch.Tensor, inc: int) -> None: + state[12] += inc + state[13] += state[12] >> 32 + state[12] = state[12] & 0xFFFFFFFF + + +@torch.jit.script +def chacha20(state: torch.Tensor) -> torch.Tensor: + x = state.clone() + + for _ in range(10): + one_round(x) + + # Return the random bytes. + return (x + state) & 0xFFFFFFFF + + +class chacha20_naive: + def __init__( + self, size, seed=None, nonce=None, count_step=1, device="cuda:0" + ): + self.size = size + self.device = device + self.count_step = count_step + + # expand 32-byte k. + # This is 1634760805, 857760878, 2036477234, 1797285236. + str2ord = lambda s: sum([2 ** (i * 8) * c for i, c in enumerate(s)]) + self.nothing_up_my_sleeve = torch.tensor( + [ + str2ord(b"expa"), + str2ord(b"nd 3"), + str2ord(b"2-by"), + str2ord(b"te k"), + ], + device=device, + dtype=torch.int64, + ) + + # Prepare the state tensor. + self.state_size = (*self.size, 16) + self.state_buffer = torch.zeros( + 16, (math.prod(self.size)), dtype=torch.int64, device=self.device + ) + + # The ind counter. + self.ind = torch.arange( + 0, math.prod(self.size), self.count_step, device=device + ) + + # Increment is the number of indices. + self.inc = math.prod(self.size) + + self.initialize_state(seed, nonce) + + # Capture computation graph. + self.capture() + + def initialize_state(self, seed=None, nonce=None): + # Generate seed if necessary. + self.key(seed) + + # Generate nonce if necessary. + self.generate_nonce(nonce) + + # Zero out the state. + self.state_buffer.zero_() + + # Set the counter. + self.state_buffer[12, :] = self.ind + + # Set the expand 32-bye k + self.state_buffer[0:4, :] = self.nothing_up_my_sleeve[:, None] + + # Set the seed. + self.state_buffer[4:12, :] = self.seed[:, None] + + # Fill in nonce. + self.state_buffer[14:, :] = self.nonce[:, None] + + def key(self, seed=None): + # 256bits seed as a key. + if seed is None: + # 256bits key as a seed, + nbytes = 32 + part_bytes = 4 + n_keys = nbytes // part_bytes + hex2int = lambda x, nbytes: int(binascii.hexlify(x), 16) + self.seed = torch.tensor( + [ + hex2int(os.urandom(part_bytes), part_bytes) + for _ in range(n_keys) + ], + device=self.device, + dtype=torch.int64, + ) + else: + self.seed = torch.tensor( + seed, device=self.device, dtype=torch.int64 + ) + + def generate_nonce(self, nonce): + # nonce is 64bits. + if nonce is None: + # 256bits key as a seed, + nbytes = 8 + part_bytes = 4 + n_keys = nbytes // part_bytes + hex2int = lambda x, nbytes: int(binascii.hexlify(x), 16) + self.nonce = torch.tensor( + [ + hex2int(os.urandom(part_bytes), part_bytes) + for _ in range(n_keys) + ], + device=self.device, + dtype=torch.int64, + ) + else: + self.nonce = torch.tensor( + nonce, device=self.device, dtype=torch.int64 + ) + + def capture(self, warmup_periods=3, fuser="fuser1"): + with torch.cuda.device(self.device): + # Reserve ample amount of excution cache. + torch.jit.set_fusion_strategy([("STATIC", 100)]) + + # Output buffer. + self.out = torch.zeros_like(self.state_buffer) + + # Warm up. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s), torch.jit.fuser(fuser): + for _ in range(warmup_periods): + self.out = chacha20(self.state_buffer) + torch.cuda.current_stream().wait_stream(s) + + # Capture. + self.g = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.g): + self.out = chacha20(self.state_buffer) + + def step(self): + increment_counter(self.state_buffer, self.inc) + + def randbyte(self): + self.g.replay() + self.step() + return self.out.clone() diff --git a/liberate/csprng/csprng.py b/liberate/csprng/csprng.py new file mode 100644 index 0000000..77110da --- /dev/null +++ b/liberate/csprng/csprng.py @@ -0,0 +1,306 @@ +import binascii +import os + +import numpy as np +import torch + +from . import ( + chacha20_cuda, + discrete_gaussian_cuda, + randint_cuda, + randround_cuda, +) +from .discrete_gaussian_sampler import build_CDT_binary_search_tree + +torch.backends.cudnn.benchmark = True + + +class Csprng: + def __init__(self, num_coefs=2 ** 15, num_channels=[8], num_repeating_channels=2, + sigma=3.2, devices=None, seed=None, nonce=None): + """N is the length of the polynomial, and C is the number of RNS channels. + procure the maximum (at level zero, special multiplication) at initialization.""" + + # This CSPRNG class generates + # 1. num_coefs x (num_channels + num_repeating_channels) uniform distributed + # random numbers at max. num_channels can be reduced down according + # to the input q at the time of generation. + # The numbers generated ri are 0 <= ri < qi. + # Seeds in the repeated channels are the same, and hence in those + # channels, the generated numbers are the same across GPUs. + # Generation of the repeated random numbers is optional. + # The same function can be used to generate ranged random integers + # in a fixed range. Again, generation of the repeated numbers is optional. + # 2. Generation of Discrete Gaussian random numbers. The numbers can be generated + # in the non-repeating channels (with the maximum number of channels num_channels), + # or in the repeating channels (with the maximum number of + # channels num_repeating_channels, where in the most typical scenario is same as 1). + + self.num_coefs = num_coefs + self.num_channels = num_channels + self.num_repeating_channels = num_repeating_channels + self.sigma = sigma + + # Set up GPUs. + # By default, use all the available GPUs on the system. + if devices is None: + gpu_count = torch.cuda.device_count() + self.devices = [f'cuda:{i}' for i in range(gpu_count)] + else: + self.devices = devices + + self.num_devices = len(self.devices) + + # Compute shares of channels per a GPU. + if len(self.num_channels) == 1: + # Allocate the same number of channels per every GPU. + self.shares = [self.num_channels[0]] * self.num_devices + elif len(self.num_channels) == self.num_devices: + self.shares = self.num_channels + else: + # User input was contradicting. + raise Exception("There was a contradicting mismatch between " + "num_channels, and devices.") + + # How many channels in total? + self.total_num_channels = sum(self.shares) + + # We generate random bytes 4x4 = 16 per an array and hence, + # internally only need to procure N // 4 length arrays. + # Out of the 16, we generate discrete gaussian or uniform + # samples 4 at a time. + self.L = self.num_coefs // 4 + + # We build binary search tree for discrete gaussian here. + self.btree, self.btree_ptr, self.btree_size, self.tree_depth = \ + build_CDT_binary_search_tree(security_bits=128, sigma=sigma) + + # Counter range at each GPU. + # Note that the counters only include the non-repeating channels. + # We can add later counters at the end that start from the + # self.total_num_channels * self.L + self.start_ind = [0] + [s * self.L for s in self.shares[:-1]] + self.ind_increments = [s * self.L for s in self.shares] + self.end_ind = [s + e for s, e in zip(self.start_ind, self.ind_increments)] + + # Total increment to add to counters after each random bytes generation. + self.inc = (self.total_num_channels + self.num_repeating_channels) * self.L + self.repeating_start = self.total_num_channels * self.L + + # expand 32-byte k. + # This is 1634760805, 857760878, 2036477234, 1797285236. + str2ord = lambda s: sum([2 ** (i * 8) * c for i, c in enumerate(s)]) + self.nothing_up_my_sleeve = [] + for device in self.devices: + str_constant = torch.tensor( + [ + str2ord(b'expa'), str2ord(b'nd 3'), str2ord(b'2-by'), str2ord(b'te k') + ], device=device, dtype=torch.int64 + ) + self.nothing_up_my_sleeve.append(str_constant) + + # Prepare the state tensors. + self.states = [] + for dev_id in range(self.num_devices): + state_size = ( + (self.shares[dev_id] + self.num_repeating_channels) * self.L, 16) + state = torch.zeros( + state_size, + dtype=torch.int64, + device=self.devices[dev_id] + ) + self.states.append(state) + + # Prepare a channeled views. + self.channeled_states = [ + self.states[i].view( + self.shares[i] + self.num_repeating_channels, + self.L, -1) for i in range(self.num_devices) + ] + + # The counter. + self.counters = [] + repeating_counter = list(range(self.repeating_start, self.inc)) + for dev_id in range(self.num_devices): + counter = list(range( + self.start_ind[dev_id], self.end_ind[dev_id])) + repeating_counter + + counter_tensor = torch.tensor(counter, + dtype=torch.int64, + device=self.devices[dev_id]) + self.counters.append(counter_tensor) + + self.refresh(seed, nonce) + + def refresh(self, seed=None, nonce=None): + # Generate seed if necessary. + self.key = self.generate_key(seed) + + # Generate nonce if necessary. + self.nonce = self.generate_nonce(nonce) + + # Iterate over all devices. + for dev_id in range(self.num_devices): + self.initialize_states(dev_id, seed, nonce) + + def initialize_states(self, dev_id, seed=None, nonce=None): + + state = self.states[dev_id] + state.zero_() + + # Set the counter. + # It is hardly unlikely we will use CxL > 2**32. + # Just fill in the 12th element + # (The lower bytes of the counter). + state[:, 12] = self.counters[dev_id][None, :] + + # Set the expand 32-bye k + state[:, 0:4] = self.nothing_up_my_sleeve[dev_id][None, :] + + # Set the seed. + state[:, 4:12] = self.key[dev_id][None, :] + + # Fill in nonce. + state[:, 14:] = self.nonce[dev_id][None, :] + + def generate_initial_bytes(self, nbytes, part_bytes=4, seed=None): + seeds = [] + if seed is None: + n_keys = nbytes // part_bytes + hex2int = lambda x, nbytes: int(binascii.hexlify(x), 16) + seed0 = [ + hex2int(os.urandom(part_bytes), part_bytes) for _ in range(n_keys) + ] + for dev_id in range(self.num_devices): + cuda_seed = torch.tensor( + seed0, + dtype=torch.int64, + device=self.devices[dev_id] + ) + seeds.append(cuda_seed) + else: + seed0 = seed + for dev_id in range(self.num_devices): + cuda_seed = torch.tensor( + seed0, + dtype=torch.int64, + device=self.devices[dev_id] + ) + seeds.append(cuda_seed) + return seeds + + def generate_key(self, seed): + # 256bits seed as a key. + # We generate the same key seed for every GPU. + # Randomity is produced by counters, not the key. + return self.generate_initial_bytes(32, seed=None) + + def generate_nonce(self, seed): + # nonce is 64bits. + return self.generate_initial_bytes(8, seed=None) + + def randbytes(self, shares=None, repeats=0, length=None, reshape=False): + # Generates (shares_i + repeats) X length random bytes. + if shares is None: + shares = self.shares + + if length is None: + L = self.L + else: + L = length + + # Set the target states. + target_states = [] + for devi in range(self.num_devices): + start_channel = self.shares[devi] - shares[devi] + end_channel = self.shares[devi] + repeats + device_states = self.channeled_states[devi][ + start_channel:end_channel, :L, :] + target_states.append(device_states.view(-1, 16)) + + # Derive random bytes. + random_bytes = chacha20_cuda.chacha20(target_states, self.inc) + + # If not reshape, flatten. + if reshape: + random_bytes = [rb.view(-1, L, 16) for rb in random_bytes] + + return random_bytes + + def randint(self, amax=3, shift=0, repeats=0, length=None): + # The default values are for generating the same uniform ternary + # arrays in all GPUs. + + if not isinstance(amax, (list, tuple)): + amax = [[amax] for share in self.shares] + + if length is None: + L = self.L + else: + L = length + + # Calculate shares. + # If repeats are greater than 0, those channels are + # subtracted from shares. + shares = [len(am) - repeats for am in amax] + + # Convert the amax list to contiguous numpy array pointers. + q_conti = [np.ascontiguousarray(q, dtype=np.uint64) for q in amax] + q_ptr = [q.__array_interface__['data'][0] for q in q_conti] + + # Set the target states. + target_states = [] + for devi in range(self.num_devices): + start_channel = self.shares[devi] - shares[devi] + end_channel = self.shares[devi] + repeats + device_states = self.channeled_states[devi][ + start_channel:end_channel, :L, :] + target_states.append(device_states) + + # Generate the randint. + rand_int = randint_cuda.randint_fast( + target_states, q_ptr, shift, self.inc) + + return rand_int + + def discrete_gaussian(self, non_repeats=0, repeats=1, length=None): + + if not isinstance(non_repeats, (list, tuple)): + shares = [non_repeats] * self.num_devices + else: + shares = non_repeats + + if length is None: + L = self.L + else: + L = length + + # Set the target states. + target_states = [] + for devi in range(self.num_devices): + start_channel = self.shares[devi] - shares[devi] + end_channel = self.shares[devi] + repeats + device_states = self.channeled_states[devi][ + start_channel:end_channel, :L, :] + target_states.append(device_states.view(-1, 16)) + + # Generate the randint. + rand_int = discrete_gaussian_cuda.discrete_gaussian_fast(target_states, + self.btree_ptr, + self.btree_size, + self.tree_depth, + self.inc) + # Reformat the rand_int. + rand_int = [ri.view(-1, self.num_coefs) for ri in rand_int] + + return rand_int + + def randround(self, coef): + """Randomly round coef. Coef must be a double tensor. + coef must reside in the fist GPU in the GPUs list""" + + L = self.num_coefs // 16 + rand_bytes = chacha20_cuda.chacha20( + (self.states[0][:L],), self.inc)[0].ravel() + randround_cuda.randround([coef], [rand_bytes]) + return rand_bytes diff --git a/liberate/csprng/discrete_gaussian.cpp b/liberate/csprng/discrete_gaussian.cpp new file mode 100644 index 0000000..ee2d66f --- /dev/null +++ b/liberate/csprng/discrete_gaussian.cpp @@ -0,0 +1,69 @@ + +#include +#include + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void discrete_gaussian_cuda(torch::Tensor rand_bytes, + uint64_t* btree, + int btree_size, + int depth); + +torch::Tensor discrete_gaussian_fast_cuda(torch::Tensor states, + uint64_t* btree, + int btree_size, + int depth, + size_t step); + + + +// The main function. +//---------------- +// Normal version. +void discrete_gaussian(std::vector inputs, + size_t btree_ptr, + int btree_size, + int depth) { + + // reinterpret pointers from numpy. + uint64_t *btree = reinterpret_cast(btree_ptr); + + for (auto &rand_bytes : inputs){ + CHECK_INPUT(rand_bytes); + + // Run in cuda. + discrete_gaussian_cuda(rand_bytes, btree, btree_size, depth); + } +} + +//-------------- +// Fast version. + +std::vector discrete_gaussian_fast(std::vector states, + size_t btree_ptr, + int btree_size, + int depth, + size_t step) { + + // reinterpret pointers from numpy. + uint64_t *btree = reinterpret_cast(btree_ptr); + + std::vector outputs; + + for (auto &my_states : states){ + auto result = discrete_gaussian_fast_cuda(my_states, btree, btree_size, depth, step); + outputs.push_back(result); + } + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("discrete_gaussian", &discrete_gaussian, "discrete gaussian sampling (128 bits)."); + m.def("discrete_gaussian_fast", &discrete_gaussian_fast, "discrete gaussian sampling fast (chacha20 fused, 128 bits)."); +} diff --git a/liberate/csprng/discrete_gaussian_cuda_kernel.cu b/liberate/csprng/discrete_gaussian_cuda_kernel.cu new file mode 100644 index 0000000..118ce23 --- /dev/null +++ b/liberate/csprng/discrete_gaussian_cuda_kernel.cu @@ -0,0 +1,237 @@ +#include +#include +#include +#include + +#include "chacha20_cuda_kernel.h" + +#define GE(x_high, x_low, y_high, y_low)\ + (((x_high) > (y_high)) | (((x_high) == (y_high)) & ((x_low) >= (y_low)))) + +#define COMBINE_TWO(high, low)\ + ((static_cast(high) << 32) | static_cast(low)) + +// Use 256 BLOCK_SIZE. 64 * 4 = 256. +#define BLOCK_SIZE 64 + +#define LUT_SIZE 128 +__constant__ uint64_t LUT[LUT_SIZE]; + +/////////////////////////////////////////////////////////// +// The implementation + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +__global__ void discrete_gaussian_fast_cuda_kernel( + torch::PackedTensorAccessor32 states, + torch::PackedTensorAccessor32 dst, + int btree_size, + int depth, + size_t step){ + + // Where am I? + const int thread_ind = threadIdx.x; + const int poly_order = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int64_t x[BLOCK_SIZE][16]; + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = states[poly_order][i]; + } + + // Repeat 10 times for chacha20. + #pragma unroll + for(int i=0; i<10; ++i){ + ONE_ROUND(x, thread_ind); + } + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = (states[poly_order][i] + x[thread_ind][i]) & MASK; + } + + // Step the state. + states[poly_order][12] += step; + states[poly_order][13] += (states[poly_order][12] >> 32); + states[poly_order][12] &= MASK; + + + // Discrete gaussian + for(int i=0; i<16; i+=4){ + // Traverse the tree in the LUT. + // Note that, out of the 16 32-bit randon numbers, + // we generate 4 discrete gaussian samples. + int jump = 1; + int current = 0; + int counter = 0; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(x[thread_ind][i], x[thread_ind][i+1]); + uint64_t x_high = COMBINE_TWO(x[thread_ind][i+2], x[thread_ind][i+3]); + + // Reserve a sign bit. + // Since we are dealing with the half plane, + // The CDT values in the LUT are at most 0.5, which means + // that the values are 127 bits. + // Also, rigorously speaking, we need to take out the MSB from the x_high + // value to take the sign, but every bit has probability of occurrence=0.5. + // Hence, it doesn't matter where we take the bit. + // For convenience, take the LSB of x_high. + int64_t sign_bit = x_high & 1; + x_high >>= 1; + + // Traverse the binary search tree. + for(int j=0; j(current); + + // Store the result. + const int new_poly_order = poly_order * 4 + i/4; + dst[new_poly_order] = sample; + } +} + + + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + +// rand_bytes are configured as N x 16. +__global__ void discrete_gaussian_cuda_kernel( + torch::PackedTensorAccessor32 rand_bytes, + int btree_size, + int depth){ + // Where am I? + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + // i is the index of the starting element at the threadIdx.y row. + const int i = threadIdx.y * 4; + + // Traverse the tree in the LUT. + // Note that, out of the 16 32-bit randon numbers, + // we generate 4 discrete gaussian samples. + int jump = 1; + int current = 0; + int counter = 0; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(rand_bytes[index][i], rand_bytes[index][i+1]); + uint64_t x_high = COMBINE_TWO(rand_bytes[index][i+2], rand_bytes[index][i+3]); + + // Reserve a sign bit. + // Since we are dealing with the half plane, + // The CDT values in the LUT are at most 0.5, which means + // that the values are 127 bits. + // Also, rigorously speaking, we need to take out the MSB from the x_high + // value to take the sign, but every bit has probability of occurrence=0.5. + // Hence, it doesn't matter where we take the bit. + // For convenience, take the LSB of x_high. + int64_t sign_bit = x_high & 1; + x_high >>= 1; + + // Traverse the binary search tree. + for(int j=0; j(current); + + // Store the result. + rand_bytes[index][i] = sample; +} + + + +/////////////////////////////////////////////////////////// +// The wrapper. + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +torch::Tensor discrete_gaussian_fast_cuda(torch::Tensor states, + uint64_t* btree, + int btree_size, + int depth, + size_t step){ + + // rand_bytes has the dim N x 16. + int dim_block = BLOCK_SIZE; + int dim_grid = states.size(0) / dim_block; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = states.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Prepare the result. + // 16 elements in each state turn into 4 random numbers. + auto result = states.new_empty({states.size(0) * 4}); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(LUT, btree, btree_size * 2 * sizeof(uint64_t)); + + // Run the cuda kernel. + auto states_acc = states.packed_accessor32(); + auto result_acc = result.packed_accessor32(); + discrete_gaussian_fast_cuda_kernel<<>>( + states_acc, result_acc, btree_size, depth, step); + + return result; +} + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + +void discrete_gaussian_cuda(torch::Tensor rand_bytes, + uint64_t* btree, + int btree_size, + int depth){ + + // rand_bytes has the dim N x 16. + dim3 dim_block(BLOCK_SIZE, 4); + int dim_grid = rand_bytes.size(0) / dim_block.x; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = rand_bytes.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(LUT, btree, btree_size * 2 * sizeof(uint64_t)); + + // Run the cuda kernel. + auto access = rand_bytes.packed_accessor32(); + discrete_gaussian_cuda_kernel<<>>(access, btree_size, depth); +} diff --git a/liberate/csprng/discrete_gaussian_sampler.py b/liberate/csprng/discrete_gaussian_sampler.py new file mode 100644 index 0000000..3f43037 --- /dev/null +++ b/liberate/csprng/discrete_gaussian_sampler.py @@ -0,0 +1,156 @@ +# Sample from discrete gaussian distribution. +import math + +import mpmath as mpm +import numpy as np +import torch + +# import discrete_gaussian_cuda +from . import discrete_gaussian_cuda + + +def build_CDT_binary_search_tree(security_bits=128, sigma=3.2): + """Currently, ONLY the discrete gaussian sampling at the default input values + is supported. That is equivalent to the 128 security measure.""" + # Set accuracy to 258 bits = (128 * 2) bits. + # We will use higher 128 bits for the 128 security model, + # Hence, retaining the 256 bits are safe for intermediate calculations + # to carry over all the significant digits. + + mpm.mp.prec = security_bits * 2 + + # Truncation boundary. + # The minimum recommended by + # Shi Bai, Adeline Langlois, Tancr`ede Lepoint, Damien Stehl ́e, + # and Ron Ste- infeld. Improved security proofs in lattice-based cryptography: + # Using the r ́enyi divergence rather than the statistical distance., + # as tau = 6 sigma. + # We want the number tau to be the power of 2 in fact, since it makes the + # binary tree search constant time. Using a larger tau the the minumum required + # is no problem in fact and in terms of tree traversing it doesn't score a performance + # hit because the tree will be balanced as a result. + # So we calculate the smallest power of 2 bigger than the minimum tau as the number + # of sampling points. + sampling_power = math.ceil(math.log2(6 * sigma)) + num_sampling_points = 2 ** sampling_power + sampling_points = list(range(num_sampling_points)) + + # Calculate probabilities at sampling points. + # Be careful when converting the python float to mpmath float. + # No mormalization is done and the mpmath tries to retain the + # bit pattern of the original float. + # As a result, when you do mpm.mpf(3.2), you get + # mpf('3.20000000000000017763568394002504646778106689453125'). + # As a workaround, we can do mpm.mpf('3.2') to get + # mpf('3.200000000000000000000000000000000000000000000000000000000000000000000000000007') + mp_sigma = mpm.mpf(str(sigma)) + mp_two = mpm.mpf("2") + S = mp_sigma * mpm.sqrt(mp_two * mpm.pi) + discrete_gaussian_prob = ( + lambda x: mpm.exp(-mpm.mpf(str(x)) ** 2 / (mp_two * mp_sigma ** 2)) / S + ) + gaussian_prob_at_sampling_points = [ + discrete_gaussian_prob(x) for x in sampling_points + ] + + # We need to halve the probability at 0. + # We need to take into account the effect of symmetry, and we have + # only calculated the probability for the half section. + gaussian_prob_at_sampling_points[0] /= 2 + + # Now, calculate the Cumulative Distribution Table. + CDT = [0] + for P in gaussian_prob_at_sampling_points: + CDT.append(CDT[-1] + P) + + # At this point, we should end up with CDT[-1] which is very close to 0.5. + # This makes sense because we are calculating the CDT of the half plane. + # That reduces the effective bits in the CDT to 127, not 128. + # This again makes sense because we need to reserve 1 bit for sign. + + # We need 128 bits integer representation of the CDT. + CDT = [int(x * mp_two ** mpm.mpf(str(security_bits))) for x in CDT] + + # Chop the numbers down in a series of 64 bit integers. + num_chops = security_bits // 64 + chopped_CDT = [] + mask = (1 << 64) - 1 + for chop in range(num_chops): + chopped_share = [(x >> (64 * chop)) & mask for x in CDT] + chopped_CDT.append(chopped_share) + + # Now we can put the chopped CDT into a numpy array. + # All the numbers in the lists are representable by int64s. + # We transpose the resulting array to make it configured as N x 2. + CDT_table = np.array(chopped_CDT, dtype=np.uint64).T + + # We want to search through this table efficiently. + # Build a binary tree. + # Note that the last leaf is the sampled values. + # The last leaf index will be calculated in-place at runtime, + # Thus ommitted. + tree_depth = sampling_power + CDT_binary_tree = [] + for depth in range(tree_depth): + num_nodes = 2 ** depth + node_index_step = num_sampling_points // num_nodes + first_node_index = num_sampling_points // num_nodes // 2 + node_indices = list( + range(first_node_index, num_sampling_points, node_index_step) + ) + CDT_binary_tree += node_indices + # Use 1D expanded binary tree. + # See https://en.wikipedia.org/wiki/Binary_tree#Arrays. + btree = CDT_table[CDT_binary_tree] + + # Return the CType pointer together with the array. + btree_size = btree.shape[0] + btree_conti = np.ascontiguousarray(btree.T.ravel(), dtype=np.uint64) + btree_ptr = btree_conti.__array_interface__["data"][0] + + # The returned tree has probably has 31 x 2 dimension. + # The 31 for the number of nodes, and + # the 2 for (lower 64 bits, higher 64 bits). + return btree, btree_ptr, btree_size, tree_depth + + +def test_discrete_gaussian(N): + btree, depth = build_CDT_binary_search_tree() + + rand = np.random.randint(0, 2 ** 64, size=(N, 2), dtype=np.uint64) + + GE = lambda x_high, x_low, y_high, y_low: ( + ((x_high) > (y_high)) | (((x_high) == (y_high)) & ((x_low) >= (y_low))) + ) + result = [] + for r in rand: + jump = 1 + current = 0 + counter = 0 + + sign_bit = int(r[0]) & 1 + r_high = int(r[0]) >> 1 + r_low = r[1] + + for j in range(depth): + ge_flag = GE( + r_high, + r_low, + btree[counter + current, 1], + btree[counter + current, 0], + ) + # ge_flag = (r_high > btree[counter+current, 1]) + + # Update current location. + current = 2 * current + int(ge_flag) + + # Update counter. + counter += jump + + # Update jump + jump *= 2 + + sample = (sign_bit * 2 - 1) * current + result.append(sample) + + return result diff --git a/liberate/csprng/randint.cpp b/liberate/csprng/randint.cpp new file mode 100644 index 0000000..f6ac564 --- /dev/null +++ b/liberate/csprng/randint.cpp @@ -0,0 +1,55 @@ +#include +#include + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void randint_cuda(torch::Tensor rand_bytes, uint64_t *q); +torch::Tensor randint_fast_cuda(torch::Tensor states, uint64_t *q, int64_t shift, size_t step); + + + +// The main function. +//---------------- +// Normal version. +void randint(std::vector inputs, + std::vector q_ptrs) { + + for (auto i=0; i(q_ptrs[i]); + + // Run in cuda. + randint_cuda(inputs[i], q); + } +} + +//-------------- +// Fast version. +std::vector randint_fast( + std::vector states, + std::vector q_ptrs, + int64_t shift, + size_t step) { + + std::vector outputs; + + for (auto i=0; i(q_ptrs[i]); + auto result = randint_fast_cuda(states[i], q, shift, step); + outputs.push_back(result); + } + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("randint", &randint, "random integer sampling (0 to q, 64 bits)."); + m.def("randint_fast", &randint_fast, "random integer sampling (randint fused, 0 to q, 64 bits)."); +} diff --git a/liberate/csprng/randint_cuda_kernel.cu b/liberate/csprng/randint_cuda_kernel.cu new file mode 100644 index 0000000..4fbc74a --- /dev/null +++ b/liberate/csprng/randint_cuda_kernel.cu @@ -0,0 +1,215 @@ +#include +#include +#include +#include + +#include "chacha20_cuda_kernel.h" + +#define COMBINE_TWO(high, low)\ + ((static_cast(high) << 32) | static_cast(low)) + +#define BLOCK_SIZE 64 + +#define LUT_SIZE 128 +__constant__ uint64_t Q[LUT_SIZE]; + +/////////////////////////////////////////////////////////// +// The implementation + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +__global__ void randint_fast_cuda_kernel( + torch::PackedTensorAccessor32 states, + torch::PackedTensorAccessor32 dst, + int64_t shift, + size_t step) { + + // Where am I? + // blockDim -> BLOCK_SIZE_FAST + // gridDim -> num_rns_channels, num_poly_orders / BLOCK_DIM_FAST + const int thread_ind = threadIdx.x; + const int poly_order = blockIdx.y * blockDim.x + threadIdx.x; + const int rns_channel = blockIdx.x; + + __shared__ int64_t x[BLOCK_SIZE][16]; + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = states[rns_channel][poly_order][i]; + } + + // Repeat 10 times for chacha20. + #pragma unroll + for(int i=0; i<10; ++i){ + ONE_ROUND(x, thread_ind); + } + + #pragma unroll + for(int i=0; i<16; ++i){ + x[thread_ind][i] = (states[rns_channel][poly_order][i] + x[thread_ind][i]) & MASK; + } + + // Step the state. + states[rns_channel][poly_order][12] += step; + states[rns_channel][poly_order][13] += (states[rns_channel][poly_order][12] >> 32); + states[rns_channel][poly_order][12] &= MASK; + + // Randint. + #pragma unroll + for(int i=0; i<16; i+=4){ + + // What is my Q? + auto p = Q[rns_channel]; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(x[thread_ind][i], x[thread_ind][i+1]); + + // Use CUDA integer intrinsics to calculate + // (x_low * p) >> 64. + // Refer to https://github.com/apple/swift/pull/39143 + auto alpha = __umul64hi(p, x_low); + + // We need to calculate carry. + auto pl = p & MASK; // 1-32 + auto ph = p >> 32; // 33-64 + //--------------------------------------- + auto xhh = x[thread_ind][i+2]; + auto xhl = x[thread_ind][i+3]; + //--------------------------------------- + auto plxhl = pl * xhl; // 65-128 + auto plxhh = pl * xhh; // 97-160 + auto phxhl = ph * xhl; // 97-160 + auto phxhh = ph * xhh; // 129-192 + //--------------------------------------- + auto carry = ((plxhl & MASK) + (alpha & MASK)) >> 32; + carry = (carry + + (plxhl >> 32) + + (alpha >> 32) + + (phxhl & MASK) + + (plxhh & MASK)) >> 32; + auto sample = (carry + + (phxhl >> 32) + + (plxhh >> 32) + phxhh); + + // Store the result. + // Don't forget the shift!!! + const int new_poly_order = poly_order * 4 + i/4; + dst[rns_channel][new_poly_order] = sample + shift; + } +} + + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + + +// rand_bytes are configured as C x N x 16, where C denotes the q channels. +__global__ void randint_cuda_kernel(torch::PackedTensorAccessor32 rand_bytes){ + + // Where am I? + const int index = blockIdx.y * blockDim.x + threadIdx.x; + + // i is the index of the starting element at the threadIdx.y row. + const int i = threadIdx.y * 4; + + // What is my Q? + auto p = Q[blockIdx.x]; + + // Compose into 2 uint64 values, the 4 32 bit values stored in + // the 4 int64 storage. + uint64_t x_low = COMBINE_TWO(rand_bytes[blockIdx.x][index][i], rand_bytes[blockIdx.x][index][i+1]); + + // Use CUDA integer intrinsics to calculate + // (x_low * p) >> 64. + // Refer to https://github.com/apple/swift/pull/39143 + auto alpha = __umul64hi(p, x_low); + + // We need to calculate carry. + auto pl = p & MASK; // 1-32 + auto ph = p >> 32; // 33-64 + //--------------------------------------- + auto xhh = rand_bytes[blockIdx.x][index][i+2]; + auto xhl = rand_bytes[blockIdx.x][index][i+3]; + //--------------------------------------- + auto plxhl = pl * xhl; // 65-128 + auto plxhh = pl * xhh; // 97-160 + auto phxhl = ph * xhl; // 97-160 + auto phxhh = ph * xhh; // 129-192 + //--------------------------------------- + auto carry = ((plxhl & MASK) + (alpha & MASK)) >> 32; + carry = (carry + + (plxhl >> 32) + + (alpha >> 32) + + (phxhl & MASK) + + (plxhh & MASK)) >> 32; + auto sample = (carry + + (phxhl >> 32) + + (plxhh >> 32) + phxhh); + + // Store the result. + rand_bytes[blockIdx.x][index][i] = sample; +} + +/////////////////////////////////////////////////////////// +// The wrapper. + +//---------------------------------------------------------- +// The fast version, chacha20 fused. +//---------------------------------------------------------- + +torch::Tensor randint_fast_cuda(torch::Tensor states, uint64_t *q, int64_t shift, size_t step){ + + // rand_bytes has the dim C x N x 16. + int dim_block = BLOCK_SIZE; + dim3 dim_grid(states.size(0), states.size(1) / dim_block); + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = states.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Prepare the result. + // 16 elements in each state turn into 4 random numbers. + auto result = states.new_empty({states.size(0), states.size(1) * 4}); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(Q, q, states.size(0) * sizeof(uint64_t)); + + // Run the cuda kernel. + auto states_acc = states.packed_accessor32(); + auto result_acc = result.packed_accessor32(); + randint_fast_cuda_kernel<<>>(states_acc, result_acc, shift, step); + + return result; +} + + + +//---------------------------------------------------------- +// The normal version. +//---------------------------------------------------------- + + +void randint_cuda(torch::Tensor rand_bytes, uint64_t* q){ + + // rand_bytes has the dim C x N x 16. + dim3 dim_block(BLOCK_SIZE, 4); + dim3 dim_grid(rand_bytes.size(0), rand_bytes.size(1) / dim_block.x); + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = rand_bytes.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Fill in the LUT constant memory. + cudaMemcpyToSymbol(Q, q, rand_bytes.size(0) * sizeof(uint64_t)); + + // Run the cuda kernel. + auto access = rand_bytes.packed_accessor32(); + randint_cuda_kernel<<>>(access); +} diff --git a/liberate/csprng/randround.cpp b/liberate/csprng/randround.cpp new file mode 100644 index 0000000..5af14d5 --- /dev/null +++ b/liberate/csprng/randround.cpp @@ -0,0 +1,32 @@ + +#include +#include + +// Check types. +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LONG(x) TORCH_CHECK(x.dtype() == torch::kInt64, #x, " must be a kInt64 tensor") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_LONG(x) + + +// Forward declaration. +void randround_cuda(torch::Tensor input, torch::Tensor rand_bytes); + +// The main function. +// rand_bytes are N 1D uint64_t tensors. +// Inputs are N 1D double tensors. +// The output will be returned in rand_bytes. +void randround(std::vector inputs, + std::vector rand_bytes) { + + for (auto i=0; i +#include +#include +#include + +#define BLOCK_SIZE 256 + +__global__ void randint_cuda_kernel(torch::PackedTensorAccessor32 input, + torch::PackedTensorAccessor32 rand_bytes){ + + // Where am I? + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + auto coef = input[index]; + auto sign_bit = signbit(coef); + auto abs_coef = fabs(coef); + + auto integ = floor(abs_coef); + auto frac = abs_coef - integ; + int64_t intinteg = static_cast(integ); + + // Convert a double to a signed 64-bit int in round-to-nearest-even mode. + // https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__INTRINSIC__CAST.html#group__CUDA__MATH__INTRINSIC__CAST + constexpr double rounder = static_cast(0x100000000); + int64_t ifrac = __double2ll_rn(frac * rounder); + + // Random round. + // The bool value must be 1 for True. + int64_t round = rand_bytes[index] < ifrac; + + // Round and recover sign. + int64_t sign = (sign_bit)? -1 : 1; + int64_t rounded = sign * (intinteg + round); + + // Put back into the rand_bytes. + rand_bytes[index] = rounded; +} + + +// The wrapper. +void randround_cuda(torch::Tensor input, torch::Tensor rand_bytes){ + + // rand_bytes has the dim C x N x 16. + const int dim_block = BLOCK_SIZE; + const int dim_grid = rand_bytes.size(0) / dim_block; + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = rand_bytes.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // Run the cuda kernel. + auto input_access = input.packed_accessor32(); + auto rand_bytes_access = rand_bytes.packed_accessor32(); + randint_cuda_kernel<<>>(input_access, rand_bytes_access); +} diff --git a/liberate/csprng/setup.py b/liberate/csprng/setup.py new file mode 100644 index 0000000..13b1c92 --- /dev/null +++ b/liberate/csprng/setup.py @@ -0,0 +1,33 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ext_modules = [ + CUDAExtension( + name="randint_cuda", + sources=["randint.cpp", "randint_cuda_kernel.cu"], + ), + CUDAExtension( + name="randround_cuda", + sources=["randround.cpp", "randround_cuda_kernel.cu"], + ), + CUDAExtension( + name="discrete_gaussian_cuda", + sources=["discrete_gaussian.cpp", "discrete_gaussian_cuda_kernel.cu"], + ), + CUDAExtension( + name="chacha20_cuda", + sources=["chacha20.cpp", "chacha20_cuda_kernel.cu"], + ), +] + +setup( + name="csprng", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + script_args=["build_ext"], + options={ + "build_ext": { + "inplace": True, + } + }, +) diff --git a/liberate/fhe/__init__.py b/liberate/fhe/__init__.py new file mode 100644 index 0000000..3ab2a1f --- /dev/null +++ b/liberate/fhe/__init__.py @@ -0,0 +1,4 @@ +from . import context, encdec +from .ckks_engine import ckks_engine +from .presets import params +from .cache import cache diff --git a/liberate/fhe/cache/__init__.py b/liberate/fhe/cache/__init__.py new file mode 100644 index 0000000..6f702bc --- /dev/null +++ b/liberate/fhe/cache/__init__.py @@ -0,0 +1 @@ +from . import cache diff --git a/liberate/fhe/cache/cache.py b/liberate/fhe/cache/cache.py new file mode 100644 index 0000000..1092baa --- /dev/null +++ b/liberate/fhe/cache/cache.py @@ -0,0 +1,32 @@ +import glob +import os +from ..context import generate_primes + +path_cache = os.path.abspath(__file__).replace("cache.py", "resources") + + +# logN_N_M = os.path.join(path_cache, "logN_N_M.pkl") +# message_special_primes = os.path.join(path_cache, "message_special_primes.pkl") +# scale_primes = os.path.join(path_cache, "scale_primes.pkl") + + +def clean_cache(path=None): + if path is None: + path = path_cache + files = glob.glob(os.path.join(path, "*.pkl")) + for file in files: + try: + os.unlink(file) + except Exception as e: + print(e) + pass + return + + +def generate_cache(path=None): + if path is None: + path = path_cache + # Read in pre-calculated high-quality primes. + _ = generate_primes.generate_message_primes(cache_folder=path) + _ = generate_primes.generate_scale_primes(cache_folder=path) + return diff --git a/liberate/fhe/cache/resources/logN_N_M.pkl b/liberate/fhe/cache/resources/logN_N_M.pkl new file mode 100644 index 0000000000000000000000000000000000000000..444186d3f1962b481f18c13b8f0f9f0c74872833 GIT binary patch literal 101 zcmZo*nHt3a0ku;!dRTJu)BUEzPSNn@@#gjB^XB&!@D@z%VFZi$G6?uGDEKls_%bwj aF)%PP00|}rpsX)Q2No%yJPSi=sU85B+Z5XX literal 0 HcmV?d00001 diff --git a/liberate/fhe/cache/resources/message_special_primes.pkl b/liberate/fhe/cache/resources/message_special_primes.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3e691fa6c9caff0bbd487de6f96e50267d5afb33 GIT binary patch literal 1110 zcmcJNu}T9$5QgU@g;mZ6NMT_SENpBouCS2Al!6GtAsk3-O=n<>6q-7UpootkVi$|E zQLxk(2-pZgjKLyma&i92?JNWXT6o3Bcgy@cH@9mths6xZQ*GpnIYx{E0On{9v?t4Ab+Y63x2wSkPs-lFPe{xWfouVQr*?pVHFDWvGqx zGlL$12fW@PZCqNuSAtURD(M!P8=whX0hfMMAo6IGm)WrEs5Y(O*Hf$=0d-&>CQ5ApbHCQ%z$3;NbbEq=FarfV&J zQ|nj@)q)pUOAqZ|o2IQW&XaA5vz{=Yo660S29Y?ivzP;GRI_GwtWnvT*%C{y2yfvU zxS%Vi4I892wMDggwK=r}r-b98d$p#R-mJg0{cVdgZMLVjt8L{%PHOAEeLRZ9s&2Qe T7N}k0Cd#xmwQaGeQd#&0V!q>y literal 0 HcmV?d00001 diff --git a/liberate/fhe/cache/resources/scale_primes.pkl b/liberate/fhe/cache/resources/scale_primes.pkl new file mode 100644 index 0000000000000000000000000000000000000000..88edd0186689fad1d44c127d5f1ec84823dda7eb GIT binary patch literal 141973 zcmb5X2ecMN7X4dXGAJ;LVnAsW1;vCB6s1H^;fiq-6WS`*oO7nkIma=!Vphg5>X^XN zQBl+}hcQ4aDn?A8qJ+2i2_O95d+V+BT&|D(+vikus;|@6eQs)VO{>r<@W0cp?$_LX zROoi@)nl&iH?mdPEsShccfm*e8$JrEVo7|2#rUy2KI+c+h*XiPSP37wBR{B82YiH8 z@KLs-HwGVd8+s4WqpIJ+M?qE0BsSqA{MZU@sfx3k&@)x)*BTwPVs!ns_-GYcnwMHQ zt1{uDC&DsenKzyAGVp57LbVh=Vqbg&s$xribjL@%9hEHC4p4CkdOPAHP{kH}BrbnA zm`GJ0hL6m(Qyq&BSx@;ydR_4mPsfL>wsbWVR$EK+(hFy`+8ky%7$4P8eB_Op!4aB;N*~}u)fhwuWF;KE`tG%IgrIz%TL$o>!s-VgzLY2!>T|>oHUOmn{ zsyGQsR(@U$TT@nk*a=EjejG=|m0!NhZ7nOm8@9Eq{CXuQS^3pyD6IUJ<{k@IWlZd! zt2>;HkNQ+1m8_F+4)Fwhq;v66h`NC4D11~@vLS1#lABbw7dNvEP}$(=LFOIE=p=p| zKztq_LDUOS)&BS>C)1-!R5EPJSD^A(2By&C_%B|iHvk`rs-UXGTL&t(U}e>7(R&yR zTu{XipkyHA4ch=$hgV(;6{zY8D)oa3a~bW+7CYc8qQm52epxt7>R+Mc%aQ-n4n41g zRk_Twx>56s-lX&LWPb{iC@)%O1`vJcc|(IdaFQ{RAF_fC|-Z4 z`g?fQ8c_KisI(STm6^9LRPhK@NmV{fkK;2?RnwvBhoJJ(7#)?U9#ENEr7WOi`y{GB zFN-$;s-`Nas@`nD?^#?4O1@obWkxrGm-;{js&ZB4ZAdk)Bl0$Ys+R8oZ#}4d(lS7* za21qn!g%6x=w$#_<+6*c$|g?aPt1921C{1*8QMaXRIwve^*dDAh0)&_r3zH3EAzgA zs;Hu;5F!}{l9tc%Zg-xIe@iv9ZT%W}ds5A|#jEvq) zZzxoC15~6cu7;|p>Z_n?@u;fd@Pc?$<(2UAR#1_u=+2RP1ys2yRCzg6y*X5%Du+{D z0+n`wiWfu4*SflZYBzZGd97NSms_|hLlX!0Ux-vx#h>h<9lhV7Qa7sEP{oo^kt!@j zRnuFV(RZODRelGm>;Y9$6;*fxD(wfAc4TxcqdU-}3fn@} zRBF|IDsqigR0UOR z2bDI1D!W3-nyr_Hsy2ZaR)h+J*rG2~+<@LjPT$Sa5_J}lfKo9ewGFAO2 zRMiG5{ti{Pr}`Nxc7&>aWR$A>0jj3T-$Lb1RNv5Bf_YTAKUAP9*MLehpwgO9hpXQPm$q#Z95oB~Zm?P~l>zdI;6IP$5E9=P-||KAU-) zLzPt3mQeXD=G_1lL@mj0^qKJL?ojy*Mwg;WQ03B4B~_-%e`lU}n{q`j)3K$w=fYJP z7IZ~lP2G?)AF5~#73V@Fs`?M8nkrGHHc;hnP_Zpk^$We_8Ko*$fXY8XAV^9@U zbuv^zSz*SHsRidh9K*g(}3aWe!RQ(n6uBG|{s=6L3Pp4YuZwOG;$1j18&*+^7 zm8jAgEa#8Aux+cV zHBT>t=(-(s{Tuz*N%XhWb!X} zzlP*IN}boF&W}*Xwa>uGRN?MCJ6nh9orPAnCiG^h(R=q4#Z6#Ds7Ap7Fb}tIm$7t+Xp?sQ9o}{sJgvyE9 z?mUe>p_vy5rQ?N)3pMujd@gNLDvd6ZSTRFl#Y>v`rBL3eiAxlf!p1^Ht(i)}XqaYB zlUS`(ekv4)YdfX1l~DRvVkIG=G_rQ?MvrF@J~v3(Ox=6Wx(Da)qN zghpY`0@(;{lztOR-G!>zLiJKY^-n@^X{C3BiXNJ2w2V+P+C?b8t$mJGdQ1CMicbrL zJv25-C@O`KLU~V#m0M}%-a_>jLa9_5EL85JnMMZ)rH!>0rFtXnWppbpd%1y7yhNy8 zPbgg@RIMvi-=LZ62<1wsP%6g@eitRx9B{f&`zkDE0mPVUp2G6 z#(vURC!wf>ZytMz-)pRoQ2I{W^%W{-2!(YtQwbMnY*&4yu?>atd)iK^xLODoddy5$ zXy#^0X9{JdaE4Hd5-XCndrByrs2slSamkFeP&GuTm?l&!g-?ZI zlvw&iC=FG*Q)#$RK1!(EN~kg#DHIOXUM?4^57Ku1TXC+_0ora&p>jWsZ7ftMRhtOq zeI*v6w%bQ2E5+aDb4}EnYv#H0NMWc@JXNSvN~Z|b!z5Onq?yAt(`ajJ*_;skm)Em$p_)HwneZw3i!&l2X1#sCryt z;cB7qq-I{C?TjuKDke#+REkQ~`P%MTjh!Www`j$69hBm)O|s^b_A*Z>8MSK7ST#;! zu}LT@RiKvURTkNlVg7<%$i_YLg3CBh2=~)T#Vnz^v&Lo!6^jX_FNN}wLe=L=8!Od9 zX*127rZhw-PZg@7PLZ? z)Jv#3L?{(P!DuC+`e2FGO1a;Bu8Cr0iRI_#k)l%lq)@F?Dy3C4^9hL+{eTQ!FAcHi)>1q_C&<< zp9_ZIpGpa*rzD)ZGFJR9R4*YEXKQRpq4bMTUP`EXPbl=zcJFHDG8%hJC@!n*jCyM3 zo0_?zX1<}ltSuCjify#re}mp>m4Gl+q1C=}_&ZpHOv}Q0^my z+wEN1s1y$us(MQ-9;q>7F8?_i&G_94Y%kCf^kd5 z>R7?Fm123cU~P!+DwimX)Mz|JlMTaxg5|*)Ee_DO16s-Qru_xux{A3)qlRg$;szQm z_7N-wD%Mxc<8(=OmU7YF_*aW}RIYxq_!8EB zrgFW!#lJW1rHrGm^k1x_9E`JZoT2f!zV-L2a%CK@w*HLcRmL0H^sZD+=P1WZjK8q> zg~nei2jki}U7+#ud&@shx&A>p71n;Ha=q?+U0**M_f}3nE63$6{)=+CoN_fwxmZRy zoYy4lJM>VlpKdK&Eu|daSI$c+=TDWRaXr=Ym(Y0n%=l&AQozWcuABzq#z9N-YKsho zuwVhMurNq51e!P`<|r5KjDNE@h5+LQh5++%NEVL2C|64=*BAnfa|kd${b2d4DhK0i zj0+>pUrpo16UM7sfA=Vt{jERadPC(nUgP<3<4QR_sT^*!_LHss4aP4k*Z;8guUh`K z%Hb{JYn0LTT8hH}P`Vmx8U$oe$S7$W5Qd~5%ka@@)C|4=Tv&DZ|t zD5p-!m2(H>`Y(-_?Ul>paDNp2SM4nROKag^9A+z5O&TxAEzPShG8F0%a7Bn&{A&R& z?eH%)x$tb|*g-jC2rwRdC>K8~7b{vEdj-a`F)pTr6ZQh^5BHsvdn?ChlxyP}LxAm5 zAIpDOx!hK{e#qjLa`{h-->#e

$m*e!u14W$~%Jz{gO(M>$VWt{$-Xeaf|QyjM9+ z(0Fx^a_WkE`t0X!%YU#*INxFXxbZEP|D1ART)v=O->C8Yl5%;2a(r1iB<0FDAO5Gz ze|%Nr@nGfRb?ff{#*0gp3uF8? zt;;h*H);~7JM z@v6IW{6d*~3C818W!!ugF5ghDmsQ4zy0nk|l+qws?Q#aJaQ!&G-oAvafP|q_tl~IUc1PoDXZI^V7%jA5tz>R<0gVu9mmY*IoH14ik8t1OcmB$D31_;*UJC?u4T;cq#a9yyhZ9f#3RGn1vDtf&%|?NRVrs4w4fN1mw_3IpaVe`HNfpON)0= zE-^$H=U#&CaYZj&{LS)><9e3=nZ|45gdrub79!r05OzjTl?_afx#Q2g4>{H>J(h7dVz zp&XA;j>h>&cx$x|E24vt8)BIx$LG~80RIF!!(WO<*ofx<=9)f{6x7vs;$gV zG)`M~5UxJfco+e0XKipmTrh}H;9JZl$&DK4gEAFdGDt?c-;4of)Hj=f z9f@Rwo7Bj-uoR=bQH=_h8EW7xu^B_98sB3I9 zb82_$O}rHbqOOe%%_+Y!C$Dc#g`p@p^%~|x++opqZa6DoKO$qpo$@N2YKpTOTeUO~ zSY#+goMYeu7rP&BUS@^P=49+4#K|}~NlxMhXHJFvj5zTtbLxTS)HoT@el{_u{I@v? z=NHndKF6G5GjsCU?hG-fIKx{F^_*l*xX7Gzx;b%yw>rg~s;QOD;mPI{ZOy4q@_ss+ z6HYWIcQOZeK6NS5-^@uzdOy9)Dcs4uz10zxQ}*>7oQmi=>}O8(Pkd*&4$J=D>OOPu z#Iw$Gb+0+uoqUHm#oCq=+zG?YiQ_%zUhECHP2=tEJYi10&7As4bIO~|sV2H}qqmx5 zPFfRZteofKdGF_RT&Cdn{&g1_N;uo$g*`^c94vHeE+UR#;^4Onbuvy##gV-db1Iz4 zh*Qq89K1td4$d>BpMt}UJs0nqQ@c|?Yff2M4j!wqX}J?8y3=Ag^)u$+VSMY^oqVA8 z^R)MTs5>|^$^2wo>&X1%wanqwj?Q5}b9lX@%Y_RfX%&t$r^4+|?dJsVr?y_GxF zb6Blh_PnCygcr@JS1_l1*_^Vc=e+99(&l7$s%Nn=W$V1|Ib+RVZG9e-i=+f!g4 zCQgBYCr;$Si#hly%pC5^#KFx}$;r!kKRm#(pTxn%oEkS!#VJ{BTZF9mJ zo^!7`H4ktyhj+NMt~uo`<`g(B(rNMF#hxQhgLDpY`%|2f`%RgjD?Ddoa{|tLbbh!e zWlnj4J2T9w&NrvRDTnrRo;$cXE>1qzow@EDVh(3}FdKXnedCza*G4Nq|@cX$g_ z=cg4mICF}&=Hxc!q&DUh{jmv|lUtip|B6jWC*szr^i#|+CoRApN4D~x=2Uat>3}_s zv~nkOG$;S%Iqrl%>9jPjzsOL+lSaG&pl8pxBFC2J&WKLI@j?erlvxh0sKw!Rm*j-5 zp2IOE{cr@am7GS=sc>;6PUM)P6AE+Wh$*dBH>bwMyW~WU7wLybOy3LhvIm?`MsX6IP&zWsba3}m_PR%1P`$;^i z%G6Icr{b2UlTS5A#sZykb#o%FYc+>kjQvC&d6`pxXHLl@FP->}IcYI-s;|unUA&*E z=J39qv?|@HhhlWFpY%V=DYiAIdf%M5ojJvO=F~fw6JK{{Cps<78!R%EcwvSc1(<8@ zdJ4`g&$j6lJZYkndF-Q;7BeT|>P?)$!Nr^k7iZ$+hQnPCTa`RMGKcqM#Hp4rCw%Dr z;Lf1rgtt6rE${ii-p@|v6t8=$vF=PU2fxB=KgCPl507h{dcbAAOuab8occNM=Ty&m z)|~P*bE=W%q|?pG9n7iRN$t!DXL>7K;_I@X>&{&4AY^S^U{3tY96a))b6C!Crz3VW zvQ~d{2iN+#6u+92moumS*_^7UxBAJPaz%5B?>vV`A@&@<^&DK~i&J#LFyS25?&OIW z9CYHU-s+#`6n*KmG;g@bP%64X z=kMl(F`k3p9CW!xd(XT8WUFFl&)L;mxl?f9FsI(ha^jBo&PywI!XD<-+gVPzr#X2W zbL!HZ#6c=cG1gmk_nZUGiECr?Fee@4Ijfsf9c)gyn&;rR3z?tjPIZhq!JRM>$2HD# zTFLu)2q#3c6z;@_%_)0ZPI$zeyn;EkI4#W^Ei#m3uQcw$4tVDPs|q(JuobuyqQzgA z&Ef$~vUot#EZzp76?ps-D}QTN-LL|$0-05=>RC0f@@eIMZpn2ZE&iUYeN{Zr(W*9b ztC(h%T=(e|d1zr)G1RltIJ08I3a6XJ+eX?w9xAhj-LST$uW*7{p{HBrab{&+wMkaR zb9PQK7H)-Suztnjm7HW1Pnnfgatn`+*%Vi9SWlQ0=Hfdp4b#89uglwbUk&R@x2nf1 ztGwE+;8u05Tlrr;#p@c@BW_*qR_Je5b)#G5i#SV`b$_c{;d!@ibE|&NtZ=7W+t3k?|5G3E=cm^TwkV9@8MqY zrstJo+~el?^Wqa4%fn$XnCBbUlWj z&hzT-Rrk7Atk&@EcQ5yKuYTCQYE}12{L-fV<=+ucdj=B15sj4wW}KFk9TS z;qC->Au_P&MHztfbg4w?C_*NIXjd7QsH_0#mlDA$G!A~dvO)_ z^5gFDxg49nf4Nul8bGFz9&s=9cdux8MRG4s@Vw%dhIgNP<=yTTyeyS@j`zA3?sl*E zr+aZeH!#;he9*n}F89(y?iJ(REC1zQ{D*sZni~Ico#t!ZtDbPLy39R1J#Bf#x$e~s zFP-CF`~^Et-Iiy&mu9(FpXFZphkNNv_liH=t4?>XT;N{8yRW=w8WF zZ}F;L?iD=M*1W>K(C{i5Z0xVfo`=Wq@gKb~$-R0+!<*Rf&UCMO#=UT%dqu-5FLp0J z>3R7Q_sSve6_>hK@uXSiIbP;o_{P0(rF->P?v+=$S2n!(qkCz(=i%iwn6kW|xfch! zm!~znA?~FQ+(YdGYa8BgcthQ*-*zu-;a+uU!`s@u@?;+NxGk$@_lo`9tG974?c`oH zs^M+xUb(G%vEdbyxo4AQ+s5;%J$Z3SuV{8J?c-h;>0UfQyq4xo78!~WO?kTn-%Rd3 zaZd#YXrAzihu5XdD|n5e-cU}-0+Qhvmhg<1y1NW+f zJTI*49v)P-X+$~P%5qF_FRkNw;c6bbq`x)Xt8RC%lEba6=lk5N`*>cIL!q>7c;(Zc z7gqMXYNC5_CHK-K_o^dsf+y>*c(&o4!wt;omf5{}oO|(Q_tL5Em9My0p6p)L@X|@{ z)gO3XI?=uOH3u)3qdv~PV)6FA{*Dx{g>|6{s|$tSj7Dk=DEvr`w|hJD^<5Sq!f%cU zu-Z&&!f#Qi&*V)a>j~izW@NL86i>m0E|e5hS&V@V_#u6<1dH^Q{|$w=UJ!VMbwiq{ z_UkV7l)US-1x%@(l(vM*Y@9|gk>zp1uKWCs&O2#Yet}QO>XykKsZ{6(yP9 zxILzcUo4P(FqBk13Wrc}8Nx%1E=9#^M)->1qch8IP)WV2Wc7-CRu`5k3Y-}&Rdf=w z&&7unEyBG!1Ww?`<@mtw3i!B!9|z+@3KU6cp?D22k73{+4D3dvid5kyesIy_5B%7H zxFP44DpRGA^kiwM(&_ZL+VFfH6jx^ut;p8<2VPr5y&_+g zxe0R={zPA?HRtUQPKi~D!f))K6^Zh%RIGX>m5b6cjDA6nDoH9 zy$OZ?@j)fEYQh`zq<#(SHAzjIa3rIweIwgPD&t6rD5-=cC8F~6jP8%T^baW6dih!| zD+@!#Q&_8IKlV2cD(y>ELB*2MB~geZ?G2UtusF$H^tPblmQTk(;qgo$OPt6U2y3I_ ziHrfMucE{MNP67z@hem}!5ENwMZPMvN)f9M<@Vexzj7{Fr6~Ubl{#~yu?|tD`oGs4 zT90|GDkRm1!rM@R)rZnuY(W*rL&ZT*;WkEDktp26s8l4nkx|wc%3M3?Hz-|CD{>Wa z-D&AJRu)UE!IO0y`g4eIy{9#p$1R%HhLWOMX#mw^F7Uck&v6y44;3@L=jlzNHwq<5 z0#!H%Dn3PTNtR6EekUdqP?M{B2he?(bt(rm8jAY zREN-`k_tlcD0-|L6pn?G8bRS?Mp+{$&4QA3BGrOqJ;k9MfwJy~QE`37QB+*VX;&y& z?|BS8*%tdyNol?VpmJ|c?L2w~RKAd2FSfl1D#{k63O(V)i|L)u?R_Z}c0f|E$X8`z zZKB+nBbT*_VppjA8`UyYtV5Kjq~=iU#k?2bg#s!)N3}BZGV}V-qssm0Jwxv_sQ3g_ z{0F_qp`=<+dW7mmwm1?h-^}Pt%wWESExzFFk7x9AdQ|Z<&ip(IRJfUH3Zpksy-trR{Fiw*P`w2eJ2LtX zqbKlN`4PRNSv=@tdPh-Br8*cYeZlBK%&YK)%m>nYVo9j6R4+lL{h;#8P;npTO@Rtj z;Z=Hj(W46E7sJPYq4Lq(G;dHH)k*3VHC8376s2|rK3JP5cYw;jKt?JRoxdwrp_k=GYWt3bs$Y}<#? zKiOhsM*pH(iQZ-_Bifs)Yab+2#lJy?Ud&qoDlbcKMW`fd%|24E$X8`(ipZWj1mV_D z@eiob0V*{L6y=$WvOrN_fugu1RAyPCv=$YM5v3tiQ>lhCk1B2nCB=yH*7T02w;jD> zp`@TtK87uJhnM$&%6l_!2=m5LZOptDszLPbhDrmeF5es<8&KT@mDgw9%}`0yEl^NJx&EGfR!B?{fe|DWV6OC3eW5#oO+`8!dE4@B=w zonI8atLV|vb`R>f6?NQ8@`q8Uy{Xes$=^qO(P`H+=x-|dM^lHvl5+`lUXMCoD!%HQ zsN(?1xmj}7mbSN0r!}clB{?UvE>XT+^t4l<i=H6*anb*z4i8g@hpE#)ML#0G z=p;JbEBRB;;yj2>FY^pN-9RAX8(DE?F1T_dsZfKa+u`@CByJ}8vNYv$vcdAr7x z@~uL6Cy?`!u2=fE_HwJ#`pFYDcAnM&nxrwM@T}4}iG}Bd@@Yy3w~}Q)NqY&cN#R(b zw2V+XTB(;%K2qEDQ94{GuBvpHX7<%~hic~PLh&l8G!)hl3h&L4b-lJ`zAluNau$jM zw4G93SKB?IvGrSPy(0Zqai#EeK+O67TfL&tL1S8>D0No)5%(T30J|!EB^0^|p)iB2 z=p~fCP#P?hKTsN?F|B13qta^&Tqy4@6o+eUC#8`Zv)W0`8hcOb4yA3Bj-Mw>J4z@WC=|BU zc4LLIQrt@j*V;1Iqm_2ocDpF;)mFBH(LO?XXNe`H)U26fB^F1_(|SdXO^K&>63f$BGff=|Wb12tU6vmbLs) zen~U6{7`;ONy`seX(D_%rB5wCl+P2g5<^x-C~1iyD*%+W#8A`%K=C&1Q%nAYD}->c zW1m(=2)7eRR=Ov?s&tT+1A0wqe<9p4VrFO&!kY^+XZva9l~On;?xUFlr5sQym4*oA zy|i7_QbD_FyP=x7i%=YGJB)6zuvIx-VKtFaw^ zf*@rjyt*v0(b{f`#`Y1y5RlA6lwQ=#wXj~%!f&M( zDoR>@D15zO2p3Mv58+uLQqb~4sTN9Feh5#nF%~z$vrqWvW6c~Ql(ZO8ib{`Z<}l4X zS~G_$9i=o(uNwlQ7H6j;?&a(N-wvOdAVLFy&)8p!ka=tDSxDy*GgvoL^BW9 zcGER;i{E7Jej}9n3+3-Lb5)Ja)XY^h_PbIajm>KlTUf8iHYHv+l=sZAs(x574sRdg zVXzdqZN3wXy%fV&ioF%_4P!FSc{KTM!L+X7T^ile+TJOM8=;b%@6f2@t%CV6OID0e zD5gC$d6MBdxQxYXo-bPRNrLfh#kjxXJBoRCP5#i@?j#ugXKi=X{pPD>bnZrsE8L*o^d;|tb)73F-t#aFlZy~?4##qY6rvvRsqIX>1#=kGR) z|64g$)_#gI_A@$vuPKL7f9dkRZoJ=bR}Arx7)11$d@fRL*#gfYXnAYki-vmBmpbK=QXx#^0$CPD71P zviPQ!f2y?~q@0`wTK;*KkC*Xec`vp6ujcCXFS9r*hsgX~shrY(D z_$`eGJQ*tM<5lG}x}9)-#riu$IZd|s5n!oT)EEk}1D`j+XvaHJtmhN(Y^N~x0_2Pb zNXXnPkolC8#PRf|v8)H05IXZ6*mU=~vp^)*I$Ny&?B5ZobBc7XNog17T zD1V~N9X8|PW8+cQA0D`sx<>eRO8+k?2RsiX`Pej4H)l^{e4C_yJa#Mb4K041F`n#^ z_=Q@Rhfna3y-rWUgR#PRPDk3Ws+{l$D>?VJ_Ww36l=BnHsh72XK{+~4ws=p=H;zkM z{u>rQZk|j(y=h!pU7)v>!)U7;^uBW3(HJjIa{kf|%IPbM<9$pnPx{8FWFnNM;{ zK6dPT80+fW7dPR+)kpJMTO_+sl0hha2MYEN0B`(~k zSA=7~#PJ*-^YOfwaCp`j=K_p-y&yckqxpEsnfb{$K44sGoOMSS57!vqWqg-1?#Jl- z;kjh?A1iDBr132lpJ;ruwSQTewIyWwuUPzg%YWVC7(x>NuW~+2>j!1yFxL1@jfXMH zEGQxKvy<`X7T-}h{bBu~x`Zs>F0D0xr11g9TPw$dl>>^{Fh5_RoCX-*sGQfa_BR`& zP6_j)aazgpZ_{|pTCXUMxBk|jqsw~-Sn3ruh63(`<1!hWVwr%eZOj7p2IPoMPfmC{ zf}Hyp6X8qa`;X;tfrjFQm@E{0scSixUsex;RHAOT9cFA=wVGxbK`Gj;6zBeG~@v|ABvl{Ui zGh%zo_|1%T=v-Z6bKHn#Ebt~>)CldGY~k9gfm?}cq~HJ2^?}k>y7;rr=;FqYYGgO^ z3^Tf0M*c*N(A$jgob|Da_wkHn;5Ap7N_bk0+}Hbf42IMzY77-Tr^^d>y#9#Y1+GjX zFf_#BMT0s^yYrSgfyEoxD)N;pwu-~dkr%G$6v*(PnJ;_NpozpqQpwL#Vjv+Z+&51ar)0~dp>N9f! zhJmz-7&?*@+Ic@X(Gw>%nG?SEoFRDHTGqxabK)}Y{O&nRnG^mrCogVJ`pa`VdJbx( z={)~}Z?g1^0iw=Ka{}t4>HPfQ&eGJnUe=&Cn4MPLw9)j zDQoOqZ?(BOVKN=5SJW6v@{;TS?|Mb4z{Zq{IA39P%G6dfhr1$i0xm-(C%tD*8fs2> z*K_vreqQ$+4mX(}oVRK}hq<$ZInkZCnz!NrWUGV=5or~V^PanTPVyY=PqgQA&EXD3 zor}!L99*(qe{!d-J2SAsWm;{#)jV_Zg8AAiyMtdOtmh_dURkc+-QkP&bn;2_GzY(1 zs1v`%hLW|)SMOz7)6BuM;O1c0A~^}A^)%-rcYgJLKJb3F!Vr+Dzi&<$Wlny>TkUF2 ze8rr&ueWk1w0KTtj?^n!IFv9|TsUL5f~#3vDPxzhygT2RgGX&Fhx;k!3hBq4Mm3# zb9kpzb3S*614mlb=42k=q}5c<8EH;>!JIJ4a}IHb!%h0x$DHU+8sjfqE|9TvJ`Kj{PntI}duRE6vIGd(JiHgnP`1|L|7h&B-@`4L`wDa_kM7-tgX1` zl2&(^v#@j!uAnt%lJ|41=e%JKEB0#+E^Bpua6?z8{)IdE;YOTB>7Yh=p)l8S8s&u= z#f1K!rGvu37!*<{2$z~Vt&Z+=bmtrFUt~$WP*C|SYA0F43!_x5TJnsFB^?q*M?=0tZIwTKdr zbTX|=ydPY&X)EpyWLj%@&wMqLPR9LSZFR3Xsjs)X-<-IzIr#za2X}q7=Ly~qDnDyK z4|xvDAINe&;{D(Tnzq8tW1XLWn3MQ=rp)2R=7emHyoV|MJm>xJ%}iO+w>*dMVA6@7 znv;+4ex|yED@C0jR!`ux;@;kBFpg{DprpCx3^ga=ezQ1XYj<|`R?X(*oxGJeQm?2n zlo}<7BDeeyOcrl2@|Fq)7jEFHvxYh0eRCRxf*R$8@Yb2O3RAq*q2|b|=IlA2Z4RE@ zvYs#SoJYOqD?R5CcW`r1=jTCl(w*+$exlCdL*D9sZ}q%6;T`N>IQ2Nia}F^lz2^OF zYYuPx$+Y4&o`YX@#mTsLr~Mq>+H!__&T;1CExeyIJO{UpbPg{zCq0O-s;mvvk=LI8 zkb7@f^VfK*56y`;c&m5KNw>T6syla?6JGXyyeLt8!JM4E)$`u!9vqHjD?jV4j>e9H zPJYCl@$Nk6&Yk9@33Q}hkqsrhHO>PZZXKWpo*J=AO~8$0wHnoiVjs5}rG?^9xALdn za9G1iZjEpYPd;n+aT~XI&q$`27g|SnL!KeYd?mN=-lWg!p`L|XN0xQ4TRpMurC}WF z-S=`U?_?H#MYM+7dc)Cf<>8)ns$2Y17CNhEx)r%^l=a1vCz)d4&XHE;&QYzUF_y%- z)vf#)zS^|Xc()=?g=odQ8y0>!(5(9!)&#R;KPKJhGdznsGM!@bEc{%iRy@tKcv&G? z=@_$en>LV4;4!q1rMWnSh<9^NQ6FY%s%c>Ima@;-AukJdHgr65QFGTmk zME5f5?&O>{3+oGoMtvdKvFLK(rx@)oJ;fox=`O4}lqR~j zu;x(u*u90dg!sdYwhfJ1Lh(!Y7S;_~SaT?!+L&&m=1|_}PhYl1?VqT%f4DXldeeA9 z&2^AFnc$K1fg&iRa~I8y;`VOP&s6sl4-)dvZ=G(};4Y=5+BUoJ}J<Os=ad-Ff_U?r}+zSma?de_~<#{MjVbjRX?&V|Liz6D|Al?<@vc=8a3x_T4-p1~w zhL;Dqmrs>EsaF)OUQu9qA*u3{SV}1M;|EI#X=NFVlUAX?d;X|b)M&B_&7~SSQqqUD zSy0jk0akNCuTqDo(PR+K9r(erH;WV=Ivc1H9nS!}PIpkTv{2Z(6+Sqnv^8pBXz`%V zxCh;1k-0gOiO2Ed3ML*y#bQ6{D1Jz2&rGQIs>d} zBl8wmM?~jsL0c%9w_YqbA@g?85>T=h;<9KVYhluI@VFLGm&ezFRO68vJfR~NEX`#9 ztgR!JchVO0rgOrw*dH=Fg5C%8SZ_xv@1z~5o?~=Zs;e2@6G}@VmGoFBN$MA+3z)~6 zK~hF3+`t7nhu#BJT>IfQsxvXG%Pumjje0gx)F{kguiaUXWE!W$@P) zG@DD>k+Z^DH&Xv5P3NXqir!~bta=lsG1`-}C!?^SdV*eWs>i8Lr+SR)E~xNNsBkwu z*~k;9&ZqZJs_EPkk1=l=m8{ZFsk)(1O89_^Pc zb->CpI3PSveRz6CGP=YERr-;N-@owE zKEPM8DwkilDs=`qA5!5)s@bF&m}#x+R8tM+d`PXE$h=$VZBOqMF5J%a#?Tv0Z&#|_ zsCH!Z21c8yuHu{zp~|dMLzO1e+mzlEs=@SLXEbvKzE8Ca_7dqMs;#I#W!^CEIi}Iu zlwM756RK5Nn`RKB6R{&n8&XZ>E@gd2Kc}}Y)t6N3F!~+Un)JS>l5zh7lz!!Zq>|mh z&s6+crP<8em9L!#6b6$DI1Bems1hU9Zj#i!X;ixjtY?#!XEm93s95bLG%DMq$@Gq< z_Yzb%mc6h}O*)BcS$Zckx-``p^tw}>O|>}lo?`SbRG3Mta6W&h+MM$_oAqb5ruv?$ znd)0cx1;)uYGOFcpG5Qw0(e&P=w--HDwh8-E4QKR3Mwg&EnW`J3r!m@v-dXe( zr+S8}BcpSBL$#;suoBcfR-##nYBp6Ls+mw>RjNnmtp^nX7jr{;^SH9Mq(>FDqW34& zNP2Uq#`MLlHK?Dc=2LykXq(l6vdTL_h0p0NPW3+XdaW)MaC}sTrie(R0#5pU z2(SaZG#e^4>fWT@%zK*NN>mS1u>wwdkS(sKN0siRdVuO4Mn^LG5Y>H~;o~W)i=o0q z=ABRVEWPs>Wo4Ui7FD|~pw6UfPjwnqPpEVvRX?g@p~9L}N6{NWbvV6ds*>Jls=O(@ z-Kbt31hqHSi%@AFsPF>S0rVbc^iX>D(|c+wsJp11-WDIXF#0=GxQ=Qry~_qk1sq+K z{K_vWg=eIgjTO21a2))!vJ|12D7vrcC#ge!(T|E=Tl56c1E|x@qSqDu579e`zE<@6 zqA#RQ--#FDZWKGnms&^pSFo*Hv;RpGCio_^ssZ zu}O4y(XAw>mz>uG@%xFMcN(X)mgKBK9o7-Q{fX3r#cwNmsOUD-X$#3|C3+9hLG+%Y zKR(*fv@(=2MMrrLpCtK1~{~ zvD38*%w|d#X!V#4ly2470Bv`t(pnn3N2$Nk!`kjgsYsI^QF{CjS=y&G^AV+q8oOWV z1&!S;6rO1>%Xf{^eqCiP8J(eYwYD?5SToPk*ab>wE1j!!u24E#+nukmaY`3z>~y7z zx@rZS#-_B=H&((X{D#|M*mzpP2DjaXy5r6mOf6ay))K;BuaTIRw@I%kX?Yvm?vbU^ z!ZhJDjg8WF$0}{FeQME~G+JZL+HMb}kxF}OX4H0Fv~0}AO8W?fH8k^SrPY-tDJ`Y+ zoVGiAuFUnz8aqKKyr!{zH1@j2#%SyfrQI~PQIo8Roi%nqd!eYYE0s1?x!Y#zmHPc5bM}bR%%6lNDSa&zUJ(j!Y3$ARcoT3XKWy{6#uv)4o&L`XINfl^ z2d37w!8cauE6EI6$tF!#vf4H8D2o}}%pR?e*& zdr)KJHKvtwa8HKw5>C}#?pL}&tH=CX+g&ddCTi?DrFK%gCS7YWAw2mn^ZA0(6-r-f z=4G1sgVM!HbCoVsTA=jCpVCV!tya?{6k02NtJP|j7fN3%t*G?5P*}dLEYW95e_u@K zL(M!`Nh{!_Qt`QE42cS*Sz@|H*2t_tjf)^N~_}j3Ct&z`U#~+h46$l zW2l!Rbzg2*x<~0zrMr|S3Zb-uOyN0=O_?wBy3(aWC^aP0+d*UBXlw(eSxT#F<^rWI zN^PYcPH3;xL20JeuIVh4zEJ9}vCoy3Q~E?GywM~}^s3UKZDd|v*32`7(u+!$Y3%vA zvJ97N?91O|O@VUtj2az z+DBuXOYxa7Mq}G3ZKJVmm9Ej^Gh>9(HuGdY_tw}jp|r2YqQ(x?*bt2!s_yLn1u*Upmpd|K&pq4d1c`AYv& zI!7}<7Yb)9eW|fCl@9MJbA5)kdv3}@A`~81x?5vbA!dTo&GV&~$Ca)ZLZKs>-ZdKQ z)=KCcrI}4ar)p-~)}(NX#uigLKxtK_{WY_nP})stHKiSu*4E4&wB0&Nz2?cBZJ>1H z98wsjbc)jULTNvtw42hl8rw%RjSg2DrI|--<_M+Zl!hyvqM5^#wrelTx2{k)X>p-- z=JaY|1)PPOk}pzWTWDRI^o!DpA##@a(iCF7mEI5vYYU~jmDW+ZOGzu>;4gkkyQr~~ zG&Wp&(b_fmy@FG)+BIISW-o20HE2*Wg}sE0q#{k)O=+M|=&3YHV+YTZ>FuOcDveRv zMrnU-x4BR_Kx3OK9ir4%GnK*$O2=ut#WeGHrH)FwwU)IscdpFm*KK8qW(uWc+X;QG z)Ke%Zg%ve6UE8g!^sdri&D?Xg^tqYF?)X(`sM0v0v}L>4!U{OHDe*F4+kj_Ju*S5; zjaRWrTH_{cZqY{sEp(!-X2t1{nDmGU6BSbU~2-anQ8>n5E){5UB5 zm2$uzQWC}=!9x<9ztDCujpJbmncrsR@IT`nEdGw=|K0d?<2^0^730G#{(`kX z&KM6pNdG4pe?3F_4C5Zgcq9VXt*b3M6z~hCtU~-4$O)DFHA^#<1onL2U`9>8qS7Wxh%#{iBA3sbw|7$$fxGziK zV*NE6AJATxe~@vdoHjJR&Em%49^(xx{)FYvgTh%3+G}8O9%4`_lMhi*I8* z)#8Jd)8}BRfYUetc*Ps9aKq;TSy6?XTq?lyHlC!6L%GBg9uJWCUB-CTLMpuAIb*5# za*f3wwf5&&96#zy|9DDI+IQ88FGxjg2o-4ug%)w|F!@!{VciPdDB{Ih}0zV~mdl zO9h;TL!m97D!`(A%;y_0`Zia_y({J6#yC(&e3a#%V2r(iRB$;+xl!R}U*o+ke{ah_ z#CVMHPu3orQu@cHkom#hfgDaT-q83A<%EZqB#!5RW%}3@(*8=zH^yH^(f)q03NW`@ z{yUa`cWZ6`rp5bPg&DkXCHWxiM;CaDGG9$}$19E@b}k z4yCle(&Fz~<(bPZ{;9=L3q#uDkxN-W_ydg69_K339uHOu(<)7Hg02#hbyIjymKY}<6J`e>$;e>$GL=XU*+_H@fOCDjd5-v)5C+5l7F*kCBkRQ8HBxG5v?gwsli;~6o@AEq2mx43aS)fhzv zm>-U_{MLAcfc+nB@uih<^G5sYZM>hw2UtG#9-JNuNo)IOwW3Q}-|`<<4m(=>QH$?o z@oOx;k1_tFu*^RWBr-p%86Rfc!`g4yTIYXhi(k@K8Sha^`;U~<&vSMD<{JN`9J;jA zeCH*V(+taB!}t?p92jJNrWhY>?O(S3P!vhl=YNbZw)~04R~WzhyUq{Z&6NKB`bGIx z<0-}{qEtu)oW@WHzw^DEc8v-+xbhd~^RSUVt{5b~n(;lx*o!cpaE`!uI^Xzui{D~= zo$*9#f34-?nOW%{Pti(yJR2+hU1|K8@iEr^3(Mb9E7Bw!Sfu^R$^q{R2rp-G3_ay=CjDyMxsCQW&3H58kF9-F4$hk!<8eb-zLA!H(jaRhz*_J>5H|_r#i@&5CZZ^)w6<8|ZG=_q{_JTLy zmJ0#b$k?WD@~spYE1D5rHUno7nt{^>$tc};&W!!M$unv+O46jKEdz%E?c*`axX+Cr z=S!zyg7<+(?qw>^dd4|syyY2qz*I8c_KYLVc-IXSxsnXL?W>dB#SA>Ct;Ps70!E@3 z0Yi;V0$yTLV+%9zE}S;mz>F)+2jt-1)mUS=Wu*%-G0`aJ3p?sApVm8N=K-*D_XUhZ&O!I5t!; zBKX89D`bSJxa{Q1Z)s_B5-uvl2?u!&wzxVtIh0oN-iNfh#B=`b{ruA%KAO**@RU23 z&NnBUlP+}U4bM5r98{jSX&vD?*z!8hE$+OA7b&eD z6WY7;p69gke&&15JUqZI^RvKn=D5?Ti}n1YJ6+5PLx0m&p_e<8rdvM)ywwDA!p5F+ zu{(p^!7r;;0jDvPc(IBLY-|L+{4Yz^!=3lck?{OxC% z!}T!SBxkDUY-~=N><<6@vgEwr&KS>`R0m%8&WcP=w0Jm9Ua z^_*Kh=X%c>hZ7gAfU|HYO~Y1|;wD3}eOLgcQNpHCdIrDWTh4Kwv%NX#RCl_0>%S zawoad-kqJXon=}-VEb_D3k$pOuXKyE6V{@XdU@HLaI`zzsxr^qKGF~01CmzPd(M-d zbF=4MjW9jPj*SMRCZHW6OxX{LT{!?scsk-9TZDa(AGtzSK%ZlbS zYS}ETQ?s4-vxDcH>8(b1&s$;yaaw7R=NxQKz)ew^dbr1(KJKvWgS1-0`@x+|o#zgo z6R=e!r7>5{Szmb7Z;4(nji2`nfitqRYXk8vk!qqjSMxpQzE z>t~iZp)=lck)FSEr;j_+y_GxRGtXJgbDnqS40Bk|Pp95+o^a=UZ}l&4b(K3!IKas~ z-|o(^PJl~=WTZ`rqj~A{-Q%^HrBlqK;gSy|3@jDP?#gf zKISC82P-*Wc&p7k2fyuTKihi_F0;i+JGg@jXLWYT zXNb4j#T|}KSqfBq*PMRtTxU*N)>}<*hgE~5)idt!Dnpz}?(p4unZuXdVc`(TnT9bX z^K+s*KeaRGcz2e?eu_B>zZGdeE4tI-IfXlWdQQXH#hvvmCp?MKAyePPovU!aMCNca zcP=+4Z0|jv;W;~ajyqvz&pF+5R>gRcaD}ha;_;FFEG%ln)BQj=&Mr1~0 z#Ej81X2y&;a*oK6nKNR}$eEclGct32j2Ss1GcqzFV~m`U84-Q2>v~_;&Ddw(@7F&b z->=uX&wcLaobw!>ALm@>12aFjHv7%oYi$mgiJ&*)((^}Va=X=hs>YXQ=8&0d*+>PP zq4!FBF={ezpATr$UB$Z+?QddSy&JmhN!!rR#Aq|w3|~pgb8(4$^s~;)-DY^lkv7kp z`IVVwGnHokZl*{LU1Ov@kFqv=9W#&F^U=nXF(0!w^hq4c8cDm7v>9QBu8PG7Kc~0L zG8cXK1H%=zWQ6pK70kQVW-o1U>8IJ+(1&ZZ*=2^SXGt3_6CwRPZic?Gg!z>j`i>Gy z|Bdytnl`VDhpS$3KR&K{CH?%x<}%}Lw>GP_jd#G>JZt82Gj(RZJ6T7#!px_%;bmD* zs_}iT#(UJv$WygViOuDw#=DB%j$I`caE5lJFKPQojgwT`8uY~oZJ;+yUFX#6rufg7 zv~q=M1`)*apRyCXCcIquv9>4X+`g+w= z9$l0U^=DJJMVfla`g+!s*J$cbPAxa}mQzcuufICwGxd%rseq%q8f`ZE(8u39P@)}@ zzpkJc_mDf=xmR2}{!@o`&ztjV&C%V0wLiMwp|%^UaYNS;o2zy0_kWPC69X*=Ig4ViV${n3quk4<>KbXO*f&)eYK@Wjbhx^WHJ%RXX=K zbH2x%J7muDbIy-Hfywsmj3M_AbKW}IS24bBbH4LK&2>6=k#nD!^AgQ@ADHuH znDezdce^;LfHSl!M$t0ps~XxBx6t_L+Y;IpFEBS$vBsNb?%yibB%1SHuz827$@o?{ zM_*^_GU#U#EW;-|kyPHf$DH?-weu$OjwNHEA4GIsx<7)>OGjpPqs)2Zos)t^Jg+a= zxjNqH;;we?=jJ^A9IF z(YddA=a6;JRe3Q!=e!Z-yzgDRLv%DnyV3Ls0`tB=hf}yy%z52k+w&Ob+&kuck>-4K zDu+JbbIo~;&d~`Xx{LLod&x&bgn5o2xPB+wWY3IH`a$v@3o^W1+7^ zXa_ux=B4jwXcN*;xN!VS9-K!%&ZxV^x$Wk>rOv(X+;Zn0cTSF;vW$Ja6UlO)cJ1DC zZmqeY$~L|Ra~}P=uk$uJN5Amv_&#y&59U1n@=KnZx7xX9&G~+7&evhiqq7Nh8FW6O z&dYBdWz0T0r%>H~-eKfDp;or9fkjB}nj?;Pj;L|?I^U7R`JGv>TY%y~~c z_W&R5Wc@8O=R3$dg1Aa^UXOD=bG~oQd1dCj@67prX3po0vFGt4?=bQ_{LT5k=AA=a zh;v7ri!kSPITz<#qdD(NaZ&+iXji27oBS0It+s=-4Ej!m5`XUj*Y4b8b3=7%=v;7Z z=iTMn-Db{rr*rq3lfurjjDM?BLuVT4an(HcfOA~#70Ycl=d0!&M#fCP@#``Q z_{hP|+h)#NWX`wMoR5C(*LJTu#|1Aj@5|0Ln)6lBR~@?k%y~~8HTQ;dzjcmpfq?Nn z>D=4SJz>tb+no1^b5-;n5@RWHE|N}9g}d9Fcd2uCn)400+nu|@wYyfFRKW55uL?LG zUCU0FjE|CPH*WHOsdi&EZm9Kko`vpWeSO#so|}bhi%7W`6uOaeG15?q=0uOAsUb8A zRk3li$Q916ijBt=YrOURP{Kmw>A~iR-t_mt`NKF>$>tS}%{{~L7Oey&y<2Z}QFJqHuQpY`Cz)l8pF((;J z0;d?B%0*(ZJ|ehKgRGCA-A3s&8Y)$taYL1b?B-?4Lf+;i-Mqo+IqqgNr+Q8@wih`q zV#oDLJSj%PHB6}NHa&2K6RLMX4_v+k1xlp!hpY}M>A}TFyvw+jkE{-<6(NO3q>_ZJ zj;UPlK~_g04_{WtUM^lCtK&URvN|3+g_5j}drqVDH20hxMh{PN>ZQ6eQhLX$Jd2WC z9v%`&NjAOIUMS~w{ZaHFd&2Zdlw?o%D;H5HP8Kayo-kBc~8f+d1{~dT-*?%PEdZq+naT!%4Q!2b^S^$)=I*D0Nq4+xk&;6m0Kg zPF2z$rzbdZ-4*ZWoRR|Rp^Ve5oTQG>Tb!h}&{0lOZ>WGbFjph;9^}NeioAMG8JwiZ zhwP+(;!T}Er%6+_8Mi9GTYx`)$yyt(e@w2M_k^`k{yd{sCaSQE`WFNi#T0#meiVYt1^gREL=^-OW`EdWM*?pW5+dQyt6q;<(P4t zxN;0?$Z+KtsTbpQapHP0o*&nD>EOimV!Xd|s^;`Iri+V&S^g%)}G*W0q+SsyEEl!V}M-R_(8Xrv$zvVQEQx&HePO>jv$m!8gdbo6TEb zHRDz#mznV{Po{^*Io-;sg40}1C7e#YnjY@r6wB#OP8V^ygVPjFxt!uTUCil9PS0IQ z4@sPU&FLCWk8^sI(-Ka9O{Ir&PIp{O4+}Wm#VME5eVlIK^aQ6VoPN#eB2LRVZAqYq zr#Zz>rH5xYjpMYQlRu|FbLzU39$w+}0jI5;S~zQnht!VQDEV>u1H4M|-zC>d z{#5c)l0TO`{}FyJ{gR(8;T-aqM^Al$950neiS~K%Nv6!|1bOC4~C=!$`6=b||pdFwP(3hYwpnri*01bi` z{UD3&Nm~jE90So)(Mw4SLZMFi(Qs9F!&e9<&`4H$s;F zD(EgL<>J+YUI5(>dIyvPiakN*%K*&*T?bkKN&^*v;y@3Aihhu>_(0>o6Z#b>3iKo> z0`v^%3n>WWy#(3?`UJET^cm<;P#35e)C0N`)C;;D^zcdYjBWz0I2q3m^djg=(0ibZ zkJF~4S~EkNk}tucZHV$PbPbJ=7KG9AFe~Ug7ixm?Fm$_ip@-2j5mXL31GEGb2tw%? zuLT8Kv{=jzkQR&C0Mha>FM-}hFDNAAT`YxNJS`nl146MF4+UX76pQgt5XLJ(FDMA( zH4iHRb0%m!3c{eU3w_p*=YBUzxy%5ifNldtgJy#QLAfW%d`FMVdbw<@&{v>b&|y#> zs1gGSAUBBec0LqOT|DzdF#H zplT2b!gxzSC;{VzfKUQP%e;8w{=Li#edjKHqOgni>tTgmqEV6srCdCeYw@0wQZ8Dc z#ajkKffnzVpf=3+-nTO9L!d+{!s2}nx)jt6ngsd>C>)fAGB0C6Kb|V%?E6Ogyfjp( z7jy&Yb5J&DKWG+cE2tRM2>Ly!0rVwkHRvl)9q0(?Y0!6|_7X*ud!YHC$z!CK zyFizMvOqbYOwcUQ#h^PuF`#=v<3SID&IXl%IzN|@Jpl>=)q-Z31lp z{Q>muWEu6dpmQ%nt(l<%>{vS6hPK+rbp8wdhw3s~7ly7Bk!jCk8fw32MHl*pN2c99 zDD!;-ng<$IwZ*&h1ZjB^2<$IYH33>r^Hs}wa zaH;6xtpI%zAielN{{r0y`Wlo9nlV~hUJ9xPO#rO}jRbuL`l?^X5*sM7aU0DCk8{XSek7T43^At~H~Z(w|Qn z^Le~!w3^J-il`>zipxAp=crKQ4|VD);9oSM(!%}4 zD_@}O`vLP`rd+JtptWEYD2I*0{JF}PC?_f3qfA%J$auAv}C|I7)d3{*k>Hgq-UxUw4Zc?78Os5aX_|8*ap-iWnOMj0jhn)eQ zuN(;rrzt9-w5fpUQ|T|Fi3>33i0?@VR7{6ToCa?1~3+9jkt{WdGpFIQfqOy{7; z^tsBv8412k`6EAYjxwD%DedWZUFmQ3C~&%RfimqvGJT0MooOOmrTmBQ!M|7jpbz}4 z@*B!)m0weSQF)c}2Ib!k_sMEnWD_^agrF^||s`66hB;^|A zOO>Bio1e}Vqrr%bzm@E+w^%5N#pSKhAtu<~QdR7ybRKU3MK z`~zyh@YyOdovzZ820evDXZ%?#}dAMv>av?w}f zLu#hbt|0tVWjY!NKdDS7`wOpBeqDJd{Y=CCyri5s5=qmCUA#i&1>a#hUEVFz6O@-Ihbvbqe{me`-%xH*{y=$? za&`dvt5M!P7Q9UPfO4hsSIUnnpLjCb|4cbRdA{-)%CnWvQoc%gl5$%g#yds1@l)`1 z%8w}DsC=_>iE^g$BITbb|D0HA%?#}dT06#Dz*n`(*8DOstYHarF^Y2 zUDq$|=|X>*K1q3l@>$B;luuNqT|?T_#RxM0N6K`ruW+aGSITsWzqCJ~e4TQu@(g9V z1VQG{QGQl=mU4}<)|&C?`V(otNT*-%E%;I8@yfqc4pF8vf@M5E9z*-SVDK^J?WchM zrhIFN)|wf%E1b=nevU^Uz3KlQe7y?&zeM>BWjZRz^m66jDbpVSX-~U`Os9*dh3l2~ zD6dm~OZi3R*Ocds#Qd9-eadu#q4c*)`5EN{%1E_9e=H{wMfRQu0SV?U&Wg~ zoG&k-|LIgh;q#T7l~a`esyt7bK4D0EIxAJ?Z%|&P{H4}*p`(UOe@B@*pQl>j`%D_%go}~QacbNZbttr;MeqBBGLgk0`H=cUuNX7o9f z`u>8iULlcX;t>_}4iPO*QbDhH2!9pyl85+Ii>%OT#1IduIBvpIK?_2QeyzF^ZsHpi z{v)-;UKJy>K#KR8iVIE9oqT2JcbHhA;;$xtsp5ST52!driuTger%$!g$0QT<8AU{* zi3Alln@Ic)Eq-OE=#1_vw4oyCRVqa(UZXQqHW zp|=Q_tI$l7nP@ZbnmNtP2WEao=Znj-I?T{}2N-%Yfu8fsd}YQ;!+77A;oDP6KlJqn z=E^V=N(Uf$UQuQu%+UJ^%ypI-s#gU=A80Vwmu4>VQ`2GQN;7Ynxxvh?u_NY1GuzGZWkNF7Yi2f@*<|Lz(b}fbOnHbJ zzVuOgUTdb_%t14Y&4kkRV6v=iGpC!8;x%)xzu0|kNxKqNe4tOr6k2J6%Y0Xu*=i=w z%qBDUnb}~5-o&7vHD;bL!`He>&-BR@ZRpPp481FbdCZI}ACqQg8f|bHAzzgv{ZQ3Z zn6+lEHFKYt>&)D1hA+pFe(1Xo^h0l4V3N(u9Hr(;GY^=dZ#*zpv6)yiR2&cetf!qq zMp$iT6m0}CbbLV@`ep;>4KqJC^Q9U3HUn+w8#T;zr>zF^N^V*%#@q4Vlm#cW?ZqDJ86qa&#qX^U1s>QcInv_l&PleBgP8Kco%)I z#>&TdRxF0jxYnLOqira2S!o#WX)~@gj4K~Q6)!O!SJH)Vr7I&`O_!iZKU@JzOuiYe zZY3tq%spn}%sg#o3tgEY{m@$$%=Mg^E9j6R{XAu6@mMuW%)D=Ap_zkb3eC(qS?5YO zGuKSKnQ}9+W`1GjYx(AjHutk;uKHBRyvEGS9co@Qv(ii>8!0+7^at(~+VWCnbQ|hisS%Gq0OjVdf??e=xJ2_8{r!H8WHg73OU-{6|pQ zjHIn0Bjk%(#Y{F+Y=&=*EN$pRG}`2ud6@PonTz)YX+xi8U>2FV+{|-kE-}+%hPQ_F zbJR?DVE2)tY(J%rY~(%>2sCU(IkaDp}SK zX3qbonp0?_O3&ZXMwK>`%v75>-^?;I31%KObCa2S%mk0uo^N3zMQ4U~C3;&;pV(>d zjG}EQrAGc&K^gCmwV7dU>dnkGbGexXX828$^nAaW?`b=WdDjdVN)r=9M+#XwAKk=U zXokLcg_&-K|I$gDEHfE&u)+RPhfxMYa*v(wDV z9yMpv+ZU`=D!m|WE-1^!zQoNs*rSn@Oe3E5>TOcvk<#d(+xj zZ5OZ7j8$^+=<^%7?v&RfyCocV*L>p83dBx0ZddDfI#>`u0D$IOp=5{k#RG&cR;;J9g^X+C5 z&CE3uYvx`vF=p1AsiD6$(od_IC+T1><^wZyFA11-GZ&gUY-YR}Kl*DZBkZN?$fcjN z&Aek~1XVzgHh(o!9HnNPnQ}85%sgpkotZi_P4wX&v>UC56Cyu_QpjHW{9&_pvQ?7Q5=gP@6(Vl@B>0UQj zBJVy>t}4uar(7kMV@|nBE-p+x=EsSH!)PMvFNxl@s*wujju&M>upys0o# zZ#d;=>MN(dry_bXh;N+3aB^-jHLs@bWZnQC!rX_t2YsZ;4&Red2!iq7b+ zMtc~&;P{&eN-?y{(ANxl-^%sd~?NbNSBAcJ6gL+@jsh&fP+9 zl;IMcd&IdZ&MkH>&bejIUFh7C&RtCJzcA*fo!i%L?j`5mcJ5W@e(&5a=c=6h*12Cf zmv;d_4=Fk`v@5csFED5ae1Vor-)T^q=G?EHyWP2aoy&KQPP5T@?{kh$wNdw^bDg{s z$uj77S6%M&&VA+Fe_zkXPn#T+%)HA&><7; z#yUseQmgwfMPmM2K^U)fz{WRJ=!JfY(=iVfb%~`flkhxTQI~%!5;IiP|b9HA}yBOzQac+`x?>l#jbN$YZb?$Wj&RoXw1${Yz@tx(| zV*0@iF44Ky-Z6Kza}SA=qBBFg;+xctzM`QWa49W=zLBA{-MQt?O><6;iqh_6=RV<` zL!6&;z0UC`ODyB4bNopXj;qpQUM@Kc#~&}@zIAS@b5A&TjdSVF@h1$-d!lpwyaUSk zj?#w%EMt&94#Mqo?soc+2KTme4?DNkxiaV0I7g?_=)6AXmN@r_bNmq#%XpFgXksi} zGaK&xZgU%*+vnVC&TVn-ZRZ|z?nCG9bFRa=eCG~3caw8-gKaDqIyZ)oShD^uaBh`z zfzJ7HK^SS*!SC#3zdF&m51qS4oD`iI+7;7jUiy-Sc0f6@Nx_w0ImdO}>6;gt`iygQ z5`nrGoTDH9)!oZGg7o*4bNsD29DhBA{*F4A=-l_t$x%_3OLr&Kc~9b_jySq2kh*ht z1B;_m1k}CEM>6hp6sdIarJC8nUV%|5MyTrNe&J{TKs&kv1yNiBo zz%pKU?nC}=OP2eVa~qx8>s*C%A30az+?USHaqb)E?i5EwXK44M(^*#Bv4Sr`58`Wl zxmL?*!*mF1DeEB(I;la}um(BYKw5Qj3Ze$54{IRxM5JtpR3(wJAyU00MwU22W~UM{ z!)A(ME$ydrk}4*>TpL3!dzN}06||%?hRrlou0<-rZ00$oT#J-ukwPBp*hxoqoTMWu zIU~hFy${zibhUS#$qKWDQ&Tgla#xt;1oKF9%MN2oMecbImxhBb2$!q zDr-2&(~}A)@|5TB*kv_HeUPClCQ`9PDtkzs5-AWNbxM-Be1y~_>Ek3Tvz=>q$O_)U zYd}`I6rvg6H5$dWFphDOIvNK#NiB?hoTQ$G6x4~}>Ibrg4)HpZEhNQhWDBk0GBvV= z)^i1rYMwWmD^^Hpoph?=aptg<*}>0D%Ac&_Bu}r7lN3FX=Pkutx_Pav51@yGoCZ0y zbCRdr%t`8GG;xyJ8B*Xx>S@Rpl4={W)h6=_iSna|SWe-bWOey-O5@TbBRNUo3fb6G z#}mrNp2Q_Lnt5|hIfou(vrjpfl5F;z^C-z?mnssn*_WSBPBweU1(am7cf?YX&0Y{s zNjAHbSCGwKe>pkX>^WCZlFi=7l@(;OkISSaoBeneCE4s#b12DXpUL$UP>SYGO0wAx z-c3n1d;9`Qve`3>DGkVV&=In)ACnv>uYiM+TbJ{zW}oDlD>?6$9QqvR&61Bba^4`h z@@3BJBsXs1Tqk+S+nlQxl9D`9C`ih$NTHyJvuJ{p0$Rj- z_Hwz7I_xdk){?v4;k;1t@PW2ige-1^7=vcy>j_=rj#sem%KHE_l;)BU1xG`lJwB$`ioX1I?QOg;x@z7vCpy4&XEt)f4<4bEe z@1xKBXARq_{i3+}+4Lage};gB@h925D zP34rz<*BA~%Hvl>4kvk~Se+j9>+$m_T4UP|(sUNN7Nyv~#Eqm;nyR`9{skJFqYdKln8x04>A2Ps0eznGHL zf@^q?QX@~@{t!J#v8w1&N>a*d8Yd}lHSuSZHimP$rJno7yV z(jG68O_H(0?Usx`DqWKCM8Q^e4Ii!E#Xwg1(xbL zrCvY}TR25>>X3GvB5$IHkr(jguRDLuYpgGZ^FhhIl2ewkZ;>4E8_qi=2TEQo*{_=Y z3duo|%OuB0eZkq14@%CJd_;1N+r#6#jy@ZHnX~$>EaACC`*{R2wCSNp6$uE93t0R`k0$*Uw6NiLVX zQu0E{jgn_e-Y7Xqa*6EkL6WmJ@ECgMbDvuz?~*+I&+Io!&XBxba;D^!k}D*ykX$Lb zSaP-G`I1*kuDhN4UoE*Ng>$3iO351~S4iG3d7k7>$Vq|)KA2V+TLUV@G#{uE zq?N>~K*gBWKT^iB05oruPzI7AST0wQ7CQvazFYI{}2GQEg*I~=xhB;&AE zA{mD*pJW`iDkS5uHKB^<#bImy41Q;f!`3dzIBe~ajKfx~eBpCI4q@J{+`dZk>^nGD zO0Jf?MDjYx(T0B*)CXe*Ey))-n2BhB}X>A8+Wt&7k2$;My@wi@9}xv_xAk zNK3Spj*wooD4P~+)1quzByHlLEKiH1`S3wSi=<^7mudcf(#zntLSdjTP&}v|6ggH} zwt)5p3+(_M2CWDApCZ$Ipfpe^J{rviHQ~K#0caa&K4>$j7E}$2#iy@1pwe_%LMEsY zln&YeN(5~NO$9ZBnp@@B?gMRkOsEUA5#)EhETJA03aSN70Qo>O=19vNP~JlMbUGHa zw?yc`Vp&4T!$Pf~GSEIy6=*kT1*j3U5wr@l1yl{%1}dE;OV|#oUneiry`U^m3n&BB z4vGhLfucY?pfJ!NC=`^wMrv~%|D!CScVY**v^7I$GN}0`p%9RFtI(bXS-St-LYqK6_X@2CExk{u8dUmV2vznQ{$kXkepXXvYcCOCqRnv`{!`8|Zky^s*b&4)Ow}WeX?_RDe2%<3K(CkohKnc7vva z=74fQGeO0m4A89cvV`fNIp+&a104k=fZ}6hS|}(B6bvc?1%lRt_I1n1wt+T-_JFED zEuikL(y|ql|DsS2sCT{4QP4(E+66L}T2LXV^H0*U7<2$s18N3s0c{2~ff_(#tEA6j zQ0OM1d7w$4IiMs^CMXq@4w?l@1eJi|Kr29LAIXlq85FTYs23FTrqEGP2q^3o8B2eY zObZ7c2E~9HL6x8&&VMS~`ckokf^EucXZ!R!Ye28E+Q<{{8H z(3SyNUL>di6bD)fiU(DKy8LB%6`&qa4ruI&GA$023_AL~EMXC-3$z53^Nmca1$7<~ zS_?V|S`VrRZ2{GRc7SR?*#WZj*`SKEg>pbupe)d8Pzq=VsN*YHdK)P1O`$oi1P(CQI0?z>y44MuK1to$aK+{0+ zpw87YvOG}ui9!XSv7kcGfWJ&z1Zo47fp&s+g0_NMes&7)rfWdict=_dDg;%6)_~@L z8bR@(ouF_~8z>N@?{b^(9?J`gI#=FV9xn=FN(Ai(rGR#TazRa?g`hQ{MIaw&F=$>< z(A?|K9sXh*OILExayHZTAoO1eXdOtapay|G#a$?1X)4| zNb8pdgS1BJ`~m4D5;aP9pw?(KXbvb5Gz*jgN(CjHB1@P6S_<+5Z2)y2lO;5Pc7vKh zJ^zqt`$6m5g$6-|py&`;LOv)5Gz+v4G#xbkN@XS~w^H(7%5Ar2UUYAL3N-C&{j}Yzw~klv;=e*Q~){xngZH&f-GSY zs3TCQ|2ygB$VozdpaIY>P*9Le+X$KhS_zt`(~e0m3qe_52$h45b_rF1wt#9tOF(Nt z6`+;p$nq9}4o?v(0nNHZCj(jf5s|EFdR)PFpm1%oHy|qG}pjJ>9Xea0x zXb~tnUY1@2ngv<_ssQDKsz8gLk|it$%>&hf5aL^%82q^FdS>8xcSgTNPjr180 zY6DFL?Ez(gYCyT53Q#_1%^lX$;w12yLfdGE{eib4LMKu|WwA2buxb%)HC z3u*_If?7d7Pz$IIv=db2UCo174mvXT8m6Q>lIC82-oL#V>3i06&*$;9f*&8{uU)P1 zaP1fbR-4)guHC29$0D9!?AD^36w zD@Fl3M`1yYC@mZs0$c;ES6mG2QCtL!Jr#}T19KF$XmN?+wtvbX>lFv~3-&9H1&*a( zKzU7iz);2h52SIbVheDd;!a?xVm&bJGIUi944nb21P0#(ECo(eoDVEe%mr2|`T>_I zcC^W;mMi*-1Xn87+#$F_F$L(Kh2<`uB~zmn%Ym_qvw?AnnZOjq3}C(DRA6#BdTza4 zdY+U13t5v}fL)5KffKLC)G}bX;v!(3;tb$n#VBBC2IlnxrYRmymgUY@9C+a8{8}hi z3;`}yJW?c6YZV)TO^U05M-*#-1B%takvCu%<-iGwrNCzTNt<`19N_GqJSG?moToUh z=;w2<)%7b=G- z=PQp>p4gB6rYP_D7dTnDTsd92Sb2u>JY~9aS;m{7oTD71Ts{`l`}u}cwBCEq0@o|A ziv-sx2b>LFraVbGS9yl=bmao&EpMa$66NMM!IjGGP2jc4Ey^2|cPh6iZ&Xf>((x(p zNd{Ld&nX9&D<4qKS1x)K(^Hgv%F)Ujm4lT-d}!bI2F9DA+@(BEdAD+j@@nO3<=$M( zzfw8N12-ry{E>cGr|q#zd75&Y@?_-!4xkh=Wa<%dTZM2XCzL1UnFHjCu zo~o6=2L>^Hmoi-xB<=f^3zY|y>AHKF?yr@@`?a!nfbt>b2<0|q-w7B`v+}_c!I{d- zM}w)diS)l$Ia&FzGTk>`rc+sGd49|P0Zvq|P%cp}Ql{(6r2PWrH79}dl$V?jPO$09 z6O>z(k9A{t9m@4B;K?zVfB&Dr3Ci1)7b-U>A5tDzhxU;dVE%696y^1Q#Pm$%EaiOV zK;?erfd;f6IT`(RDEljKQ?5$D^v%i}8o{;7o0S(Tw<^z6ZdXoMKB}CiJgA(Y9P$$S zKRyQIOHtmnW>B{0Y~{Y^z{Sd|mCKbEDX&y6P~NCKQ+bPWtnxnPvC0QtA1V7opz?+$ z@X^&+e(0Ov9_0zj2bE)$cPpnVFIUb{o}-+n>_wvgIm*?Q{qp=5C>JZQP|jA~t{koG z?ZEiPDaR-ue!QO!1~LEZU?4TwGx-8SIv7~Zd8-jW$d?$B@6sCZMRb}VGu4lg_Ef@G zxI%eA*{2+N0;W?T5osT%Och{+=PO4luTu_Irdnpwf7@}ir|Qqb3Ci8d5z18WN~Q-Z zQ*~(JZQo-4ampK%cZ|VwDqdD@?_{1)YH zw}JO7m*#*EDKAj&Q%+G1%EtI2lqV?%DQ7Bgx)uFJMbJ3~wEyQD&j9BrFSrRDr#vwW z9I0HW9HQK)9IU)i*-kMeY#endG^Ia;gwcm06=>@EVK%61ol@c+1r zfZY|q?jm4!1z6&b@!4Gg4vzu<-|hUwQf`XkV;+;2^j`xmlS?>dE@wtW4$egy$=-dk=$am3vFTyOmp&o0WGecPTGZ z?opnp9B?hhm!%wi0-a(&&o4uHe+W29x!_cAwDKb5!G|!O#me2v4a%L$o0JzRcPekq zr*j%;xzj??f5?Y)Rs(Ue^4R_0Im!n=09Ps>P+qRQTzQ3ZqVh)NV{Mqf?>Te)Qh z-B^;A+oIf*1>UF}oDE*19IjlgoTa=>xm>wSxkkBId5`jJV!9i}@Ph%B#+3UX(02fG zKLaYuBfM3)UhmI9-vP*UD)27T2b3o%({}*F?$0nuxmEACkgRO?XQ0pYSe{PbsrOT$ zds|8WEz0zrfUv(m+Se(Eo(L{grqB4&U$Sz%@`7(Lf0A1^AFMea4pQUCN2d zJ<8$Aeo&(t!#cnaF@RX*@2c!{$23vhvQi1G~O808G*IOTNZ zdCJMk8%CXAn%8N$R$r99Wvhwm1 zM+#3??ol529?K66!1T?^QYgBT&i5F+^k%ny!m#z(<${kM>#5*PM;u7Qf^hAth}Fi?)4WAKN#%b ziitGS5-L#0|JN$7Q!ZDgj|(z=s`5(uMW5UID<4*-Vj|KX)sKA^YPpEt`Tqvx zRmyhH7<%U*)9uzU^y`7J-3q2jZ{@O3Zw+I&cA>jGNdI>075aD~Y`0oz9{^MJV(HIr zr9$QNgmd-QDap$8*;ezz60pLOSxT{PPv!o+oqhSyhXW0d98A(@^a;6 z%7w~@m9zN_4BDd2oJYRWw7Wx}3K!5!g(#aXLoPOn) zSa7>?yz&<1>B_5=7c1v0*C?mFJSgL@Q$D`t2W&s(Oy!7+botL=dW!N91xlFlGxkkAp zbwHl)8k_%@$AnvKy7Cd_Ny>f72#eJ9ri zC6Q%fi3)nJK)>cvi^l)Fvk<+&WtMQekx;w}`Y8gl*u95ROxV4L(oC$-Tl|!nu)6{+ zHeq)KYB4c5fW^`e7Ff#a?;$3cC{vN*r!DeSM2}Vxu40Ob-hT8k%|wfebtd+z*luFA zip)Umbidx6XyHjJcB@DYQbBh-lQl-~H!$=~Dt4P#tAg%Xh8B^>FiVYzoGys1CiZng z95RuhB6O_wLA9%8v56+AbhU{2CW2IKGci&HefNTK?D`C|c;~3tsbYD8iY67SOl(!r zXkv?s9VQx8w3?Xq0+teRxpwNe3?k6P(O*L(m^h+hj)?;*7MN&OvB<=371@8+rDUr} z`CP>e6*(pnRTP@oUxjg$nOLo&#zdWp?Iu>JXf{#xHl27wueml8i{c@?E*-~66=O~G zUV;{JCYn_wn%JTu%|xAwMJ7sABqxr*woO*iT+65Y&`Ui~#g3;{_^aq)K{u=Vw}T0d zKf?3q2UQA%REthZPj{G!Hj_$c(nuTnOo28%X1LC!%tg09MVoDA=9^h&rrk`Q8NS1Y z^gPQ9{WJ+PpUS36&r8kpQdvDQ)n@2hOPIB0Hk)ZOv(n5iGs$MQ2I#VSsE)4mJULj+ zRx|VkBgRu_W}%s-W-84rG?PfBa^-n(bzkW@*9`q^0mG$zrH#8eRkfMYe`%WX2yBiCfiJwnI5WfEMqP*({85L%tkY9X3EU8 zo9PJ9e*B_!S!Emep6@5pOW#b}OKM`xtTt0>W~G@LGZkhU%=9#BKU>UnnAvWo(M-ta ze9LRlTr<hSDQIx=75>KW)7Kg zH$~%H0!csaW@hw331$TyjHKr@GX-?W5L0btmKi$P3Uf^`({5&v4l6R(VKckUEI3K$ z$~6;7cWjdJ%rrB>Oq?0|nG`l{AEqq*h#7Cc8jo(uAj|6hKuv&| zd^3~HB%6t(GSMEL3AtIvGu2F@nJ6=vW+KeYHnaUUT~@i7^=7KfG?>eoVx`S|Gofa3&G^ium|1Qn#>{au`{^TqY~=}=+H?C|YGTbCFq3Ac#mq`GRc0E^ z_{{7yGsny!Gg)RrZq%hGnVD=R%uF2Js#KnP#~dBu?90{cH&ZfA%_cTDZd=Tho8h-V(k7UWFVbeWnPM~B&D5KTp(BONwZ}|qznXM9Qb<2RW^&EA zo7VB0CTYVpl%=0WGwwEaTg~KIn`36;%tX_XLi#!Ot@gu5B{7|5_-G<#>M?EOPets- z1e=L5(@Mt$8P8NR>&;}FSz>06zxFfBOfnq>r03~o%1=|1Y-XDoKQm2cHc$AC$~?Pr3S)G9Se zW@@idQ)VW}Oof^CDcYvWOr4o!W@63kFcWTOPrZ)k=+)X!IDMrcBiv)A*vu+32hFTB ze(ht3h zL_aZR+|3f{D6!0=A#?Fv3B*L3$)%%!7&bc*b1+2B9y2~OeP#}yqHV@L ztjjuRCeh4(Gc(OJoAH_1YG!k(_ET-1O=R%rG;7F5r-@TwKP9Hd41a0uv))X;nSEyR%(R;cGt+B^-vdg|6X?q< zSr&fJ=RIZ?m}xVUXr`x3pQ&eN=cyyGR_P=>Jonk5YF3z8Vy47Q zjhSLIyUgU7*>5J(%rP@7>Dm?969P`tWsPg&)5rLoikXP_)l``2lvC2E%^oxB%>;*O zKQ(5;*wAfXhaZgg(?R5HUQ9k6GNL$aHbZL$hThepO@$f$@<`?iqBk-!*EloulNih( zy={>;?zXCH&F~jPGS@6KbIdf*%UI@GZ-&3(5z}I3o*DXm9sOjR=`}OWOe?)aWmzF+ zHlD1e?_2G8ubHD}!s%rrb9I@CG~=T;J2IYRGYjb@DyGIvo|#57>1H;WNiws?Ostu` zr|Z(Y=>;i0ue(}JtC{I3YIc|@Gt*>dy_q#;wwPIJhEB=FGo50l$V}@Wb?H;+Wg}~2 z7roeIJc(vjm|1S7hh9in8)n+fgkG!tbkPey+QggLY^KOegPA>M_@yOt^_j^vlXa{1 z9AIWXy(DDJ-3xWD1@rKwo_A`yX{gDxN zm^pr)nkX~V&Q}v)W~Z5!ZtW)`M%z@F=`mAcCV=kkCdXwwc*s zW`~&yGrQQ(t*6uJS4`aFji2^bI;_YYiucf=hfWS4Rc@-oDLVBJt;W%z2CaJNO^wXh z%H+)2k0X*G1;=Yyt(Zt}d+)X+_y=?WvvICMv6 zca!J;a!2R20X7JC`{oX(-0hn)PqkM6p-6F-Qk^`y&erw%N zGBwVr08`7I8u(CWT{P0$-Gy?(8ogkmOzpfu&QK#2VyfAx zzWq9gk(t)2+0-1T8ceNqs?O97r&gLe;#8@ry*FC-#im-F$~{*HaonkSv-w;(TG7EX zY{mmWRaNNJ6jOChO*XaFsR^dMo2;+FJnd_WQ{AT0oZ4+F*Qs5mM!tQ5?o8WE<)`pT zdUSL1e><4bGvM!>=qv|H^w*qE)}XZCIXZbnU5#@?w>Qssj!xLnc0;!rPjGJNmg33I z@drFCm-Y#rH_f?NKXa3vYZ_%P(z!n8f}HDjZVDahunhiU8_%tgKBU2wI=9BTb?$j*vF8G_dLDz16v$+X4g^adGj&o(s zWja^m+&JgzoeOi0zw*YIkG^BeIC{6aJx{8^H#oP+xtx010cCp>I@k9{ zbBWGn{mERMb7juOI#=afv~#PRi*#;}bAHaXIM;W+jrq89$LK;t-G`j($J1p#?rEZUR>gq zI=6SDxoqdkUN#rxT#a+>&s%>hoojP$t8>lH?R0LJa}zi5GUi@?c{=^4@9gM>HjVZ+ zI=z9?Vp{ujP6nlX=lBCA+;Qg?(-wi7PHO@44&AIi*SWELc#{5xZfZYt6Z)YW$?v0Q zhjv3Zjpz4>aQq7ymNCz{p&Q9(^5IUFG3eSw(vMnN>xpB_%7id^UG5=kWu@pp_D{yYmxfJK(&$f1Z_)v*;ccNAI13g>3NZf=8fYqpu&=-fo-b~|_AHEY-9+$!h7CfIVb zoZJ6XbNl!Mg^Y!6vZd>9m2<0{Tj<<2=N343z`0E4+MS!~T#s}4FWYjxCv3UJ-Lyl? zSkj!!`_f#abH{L&AiX{|IXB6g9p^+1;`WBwF#yPaF++&1STF0pp&or`pCk#nuiEp%>dyvwTnxgtKK$~rjg z+`eveVJF!#b~%^gToQf1g?Z;W7wX(%=N8kqUTC)|$oiY<+{BB_r8vhw`(s^&IJeI^ zf9Kkq+wz4iqtCep=ThUWzij79oudnYb$m;m+vD6`=NkUU+D*FH`rGh~xm4$tI#=La zxpOO>o8jCl=Z@7`e`}pQ0bZ*`Qw3Ev|wAs1(B6I$`Z7d!4o6B%+i*s4dt#@vob6L*SIhX3(8t1kyuw@)_ zuEx17e6fkV&VA0!ns9=;8O}wYV=lnCLg)JLvt_JsF3-7s=Q5m|G|~DCc5bP2z4NWV z8s~y;9ME;T+qufk^b4V^=cCSLIXCHCTgHJ`tlbjlGM!s4ZtnG0qz^k767lIKbO@Y6 zTbR!8ptQv~{=5R$>)b@=`kbq#XM%S0n~(OF=v3K~Ir=3<+vPh~${&uUU4(OsoolAGit#OR&ilq(m2-UR7us!dZi{ol{O(hhk?!0$ z=cYJUG1}U7@dsvEr@eGWDdyel+*JC}6K=C}`Qhf|gSyOH=v=95SLNIs=h~bLb*{s? z`3Gz)LHwayo=4+ebMek?;}7q$++^osoy&5rn?Ha{JNeQ;`rGGRk!v^4x!hD+Zj5u; zcbgmQT(@&a_(QoYBcQ9JdY$V$-`Y*O*ZMo?T&i>J&dqmjlXGjFt9CA{)cPxNE`UF{%Jax{ zF5S6g=ax7Z?%a0gR!7^I*Ed>!dCqNdZiaJ5oSW`kzjF!BCA?()O>i#Nxt{ZEx!KOO zIyc|B1LLjT66gA#qZhjDFCETpq!ZcT2AvxzZtnG=>GW2SJ|AZKWh)$;)r}G_xC(~(jrw`LbtmRJCIh+RQJI?T74bu27n~Ql`H79A`d;%r3=Nobc z(M)Fyn`zug(!05RG^Y)mGC0+78ppFsZ_-g6cizrU78A~0%3>-{rj$dSM-1y+n#vtZ zW%h;KUb;L!ic%FX#*c4BUB+pEQ-QdXD9Iq=f+!`h%ittKtQ|v1hP{o`0FP7dH80O1 z;8a3+7RN&=$+IX5rzFo}{~45K^BnEt=U#v2u)#+04YKR_?%esDbdw+pG@doJOCq z#tmDwn|Sva2i5SoWZV{lWf5zPDQ8DgKXr*oMgjSbCQiQF^rq= z2Fm6no3E5p2X6--r)EwKyar{v9=?_mw(CzQ$#!kMj*@KG!@MqKyDq$ul5E%Zn<&Y4 zox$r@wyWQ*l$v=9%bITDbcEOWYEI+spoe@;19#CwGN;1@^bpOdsgNFmIVImu53+Yu z7gMSo#p`vUY=;$+SO1dNalYi*D$ene6QAN7DtX2#&I595!Pd2$k4w&fo^yxf`Etmp zmE7C|5a?1hEagsy7X~x654+i%#L~jC&^@Ig|4^@f+$n zSIPZf<4X8tTP8XADb59w6W)@Sv*Z?kK9(d%?n&Z|yV^!Qz7 z?i|g>5bUST&v3?mns_0v-fG!5XaAZr_S3!xc~`-H8r;A+P1+1>=hyRq)a9p}b5ZV; zoP7ZwKX3=@89xh~d;M9%c4@yTn%aCeJ?!RmSYEK4_Hzp7eNJA)3FlC1swQzD10%z$Ll>#w4RT29%V;(1`lljy-q zp@%+Bm7H3yq=)sKnx@giVNPo~gH($mY}%Lk~GJm6IHVS95A&w}Mj@A1=yb_(fg% zq(^zSEXwMx#AocvpE6DK)aa^FeZezxQe$;}P? zVkwa9J9r}H49QKSc#BMyTolhaN^;J%oP#B|-og3*balo7s$JE+CO0G@_mTutTCqhX zh9np#2}y`SMMXs=&ZwwZQ74Km)6y0d6QV?^j&)iwqM`4NEitHQQAeG`(O@67(^sC3 zc2uf79d)#*Bq2H~iIapFl91cC_WG^0+xzAp*7xnT*FG??IUEjWJ2~EXlZiKun)(HP zi!O6qf0cQ$EOT7F!NdWM=loXa;5h%xqvYqj(qb!N;vUDDubFs_9ETq?v6JJ_O=eeM!0SlvTzm$_n;eHeV&V$N%TGI#;v&aeSrg|t zUb?}=agO81%zIIQOrsn0>b={?`-SJc!;^^Uv$N%fuOu6CB6iZG4PlybE{UX>!8+PB3_fiIW^-efccM zSYICR7Fb`t!5-_&?{bVa#@g}C_nR1N$G17g8u8m4V_o^-#lmPG&uHh zoZ-Ei`5znK<@kW(DaT`u!{0S^njFVCc5}SMaUx^-u5p~Z#Kh|yFLB)DILh$>$9|53 z-!nBw7n}V0v&}nbn`8f1&5OCo@%q8T!X*)PxKW5^|f~gM` zLC_F1`Xy&}+&scAD0Vs1HYfoqf>uGxpgO4g6|OP^`M$v9ig1;kEv7iA3rc`!Pu@Iq zs~``k1{!~!+uT=R%Ru2fnL4216L?=N0VPZvKwfJc&7Wy~KQY2y&c%t%>~{yZ^|oBj@R!rafRdYFHD^0c=C{mGw`Q9mSQKz^^}Qc5mSHXwI<%>*zq?e-ioX& zo^#?0?2x{ique2V??pEmTBL$*MxZ57`6O;jf{G_Ig+Z&JAZP+|gCb7Otk6OiRC5Vb z1f@YsGj7X)h7U58K?Bg}RIaiKI`A;9f|8G6N`bPV1Sk(G-On`(pv|8#H9)W>fPx?7 z@peAWnI%vI)Bxo{v+wba*fPlXB(J&eaZv5aOkU8zJmO$m_GQ~kNkPGB_9=FlXiy1uyD*c410_yxVQwy{J z+5l^K5wvqLQy$a@ErIqy z0g&So&RqK|uHpk7zMLr!>b`+r%Oy}9R0mZ+9ncbJ8{`L#K?hl`;t%t9+o0VSGOdD! zFXPX10noy$n4F+-hN<~_uCn_YrUIyXHB%Ji$ue!ejx)PCrWPo?!juCw|CDJ7c|7h{ON z2%-OJlfD6b6mnRo)X?Shs;8=wj(_Fc|wfyQ5B>Vh_Rm0ML-Qu|18cNgF2vzi;rV8L9GNIHA#W`ActhWfZGnw9w^)Kl7%W4(Nc^a4?dVe4v~BqfB|w0;mG&(rOHpQUz^*a!&47 z1udM)lmX>IVNe+~qBR?+N)0r7fT<2zn=*AkgZr2!pjA-(G#098Ospdly@@}0w(^Lw~o6I27m&gHfOs0d0)CMXH& zfHrRDDm$RuKQTFiTxGw@+hL{cc2Ya1?4~)(B%JcW(YK{ zoNbD^LC(Kqn%&M-W}tPD@15LM1w}w5P!V)E;(k@o2($^>1?_;gK?6{weTLbHIjHmy zzZW{*#p5l5qM#Hg1xh?b?~Lc2_+s=r0{l9lpHx%H9grK8If>g2k8|4;)B(AVa9arE z1GygM%vDhA5vDrG3EBtkKFn=pCs*l$sz;e3p!_i=52y{A(P{^@olZBmZG+;V4k!Uy z17$#?2f5!mDE$kj&>0x-=S(@!_L!*%YJjSsJZSJZ?w1CouVe~?vY`2$T%`b72eqS| zSq6VP63@0Yns0OY)zX?z1$8Q;S+0;PV#v2s8mX?&T^&P#Clcii1j^B4`P;)qIk9f{icbnx(^Y4JARb z`VB+m4Mwm{Y6-0$EBx8*^Hpadv*GWvnWpvI%z&q<32(6CXE?-V9C=m0dIb7t&R zZaV;_K$VBMtpG|*nN~r*`+I>tLAU`O2lxuoHgU2!T@8!&GKhqj$1X=|-9?xxQ zkQ+4l3HJ+wHowl)0(rj1)B|nrFbzO?P~r((GXcti0-y$H{8jF^59)%V7jnOK(5%dq z15K`H3W1z&XWIG-SBZnxL0M1@Q~|Am)<6}|KB)aTu2Opk*W5hDv;``G1{H2gALX_g zC<*fYCAS4n=Qcmc2P%LLT->$+s)9B_%b+nR3hM20&CY4qmicqcGiwdh1-U>I(D+o& zj6BMj15gHZxc)5Dy$IU9EM}+$nqSJa2`YdNKv|IQIM+;r+Mpn)2RaPLRu(V#Z!gAE z=^JLps>2KvqRa0S&IE%%=SgUF1A~IYM>*MX1g#O^nqci=Zfy!yfMq9={lN4w;OHSv zZn%McVDb!L5?B{3-Nnh9ietvUpbIz`bN~Zqp;Y}&PVQU@^aGtw2Reaq!Py<0+!Ndb z`ksO0-Z!{)Rj~R##-^YXI1=2ga;rCnzFpvg;5u+wum~&)=77!DqHhS8Is|%w;d_Do z_wqoF`+#|%OE3X+3;KXj!NMZUF>V1E7YqX@ zf(K9FzMk*Do1b+%z{wBa!`J~H3i^QVAEDI=EDH`E&!wV2Mr#eYC0GUS3x*{5Kad;% zE(>n^xl~SY3%Db=o29SOX3PQ^43y zP}vV$z4~wX!SMp87uFd)Hwk<`cy_?&0;kIx_IBzbFPwg&q1$(-! z3DcRP%r0{wdedKcehRoMOlw$jdrp{sg2y!9LK zOn6h+a}4?G!mcNR=?Z&XKXDG2u7Sav7v8=H+!Xdq!281aZLsS}=s)^ra71`lI3=7? zjv+rLoP94iDqMdbI4pc1>=ix~b_oY=K>pg7F@93GEW9c_OrgCgoc#51zQ23IVd1&3 zPuTx{)b|LNgr~nk|K_jIJ`zq|L|-w}|9>c45bg@Mg!9_}V&tcU`@;QSqQ2`Aw6}$K zglodZu=E#R6LtvCgzXmv%(eOW#Lf^`99&oUGz0PJ$?@1n(*OgP;Vx@BOLuX@;8K+g+0PWVTW*AczXl& zwux63pM2tEKqvhlF&p7Ezh$Iby5Jx4{5ZI1E)-9`B}~6uVJ-{Lg-3JbyN;l}Bit9J ziwbl7iSVlM$Xr04>h~W(d(s775?(qT92Y(i4hhGOqJ86j^rs&Ka=qfO!A;?)a94Qz zE4258_k^=&qW`w=G$i$fOBaHh!nvn_%ffVJDjvTi+!5~l6ytd=LVNU6;E-_pli;Xu zQaB;(6K)B+h4arvJ%@1mHtwSz9~EuZVP+=3GGhdpz!GCBc(p2s!eyz)Kh-@gRy zZQ*UgwZCxpV$^s3#ZkWhP2tuLjxkq-yTS$HmBkBBoD7!D z@BipdRFBgawsf020v?$Qs}rvYC+FanaP46*txm=LX)OlMAD;}aKLifA!DZon;jFOt z47A6Eqr(0N(4Q{-%k^D92Tz1)Ay^(SawghmcYrg(6XCvaUby>Me*ekp)y!lhq>kBXcShhzv`%bXqxu~D2 zfCIu&;h1pm9cbSW?g$?WZwiN=C-+bDx4=2!T^D_WPtSv_aOrfgM>u#C>=q6QJB2rd z8*j(TsmKosr-U0VjiMV&HKPT;i@oQP><|tM55Ix_(>st~7mmJ>PD(J_Av_We3r~bY!r3<=f3k~u z1L2Br^v!7R{2SVrg`2`P;q7fzfEHqB+w=}yhfmD1^U zNtmwk#=L4SOHMu^Osk48Cr?IuS=f6TI4(@b2)KVtI3i4Y6x_cd+`Au4#{@Wk;}Gn3 zf$2re`Eg-7CcsQDYR+E}&I;G=LHks=dpEfLSmZZ_r{{tz!r>s;C!7|ZOwiwRKHBrb zSz$V+z}qvw0PS?mUS@9yyek|KP6+P^`@VyG`Z&V(C;54B@X27Wa7MUz3)zmSF z^M!XWME~69(4K#a)E8bCo_$vOzZLBV!tT$4hr&VOs&GtrSvV)067C4QUy1r%;prxw z13}NH$>&ht_ZhH%6Fm7exF}o`ZVRV{Bex(w_bKFOgh!tQSA_k->%#3e+Bb#6!hu^+ zZ$UUHJp2dbM};?pXZIk#DqQLw;s0Mycwg8r?EfdUw?2V->t6J;{{zYhC7w!p9+Gt-A z4tx^q5snJC7tlXNys~)liIV|cqvY9U5}-ZQ z=Ca+C7Z*-G2+j+Kg)72A;k~2i?-d?C2J8@CI2%0vE$TIed&1t|L3{P5Xb%Z*{Rq4u zyuS~Q3-^R`!d>C@%TO;VTz)Y)BHVo`c=$uqn+op<2UBQo39k$1g%5<2*Q0;xWyp8G z0bCZ|xf(nZ&Mkv?Uyl6TYrxJ|fVW=__6wJUL&8hKIpKK*`SHI+|Fm%7r6=+KJ0`p? zJbo4O`@%!vfpAASkV5~0a9B7{IL^!rQ_T;kmF+*!eQlKX@Jb z2ZVQoqr#iQMd5_-))&w}AiSNR(?DoDyuwrA=GCb0cs|!o?!md%_;!Lt%$- z>Z()XD~p%>x5>c7ugVUqZ{LlYHkZWWwKqm;w&|*Q zY;trNL(W-NL*GHbWYn0&2|0&$Jv^b#?&lZ5S=sR^p_ zyi{^FzpmS(X8#Rhs%ji+w$!w;lC!NQqQ=c9xX={hR1;HEd%g79QIk>A|CVge#_J@f z@Vt}c(NMFdrle-BCZT4T4u$cfF}Paltf>i{c^r>MQ_aC!%{ex-snO?4ufyxaRMqU$ zfp^Z?RI{#TPfb{j;{{SDs3xq&ttPGJ;999Xxlwuz)vWo1ZYP?J>?R?}3|dyDjHs~P_0x47@A-96<=7FFSALyHI3SzVm!{&f4ugRvTCa#8#h#;p%SKQ;A zEKN_`i%x)m*;Yeee83D&m7KJi;Avv$%MavK)X?&>FjX~GzP2=N;l3LBiUc`cm(N#(JcdQwbaLwX&4-kkqK&#{!6f|^Y=#akt3TTSB@F}rFO)Xda4 z)l5I5V{K}$ADfenXwYpn%bzl*6Ol=%$$nZ)P|cbepPIUw{m)6|zM8BW$A;9|{jB7; z)NHEBtC=pCvzcf!3Tm>Se;8)0CUUD7=Vzqy_AO#UY8q-T`*=f9X9mM zEX=}bVp=rOa4-72g`6FlMA*NGabPo4lT?%NOC2+WQJb!O%Apvqb7Bpn24HXH6b4A)hzu) z%*OeWle$ez-~uszH4AFyUzD7*n!cL0nv$B$E48wk^_#^ksBtyL1l0u8Y|u2v+ZNG&MEZESbZ}W(wmWvb3+nhQ3UO@v1SC8s`jYQsbNwUDce;zM5n}Okd5i znw}cZIg&$%r%)%Qrl@9FO=2uLbOJlZ+WQwVdunpqV*E7m@%EUdYACfwAHCyX?FTN+0*Z+r@wlOPLp+?* zvW5=muoj#bneJUM0q_rlzLkUCozh*?roR#Q?_P_v~bre>gKL5<^$Qh9#9RQ6ut#C_jX6ZqmO zFfBD}YRYQ%)r8c{)wtD!Zj(9=HAOY+=gH^;HOp!|w9EbWEsF)3! z@3?Z@Ehem{b%q$X8V4N+=Q`^%>D5$|RpC4v&Ca!AmejP>xL+u}Qfho^g4al|l$y$0v`$jTdb5~iHT!Q8 zv#MrI&E{2lPt`0fiW$67dijIqv@m+kSJjNwtg4x+DXN*P$*BpPCv}E@WX?CEf-yBq zUovN$kx8lXs41(NbtGq1O;t@(O+=07d>L!`HtDrR$8q>xEU59^#& zL(@Ur1k>YUs%p$)eVntSCQL7WHa?mlxR-hHv!Qc#Vb;}TXoBFJvKl8%2W--67S#CF zB-Bi4I^fFdYAW=B3)`b+NFRdO&~JQXblSCHv#F-0W|<}t9(`3!fxZ`GvrW?q_o~v@ zjmS|mdYYJ$n%2|B1k`LkLrjk*Bd(K(i3zJ&RWlt)PEUjj6MvrM9K2mjPEF%_F?(tXYQi+l@mPhj z%-cb|Y6XT;5+j#Vy?}=%t>8k0gsee~;@{g5XH8nLgYidfrJc7ze zHKEryVf_DB>J+aQ<5p8uGaE?GP|ZlqM9rR>shWYBI;|4N_p7U>uO|90<}5sVz!zRG zISbEw5N1P7G%hBP)ooT&QZstC)bZQiC8|Kv7gU#KE8d8@RqLuIR=HJ~ zT_tqfqv?q|I%z)$%15&iYCTicK21KX0?$&FR5i0Iq$)~N6lytCO|6R1l*F~-G)W=L zuPS3z=0?e?S~Y*SsHRn&cZ%AvDyeF0l~>h>CPCRQRRh`y5%rf+%YO==XGYJjgNmrG zRc%$lQ#GrlYTc?5O`v=qEA)It`GTr}Rm-YIR=HFSPS=i2ntJi?II60k%5zNB=z7hf zj|muYPt}f9EmiwgU{zL?|17Pwq$+JyL{-MBpsK7@K2;T~YPajA-fPRN3O~~+kJn@*S!-7HRJop| zS*xlRtXfc&w`%?~sa0jQvKakulNmi`w0msM9ib>k(;A(gKvBcG*>QE|NEGtux08|= zw=QWN9pjh01?!xrsB>GFwT^y9CH2fs8O8`%XU-XbYdxZQUF-HAQb!kBk-RNB7>@g~ z?ovmWOA*(%E@oZ!bj>@wTl2a!hhp5gb?uMcsHq z^Rm{}ty{A$Y@PR3?d!B|X5H?mwO;D;nzv+K!#e-})Vy8mf*+?RB|lI1tXuz>I>(mQ z%UM^oZrQqlb)%1J-@q5NZ__&eZS*e8uaCNQ-mj>uS+`?d-nx->G3z2<)xQ0Y=(uav z#jNXE=dmumqkRiE(%U~j20iQ6->PoUx`qF8U(>#R>(XCW*S=2sX02PWZq2%#lIFGl zP5V}?3+<{~vaVrW+`9bvPI*j&)(sy^FF1ajJ=Udt>fF|4t#eq{v2OcX-M+qc4eMf$ z(|X10&1sA)i%)wIov%SFhnj2tn<+*CU3*(lC+P7|<=Vo<%>t-KT=f7P04y_AVw`X0;x`uT_>z1r@Mzr4iCLJS^RyVdT z+E=${-I8@%)=jJ{Sm*n9?K^Gg7=G*e)&;(;c^T^#tXoZL--=J=0-IjHc@94N~>+;t5tjql#z3BOQ)yitU@Uv(yfL+>k`(D zpQCyEuhYDlbp`7x&(%Dabrb96S8Lzg^E7X2-PpR|wX{dVk5$s_x$rh`m?J>!)~s_` z=Xt*7)vwdMRqM*uO|2_gmwbWt?Uc0dnssIC2G*sm3(?Qgv3+st64ni`)q24+y{P$q zRR7EQ&JZ`dM)PWqQ`cv=viS59CqpMav*<(;Ga1r&&4zn7G6{8sZibnp#%KI{6{d90gQm%Cs49$M%8u{!!`p46NCNS%+q1H~9U z>x$N8_BF5ZRL!Fw?#cZazEqv}hw6MOb?qOh8^26l*t-17)g69c^A4=rv~K(g&71wR z=B3TF$p52yMqS>z%|BFEv#w#?5`F86zN^-?txH%pvMylV_678Q%j1q;sbe_P>dMxY ztjk+>U|rI>nRPMiJg?Mxqo|INw66O~nmYLPv0>ekbsg*C)(xzSSm*dd9pf;q_1xAa ztxH-LwQl1jG$rt3aPSK4>&5wz^jaNTm#}Wfy0UdW>lR*3AIi91+`61~`IqY$o7Q!{ zMEf=T80=X$UV1=W_BC2BY2DPigGJ4oTeoao?X}uB|02!HFRPngr7mdBv*K;8T30Ci zt~j4{?Ki4(S{Hn?I)`;7b}Ng|IB_!U(leFL7omqWW~Me?oPKIu(7L9X+Bk1Wj~4RE zW-pLk$vXO$L!9|GRPyNFA+JUE4sOjlhjlgUx>L=g?{cMY*-SgUeS2ow;c>mz`F>xW z%etU-MYE^LeMi=H%*R@GGwZg#sjm13S}*(!bwlf7yXuzAUL|i!$hxd`(|^;vmUS)b zIxo<^IqTeSQMX`S;u>|^^s^jnbJe{I0t4GigfT$Et4K@}9a?>q6gG=ddnp-N38$I@aZWN8Q*Q zmF6+x*2R9LZg#fj)vY_UZp*rkb$izJ?x7b0KVMRG>O00rznGrn>@wEHts7eBdy(c% ztaDoD_z$f&xk~eV)^)A(TeoFh*t)iLmG{t#kROATb@M-?<4Ej0)*TkqHO)_1_W;7#gM)=jK)TetLP&D(gP)~j0Aw614e*t%VFW&>~E^aa{?Y~Ah~X0m+~>kc1J e$HRD=LvPW(HS0#!Ijjp_qj_Dw>AP~>#s3cl5Ih_J literal 0 HcmV?d00001 diff --git a/liberate/fhe/ckks_engine.py b/liberate/fhe/ckks_engine.py new file mode 100644 index 0000000..7753ae0 --- /dev/null +++ b/liberate/fhe/ckks_engine.py @@ -0,0 +1,2615 @@ +import datetime +import gc +import math +import pickle +from hashlib import sha256 +from pathlib import Path + +import numpy as np +import torch + +# from context.ckks_context import ckks_context +from .context.ckks_context import ckks_context +from .data_struct import data_struct +from .encdec import decode, encode, rotate, conjugate +from .version import VERSION +from .presets import types, errors +from liberate.ntt import ntt_context +from liberate.ntt import ntt_cuda +from liberate.csprng import Csprng + + +class ckks_engine: + @errors.log_error + def __init__(self, devices: list[int] = None, verbose: bool = False, + bias_guard: bool = True, norm: str = 'forward', **ctx_params): + """ + buffer_bit_length=62, + scale_bits=40, + logN=15, + num_scales=None, + num_special_primes=2, + sigma=3.2, + uniform_tenary_secret=True, + cache_folder='cache/', + security_bits=128, + quantum='post_quantum', + distribution='uniform', + read_cache=True, + save_cache=True, + verbose=False + """ + + self.bias_guard = bias_guard + + self.norm = norm + + self.version = VERSION + + self.ctx = ckks_context(**ctx_params) + self.ntt = ntt_context(self.ctx, devices=devices, verbose=verbose) + + if self.bias_guard: + if self.ctx.num_special_primes < 2: + raise errors.NotEnoughPrimesForBiasGuard( + bias_guard=self.bias_guard, + num_special_primes=self.ctx.num_special_primes) + + self.num_levels = self.ntt.num_levels - 1 + + self.num_slots = self.ctx.N // 2 + + rng_repeats = max(self.ntt.num_special_primes, 2) + self.rng = Csprng(self.ntt.ctx.N, [len(di) for di in self.ntt.p.d], rng_repeats, devices=self.ntt.devices) + + self.int_scale = 2 ** self.ctx.scale_bits + self.scale = np.float64(self.int_scale) + + qstr = ','.join([str(qi) for qi in self.ctx.q]) + hashstr = (self.ctx.generation_string + "_" + qstr).encode("utf-8") + self.hash = sha256(bytes(hashstr)).hexdigest() + + self.make_adjustments_and_corrections() + + self.device0 = self.ntt.devices[0] + + self.make_mont_PR() + + self.reserve_ksk_buffers() + + self.create_ksk_rescales() + + self.alloc_parts() + + self.leveled_devices() + + self.create_rescale_scales() + + self.galois_deltas = [2 ** i for i in range(self.ctx.logN - 1)] + + self.mult_dispatch_dict = { + (data_struct, data_struct): self.auto_cc_mult, + (list, data_struct): self.mc_mult, + (np.ndarray, data_struct): self.mc_mult, + (data_struct, np.ndarray): self.cm_mult, + (data_struct, list): self.cm_mult, + (float, data_struct): self.scalar_mult, + (data_struct, float): self.mult_scalar, + (int, data_struct): self.int_scalar_mult, + (data_struct, int): self.mult_int_scalar + } + + self.add_dispatch_dict = { + (data_struct, data_struct): self.auto_cc_add, + (list, data_struct): self.mc_add, + (np.ndarray, data_struct): self.mc_add, + (data_struct, np.ndarray): self.cm_add, + (data_struct, list): self.cm_add, + (float, data_struct): self.scalar_add, + (data_struct, float): self.add_scalar, + (int, data_struct): self.scalar_add, + (data_struct, int): self.add_scalar + } + + self.sub_dispatch_dict = { + (data_struct, data_struct): self.auto_cc_sub, + (list, data_struct): self.mc_sub, + (np.ndarray, data_struct): self.mc_sub, + (data_struct, np.ndarray): self.cm_sub, + (data_struct, list): self.cm_sub, + (float, data_struct): self.scalar_sub, + (data_struct, float): self.sub_scalar, + (int, data_struct): self.scalar_sub, + (data_struct, int): self.sub_scalar + } + + # ------------------------------------------------------------------------------------------- + # Various pre-calculations. + # ------------------------------------------------------------------------------------------- + def create_rescale_scales(self): + self.rescale_scales = [] + for level in range(self.num_levels): + self.rescale_scales.append([]) + + for device_id in range(self.ntt.num_devices): + dest_level = self.ntt.p.destination_arrays[level] + + if device_id < len(dest_level): + dest = dest_level[device_id] + rescaler_device_id = self.ntt.p.rescaler_loc[level] + m0 = self.ctx.q[level] + + if rescaler_device_id == device_id: + m = [self.ctx.q[i] for i in dest[1:]] + else: + m = [self.ctx.q[i] for i in dest] + + scales = [(pow(m0, -1, mi) * self.ctx.R) % mi for mi in m] + + scales = torch.tensor(scales, + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id]) + self.rescale_scales[level].append(scales) + + def leveled_devices(self): + self.len_devices = [] + for level in range(self.num_levels): + self.len_devices.append(len([[a] for a in self.ntt.p.p[level] if len(a) > 0])) + + self.neighbor_devices = [] + for level in range(self.num_levels): + self.neighbor_devices.append([]) + len_devices_at = self.len_devices[level] + available_devices_ids = range(len_devices_at) + for src_device_id in available_devices_ids: + neighbor_devices_at = [ + device_id for device_id in available_devices_ids if device_id != src_device_id + ] + self.neighbor_devices[level].append(neighbor_devices_at) + + def alloc_parts(self): + self.parts_alloc = [] + for level in range(self.num_levels): + num_parts = [len(parts) for parts in self.ntt.p.p[level]] + parts_alloc = [ + alloc[-num_parts[di] - 1:-1] for di, alloc in enumerate(self.ntt.p.part_allocations) + ] + self.parts_alloc.append(parts_alloc) + + self.stor_ids = [] + for level in range(self.num_levels): + self.stor_ids.append([]) + alloc = self.parts_alloc[level] + min_id = min([min(a) for a in alloc if len(a) > 0]) + for device_id in range(self.ntt.num_devices): + global_ids = self.parts_alloc[level][device_id] + new_ids = [i - min_id for i in global_ids] + self.stor_ids[level].append(new_ids) + + def create_ksk_rescales(self): + R = self.ctx.R + P = self.ctx.q[-self.ntt.num_special_primes:][::-1] + m = self.ctx.q + PiR = [[(pow(Pj, -1, mi) * R) % mi for mi in m[:-P_ind - 1]] for P_ind, Pj in enumerate(P)] + + self.PiRs = [] + + level = 0 + self.PiRs.append([]) + + for P_ind in range(self.ntt.num_special_primes): + self.PiRs[level].append([]) + + for device_id in range(self.ntt.num_devices): + dest = self.ntt.p.destination_arrays_with_special[level][device_id] + PiRi = [PiR[P_ind][i] for i in dest[:-P_ind - 1]] + PiRi = torch.tensor(PiRi, + device=self.ntt.devices[device_id], + dtype=self.ctx.torch_dtype) + self.PiRs[level][P_ind].append(PiRi) + + for level in range(1, self.num_levels): + self.PiRs.append([]) + + for P_ind in range(self.ntt.num_special_primes): + + self.PiRs[level].append([]) + + for device_id in range(self.ntt.num_devices): + start = self.ntt.starts[level][device_id] + PiRi = self.PiRs[0][P_ind][device_id][start:] + + self.PiRs[level][P_ind].append(PiRi) + + def reserve_ksk_buffers(self): + self.ksk_buffers = [] + for device_id in range(self.ntt.num_devices): + self.ksk_buffers.append([]) + for part_id in range(len(self.ntt.p.p[0][device_id])): + buffer = torch.empty( + [self.ntt.num_special_primes, self.ctx.N], + dtype=self.ctx.torch_dtype + ).pin_memory() + self.ksk_buffers[device_id].append(buffer) + + def make_mont_PR(self): + P = math.prod(self.ntt.ctx.q[-self.ntt.num_special_primes:]) + R = self.ctx.R + PR = P * R + self.mont_PR = [] + for device_id in range(self.ntt.num_devices): + dest = self.ntt.p.destination_arrays[0][device_id] + m = [self.ctx.q[i] for i in dest] + PRm = [PR % mi for mi in m] + PRm = torch.tensor(PRm, + device=self.ntt.devices[device_id], + dtype=self.ctx.torch_dtype) + self.mont_PR.append(PRm) + + def make_adjustments_and_corrections(self): + + self.alpha = [(self.scale / np.float64(q)) ** 2 for q in self.ctx.q[:self.ctx.num_scales]] + self.deviations = [1] + for al in self.alpha: + self.deviations.append(self.deviations[-1] ** 2 * al) + + self.final_q_ind = [da[0][0] for da in self.ntt.p.destination_arrays[:-1]] + self.final_q = [self.ctx.q[ind] for ind in self.final_q_ind] + self.final_alpha = [(self.scale / np.float64(q)) for q in self.final_q] + self.corrections = [1 / (d * fa) for d, fa in zip(self.deviations, self.final_alpha)] + + self.base_prime = self.ctx.q[self.ntt.p.base_prime_idx] + + self.final_scalar = [] + for qi, q in zip(self.final_q_ind, self.final_q): + scalar = (pow(q, -1, self.base_prime) * self.ctx.R) % self.base_prime + scalar = torch.tensor([scalar], + device=self.ntt.devices[0], + dtype=self.ctx.torch_dtype) + self.final_scalar.append(scalar) + + # ------------------------------------------------------------------------------------------- + # Example generation. + # ------------------------------------------------------------------------------------------- + + def absmax_error(self, x, y): + if type(x[0]) == np.complex128 and type(y[0]) == np.complex128: + r = np.abs(x.real - y.real).max() + np.abs(x.imag - y.imag).max() * 1j + else: + r = np.abs(np.array(x) - np.array(y)).max() + return r + + def integral_bits_available(self): + base_prime = self.base_prime + max_bits = math.floor(math.log2(base_prime)) + integral_bits = max_bits - self.ctx.scale_bits + return integral_bits + + @errors.log_error + def example(self, amin=None, amax=None, decimal_places: int = 10) -> np.array: + if amin is None: + amin = -(2 ** self.integral_bits_available()) + + if amax is None: + amax = 2 ** self.integral_bits_available() + + base = 10 ** decimal_places + a = np.random.randint(amin * base, amax * base, self.ctx.N // 2) / base + b = np.random.randint(amin * base, amax * base, self.ctx.N // 2) / base + + sample = a + b * 1j + + return sample + + # ------------------------------------------------------------------------------------------- + # Encode/Decode + # ------------------------------------------------------------------------------------------- + + def padding(self, m): + # m = m[:self.num_slots] + try: + m_len = len(m) + padding_result = np.pad(m, (0, self.num_slots - m_len), constant_values=(0, 0)) + except TypeError as e: + m_len = len([m]) + padding_result = np.pad([m], (0, self.num_slots - m_len), constant_values=(0, 0)) + except Exception as e: + raise Exception("[Error] encoding Padding Error.") + return padding_result + + @errors.log_error + def encode(self, m, level: int = 0, padding=True) -> list[torch.Tensor]: + """ + Encode a plain message m, using an encoding function. + Note that the encoded plain text is pre-permuted to yield cyclic rotation. + """ + deviation = self.deviations[level] + if padding: + m = self.padding(m) + encoded = [encode(m, scale=self.scale, rng=self.rng, + device=self.device0, + deviation=deviation, norm=self.norm)] + + pt_buffer = self.ksk_buffers[0][0][0] + pt_buffer.copy_(encoded[-1]) + for dev_id in range(1, self.ntt.num_devices): + encoded.append(pt_buffer.cuda(self.ntt.devices[dev_id])) + return encoded + + @errors.log_error + def decode(self, m, level=0, is_real: bool = False) -> list: + """ + Base prime is located at -1 of the RNS channels in GPU0. + Assuming this is an orginary RNS deinclude_special. + """ + correction = self.corrections[level] + decoded = decode(m[0].squeeze(), scale=self.scale, correction=correction, norm=self.norm) + m = decoded[:self.ctx.N // 2].cpu().numpy() + if is_real: + m = m.real + return m + + # ------------------------------------------------------------------------------------------- + # secret key/public key generation. + # ------------------------------------------------------------------------------------------- + + @errors.log_error + def create_secret_key(self, include_special: bool = True) -> data_struct: + uniform_ternary = self.rng.randint(amax=3, shift=-1, repeats=1) + + mult_type = -2 if include_special else -1 + unsigned_ternary = self.ntt.tile_unsigned(uniform_ternary, lvl=0, mult_type=mult_type) + self.ntt.enter_ntt(unsigned_ternary, 0, mult_type) + + return data_struct( + data=unsigned_ternary, + include_special=include_special, + montgomery_state=True, + ntt_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + + @errors.log_error + def create_public_key(self, sk: data_struct, include_special: bool = False, + a: list[torch.Tensor] = None) -> data_struct: + """ + Generates a public key against the secret key sk. + pk = -a * sk + e = e - a * sk + """ + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if include_special and not sk.include_special: + raise errors.SecretKeyNotIncludeSpecialPrime() + + # Set the mult_type + mult_type = -2 if include_special else -1 + + # Generate errors for the ordinary case. + level = 0 + e = self.rng.discrete_gaussian(repeats=1) + e = self.ntt.tile_unsigned(e, level, mult_type) + + self.ntt.enter_ntt(e, level, mult_type) + repeats = self.ctx.num_special_primes if sk.include_special else 0 + + # Applying mont_mult in the order of 'a', sk will + if a is None: + a = self.rng.randint( + self.ntt.q_prepack[mult_type][level][0], + repeats=repeats + ) + + sa = self.ntt.mont_mult(a, sk.data, 0, mult_type) + pk0 = self.ntt.mont_sub(e, sa, 0, mult_type) + + return data_struct( + data=(pk0, a), + include_special=include_special, + ntt_state=True, + montgomery_state=True, + origin=types.origins["pk"], + level=0, + hash=self.hash, + version=self.version + ) + + # ------------------------------------------------------------------------------------------- + # Encrypt/Decrypt + # ------------------------------------------------------------------------------------------- + + @errors.log_error + def encrypt(self, pt: list[torch.Tensor], pk: data_struct, level: int = 0) -> data_struct: + """ + We again, multiply pt by the scale. + Since pt is already multiplied by the scale, + the multiplied pt no longer can be stored + in a single RNS channel. + That means we are forced to do the multiplication + in full RNS domain. + Note that we allow encryption at + levels other than 0, and that will take care of multiplying + the deviation factors. + """ + if pk.origin != types.origins["pk"]: + raise errors.NotMatchType(origin=pk.origin, to=types.origins["pk"]) + + mult_type = -2 if pk.include_special else -1 + + e0e1 = self.rng.discrete_gaussian(repeats=2) + + e0 = [e[0] for e in e0e1] + e1 = [e[1] for e in e0e1] + + e0_tiled = self.ntt.tile_unsigned(e0, level, mult_type) + e1_tiled = self.ntt.tile_unsigned(e1, level, mult_type) + + pt_tiled = self.ntt.tile_unsigned(pt, level, mult_type) + self.ntt.mont_enter_scale(pt_tiled, level, mult_type) + self.ntt.mont_redc(pt_tiled, level, mult_type) + pte0 = self.ntt.mont_add(pt_tiled, e0_tiled, level, mult_type) + + start = self.ntt.starts[level] + pk0 = [pk.data[0][di][start[di]:] for di in range(self.ntt.num_devices)] + pk1 = [pk.data[1][di][start[di]:] for di in range(self.ntt.num_devices)] + + v = self.rng.randint(amax=2, shift=0, repeats=1) + + v = self.ntt.tile_unsigned(v, level, mult_type) + self.ntt.enter_ntt(v, level, mult_type) + + vpk0 = self.ntt.mont_mult(v, pk0, level, mult_type) + vpk1 = self.ntt.mont_mult(v, pk1, level, mult_type) + + self.ntt.intt_exit(vpk0, level, mult_type) + self.ntt.intt_exit(vpk1, level, mult_type) + + ct0 = self.ntt.mont_add(vpk0, pte0, level, mult_type) + ct1 = self.ntt.mont_add(vpk1, e1_tiled, level, mult_type) + + self.ntt.reduce_2q(ct0, level, mult_type) + self.ntt.reduce_2q(ct1, level, mult_type) + + ct = data_struct( + data=(ct0, ct1), + include_special=mult_type == -2, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + return ct + + def decrypt_triplet(self, ct_mult: data_struct, sk: data_struct) -> list[torch.Tensor]: + if ct_mult.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=ct_mult.origin, to=types.origins["ctt"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + if not ct_mult.ntt_state or not ct_mult.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct_mult.origin) + if (not sk.ntt_state) or (not sk.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk.origin) + + level = ct_mult.level + d0 = [ct_mult.data[0][0].clone()] + d1 = [ct_mult.data[1][0]] + d2 = [ct_mult.data[2][0]] + + self.ntt.intt_exit_reduce(d0, level) + + sk_data = [sk.data[0][self.ntt.starts[level][0]:]] + + d1_s = self.ntt.mont_mult(d1, sk_data, level) + + s2 = self.ntt.mont_mult(sk_data, sk_data, level) + d2_s2 = self.ntt.mont_mult(d2, s2, level) + + self.ntt.intt_exit(d1_s, level) + self.ntt.intt_exit(d2_s2, level) + + pt = self.ntt.mont_add(d0, d1_s, level) + pt = self.ntt.mont_add(pt, d2_s2, level) + self.ntt.reduce_2q(pt, level) + + base_at = -self.ctx.num_special_primes - 1 if ct_mult.include_special else -1 + + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + return scaled + + def decrypt_double(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + if not sk.ntt_state or not sk.montgomery_state: + raise errors.NotMatchDataStructState(origin=sk.origin) + + ct0 = ct.data[0][0] + level = ct.level + sk_data = sk.data[0][self.ntt.starts[level][0]:] + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + pt = self.ntt.mont_add([ct0], sa, level) + self.ntt.reduce_2q(pt, level) + + base_at = -self.ctx.num_special_primes - 1 if ct.include_special else -1 + + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + ############################################################################# + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + return scaled + + def decrypt(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: + """ + Decrypt the cipher text ct using the secret key sk. + Note that the final rescaling must precede the actual decryption process. + """ + + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + if ct.origin == types.origins["ctt"]: + pt = self.decrypt_triplet(ct_mult=ct, sk=sk) + elif ct.origin == types.origins["ct"]: + pt = self.decrypt_double(ct=ct, sk=sk) + else: + raise errors.NotMatchType(origin=ct.origin, to=f"{types.origins['ct']} or {types.origins['ctt']}") + + return pt + + # ------------------------------------------------------------------------------------------- + # Key switching. + # ------------------------------------------------------------------------------------------- + + def create_key_switching_key(self, sk_from: data_struct, sk_to: data_struct, a=None) -> data_struct: + """ + Creates a key to switch the key for sk_src to sk_dst. + """ + + if sk_from.origin != types.origins["sk"] or sk_from.origin != types.origins["sk"]: + raise errors.NotMatchType(origin="not a secret key", to=types.origins["sk"]) + if (not sk_from.ntt_state) or (not sk_from.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_from.origin) + if (not sk_to.ntt_state) or (not sk_to.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_to.origin) + + level = 0 + + stops = self.ntt.stops[-1] + Psk_src = [sk_from.data[di][:stops[di]].clone() for di in range(self.ntt.num_devices)] + + self.ntt.mont_enter_scalar(Psk_src, self.mont_PR, level) + + ksk = [[] for _ in range(self.ntt.p.num_partitions + 1)] + + for device_id in range(self.ntt.num_devices): + for part_id, part in enumerate(self.ntt.p.p[level][device_id]): + global_part_id = self.ntt.p.part_allocations[device_id][part_id] + + crs = a[global_part_id] if a else None + pk = self.create_public_key(sk_to, include_special=True, a=crs) + + key = tuple(part) + astart = part[0] + astop = part[-1] + 1 + shard = Psk_src[device_id][astart:astop] + pk_data = pk.data[0][device_id][astart:astop] + + _2q = self.ntt.parts_pack[device_id][key]['_2q'] + update_part = ntt_cuda.mont_add([pk_data], [shard], _2q)[0] + pk_data.copy_(update_part, non_blocking=True) + + pk_name = f'key switch key part index {global_part_id}' + pk = pk._replace(origin=pk_name) + + ksk[global_part_id] = pk + + return data_struct( + data=ksk, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ksk"], + level=level, + hash=self.hash, + version=self.version) + + def pre_extend(self, a, device_id, level, part_id, exit_ntt=False): + text_part = self.ntt.p.parts[level][device_id][part_id] + param_part = self.ntt.p.p[level][device_id][part_id] + + alpha = len(text_part) + a_part = a[device_id][text_part[0]:text_part[-1] + 1] + + if exit_ntt: + self.ntt.intt_exit_reduce([a_part], level, device_id, part_id) + + state = a_part[0].repeat(alpha, 1) + key = tuple(param_part) + for i in range(alpha - 1): + mont_pack = self.ntt.parts_pack[device_id][param_part[i + 1],]['mont_pack'] + _2q = self.ntt.parts_pack[device_id][param_part[i + 1],]['_2q'] + Y_scalar = self.ntt.parts_pack[device_id][key]['Y_scalar'][i][None] + + Y = (a_part[i + 1] - state[i + 1])[None, :] + + ntt_cuda.mont_enter([Y], [Y_scalar], *mont_pack) + ntt_cuda.reduce_2q([Y], _2q) + + state[i + 1] = Y + + if i + 2 < alpha: + state_key = tuple(param_part[i + 2:]) + state_mont_pack = self.ntt.parts_pack[device_id][state_key]['mont_pack'] + state_2q = self.ntt.parts_pack[device_id][state_key]['_2q'] + L_scalar = self.ntt.parts_pack[device_id][key]['L_scalar'][i] + new_state_len = alpha - (i + 2) + new_state = Y.repeat(new_state_len, 1) + ntt_cuda.mont_enter([new_state], [L_scalar], *state_mont_pack) + ntt_cuda.reduce_2q([new_state], state_2q) + state[i + 2:] += new_state + ntt_cuda.reduce_2q([state[i + 2:]], state_2q) + + return state + + def extend(self, state, device_id, level, part_id, target_device_id=None): + + if target_device_id is None: + target_device_id = device_id + + rns_len = len( + self.ntt.p.destination_arrays_with_special[level][target_device_id]) + alpha = len(state) + + extended = state[0].repeat(rns_len, 1) + self.ntt.mont_enter([extended], level, target_device_id, -2) + + part = self.ntt.p.p[level][device_id][part_id] + key = tuple(part) + + L_enter = self.ntt.parts_pack[device_id][key]['L_enter'][target_device_id] + + start = self.ntt.starts[level][target_device_id] + + for i in range(alpha - 1): + Y = state[i + 1].repeat(rns_len, 1) + + self.ntt.mont_enter_scalar([Y], [L_enter[i][start:]], level, target_device_id, -2) + extended = self.ntt.mont_add([extended], [Y], level, target_device_id, -2)[0] + + return extended + + def create_switcher(self, a: list[torch.Tensor], ksk: data_struct, level, exit_ntt=False) -> tuple: + ksk_alloc = self.parts_alloc[level] + + len_devices = self.len_devices[level] + neighbor_devices = self.neighbor_devices[level] + + num_parts = sum([len(alloc) for alloc in ksk_alloc]) + part_results = [[[[] for _ in range(len_devices)], [[] for _ in range(len_devices)]] for _ in range(num_parts)] + + states = [[] for _ in range(num_parts)] + for src_device_id in range(len_devices): + for part_id in range(len(self.ntt.p.p[level][src_device_id])): + storage_id = self.stor_ids[level][src_device_id][part_id] + state = self.pre_extend(a, + src_device_id, + level, + part_id, + exit_ntt + ) + states[storage_id] = state + + CPU_states = [[] for _ in range(num_parts)] + for src_device_id in range(len_devices): + for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): + storage_id = self.stor_ids[level][src_device_id][part_id] + alpha = len(part) + CPU_state = self.ksk_buffers[src_device_id][part_id][:alpha] + CPU_state.copy_(states[storage_id], non_blocking=True) + CPU_states[storage_id] = CPU_state + + for src_device_id in range(len_devices): + for part_id in range(len(self.ntt.p.p[level][src_device_id])): + storage_id = self.stor_ids[level][src_device_id][part_id] + state = states[storage_id] + d0, d1 = self.switcher_later_part(state, ksk, + src_device_id, + src_device_id, + level, part_id) + + part_results[storage_id][0][src_device_id] = d0 + part_results[storage_id][1][src_device_id] = d1 + + CUDA_states = [[] for _ in range(num_parts)] + for src_device_id in range(len_devices): + for j, dst_device_id in enumerate( + neighbor_devices[src_device_id]): + for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): + storage_id = self.stor_ids[level][src_device_id][part_id] + CPU_state = CPU_states[storage_id] + CUDA_states[storage_id] = CPU_state.cuda(self.ntt.devices[dst_device_id], non_blocking=True) + + torch.cuda.synchronize() + + for src_device_id in range(len_devices): + for j, dst_device_id in enumerate( + neighbor_devices[src_device_id]): + for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): + storage_id = self.stor_ids[level][src_device_id][part_id] + CUDA_state = CUDA_states[storage_id] + d0, d1 = self.switcher_later_part(CUDA_state, + ksk, + src_device_id, + dst_device_id, + level, + part_id) + part_results[storage_id][0][dst_device_id] = d0 + part_results[storage_id][1][dst_device_id] = d1 + + summed0 = part_results[0][0] + summed1 = part_results[0][1] + + for i in range(1, len(part_results)): + summed0 = self.ntt.mont_add(summed0, part_results[i][0], level, -2) + summed1 = self.ntt.mont_add(summed1, part_results[i][1], level, -2) + + d0 = summed0 + d1 = summed1 + + current_len = [len(d) for d in self.ntt.p.destination_arrays_with_special[level]] + + for P_ind in range(self.ntt.num_special_primes): + current_len = [c - 1 for c in current_len] + + PiRi = self.PiRs[level][P_ind] + + P0 = [d[-1].repeat(current_len[di], 1) for di, d in enumerate(d0)] + P1 = [d[-1].repeat(current_len[di], 1) for di, d in enumerate(d1)] + + d0 = [d0[i][:current_len[i]] - P0[i] for i in range(len_devices)] + d1 = [d1[i][:current_len[i]] - P1[i] for i in range(len_devices)] + + self.ntt.mont_enter_scalar(d0, PiRi, level, -2) + self.ntt.mont_enter_scalar(d1, PiRi, level, -2) + + self.ntt.reduce_2q(d0, level, -1) + self.ntt.reduce_2q(d1, level, -1) + + return d0, d1 + + def switcher_later_part(self, + state, ksk, + src_device_id, + dst_device_id, + level, part_id): + + extended = self.extend(state, src_device_id, level, part_id, dst_device_id) + + self.ntt.ntt([extended], level, dst_device_id, -2) + + ksk_loc = self.parts_alloc[level][src_device_id][part_id] + ksk_part_data = ksk.data[ksk_loc].data + + start = self.ntt.starts[level][dst_device_id] + ksk0_data = ksk_part_data[0][dst_device_id][start:] + ksk1_data = ksk_part_data[1][dst_device_id][start:] + + d0 = self.ntt.mont_mult([extended], [ksk0_data], level, dst_device_id, -2) + d1 = self.ntt.mont_mult([extended], [ksk1_data], level, dst_device_id, -2) + + self.ntt.intt_exit_reduce(d0, level, dst_device_id, -2) + self.ntt.intt_exit_reduce(d1, level, dst_device_id, -2) + + return d0[0], d1[0] + + def switch_key(self, ct: data_struct, ksk: data_struct) -> data_struct: + include_special = ct.include_special + ntt_state = ct.ntt_state + montgomery_state = ct.montgomery_state + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + + level = ct.level + a = ct.data[1] + d0, d1 = self.create_switcher(a, ksk, level, exit_ntt=ct.ntt_state) + + new_ct0 = self.ntt.mont_add(ct.data[0], d0, level, -1) + + return data_struct( + data=(new_ct0, d1), + include_special=include_special, + ntt_state=ntt_state, + montgomery_state=montgomery_state, + origin=types.origins["ct"], + level=level, + hash=self.hash + ) + + # ------------------------------------------------------------------------------------------- + # Multiplication. + # ------------------------------------------------------------------------------------------- + + def rescale(self, ct: data_struct, exact_rounding=True) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + level = ct.level + next_level = level + 1 + + if next_level >= self.num_levels: + raise errors.MaximumLevelError(level=ct.level, level_max=self.num_levels) + + rescaler_device_id = self.ntt.p.rescaler_loc[level] + neighbor_devices_before = self.neighbor_devices[level] + neighbor_devices_after = self.neighbor_devices[next_level] + len_devices_after = len(neighbor_devices_after) + len_devices_before = len(neighbor_devices_before) + + rescaling_scales = self.rescale_scales[level] + data0 = [[] for _ in range(len_devices_after)] + data1 = [[] for _ in range(len_devices_after)] + + rescaler0 = [[] for _ in range(len_devices_before)] + rescaler1 = [[] for _ in range(len_devices_before)] + + rescaler0_at = ct.data[0][rescaler_device_id][0] + rescaler0[rescaler_device_id] = rescaler0_at + + rescaler1_at = ct.data[1][rescaler_device_id][0] + rescaler1[rescaler_device_id] = rescaler1_at + + if rescaler_device_id < len_devices_after: + data0[rescaler_device_id] = ct.data[0][rescaler_device_id][1:] + data1[rescaler_device_id] = ct.data[1][rescaler_device_id][1:] + + CPU_rescaler0 = self.ksk_buffers[0][0][0] + CPU_rescaler1 = self.ksk_buffers[0][1][0] + + CPU_rescaler0.copy_(rescaler0_at, non_blocking=False) + CPU_rescaler1.copy_(rescaler1_at, non_blocking=False) + + for device_id in neighbor_devices_before[rescaler_device_id]: + device = self.ntt.devices[device_id] + CUDA_rescaler0 = CPU_rescaler0.cuda(device) + CUDA_rescaler1 = CPU_rescaler1.cuda(device) + + rescaler0[device_id] = CUDA_rescaler0 + rescaler1[device_id] = CUDA_rescaler1 + + if device_id < len_devices_after: + data0[device_id] = ct.data[0][device_id] + data1[device_id] = ct.data[1][device_id] + + if exact_rounding: + rescale_channel_prime_id = self.ntt.p.destination_arrays[level][rescaler_device_id][0] + + round_at = self.ctx.q[rescale_channel_prime_id] // 2 + + rounder0 = [[] for _ in range(len_devices_before)] + rounder1 = [[] for _ in range(len_devices_before)] + + for device_id in range(len_devices_after): + rounder0[device_id] = torch.where(rescaler0[device_id] > round_at, 1, 0) + rounder1[device_id] = torch.where(rescaler1[device_id] > round_at, 1, 0) + + data0 = [(d - s) for d, s in zip(data0, rescaler0)] + data1 = [(d - s) for d, s in zip(data1, rescaler1)] + + self.ntt.mont_enter_scalar(data0, self.rescale_scales[level], next_level) + + self.ntt.mont_enter_scalar(data1, self.rescale_scales[level], next_level) + + if exact_rounding: + data0 = [(d + r) for d, r in zip(data0, rounder0)] + data1 = [(d + r) for d, r in zip(data1, rounder1)] + + self.ntt.reduce_2q(data0, next_level) + self.ntt.reduce_2q(data1, next_level) + + return data_struct( + data=(data0, data1), + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=next_level, + hash=self.hash, + version=self.version + ) + + def create_evk(self, sk: data_struct) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + sk2_data = self.ntt.mont_mult(sk.data, sk.data, 0, -2) + sk2 = data_struct( + data=sk2_data, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=sk.level, + hash=self.hash, + version=self.version + ) + + return self.create_key_switching_key(sk2, sk) + + def cc_mult(self, a: data_struct, b: data_struct, evk: data_struct, relin=True) -> data_struct: + if a.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=a.origin, to=types.origins["sk"]) + if b.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=b.origin, to=types.origins["sk"]) + # Rescale. + x = self.rescale(a) + y = self.rescale(b) + + level = x.level + + # Multiply. + x0 = x.data[0] + x1 = x.data[1] + + y0 = y.data[0] + y1 = y.data[1] + + self.ntt.enter_ntt(x0, level) + self.ntt.enter_ntt(x1, level) + self.ntt.enter_ntt(y0, level) + self.ntt.enter_ntt(y1, level) + + d0 = self.ntt.mont_mult(x0, y0, level) + + x0y1 = self.ntt.mont_mult(x0, y1, level) + x1y0 = self.ntt.mont_mult(x1, y0, level) + d1 = self.ntt.mont_add(x0y1, x1y0, level) + + d2 = self.ntt.mont_mult(x1, y1, level) + + ct_mult = data_struct( + data=(d0, d1, d2), + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash + ) + if relin: + ct_mult = self.relinearize(ct_triplet=ct_mult, evk=evk) + + return ct_mult + + def relinearize(self, ct_triplet: data_struct, evk: data_struct) -> data_struct: + if ct_triplet.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=ct_triplet.origin, to=types.origins["ctt"]) + if not ct_triplet.ntt_state or not ct_triplet.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct_triplet.origin) + + d0, d1, d2 = ct_triplet.data + level = ct_triplet.level + + # intt. + self.ntt.intt_exit_reduce(d0, level) + self.ntt.intt_exit_reduce(d1, level) + self.ntt.intt_exit_reduce(d2, level) + + # Key switch the x1y1. + d2_0, d2_1 = self.create_switcher(d2, evk, level) + + # Add the switcher to d0, d1. + d0 = [p + q for p, q in zip(d0, d2_0)] + d1 = [p + q for p, q in zip(d1, d2_1)] + + # Final reduction. + self.ntt.reduce_2q(d0, level) + self.ntt.reduce_2q(d1, level) + + # Compose and return. + return data_struct( + data=(d0, d1), + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash + ) + + # ------------------------------------------------------------------------------------------- + # Rotation. + # ------------------------------------------------------------------------------------------- + + def create_rotation_key(self, sk: data_struct, delta: int, a: list[torch.Tensor] = None) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + sk_new_data = [s.clone() for s in sk.data] + self.ntt.intt(sk_new_data) + sk_new_data = [rotate(s, delta) for s in sk_new_data] + self.ntt.ntt(sk_new_data) + sk_rotated = data_struct( + data=sk_new_data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + + rotk = self.create_key_switching_key(sk_rotated, sk, a=a) + rotk = rotk._replace(origin=types.origins["rotk"] + f"{delta}") + return rotk + + def rotate_single(self, ct: data_struct, rotk: data_struct) -> data_struct: + + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if types.origins["rotk"] not in rotk.origin: + raise errors.NotMatchType(origin=rotk.origin, to=types.origins["rotk"]) + + level = ct.level + include_special = ct.include_special + ntt_state = ct.ntt_state + montgomery_state = ct.montgomery_state + origin = rotk.origin + delta = int(origin.split(':')[-1]) + + rotated_ct_data = [[rotate(d, delta) for d in ct_data] for ct_data in ct.data] + + rotated_ct_rotated_sk = data_struct( + data=rotated_ct_data, + include_special=include_special, + ntt_state=ntt_state, + montgomery_state=montgomery_state, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + rotated_ct = self.switch_key(rotated_ct_rotated_sk, rotk) + return rotated_ct + + def create_galois_key(self, sk: data_struct) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + + galois_key_parts = [self.create_rotation_key(sk, delta) for delta in self.galois_deltas] + + galois_key = data_struct( + data=galois_key_parts, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["galk"], + level=0, + hash=self.hash, + version=self.version + ) + return galois_key + + def rotate_galois(self, ct: data_struct, gk: data_struct, delta: int, return_circuit=False) -> data_struct: + + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if gk.origin != types.origins["galk"]: + raise errors.NotMatchType(origin=gk.origin, to=types.origins["galk"]) + + current_delta = delta % (self.ctx.N // 2) + galois_circuit = [] + + while current_delta: + galois_ind = int(math.log2(current_delta)) + galois_delta = self.galois_deltas[galois_ind] + galois_circuit.append(galois_ind) + current_delta -= galois_delta + + if len(galois_circuit) > 0: + rotated_ct = self.rotate_single(ct, gk.data[galois_circuit[0]]) + + for delta_ind in galois_circuit[1:]: + rotated_ct = self.rotate_single(rotated_ct, gk.data[delta_ind]) + elif len(galois_circuit) == 0: + rotated_ct = ct + else: + pass + + if return_circuit: + return rotated_ct, galois_circuit + else: + return rotated_ct + + # ------------------------------------------------------------------------------------------- + # Add/sub. + # ------------------------------------------------------------------------------------------- + def cc_add_double(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ct"] or b.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ct"]) + if a.ntt_state or a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if b.ntt_state or b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + c0 = self.ntt.mont_add(a.data[0], b.data[0], level) + c1 = self.ntt.mont_add(a.data[1], b.data[1], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + data.extend([c0, c1]) + + return data_struct( + data=data, + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash + ) + + def cc_add_triplet(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ctt"] or b.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ctt"]) + if not a.ntt_state or not a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if not b.ntt_state or not b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + c0 = self.ntt.mont_add(a.data[0], b.data[0], level) + c1 = self.ntt.mont_add(a.data[1], b.data[1], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + data.extend([c0, c1]) + c2 = self.ntt.mont_add(a.data[2], b.data[2], level) + self.ntt.reduce_2q(c2, level) + data.append(c2) + + return data_struct( + data=data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash, + version=self.version + ) + + def cc_add(self, a: data_struct, b: data_struct) -> data_struct: + + if a.origin == types.origins["ct"] and b.origin == types.origins["ct"]: + ct_add = self.cc_add_double(a, b) + elif a.origin == types.origins["ctt"] and b.origin == types.origins["ctt"]: + ct_add = self.cc_add_triplet(a, b) + else: + raise errors.DifferentTypeError(a=a.origin, b=b.origin) + + return ct_add + + def cc_sub_double(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ct"] or b.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ct"]) + if a.ntt_state or a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if b.ntt_state or b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + + c0 = self.ntt.mont_sub(a.data[0], b.data[0], level) + c1 = self.ntt.mont_sub(a.data[1], b.data[1], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + data.extend([c0, c1]) + + return data_struct( + data=data, + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + def cc_sub_triplet(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != types.origins["ctt"] or b.origin != types.origins["ctt"]: + raise errors.NotMatchType(origin=f"{a.origin} and {b.origin}", to=types.origins["ctt"]) + if not a.ntt_state or not a.montgomery_state: + raise errors.NotMatchDataStructState(origin=a.origin) + if not b.ntt_state or not b.montgomery_state: + raise errors.NotMatchDataStructState(origin=b.origin) + + level = a.level + data = [] + c0 = self.ntt.mont_sub(a.data[0], b.data[0], level) + c1 = self.ntt.mont_sub(a.data[1], b.data[1], level) + c2 = self.ntt.mont_sub(a.data[2], b.data[2], level) + self.ntt.reduce_2q(c0, level) + self.ntt.reduce_2q(c1, level) + self.ntt.reduce_2q(c2, level) + data.extend([c0, c1, c2]) + + return data_struct( + data=data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash, + version=self.version + ) + + def cc_sub(self, a: data_struct, b: data_struct) -> data_struct: + if a.origin != b.origin: + raise Exception(f"[Error] triplet error") + + if types.origins["ct"] == a.origin and types.origins["ct"] == b.origin: + ct_sub = self.cc_sub_double(a, b) + elif a.origin == types.origins["ctt"] and b.origin == types.origins["ctt"]: + ct_sub = self.cc_sub_triplet(a, b) + else: + raise errors.DifferentTypeError(a=a.origin, b=b.origin) + return ct_sub + + def cc_subtract(self, a, b): + return self.cc_sub(a, b) + + # ------------------------------------------------------------------------------------------- + # Level up. + # ------------------------------------------------------------------------------------------- + def level_up(self, ct: data_struct, dst_level: int): + if types.origins["ct"] != ct.origin: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + + current_level = ct.level + + new_ct = self.rescale(ct) + + src_level = current_level + 1 + + dst_len_devices = len(self.ntt.p.destination_arrays[dst_level]) + + diff_deviation = self.deviations[dst_level] / np.sqrt(self.deviations[src_level]) + + deviated_delta = round(self.scale * diff_deviation) + + if dst_level - src_level > 0: + src_rns_lens = [len(d) for d in self.ntt.p.destination_arrays[src_level]] + dst_rns_lens = [len(d) for d in self.ntt.p.destination_arrays[dst_level]] + + diff_rns_lens = [y - x for x, y in zip(dst_rns_lens, src_rns_lens)] + + new_ct_data0 = [] + new_ct_data1 = [] + + for device_id in range(dst_len_devices): + new_ct_data0.append(new_ct.data[0][device_id][diff_rns_lens[device_id]:]) + new_ct_data1.append(new_ct.data[1][device_id][diff_rns_lens[device_id]:]) + else: + new_ct_data0, new_ct_data1 = new_ct.data + + multipliers = [] + for device_id in range(dst_len_devices): + dest = self.ntt.p.destination_arrays[dst_level][device_id] + q = [self.ctx.q[i] for i in dest] + + multiplier = [(deviated_delta * self.ctx.R) % qi for qi in q] + multiplier = torch.tensor(multiplier, dtype=self.ctx.torch_dtype, device=self.ntt.devices[device_id]) + multipliers.append(multiplier) + + self.ntt.mont_enter_scalar(new_ct_data0, multipliers, dst_level) + self.ntt.mont_enter_scalar(new_ct_data1, multipliers, dst_level) + + self.ntt.reduce_2q(new_ct_data0, dst_level) + self.ntt.reduce_2q(new_ct_data1, dst_level) + + new_ct = data_struct( + data=(new_ct_data0, new_ct_data1), + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=dst_level, + hash=self.hash, + version=self.version + ) + + return new_ct + + # ------------------------------------------------------------------------------------------- + # Fused enc/dec. + # ------------------------------------------------------------------------------------------- + def encodecrypt(self, m, pk: data_struct, level: int = 0, padding=True) -> data_struct: + if pk.origin != types.origins["pk"]: + raise errors.NotMatchType(origin=pk.origin, to=types.origins["pk"]) + + if padding: + m = self.padding(m=m) + + deviation = self.deviations[level] + pt = encode(m, scale=self.scale, + device=self.device0, norm=self.norm, + deviation=deviation, rng=self.rng, + return_without_scaling=self.bias_guard) + + if self.bias_guard: + dc_integral = pt[0].item() // 1 + pt[0] -= dc_integral + + dc_scale = int(dc_integral) * int(self.scale) + dc_rns = [] + for device_id, dest in enumerate(self.ntt.p.destination_arrays[level]): + dci = [dc_scale % self.ctx.q[i] for i in dest] + dci = torch.tensor(dci, + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id]) + dc_rns.append(dci) + + pt *= np.float64(self.scale) + pt = self.rng.randround(pt) + + encoded = [pt] + + pt_buffer = self.ksk_buffers[0][0][0] + pt_buffer.copy_(encoded[-1]) + for dev_id in range(1, self.ntt.num_devices): + encoded.append(pt_buffer.cuda(self.ntt.devices[dev_id])) + + mult_type = -2 if pk.include_special else -1 + + e0e1 = self.rng.discrete_gaussian(repeats=2) + + e0 = [e[0] for e in e0e1] + e1 = [e[1] for e in e0e1] + + e0_tiled = self.ntt.tile_unsigned(e0, level, mult_type) + e1_tiled = self.ntt.tile_unsigned(e1, level, mult_type) + + pt_tiled = self.ntt.tile_unsigned(encoded, level, mult_type) + + if self.bias_guard: + for device_id, pti in enumerate(pt_tiled): + pti[:, 0] += dc_rns[device_id] + + self.ntt.mont_enter_scale(pt_tiled, level, mult_type) + self.ntt.mont_redc(pt_tiled, level, mult_type) + pte0 = self.ntt.mont_add(pt_tiled, e0_tiled, level, mult_type) + + start = self.ntt.starts[level] + pk0 = [pk.data[0][di][start[di]:] for di in range(self.ntt.num_devices)] + pk1 = [pk.data[1][di][start[di]:] for di in range(self.ntt.num_devices)] + + v = self.rng.randint(amax=2, shift=0, repeats=1) + + v = self.ntt.tile_unsigned(v, level, mult_type) + self.ntt.enter_ntt(v, level, mult_type) + + vpk0 = self.ntt.mont_mult(v, pk0, level, mult_type) + vpk1 = self.ntt.mont_mult(v, pk1, level, mult_type) + + self.ntt.intt_exit(vpk0, level, mult_type) + self.ntt.intt_exit(vpk1, level, mult_type) + + ct0 = self.ntt.mont_add(vpk0, pte0, level, mult_type) + ct1 = self.ntt.mont_add(vpk1, e1_tiled, level, mult_type) + + self.ntt.reduce_2q(ct0, level, mult_type) + self.ntt.reduce_2q(ct1, level, mult_type) + + ct = data_struct( + data=(ct0, ct1), + include_special=mult_type == -2, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + return ct + + def decryptcode(self, ct: data_struct, sk: data_struct, is_real=False) -> data_struct: + if (not sk.ntt_state) or (not sk.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk.origin) + + level = ct.level + sk_data = sk.data[0][self.ntt.starts[level][0]:] + + if ct.origin == types.origins["ct"]: + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + + ct0 = ct.data[0][0] + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + pt = self.ntt.mont_add([ct0], sa, level) + self.ntt.reduce_2q(pt, level) + + elif ct.origin == types.origins["ctt"]: + if not ct.ntt_state or not ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + + d0 = [ct.data[0][0].clone()] + d1 = [ct.data[1][0]] + d2 = [ct.data[2][0]] + + self.ntt.intt_exit_reduce(d0, level) + + sk_data = [sk.data[0][self.ntt.starts[level][0]:]] + + d1_s = self.ntt.mont_mult(d1, sk_data, level) + + s2 = self.ntt.mont_mult(sk_data, sk_data, level) + d2_s2 = self.ntt.mont_mult(d2, s2, level) + + self.ntt.intt_exit(d1_s, level) + self.ntt.intt_exit(d2_s2, level) + + pt = self.ntt.mont_add(d0, d1_s, level) + pt = self.ntt.mont_add(pt, d2_s2, level) + self.ntt.reduce_2q(pt, level) + else: + raise errors.NotMatchType(origin=ct.origin, to=f"{types.origins['ct']} or {types.origins['ctt']}") + + base_at = -self.ctx.num_special_primes - 1 if ct.include_special else -1 + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + + len_left = len(self.ntt.p.destination_arrays[level][0]) + + if (len_left >= 3) and self.bias_guard: + dc0 = base[0][0].item() + dc1 = scaler[0][0].item() + dc2 = pt[0][1][0].item() + + base[0][0] = 0 + scaler[0][0] = 0 + + q0_ind = self.ntt.p.destination_arrays[level][0][base_at] + q1_ind = self.ntt.p.destination_arrays[level][0][0] + q2_ind = self.ntt.p.destination_arrays[level][0][1] + + q0 = self.ctx.q[q0_ind] + q1 = self.ctx.q[q1_ind] + q2 = self.ctx.q[q2_ind] + + Q = q0 * q1 * q2 + Q0 = q1 * q2 + Q1 = q0 * q2 + Q2 = q0 * q1 + + Qi0 = pow(Q0, -1, q0) + Qi1 = pow(Q1, -1, q1) + Qi2 = pow(Q2, -1, q2) + + dc = (dc0 * Qi0 * Q0 + dc1 * Qi1 * Q1 + dc2 * Qi2 * Q2) % Q + + half_Q = Q // 2 + dc = dc if dc <= half_Q else dc - Q + + dc = (dc + (q1 - 1)) // q1 + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + + # Decoding. + correction = self.corrections[level] + decoded = decode( + scaled[0][-1], + scale=self.scale, + correction=correction, + norm=self.norm, + return_without_scaling=self.bias_guard + ) + decoded = decoded[:self.ctx.N // 2].cpu().numpy() + ## + + decoded = decoded / self.scale * correction + + # Bias guard. + if (len_left >= 3) and self.bias_guard: + decoded += dc / self.scale * correction + if is_real: + decoded = decoded.real + return decoded + + # Shortcuts. + def encorypt(self, m, pk: data_struct, level: int = 0, padding=True): + return self.encodecrypt(m, pk=pk, level=level, padding=padding) + + def decrode(self, ct: data_struct, sk: data_struct, is_real=False): + return self.decryptcode(ct=ct, sk=sk, is_real=is_real) + + # ------------------------------------------------------------------------------------------- + # Conjugation + # ------------------------------------------------------------------------------------------- + + def create_conjugation_key(self, sk: data_struct) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if (not sk.ntt_state) or (not sk.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk.origin) + + sk_new_data = [s.clone() for s in sk.data] + self.ntt.intt(sk_new_data) + sk_new_data = [conjugate(s) for s in sk_new_data] + self.ntt.ntt(sk_new_data) + sk_rotated = data_struct( + data=sk_new_data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + rotk = self.create_key_switching_key(sk_rotated, sk) + rotk = rotk._replace(origin=types.origins["conjk"]) + return rotk + + def conjugate(self, ct: data_struct, conjk: data_struct): + level = ct.level + conj_ct_data = [[conjugate(d) for d in ct_data] for ct_data in ct.data] + + conj_ct_sk = data_struct( + data=conj_ct_data, + include_special=False, + ntt_state=False, + montgomery_state=False, + origin=types.origins["ct"], + level=level, + hash=self.hash, + version=self.version + ) + + conj_ct = self.switch_key(conj_ct_sk, conjk) + return conj_ct + + # ------------------------------------------------------------------------------------------- + # Clone. + # ------------------------------------------------------------------------------------------- + + def clone_tensors(self, data: data_struct) -> data_struct: + new_data = [] + # Some data has 1 depth. + if not isinstance(data[0], list): + for device_data in data: + new_data.append(device_data.clone()) + else: + for part in data: + new_data.append([]) + for device_data in part: + new_data[-1].append(device_data.clone()) + return new_data + + def clone(self, text): + if not isinstance(text.data[0], data_struct): + # data are tensors. + data = self.clone_tensors(text.data) + + wrapper = data_struct( + data=data, + include_special=text.include_special, + ntt_state=text.ntt_state, + montgomery_state=text.montgomery_state, + origin=text.origin, + level=text.level, + hash=text.hash, + version=text.version + ) + + else: + wrapper = data_struct( + data=[], + include_special=text.include_special, + ntt_state=text.ntt_state, + montgomery_state=text.montgomery_state, + origin=text.origin, + level=text.level, + hash=text.hash, + version=text.version + ) + + for d in text.data: + wrapper.data.append(self.clone(d)) + + return wrapper + + # ------------------------------------------------------------------------------------------- + # Move data back and forth from GPUs to the CPU. + # ------------------------------------------------------------------------------------------- + + def download_to_cpu(self, gpu_data, level, include_special): + # Prepare destination arrays. + if include_special: + dest = self.ntt.p.destination_arrays_with_special[level] + else: + dest = self.ntt.p.destination_arrays[level] + + # dest contain indices that are the absolute order of primes. + # Convert them to tensor channel indices at this level. + # That is, force them to start from zero. + min_ind = min([min(d) for d in dest]) + dest = [[di - min_ind for di in d] for d in dest] + + # Tensor size parameters. + num_rows = sum([len(d) for d in dest]) + num_cols = self.ctx.N + cpu_size = (num_rows, num_cols) + + # Make a cpu tensor to aggregate the data in GPUs. + cpu_tensor = torch.empty(cpu_size, dtype=self.ctx.torch_dtype, device='cpu') + + for ten, dest_i in zip(gpu_data, dest): + # Check if the tensor is in the gpu. + if ten.device.type != 'cuda': + raise Exception("To download data to the CPU, it must already be in a GPU!!!") + + # Copy in the data. + cpu_tensor[dest_i] = ten.cpu() + + # To avoid confusion, make a list with a single element (only one device, that is the CPU), + # and return it. + return [cpu_tensor] + + def upload_to_gpu(self, cpu_data, level, include_special): + # There's only one device data in the cpu data. + cpu_tensor = cpu_data[0] + + # Check if the tensor is in the cpu. + if cpu_tensor.device.type != 'cpu': + raise Exception("To upload data to GPUs, it must already be in the CPU!!!") + + # Prepare destination arrays. + if include_special: + dest = self.ntt.p.destination_arrays_with_special[level] + else: + dest = self.ntt.p.destination_arrays[level] + + # dest contain indices that are the absolute order of primes. + # Convert them to tensor channel indices at this level. + # That is, force them to start from zero. + min_ind = min([min(d) for d in dest]) + dest = [[di - min_ind for di in d] for d in dest] + + gpu_data = [] + for device_id in range(len(dest)): + # Copy in the data. + dest_device = dest[device_id] + device = self.ntt.devices[device_id] + gpu_tensor = cpu_tensor[dest_device].to(device=device) + + # Append to the gpu_data list. + gpu_data.append(gpu_tensor) + + return gpu_data + + def move_tensors(self, data, level, include_special, direction): + func = { + 'gpu2cpu': self.download_to_cpu, + 'cpu2gpu': self.upload_to_gpu + }[direction] + + # Some data has 1 depth. + if not isinstance(data[0], list): + moved = func(data, level, include_special) + new_data = moved + else: + new_data = [] + for part in data: + moved = func(part, level, include_special) + new_data.append(moved) + return new_data + + def move_to(self, text, direction='gpu2cpu'): + if not isinstance(text.data[0], data_struct): + level = text.level + include_special = text.include_special + + # data are tensors. + data = self.move_tensors(text.data, level, + include_special, direction) + + wrapper = data_struct( + data, text.include_special, + text.ntt_state, text.montgomery_state, + text.origin, text.level, text.hash, text.version + ) + + else: + wrapper = data_struct( + [], text.include_special, + text.ntt_state, text.montgomery_state, + text.origin, text.level, text.hash, text.version + ) + + for d in text.data: + moved = self.move_to(d, direction) + wrapper.data.append(moved) + + return wrapper + + # Shortcuts + + def cpu(self, ct): + return self.move_to(ct, 'gpu2cpu') + + def cuda(self, ct): + return self.move_to(ct, 'cpu2gpu') + + # ------------------------------------------------------------------------------------------- + # check device. + # ------------------------------------------------------------------------------------------- + + def tensor_device(self, data): + # Some data has 1 depth. + if not isinstance(data[0], list): + return data[0].device.type + else: + return data[0][0].device.type + + def device(self, text): + if not isinstance(text.data[0], data_struct): + # data are tensors. + return self.tensor_device(text.data) + else: + return self.device(text.data[0]) + + # ------------------------------------------------------------------------------------------- + # Print data structure. + # ------------------------------------------------------------------------------------------- + + def tree_lead_text(self, level, tabs=2, final=False): + final_char = "└" if final else "├" + + if level == 0: + leader = " " * tabs + trailer = "─" * tabs + lead_text = "─" * tabs + "┬" + trailer + + elif level < 0: + level = -level + leader = " " * tabs + trailer = "─" + "─" * (tabs - 1) + lead_fence = leader + "│" * (level - 1) + lead_text = lead_fence + final_char + trailer + + else: + leader = " " * tabs + trailer = "┬" + "─" * (tabs - 1) + lead_fence = leader + "│" * (level - 1) + lead_text = lead_fence + "├" + trailer + + return lead_text + + def print_data_shapes(self, data, level): + # Some data structures have 1 depth. + if isinstance(data[0], list): + for part_i, part in enumerate(data): + for device_id, d in enumerate(part): + device = self.ntt.devices[device_id] + + if (device_id == len(part) - 1) and \ + (part_i == len(data) - 1): + final = True + else: + final = False + + lead_text = self.tree_lead_text(-level, final=final) + + print(f"{lead_text} tensor at device {device} with " + f"shape {d.shape}.") + else: + for device_id, d in enumerate(data): + device = self.ntt.devices[device_id] + + if device_id == len(data) - 1: + final = True + else: + final = False + + lead_text = self.tree_lead_text(-level, final=final) + + print(f"{lead_text} tensor at device {device} with " + f"shape {d.shape}.") + + def print_data_structure(self, text, level=0): + lead_text = self.tree_lead_text(level) + print(f"{lead_text} {text.origin}") + + if not isinstance(text.data[0], data_struct): + self.print_data_shapes(text.data, level + 1) + else: + for d in text.data: + self.print_data_structure(d, level + 1) + + # ------------------------------------------------------------------------------------------- + # Save and load. + # ------------------------------------------------------------------------------------------- + + def auto_generate_filename(self, fmt_str='%Y%m%d%H%M%s%f'): + return datetime.datetime.now().strftime(fmt_str) + '.pkl' + + def save(self, text, filename=None): + if filename is None: + filename = self.auto_generate_filename() + + savepath = Path(filename) + + # Check if the text is in the CPU. + # If not, move to CPU. + if self.device(text) != 'cpu': + cpu_text = self.cpu(text) + else: + cpu_text = text + + with savepath.open('wb') as f: + pickle.dump(cpu_text, f) + + def load(self, filename, move_to_gpu=True): + savepath = Path(filename) + with savepath.open('rb') as f: + # gc.disable() + cpu_text = pickle.load(f) + # gc.enable() + + if move_to_gpu: + text = self.cuda(cpu_text) + else: + text = cpu_text + + return text + + # ------------------------------------------------------------------------------------------- + # Negate. + # ------------------------------------------------------------------------------------------- + + def negate(self, ct: data_struct) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) # ctt + new_ct = self.clone(ct) + + new_data = new_ct.data + for part in new_data: + for d in part: + d *= -1 + self.ntt.make_signed(part, ct.level) + + return new_ct + + # ------------------------------------------------------------------------------------------- + # scalar ops. + # ------------------------------------------------------------------------------------------- + + def mult_int_scalar(self, ct: data_struct, scalar, evk=None, relin=True): + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + + device_len = len(ct.data[0]) + + int_scalar = int(scalar) + mont_scalar = [(int_scalar * self.ctx.R) % qi for qi in self.ctx.q] + + dest = self.ntt.p.destination_arrays[ct.level] + + partitioned_mont_scalar = [[mont_scalar[i] for i in desti] for desti in dest] + tensorized_scalar = [] + for device_id in range(device_len): + scal_tensor = torch.tensor( + partitioned_mont_scalar[device_id], + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id] + ) + tensorized_scalar.append(scal_tensor) + + new_ct = self.clone(ct) + new_data = new_ct.data + for i in [0, 1]: + self.ntt.mont_enter_scalar(new_data[i], tensorized_scalar, ct.level) + self.ntt.reduce_2q(new_data[i], ct.level) + + return new_ct + + def mult_scalar(self, ct, scalar, evk=None, relin=True): + device_len = len(ct.data[0]) + + scaled_scalar = int( + scalar * self.scale * np.sqrt(self.deviations[ct.level + 1]) + 0.5) + + mont_scalar = [(scaled_scalar * self.ctx.R) % qi for qi in self.ctx.q] + + dest = self.ntt.p.destination_arrays[ct.level] + + partitioned_mont_scalar = [[mont_scalar[i] for i in dest_i] for dest_i in dest] + tensorized_scalar = [] + for device_id in range(device_len): + scal_tensor = torch.tensor( + partitioned_mont_scalar[device_id], + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id] + ) + tensorized_scalar.append(scal_tensor) + + new_ct = self.clone(ct) + new_data = new_ct.data + + for i in [0, 1]: + self.ntt.mont_enter_scalar(new_data[i], tensorized_scalar, ct.level) + self.ntt.reduce_2q(new_data[i], ct.level) + + return self.rescale(new_ct) + + def add_scalar(self, ct, scalar): + device_len = len(ct.data[0]) + + scaled_scalar = int(scalar * self.scale * self.deviations[ct.level] + 0.5) + + if self.norm == 'backward': + scaled_scalar *= self.ctx.N + + scaled_scalar *= self.int_scale + + scaled_scalar = [scaled_scalar % qi for qi in self.ctx.q] + + dest = self.ntt.p.destination_arrays[ct.level] + + partitioned_mont_scalar = [[scaled_scalar[i] for i in desti] for desti in dest] + tensorized_scalar = [] + for device_id in range(device_len): + scal_tensor = torch.tensor( + partitioned_mont_scalar[device_id], + dtype=self.ctx.torch_dtype, + device=self.ntt.devices[device_id] + ) + tensorized_scalar.append(scal_tensor) + + new_ct = self.clone(ct) + new_data = new_ct.data + + dc = [d[:, 0] for d in new_data[0]] + for device_id in range(device_len): + dc[device_id] += tensorized_scalar[device_id] + + self.ntt.reduce_2q(new_data[0], ct.level) + + return new_ct + + def sub_scalar(self, ct, scalar): + return self.add_scalar(ct, -scalar) + + def int_scalar_mult(self, scalar, ct, evk=None, relin=True): + return self.mult_int_scalar(ct, scalar) + + def scalar_mult(self, scalar, ct, evk=None, relin=True): + return self.mult_scalar(ct, scalar) + + def scalar_add(self, scalar, ct): + return self.add_scalar(ct, scalar) + + def scalar_sub(self, scalar, ct): + neg_ct = self.negate(ct) + return self.add_scalar(neg_ct, scalar) + + # ------------------------------------------------------------------------------------------- + # message ops. + # ------------------------------------------------------------------------------------------- + + def mc_mult(self, m, ct, evk=None, relin=True): + m = np.array(m) * np.sqrt(self.deviations[ct.level + 1]) + + pt = self.encode(m, 0) + + pt_tiled = self.ntt.tile_unsigned(pt, ct.level) + + # Transform ntt to prepare for multiplication. + self.ntt.enter_ntt(pt_tiled, ct.level) + + # Prepare a new ct. + new_ct = self.clone(ct) + + self.ntt.enter_ntt(new_ct.data[0], ct.level) + self.ntt.enter_ntt(new_ct.data[1], ct.level) + + new_d0 = self.ntt.mont_mult(pt_tiled, new_ct.data[0], ct.level) + new_d1 = self.ntt.mont_mult(pt_tiled, new_ct.data[1], ct.level) + + self.ntt.intt_exit_reduce(new_d0, ct.level) + self.ntt.intt_exit_reduce(new_d1, ct.level) + + new_ct.data[0] = new_d0 + new_ct.data[1] = new_d1 + + return self.rescale(new_ct) + + def mc_add(self, m, ct): + pt = self.encode(m, ct.level) + pt_tiled = self.ntt.tile_unsigned(pt, ct.level) + + self.ntt.mont_enter_scale(pt_tiled, ct.level) + + new_ct = self.clone(ct) + self.ntt.mont_enter(new_ct.data[0], ct.level) + new_d0 = self.ntt.mont_add(pt_tiled, new_ct.data[0], ct.level) + self.ntt.mont_redc(new_d0, ct.level) + self.ntt.reduce_2q(new_d0, ct.level) + + new_ct.data[0] = new_d0 + + return new_ct + + def mc_sub(self, m, ct): + neg_ct = self.negate(ct) + return self.mc_add(m, neg_ct) + + def cm_mult(self, ct, m, evk=None, relin=True): + return self.mc_mult(m, ct) + + def cm_add(self, ct, m): + return self.mc_add(m, ct) + + def cm_sub(self, ct, m): + return self.mc_add(-np.array(m), ct) + + # ------------------------------------------------------------------------------------------- + # Automatic cc ops. + # ------------------------------------------------------------------------------------------- + + def auto_level(self, ct0, ct1): + level_diff = ct0.level - ct1.level + if level_diff < 0: + new_ct0 = self.level_up(ct0, ct1.level) + return new_ct0, ct1 + elif level_diff > 0: + new_ct1 = self.level_up(ct1, ct0.level) + return ct0, new_ct1 + else: + return ct0, ct1 + + def auto_cc_mult(self, ct0, ct1, evk, relin=True): + lct0, lct1 = self.auto_level(ct0, ct1) + return self.cc_mult(lct0, lct1, evk, relin=relin) + + def auto_cc_add(self, ct0, ct1): + lct0, lct1 = self.auto_level(ct0, ct1) + return self.cc_add(lct0, lct1) + + def auto_cc_sub(self, ct0, ct1): + lct0, lct1 = self.auto_level(ct0, ct1) + return self.cc_sub(lct0, lct1) + + # ------------------------------------------------------------------------------------------- + # Fully automatic ops. + # ------------------------------------------------------------------------------------------- + + def mult(self, a, b, evk=None, relin=True): + type_a = type(a) + type_b = type(b) + + try: + func = self.mult_dispatch_dict[type_a, type_b] + except Exception as e: + raise Exception(f"Unsupported data types are input.\n{e}") + + return func(a, b, evk, relin) + + def add(self, a, b): + type_a = type(a) + type_b = type(b) + + try: + func = self.add_dispatch_dict[type_a, type_b] + except Exception as e: + raise Exception(f"Unsupported data types are input.\n{e}") + + return func(a, b) + + def sub(self, a, b): + type_a = type(a) + type_b = type(b) + + try: + func = self.sub_dispatch_dict[type_a, type_b] + except Exception as e: + raise Exception(f"Unsupported data types are input.\n{e}") + + return func(a, b) + + # ------------------------------------------------------------------------------------------- + # Misc. + # ------------------------------------------------------------------------------------------- + + def refresh(self): + # Refreshes the rng state. + self.rng.refresh() + + def reduce_error(self, ct): + # Reduce the accumulated error in the cipher text. + return self.mult_scalar(ct, 1.0) + + # ------------------------------------------------------------------------------------------- + # Misc ops. + # ------------------------------------------------------------------------------------------- + + def sum(self, ct: data_struct, gk: data_struct, rescale_every=5) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if gk.origin != types.origins["galk"]: + raise errors.NotMatchType(origin=gk.origin, to=types.origins["galk"]) + + new_ct = self.clone(ct) + for roti in range(self.ctx.logN - 1): + rot_ct = self.rotate_single(new_ct, gk.data[roti]) + sum_ct = self.add(rot_ct, new_ct) + del new_ct, rot_ct + if roti != 0 and (roti % rescale_every) == 0: + new_ct = self.reduce_error(sum_ct) + else: + new_ct = sum_ct + return new_ct + + def mean(self, ct: data_struct, gk: data_struct, alpha=1, rescale_every=5) -> data_struct: + # Divide by num_slots. + # The cipher text is refreshed here, and hence + # doesn't need to be refreshed at roti=0 in the loop. + new_ct = self.mult(1 / self.num_slots / alpha, ct) + + for roti in range(self.ctx.logN - 1): + rotk = gk.data[roti] + rot_ct = self.rotate_single(new_ct, rotk) + sum_ct = self.add(rot_ct, new_ct) + del new_ct, rot_ct + if ((roti % rescale_every) == 0) and (roti != 0): + new_ct = self.reduce_error(sum_ct) + else: + new_ct = sum_ct + return new_ct + + def cov(self, ct_a: data_struct, ct_b: data_struct, + evk: data_struct, gk: data_struct, rescale_every=5) -> data_struct: + cta_mean = self.mean(ct_a, gk, rescale_every=rescale_every) + ctb_mean = self.mean(ct_b, gk, rescale_every=rescale_every) + + cta_dev = self.sub(ct_a, cta_mean) + ctb_dev = self.sub(ct_b, ctb_mean) + + ct_cov = self.mult(self.mult(cta_dev, ctb_dev, evk), 1 / (self.num_slots - 1)) + + return ct_cov + + def pow(self, ct: data_struct, power: int, evk: data_struct) -> data_struct: + current_exponent = 2 + pow_list = [ct] + while current_exponent <= power: + current_ct = pow_list[-1] + new_ct = self.cc_mult(current_ct, current_ct, evk) + pow_list.append(new_ct) + current_exponent *= 2 + + remaining_exponent = power - current_exponent // 2 + new_ct = pow_list[-1] + + while remaining_exponent > 0: + pow_ind = math.floor(math.log2(remaining_exponent)) + pow_term = pow_list[pow_ind] + new_ct = self.auto_cc_mult(new_ct, pow_term, evk) + remaining_exponent -= 2 ** pow_ind + + return new_ct + + def square(self, ct: data_struct, evk: data_struct, relin=True) -> data_struct: + x = self.rescale(ct) + + level = x.level + + # Multiply. + x0, x1 = x.data + + self.ntt.enter_ntt(x0, level) + self.ntt.enter_ntt(x1, level) + + d0 = self.ntt.mont_mult(x0, x0, level) + x0y1 = self.ntt.mont_mult(x0, x1, level) + d2 = self.ntt.mont_mult(x1, x1, level) + + d1 = self.ntt.mont_add(x0y1, x0y1, level) + + ct_mult = data_struct( + data=(d0, d1, d2), + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ctt"], + level=level, + hash=self.hash, + version=self.version + ) + if relin: + ct_mult = self.relinearize(ct_triplet=ct_mult, evk=evk) + + return ct_mult + + # ------------------------------------------------------------------------------------------- + # Multiparty. + # ------------------------------------------------------------------------------------------- + def multiparty_public_crs(self, pk: data_struct): + crs = self.clone(pk).data[1] + return crs + + def multiparty_create_public_key(self, sk: data_struct, a=None, include_special=False) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if include_special and not sk.include_special: + raise errors.SecretKeyNotIncludeSpecialPrime() + mult_type = -2 if include_special else -1 + + level = 0 + e = self.rng.discrete_gaussian(repeats=1) + e = self.ntt.tile_unsigned(e, level, mult_type) + + self.ntt.enter_ntt(e, level, mult_type) + repeats = self.ctx.num_special_primes if sk.include_special else 0 + + if a is None: + a = self.rng.randint( + self.ntt.q_prepack[mult_type][level][0], + repeats=repeats + ) + + sa = self.ntt.mont_mult(a, sk.data, 0, mult_type) + pk0 = self.ntt.mont_sub(e, sa, 0, mult_type) + pk = data_struct( + data=(pk0, a), + include_special=include_special, + ntt_state=True, + montgomery_state=True, + origin=types.origins["pk"], + level=level, + hash=self.hash, + version=self.version + ) + return pk + + def multiparty_create_collective_public_key(self, pks: list[data_struct]) -> data_struct: + data, include_special, ntt_state, montgomery_state, origin, level, hash_, version = pks[0] + mult_type = -2 if include_special else -1 + b = [b.clone() for b in data[0]] # num of gpus + a = [a.clone() for a in data[1]] + + for pk in pks[1:]: + b = self.ntt.mont_add(b, pk.data[0], lvl=0, mult_type=mult_type) + + cpk = data_struct( + (b, a), + include_special=include_special, + ntt_state=ntt_state, + montgomery_state=montgomery_state, + origin=types.origins["pk"], + level=level, + hash=self.hash, + version=self.version + ) + return cpk + + def multiparty_decrypt_head(self, ct: data_struct, sk: data_struct): + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + if not sk.ntt_state or not sk.montgomery_state: + raise errors.NotMatchDataStructState(origin=sk.origin) + level = ct.level + + ct0 = ct.data[0][0] + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + + sk_data = sk.data[0][self.ntt.starts[level][0]:] + + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + pt = self.ntt.mont_add([ct0], sa, level) + + return pt + + def multiparty_decrypt_partial(self, ct: data_struct, sk: data_struct) -> data_struct: + if ct.origin != types.origins["ct"]: + raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + if ct.ntt_state or ct.montgomery_state: + raise errors.NotMatchDataStructState(origin=ct.origin) + if not sk.ntt_state or not sk.montgomery_state: + raise errors.NotMatchDataStructState(origin=sk.origin) + + data, include_special, ntt_state, montgomery_state, origin, level, hash_, version = ct + + a = ct.data[1][0].clone() + + self.ntt.enter_ntt([a], level) + + sk_data = sk.data[0][self.ntt.starts[level][0]:] + + sa = self.ntt.mont_mult([a], [sk_data], level) + self.ntt.intt_exit(sa, level) + + return sa + + def multiparty_decrypt_fusion(self, pcts: list, level=0, include_special=False): + pt = [x.clone() for x in pcts[0]] + for pct in pcts[1:]: + pt = self.ntt.mont_add(pt, pct, level) + + self.ntt.reduce_2q(pt, level) + + base_at = -self.ctx.num_special_primes - 1 if include_special else -1 + + base = pt[0][base_at][None, :] + scaler = pt[0][0][None, :] + + final_scalar = self.final_scalar[level] + scaled = self.ntt.mont_sub([base], [scaler], -1) + self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) + self.ntt.reduce_2q(scaled, -1) + self.ntt.make_signed(scaled, -1) + + m = self.decode(m=scaled, level=level) + + return m + + #### ------------------------------------------------------------------------------------------- + #### Multiparty. ROTATION + #### ------------------------------------------------------------------------------------------- + + def multiparty_create_key_switching_key(self, sk_src: data_struct, sk_dst: data_struct, a=None) -> data_struct: + if sk_src.origin != types.origins["sk"] or sk_src.origin != types.origins["sk"]: + raise errors.NotMatchType(origin="not a secret key", to=types.origins["sk"]) + if (not sk_src.ntt_state) or (not sk_src.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_src.origin) + if (not sk_dst.ntt_state) or (not sk_dst.montgomery_state): + raise errors.NotMatchDataStructState(origin=sk_dst.origin) + + level = 0 + + stops = self.ntt.stops[-1] + Psk_src = [sk_src.data[di][:stops[di]].clone() for di in range(self.ntt.num_devices)] + + self.ntt.mont_enter_scalar(Psk_src, self.mont_PR, level) + + ksk = [[] for _ in range(self.ntt.p.num_partitions + 1)] + for device_id in range(self.ntt.num_devices): + for part_id, part in enumerate(self.ntt.p.p[level][device_id]): + global_part_id = self.ntt.p.part_allocations[device_id][part_id] + + crs = a[global_part_id] if a else None + pk = self.multiparty_create_public_key(sk_dst, include_special=True, a=crs) + key = tuple(part) + astart = part[0] + astop = part[-1] + 1 + shard = Psk_src[device_id][astart:astop] + pk_data = pk.data[0][device_id][astart:astop] + + _2q = self.ntt.parts_pack[device_id][key]['_2q'] + update_part = ntt_cuda.mont_add([pk_data], [shard], _2q)[0] + pk_data.copy_(update_part, non_blocking=True) + + # Name the pk. + pk_name = f'key switch key part index {global_part_id}' + pk = pk._replace(origin=pk_name) + + ksk[global_part_id] = pk + + return data_struct( + data=ksk, + include_special=True, + ntt_state=True, + montgomery_state=True, + origin=types.origins["ksk"], + level=level, + hash=self.hash, + version=self.version + ) + + def multiparty_create_rotation_key(self, sk: data_struct, delta: int, a=None) -> data_struct: + sk_new_data = [s.clone() for s in sk.data] + self.ntt.intt(sk_new_data) + sk_new_data = [rotate(s, delta) for s in sk_new_data] + self.ntt.ntt(sk_new_data) + sk_rotated = data_struct( + data=sk_new_data, + include_special=False, + ntt_state=True, + montgomery_state=True, + origin=types.origins["sk"], + level=0, + hash=self.hash, + version=self.version + ) + rotk = self.multiparty_create_key_switching_key(sk_rotated, sk, a=a) + rotk = rotk._replace(origin=types.origins["rotk"] + f"{delta}") + return rotk + + def multiparty_generate_rotation_key(self, rotks: list[data_struct]) -> data_struct: + crotk = self.clone(rotks[0]) + for rotk in rotks[1:]: + for ksk_idx in range(len(rotk.data)): + update_parts = self.ntt.mont_add(crotk.data[ksk_idx].data[0], rotk.data[ksk_idx].data[0]) + crotk.data[ksk_idx].data[0][0].copy_(update_parts[0], non_blocking=True) + return crotk + + def generate_rotation_crs(self, rotk: data_struct): + if types.origins["rotk"] not in rotk.origin and types.origins["ksk"] != rotk.origin: + raise errors.NotMatchType(origin=rotk.origin, to=types.origins["ksk"]) + crss = [] + for ksk in rotk.data: + crss.append(ksk.data[1]) + return crss + + #### ------------------------------------------------------------------------------------------- + #### Multiparty. GALOIS + #### ------------------------------------------------------------------------------------------- + + def generate_galois_crs(self, galk: data_struct): + if galk.origin != types.origins["galk"]: + raise errors.NotMatchType(origin=galk.origin, to=types.origins["galk"]) + crs_s = [] + for rotk in galk.data: + crss = [ksk.data[1] for ksk in rotk.data] + crs_s.append(crss) + return crs_s + + def multiparty_create_galois_key(self, sk: data_struct, a: list) -> data_struct: + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + galois_key_parts = [ + self.multiparty_create_rotation_key(sk, self.galois_deltas[idx], a=a[idx]) + for idx in range(len(self.galois_deltas)) + ] + + galois_key = data_struct( + data=galois_key_parts, + include_special=True, + montgomery_state=True, + ntt_state=True, + origin=types.origins["galk"], + level=0, + hash=self.hash, + version=self.version + ) + return galois_key + + def multiparty_generate_galois_key(self, galks: list[data_struct]) -> data_struct: + cgalk = self.clone(galks[0]) + for galk in galks[1:]: # galk + for rotk_idx in range(len(galk.data)): # rotk + for ksk_idx in range(len(galk.data[rotk_idx].data)): # ksk + update_parts = self.ntt.mont_add( + cgalk.data[rotk_idx].data[ksk_idx].data[0], + galk.data[rotk_idx].data[ksk_idx].data[0] + ) + cgalk.data[rotk_idx].data[ksk_idx].data[0][0].copy_(update_parts[0], non_blocking=True) + return cgalk + + #### ------------------------------------------------------------------------------------------- + #### Multiparty. Evaluation Key + #### ------------------------------------------------------------------------------------------- + + def multiparty_sum_evk_share(self, evks_share: list[data_struct]): + evk_sum = self.clone(evks_share[0]) + for evk_share in evks_share[1:]: + for ksk_idx in range(len(evk_sum.data)): + update_parts = self.ntt.mont_add(evk_sum.data[ksk_idx].data[0], evk_share.data[ksk_idx].data[0]) + for dev_id in range(len(update_parts)): + evk_sum.data[ksk_idx].data[0][dev_id].copy_(update_parts[dev_id], non_blocking=True) + + return evk_sum + + def multiparty_mult_evk_share_sum(self, evk_sum: data_struct, sk: data_struct): + if sk.origin != types.origins["sk"]: + raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) + evk_sum_mult = self.clone(evk_sum) + + for ksk_idx in range(len(evk_sum.data)): + update_part_b = self.ntt.mont_mult(evk_sum_mult.data[ksk_idx].data[0], sk.data) + update_part_a = self.ntt.mont_mult(evk_sum_mult.data[ksk_idx].data[1], sk.data) + for dev_id in range(len(update_part_b)): + evk_sum_mult.data[ksk_idx].data[0][dev_id].copy_(update_part_b[dev_id], non_blocking=True) + evk_sum_mult.data[ksk_idx].data[1][dev_id].copy_(update_part_a[dev_id], non_blocking=True) + + return evk_sum_mult + + def multiparty_sum_evk_share_mult(self, evk_sum_mult: list[data_struct]) -> data_struct: + cevk = self.clone(evk_sum_mult[0]) + for evk in evk_sum_mult[1:]: + for ksk_idx in range(len(cevk.data)): + update_part_b = self.ntt.mont_add(cevk.data[ksk_idx].data[0], evk.data[ksk_idx].data[0]) + update_part_a = self.ntt.mont_add(cevk.data[ksk_idx].data[1], evk.data[ksk_idx].data[1]) + for dev_id in range(len(update_part_b)): + cevk.data[ksk_idx].data[0][dev_id].copy_(update_part_b[dev_id], non_blocking=True) + cevk.data[ksk_idx].data[1][dev_id].copy_(update_part_a[dev_id], non_blocking=True) + return cevk + + #### ------------------------------------------------------------------------------------------- + #### Statistics + #### ------------------------------------------------------------------------------------------- + + def sqrt(self, ct: data_struct, evk: data_struct, e=0.0001, alpha=0.0001) -> data_struct: + a = self.clone(ct) + b = self.clone(ct) + + while e <= 1 - alpha: + k = float(np.roots([1 - e ** 3, -6 + 6 * e ** 2, 9 - 9 * e])[1]) + t = self.mult_scalar(a, k, evk) + b0 = self.sub_scalar(t, 3) + b1 = self.mult_scalar(b, (k ** 0.5) / 2, evk) + b = self.cc_mult(b0, b1, evk) + + a0 = self.mult_scalar(a, (k ** 3) / 4) + t = self.sub_scalar(a, 3 / k) + a1 = self.square(t, evk) + a = self.cc_mult(a0, a1, evk) + e = k * (3 - k) ** 2 / 4 + + return b + + def var(self, ct: data_struct, evk: data_struct, gk: data_struct, relin=False) -> data_struct: + ct_mean = self.mean(ct=ct, gk=gk) + dev = self.sub(ct, ct_mean) + dev = self.square(ct=dev, evk=evk, relin=relin) + if not relin: + dev = self.relinearize(ct_triplet=dev, evk=evk) + ct_var = self.mean(ct=dev, gk=gk) + return ct_var + + def std(self, ct: data_struct, evk: data_struct, gk: data_struct, relin=False) -> data_struct: + ct_var = self.var(ct=ct, evk=evk, gk=gk, relin=relin) + ct_std = self.sqrt(ct=ct_var, evk=evk) + return ct_std diff --git a/liberate/fhe/context/__init__.py b/liberate/fhe/context/__init__.py new file mode 100644 index 0000000..d88f9f0 --- /dev/null +++ b/liberate/fhe/context/__init__.py @@ -0,0 +1,2 @@ +from .ckks_context import ckks_context +from .security_parameters import maximum_qbits, minimum_cyclotomic_order diff --git a/liberate/fhe/context/ckks_context.py b/liberate/fhe/context/ckks_context.py new file mode 100644 index 0000000..c6098a0 --- /dev/null +++ b/liberate/fhe/context/ckks_context.py @@ -0,0 +1,360 @@ +import math +import pickle +from pathlib import Path +import warnings + +import numpy as np +import torch + +from .generate_primes import generate_message_primes, generate_scale_primes +from .security_parameters import maximum_qbits +from liberate.fhe.cache import cache +from liberate.fhe.presets import errors + +# ------------------------------------------------------------------------------------------ +# NTT parameter pre-calculation. +# ------------------------------------------------------------------------------------------ +CACHE_FOLDER = cache.path_cache + + +def primitive_root_2N(q, N): + _2N = 2 * N + K = (q - 1) // _2N + for x in range(2, N): + g = pow(x, K, q) + h = pow(g, N, q) + if h != 1: + break + return g + + +def psi_power_series(psi, N, q): + series = [1] + for i in range(N - 1): + series.append(series[-1] * psi % q) + return series + + +def bit_rev_psi(q, logN): + N = 2 ** logN + psi = [primitive_root_2N(qi, N) for qi in q] + # Bit-reverse index. + ind = range(N) + brind = [bit_reverse(i, logN) for i in ind] + # The psi power and the indices are the same. + return [pow(psi, brpower, q) for brpower in brind] + + +def psi_bank(q, logN): + N = 2 ** logN + psi = [primitive_root_2N(qi, N) for qi in q] + ipsi = [pow(psii, -1, qi) for psii, qi in zip(psi, q)] + psi_series = [psi_power_series(psii, N, qi) for psii, qi in zip(psi, q)] + ipsi_series = [ + psi_power_series(ipsii, N, qi) for ipsii, qi in zip(ipsi, q) + ] + return psi_series, ipsi_series + + +def bit_reverse(a, nbits): + format_string = f"0{nbits}b" + binary_string = f"{a:{format_string}}" + reverse_binary_string = binary_string[::-1] + return int(reverse_binary_string, 2) + + +def bit_reverse_order_index(logN): + N = 2 ** logN + # Note that for a bit reversing, forward and backward permutations are the same. + # i.e., don't worry about which direction. + revi = np.array([bit_reverse(i, logN) for i in range(N)], dtype=np.int32) + return revi + + +def get_psi(q, logN, my_dtype): + np_dtype_dict = { + np.int32: np.int32, + np.int64: np.int64, + 30: np.int32, + 62: np.int64, + } + dtype = np_dtype_dict[my_dtype] + psi, ipsi = psi_bank(q, logN) + bit_reverse_index = bit_reverse_order_index(logN) + psi = np.array(psi, dtype=dtype)[:, bit_reverse_index] + ipsi = np.array(ipsi, dtype=dtype)[:, bit_reverse_index] + return psi, ipsi + + +def paint_butterfly_forward(logN): + N = 2 ** logN + t = N + painted_even = np.zeros((logN, N), dtype=np.bool8) + painted_odd = np.zeros((logN, N), dtype=np.bool8) + painted_psi = np.zeros((logN, N // 2), dtype=np.int32) + for logm in range(logN): + m = 2 ** logm + t //= 2 + psi_ind = 0 + for i in range(m): + j1 = 2 * i * t + j2 = j1 + t - 1 + Sind = m + i + for j in range(j1, j2 + 1): + Uind = j + Vind = j + t + painted_even[logm, Uind] = True + painted_odd[logm, Vind] = True + painted_psi[logm, psi_ind] = Sind + psi_ind += 1 + painted_eveni = np.where(painted_even)[1].reshape(logN, -1) + painted_oddi = np.where(painted_odd)[1].reshape(logN, -1) + return painted_eveni, painted_oddi, painted_psi + + +def paint_butterfly_backward(logN): + N = 2 ** logN + t = 1 + painted_even = np.zeros((logN, N), dtype=np.bool8) + painted_odd = np.zeros((logN, N), dtype=np.bool8) + painted_psi = np.zeros((logN, N // 2), dtype=np.int32) + for logm in range(logN, 0, -1): + level = logN - logm + m = 2 ** logm + j1 = 0 + h = m // 2 + psi_ind = 0 + for i in range(h): + j2 = j1 + t - 1 + Sind = h + i + for j in range(j1, j2 + 1): + Uind = j + Vind = j + t + # Paint + painted_even[level, Uind] = True + painted_odd[level, Vind] = True + painted_psi[level, psi_ind] = Sind + psi_ind += 1 + j1 += 2 * t + t *= 2 + painted_eveni = np.where(painted_even)[1].reshape(logN, -1) + painted_oddi = np.where(painted_odd)[1].reshape(logN, -1) + return painted_eveni, painted_oddi, painted_psi + + +# ------------------------------------------------------------------------------------------ +# The context class. +# ------------------------------------------------------------------------------------------ + + +@errors.log_error +class ckks_context: + def __init__( + self, + buffer_bit_length=62, + scale_bits=40, + logN=15, + num_scales=None, + num_special_primes=2, + sigma=3.2, + uniform_tenary_secret=True, + cache_folder=CACHE_FOLDER, + security_bits=128, + quantum="post_quantum", + distribution="uniform", + read_cache=True, + save_cache=True, + verbose=False, + is_secured=True + + ): + if not Path(cache_folder).exists(): + Path(cache_folder).mkdir(parents=True, exist_ok=True) + + self.generation_string = f"{buffer_bit_length}_{scale_bits}_{logN}_{num_scales}_" \ + f"{num_special_primes}_{security_bits}_{quantum}_" \ + f"{distribution}" + + self.is_secured = is_secured + # Compose cache savefile name. + savepath = Path(cache_folder) / Path(self.generation_string + ".pkl") + + if savepath.exists() and read_cache: + with savepath.open("rb") as f: + __dict__ = pickle.load(f) + self.__dict__.update(__dict__) + + if verbose: + print( + f"I have read in from the cached save file {savepath}!!!\n" + ) + self.init_print() + + return + + # Transfer input parameters. + self.buffer_bit_length = buffer_bit_length + self.scale_bits = scale_bits + self.logN = logN + self.num_special_primes = num_special_primes + self.cache_folder = cache_folder + self.security_bits = security_bits + self.quantum = quantum + self.distribution = distribution + # Sampling strategy. + self.sigma = sigma + self.uniform_tenary_secret = uniform_tenary_secret + if self.uniform_tenary_secret: + self.secret_key_sampling_method = "uniform tenary" + else: + self.secret_key_sampling_method = "sparse tenary" + + # dtypes. + self.torch_dtype = {30: torch.int32, 62: torch.int64}[ + self.buffer_bit_length + ] + self.numpy_dtype = {30: np.int32, 62: np.int64}[self.buffer_bit_length] + + # Polynomial length. + self.N = 2 ** self.logN + + # We set the message prime to of bit-length W-2. + self.message_bits = self.buffer_bit_length - 2 + + # Read in pre-calculated high-quality primes. + try: + message_special_primes = generate_message_primes(cache_folder=cache_folder)[self.message_bits][self.N] + except KeyError as e: + raise errors.NotFoundMessageSpecialPrimes(message_bit=self.message_bits, N=self.N) + + # For logN > 16, we need significantly more primes. + how_many = 64 if self.logN < 16 else 128 + try: + scale_primes = generate_scale_primes(cache_folder=cache_folder, how_many=how_many)[self.scale_bits, self.N] + except KeyError as e: + raise errors.NotFoundScalePrimes(scale_bits=self.scale_bits, N=self.N) + + # Compose the primes pack. + # Rescaling drops off primes in --> direction. + # Key switching drops off primes in <-- direction. + # Hence, [scale primes, base message prime, special primes] + self.max_qbits = int( + maximum_qbits(self.N, security_bits, quantum, distribution) + ) + base_special_primes = message_special_primes[: 1 + self.num_special_primes] + + # If num_scales is None, generate the maximal number of levels. + try: + if num_scales is None: + base_special_bits = sum( + [math.log2(p) for p in base_special_primes] + ) + available_bits = self.max_qbits - base_special_bits + num_scales = 0 + available_bits -= math.log2(scale_primes[num_scales]) + while available_bits > 0: + num_scales += 1 + available_bits -= math.log2(scale_primes[num_scales]) + + self.num_scales = num_scales + self.q = scale_primes[:num_scales] + base_special_primes + except IndexError as e: + raise errors.NotEnoughPrimes(scale_bits=self.scale_bits, N=self.N) + + # Check if security requirements are met. + self.total_qbits = math.ceil(sum([math.log2(qi) for qi in self.q])) + + if self.total_qbits > self.max_qbits: + if self.is_secured: + raise errors.ViolatedAllowedQbits( + scale_bits=self.scale_bits, N=self.N, num_scales=self.num_scales, + max_qbits=self.max_qbits, total_qbits=self.total_qbits) + else: + warnings.warn( + f"Maximum allowed qbits are violated: " + f"max_qbits={self.max_qbits:4d} and the " + f"requested total is {self.total_qbits:4d}." + ) + + # Generate Montgomery parameters and NTT paints. + self.generate_montgomery_parameters() + self.generate_paints() + + if verbose: + self.init_print() + + # Save cache. + if save_cache: + with savepath.open("wb") as f: + pickle.dump(self.__dict__, f) + + if verbose: + print(f"I have saved to the cached save file {savepath}!!!\n") + + def generate_montgomery_parameters(self): + self.R = 2 ** self.buffer_bit_length + self.R_square = [self.R ** 2 % qi for qi in self.q] + self.half_buffer_bit_length = self.buffer_bit_length // 2 + self.lower_bits_mask = (1 << self.half_buffer_bit_length) - 1 + self.full_bits_mask = (1 << self.buffer_bit_length) - 1 + + self.q_lower_bits = [qi & self.lower_bits_mask for qi in self.q] + self.q_higher_bits = [ + qi >> self.half_buffer_bit_length for qi in self.q + ] + self.q_double = [qi << 1 for qi in self.q] + + self.R_inv = [pow(self.R, -1, qi) for qi in self.q] + self.k = [ + (self.R * R_invi - 1) // qi + for R_invi, qi in zip(self.R_inv, self.q) + ] + self.k_lower_bits = [ki & self.lower_bits_mask for ki in self.k] + self.k_higher_bits = [ + ki >> self.half_buffer_bit_length for ki in self.k + ] + + def generate_paints(self): + self.N_inv = [pow(self.N, -1, qi) for qi in self.q] + + # psi and psi_inv. + psi, psi_inv = get_psi(self.q, self.logN, self.buffer_bit_length) + + # Paints. + ( + self.forward_even_indices, + self.forward_odd_indices, + forward_psi_paint, + ) = paint_butterfly_forward(self.logN) + ( + self.backward_even_indices, + self.backward_odd_indices, + backward_psi_paint, + ) = paint_butterfly_backward(self.logN) + + # Pre-painted psi and ipsi. + self.forward_psi = psi[..., forward_psi_paint.ravel()].reshape( + -1, *forward_psi_paint.shape + ) + self.backward_psi_inv = psi_inv[ + ..., backward_psi_paint.ravel() + ].reshape(-1, *backward_psi_paint.shape) + + def init_print(self): + print(f""" +I have received inputs: + buffer_bit_length\t\t= {self.buffer_bit_length:,d} + scale_bits\t\t\t= {self.scale_bits:,d} + logN\t\t\t\t= {self.logN:,d} + N\t\t\t\t= {self.N:,d} + Number of special primes\t= {self.num_special_primes:,d} + Number of scales\t\t= {self.num_scales:,d} + Cache folder\t\t\t= '{self.cache_folder:s}' + Security bits\t\t\t= {self.security_bits:,d} + Quantum security model\t\t= {self.quantum:s} + Security sampling distribution\t= {self.distribution:s} + Number of message bits\t\t= {self.message_bits:,d} + In total I will be using '{self.total_qbits:,d}' bits out of available maximum '{self.max_qbits:,d}' bits. + And is it secured?\t\t= {self.is_secured} +My RNS primes are {self.q}.""" + ) diff --git a/liberate/fhe/context/generate_primes.py b/liberate/fhe/context/generate_primes.py new file mode 100644 index 0000000..b912cf9 --- /dev/null +++ b/liberate/fhe/context/generate_primes.py @@ -0,0 +1,311 @@ +import math +import multiprocessing +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from joblib import Parallel, delayed + +from .prim_test import MillerRabinPrimalityTest +from .security_parameters import maximum_qbits +from liberate.fhe.cache import cache + +# Default cache folder. +CACHE_FOLDER = cache.path_cache + + +def generate_N_M(logN=None, cache_folder=CACHE_FOLDER, **kw): + if logN is None: + logN = list(range(12, 18)) + savefile = Path(cache_folder) / 'logN_N_M.pkl' + + if savefile.exists(): + with savefile.open('rb') as f: + logN_N_M = pickle.load(f) + logN = logN_N_M['logN'] + N = logN_N_M['N'] + M = logN_N_M['M'] + return logN, N, M + + N = [2 ** lN for lN in logN] + M = [2 * n for n in N] + + logN_N_M = { + 'logN': logN, + 'N': N, + 'M': M + } + + with savefile.open('wb') as f: + pickle.dump(logN_N_M, f) + + return logN, N, M + + +def check_ntt_primality(q: int, M: int): + # Is this in the KM+1 form? + NTT_comliance = (q - 1) % M + # It is compliant, go ahead. + if NTT_comliance == 0: + # Now, is q a prime? + is_prime = MillerRabinPrimalityTest(q) + if is_prime: + return True + return False + + +def generate_message_primes(mbits=None, cache_folder=CACHE_FOLDER, how_many=11, **kw): + if mbits is None: + mbits = [28, 60] + savefile = Path(cache_folder) / 'message_special_primes.pkl' + + if savefile.exists(): + with savefile.open('rb') as f: + mprimes = pickle.load(f) + # return mprimes + else: + logN, N, M = generate_N_M(cache_folder=cache_folder, **kw) + + mprimes = {} + for mb in mbits: + mprimes[mb] = {} + for m in M: + # We want to deal with N, and hide M implicitly. + N = m // 2 + mprimes[mb][N] = [] + current_query = 2 ** mb - 1 + q_count = 0 + + while True: + ok = check_ntt_primality(current_query, m) + if ok: + mprimes[mb][N].append(current_query) + q_count += 1 + + # Have we pulled out how_many primes? + if q_count == how_many: + break + + # Move onto the next query. + current_query -= 2 + + with savefile.open('wb') as f: + pickle.dump(mprimes, f) + + return mprimes + + +def maximum_levels(N: int, qbits: int = 40, mbits: int = 60, nksk: int = 2) -> int: + extra_bits = mbits * (1 + nksk) + f_levels = (maximum_qbits(N) - extra_bits) / qbits + return math.floor(f_levels) + + +def find_the_next_prime(start: int, m: int, up=True) -> int: + step: int = 2 if up else -2 + current_query: int = start + while True: + ok: bool = check_ntt_primality(q=current_query, M=m) + if ok: + break + current_query += step + return current_query + + +def generate_alternating_prime_sequence( + sb: int = 40, + N: int = 2 ** 15, + how_many: int = 60, + optimize: bool = True, + alternate_directions: bool = True, + fixed_direction: bool = False, +) -> list: + m: int = N * 2 + scale: int = 2 ** sb + + s_primes: list = [] + + up: int = scale + 1 + down: int = scale - 1 + + if alternate_directions: + # Initial search. + up0: int = find_the_next_prime(start=up, m=m) + down0: int = find_the_next_prime(start=down, m=m, up=False) + + # Initial error. + eup: int = up0 - scale + edown: int = scale - down0 + + # Set the current direction. + # True is up + # This is the NEXT direction. That means if the first item was up, then this direction is down. + current_direction: bool = False if eup < edown else True + + # Initialize count. + q_count: int = 0 + + cumulative_scale: int = 1 + + # The loop. + while True: + # Search in the given direction. + current_query: int = up if current_direction else down + next_prime: int = find_the_next_prime(start=current_query, m=m, up=current_direction) + # cumulative scale progresses as (beta_{i-1} * beta_i)**2 + + # cumulative_scale *= scale / next_prime + + # Use Pre-rescale quadratic deviation rule. + current_dev = scale / next_prime + cumulative_scale = cumulative_scale ** 2 * current_dev ** 2 + + # Set the next variable. + if current_direction: + up: int = next_prime + 2 + if optimize: + searched: int = int((cumulative_scale * scale) // 2 * 2 - 1) + down: int = searched if searched < down else down + else: + down: int = next_prime - 2 + if optimize: + searched: int = int((cumulative_scale * scale) // 2 * 2 + 1) + up: int = searched if searched > up else up + + # Switch the direction. + current_direction: bool = not current_direction + + # Store. + s_primes.append(next_prime) + q_count += 1 + + # Escape. + if q_count >= how_many: + break + + else: + q_count: int = 0 + current_query: int = up if fixed_direction else down + step: int = 2 if fixed_direction else -2 + while True: + current_query: int = find_the_next_prime( + start=current_query, m=m, up=fixed_direction + ) + s_primes.append(current_query) + # Move on. + q_count += 1 + current_query += step + + # Escape. + if q_count >= how_many: + break + return s_primes + + +def cum_prod(x: list) -> list: + ret: list = [1] + for i in range(len(x)): + ret.append(ret[-1] * x[i]) + return ret[1:] + + +def plot_cumulative_relative_error(sb: int = 40, N: int = 2 ** 15, label: str = "Optimized", **kw): + how_many: int = maximum_levels(N=N, qbits=sb) + p: list = generate_alternating_prime_sequence(sb=sb, N=N, how_many=how_many, **kw) + + # Check every prime in the sequence is unique. + unique_p: list = sorted(set(p)) + + err_msg = "There are repeating primes in the generate primes set!!!" + assert len(unique_p) == len(p), err_msg + + scale: int = 2 ** sb + e: list = [scale / pi for pi in p] + + y: list = cum_prod(e) + + plt.plot(y, label=label) + plt.grid() + plt.show() + # Error propagation. + q: np.array = np.array(y) - 1 + print(f"Error expanded {np.abs(q).max() / np.abs(q)[0]} times.") + + +def pgen_pseq(sb, N, how_many: int) -> list | str: + # TODO return 정리 + if how_many < 2: + return f"ERROR!!! sb = {sb}, N = {N}. Not enough primes." + + try: + res: list = generate_alternating_prime_sequence( + sb=sb, N=N, how_many=how_many + ) + except Exception as e: + # Try with the half how_many. + res: list = pgen_pseq(sb=sb, N=N, how_many=how_many // 2) + return res + + +def generate_scale_primes(cache_folder=CACHE_FOLDER, how_many=64, ncpu_cutdown=32, verbose=0): + savefile = Path(cache_folder) / 'scale_primes.pkl' + if savefile.exists(): + with open(savefile, 'rb') as f: + result_dict = pickle.load(f) + return result_dict + + ncpu = multiprocessing.cpu_count() + + # Cut down the number of n-cpus. It tends to slow down after 32. + ncpu = ncpu_cutdown if ncpu > ncpu_cutdown else ncpu + + logN, N, M = generate_N_M(cache_folder=cache_folder) + + # Scale. + logS = list(range(20, 55, 5)) + + # Generate input packages. + inputs = [] + for log_n, n in zip(logN, N): + how_many = 64 if log_n < 16 else 128 + for sb in logS: + inputs.append((sb, n, how_many)) + + result = Parallel(n_jobs=ncpu, + verbose=verbose)(delayed(pgen_pseq)(*inp) for inp in inputs) + + result_dict = {(sb, N): pr for (sb, N, how_many), pr in zip(inputs, result)} + + with savefile.open('wb') as f: + pickle.dump(result_dict, f) + + return result_dict + + +def measure_scale_primes_quality(sb: int = 40, N: int = 2 ** 15): + scale_primes: dict = generate_scale_primes() + + p: list = scale_primes[(sb, N)] + + # Check every prime in the sequence is unique. + unique_p = sorted(set(p)) + assert len(unique_p) == len( + p + ), "There are repeating primes in the generate primes set!!!" + + scale = 2 ** sb + e = [scale / pi for pi in p] + + y = cum_prod(e) + + plt.plot(y, label=f"Scale bits={sb}, logN={math.log2(N)}") + + # How many primes? + print(f"I have {len(p)} primes in the set.") + + # Error propagation. + q = np.array(y) - 1 + print(f"Max. relative error is {np.abs(q).max():.3e}.") + print(f"Min. relative error is {np.abs(q).min():.3e}.") + print(f"Error expanded {np.abs(q).max() / np.abs(q)[0]:.3f} times.") diff --git a/liberate/fhe/context/prim_test.py b/liberate/fhe/context/prim_test.py new file mode 100644 index 0000000..ead0a5b --- /dev/null +++ b/liberate/fhe/context/prim_test.py @@ -0,0 +1,64 @@ +import random + + +def MillerRabinPrimalityTest(number, rounds=10): + # If the input is an even number, return immediately with False. + if number == 2: + return True + elif number == 1 or number % 2 == 0: + return False + + # First we want to express n as : 2^s * r ( were r is odd ) + + # The odd part of the number + oddPartOfNumber = number - 1 + + # The number of time that the number is divided by two + timesTwoDividNumber = 0 + + # while r is even divid by 2 to find the odd part + while oddPartOfNumber % 2 == 0: + oddPartOfNumber = oddPartOfNumber / 2 + timesTwoDividNumber = timesTwoDividNumber + 1 + + # Make oddPartOfNumber integer. + oddPartOfNumber = int(oddPartOfNumber) + + # Since there are number that are cases of "strong liar" we need to check more than one number + for time in range(rounds): + # Choose "Good" random number + while True: + # Draw a RANDOM number in range of number ( Z_number ) + randomNumber = random.randint(2, number) - 1 + if randomNumber != 0 and randomNumber != 1: + break + + # randomNumberWithPower = randomNumber^oddPartOfNumber mod number + randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) + + # If random number is not 1 and not -1 ( in mod n ) + if (randomNumberWithPower != 1) and ( + randomNumberWithPower != number - 1 + ): + # number of iteration + iterationNumber = 1 + + # While we can squre the number and the squered number is not -1 mod number + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): + # Squre the number + randomNumberWithPower = pow(randomNumberWithPower, 2, number) + + # inc the number of iteration + iterationNumber = iterationNumber + 1 + + # If x != -1 mod number then it is because we did not find strong witnesses + # hence 1 have more then two roots in mod n ==> + # n is composite ==> return false for primality + + if randomNumberWithPower != (number - 1): + return False + + # The number pass the tests ==> it is probably prime ==> return true for primality + return True diff --git a/liberate/fhe/context/security_parameters.py b/liberate/fhe/context/security_parameters.py new file mode 100644 index 0000000..6057eaa --- /dev/null +++ b/liberate/fhe/context/security_parameters.py @@ -0,0 +1,201 @@ +from scipy.interpolate import InterpolatedUnivariateSpline + +# These are bits security levels, i.e., the 1^lambda measure. +security_levels = [128, 192, 256] + +# This is the dimension of the cyclotomic moduli, in +# ℤ[X]/Φ𝑚(𝑋), and m = 2^l. Where n is the leading (biggest) power of X in the polynomial Φ𝑚(𝑋). +# Such that, Φ𝑚(𝑋) = X^n + 1, where n = m / 2. +cyclotomic_n = [1024, 2048, 4096, 8192, 16384, 32768] + +# The following q is the moduli of a ring ℤq. Note that the numbers are given in log(q) values, +# where log in this context means log base 2. +# +# There are 2 sections in the standard documentation, namely pre- and post- quantum security. +# We separate them in respective dictionaries. +# +# Also, there are 3 different methods of sampling the messages according to respective distributions. +# Those are uniform, error, and (-1, 1) (tenary). +# We differentiate the message distribution by dictionary keys: 'uniform', 'error', and 'tenary'. + +# This is the pre-quantum security requirements. +logq_preq = {} +logq_preq["uniform"] = [ + 29, + 21, + 16, + 56, + 39, + 31, + 111, + 77, + 60, + 220, + 154, + 120, + 440, + 307, + 239, + 880, + 612, + 478, +] +logq_preq["error"] = [ + 29, + 21, + 16, + 56, + 39, + 31, + 111, + 77, + 60, + 220, + 154, + 120, + 440, + 307, + 239, + 883, + 613, + 478, +] +logq_preq["tenary"] = [ + 27, + 19, + 14, + 54, + 37, + 29, + 109, + 75, + 58, + 218, + 152, + 118, + 438, + 305, + 237, + 881, + 611, + 476, +] + +# This is the post-quantum security requirements. +logq_postq = {} +logq_postq["uniform"] = [ + 27, + 19, + 15, + 53, + 37, + 29, + 103, + 72, + 56, + 206, + 143, + 111, + 413, + 286, + 222, + 829, + 573, + 445, +] +logq_postq["error"] = [ + 27, + 19, + 15, + 53, + 37, + 29, + 103, + 72, + 56, + 206, + 143, + 111, + 413, + 286, + 222, + 829, + 573, + 445, +] +logq_postq["tenary"] = [ + 25, + 17, + 13, + 51, + 35, + 27, + 101, + 70, + 54, + 202, + 141, + 109, + 411, + 284, + 220, + 827, + 571, + 443, +] + + +# Partition q's by security levels. +def partitq(q): + qlen = len(q) + levlen = len(security_levels) + grouped = [ + [q[i] for i in range(0 + lev, qlen, levlen)] for lev in range(levlen) + ] + by_sec_lev = {lev: grouped[l] for l, lev in enumerate(security_levels)} + return by_sec_lev + + +# Gather up. +logq = {} +distributions = ["uniform", "error", "tenary"] +logq["pre_quantum"] = { + distributions[disti]: partitq(logq_preq[dist]) + for disti, dist in enumerate(distributions) +} +logq["post_quantum"] = { + distributions[disti]: partitq(logq_postq[dist]) + for disti, dist in enumerate(distributions) +} + + +def minimum_cyclotomic_order( + q_bits, security_bits=128, quantum="post_quantum", distribution="uniform" +): + assert quantum in [ + "pre_quantum", + "post_quantum", + ], "Wrong quantum security model!!!" + assert distribution in ["uniform", "error", "tenary"] + assert security_bits in [128, 192, 256] + + x = logq[quantum][distribution][security_bits] + y = cyclotomic_n + s = InterpolatedUnivariateSpline(x, y, k=1) + return s(q_bits) + + +def maximum_qbits( + L, security_bits=128, quantum="post_quantum", distribution="uniform" +): + assert quantum in [ + "pre_quantum", + "post_quantum", + ], "Wrong quantum security model!!!" + assert distribution in ["uniform", "error", "tenary"] + assert security_bits in [128, 192, 256] + + x = cyclotomic_n + y = logq[quantum][distribution][security_bits] + s = InterpolatedUnivariateSpline(x, y, k=1) + return s(L) diff --git a/liberate/fhe/data_struct.py b/liberate/fhe/data_struct.py new file mode 100644 index 0000000..3d82e12 --- /dev/null +++ b/liberate/fhe/data_struct.py @@ -0,0 +1,24 @@ +from typing import NamedTuple +from .version import VERSION + + +class data_struct(NamedTuple): + """ + - Data structure. + - data: the data in tensor format + - include_special: Boolean, including the special prime channels or not. + - ntt_state: Boolean, whether if the data is ntt transformed or not. + - montgomery_state: Boolean, whether if the data is in the Montgomery form or not. + - origin: String, where did this data came from - cipher text, secret key, etc. + - level: Integer, the current level where this data is situated. + - hash: String, a SHA256 hash of the input settings and the prime numbers used to RNS decompose numbers. + - version: String, version number. + """ + data: tuple | list + include_special: bool + ntt_state: bool + montgomery_state: bool + origin: str + level: int + hash: str + version: str = VERSION diff --git a/liberate/fhe/encdec/__init__.py b/liberate/fhe/encdec/__init__.py new file mode 100644 index 0000000..803a2da --- /dev/null +++ b/liberate/fhe/encdec/__init__.py @@ -0,0 +1 @@ +from .encdec import decode, encode, rotate, conjugate diff --git a/liberate/fhe/encdec/encdec.py b/liberate/fhe/encdec/encdec.py new file mode 100644 index 0000000..f0a3e26 --- /dev/null +++ b/liberate/fhe/encdec/encdec.py @@ -0,0 +1,323 @@ +import numpy as np +import torch + + +# --------------------------------------------------------------- +# Permutation. +# --------------------------------------------------------------- + +def circular_shift_permutation(N, shift=1): + left = np.roll(np.arange(N // 2), shift) + right = np.roll(np.arange(N // 2), -shift) + N // 2 + return np.concatenate([left, right]) + + +def canon_permutation(N, k=1, verbose=False): + """ + Permutes the coefficients of the lattice basis that yields correctly the permutation + of the decoded message. + + The canonical permutation is defined as mu_p(n) = pn mod M where p is coprime with M, + where p=2*k+1. + """ + M = 2 * N + p = int(2 * k + 1) # Make sure p is an integer. + n = np.arange(M) # n starts from 0. + pn = p * n % M + if verbose: + print(f"Canonical permutation for p={p} is\n{pn}") + return pn + + +def canon_permutation_torch(N, k=1, device='cuda:0', verbose=False): + """ + Permutes the coefficients of the lattice basis that yields correctly the permutation + of the decoded message. + + The canonical permutation is defined as mu_p(n) = pn mod M where p is coprime with M, + where p=2*k+1. + """ + M = N * 2 + p = int(2 * k + 1) # Make sure p is an integer. + n = torch.arange(N, device=device) # n starts from 0. + pn = p * n % M + if verbose: + print(f"Canonical permutation for p={p} is\n{pn}") + return pn + + +def fold_permutation(N, p, verbose=False): + """ + In application to crypto, we fold the FFT at Nyquist. + + Inverse FFT results in selection of alternating elements. + Folding should correct the indices of the permutation according to the + folding rule. + + For example, 1->0, 3->1, 5->2, and so on. + """ + fold_p = (p[1::2] - 1) // 2 + if verbose: + print(f"Folding\n{p}\nresulted in\n{fold_p}.") + return fold_p + + +def conjugate_permutation(p, q): + """ + Conjugate permutations p and q by stacking p on top of q. + + Permutations p and q must share the same cycle structures. + """ + # Calculate cycles. + pc = permutation_cycles(p) + qc = permutation_cycles(q) + + # Check if the cycle structures match. + cs1 = [len(c) for c in pc] + cs2 = [len(c) for c in qc] + assert ( + cs1 == cs2 + ), "Cycle structures of permutations must match for a conjugate to exist!!!" + + # Expand cycles. + pe = np.array([i for c in pc for i in c]) + qe = np.array([i for c in qc for i in c]) + + # Move slots. + r = np.zeros_like(p) + r[qe] = pe + + # Return. + return r + + +def permutation_cycles(perm): + """ + Transform a plain permutation into a composition of cycles. + """ + pi = {i: perm[i] for i in range(len(perm))} + cycles = [] + while pi: + elem0 = next(iter(pi)) # arbitrary starting element + this_elem = pi[elem0] + next_item = pi[this_elem] + + cycle = [] + while True: + cycle.append(this_elem) + del pi[this_elem] + this_elem = next_item + if next_item in pi: + next_item = pi[next_item] + else: + break + cycles.append(cycle) + return cycles + + +def inverse_permutation(p, verbose=False): + """ + Calculates the inverse permutation. + """ + N = len(p) + ind = np.arange(N) + ip = ind[np.argsort(p)] + if verbose: + print(f"The inverse of permutation\n{p}\nis\n{ip}.") + return ip + + +# --------------------------------------------------------------- +# Negacyclic fft. +# --------------------------------------------------------------- + + +def expand2conjugate(m): + return torch.concat([m, torch.flipud(torch.conj(m))]) + + +def generate_twister(N, device='cuda:0'): + expr = -1j * torch.pi * torch.arange(N, device=device, dtype=torch.float64) / N + return torch.exp(expr) + + +def generate_skewer(N, device='cuda:0'): + expr = 1j * torch.pi * torch.arange(N, device=device, dtype=torch.float64) / N + skew = torch.exp(expr) + return skew + + +def m2poly(m, twister, norm='backward'): + """ + m is the message and this function turns the message into + polynomial coefficients. + The message must be expanded mirrored in conjugacy. + """ + + # Run fft and multiply twister. + ffted = torch.fft.fft(m, norm=norm) + + # Twist. + twisted = ffted * twister + + # Return the real part. + return twisted.real + + +def poly2m(poly, skewer, norm='backward'): + """ + poly is the polynomial coefficients and this function turns the coefficients + into a plain message. + """ + + # Multiply skewer. + t = poly * skewer + + # Recover. + recovered = torch.fft.ifft(t, norm=norm) + + # Return the real part. + return recovered + + +# --------------------------------------------------------------- +# Utilities. +# --------------------------------------------------------------- + +perm_cache = {} +twister_cache = {} +skewer_cache = {} + + +def prepost_perms(N, device='cuda:0'): + circ_shift = circular_shift_permutation(N) + canon_perm = canon_permutation(N) + fold_perm = fold_permutation(N, canon_perm) + post_perm = conjugate_permutation(circ_shift, fold_perm) + pre_perm = inverse_permutation(post_perm)[:N // 2] + + post_perm = torch.from_numpy(post_perm).to(device) + pre_perm = torch.from_numpy(pre_perm).to(device) + return pre_perm, post_perm + + +def pre_permute(m, pre_perm): + """ + Input m must be a torch tensor. + """ + N = m.size(-1) + permed_m = torch.zeros((N * 2,), dtype=m.dtype, device=m.device) + permed_m[pre_perm] = m + conj_permed_m = permed_m + permed_m.conj().flip(0) + return conj_permed_m + + +def post_permute(m, post_perm): + """ + Input m must be a torch tensor. + """ + permed_m = torch.zeros_like(m) + permed_m[post_perm] = m + return permed_m + + +def rotate(m, delta): + N = m.size(-1) + C = m.numel() // N + + shift = delta % N + leap = (3 ** shift - 1) // 2 % (N * 2) + + if (N, leap, m.device) in perm_cache.keys(): + perm = perm_cache[(N, leap, m.device)] + else: + perm = canon_permutation_torch(N, leap, device=m.device) + perm_cache[(N, leap, m.device)] = perm + + perm_folded = perm % N + perm_sign = (-1) ** (perm // N) + + # Permute! + # We want to both be capable of 2D and 1D tensors. + # Use view. + rot_m = torch.zeros_like(m) + rot_m.view(C, N).T[perm_folded] = (perm_sign * m).view(C, N).T + + return rot_m + + +def conjugate(m): + N = m.size(-1) + C = m.numel() // N + + leap = N - 1 + + if (N, leap, m.device) in perm_cache.keys(): + perm = perm_cache[(N, leap, m.device)] + else: + perm = canon_permutation_torch(N, leap, device=m.device) + perm_cache[(N, leap, m.device)] = perm + + perm_folded = perm % N + perm_sign = (-1) ** (perm // N) + + # Permute! + # We want to both be capable of 2D and 1D tensors. + # Use view. + rot_m = torch.zeros_like(m) + rot_m.view(C, N).T[perm_folded] = (perm_sign * m).view(C, N).T + + return rot_m + + +def encode(m, rng=None, scale=2 ** 40, deviation=1.0, + device='cuda:0', norm='forward', + return_without_scaling=False): + N = len(m) * 2 + if (N, device) in perm_cache.keys(): + pre_perm, post_perm = perm_cache[(N, device)] + else: + pre_perm, post_perm = prepost_perms(N, device=device) + perm_cache[(N, device)] = (pre_perm, post_perm) + + mm = torch.from_numpy(np.array(m * deviation)).to(device) # check dtype m * deviation + mm = pre_permute(mm, pre_perm) + + if (N, device) in twister_cache.keys(): + twister = twister_cache[N, device] + else: + twister = generate_twister(N, device) + twister_cache[N, device] = twister + + if return_without_scaling: + return m2poly(mm, twister, norm) + else: + mm = m2poly(mm, twister, norm) * np.float64(scale) + return rng.randround(mm) + + +def decode(m, scale=2 ** 40, + correction=1.0, norm='forward', + return_without_scaling=False): + N = len(m) + device = m.device.type + ':' + str(m.device.index) + if (N, device) in perm_cache.keys(): + pre_perm, post_perm = perm_cache[(N, device)] + else: + pre_perm, post_perm = prepost_perms(N, device=device) + perm_cache[(N, device)] = (pre_perm, post_perm) + + if (N, device) in skewer_cache.keys(): + skewer = skewer_cache[N, device] + else: + skewer = generate_skewer(N, device) + skewer_cache[N, device] = skewer + + if return_without_scaling: + mm = poly2m(m, skewer, norm=norm) + mm = post_permute(mm, post_perm) + return mm + else: + mm = poly2m(m, skewer, norm=norm) / scale * correction + mm = post_permute(mm, post_perm) + return mm diff --git a/liberate/fhe/presets/__init__.py b/liberate/fhe/presets/__init__.py new file mode 100644 index 0000000..dc776af --- /dev/null +++ b/liberate/fhe/presets/__init__.py @@ -0,0 +1,3 @@ +from .params import params +from . import types +from . import errors diff --git a/liberate/fhe/presets/errors.py b/liberate/fhe/presets/errors.py new file mode 100644 index 0000000..d914122 --- /dev/null +++ b/liberate/fhe/presets/errors.py @@ -0,0 +1,166 @@ +from functools import wraps +import logging + + +def log_error(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logging.error(f"[Error] Error in {func.__name__} : {e}") + raise + + return wrapper + + +class TestException(Exception): + def __init__(self): + message_error = "test error" + super().__init__(message_error) + + +class NotFoundMessageSpecialPrimes(Exception): + def __init__(self, message_bit, N): + self.message_error = f"""Can't find message_bit = {message_bit:3 +#include + +//------------------------------------------------------------------ +// Main functions +//------------------------------------------------------------------ + +torch::Tensor mont_mult_cuda(const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void mont_enter_cuda(torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void ntt_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void enter_ntt_cuda(torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void intt_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + + +void mont_redc_cuda(torch::Tensor a, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + + +void intt_exit_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void intt_exit_reduce_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void intt_exit_reduce_signed_cuda(torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh); + +void reduce_2q_cuda(torch::Tensor a, + const torch::Tensor _2q); + +void make_signed_cuda(torch::Tensor a, + const torch::Tensor _2q); + +void make_unsigned_cuda(torch::Tensor a, + const torch::Tensor _2q); + +torch::Tensor mont_add_cuda(const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor _2q); + +torch::Tensor mont_sub_cuda(const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor _2q); + +torch::Tensor tile_unsigned_cuda(torch::Tensor a, + const torch::Tensor _2q); + + +//------------------------------------------------------------------ +// Main functions +//------------------------------------------------------------------ + +std::vector mont_mult( + const std::vector a, + const std::vector b, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + std::vector outputs; + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector Rs, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector Rs, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector even, + const std::vector odd, + const std::vector psi, + const std::vector Ninv, + const std::vector _2q, + const std::vector ql, + const std::vector qh, + const std::vector kl, + const std::vector kh) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector _2q) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector _2q) { + + const auto num_devices = a.size(); + for (int i=0; i a, + const std::vector _2q) { + + const auto num_devices = a.size(); + for (int i=0; i mont_add( + const std::vector a, + const std::vector b, + const std::vector _2q) { + + std::vector outputs; + + const auto num_devices = a.size(); + for (int i=0; i mont_sub( + const std::vector a, + const std::vector b, + const std::vector _2q) { + + std::vector outputs; + + const auto num_devices = a.size(); + for (int i=0; i tile_unsigned( + std::vector a, + const std::vector _2q) { + + std::vector outputs; + + const auto num_devices = _2q.size(); + for (int i=0; i FORWARD NTT"); + m.def("intt", &intt, "INVERSE NTT"); + m.def("mont_redc", &mont_redc, "MONTGOMERY REDUCTION"); + m.def("intt_exit", &intt_exit, "INVERSE NTT -> EXIT"); + m.def("intt_exit_reduce", &intt_exit_reduce, "INVERSE NTT -> EXIT -> REDUCE"); + m.def("intt_exit_reduce_signed", &intt_exit_reduce_signed, "INVERSE NTT -> EXIT -> REDUCE -> MAKE SIGNED"); + m.def("reduce_2q", &reduce_2q, "REDUCE RANGE TO 2q"); + m.def("make_signed", &make_signed, "MAKE SIGNED"); + m.def("make_unsigned", &make_unsigned, "MAKE UNSIGNED"); + m.def("mont_add", &mont_add, "MONTGOMERY ADDITION"); + m.def("mont_sub", &mont_sub, "MONTGOMERY SUBTRACTION"); + m.def("tile_unsigned", &tile_unsigned, "TILE -> MAKE UNSIGNED"); +} diff --git a/liberate/ntt/ntt_context.py b/liberate/ntt/ntt_context.py new file mode 100644 index 0000000..7dc774b --- /dev/null +++ b/liberate/ntt/ntt_context.py @@ -0,0 +1,523 @@ +import datetime +import time + +import numpy as np +import torch + +from liberate.fhe.presets import errors +from . import ntt_cuda +from .rns_partition import rns_partition + + +@errors.log_error +class ntt_context: + def __init__(self, ctx, index_type=torch.int32, devices=None, verbose=False): + + # Mark the start time. + t0 = time.time() + + # Set devices first. + if devices is None: + gpu_count = torch.cuda.device_count() + self.devices = [f'cuda:{i}' for i in range(gpu_count)] + else: + self.devices = devices + + self.num_devices = len(self.devices) + + # Transfer input parameters. + self.index_type = index_type + self.verbose = verbose + self.ctx = ctx + + if self.verbose: + print(f"[{str(datetime.datetime.now())}] I have received the context:\n") + self.ctx.init_print() + print(f"[{str(datetime.datetime.now())}] Requested devices for computation are {self.devices}.") + + self.num_ordinary_primes = self.ctx.num_scales + 1 + self.num_special_primes = self.ctx.num_special_primes + self.num_levels = self.ctx.num_scales + 1 + + self.p = rns_partition(self.num_ordinary_primes, + self.num_special_primes, self.num_devices) + if verbose: + print(f"[{str(datetime.datetime.now())}] I have generated a partitioning scheme.") + print(f"[{str(datetime.datetime.now())}] I have in total {self.num_levels} levels available.") + print(f"[{str(datetime.datetime.now())}] I have {self.num_ordinary_primes} ordinary primes.") + print(f"[{str(datetime.datetime.now())}] I have {self.num_special_primes} special primes.") + + self.prepare_parameters() + + if verbose: + print(f"[{str(datetime.datetime.now())}] I prepared ntt parameters.") + + t1 = time.time() + if verbose: + print(f'[{str(datetime.datetime.now())}] ntt initialization took {(t1 - t0):.2f} seconds.') + + self.qlists = [qi.tolist() for qi in self.q] + + astop_special = [len(d) for d in self.p.destination_arrays_with_special[0]] + astop_ordinary = [len(d) for d in self.p.destination_arrays[0]] + self.starts = self.p.diff + + self.stops = [astop_special, astop_ordinary] + + self.generate_parts_pack() + self.pre_package() + + # ------------------------------------------------------------------------------------------------- + # Arrange according to partitioning scheme input variables, and copy to GPUs for fast access. + # ------------------------------------------------------------------------------------------------- + + def partition_variable(self, variable): + + np_v = np.array(variable, dtype=self.ctx.numpy_dtype) + + v_special = [] + dest = self.p.d_special + for dev_id in range(self.num_devices): + d = dest[dev_id] + parted_v = np_v[d] + v_special.append( + torch.from_numpy(parted_v).to( + self.devices[dev_id])) + + return v_special + + def copy_to_devices(self, variable): + return [torch.tensor( + variable, dtype=self.index_type, device=device) for device in self.devices] + + def psi_enter(self): + Rs = self.Rs + ql = self.ql + qh = self.qh + kl = self.kl + kh = self.kh + + p = self.psi + + a = [psi.view(psi.size(0), -1) for psi in p] + + ntt_cuda.mont_enter(a, Rs, ql, qh, kl, kh) + + p = self.ipsi + a = [psi.view(psi.size(0), -1) for psi in p] + ntt_cuda.mont_enter(a, Rs, ql, qh, kl, kh) + + def Ninv_enter(self): + self.Ninv = [ + (self.ctx.N_inv[i] * self.ctx.R) % self.ctx.q[i] + for i in range(len(self.ctx.q))] + + def prepare_parameters(self): + + scale = 2 ** self.ctx.scale_bits + self.Rs_scale = self.partition_variable([ + (Rs * scale) % q for Rs, q in + zip(self.ctx.R_square, self.ctx.q) + ]) + + self.Rs = self.partition_variable(self.ctx.R_square) + + self.q = self.partition_variable(self.ctx.q) + self._2q = self.partition_variable(self.ctx.q_double) + self.ql = self.partition_variable(self.ctx.q_lower_bits) + self.qh = self.partition_variable(self.ctx.q_higher_bits) + self.kl = self.partition_variable(self.ctx.k_lower_bits) + self.kh = self.partition_variable(self.ctx.k_higher_bits) + + self.even = self.copy_to_devices(self.ctx.forward_even_indices) + self.odd = self.copy_to_devices(self.ctx.forward_odd_indices) + self.ieven = self.copy_to_devices(self.ctx.backward_even_indices) + self.iodd = self.copy_to_devices(self.ctx.backward_odd_indices) + + self.psi = self.partition_variable(self.ctx.forward_psi) + self.ipsi = self.partition_variable(self.ctx.backward_psi_inv) + + self.Ninv_enter() + self.Ninv = self.partition_variable(self.Ninv) + + self.psi_enter() + + self.mont_pack0 = [self.ql, self.qh, self.kl, self.kh] + + self.ntt_pack0 = [self.even, self.odd, + self.psi, self._2q, self.ql, self.qh, + self.kl, self.kh] + + self.intt_pack0 = [self.ieven, self.iodd, + self.ipsi, self.Ninv, self._2q, + self.ql, self.qh, + self.kl, self.kh] + + def param_pack(self, param, astart, astop, remove_empty=True): + pack = [param[dev_id][astart[dev_id]:astop[dev_id]] + for dev_id in range(self.num_devices)] + + remove_empty_f = lambda x: [xi for xi in x if len(xi) > 0] + if remove_empty: + pack = remove_empty_f(pack) + return pack + + def mont_pack(self, astart, astop, remove_empty=True): + return [self.param_pack( + param, astart, astop, remove_empty) for param in self.mont_pack0] + + def ntt_pack(self, astart, astop, remove_empty=True): + remove_empty_f_x = lambda x: [xi for xi in x if len(xi) > 0] + + remove_empty_f_xy = lambda x, y: [xi for xi, yi in zip(x, y) if len(yi) > 0] + + even_odd = self.ntt_pack0[:2] + rest = [self.param_pack( + param, astart, astop, remove_empty=False) for param in self.ntt_pack0[2:]] + + if remove_empty: + even_odd = [remove_empty_f_xy(eo, rest[0]) for eo in even_odd] + rest = [remove_empty_f_x(r) for r in rest] + + return even_odd + rest + + def intt_pack(self, astart, astop, remove_empty=True): + remove_empty_f_x = lambda x: [xi for xi in x if len(xi) > 0] + + remove_empty_f_xy = lambda x, y: [xi for xi, yi in zip(x, y) if len(yi) > 0] + + even_odd = self.intt_pack0[:2] + rest = [self.param_pack( + param, astart, astop, remove_empty=False) for param in self.intt_pack0[2:]] + + if remove_empty: + even_odd = [remove_empty_f_xy(eo, rest[0]) for eo in even_odd] + rest = [remove_empty_f_x(r) for r in rest] + + return even_odd + rest + + def start_stop(self, lvl, mult_type): + return self.starts[lvl], self.stops[mult_type] + + # ------------------------------------------------------------------------------------------------- + # Package by parts. + # ------------------------------------------------------------------------------------------------- + + def params_pack_device(self, device_id, astart, astop): + starts = [0] * self.num_devices + stops = [0] * self.num_devices + + starts[device_id] = astart + stops[device_id] = astop + 1 + + stst = [starts, stops] + + item = {} + + item['mont_pack'] = self.mont_pack(*stst) + item['ntt_pack'] = self.ntt_pack(*stst) + item['intt_pack'] = self.intt_pack(*stst) + item['Rs'] = self.param_pack(self.Rs, *stst) + item['Rs_scale'] = self.param_pack(self.Rs_scale, *stst) + item['_2q'] = self.param_pack(self._2q, *stst) + item['qlist'] = self.param_pack(self.qlists, *stst) + + return item + + def generate_parts_pack(self): + + blank_L_enter = [None] * self.num_devices + + self.parts_pack = [] + + for device_id in range(self.num_devices): + self.parts_pack.append({}) + + for i in range(len( + self.p.destination_arrays_with_special[0][device_id])): + self.parts_pack[device_id][i,] = self.params_pack_device( + device_id, i, i) + + for level in range(self.num_levels): + + for mult_type in [-1, -2]: + starts, stops = self.start_stop(level, mult_type) + astart = starts[device_id] + + astop = stops[device_id] - 1 + + key = tuple(range(astart, astop + 1)) + + if len(key) > 0: + if key not in self.parts_pack[device_id]: + self.parts_pack[device_id][key] = self.params_pack_device( + device_id, astart, astop) + + for p in self.p.p_special[level][device_id]: + key = tuple(p) + if key not in self.parts_pack[device_id].keys(): + astart = p[0] + astop = p[-1] + self.parts_pack[device_id][key] = self.params_pack_device( + device_id, astart, astop) + + for device_id in range(self.num_devices): + for level in range(self.num_levels): + # We do basis extension for only ordinary parts. + for part_index, part in enumerate(self.p.destination_parts[level][device_id]): + + key = tuple(self.p.p[level][device_id][part_index]) + + # Check if Y and L are already calculated for this part. + if 'Y_scalar' not in self.parts_pack[device_id][key].keys(): + + alpha = len(part) + m = [self.ctx.q[idx] for idx in part] + L = [m[0]] + + for i in range(1, alpha - 1): + L.append(L[-1] * m[i]) + + Y_scalar = [] + L_scalar = [] + for i in range(alpha - 1): + L_inv = pow(L[i], -1, m[i + 1]) + L_inv_R = (L_inv * self.ctx.R) % m[i + 1] + Y_scalar.append(L_inv_R) + + if (i + 2) < alpha: + L_scalar.append([]) + for j in range(i + 2, alpha): + L_scalar[i].append((L[i] * self.ctx.R) % m[j]) + + L_enter_devices = [] + for target_device_id in range(self.num_devices): + + dest = self.p.destination_arrays_with_special[0][target_device_id] + q = [self.ctx.q[idx] for idx in dest] + Rs = [self.ctx.R_square[idx] for idx in dest] + + L_enter = [] + for i in range(alpha - 1): + + L_enter.append([]) + for j in range(len(dest)): + L_Rs = (L[i] * Rs[j]) % q[j] + L_enter[i].append(L_Rs) + L_enter_devices.append(L_enter) + + device = self.devices[device_id] + + if len(Y_scalar) > 0: + Y_scalar = torch.tensor( + Y_scalar, dtype=self.ctx.torch_dtype, device=device) + self.parts_pack[device_id][key]['Y_scalar'] = Y_scalar + + for target_device_id in range(self.num_devices): + target_device = self.devices[target_device_id] + + L_enter_devices[target_device_id] = [ + torch.tensor( + Li, + dtype=self.ctx.torch_dtype, + device=target_device + ) for Li in L_enter_devices[target_device_id]] + + self.parts_pack[device_id][key]['L_enter'] = L_enter_devices + + else: + self.parts_pack[device_id][key]['Y_scalar'] = None + self.parts_pack[device_id][key]['L_enter'] = blank_L_enter + + if len(L_scalar) > 0: + L_scalar = [torch.tensor( + Li, dtype=self.ctx.torch_dtype, device=device) for Li in L_scalar] + self.parts_pack[device_id][key]['L_scalar'] = L_scalar + else: + self.parts_pack[device_id][key]['L_scalar'] = None + + # ------------------------------------------------------------------------------------------------- + # Pre-packaging. + # ------------------------------------------------------------------------------------------------- + def pre_package(self): + self.mont_prepack = [] + self.ntt_prepack = [] + self.intt_prepack = [] + self.Rs_prepack = [] + self.Rs_scale_prepack = [] + self._2q_prepack = [] + + # q_prepack is a list of lists, not tensors. + # We need this for generating uniform samples. + self.q_prepack = [] + + for device_id in range(self.num_devices): + mont_prepack = [] + ntt_prepack = [] + intt_prepack = [] + Rs_prepack = [] + Rs_scale_prepack = [] + _2q_prepack = [] + q_prepack = [] + for lvl in range(self.num_levels): + mont_prepack_part = [] + ntt_prepack_part = [] + intt_prepack_part = [] + Rs_prepack_part = [] + Rs_scale_prepack_part = [] + _2q_prepack_part = [] + q_prepack_part = [] + for part in self.p.p_special[lvl][device_id]: + key = tuple(part) + item = self.parts_pack[device_id][key] + + mont_prepack_part.append(item['mont_pack']) + ntt_prepack_part.append(item['ntt_pack']) + intt_prepack_part.append(item['intt_pack']) + Rs_prepack_part.append(item['Rs']) + Rs_scale_prepack_part.append(item['Rs_scale']) + _2q_prepack_part.append(item['_2q']) + q_prepack_part.append(item['qlist']) + + for mult_type in [-2, -1]: + starts, stops = self.start_stop(lvl, mult_type) + astart = starts[device_id] + + astop = stops[device_id] - 1 + + key = tuple(range(astart, astop + 1)) + + if len(key) > 0: + item = self.parts_pack[device_id][key] + + mont_prepack_part.append(item['mont_pack']) + ntt_prepack_part.append(item['ntt_pack']) + intt_prepack_part.append(item['intt_pack']) + Rs_prepack_part.append(item['Rs']) + Rs_scale_prepack_part.append(item['Rs_scale']) + _2q_prepack_part.append(item['_2q']) + q_prepack_part.append(item['qlist']) + + else: + mont_prepack_part.append(None) + ntt_prepack_part.append(None) + intt_prepack_part.append(None) + Rs_prepack_part.append(None) + Rs_scale_prepack_part.append(None) + _2q_prepack_part.append(None) + q_prepack_part.append(None) + + mont_prepack.append(mont_prepack_part) + ntt_prepack.append(ntt_prepack_part) + intt_prepack.append(intt_prepack_part) + Rs_prepack.append(Rs_prepack_part) + Rs_scale_prepack.append(Rs_scale_prepack_part) + _2q_prepack.append(_2q_prepack_part) + q_prepack.append(q_prepack_part) + + self.mont_prepack.append(mont_prepack) + self.ntt_prepack.append(ntt_prepack) + self.intt_prepack.append(intt_prepack) + self.Rs_prepack.append(Rs_prepack) + self.Rs_scale_prepack.append(Rs_scale_prepack) + self._2q_prepack.append(_2q_prepack) + self.q_prepack.append(q_prepack) + + for mult_type in [-2, -1]: + mont_prepack = [] + ntt_prepack = [] + intt_prepack = [] + Rs_prepack = [] + Rs_scale_prepack = [] + _2q_prepack = [] + q_prepack = [] + for lvl in range(self.num_levels): + stst = self.start_stop(lvl, mult_type) + mont_prepack.append([self.mont_pack(*stst)]) + ntt_prepack.append([self.ntt_pack(*stst)]) + intt_prepack.append([self.intt_pack(*stst)]) + Rs_prepack.append([self.param_pack(self.Rs, *stst)]) + Rs_scale_prepack.append([self.param_pack(self.Rs_scale, *stst)]) + _2q_prepack.append([self.param_pack(self._2q, *stst)]) + q_prepack.append([self.param_pack(self.qlists, *stst)]) + self.mont_prepack.append(mont_prepack) + self.ntt_prepack.append(ntt_prepack) + self.intt_prepack.append(intt_prepack) + self.Rs_prepack.append(Rs_prepack) + self.Rs_scale_prepack.append(Rs_scale_prepack) + self._2q_prepack.append(_2q_prepack) + self.q_prepack.append(q_prepack) + + # ------------------------------------------------------------------------------------------------- + # Helper functions to do the Montgomery and NTT operations. + # ------------------------------------------------------------------------------------------------- + + def mont_enter(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_enter(a, self.Rs_prepack[mult_type][lvl][part], + *self.mont_prepack[mult_type][lvl][part]) + + def mont_enter_scale(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_enter( + a, self.Rs_scale_prepack[mult_type][lvl][part], + *self.mont_prepack[mult_type][lvl][part]) + + def mont_enter_scalar(self, a, b, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_enter( + a, b, *self.mont_prepack[mult_type][lvl][part]) + + def mont_mult(self, a, b, lvl=0, mult_type=-1, part=0): + return ntt_cuda.mont_mult( + a, b, *self.mont_prepack[mult_type][lvl][part]) + + def ntt(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.ntt( + a, *self.ntt_prepack[mult_type][lvl][part]) + + def enter_ntt(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.enter_ntt( + a, self.Rs_prepack[mult_type][lvl][part], + *self.ntt_prepack[mult_type][lvl][part]) + + def intt(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt( + a, *self.intt_prepack[mult_type][lvl][part]) + + def mont_redc(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.mont_redc( + a, *self.mont_prepack[mult_type][lvl][part]) + + def intt_exit(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt_exit( + a, *self.intt_prepack[mult_type][lvl][part]) + + def intt_exit_reduce(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt_exit_reduce( + a, *self.intt_prepack[mult_type][lvl][part]) + + def intt_exit_reduce_signed(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.intt_exit_reduce_signed( + a, *self.intt_prepack[mult_type][lvl][part]) + + def reduce_2q(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.reduce_2q( + a, self._2q_prepack[mult_type][lvl][part]) + + def make_signed(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.make_signed( + a, self._2q_prepack[mult_type][lvl][part]) + + def make_unsigned(self, a, lvl=0, mult_type=-1, part=0): + ntt_cuda.make_unsigned( + a, self._2q_prepack[mult_type][lvl][part]) + + def mont_add(self, a, b, lvl=0, mult_type=-1, part=0): + return ntt_cuda.mont_add( + a, b, self._2q_prepack[mult_type][lvl][part]) + + def mont_sub(self, a, b, lvl=0, mult_type=-1, part=0): + return ntt_cuda.mont_sub( + a, b, self._2q_prepack[mult_type][lvl][part]) + + def tile_unsigned(self, a, lvl=0, mult_type=-1, part=0): + return ntt_cuda.tile_unsigned( + a, self._2q_prepack[mult_type][lvl][part]) diff --git a/liberate/ntt/ntt_cuda_kernel.cu b/liberate/ntt/ntt_cuda_kernel.cu new file mode 100644 index 0000000..fc96eea --- /dev/null +++ b/liberate/ntt/ntt_cuda_kernel.cu @@ -0,0 +1,1230 @@ +#include +#include +#include +#include + +#define BLOCK_SIZE 256 + +//------------------------------------------------------------------ +// pointwise mont_mult +//------------------------------------------------------------------ + +template __device__ __forceinline__ scalar_t +mont_mult_scalar_cuda_kernel( + const scalar_t a, const scalar_t b, + const scalar_t ql, const scalar_t qh, + const scalar_t kl, const scalar_t kh) { + + // Masks. + constexpr scalar_t one = 1; + constexpr scalar_t nbits = sizeof(scalar_t) * 8 - 2; + constexpr scalar_t half_nbits = sizeof(scalar_t) * 4 - 1; + constexpr scalar_t fb_mask = ((one << nbits) - one); + constexpr scalar_t lb_mask = (one << half_nbits) - one; + + const scalar_t al = a & lb_mask; + const scalar_t ah = a >> half_nbits; + const scalar_t bl = b & lb_mask; + const scalar_t bh = b >> half_nbits; + + const scalar_t alpha = ah * bh; + const scalar_t beta = ah * bl + al * bh; + const scalar_t gamma = al * bl; + + // s = xk mod R + const scalar_t gammal = gamma & lb_mask; + const scalar_t gammah = gamma >> half_nbits; + const scalar_t betal = beta & lb_mask; + const scalar_t betah = beta >> half_nbits; + + scalar_t upper = gammal * kh; + upper = upper + (gammah + betal) * kl; + upper = upper << half_nbits; + scalar_t s = upper + gammal * kl; + s = upper + gammal * kl; + s = s & fb_mask; + + // t = x + sq + // u = t/R + const scalar_t sl = s & lb_mask; + const scalar_t sh = s >> half_nbits; + const scalar_t sqb = sh * ql + sl * qh; + const scalar_t sqbl = sqb & lb_mask; + const scalar_t sqbh = sqb >> half_nbits; + + scalar_t carry = (gamma + sl * ql) >> half_nbits; + carry = (carry + betal + sqbl) >> half_nbits; + + return alpha + betah + sqbh + carry + sh * qh; +} + + +//------------------------------------------------------------------ +// mont_mult +//------------------------------------------------------------------ + +template +__global__ void mont_mult_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32b_acc, + torch::PackedTensorAccessor32c_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + const scalar_t a = a_acc[i][j]; + const scalar_t b = b_acc[i][j]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Store the result. + c_acc[i][j] = mont_mult_scalar_cuda_kernel(a, b, ql, qh, kl, kh); +} + +template +void mont_mult_cuda_typed( + const torch::Tensor a, + const torch::Tensor b, + torch::Tensor c, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + const auto a_acc = a.packed_accessor32(); + const auto b_acc = b.packed_accessor32(); + auto c_acc = c.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + mont_mult_cuda_kernel<<>>( + a_acc, b_acc, c_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + + +torch::Tensor mont_mult_cuda( + const torch::Tensor a, + const torch::Tensor b, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Prepare the output. + torch::Tensor c = torch::empty_like(a); + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_mult_cuda", ([&] { + mont_mult_cuda_typed(a, b, c, ql, qh, kl, kh); + })); + + return c; +} + + + +//------------------------------------------------------------------ +// mont_enter +//------------------------------------------------------------------ + +template +__global__ void mont_enter_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32Rs_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + const scalar_t a = a_acc[i][j]; + const scalar_t Rs = Rs_acc[i]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Store the result. + a_acc[i][j] = mont_mult_scalar_cuda_kernel(a, Rs, ql, qh, kl, kh); +} + + +template +void mont_enter_cuda_typed( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + const auto Rs_acc = Rs.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + mont_enter_cuda_kernel<<>>( + a_acc, Rs_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +void mont_enter_cuda( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_enter_cuda", ([&] { + mont_enter_cuda_typed(a, Rs, ql, qh, kl, kh); + })); +} + + + + + +//------------------------------------------------------------------ +// ntt +//------------------------------------------------------------------ + +template +__global__ void ntt_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32even_acc, + const torch::PackedTensorAccessor32odd_acc, + const torch::PackedTensorAccessor32psi_acc, + const torch::PackedTensorAccessor32_2q_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc, + const int level){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Montgomery inputs. + const scalar_t _2q = _2q_acc[i]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Butterfly. + const int even_j = even_acc[level][j]; + const int odd_j = odd_acc[level][j]; + + const scalar_t U = a_acc[i][even_j]; + const scalar_t S = psi_acc[i][level][j]; + const scalar_t O = a_acc[i][odd_j]; + const scalar_t V = mont_mult_scalar_cuda_kernel(S, O, ql, qh, kl, kh); + + // Store back. + const scalar_t UplusV = U + V; + const scalar_t UminusV = U + _2q - V; + + a_acc[i][even_j] = (UplusV < _2q)? UplusV : UplusV - _2q; + a_acc[i][odd_j] = (UminusV < _2q)? UminusV : UminusV - _2q; +} + + +template +void ntt_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N = even.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } +} + + + +void ntt_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_ntt_cuda", ([&] { + ntt_cuda_typed(a, even, odd, psi, _2q, ql, qh, kl, kh); + })); +} + + +//------------------------------------------------------------------ +// enter_ntt +//------------------------------------------------------------------ + +template +void enter_ntt_cuda_typed( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + const auto Rs_acc = Rs.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + // enter. + mont_enter_cuda_kernel<<>>( + a_acc, Rs_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // ntt. + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } +} + + +void enter_ntt_cuda( + torch::Tensor a, + const torch::Tensor Rs, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_enter_ntt_cuda", ([&] { + enter_ntt_cuda_typed(a, Rs, even, odd, psi, _2q, ql, qh, kl, kh); + })); +} + + + + + +//------------------------------------------------------------------ +// intt +//------------------------------------------------------------------ + +template +__global__ void intt_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32even_acc, + const torch::PackedTensorAccessor32odd_acc, + const torch::PackedTensorAccessor32psi_acc, + const torch::PackedTensorAccessor32_2q_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc, + const int level){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Montgomery inputs. + const scalar_t _2q = _2q_acc[i]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Butterfly. + const int even_j = even_acc[level][j]; + const int odd_j = odd_acc[level][j]; + + const scalar_t U = a_acc[i][even_j]; + const scalar_t S = psi_acc[i][level][j]; + const scalar_t V = a_acc[i][odd_j]; + + const scalar_t UminusV = U + _2q - V; + const scalar_t O = (UminusV < _2q)? UminusV : UminusV - _2q; + + const scalar_t W = mont_mult_scalar_cuda_kernel(S, O, ql, qh, kl, kh); + a_acc[i][odd_j] = W; + + const scalar_t UplusV = U + V; + a_acc[i][even_j] = (UplusV < _2q)? UplusV : UplusV - _2q; +} + + +template +void intt_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +void intt_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_cuda", ([&] { + intt_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + + + + + + +//------------------------------------------------------------------ +// mont_redc +//------------------------------------------------------------------ + +template +__global__ void mont_redc_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32ql_acc, + const torch::PackedTensorAccessor32qh_acc, + const torch::PackedTensorAccessor32kl_acc, + const torch::PackedTensorAccessor32kh_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Masks. + constexpr scalar_t one = 1; + constexpr scalar_t nbits = sizeof(scalar_t) * 8 - 2; + constexpr scalar_t half_nbits = sizeof(scalar_t) * 4 - 1; + constexpr scalar_t fb_mask = ((one << nbits) - one); + constexpr scalar_t lb_mask = (one << half_nbits) - one; + + // Inputs. + const scalar_t x = a_acc[i][j]; + const scalar_t ql = ql_acc[i]; + const scalar_t qh = qh_acc[i]; + const scalar_t kl = kl_acc[i]; + const scalar_t kh = kh_acc[i]; + + // Implementation. + // s= xk mod R + const scalar_t xl = x & lb_mask; + const scalar_t xh = x >> half_nbits; + const scalar_t xkb = xh * kl + xl * kh; + scalar_t s = (xkb << half_nbits) + xl * kl; + s = s & fb_mask; + + // t = x + sq + // u = t/R + // Note that x gets erased in t/R operation if x < R. + const scalar_t sl = s & lb_mask; + const scalar_t sh = s >> half_nbits; + const scalar_t sqb = sh * ql + sl * qh; + const scalar_t sqbl = sqb & lb_mask; + const scalar_t sqbh = sqb >> half_nbits; + scalar_t carry = (x + sl * ql) >> half_nbits; + carry = (carry + sqbl) >> half_nbits; + + // Assume we have satisfied the condition 4*q < R. + // Return the calculated value directly without conditional subtraction. + a_acc[i][j] = sqbh + carry + sh * qh; +} + + +template +void mont_redc_cuda_typed( + torch::Tensor a, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +void mont_redc_cuda( + torch::Tensor a, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_redc_cuda", ([&] { + mont_redc_cuda_typed(a, ql, qh, kl, kh); + })); +} + + +//------------------------------------------------------------------ +// Chained intt series. +//------------------------------------------------------------------ + +/**************************************************************/ +/* CUDA kernels */ +/**************************************************************/ + +template +__global__ void reduce_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t q = _2q_acc[i] >> one; + + // Reduce. + a_acc[i][j] = (a < q)? a : a - q; +} + +template +__global__ void make_signed_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t q = _2q_acc[i] >> one; + const scalar_t q_half = q >> one; + + // Make signed. + a_acc[i][j] = (a <= q_half)? a : a - q; +} + + +/**************************************************************/ +/* Typed functions */ +/**************************************************************/ + +/////////////////////////////////////////////////////////////// +// intt exit + +template +void intt_exit_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Exit. + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); +} + +/////////////////////////////////////////////////////////////// +// intt exit reduce + +template +void intt_exit_reduce_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Exit. + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Reduce. + reduce_cuda_kernel<<>>(a_acc, _2q_acc); +} + + +/////////////////////////////////////////////////////////////// +// intt exit reduce signed + +template +void intt_exit_reduce_signed_cuda_typed( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Retrieve the device index, then set the corresponding device and stream. + auto device_id = a.device().index(); + cudaSetDevice(device_id); + + // Use a preallocated pytorch stream. + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + // The problem dimension. + // Be careful. even and odd has half the length of the a. + const auto C = ql.size(0); + const auto logN = even.size(0); + const auto N_half = even.size(1); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid_ntt (C, N_half / BLOCK_SIZE); + dim3 dim_grid_enter (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + auto a_acc = a.packed_accessor32(); + + const auto even_acc = even.packed_accessor32(); + const auto odd_acc = odd.packed_accessor32(); + const auto psi_acc = psi.packed_accessor32(); + const auto Ninv_acc = Ninv.packed_accessor32(); + + const auto _2q_acc = _2q.packed_accessor32(); + const auto ql_acc = ql.packed_accessor32(); + const auto qh_acc = qh.packed_accessor32(); + const auto kl_acc = kl.packed_accessor32(); + const auto kh_acc = kh.packed_accessor32(); + + for(int i=0; i<<>>( + a_acc, even_acc, odd_acc, psi_acc, + _2q_acc, ql_acc, qh_acc, kl_acc, kh_acc, i); + } + + // Normalize. + mont_enter_cuda_kernel<<>>( + a_acc, Ninv_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Exit. + mont_redc_cuda_kernel<<>>( + a_acc, ql_acc, qh_acc, kl_acc, kh_acc); + + // Reduce. + reduce_cuda_kernel<<>>(a_acc, _2q_acc); + + // Make signed. + make_signed_cuda_kernel<<>>(a_acc, _2q_acc); +} + + + + +/**************************************************************/ +/* Connectors */ +/**************************************************************/ + +/////////////////////////////////////////////////////////////// +// intt exit + +void intt_exit_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_exit_cuda", ([&] { + intt_exit_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + +/////////////////////////////////////////////////////////////// +// intt exit reduce + +void intt_exit_reduce_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_exit_reduce_cuda", ([&] { + intt_exit_reduce_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + +/////////////////////////////////////////////////////////////// +// intt exit reduce signed + +void intt_exit_reduce_signed_cuda( + torch::Tensor a, + const torch::Tensor even, + const torch::Tensor odd, + const torch::Tensor psi, + const torch::Tensor Ninv, + const torch::Tensor _2q, + const torch::Tensor ql, + const torch::Tensor qh, + const torch::Tensor kl, + const torch::Tensor kh) { + + // Dispatch to the correct data type. + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_intt_exit_reduce_signed_cuda", ([&] { + intt_exit_reduce_signed_cuda_typed(a, even, odd, psi, Ninv, _2q, ql, qh, kl, kh); + })); +} + + +//------------------------------------------------------------------ +// Misc +//------------------------------------------------------------------ + +template +__global__ void make_unsigned_cuda_kernel( + torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t q = _2q_acc[i] >> one; + + // Make unsigned. + a_acc[i][j] += q; +} + +template +__global__ void tile_unsigned_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + torch::PackedTensorAccessor32dst_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t q = _2q_acc[i] >> one; + const scalar_t a = a_acc[j]; + + // Make unsigned. + dst_acc[i][j] = a + q; +} + +template +__global__ void mont_add_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32b_acc, + torch::PackedTensorAccessor32c_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t b = b_acc[i][j]; + const scalar_t _2q = _2q_acc[i]; + + // Add. + const scalar_t aplusb = a + b; + c_acc[i][j] = (aplusb < _2q)? aplusb : aplusb - _2q; +} + +template +__global__ void mont_sub_cuda_kernel( + const torch::PackedTensorAccessor32a_acc, + const torch::PackedTensorAccessor32b_acc, + torch::PackedTensorAccessor32c_acc, + const torch::PackedTensorAccessor32_2q_acc){ + + // Where am I? + const int i = blockIdx.x; + const int j = blockIdx.y * BLOCK_SIZE + threadIdx.x; + + // Inputs. + constexpr scalar_t one = 1; + const scalar_t a = a_acc[i][j]; + const scalar_t b = b_acc[i][j]; + const scalar_t _2q = _2q_acc[i]; + + // Sub. + const scalar_t aminusb = a + _2q - b; + c_acc[i][j] = (aminusb < _2q)? aminusb : aminusb - _2q; +} + +template +void reduce_2q_cuda_typed(torch::Tensor a, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = a.size(0); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + auto a_acc = a.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + reduce_cuda_kernel<<>>(a_acc, _2q_acc); +} + +template +void make_signed_cuda_typed(torch::Tensor a, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = a.size(0); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + auto a_acc = a.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + make_signed_cuda_kernel<<>>(a_acc, _2q_acc); +} + +template +void tile_unsigned_cuda_typed(const torch::Tensor a, torch::Tensor dst, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = _2q.size(0); + const auto N = a.size(0); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + const auto a_acc = a.packed_accessor32(); + auto dst_acc = dst.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + tile_unsigned_cuda_kernel<<>>(a_acc, dst_acc, _2q_acc); +} + +template +void make_unsigned_cuda_typed(torch::Tensor a, const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + const auto C = a.size(0); + const auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + auto a_acc = a.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + + make_unsigned_cuda_kernel<<>>(a_acc, _2q_acc); +} + +template +void mont_add_cuda_typed( + const torch::Tensor a, + const torch::Tensor b, + torch::Tensor c, + const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + const auto a_acc = a.packed_accessor32(); + const auto b_acc = b.packed_accessor32(); + auto c_acc = c.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + mont_add_cuda_kernel<<>>(a_acc, b_acc, c_acc, _2q_acc); +} + +template +void mont_sub_cuda_typed( + const torch::Tensor a, + const torch::Tensor b, + torch::Tensor c, + const torch::Tensor _2q) { + + auto device_id = a.device().index(); + cudaSetDevice(device_id); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + auto C = a.size(0); + auto N = a.size(1); + + int dim_block = BLOCK_SIZE; + dim3 dim_grid (C, N / BLOCK_SIZE); + + // Run the cuda kernel. + const auto a_acc = a.packed_accessor32(); + const auto b_acc = b.packed_accessor32(); + auto c_acc = c.packed_accessor32(); + const auto _2q_acc = _2q.packed_accessor32(); + mont_sub_cuda_kernel<<>>(a_acc, b_acc, c_acc, _2q_acc); +} + +void reduce_2q_cuda(torch::Tensor a, const torch::Tensor _2q) { + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_reduce_2q_cuda", ([&] { + reduce_2q_cuda_typed(a, _2q); + })); +} + +void make_signed_cuda(torch::Tensor a, const torch::Tensor _2q) { + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_make_signed_cuda", ([&] { + make_signed_cuda_typed(a, _2q); + })); +} + +void make_unsigned_cuda(torch::Tensor a, const torch::Tensor _2q) { + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_make_unsigned_cuda", ([&] { + make_unsigned_cuda_typed(a, _2q); + })); +} + +torch::Tensor tile_unsigned_cuda(const torch::Tensor a, const torch::Tensor _2q) { + a.squeeze_(); + const auto C = _2q.size(0); + const auto N = a.size(0); + auto c = a.new_empty({C, N}); + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_tile_unsigned_cuda", ([&] { + tile_unsigned_cuda_typed(a, c, _2q); + })); + return c; +} + +torch::Tensor mont_add_cuda(const torch::Tensor a, const torch::Tensor b, const torch::Tensor _2q) { + torch::Tensor c = torch::empty_like(a); + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_add_cuda", ([&] { + mont_add_cuda_typed(a, b, c, _2q); + })); + return c; +} + +torch::Tensor mont_sub_cuda(const torch::Tensor a, const torch::Tensor b, const torch::Tensor _2q) { + torch::Tensor c = torch::empty_like(a); + AT_DISPATCH_INTEGRAL_TYPES(a.type(), "typed_mont_sub_cuda", ([&] { + mont_sub_cuda_typed(a, b, c, _2q); + })); + return c; +} diff --git a/liberate/ntt/rns_partition.py b/liberate/ntt/rns_partition.py new file mode 100644 index 0000000..9e54c59 --- /dev/null +++ b/liberate/ntt/rns_partition.py @@ -0,0 +1,141 @@ +import numpy as np + + +class rns_partition: + def __init__(self, num_ordinary_primes=17, + num_special_primes=2, + num_devices=2): + + primes_idx = list(range(num_ordinary_primes - 1)) + base_idx = num_ordinary_primes - 1 + + num_partitions = -(-(num_ordinary_primes - 1) // num_special_primes) + + part = lambda i: primes_idx[i * num_special_primes:(i + 1) * num_special_primes] + partitions = [part(i) for i in range(num_partitions)] + + partitions.append([num_ordinary_primes - 1]) + + partitions.append( + list( + range(num_ordinary_primes, num_ordinary_primes + num_special_primes))) + + alloc = lambda i: list(range(num_partitions - i - 1, -1, -num_devices))[::-1] + part_allocations = [alloc(i) for i in range(num_devices)] + + part_allocations[0].append(num_partitions) + + for p in part_allocations: p.append(num_partitions + 1) + + expand_alloc = lambda i: [partitions[part] for part in part_allocations[i]] + prime_allocations = [expand_alloc(i) for i in range(num_devices)] + + flat_prime_allocations = [sum(alloc, []) for alloc in prime_allocations] + + self.num_ordinary_primes = num_ordinary_primes + self.num_special_primes = num_special_primes + self.num_devices = num_devices + self.num_partitions = num_partitions + self.partitions = partitions + self.part_allocations = part_allocations + self.prime_allocations = prime_allocations + self.flat_prime_allocations = flat_prime_allocations + self.num_scales = self.num_ordinary_primes - 1 + + self.base_prime_idx = self.num_ordinary_primes - 1 + self.special_prime_idx = list( + range( + self.num_ordinary_primes + 1, + self.num_ordinary_primes + 1 + self.num_special_primes) + ) + + self.compute_destination_arrays() + self.compute_rescaler_locations() + self.compute_partitions() + + def compute_destination_arrays(self): + filter_alloc = lambda devi, i: [ + a for a in self.flat_prime_allocations[devi] if a >= i] + + self.destination_arrays_with_special = [] + for lvl in range(self.num_ordinary_primes): + src = [filter_alloc(devi, lvl) for devi in range(self.num_devices)] + self.destination_arrays_with_special.append(src) + + special_removed = lambda i: [ + a[:-self.num_special_primes] for a in self.destination_arrays_with_special[i]] + + self.destination_arrays = [ + special_removed(i) for i in range(self.num_ordinary_primes)] + + lint = lambda arr: [a for a in arr if len(a) > 0] + self.destination_arrays = [lint(a) for a in self.destination_arrays] + + def compute_rescaler_locations(self): + mins = lambda arr: [min(a) for a in arr] + mins_loc = lambda a: mins(a).index(min(mins(a))) + self.rescaler_loc = [mins_loc(a) for a in self.destination_arrays_with_special] + + def partings(self, lvl): + count_element_sizes = lambda arr: np.array([len(a) for a in arr]) + cumsum_element_sizes = lambda arr: np.cumsum(arr) + remove_empty_parts = lambda arr: [a for a in arr if a > 0] + regenerate_parts = lambda arr: [list(range(a, b)) for a, b in zip([0] + arr[:-1], arr)] + + part_counts = [count_element_sizes(a) for a in self.prime_allocations] + part_cumsums = [cumsum_element_sizes(a) for a in part_counts] + level_diffs = [ + len(a) - len(b) for a, b in zip( + self.destination_arrays_with_special[0], self.destination_arrays_with_special[lvl])] + + part_cumsums_lvl = [remove_empty_parts(a - d) for a, d in zip(part_cumsums, level_diffs)] + part_count_lvl = [np.diff(a, prepend=0) for a in part_cumsums_lvl] + parts_lvl = [regenerate_parts(a) for a in part_cumsums_lvl] + return part_cumsums_lvl, part_count_lvl, parts_lvl + + def compute_partitions(self): + self.part_cumsums = [] + self.part_counts = [] + self.parts = [] + self.destination_parts = [] + self.destination_parts_with_special = [] + self.p = [] + self.p_special = [] + self.diff = [] + + self.d = [ + self.destination_arrays[0][dev_i] for dev_i in range(self.num_devices)] + + self.d_special = [ + self.destination_arrays_with_special[0][dev_i] + for dev_i in range(self.num_devices)] + + for lvl in range(self.num_ordinary_primes): + pcu, pco, par = self.partings(lvl) + self.part_cumsums.append(pcu) + self.part_counts.append(pco) + self.parts.append(par) + + dest = self.destination_arrays_with_special[lvl] + destp_special = [ + [ + [ + d[pi] for pi in p + ] for p in dev_p + ] for d, dev_p in zip(dest, par)] + destp = [dev_dp[:-1] for dev_dp in destp_special] + + self.destination_parts.append(destp) + self.destination_parts_with_special.append(destp_special) + + diff = [len(d1) - len(d2) for d1, d2 in zip( + self.destination_arrays_with_special[0], + self.destination_arrays_with_special[lvl])] + p_special = [[[pi + d for pi in p] + for p in dev_p] + for d, dev_p in zip(diff, self.parts[lvl])] + p = [dev_p[:-1] for dev_p in p_special] + + self.p.append(p) + self.p_special.append(p_special) + self.diff.append(diff) \ No newline at end of file diff --git a/liberate/ntt/setup.py b/liberate/ntt/setup.py new file mode 100644 index 0000000..ef8f21f --- /dev/null +++ b/liberate/ntt/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ext_modules = [ + CUDAExtension( + name="ntt_cuda", + sources=[ + "ntt.cpp", + "ntt_cuda_kernel.cu", + ], + ) +] +if __name__ == "__main__": + setup( + name="ntt", + ext_modules=ext_modules, + script_args=["build_ext"], + cmdclass={"build_ext": BuildExtension}, + # options={ + # "build":{ + # "build_lib":"liberate/ntt", + # } + # } + ) diff --git a/liberate/utils/__init__.py b/liberate/utils/__init__.py new file mode 100644 index 0000000..7a97403 --- /dev/null +++ b/liberate/utils/__init__.py @@ -0,0 +1,5 @@ +from . import helpers + +__all__ = [ + "helpers", +] diff --git a/liberate/utils/helpers.py b/liberate/utils/helpers.py new file mode 100644 index 0000000..a45474f --- /dev/null +++ b/liberate/utils/helpers.py @@ -0,0 +1,41 @@ +import numpy as np +from matplotlib import pyplot as plt + + +def random_complex_array( + n: int = 2 ** 8, + amin: int = -(2 ** 20), + amax: int = 2 ** 20, + decimal_exponent: int = 10, +): + base = 10 ** decimal_exponent + a = np.random.randint(amin * base, amax * base, n) / base + b = np.random.randint(amin * base, amax * base, n) / base + ret = a + b * 1j + return ret + + +def check_errors(test_message, test_message_dec, idx=10, title="errors"): + errs = test_message_dec - test_message + plt.figure(figsize=(16, 9)) + plt.plot(errs) + plt.grid() + plt.title(title) + plt.show() + + print("============================================================") + for x, y in zip(test_message[:idx], test_message_dec[:idx]): + print(f"{x.real:19.10f} | {y.real:19.10f} | {(y - x).real:14.10f}") + print("============================================================") + print(f"mean\t=\t{errs.mean():10.15f}") + print(f"std\t=\t{errs.std():10.15f}") + print(f"max err\t=\t{abs(errs).max().real:10.15f}") + print(f"min err\t=\t{abs(errs).min().real:10.15f}") + + +def absmax_error(x, y): + if type(x[0]) == np.complex128 and type(y[0]) == np.complex128: + r = np.abs(x.real - y.real).max() + np.abs(x.imag - y.imag).max() * 1j + else: + r = np.abs(np.array(x) - np.array(y)).max() + return r diff --git a/lint.sh b/lint.sh new file mode 100755 index 0000000..34872f0 --- /dev/null +++ b/lint.sh @@ -0,0 +1,6 @@ +echo 'isort:' +poetry run isort . +echo 'flake8:' +poetry run flake8 . +echo 'black:' +poetry run black . diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f94e3a3 --- /dev/null +++ b/setup.py @@ -0,0 +1,68 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ext_modules = [ + CUDAExtension( + name="randint_cuda", + sources=[ + "liberate/csprng/randint.cpp", + "liberate/csprng/randint_cuda_kernel.cu", + ], + ), + CUDAExtension( + name="randround_cuda", + sources=[ + "liberate/csprng/randround.cpp", + "liberate/csprng/randround_cuda_kernel.cu", + ], + ), + CUDAExtension( + name="discrete_gaussian_cuda", + sources=[ + "liberate/csprng/discrete_gaussian.cpp", + "liberate/csprng/discrete_gaussian_cuda_kernel.cu", + ], + ), + CUDAExtension( + name="chacha20_cuda", + sources=[ + "liberate/csprng/chacha20.cpp", + "liberate/csprng/chacha20_cuda_kernel.cu", + ], + ), +] + +ext_modules_ntt = [ + CUDAExtension( + name="ntt_cuda", + sources=[ + "liberate/ntt/ntt.cpp", + "liberate/ntt/ntt_cuda_kernel.cu", + ], + ) +] + +if __name__ == "__main__": + setup( + name="csprng", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + script_args=["build_ext"], + options={ + "build": { + "build_lib": "liberate/csprng", + } + }, + ) + + setup( + name="ntt", + ext_modules=ext_modules_ntt, + script_args=["build_ext"], + cmdclass={"build_ext": BuildExtension}, + options={ + "build": { + "build_lib": "liberate/ntt", + } + }, + )