diff --git a/src/liberate/fhe/ckks_engine.py b/src/liberate/fhe/ckks_engine.py index 1832079..e264949 100644 --- a/src/liberate/fhe/ckks_engine.py +++ b/src/liberate/fhe/ckks_engine.py @@ -842,6 +842,10 @@ def create_switcher(self, a: list[torch.Tensor], ksk: data_struct, level, exit_n # Rename summed's. d0 = summed0 d1 = summed1 + + # intt to prepare for division by P. + self.ntt.intt_exit_reduce(d0, level, -2) + self.ntt.intt_exit_reduce(d1, level, -2) # 6. Divide by P. # This is actually done in successive order. @@ -924,11 +928,7 @@ def switcher_later_part(self, d0 = self.ntt.mont_mult( [extended], [ksk0_data], level, dst_device_id, -2) d1 = self.ntt.mont_mult( - [extended], [ksk1_data], level, dst_device_id, -2) - - # intt to prepare for division by P. - self.ntt.intt_exit_reduce(d0, level, dst_device_id, -2) - self.ntt.intt_exit_reduce(d1, level, dst_device_id, -2) + [extended], [ksk1_data], level, dst_device_id, -2) # When returning, un-list the results by taking the 0th element. return d0[0], d1[0]