]> git.feebdaed.xyz Git - 0xmirror/liboqs.git/commitdiff
Add checks for ML-KEM keys (#2009)
authorAbhinav Saxena <abhinav.saxena@thalesgroup.com>
Wed, 26 Feb 2025 09:31:13 +0000 (15:01 +0530)
committerGitHub <noreply@github.com>
Wed, 26 Feb 2025 09:31:13 +0000 (10:31 +0100)
* add checks for ML-KEM keys

* add mod(3329) using barrett reduction

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>
tests/CMakeLists.txt
tests/vectors_kem.c

index 6d08516a8985cb9a6fe61e74ca212ed8be1389e0..fae241991d0613a188b64437665eb928580ee741 100644 (file)
@@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS})
 add_executable(vectors_kem vectors_kem.c)
 target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS})
 
+if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS)
+    # workaround for Windows .dll
+    if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING)
+        target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition)
+    else()
+        target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE")
+    endif()
+endif()
+
 # Enable Valgrind-based timing side-channel analysis for test_kem and test_sig
 if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD)
     message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.")
index a7a1dc6a7bd1a80474cf73d1296541d22d6eed5d..128d5e840d9310127a055f0970e66d959f12a408 100644 (file)
 #include <sys/stat.h>
 
 #include <oqs/oqs.h>
-
+#include <oqs/sha3.h>
 #include "system_info.c"
 
+#ifdef OQS_ENABLE_KEM_ML_KEM
+/* macros for sanity checks for encaps and decaps key */
+#define ML_KEM_POLYBYTES        384
+#define ML_KEM_K_MAX            4
+#define ML_KEM_N                256
+#define ML_KEM_1024_PK_SIZE     1568
+#define ML_KEM_Q                3329
+#define SHA3_256_OP_LEN         32
+#endif //OQS_ENABLE_KEM_ML_KEM
+
 struct {
        const uint8_t *pos;
 } prng_state = {
        .pos = 0
 };
 
+/* MLKEM-specific functions */
+static inline bool is_ml_kem(const char *method_name) {
+       return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
+              || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
+              || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
+}
+
 static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) {
        size_t i;
        fprintf(fp, "%s", S);
@@ -58,12 +75,106 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) {
        }
 }
 
