Skip to content

Commit

Permalink
Merge pull request wolfSSL#7712 from SparkiDev/kyber_ml_kem
Browse files Browse the repository at this point in the history
KYBER/ML-KEM: make ML-KEM available
  • Loading branch information
JacobBarthelmeh authored Jul 5, 2024
2 parents e6fbe25 + 1fd9f2a commit 5ca9b2f
Show file tree
Hide file tree
Showing 8 changed files with 4,378 additions and 16 deletions.
6 changes: 6 additions & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,9 @@ do
1024)
ENABLED_KYBER1024=yes
;;
ml-kem)
ENABLED_ML_KEM=yes
;;
*)
AC_MSG_ERROR([Invalid choice for KYBER []: $ENABLED_KYBER.])
break;;
Expand All @@ -1239,6 +1242,9 @@ then
if test "$ENABLED_KYBER1024" = ""; then
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_NO_KYBER1024"
fi
if test "$ENABLED_ML_KEM" = "yes"; then
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_ML_KEM"
fi

if test "$ENABLED_WC_KYBER" = "yes"
then
Expand Down
3,839 changes: 3,839 additions & 0 deletions tests/api.c

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions wolfcrypt/benchmark/benchmark.c
Original file line number Diff line number Diff line change
Expand Up @@ -3593,17 +3593,17 @@ static void* benchmarks_do(void* args)
#ifdef WOLFSSL_HAVE_KYBER
if (bench_all || (bench_pq_asym_algs & BENCH_KYBER)) {
#ifdef WOLFSSL_KYBER512
if (bench_pq_asym_algs & BENCH_KYBER512) {
if (bench_all || (bench_pq_asym_algs & BENCH_KYBER512)) {
bench_kyber(KYBER512);
}
#endif
#ifdef WOLFSSL_KYBER768
if (bench_pq_asym_algs & BENCH_KYBER768) {
if (bench_all || (bench_pq_asym_algs & BENCH_KYBER768)) {
bench_kyber(KYBER768);
}
#endif
#ifdef WOLFSSL_KYBER1024
if (bench_pq_asym_algs & BENCH_KYBER1024) {
if (bench_all || (bench_pq_asym_algs & BENCH_KYBER1024)) {
bench_kyber(KYBER1024);
}
#endif
Expand Down
6 changes: 4 additions & 2 deletions wolfcrypt/src/ext_kyber.c
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
*/
int wc_KyberKey_DecodePrivateKey(KyberKey* key, unsigned char* in, word32 len)
int wc_KyberKey_DecodePrivateKey(KyberKey* key, const unsigned char* in,
word32 len)
{
int ret = 0;
word32 privLen = 0;
Expand Down Expand Up @@ -647,7 +648,8 @@ int wc_KyberKey_DecodePrivateKey(KyberKey* key, unsigned char* in, word32 len)
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
*/
int wc_KyberKey_DecodePublicKey(KyberKey* key, unsigned char* in, word32 len)
int wc_KyberKey_DecodePublicKey(KyberKey* key, const unsigned char* in,
word32 len)
{
int ret = 0;
word32 pubLen = 0;
Expand Down
67 changes: 63 additions & 4 deletions wolfcrypt/src/wc_kyber.c
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
byte msg[2 * KYBER_SYM_SZ];
byte kr[2 * KYBER_SYM_SZ + 1];
int ret = 0;
#ifndef WOLFSSL_ML_KEM
unsigned int ctSz = 0;
#endif

/* Validate parameters. */
if ((key == NULL) || (ct == NULL) || (ss == NULL) || (rand == NULL)) {
Expand All @@ -543,6 +545,7 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
ret = BUFFER_E;
}

#ifndef WOLFSSL_ML_KEM
if (ret == 0) {
/* Establish parameters based on key type. */
switch (key->type) {
Expand All @@ -567,6 +570,7 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
break;
}
}
#endif

/* If public hash (h) is not stored against key, calculate it. */
if ((ret == 0) && ((key->flags & KYBER_FLAG_H_SET) == 0)) {
Expand Down Expand Up @@ -596,8 +600,12 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
}

if (ret == 0) {
#ifndef WOLFSSL_ML_KEM
/* Hash random to anonymize as seed data. */
ret = KYBER_HASH_H(rand, KYBER_SYM_SZ, msg);
#else
XMEMCPY(msg, rand, KYBER_SYM_SZ);
#endif
}
if (ret == 0) {
/* Copy the hash of the public key into msg. */
Expand All @@ -612,6 +620,7 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
ret = kyberkey_encapsulate(key, msg, kr + KYBER_SYM_SZ, ct);
}

#ifndef WOLFSSL_ML_KEM
if (ret == 0) {
/* Hash the cipher text after the seed. */
ret = KYBER_HASH_H(ct, ctSz, kr + KYBER_SYM_SZ);
Expand All @@ -620,6 +629,11 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
/* Derive the secret from the seed and hash of cipher text. */
ret = KYBER_KDF(kr, 2 * KYBER_SYM_SZ, ss, KYBER_SS_SZ);
}
#else
if (ret == 0) {
XMEMCPY(ss, kr, KYBER_SS_SZ);
}
#endif

return ret;
}
Expand Down Expand Up @@ -725,6 +739,39 @@ static KYBER_NOINLINE int kyberkey_decapsulate(KyberKey* key,
return ret;
}

#ifdef WOLFSSL_ML_KEM
/* Derive the secret from z and cipher text.
*
* @param [in] z Implicit rejection value.
* @param [in] ct Cipher text.
* @param [in] ctSz Length of cipher text in bytes.
* @param [out] ss Shared secret.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation failed.
* @return Other negative when a hash error occurred.
*/
static int kyber_derive_secret(const byte* z, const byte* ct, word32 ctSz,
byte* ss)
{
int ret;
wc_Shake shake;

ret = wc_InitShake256(&shake, NULL, INVALID_DEVID);
if (ret == 0) {
ret = wc_Shake256_Update(&shake, z, KYBER_SYM_SZ);
if (ret == 0) {
ret = wc_Shake256_Update(&shake, ct, ctSz);
}
if (ret == 0) {
ret = wc_Shake256_Final(&shake, ss, KYBER_SS_SZ);
}
wc_Shake256_Free(&shake);
}

return ret;
}
#endif

/**
* Decapsulate the cipher text to calculate the shared secret.
*
Expand Down Expand Up @@ -818,6 +865,7 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,
/* Compare generated cipher text with that passed in. */
fail = kyber_cmp(ct, cmp, ctSz);

#ifndef WOLFSSL_ML_KEM
/* Hash the cipher text after the seed. */
ret = KYBER_HASH_H(ct, ctSz, kr + KYBER_SYM_SZ);
}
Expand All @@ -829,6 +877,15 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,

/* Derive the secret from the seed and hash of cipher text. */
ret = KYBER_KDF(kr, 2 * KYBER_SYM_SZ, ss, KYBER_SS_SZ);
#else
ret = kyber_derive_secret(key->z, ct, ctSz, msg);
}
if (ret == 0) {
/* Change seed to z on comparison failure. */
for (i = 0; i < KYBER_SYM_SZ; i++) {
ss[i] = kr[i] ^ ((kr[i] ^ msg[i]) & fail);
}
#endif
}

#ifndef USE_INTEL_SPEEDUP
Expand All @@ -854,13 +911,14 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
*/
int wc_KyberKey_DecodePrivateKey(KyberKey* key, unsigned char* in, word32 len)
int wc_KyberKey_DecodePrivateKey(KyberKey* key, const unsigned char* in,
word32 len)
{
int ret = 0;
word32 privLen = 0;
word32 pubLen = 0;
unsigned int k = 0;
unsigned char* p = in;
const unsigned char* p = in;

/* Validate parameters. */
if ((key == NULL) || (in == NULL)) {
Expand Down Expand Up @@ -938,12 +996,13 @@ int wc_KyberKey_DecodePrivateKey(KyberKey* key, unsigned char* in, word32 len)
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
*/
int wc_KyberKey_DecodePublicKey(KyberKey* key, unsigned char* in, word32 len)
int wc_KyberKey_DecodePublicKey(KyberKey* key, const unsigned char* in,
word32 len)
{
int ret = 0;
word32 pubLen = 0;
unsigned int k = 0;
unsigned char* p = in;
const unsigned char* p = in;

if ((key == NULL) || (in == NULL)) {
ret = BAD_FUNC_ARG;
Expand Down
4 changes: 1 addition & 3 deletions wolfcrypt/src/wc_kyber_poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -2719,9 +2719,6 @@ static void kyber_vec_compress_10_c(byte* r, sword16* v, unsigned int kp)
{
unsigned int i;
unsigned int j;
#ifdef WOLFSSL_KYBER_SMALL
unsigned int k;
#endif

for (i = 0; i < kp; i++) {
/* Reduce each coefficient to mod q. */
Expand All @@ -2736,6 +2733,7 @@ static void kyber_vec_compress_10_c(byte* r, sword16* v, unsigned int kp)
/* Each 4 polynomial coefficients. */
for (j = 0; j < KYBER_N; j += 4) {
#ifdef WOLFSSL_KYBER_SMALL
unsigned int k;
sword16 t[4];
/* Compress four polynomial values to 10 bits each. */
for (k = 0; k < 4; k++) {
Expand Down
Loading

0 comments on commit 5ca9b2f

Please sign in to comment.