Skip to content

Commit

Permalink
Removed length from the csprng API.
Browse files Browse the repository at this point in the history
  • Loading branch information
juwhan-k committed Jan 5, 2024
1 parent 82567d6 commit 4daf9c6
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions src/liberate/csprng/csprng.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,23 +222,18 @@ def generate_nonce(self, seed):
# nonce is 64bits.
return self.generate_initial_bytes(8, seed=None)

def randbytes(self, shares=None, repeats=0, length=None, reshape=False):
def randbytes(self, shares=None, repeats=0, reshape=False):
# Generates (shares_i + repeats) X length random bytes.
if shares is None:
shares = self.shares

if length is None:
L = self.L
else:
L = length // 16

# Set the target states.
target_states = []
for devi in range(self.num_devices):
start_channel = self.shares[devi] - shares[devi]
end_channel = self.shares[devi] + repeats
device_states = self.channeled_states[devi][
start_channel:end_channel, :L, :
start_channel:end_channel, :, :
]
target_states.append(device_states.view(-1, 16))

Expand All @@ -247,22 +242,17 @@ def randbytes(self, shares=None, repeats=0, length=None, reshape=False):

# If not reshape, flatten.
if reshape:
random_bytes = [rb.view(-1, L, 16) for rb in random_bytes]
random_bytes = [rb.view(-1, self.L, 16) for rb in random_bytes]

return random_bytes

def randint(self, amax=3, shift=0, repeats=0, length=None):
def randint(self, amax=3, shift=0, repeats=0):
# The default values are for generating the same uniform ternary
# arrays in all GPUs.

if not isinstance(amax, (list, tuple)):
amax = [[amax] for share in self.shares]

if length is None:
L = self.L
else:
L = length // 4

# Calculate shares.
# If repeats are greater than 0, those channels are
# subtracted from shares.
Expand All @@ -278,7 +268,7 @@ def randint(self, amax=3, shift=0, repeats=0, length=None):
start_channel = self.shares[devi] - shares[devi]
end_channel = self.shares[devi] + repeats
device_states = self.channeled_states[devi][
start_channel:end_channel, :L, :
start_channel:end_channel, :, :
]
target_states.append(device_states)

Expand All @@ -289,24 +279,19 @@ def randint(self, amax=3, shift=0, repeats=0, length=None):

return rand_int

def discrete_gaussian(self, non_repeats=0, repeats=1, length=None):
def discrete_gaussian(self, non_repeats=0, repeats=1):
if not isinstance(non_repeats, (list, tuple)):
shares = [non_repeats] * self.num_devices
else:
shares = non_repeats

if length is None:
L = self.L
else:
L = length // 4

# Set the target states.
target_states = []
for devi in range(self.num_devices):
start_channel = self.shares[devi] - shares[devi]
end_channel = self.shares[devi] + repeats
device_states = self.channeled_states[devi][
start_channel:end_channel, :L, :
start_channel:end_channel, :, :
]
target_states.append(device_states.view(-1, 16))

Expand All @@ -327,6 +312,9 @@ def randround(self, coef):
"""Randomly round coef. Coef must be a double tensor.
coef must reside in the fist GPU in the GPUs list"""

# The following slicing is OK, since we're using only the first
# contiguous stream of states.
# It will not make the target state strided.
L = self.num_coefs // 16
rand_bytes = chacha20_cuda.chacha20((self.states[0][:L],), self.inc)[
0
Expand Down

0 comments on commit 4daf9c6

Please sign in to comment.