-/* ML_KEM-specific functions */
-static inline bool is_ml_kem(const char *method_name) {
-       return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
-              || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
-              || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
+#ifdef OQS_ENABLE_KEM_ML_KEM
+/* barret reduction for mod(Q) */
+int16_t barrett_reduce(int16_t a) {
+       const int16_t v = ((1 << 26) + ML_KEM_Q / 2) / ML_KEM_Q;
+       int32_t t = ((int32_t)v * a + (1 << 25)) >> 26;
+       t *= ML_KEM_Q;
+       a -= t;
+
+       int16_t mask = a >> 15;
+       a += (ML_KEM_Q & mask);
+       return a;
+}
+/* fetch value of 'K' from ML-KEM version */
+uint8_t get_ml_kem_k(const char *method) {
+       if (0 == strcmp(method, OQS_KEM_alg_ml_kem_512)) {
+               return 2;
+       } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_768)) {
+               return 3;
+       } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_1024)) {
+               return 4;
+       } else {
+               return 0;  // Default/error case
+       }
+}
+/* sanity check for private/decaps key as specified in section 7.3 of FIPS-203 */
+static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) {
+       /* sanity checks */
+       if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
+               fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name);
+               return false;
+       }
+       /* buffer to hold public key hash */
+       uint8_t pkdig[SHA3_256_OP_LEN] = {0};
+       /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
+       K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
+       uint8_t K = get_ml_kem_k(method_name);
+       if (0 == K) {
+               fprintf(stderr, "K value can be fetched only for ML-KEM !\n");
+               return false;
+       }
+       /* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */
+       OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_POLYBYTES * K), (ML_KEM_POLYBYTES * K) + 32);
+       /* compare it with public key hash stored at 768k+32 offset */
+       if (0 != memcmp(pkdig, sk + (ML_KEM_POLYBYTES * K * 2) + 32, SHA3_256_OP_LEN)) {
+               return false;
+       }
+       return true;
+}
+/* sanity check for public/encaps key as specified in section 7.2 of FIPS-203 */
+static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) {
+       /* sanity checks */
+       if ((NULL == pk) || (0 == pkLen) ||  (NULL == method_name) || (false == is_ml_kem(method_name))) {
+               fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name);
+               return false;
+       }
+       unsigned int i, j;
+       /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
+       K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
+       uint8_t K = get_ml_kem_k(method_name);
+       if (0 == K) {
+               fprintf(stderr, "K value can be fetched only for ML-KEM !\n");
+               return false;
+       }
+       /* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions
+       encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */
+       uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0};
+       /* buffer to hold encoded value */
+       uint8_t buffe[ML_KEM_1024_PK_SIZE - 32] = {0};
+       uint16_t *buff_dec;
+       /* perform byte decoding as per Algo 6 of FIPS 203 */
+       for (i = 0; i < K; i++) {
+               buff_dec = &buffd[i * ML_KEM_N];
+               const uint8_t *curr_pk = &pk[i * ML_KEM_POLYBYTES];
+               for (j = 0; j < ML_KEM_N / 2; j++) {
+                       buff_dec[2 * j + 0] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF;
+                       buff_dec[2 * j + 0] = (uint16_t)barrett_reduce((int16_t)buff_dec[2 * j]);
+                       buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF;
+                       buff_dec[2 * j + 1] = (uint16_t)barrett_reduce((int16_t)buff_dec[2 * j + 1]);
+               }
+       }
+       /* perform byte encoding as per Algo 5 of FIPS 203 */
+       for (i = 0; i < K; i++) {
+               uint16_t t0, t1;
+               buff_dec = &buffd[i * ML_KEM_N];
+               uint8_t *buff_enc = &buffe[i * ML_KEM_POLYBYTES];
+               for (j = 0; j < ML_KEM_N / 2; j++) {
+                       t0 = buff_dec[2 * j];
+                       t1 = buff_dec[2 * j + 1];
+                       buff_enc[3 * j + 0] = (uint8_t)(t0 >> 0);
+                       buff_enc[3 * j + 1] = (uint8_t)((t0 >> 8) | (t1 << 4));
+                       buff_enc[3 * j + 2] = (uint8_t)(t1 >> 4);
+               }
+       }
+       /* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */
+       if (0 != memcmp(buffe, pk, pkLen - 32)) {
+               return false;
+       }
+       return true;
 }
+#endif //OQS_ENABLE_KEM_ML_KEM
 
 static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) {
        (void) personalization_string;
@@ -134,6 +245,13 @@ static OQS_STATUS kem_kg_vector(const char *method_name,
        fprintBstr(fh, "ek: ", public_key, kem->length_public_key);
        fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key);
 
+#ifdef OQS_ENABLE_KEM_ML_KEM
+       if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) {
+               fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name);
+               goto err;
+       }
+#endif //OQS_ENABLE_KEM_ML_KEM
+
        if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) {
                ret = OQS_SUCCESS;
        } else {
@@ -208,6 +326,13 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name,
                goto err;
        }
 
+#ifdef OQS_ENABLE_KEM_ML_KEM
+       if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) {
+               fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name);
+               goto err;
+       }
+#endif //OQS_ENABLE_KEM_ML_KEM
+
        rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk);
        if (rc != OQS_SUCCESS) {
                fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
@@ -273,6 +398,13 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name,
                goto err;
        }
 
+#ifdef OQS_ENABLE_KEM_ML_KEM
+       if (false == sanityCheckSK(encdec_sk, method_name)) {
+               fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name);
+               goto err;
+       }
+#endif //OQS_ENABLE_KEM_ML_KEM
+
        rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk);
        if (rc != OQS_SUCCESS) {
                fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);