Skip to content

Commit

Permalink
streamlined decode util to handle lists instead of single polynomials
Browse files Browse the repository at this point in the history
  • Loading branch information
PyryL committed Nov 15, 2023
1 parent e6ab86d commit 2a49ab9
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 25 deletions.
11 changes: 3 additions & 8 deletions kyber/decrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@ def decrypt(self) -> bytes:
:returns Decrypted 32-bit shared secret
"""

# split self._sk into chunks of length 32*12 and decode each one of them into a polynomial
s = np.array([
decode(self._sk[32*12*i : 32*12*(i+1)], 12) for i in range(len(self._sk)//(32*12))
])
s = np.array(decode(self._sk, 12))

u, v = self._c[:du*k*n//8], self._c[du*k*n//8:]

u = np.array([
decode(u[32*du*i : 32*du*(i+1)], du) for i in range(len(u)//(32*du))
])
v = decode(v, dv)
u = decode(u, du)
v = decode(v, dv)[0]

u = np.array([decompress(pol, du) for pol in u])
v = decompress(v, dv)
Expand Down
6 changes: 2 additions & 4 deletions kyber/encrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def encrypt(self):
rb = self._r

t, rho = self._pk[:-32], self._pk[-32:]
t = np.array([
decode(t[32*12*i : 32*12*(i+1)], 12) for i in range(len(t)//(32*12))
])
t = np.array(decode(t, 12))

A = np.empty((k, k), Polynomial)
for i in range(k):
Expand All @@ -60,7 +58,7 @@ def encrypt(self):
e2 = polmod(e2)

u = np.matmul(A.T, r) + e1
v = np.matmul(t.T, r) + e2 + decompress(decode(m, 1), 1)
v = np.matmul(t.T, r) + e2 + decompress(decode(m, 1)[0], 1)

u = matmod(u)
v = polmod(v)
Expand Down
22 changes: 13 additions & 9 deletions kyber/utils/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,21 @@ def encode(pols: list[Polynomial], l: int) -> bytes:
assert len(result) == 32*l*len(pols)
return bytes(result)

def decode(b: bytes, l: int) -> Polynomial:
def decode(b: bytes, l: int) -> list[Polynomial]:
"""
Converts the given byte array (length `32*l`) into a polynomial (degree 255)
Converts the given byte array (length `32*l*x` for some integer x) into
a list of polynomials (length x, each degree 255)
in which each coefficient is in range `0...2**l-1` (inclusive).
"""

if len(b) != 32*l:
if len(b) % 32*l != 0:
raise ValueError()
bits = bytes_to_bits(b)
f = np.empty((256, ))
for i in range(256):
f[i] = sum(bits[i*l+j]*2**j for j in range(l)) # accesses each bit exactly once
assert 0 <= f[i] and f[i] <= 2**l-1
return Polynomial(f)
result = []
for t in range(len(b) // (32*l)):
bits = bytes_to_bits(b[32*l*t : 32*l*(t+1)])
f = np.empty((256, ))
for i in range(256):
f[i] = sum(bits[i*l+j]*2**j for j in range(l)) # accesses each bit exactly once
assert 0 <= f[i] and f[i] <= 2**l-1
result.append(Polynomial(f))
return result
8 changes: 4 additions & 4 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ def setUp(self):
self.polynomial2 = Polynomial([randint(0, 1) for _ in range(256)])

def test_encoding_symmetry(self):
polynomial = decode(self.data, self.l)
restored_data = encode([polynomial], self.l)
polynomials = decode(self.data, self.l)
restored_data = encode(polynomials, self.l)
self.assertEqual(self.data, restored_data)

def test_decode_coefficients(self):
polynomial = decode(self.data, self.l)
polynomial = decode(self.data, self.l)[0]
for c in polynomial.coef:
self.assertTrue(0 <= int(c) or int(c) <= 2**self.l-1)

def test_decode_degree(self):
polynomial = decode(self.data, self.l)
polynomial = decode(self.data, self.l)[0]
self.assertEqual(len(polynomial.coef), 256)

def test_decode_raises_with_invalid_argument_length(self):
Expand Down

0 comments on commit 2a49ab9

Please sign in to comment.