diff --git a/src/liberate/csprng/csprng.py b/src/liberate/csprng/csprng.py index b7d41a7..17173e4 100644 --- a/src/liberate/csprng/csprng.py +++ b/src/liberate/csprng/csprng.py @@ -222,23 +222,18 @@ def generate_nonce(self, seed): # nonce is 64bits. return self.generate_initial_bytes(8, seed=None) - def randbytes(self, shares=None, repeats=0, length=None, reshape=False): + def randbytes(self, shares=None, repeats=0, reshape=False): # Generates (shares_i + repeats) X length random bytes. if shares is None: shares = self.shares - if length is None: - L = self.L - else: - L = length // 16 - # Set the target states. target_states = [] for devi in range(self.num_devices): start_channel = self.shares[devi] - shares[devi] end_channel = self.shares[devi] + repeats device_states = self.channeled_states[devi][ - start_channel:end_channel, :L, : + start_channel:end_channel, :, : ] target_states.append(device_states.view(-1, 16)) @@ -247,22 +242,17 @@ def randbytes(self, shares=None, repeats=0, length=None, reshape=False): # If not reshape, flatten. if reshape: - random_bytes = [rb.view(-1, L, 16) for rb in random_bytes] + random_bytes = [rb.view(-1, self.L, 16) for rb in random_bytes] return random_bytes - def randint(self, amax=3, shift=0, repeats=0, length=None): + def randint(self, amax=3, shift=0, repeats=0): # The default values are for generating the same uniform ternary # arrays in all GPUs. if not isinstance(amax, (list, tuple)): amax = [[amax] for share in self.shares] - if length is None: - L = self.L - else: - L = length // 4 - # Calculate shares. # If repeats are greater than 0, those channels are # subtracted from shares. @@ -278,7 +268,7 @@ def randint(self, amax=3, shift=0, repeats=0, length=None): start_channel = self.shares[devi] - shares[devi] end_channel = self.shares[devi] + repeats device_states = self.channeled_states[devi][ - start_channel:end_channel, :L, : + start_channel:end_channel, :, : ] target_states.append(device_states) @@ -289,24 +279,19 @@ def randint(self, amax=3, shift=0, repeats=0, length=None): return rand_int - def discrete_gaussian(self, non_repeats=0, repeats=1, length=None): + def discrete_gaussian(self, non_repeats=0, repeats=1): if not isinstance(non_repeats, (list, tuple)): shares = [non_repeats] * self.num_devices else: shares = non_repeats - if length is None: - L = self.L - else: - L = length // 4 - # Set the target states. target_states = [] for devi in range(self.num_devices): start_channel = self.shares[devi] - shares[devi] end_channel = self.shares[devi] + repeats device_states = self.channeled_states[devi][ - start_channel:end_channel, :L, : + start_channel:end_channel, :, : ] target_states.append(device_states.view(-1, 16)) @@ -327,6 +312,9 @@ def randround(self, coef): """Randomly round coef. Coef must be a double tensor. coef must reside in the fist GPU in the GPUs list""" + # The following slicing is OK, since we're using only the first + # contiguous stream of states. + # It will not make the target state strided. L = self.num_coefs // 16 rand_bytes = chacha20_cuda.chacha20((self.states[0][:L],), self.inc)[ 0