From 4565eed44c69cb52fb6f7b96e12892dc460971ab Mon Sep 17 00:00:00 2001 From: Pyry Lahtinen Date: Tue, 12 Dec 2023 15:26:52 +0200 Subject: [PATCH] fixed some lint errors --- kyber/ccakem.py | 4 ++-- kyber/encryption/__init__.py | 2 +- kyber/encryption/decrypt.py | 42 ++++++++++++++++-------------------- tests/test_cbd.py | 2 +- tests/test_decrypt.py | 8 +++---- tests/test_encryption.py | 4 ++-- 6 files changed, 29 insertions(+), 33 deletions(-) diff --git a/kyber/ccakem.py b/kyber/ccakem.py index e0891eb..3e51762 100644 --- a/kyber/ccakem.py +++ b/kyber/ccakem.py @@ -1,5 +1,5 @@ from secrets import token_bytes -from kyber.encryption import generate_keys, Encrypt, Decrypt +from kyber.encryption import generate_keys, Encrypt, decrypt from kyber.utils.pseudo_random import H, G, kdf from kyber.constants import k, n, du, dv @@ -56,7 +56,7 @@ def ccakem_decrypt(ciphertext: bytes, private_key: bytes, shared_secret_length: assert h == H(pk) - m = Decrypt(sk, ciphertext).decrypt() + m = decrypt(sk, ciphertext) Kr = G(m + h) K, r = Kr[:32], Kr[32:] c = Encrypt(pk, m, r).encrypt() diff --git a/kyber/encryption/__init__.py b/kyber/encryption/__init__.py index 0cbab35..da79c33 100644 --- a/kyber/encryption/__init__.py +++ b/kyber/encryption/__init__.py @@ -1,3 +1,3 @@ from kyber.encryption.keygen import generate_keys from kyber.encryption.encrypt import Encrypt -from kyber.encryption.decrypt import Decrypt +from kyber.encryption.decrypt import decrypt diff --git a/kyber/encryption/decrypt.py b/kyber/encryption/decrypt.py index 25b19f3..7453560 100644 --- a/kyber/encryption/decrypt.py +++ b/kyber/encryption/decrypt.py @@ -4,33 +4,29 @@ from kyber.constants import n, k, du, dv from kyber.entities.polring import PolynomialRing -class Decrypt: - def __init__(self, private_key, ciphertext) -> None: - self._sk = private_key - self._c = ciphertext - if len(self._sk) != 32*12*k: - raise ValueError() - if len(self._c) != du*k*n//8 + dv*n//8: - raise ValueError() +def decrypt(private_key, ciphertext) -> bytes: + """ + Decrypts the given ciphertext with the given private key. + :returns Decrypted 32-bit shared secret + """ - def decrypt(self) -> bytes: - """ - Decrypts the given ciphertext with the given private key. - :returns Decrypted 32-bit shared secret - """ + if len(private_key) != 32*12*k: + raise ValueError() + if len(ciphertext) != du*k*n//8 + dv*n//8: + raise ValueError() - s = np.array(decode(self._sk, 12)) + s = np.array(decode(private_key, 12)) - u, v = self._c[:du*k*n//8], self._c[du*k*n//8:] + u, v = ciphertext[:du*k*n//8], ciphertext[du*k*n//8:] - u = decode(u, du) - v = decode(v, dv)[0] + u = decode(u, du) + v = decode(v, dv)[0] - u = np.array([decompress(pol, du) for pol in u]) - v = decompress(v, dv) + u = np.array([decompress(pol, du) for pol in u]) + v = decompress(v, dv) - m: PolynomialRing = v - np.matmul(s.T, u) - m: bytes = encode(compress([m], 1), 1) + m: PolynomialRing = v - np.matmul(s.T, u) + m: bytes = encode(compress([m], 1), 1) - assert len(m) == 32 - return m + assert len(m) == 32 + return m diff --git a/tests/test_cbd.py b/tests/test_cbd.py index cc48bb0..c28893a 100644 --- a/tests/test_cbd.py +++ b/tests/test_cbd.py @@ -1,5 +1,5 @@ import unittest -from random import seed,randbytes +from random import seed, randbytes from base64 import b64decode from kyber.utils.cbd import cbd diff --git a/tests/test_decrypt.py b/tests/test_decrypt.py index b608df5..31d4ab8 100644 --- a/tests/test_decrypt.py +++ b/tests/test_decrypt.py @@ -1,6 +1,6 @@ import unittest from random import seed, randbytes -from kyber.encryption import Decrypt +from kyber.encryption import decrypt from kyber.constants import k, n, du, dv class TestDecrypt(unittest.TestCase): @@ -10,7 +10,7 @@ def setUp(self): def test_decryption_outputs_valid_shared_secret(self): private_key = randbytes(32*12*k) ciphertext = randbytes(du*k*n//8 + dv*n//8) - shared_secret = Decrypt(private_key, ciphertext).decrypt() + shared_secret = decrypt(private_key, ciphertext) self.assertEqual(type(shared_secret), bytes) self.assertEqual(len(shared_secret), 32) @@ -19,11 +19,11 @@ def test_decryption_raises_with_invalid_private_key(self): invalid_private_key = randbytes(32*12*k + 1) valid_ciphertext = randbytes(du*k*n//8 + dv*n//8) with self.assertRaises(ValueError): - Decrypt(invalid_private_key, valid_ciphertext) + decrypt(invalid_private_key, valid_ciphertext) def test_decryption_raises_with_invalid_ciphertext(self): # this ciphertext is one byte too short valid_private_key = randbytes(32*12*k) invalid_ciphertext = randbytes(du*k*n//8 + dv*n//8 - 1) with self.assertRaises(ValueError): - Decrypt(valid_private_key, invalid_ciphertext) + decrypt(valid_private_key, invalid_ciphertext) diff --git a/tests/test_encryption.py b/tests/test_encryption.py index be71363..df447b7 100644 --- a/tests/test_encryption.py +++ b/tests/test_encryption.py @@ -1,5 +1,5 @@ import unittest -from kyber.encryption import generate_keys, Encrypt, Decrypt +from kyber.encryption import generate_keys, Encrypt, decrypt class TestIntegration(unittest.TestCase): def test_encryption_symmetry(self): @@ -7,6 +7,6 @@ def test_encryption_symmetry(self): private_key, public_key = generate_keys() encrypter = Encrypt(public_key) ciphertext = encrypter.encrypt() - decrypted_shared_secret = Decrypt(private_key, ciphertext).decrypt() + decrypted_shared_secret = decrypt(private_key, ciphertext) self.assertEqual(encrypter.secret, decrypted_shared_secret) self.assertEqual(len(encrypter.secret), 32)