From 95175144486e5a9fc9081820b3e2755342e17981 Mon Sep 17 00:00:00 2001 From: YongwooLee Date: Sat, 25 Nov 2023 15:17:56 +0900 Subject: [PATCH 1/7] README.md typo fix We use logN=15 for silver. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b2c1762..1f88d77 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ from liberate import fhe from liberate.fhe import presets # Generate CKKS engine with preset parameters -grade = "silver" # logN=14 +grade = "silver" # logN=15 params = presets.params[grade] engine = fhe.ckks_engine(**params, verbose=True) From ce31d094f87143ee1c89aa5c2f347c78233eebb2 Mon Sep 17 00:00:00 2001 From: desilo-hanyul <95674196+hanyul-ryu@users.noreply.github.com> Date: Tue, 12 Dec 2023 15:59:46 +0900 Subject: [PATCH 2/7] Update rns_partition.py Dead code removal request --- src/liberate/ntt/rns_partition.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/liberate/ntt/rns_partition.py b/src/liberate/ntt/rns_partition.py index b2cedcf..276afa9 100644 --- a/src/liberate/ntt/rns_partition.py +++ b/src/liberate/ntt/rns_partition.py @@ -56,12 +56,6 @@ def __init__( self.num_scales = self.num_ordinary_primes - 1 self.base_prime_idx = self.num_ordinary_primes - 1 - self.special_prime_idx = list( - range( - self.num_ordinary_primes + 1, - self.num_ordinary_primes + 1 + self.num_special_primes, - ) - ) self.compute_destination_arrays() self.compute_rescaler_locations() From 8df4f39fe011d66fcfbbe814b0efa89ed2fa39b3 Mon Sep 17 00:00:00 2001 From: desilo-hanyul <95674196+hanyul-ryu@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:01:30 +0900 Subject: [PATCH 3/7] Update csprng.py Rationalized length calculation. --- src/liberate/csprng/csprng.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/liberate/csprng/csprng.py b/src/liberate/csprng/csprng.py index 8de672a..b7d41a7 100644 --- a/src/liberate/csprng/csprng.py +++ b/src/liberate/csprng/csprng.py @@ -230,7 +230,7 @@ def randbytes(self, shares=None, repeats=0, length=None, reshape=False): if length is None: L = self.L else: - L = length + L = length // 16 # Set the target states. target_states = [] @@ -261,7 +261,7 @@ def randint(self, amax=3, shift=0, repeats=0, length=None): if length is None: L = self.L else: - L = length + L = length // 4 # Calculate shares. # If repeats are greater than 0, those channels are @@ -298,7 +298,7 @@ def discrete_gaussian(self, non_repeats=0, repeats=1, length=None): if length is None: L = self.L else: - L = length + L = length // 4 # Set the target states. target_states = [] From bfa88adc52536e77374980143d313fd0be992df7 Mon Sep 17 00:00:00 2001 From: hanyul-ryu Date: Fri, 22 Dec 2023 11:35:22 +0900 Subject: [PATCH 4/7] remove bias_guard at bronze --- src/liberate/fhe/presets/params.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/liberate/fhe/presets/params.py b/src/liberate/fhe/presets/params.py index 88326e5..97c3490 100644 --- a/src/liberate/fhe/presets/params.py +++ b/src/liberate/fhe/presets/params.py @@ -3,7 +3,6 @@ "logN": 14, "num_special_primes": 1, "devices": [0], - "bias_guard": False, "scale_bits": 40, "num_scales": None, }, From 7883614937cb83c0711403a5d5c4fd74898027b8 Mon Sep 17 00:00:00 2001 From: hanyul-ryu Date: Fri, 22 Dec 2023 11:35:27 +0900 Subject: [PATCH 5/7] remove check bias_gurad at init --- src/liberate/fhe/ckks_engine.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/liberate/fhe/ckks_engine.py b/src/liberate/fhe/ckks_engine.py index 7753ae0..7aa784b 100644 --- a/src/liberate/fhe/ckks_engine.py +++ b/src/liberate/fhe/ckks_engine.py @@ -49,12 +49,6 @@ def __init__(self, devices: list[int] = None, verbose: bool = False, self.ctx = ckks_context(**ctx_params) self.ntt = ntt_context(self.ctx, devices=devices, verbose=verbose) - if self.bias_guard: - if self.ctx.num_special_primes < 2: - raise errors.NotEnoughPrimesForBiasGuard( - bias_guard=self.bias_guard, - num_special_primes=self.ctx.num_special_primes) - self.num_levels = self.ntt.num_levels - 1 self.num_slots = self.ctx.N // 2 From 408fccd67efa306cc68114acfdb23c41cb214897 Mon Sep 17 00:00:00 2001 From: Juwhan Kim Date: Wed, 3 Jan 2024 20:21:48 +0900 Subject: [PATCH 6/7] Stability improvement and decrypt rounding. --- src/liberate/fhe/ckks_engine.py | 330 +++++++++++++++++++++----------- 1 file changed, 222 insertions(+), 108 deletions(-) diff --git a/src/liberate/fhe/ckks_engine.py b/src/liberate/fhe/ckks_engine.py index 7aa784b..7cfcf6f 100644 --- a/src/liberate/fhe/ckks_engine.py +++ b/src/liberate/fhe/ckks_engine.py @@ -479,7 +479,7 @@ def encrypt(self, pt: list[torch.Tensor], pk: data_struct, level: int = 0) -> da return ct - def decrypt_triplet(self, ct_mult: data_struct, sk: data_struct) -> list[torch.Tensor]: + def decrypt_triplet(self, ct_mult: data_struct, sk: data_struct, final_round=True) -> list[torch.Tensor]: if ct_mult.origin != types.origins["ctt"]: raise errors.NotMatchType(origin=ct_mult.origin, to=types.origins["ctt"]) if sk.origin != types.origins["sk"]: @@ -521,9 +521,18 @@ def decrypt_triplet(self, ct_mult: data_struct, sk: data_struct) -> list[torch.T self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) self.ntt.reduce_2q(scaled, -1) self.ntt.make_signed(scaled, -1) + + # Round? + if final_round: + # The scaler and the base channels are guaranteed to be in the + # device 0. + rounding_prime = self.ntt.qlists[0][-self.ctx.num_special_primes-2] + rounder = (scaler[0] > (rounding_prime // 2)) * 1 + scaled[0] += rounder + return scaled - def decrypt_double(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: + def decrypt_double(self, ct: data_struct, sk: data_struct, final_round=True) -> list[torch.Tensor]: if ct.origin != types.origins["ct"]: raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) if sk.origin != types.origins["sk"]: @@ -556,9 +565,18 @@ def decrypt_double(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor] self.ntt.mont_enter_scalar(scaled, [final_scalar], -1) self.ntt.reduce_2q(scaled, -1) self.ntt.make_signed(scaled, -1) + + # Round? + if final_round: + # The scaler and the base channels are guaranteed to be in the + # device 0. + rounding_prime = self.ntt.qlists[0][-self.ctx.num_special_primes-2] + rounder = (scaler[0] > (rounding_prime // 2)) * 1 + scaled[0] += rounder + return scaled - def decrypt(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: + def decrypt(self, ct: data_struct, sk: data_struct, final_round=True) -> list[torch.Tensor]: """ Decrypt the cipher text ct using the secret key sk. Note that the final rescaling must precede the actual decryption process. @@ -568,9 +586,9 @@ def decrypt(self, ct: data_struct, sk: data_struct) -> list[torch.Tensor]: raise errors.NotMatchType(origin=sk.origin, to=types.origins["sk"]) if ct.origin == types.origins["ctt"]: - pt = self.decrypt_triplet(ct_mult=ct, sk=sk) + pt = self.decrypt_triplet(ct_mult=ct, sk=sk, final_round=final_round) elif ct.origin == types.origins["ct"]: - pt = self.decrypt_double(ct=ct, sk=sk) + pt = self.decrypt_double(ct=ct, sk=sk, final_round=final_round) else: raise errors.NotMatchType(origin=ct.origin, to=f"{types.origins['ct']} or {types.origins['ctt']}") @@ -634,123 +652,167 @@ def create_key_switching_key(self, sk_from: data_struct, sk_to: data_struct, a=N version=self.version) def pre_extend(self, a, device_id, level, part_id, exit_ntt=False): + # param_parts contain only the ordinary parts. + # Hence, loop around it. + # text_parts contain special primes. text_part = self.ntt.p.parts[level][device_id][part_id] param_part = self.ntt.p.p[level][device_id][part_id] + # Carve out the partition. alpha = len(text_part) - a_part = a[device_id][text_part[0]:text_part[-1] + 1] - + a_part = a[device_id][text_part[0]:text_part[-1]+1] + + # Release ntt. if exit_ntt: self.ntt.intt_exit_reduce([a_part], level, device_id, part_id) + # Prepare a state. + # Initially, it is x[0] % m[i]. + # However, m[i] is a monotonically increasing + # sequence, i.e., repeating the row would suffice + # to do the job. + + # 2023-10-16, Juwhan Kim, In fact, m[i] is NOT monotonically increasing. + state = a_part[0].repeat(alpha, 1) + key = tuple(param_part) - for i in range(alpha - 1): - mont_pack = self.ntt.parts_pack[device_id][param_part[i + 1],]['mont_pack'] - _2q = self.ntt.parts_pack[device_id][param_part[i + 1],]['_2q'] + for i in range(alpha-1): + mont_pack = self.ntt.parts_pack[device_id][param_part[i+1],]['mont_pack'] + _2q = self.ntt.parts_pack[device_id][param_part[i+1],]['_2q'] Y_scalar = self.ntt.parts_pack[device_id][key]['Y_scalar'][i][None] - Y = (a_part[i + 1] - state[i + 1])[None, :] + Y = (a_part[i+1] - state[i+1])[None, :] - ntt_cuda.mont_enter([Y], [Y_scalar], *mont_pack) - ntt_cuda.reduce_2q([Y], _2q) + # mont_enter will take care of signedness. + # ntt_cuda.make_unsigned([Y], _2q) + ntt_cuda.mont_enter([Y], [Y_scalar], *mont_pack) + # ntt_cuda.reduce_2q([Y], _2q) - state[i + 1] = Y + state[i+1] = Y - if i + 2 < alpha: - state_key = tuple(param_part[i + 2:]) + if i+2 < alpha: + state_key = tuple(param_part[i+2:]) state_mont_pack = self.ntt.parts_pack[device_id][state_key]['mont_pack'] state_2q = self.ntt.parts_pack[device_id][state_key]['_2q'] L_scalar = self.ntt.parts_pack[device_id][key]['L_scalar'][i] - new_state_len = alpha - (i + 2) + new_state_len = alpha - (i+2) new_state = Y.repeat(new_state_len, 1) ntt_cuda.mont_enter([new_state], [L_scalar], *state_mont_pack) - ntt_cuda.reduce_2q([new_state], state_2q) - state[i + 2:] += new_state - ntt_cuda.reduce_2q([state[i + 2:]], state_2q) - + state[i+2:] += new_state + + # Returned state is in plain integer format. return state def extend(self, state, device_id, level, part_id, target_device_id=None): + # Note that device_id, level, and part_id is from + # where the state has been originally calculated at. + # The state can reside in a different GPU than + # the original one. if target_device_id is None: target_device_id = device_id - + rns_len = len( self.ntt.p.destination_arrays_with_special[level][target_device_id]) alpha = len(state) + + # Initialize the output + extended = state[0].repeat(rns_len, 1) + self.ntt.mont_enter([extended], level, target_device_id, -2) - extended = state[0].repeat(rns_len, 1) - self.ntt.mont_enter([extended], level, target_device_id, -2) - + # Generate the search key to find the L_enter. part = self.ntt.p.p[level][device_id][part_id] key = tuple(part) + # Extract the L_enter in the target device. L_enter = self.ntt.parts_pack[device_id][key]['L_enter'][target_device_id] - + + # L_enter covers the whole rns range. + # Start from the leveled start. start = self.ntt.starts[level][target_device_id] - - for i in range(alpha - 1): - Y = state[i + 1].repeat(rns_len, 1) - + + # Loop to generate. + for i in range(alpha-1): + Y = state[i+1].repeat(rns_len, 1) + self.ntt.mont_enter_scalar([Y], [L_enter[i][start:]], level, target_device_id, -2) extended = self.ntt.mont_add([extended], [Y], level, target_device_id, -2)[0] - + + # Returned extended is in the Montgomery format. return extended + def create_switcher(self, a: list[torch.Tensor], ksk: data_struct, level, exit_ntt=False) -> tuple: + # ksk parts allocation. ksk_alloc = self.parts_alloc[level] - + + # Device lens and neighbor devices. len_devices = self.len_devices[level] neighbor_devices = self.neighbor_devices[level] + # Iterate over source device ids, and then part ids. num_parts = sum([len(alloc) for alloc in ksk_alloc]) - part_results = [[[[] for _ in range(len_devices)], [[] for _ in range(len_devices)]] for _ in range(num_parts)] - + part_results = [ + [ + [[] for _ in range(len_devices)], + [[] for _ in range(len_devices)] + ] + for _ in range(num_parts) + ] + + # 1. Generate states. states = [[] for _ in range(num_parts)] for src_device_id in range(len_devices): for part_id in range(len(self.ntt.p.p[level][src_device_id])): storage_id = self.stor_ids[level][src_device_id][part_id] - state = self.pre_extend(a, - src_device_id, - level, - part_id, - exit_ntt - ) + state = self.pre_extend( + a, + src_device_id, + level, + part_id, + exit_ntt + ) states[storage_id] = state - + + # 2. Copy to CPU. CPU_states = [[] for _ in range(num_parts)] for src_device_id in range(len_devices): for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): storage_id = self.stor_ids[level][src_device_id][part_id] alpha = len(part) - CPU_state = self.ksk_buffers[src_device_id][part_id][:alpha] + CPU_state = self.ksk_buffers[src_device_id][part_id][:alpha] CPU_state.copy_(states[storage_id], non_blocking=True) CPU_states[storage_id] = CPU_state - + + # 3. Continue on with the follow ups on source devices. for src_device_id in range(len_devices): for part_id in range(len(self.ntt.p.p[level][src_device_id])): storage_id = self.stor_ids[level][src_device_id][part_id] state = states[storage_id] d0, d1 = self.switcher_later_part(state, ksk, - src_device_id, - src_device_id, - level, part_id) + src_device_id, + src_device_id, + level, part_id) part_results[storage_id][0][src_device_id] = d0 part_results[storage_id][1][src_device_id] = d1 - + + # 4. Copy onto neighbor GPUs the states. CUDA_states = [[] for _ in range(num_parts)] for src_device_id in range(len_devices): for j, dst_device_id in enumerate( - neighbor_devices[src_device_id]): + neighbor_devices[src_device_id]): for part_id, part in enumerate(self.ntt.p.p[level][src_device_id]): storage_id = self.stor_ids[level][src_device_id][part_id] CPU_state = CPU_states[storage_id] - CUDA_states[storage_id] = CPU_state.cuda(self.ntt.devices[dst_device_id], non_blocking=True) - - torch.cuda.synchronize() - + CUDA_states[storage_id] = CPU_state.cuda( + self.ntt.devices[dst_device_id], non_blocking=True) + + # 5. Synchronize. + # torch.cuda.synchronize() + + #6. Do follow ups on neighbors. for src_device_id in range(len_devices): for j, dst_device_id in enumerate( neighbor_devices[src_device_id]): @@ -758,44 +820,81 @@ def create_switcher(self, a: list[torch.Tensor], ksk: data_struct, level, exit_n storage_id = self.stor_ids[level][src_device_id][part_id] CUDA_state = CUDA_states[storage_id] d0, d1 = self.switcher_later_part(CUDA_state, - ksk, - src_device_id, - dst_device_id, - level, - part_id) + ksk, + src_device_id, + dst_device_id, + level, + part_id) part_results[storage_id][0][dst_device_id] = d0 part_results[storage_id][1][dst_device_id] = d1 - + + # 7. Sum up. summed0 = part_results[0][0] summed1 = part_results[0][1] - + + for i in range(1, len(part_results)): - summed0 = self.ntt.mont_add(summed0, part_results[i][0], level, -2) - summed1 = self.ntt.mont_add(summed1, part_results[i][1], level, -2) - + summed0 = self.ntt.mont_add( + summed0, part_results[i][0], level, -2) + summed1 = self.ntt.mont_add( + summed1, part_results[i][1], level, -2) + + # Rename summed's. d0 = summed0 d1 = summed1 + + # 6. Divide by P. + # This is actually done in successive order. + # Rescale from the most outer prime channel. + # Start from the special len and drop channels one by one. + + # Pre-montgomery enter the ordinary part. + # Note that special prime channels remain intact. + c0 = [d[:-self.ntt.num_special_primes] for d in d0] + c1 = [d[:-self.ntt.num_special_primes] for d in d1] + + self.ntt.mont_enter(c0, level, -1) + self.ntt.mont_enter(c1, level, -1) current_len = [len(d) for d in self.ntt.p.destination_arrays_with_special[level]] - + for P_ind in range(self.ntt.num_special_primes): - current_len = [c - 1 for c in current_len] - PiRi = self.PiRs[level][P_ind] - P0 = [d[-1].repeat(current_len[di], 1) for di, d in enumerate(d0)] - P1 = [d[-1].repeat(current_len[di], 1) for di, d in enumerate(d1)] - - d0 = [d0[i][:current_len[i]] - P0[i] for i in range(len_devices)] - d1 = [d1[i][:current_len[i]] - P1[i] for i in range(len_devices)] + # Tile. + P0 = [d[-1-P_ind].repeat(current_len[di], 1) for di, d in enumerate(d0)] + P1 = [d[-1-P_ind].repeat(current_len[di], 1) for di, d in enumerate(d1)] + + # mont enter only the ordinary part. + Q0 = [d[:-self.ntt.num_special_primes] for d in P0] + Q1 = [d[:-self.ntt.num_special_primes] for d in P1] + + self.ntt.mont_enter(Q0, level, -1) + self.ntt.mont_enter(Q1, level, -1) + + # subtract P0 and P1. + # Note that by the consequence of the above mont_enter + # ordinary parts will be in montgomery form, + # while the special part remains plain. + d0 = self.ntt.mont_sub(d0, P0, level, -2) + d1 = self.ntt.mont_sub(d1, P1, level, -2) self.ntt.mont_enter_scalar(d0, PiRi, level, -2) self.ntt.mont_enter_scalar(d1, PiRi, level, -2) - self.ntt.reduce_2q(d0, level, -1) - self.ntt.reduce_2q(d1, level, -1) - - return d0, d1 + # Carve out again, since d0 and d1 are fresh new. + c0 = [d[:-self.ntt.num_special_primes] for d in d0] + c1 = [d[:-self.ntt.num_special_primes] for d in d1] + + # Exit the montgomery. + self.ntt.mont_redc(c0, level, -1) + self.ntt.mont_redc(c1, level, -1) + + self.ntt.reduce_2q(c0, level, -1) + self.ntt.reduce_2q(c1, level, -1) + + # 7. Return + return c0, c1 def switcher_later_part(self, state, ksk, @@ -803,10 +902,17 @@ def switcher_later_part(self, dst_device_id, level, part_id): - extended = self.extend(state, src_device_id, level, part_id, dst_device_id) + # Extend basis. + extended = self.extend( + state, src_device_id, + level, part_id, dst_device_id) - self.ntt.ntt([extended], level, dst_device_id, -2) + # ntt extended to prepare polynomial multiplication. + # extended is in the Montgomery format already. + self.ntt.ntt( + [extended], level, dst_device_id, -2) + # Extract the ksk. ksk_loc = self.parts_alloc[level][src_device_id][part_id] ksk_part_data = ksk.data[ksk_loc].data @@ -814,12 +920,17 @@ def switcher_later_part(self, ksk0_data = ksk_part_data[0][dst_device_id][start:] ksk1_data = ksk_part_data[1][dst_device_id][start:] - d0 = self.ntt.mont_mult([extended], [ksk0_data], level, dst_device_id, -2) - d1 = self.ntt.mont_mult([extended], [ksk1_data], level, dst_device_id, -2) + # Multiply. + d0 = self.ntt.mont_mult( + [extended], [ksk0_data], level, dst_device_id, -2) + d1 = self.ntt.mont_mult( + [extended], [ksk1_data], level, dst_device_id, -2) + # intt to prepare for division by P. self.ntt.intt_exit_reduce(d0, level, dst_device_id, -2) - self.ntt.intt_exit_reduce(d1, level, dst_device_id, -2) - + self.ntt.intt_exit_reduce(d1, level, dst_device_id, -2) + + # When returning, un-list the results by taking the 0th element. return d0[0], d1[0] def switch_key(self, ct: data_struct, ksk: data_struct) -> data_struct: @@ -834,6 +945,7 @@ def switch_key(self, ct: data_struct, ksk: data_struct) -> data_struct: d0, d1 = self.create_switcher(a, ksk, level, exit_ntt=ct.ntt_state) new_ct0 = self.ntt.mont_add(ct.data[0], d0, level, -1) + self.ntt.reduce_2q(new_ct0, level, -1) return data_struct( data=(new_ct0, d1), @@ -1078,6 +1190,12 @@ def rotate_single(self, ct: data_struct, rotk: data_struct) -> data_struct: rotated_ct_data = [[rotate(d, delta) for d in ct_data] for ct_data in ct.data] + # Rotated ct may contain negative numbers. + mult_type = -2 if include_special else -1 + for ct_data in rotated_ct_data: + self.ntt.make_unsigned(ct_data, level, mult_type) + self.ntt.reduce_2q(ct_data, level, mult_type) + rotated_ct_rotated_sk = data_struct( data=rotated_ct_data, include_special=include_special, @@ -1438,7 +1556,7 @@ def encodecrypt(self, m, pk: data_struct, level: int = 0, padding=True) -> data_ return ct - def decryptcode(self, ct: data_struct, sk: data_struct, is_real=False) -> data_struct: + def decryptcode(self, ct: data_struct, sk: data_struct, is_real=False, final_round=True) -> data_struct: if (not sk.ntt_state) or (not sk.montgomery_state): raise errors.NotMatchDataStructState(origin=sk.origin) @@ -1530,6 +1648,14 @@ def decryptcode(self, ct: data_struct, sk: data_struct, is_real=False) -> data_s self.ntt.reduce_2q(scaled, -1) self.ntt.make_signed(scaled, -1) + # Round? + if final_round: + # The scaler and the base channels are guaranteed to be in the + # device 0. + rounding_prime = self.ntt.qlists[0][-self.ctx.num_special_primes-2] + rounder = (scaler[0] > (rounding_prime // 2)) * 1 + scaled[0] += rounder + # Decoding. correction = self.corrections[level] decoded = decode( @@ -1555,8 +1681,8 @@ def decryptcode(self, ct: data_struct, sk: data_struct, is_real=False) -> data_s def encorypt(self, m, pk: data_struct, level: int = 0, padding=True): return self.encodecrypt(m, pk=pk, level=level, padding=padding) - def decrode(self, ct: data_struct, sk: data_struct, is_real=False): - return self.decryptcode(ct=ct, sk=sk, is_real=is_real) + def decrode(self, ct: data_struct, sk: data_struct, is_real=False, final_round=True): + return self.decryptcode(ct=ct, sk=sk, is_real=is_real, final_round=final_round) # ------------------------------------------------------------------------------------------- # Conjugation @@ -2169,38 +2295,24 @@ def reduce_error(self, ct): # Misc ops. # ------------------------------------------------------------------------------------------- - def sum(self, ct: data_struct, gk: data_struct, rescale_every=5) -> data_struct: - if ct.origin != types.origins["ct"]: - raise errors.NotMatchType(origin=ct.origin, to=types.origins["ct"]) - if gk.origin != types.origins["galk"]: - raise errors.NotMatchType(origin=gk.origin, to=types.origins["galk"]) - + def sum(self, ct, gk): new_ct = self.clone(ct) - for roti in range(self.ctx.logN - 1): - rot_ct = self.rotate_single(new_ct, gk.data[roti]) - sum_ct = self.add(rot_ct, new_ct) - del new_ct, rot_ct - if roti != 0 and (roti % rescale_every) == 0: - new_ct = self.reduce_error(sum_ct) - else: - new_ct = sum_ct + for roti in range(self.ctx.logN-1): + rotk = gk.data[roti] + rot_ct = self.rotate_single(new_ct, rotk) + new_ct = self.add(rot_ct, new_ct) return new_ct - def mean(self, ct: data_struct, gk: data_struct, alpha=1, rescale_every=5) -> data_struct: + def mean(self, ct, gk, alpha=1): # Divide by num_slots. # The cipher text is refreshed here, and hence - # doesn't need to be refreshed at roti=0 in the loop. - new_ct = self.mult(1 / self.num_slots / alpha, ct) - - for roti in range(self.ctx.logN - 1): + # doesn't beed to be refreshed at roti=0 in the loop. + new_ct = self.mult(1/self.num_slots/alpha, ct) + + for roti in range(self.ctx.logN-1): rotk = gk.data[roti] rot_ct = self.rotate_single(new_ct, rotk) - sum_ct = self.add(rot_ct, new_ct) - del new_ct, rot_ct - if ((roti % rescale_every) == 0) and (roti != 0): - new_ct = self.reduce_error(sum_ct) - else: - new_ct = sum_ct + new_ct = self.add(rot_ct, new_ct) return new_ct def cov(self, ct_a: data_struct, ct_b: data_struct, @@ -2607,3 +2719,5 @@ def std(self, ct: data_struct, evk: data_struct, gk: data_struct, relin=False) - ct_var = self.var(ct=ct, evk=evk, gk=gk, relin=relin) ct_std = self.sqrt(ct=ct_var, evk=evk) return ct_std + + From 9cb001c4a241536547677b03ff9e8d30552d363a Mon Sep 17 00:00:00 2001 From: Juwhan Kim Date: Wed, 3 Jan 2024 23:30:04 +0900 Subject: [PATCH 7/7] Removed rescale_every from cov --- src/liberate/fhe/ckks_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/liberate/fhe/ckks_engine.py b/src/liberate/fhe/ckks_engine.py index 7cfcf6f..1832079 100644 --- a/src/liberate/fhe/ckks_engine.py +++ b/src/liberate/fhe/ckks_engine.py @@ -2316,9 +2316,9 @@ def mean(self, ct, gk, alpha=1): return new_ct def cov(self, ct_a: data_struct, ct_b: data_struct, - evk: data_struct, gk: data_struct, rescale_every=5) -> data_struct: - cta_mean = self.mean(ct_a, gk, rescale_every=rescale_every) - ctb_mean = self.mean(ct_b, gk, rescale_every=rescale_every) + evk: data_struct, gk: data_struct) -> data_struct: + cta_mean = self.mean(ct_a, gk) + ctb_mean = self.mean(ct_b, gk) cta_dev = self.sub(ct_a, cta_mean) ctb_dev = self.sub(ct_b, ctb_mean)