diff --git a/.gitignore b/.gitignore index 51f543d..d56d65a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ lib/* test.log .vscode +emulation.log \ No newline at end of file diff --git a/README.md b/README.md index 06e19e8..83c6b1a 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ -### Now updated on PlatformIO registry as digitaldragon/SSLClient@1.1.6 -### Updated on Arduino Libraries registry to digitaldragon/GovoroxSSLClient@1.1.6 +### Now updated on PlatformIO registry as digitaldragon/SSLClient@1.1.7 +### Updated on Arduino Libraries registry to digitaldragon/GovoroxSSLClient@1.1.7 # SSLClient Arduino library using *mbedtls* functions The SSLClient class implements support for secure connections using TLS (SSL). It Provides a transparent SSL wrapper over existing transport object of a **Client** class. diff --git a/Release Note.md b/Release Note.md index 9452bf1..dda2c72 100644 --- a/Release Note.md +++ b/Release Note.md @@ -7,4 +7,5 @@ SSL Client Updates: 5. Fix buffer issue when writing data larger than receiving buffer, Commit: 4ce6c5f 6. Fix issue where client read timeout value not being set, Commit: 59ae9f0 7. Add clarity to return values for start_ssl_client and fix early termination of ssl client, Commit: cc40266 -8. Close issue [#30](https://github.com/govorox/SSLClient/issues/30), Commit: e426936 \ No newline at end of file +8. Close issue [#30](https://github.com/govorox/SSLClient/issues/30), Commit: e426936 +9. Separate concerns from start_ssl_client into singly responsible functions and unit test private API, commit: 0f1fa36 \ No newline at end of file diff --git a/library.json b/library.json index dc1713a..a646a1b 100644 --- a/library.json +++ b/library.json @@ -1,6 +1,6 @@ { "name": "SSLClient", - "version": "1.1.6", + "version": "1.1.7", "repository": { "type": "git", diff --git a/library.properties b/library.properties index 68f9651..fe2b607 100644 --- a/library.properties +++ b/library.properties @@ -1,5 +1,5 @@ name=GovoroxSSLClient -version=1.1.6 +version=1.1.7 author=V Govorovski maintainer=Robert Byrnes sentence=Provides secure network connection over a generic Client trasport object. diff --git a/platformio.ini b/platformio.ini index 84ce5fe..a5e77d3 100644 --- a/platformio.ini +++ b/platformio.ini @@ -16,8 +16,8 @@ test_framework = unity platform = native build_type = test lib_deps = - digitaldragon/Emulation@^0.0.9 - armmbed/mbedtls@^2.23.0 + digitaldragon/Emulation@0.1.5 + # armmbed/mbedtls@^2.23.0 lib_ldf_mode = deep+ build_unflags = -std=gnu++11 build_flags = diff --git a/src/SSLClient.cpp b/src/SSLClient.cpp index e6b9b3b..4d5d9a5 100644 --- a/src/SSLClient.cpp +++ b/src/SSLClient.cpp @@ -87,13 +87,13 @@ SSLClient::~SSLClient() { void SSLClient::stop() { if (sslclient->client != nullptr) { if (sslclient->client >= 0) { - log_i("Stopping ssl client"); + log_d("Stopping ssl client"); stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key); } else { - log_i("stop() not called because client is < 0"); + log_d("stop() not called because client is < 0"); } } else { - log_i("stop() not called because client is nullptr"); + log_d("stop() not called because client is nullptr"); } _connected = false; _peek = -1; @@ -277,7 +277,7 @@ int SSLClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const int SSLClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey) { log_v("start_ssl_client with PSK"); - if(_timeout > 0){ + if (_timeout > 0) { sslclient->handshake_timeout = _timeout; } @@ -366,12 +366,15 @@ size_t SSLClient::write(uint8_t data) { */ size_t SSLClient::write(const uint8_t *buf, size_t size) { if (!_connected) { + log_w("SSLClient is not connected."); return 0; } + log_d("Sending data to SSL connection..."); int res = send_ssl_data(sslclient, buf, size); if (res < 0) { + log_e("Error sending data to SSL connection. Stopping SSLClient..."); stop(); res = 0; } diff --git a/src/ssl_client.cpp b/src/ssl_client.cpp index b592b02..c7c23a8 100644 --- a/src/ssl_client.cpp +++ b/src/ssl_client.cpp @@ -98,28 +98,33 @@ static int client_net_recv( void *ctx, unsigned char *buf, size_t len ) { * \return int The number of bytes received, or a non-zero error code; * with a non-blocking socket, MBEDTLS_ERR_SSL_WANT_READ * indicates read() would block. - * \return int -1 if Client* is nullptr. - * \return int -2 if connect failed. + * \return int 0 if incorrectly called and len = 0, + * \return int -1 if Client* is nullptr. + * \return int -2 if connect failed. */ int client_net_recv_timeout(void *ctx, unsigned char *buf, size_t len, uint32_t timeout) { Client *client = (Client*)ctx; - log_v("Timeout set to %u", timeout); - if (!client) { log_e("Uninitialised!"); return -1; } + if (len == 0) { + log_e("Zero length specified!"); + return 0; + } + + log_v("Timeout set to %u", timeout); + unsigned long start = millis(); unsigned long tms = start + timeout; - do { - int pending = client->available(); - if (pending < len && timeout > 0) { - delay(1); - } else break; - } while (millis() < tms); + int pending = client->available(); + while (pending < len && millis() < tms) { + delay(1); + pending = client->available(); + } int result = client->read(buf, len); @@ -128,7 +133,7 @@ int client_net_recv_timeout(void *ctx, unsigned char *buf, size_t len, uint32_t } log_v("SSL client RX (received=%d expected=%zu in %lums)", result, len, millis()-start); - + if (result > 0) { //esp_log_buffer_hexdump_internal("SSL.RD", buf, (uint16_t)result, ESP_LOG_VERBOSE); } @@ -186,7 +191,8 @@ static int client_net_send(void *ctx, const unsigned char *buf, size_t len) { } } - log_d("SSL client TX res=%d len=%zu", result, len); + log_v("SSL client TX res=%d len=%zu", result, len); + return result; } @@ -197,7 +203,7 @@ static int client_net_send(void *ctx, const unsigned char *buf, size_t len) { * \param client Client* - The client. */ void ssl_init(sslclient_context *ssl_client, Client *client) { - log_v("Init SSL"); + log_d("Init SSL"); // reset embedded pointers to zero memset(ssl_client, 0, sizeof(sslclient_context)); ssl_client->client = client; @@ -207,22 +213,69 @@ void ssl_init(sslclient_context *ssl_client, Client *client) { } /** - * \brief Start the ssl client. - * - * \param ssl_client sslclient_context* - The ssl client context. - * \param host const char* - The host to connect to. - * \param port uint32_t - The port to connect to. - * \param timeout int - The timeout in milliseconds. - * \param rootCABuff const char* - The root CA certificate. - * \param cli_cert const char* - The client certificate. - * \param cli_key const char*- The client key. - * \param pskIdent const char* - The PSK identity. - * \param psKey const char* - The PSK key. - * \return int 1 if successful. - * \return int -1 if Client* is nullptr. - * \return int -2 if connect failed. - * \return int -3 if PSK key is invalid. - * \return int -4 if SSL handshake timeout. + * \brief Cleans up allocated resources and stops the SSL socket if an error occurred. + * + * \param ssl_client Pointer to the SSL client context. + * \param ca_cert_initialized Flag indicating if the CA certificate was initialized. + * \param client_cert_initialized Flag indicating if the client certificate was initialized. + * \param client_key_initialized Flag indicating if the client key was initialized. + * \param ret Return value from the previous operations. + * \param rootCABuff Pointer to the root CA buffer. + * \param cli_cert Pointer to the client certificate. + * \param cli_key Pointer to the client key. + */ +void cleanup( + sslclient_context *ssl_client, + bool ca_cert_initialized, + bool client_cert_initialized, + bool client_key_initialized, + int ret, + const char *rootCABuff, + const char *cli_cert, + const char *cli_key +) { + if (ca_cert_initialized) { + mbedtls_x509_crt_free(&ssl_client->ca_cert); + } + if (client_cert_initialized) { + mbedtls_x509_crt_free(&ssl_client->client_cert); + } + if (client_key_initialized) { + mbedtls_pk_free(&ssl_client->client_key); + } + if (ret != 0) { + stop_ssl_socket(ssl_client, rootCABuff, cli_cert, cli_key); // Stop SSL socket on error + } + log_d("Free internal heap after TLS %u", ESP.getFreeHeap()); +} + +/** + * \brief Logs information about a failed certificate verification. + * + * \param flags Flags returned from the certificate verification process. + */ +void log_failed_cert(int flags) { + if (flags != 0) { + char buf[512]; + memset(buf, 0, sizeof(buf)); + mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", flags); + log_e("Failed to verify peer certificate! verification info: %s", buf); + } +} + +/** + * \brief Starts the SSL client, handling initialization, authentication, and connection processes. + * + * \param ssl_client Pointer to the SSL client context. + * \param host Pointer to the host string. + * \param port Port number for the connection. + * \param timeout Timeout value for the connection. + * \param rootCABuff Pointer to the root CA buffer. + * \param cli_cert Pointer to the client certificate. + * \param cli_key Pointer to the client key. + * \param pskIdent Pointer to the PSK identifier.s + * \param psKey Pointer to the PSK key. + * \return 1 on successful SSL client start, 0 otherwise. */ int start_ssl_client( sslclient_context *ssl_client, @@ -239,244 +292,491 @@ int start_ssl_client( log_v("Connecting to %s:%d", host, port); int ret = 0; // for mbedtls function return values - int func_ret = 0; // for start_ssl_client return values bool ca_cert_initialized = false; bool client_cert_initialized = false; bool client_key_initialized = false; - bool breakBothLoops = false; - - do { // executes once, breaks on error... - // Step 1 - Initiate TCP connection - Client *pClient = ssl_client->client; - if (!pClient) { - log_e("Client pointer is null."); - func_ret = -1; + do { + ret = init_tcp_connection(ssl_client, host, port); // Step 1 - Initiate TCP connection + if (ret != 0) { + break; + } + ret = seed_random_number_generator(ssl_client); // Step 2 - Seed the random number generator + if (ret == MBEDTLS_ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED || ret != 0) { break; } - - log_v("Client pointer: %p", (void*) pClient); // log_v - - if (!pClient->connect(host, port)) { - log_e("Connection to server failed!"); - func_ret = -2; + log_v("Random number generator seeded, ret: %d", ret); + ret = set_up_tls_defaults(ssl_client); // Step 3 - Set up the SSL/TLS defaults + if (ret != 0) { // MBEDTLS_ERR_XXX_ALLOC_FAILED undefined? + break; + } + log_v("SSL config defaults set, ret: %d", ret); + ret = auth_root_ca_buff(ssl_client, rootCABuff, &ca_cert_initialized, pskIdent, psKey); // Step 4 route a - Set up required auth mode rootCaBuff + if (ret != 0) { break; } + log_v("SSL auth mode set, ret: %d", ret); + ret = auth_client_cert_key(ssl_client, cli_cert, cli_key, &client_cert_initialized, &client_key_initialized); // Step 4 route b - Set up required auth mode cli_cert and cli_key + if (ret != 0) { + break; + } + log_v("SSL client cert and key set, ret: %d", ret); + ret = set_hostname_for_tls(ssl_client, host); // Step 5 - Set the hostname for a TLS session + if (ret != 0) { + break; + } + log_v("SSL hostname set, ret: %d", ret); + ret = set_io_callbacks_and_timeout(ssl_client, timeout); // Step 6 - Configure IO callbacks and set a read timeout for the SSL client context + if (ret != 0) { + break; + } + log_v("SSL IO callbacks and timeout set, ret: %d", ret); + ret = perform_ssl_handshake(ssl_client, cli_cert, cli_key); // Step 7 - Perform SSL/TLS handshake + if (ret != 0) { + break; + } + int flags = verify_server_cert(ssl_client); // Step 8 - Verify the server certificate + if (ret != 0) { + log_failed_cert(flags); + } else { + log_v("Certificate verified."); + } + } while (0); // do once, force break on error - // Step 2 - Seed the random number generator - log_v("Seeding the random number generator"); - mbedtls_entropy_init(&ssl_client->entropy_ctx); - log_v("Entropy context initialized"); // log_v + // Step 9 - Cleanup and return + cleanup(ssl_client, ca_cert_initialized, client_cert_initialized, client_key_initialized, ret, rootCABuff, cli_cert, cli_key); - ret = mbedtls_ctr_drbg_seed(&ssl_client->drbg_ctx, mbedtls_entropy_func, - &ssl_client->entropy_ctx, (const unsigned char *) pers, strlen(pers)); + if (ret == 0) { + return 1; + } - if (ret == MBEDTLS_ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED || ret != 0) { - break; - } + handle_error(ret); + return 0; +} - log_v("Random number generator seeded, ret: %d", ret); // log_v +/** + * \brief Initializes a TCP connection to a remote host on the specified port. + * + * \param ssl_client sslclient_context* - The SSL client context. + * \param host const char* - The host to connect to. + * \param port uint32_t - The port to connect to. + * + * \return int 0 if the TCP connection is successfully established. + * \return int -1 if the SSL client's Client pointer is null. + * \return int -2 if the connection to the server failed. + * + * This function initiates a TCP connection to a remote host on the specified port using the provided + * SSL client context. It checks if the Client pointer within the context is valid, attempts to + * establish the TCP connection, and returns appropriate error codes if any issues are encountered. + */ +int init_tcp_connection(sslclient_context *ssl_client, const char *host, uint32_t port) { + Client *pClient = ssl_client->client; + if (!pClient) { + log_e("Client pointer is null."); + return -1; + } - // Step 3 - Set up the SSL/TLS defaults - log_v("Setting up the SSL/TLS defaults..."); + log_v("Client pointer: %p", (void*) pClient); - ret = mbedtls_ssl_config_defaults(&ssl_client->ssl_conf, - MBEDTLS_SSL_IS_CLIENT, - MBEDTLS_SSL_TRANSPORT_STREAM, - MBEDTLS_SSL_PRESET_DEFAULT); - if (ret != 0) { // MBEDTLS_ERR_XXX_ALLOC_FAILED undefined? - break; - } + if (!pClient->connect(host, port)) { + log_e("Connection to server failed!"); + return -2; + } - log_v("SSL config defaults set, ret: %d", ret); + return 0; +} - // Step 4 route a - Set up required auth mode rootCaBuff - if (rootCABuff != NULL) { - log_v("Loading CA cert"); - mbedtls_x509_crt_init(&ssl_client->ca_cert); - mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); - ret = mbedtls_x509_crt_parse(&ssl_client->ca_cert, (const unsigned char *)rootCABuff, strlen(rootCABuff) + 1); +/** + * \brief Seed the random number generator for SSL/TLS operations. + * + * \param ssl_client sslclient_context* - The SSL client context. + * + * \return int 0 if the random number generator is successfully seeded. + * \return int An error code if the seeding process fails. + * + * This function initializes the random number generator used in SSL/TLS operations. + * It sets up the entropy source and uses it to seed the deterministic random bit generator (DRBG). + * The DRBG is essential for generating secure cryptographic keys and nonces during SSL/TLS + * communication. If successful, the function returns 0; otherwise, it returns an error code. + */ +int seed_random_number_generator(sslclient_context *ssl_client) { + log_v("Seeding the random number generator"); + mbedtls_entropy_init(&ssl_client->entropy_ctx); + log_v("Entropy context initialized"); + int ret = mbedtls_ctr_drbg_seed(&ssl_client->drbg_ctx, mbedtls_entropy_func, + &ssl_client->entropy_ctx, (const unsigned char *) pers, strlen(pers)); + return ret; +} - if (ret < 0) { - break; // if ret > 0 n certs failed, ret < 0 pem or x509 error code. - } +/** + * \brief Set up SSL/TLS configuration with default settings. + * + * \param ssl_client sslclient_context* - The SSL client context. + * + * \return int 0 if SSL/TLS configuration is successfully set up with defaults. + * \return int An error code if the setup process fails. + * + * This function configures SSL/TLS settings with default values, including specifying that + * it's used as a client, operating in a stream transport mode, and applying the default preset. + * The SSL/TLS configuration is essential for establishing secure communication over the network. + * If successful, the function returns 0; otherwise, it returns an error code. + */ +int set_up_tls_defaults(sslclient_context *ssl_client) { + log_v("Setting up the SSL/TLS defaults..."); - mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL); - // mbedtls_ssl_conf_verify(&ssl_client->ssl_ctx, my_verify, NULL ); - - ca_cert_initialized = true; - - } else if (pskIdent != NULL && psKey != NULL) { - log_v("Setting up PSK"); - - // convert PSK from hex to binary - if ((strlen(psKey) & 1) != 0 || strlen(psKey) > 2*MBEDTLS_PSK_MAX_LEN) { - log_e("pre-shared key not valid hex or too long"); - func_ret = -3; - break; - } + int ret = mbedtls_ssl_config_defaults(&ssl_client->ssl_conf, MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); + return ret; +} - unsigned char psk[MBEDTLS_PSK_MAX_LEN]; - size_t psk_len = strlen(psKey)/2; - - for (int j=0; j= '0' && c <= '9') c -= '0'; - else if (c >= 'A' && c <= 'F') c -= 'A' - 10; - else if (c >= 'a' && c <= 'f') c -= 'a' - 10; - else return -1; - psk[j/2] = c<<4; - c = psKey[j+1]; - if (c >= '0' && c <= '9') c -= '0'; - else if (c >= 'A' && c <= 'F') c -= 'A' - 10; - else if (c >= 'a' && c <= 'f') c -= 'a' - 10; - else return -1; - psk[j/2] |= c; - } +/** + * \brief Configure SSL/TLS authentication options based on provided parameters. + * + * \param ssl_client sslclient_context* - The SSL client context. + * \param rootCABuff const char* - The root CA certificate buffer. + * \param ca_cert_initialized bool* - Indicates whether CA certificate is initialized. + * \param pskIdent const char* - The PSK identity. + * \param psKey const char* - The PSK key. + * \param func_ret int* - Pointer to an integer to hold the return value. + * + * \return int 0 if the SSL/TLS authentication options are configured successfully. + * \return int An error code if the configuration process fails. + * + * This function configures SSL/TLS authentication options based on the provided parameters. + * If `rootCABuff` is not NULL, it loads the root CA certificate and configures SSL/TLS to + * require verification. If `pskIdent` and `psKey` are not NULL, it sets up a pre-shared key + * (PSK) for authentication. If none of the options are provided, it configures SSL/TLS with + * no verification. The function may modify the value pointed to by `func_ret` to indicate errors. + * If successful, the function returns 0; otherwise, it returns an error code, -1 for a null context. + */ +int auth_root_ca_buff(sslclient_context *ssl_client, const char *rootCABuff, bool *ca_cert_initialized, + const char *pskIdent, const char *psKey) { + if (ssl_client == nullptr) { + log_e("Uninitialised context!"); + return -1; + } - // set mbedtls config - ret = mbedtls_ssl_conf_psk(&ssl_client->ssl_conf, psk, psk_len, - (const unsigned char *)pskIdent, strlen(pskIdent)); - if (ret != 0) { // MBEDTLS_ERR_SSL_XXX undefined? - log_e("mbedtls_ssl_conf_psk returned %d", ret); - break; - } - } else { - mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE); - log_i("WARNING: Use certificates for a more secure communication!"); + int ret = 0; + if (rootCABuff != nullptr) { + log_v("Loading CA cert"); + mbedtls_x509_crt_init(&ssl_client->ca_cert); + mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); + ret = mbedtls_x509_crt_parse(&ssl_client->ca_cert, (const unsigned char *)rootCABuff, strlen(rootCABuff) + 1); + + if (ret < 0) { + // if ret > 0 n certs failed, ret < 0 pem or x509 error code. + return ret; } - // Step 4 route b - Set up required auth mode cli_cert and cli_key - if (cli_cert != NULL && cli_key != NULL) { - mbedtls_x509_crt_init(&ssl_client->client_cert); - mbedtls_pk_init(&ssl_client->client_key); + mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL); - log_v("Loading CRT cert"); - ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); - if (ret != 0) { - break; // if ret > 0 n certs failed, ret < 0 pem or x509 error code. - } else { - client_cert_initialized = true; - } + if (ca_cert_initialized != nullptr) { + *ca_cert_initialized = true; + } else { + log_e("ca_cert_initialized is null!"); + return -1; + } + + } else if (pskIdent != nullptr && psKey != nullptr) { + log_v("Setting up PSK"); + + // convert PSK from hex to binary + if ((strlen(psKey) & 1) != 0 || strlen(psKey) > 2*MBEDTLS_PSK_MAX_LEN) { + log_e("pre-shared key not valid hex or too long"); + return -1; + } - log_v("Loading private key"); - ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0); - if (ret != 0) { // PK or PEM non-zero error codes - mbedtls_x509_crt_free(&ssl_client->client_cert); // cert+key are free'd in pair - break; - } else { - client_key_initialized = true; - } + unsigned char psk[MBEDTLS_PSK_MAX_LEN]; + size_t psk_len = strlen(psKey)/2; + + for (int j=0; j= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] = c<<4; + c = psKey[j+1]; + if (c >= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] |= c; + } - ret = mbedtls_ssl_conf_own_cert(&ssl_client->ssl_conf, &ssl_client->client_cert, &ssl_client->client_key); - if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED || ret != 0) { - break; - } + // set mbedtls config + ret = mbedtls_ssl_conf_psk(&ssl_client->ssl_conf, psk, psk_len, + (const unsigned char *)pskIdent, strlen(pskIdent)); + if (ret != 0) { // MBEDTLS_ERR_SSL_XXX undefined? + log_e("mbedtls_ssl_conf_psk returned %d", ret); + return ret; } + } else { + mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE); + log_w("WARNING: Use certificates for a more secure communication!"); + ret = 0; + } + return ret; +} - // Step 5 - Set hostname for TLS session - log_v("Setting hostname for TLS session..."); +/** + * \brief Authenticate the client by initializing certificates and keys. + * + * This function initializes and loads the client's certificate and private key into + * the provided SSL client context. It also provides a status of the initialization + * of the certificate and key. + * + * \param[in,out] ssl_client Pointer to the SSL client context. + * \param[in] cli_cert Pointer to the client certificate in string format. + * \param[in] cli_key Pointer to the client private key in string format. + * \param[out] client_cert_initialized Pointer to a boolean indicating if the client certificate was initialized. + * \param[out] client_key_initialized Pointer to a boolean indicating if the client key was initialized. + * + * \return 0 if successful, or a non-zero error code indicating a failure during the initialization or parsing. + * Positive error codes indicate number of certs that failed. + * Negative error codes indicate a PEM or x509 error. + */ +int auth_client_cert_key(sslclient_context *ssl_client, const char *cli_cert, const char *cli_key, bool *client_cert_initialized, bool *client_key_initialized) { + int ret = 0; + // Step 4 route b - Set up required auth mode cli_cert and cli_key + if (cli_cert != NULL && cli_key != NULL) { + mbedtls_x509_crt_init(&ssl_client->client_cert); + mbedtls_pk_init(&ssl_client->client_key); + + log_v("Loading CRT cert"); + ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); + if (ret != 0) { + // if ret > 0 n certs failed, ret < 0 pem or x509 error code. + return ret; + } else { + *client_cert_initialized = true; + } - // Hostname set here should match CN in server certificate - ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host); - - if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED || ret == MBEDTLS_ERR_SSL_BAD_INPUT_DATA || ret != 0) { - break; + log_v("Loading private key"); + ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0); + if (ret != 0) { // PK or PEM non-zero error codes + mbedtls_x509_crt_free(&ssl_client->client_cert); // cert+key are free'd in pair + return ret; + } else { + *client_key_initialized = true; } - mbedtls_ssl_conf_rng(&ssl_client->ssl_conf, mbedtls_ctr_drbg_random, &ssl_client->drbg_ctx); + ret = mbedtls_ssl_conf_own_cert(&ssl_client->ssl_conf, &ssl_client->client_cert, &ssl_client->client_key); + } + return ret; +} - ret = mbedtls_ssl_setup(&ssl_client->ssl_ctx, &ssl_client->ssl_conf); +/** + * \brief Set the hostname for a TLS session. + * + * This function sets the hostname for a TLS session which should match + * the Common Name (CN) in the server certificate to ensure the identity + * of the remote host. It configures the provided SSL client context + * with the hostname and sets up the SSL context with the necessary + * configurations. + * + * \param ssl_client A pointer to the sslclient_context structure + * representing the SSL client context. + * \param host A pointer to a character string representing the hostname. + * + * \return int Returns 0 on success. On failure, it returns + * MBEDTLS_ERR_SSL_ALLOC_FAILED if there's a memory allocation + * failure, MBEDTLS_ERR_SSL_BAD_INPUT_DATA for bad input data, + * or other mbedtls error codes as defined in mbedtls error header file. + * + * \note The hostname set should match the CN in the server certificate. + * + * Usage: + * \code + * sslclient_context ssl_client; + * const char *host = "example.com"; + * int ret = set_hostname_for_tls(&ssl_client, host); + * if(ret != 0) { + * // handle error + * } + * \endcode + */ +int set_hostname_for_tls(sslclient_context *ssl_client, const char *host) { + int ret; + log_v("Setting hostname for TLS session..."); + + // Hostname set here should match CN in server certificate + ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host); + + if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED || ret == MBEDTLS_ERR_SSL_BAD_INPUT_DATA || ret != 0) { + log_e("Failed to set hostname for tls session"); + return ret; + } - if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED || ret != 0) { - break; - } + mbedtls_ssl_conf_rng(&ssl_client->ssl_conf, mbedtls_ctr_drbg_random, &ssl_client->drbg_ctx); - // Step 6 - Set up the I/O callbacks (this is the heart of it) - log_v("Setting up IO callbacks..."); - mbedtls_ssl_set_bio(&ssl_client->ssl_ctx, ssl_client->client, - client_net_send, NULL, client_net_recv_timeout ); - - log_v("Setting timeout to %i", timeout); - mbedtls_ssl_conf_read_timeout(&ssl_client->ssl_conf, timeout); + ret = mbedtls_ssl_setup(&ssl_client->ssl_ctx, &ssl_client->ssl_conf); - // Step 7 - Perform the SSL/TLS handshake - log_v("Performing the SSL/TLS handshake..."); - unsigned long handshake_start_time = millis(); + return ret; +} - while ((ret = mbedtls_ssl_handshake(&ssl_client->ssl_ctx)) != 0) { - if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { - break; - } - if ((millis()-handshake_start_time) > ssl_client->handshake_timeout) { - log_e("SSL handshake timeout"); - func_ret = -4; - breakBothLoops = true; - break; - } - vTaskDelay(10 / portTICK_PERIOD_MS); - } +/** + * \brief Configures IO callbacks and sets a read timeout for the SSL client context. + * + * This function sets up the IO callbacks for sending, receiving, and receiving with timeout + * for the provided SSL client context. It also configures the read timeout for the SSL client context. + * + * \param ssl_client A pointer to the sslclient_context structure representing the SSL client context. + * \param timeout The timeout value in milliseconds for reading operations. + * + * \return int Returns 0 on success, -1 + * + * Usage: + * \code + * sslclient_context ssl_client; + * int timeout = 5000; // 5 seconds + * int ret = set_io_callbacks_and_timeout(&ssl_client, timeout); + * if (ret != 0) { + * // handle error + * } + * \endcode + * + * \note The function assumes that the sslclient_context structure is properly initialized and the + * client_net_send, client_net_recv, and client_net_recv_timeout functions are correctly implemented. + */ +int set_io_callbacks_and_timeout(sslclient_context *ssl_client, int timeout) { + if (ssl_client == nullptr) { + log_e("Uninitialised context!"); + return -1; + } + + if (timeout < 0) { + log_e("Invalid timeout value"); + return -2; + } - if (breakBothLoops) { - break; // break the outer do-while loop - } + log_v("Setting up IO callbacks..."); + mbedtls_ssl_set_bio(&ssl_client->ssl_ctx, ssl_client->client, client_net_send, NULL, client_net_recv_timeout); - if (cli_cert != NULL && cli_key != NULL) { - log_v("Protocol is %s Ciphersuite is %s", mbedtls_ssl_get_version(&ssl_client->ssl_ctx), mbedtls_ssl_get_ciphersuite(&ssl_client->ssl_ctx)); - ret = mbedtls_ssl_get_record_expansion(&ssl_client->ssl_ctx); - if (ret != 0) { - if (ret == MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE) { - log_w("Record expansion is not available (compression)"); - } else { - log_e(" mbedtls_ssl_get_record_expansion returned -0x%x", -ret); - } - break; - } else { - log_w("Record expansion is unknown (compression)"); - } - } + log_v("Setting timeout to %i", timeout); + mbedtls_ssl_conf_read_timeout(&ssl_client->ssl_conf, timeout); - // Step 8 - Verify the server certificate - log_v("Verifying peer X.509 certificate..."); + return 0; +} - int flags = mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx); +/** + * \brief Performs the SSL/TLS handshake for a given SSL client context. + * + * This function initiates and manages the SSL/TLS handshake process. It also checks for + * timeout conditions and handles client certificate and key if provided. + * + * \param ssl_client A pointer to the sslclient_context structure representing the SSL client context. + * \param func_ret A pointer to an integer where a specific error code can be stored for further analysis. + * \param cli_cert A pointer to a character string representing the client's certificate. If not needed, pass NULL. + * \param cli_key A pointer to a character string representing the client's private key. If not needed, pass NULL. + * + * \return int Returns 0 on successful handshake completion. Returns -1 if the handshake process + * times out. Returns a mbedtls error code if any other error occurs during the handshake process. + * + * Usage: + * \code + * sslclient_context ssl_client; + * int func_ret = 0; + * const char *cli_cert = "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----"; + * const char *cli_key = "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----"; + * int ret = perform_ssl_handshake(&ssl_client, &func_ret, cli_cert, cli_key); + * if(ret != 0) { + * // handle error + * } + * \endcode + * + * \note This function assumes that the sslclient_context structure is properly initialized and the + * mbedtls libraries are correctly configured. + */ +int perform_ssl_handshake(sslclient_context *ssl_client, const char *cli_cert, const char *cli_key) { + if (ssl_client == nullptr) { + log_e("Uninitialised context!"); + return -1; + } - if (ret != 0) { - char buf[512]; - memset(buf, 0, sizeof(buf)); - mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", flags); - log_e("Failed to verify peer certificate! verification info: %s", buf); - stop_ssl_socket(ssl_client, rootCABuff, cli_cert, cli_key); // It's not safe continue. + int ret = 0; + bool breakBothLoops = false; + log_v("Performing the SSL/TLS handshake, timeout %lu ms", ssl_client->handshake_timeout); + unsigned long handshake_start_time = millis(); + log_d("calling mbedtls_ssl_handshake with ssl_ctx address %p", (void *)&ssl_client->ssl_ctx); + + while ((ret = mbedtls_ssl_handshake(&ssl_client->ssl_ctx)) != 0) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { break; - } else { - log_v("Certificate verified."); } - } while (0); // executes once, breaks on error... + if ((millis()-handshake_start_time) > ssl_client->handshake_timeout) { + log_e("SSL handshake timeout"); + breakBothLoops = true; + break; + } - // Step 9 - Cleanup and return - if (ca_cert_initialized) { - mbedtls_x509_crt_free(&ssl_client->ca_cert); + vTaskDelay(10 / portTICK_PERIOD_MS); } - if (client_cert_initialized) { - mbedtls_x509_crt_free(&ssl_client->client_cert); + if (breakBothLoops) { + return -1; } - if (client_key_initialized) { - mbedtls_pk_free(&ssl_client->client_key); + if (cli_cert != NULL && cli_key != NULL) { + log_v("Protocol is %s Ciphersuite is %s", mbedtls_ssl_get_version(&ssl_client->ssl_ctx), mbedtls_ssl_get_ciphersuite(&ssl_client->ssl_ctx)); + ret = mbedtls_ssl_get_record_expansion(&ssl_client->ssl_ctx); + if (ret != 0) { + if (ret == MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE) { + log_w("Record expansion is not available (compression)"); + } else { + log_e("mbedtls_ssl_get_record_expansion returned -0x%x", -ret); + } + } else { + log_w("Record expansion is unknown (compression)"); + } } + return ret; +} - log_v("Free internal heap after TLS %u", ESP.getFreeHeap()); - - if (ret < 0) { - return handle_error(ret); - stop_ssl_socket(ssl_client, rootCABuff, cli_cert, cli_key); - } else { - func_ret = 1; +/** + * \brief Verifies the server's certificate using the provided SSL client context. + * + * This function performs a verification of the server's certificate to ensure it's valid and trustworthy. + * The verification process checks the server certificate against the provided root CA. + * If client certificate and key are provided, they can be used for further verification or cleanup. + * + * \param ssl_client A pointer to the sslclient_context structure representing the SSL client context. + * \param ret The return value of the mbedtls_ssl_handshake function. + * \param rootCABuff A pointer to a character string containing the root CA certificate. + * \param cli_cert A pointer to a character string representing the client's certificate. If not needed, pass NULL. + * \param cli_key A pointer to a character string representing the client's private key. If not needed, pass NULL. + * + * \return int Returns 0 on successful verification. Returns a non-zero error code on failure, + * which can be obtained from the mbedtls library. + * + * Usage: + * \code + * sslclient_context ssl_client; + * const char *rootCABuff = "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----"; + * const char *cli_cert = "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----"; + * const char *cli_key = "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----"; + * int ret = verify_server_cert(&ssl_client, rootCABuff, cli_cert, cli_key); + * if(ret != 0) { + * // handle error + * } + * \endcode + * + * \note This function assumes that the sslclient_context structure is properly initialized and the + * mbedtls libraries are correctly configured. Also, ensure that the root CA certificate is correct + * and corresponds to the CA that issued the server's certificate. + */ +int verify_server_cert(sslclient_context *ssl_client) { + if (ssl_client == nullptr) { + log_e("Uninitialised context!"); + return -1; } + + log_v("Verifying peer X.509 certificate..."); + + int flags = mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx); - return func_ret; + return flags; } /** @@ -488,21 +788,23 @@ int start_ssl_client( * \param cli_key const char* - The client key. */ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) { - log_v("Cleaning SSL connection."); - log_v("Stopping SSL client. Current client pointer address: %p", (void *)ssl_client->client); + log_d("Cleaning SSL connection."); // Stop the client connection - ssl_client->client->stop(); + if (ssl_client && ssl_client->client) { + log_d("Stopping SSL client. Current client pointer address: %p", (void *)ssl_client->client); + ssl_client->client->stop(); + } if (ssl_client->ssl_conf.ca_chain != NULL) { - log_v("Freeing CA cert. Current ca_cert address: %p", (void *)&ssl_client->ca_cert); + log_d("Freeing CA cert. Current ca_cert address: %p", (void *)&ssl_client->ca_cert); // Free the memory associated with the CA certificate mbedtls_x509_crt_free(&ssl_client->ca_cert); } if (ssl_client->ssl_conf.key_cert != NULL) { - log_v("Freeing client cert and client key. Current client_cert address: %p, client_key address: %p", + log_d("Freeing client cert and client key. Current client_cert address: %p, client_key address: %p", (void *)&ssl_client->client_cert, (void *)&ssl_client->client_key); // Free the memory associated with the client certificate and key @@ -511,23 +813,21 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons } // Free other SSL-related contexts and log their current addresses - log_v("Freeing SSL context. Current ssl_ctx address: %p", (void *)&ssl_client->ssl_ctx); + log_d("Freeing SSL context. Current ssl_ctx address: %p", (void *)&ssl_client->ssl_ctx); mbedtls_ssl_free(&ssl_client->ssl_ctx); - log_v("Freeing SSL config. Current ssl_conf address: %p", (void *)&ssl_client->ssl_conf); + log_d("Freeing SSL config. Current ssl_conf address: %p", (void *)&ssl_client->ssl_conf); mbedtls_ssl_config_free(&ssl_client->ssl_conf); - log_v("Freeing DRBG context. Current drbg_ctx address: %p", (void *)&ssl_client->drbg_ctx); + log_d("Freeing DRBG context. Current drbg_ctx address: %p", (void *)&ssl_client->drbg_ctx); mbedtls_ctr_drbg_free(&ssl_client->drbg_ctx); - log_v("Freeing entropy context. Current entropy_ctx address: %p", (void *)&ssl_client->entropy_ctx); + log_d("Freeing entropy context. Current entropy_ctx address: %p", (void *)&ssl_client->entropy_ctx); mbedtls_entropy_free(&ssl_client->entropy_ctx); - // log_v("Resetting embedded pointers to zero for ssl_client at address: %p", (void *)ssl_client); - // memset(ssl_client, 0, sizeof(sslclient_context)); + log_d("Finished cleaning SSL connection."); } - /** * \brief Check if there is data to read or not. * @@ -536,10 +836,13 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons */ int data_to_read(sslclient_context *ssl_client) { int ret, res; + ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, NULL, 0); - //log_e("RET: %i",ret); //for low level debug + log_v("RET: %i",ret); // for low level debug + res = mbedtls_ssl_get_bytes_avail(&ssl_client->ssl_ctx); - //log_e("RES: %i",res); //for low level debug + log_v("RES: %i",res); // for low level debug + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) { return handle_error(ret); } @@ -556,7 +859,6 @@ int data_to_read(sslclient_context *ssl_client) { * \return int The number of bytes sent. */ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len) { - // Log contents of ssl_client if(ssl_client != nullptr) { log_v("ssl_client->client: %p", (void *)ssl_client->client); log_v("ssl_client->handshake_timeout: %lu", ssl_client->handshake_timeout); @@ -565,7 +867,13 @@ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len return -1; } - log_v("Writing SSL (%zu bytes)...", len); // for low level debug + log_v("Writing SSL (%zu bytes)...", len); + + // Print the data being sent + // for (size_t i = 0; i < len; i++) { + // log_v("Data[%zu]: %02X", i, data[i]); + // } + int ret = -1; while ((ret = mbedtls_ssl_write(&ssl_client->ssl_ctx, data, len)) <= 0) { @@ -575,7 +883,7 @@ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len } len = ret; - log_v("%zu bytes written", len); // for low level debug + log_v("%zu bytes written", len); return ret; } @@ -585,15 +893,15 @@ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len * \param ssl_client sslclient_context* - The ssl client context. * \param data uint8_t* - The data to receive. * \param length int - The length of the data. - * \return size_t The number of bytes received. + * \return size_t The number of bytes received. */ int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, size_t length) { - log_v( "Reading SSL (%d bytes)", length); //for low level debug + log_v( "Reading SSL (%d bytes)", length); int ret = -1; ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length); - log_v( "%d bytes read", ret); //for low level debug + log_v( "%d bytes read", ret); return ret; } @@ -604,7 +912,7 @@ int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, size_t length) * \param res uint8_t* - The data to receive. * \return bool True if the data was received, false otherwise. */ -static bool parseHexNibble(char pb, uint8_t* res) { +static bool parse_hex_nibble(char pb, uint8_t* res) { if (pb >= '0' && pb <= '9') { *res = (uint8_t) (pb - '0'); return true; } else if (pb >= 'a' && pb <= 'f') { @@ -622,14 +930,17 @@ static bool parseHexNibble(char pb, uint8_t* res) { * \param domainName const string& - The domain name. * \return bool True if the name from certificate and domain name match, false otherwise. */ -static bool matchName(const string& name, const string& domainName) { - size_t wildcardPos = name.find('*'); +static bool match_name(const string& name, const string& domainName) { + size_t wildcardPos = name.find("*"); + if (wildcardPos == (size_t)12) { + return false; // We don't support wildcards for subdomains + } if (wildcardPos == string::npos) { // Not a wildcard, expect an exact match return name == domainName; } - size_t firstDotPos = name.find('.'); + size_t firstDotPos = name.find("."); if (wildcardPos > firstDotPos) { // Wildcard is not part of leftmost component of domain name // Do not attempt to match (rfc6125 6.4.3.1) @@ -657,8 +968,7 @@ static bool matchName(const string& name, const string& domainName) { * \param domain_name const char* - The domain name. * \return bool True if the certificate matches the fingerprint, false otherwise. */ -bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name) -{ +bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name) { // Convert hex string to byte array uint8_t fingerprint_local[32]; int len = strlen(fp); @@ -667,15 +977,18 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const while (pos < len && ((fp[pos] == ' ') || (fp[pos] == ':'))) { ++pos; } + if (pos > len - 2) { - log_v("pos:%d len:%d fingerprint too short", pos, len); + log_d("pos:%d len:%d fingerprint too short", pos, len); return false; } + uint8_t high, low; - if (!parseHexNibble(fp[pos], &high) || !parseHexNibble(fp[pos+1], &low)) { - log_v("pos:%d len:%d invalid hex sequence: %c%c", pos, len, fp[pos], fp[pos+1]); + if (!parse_hex_nibble(fp[pos], &high) || !parse_hex_nibble(fp[pos+1], &low)) { + log_d("pos:%d len:%d invalid hex sequence: %c%c", pos, len, fp[pos], fp[pos+1]); return false; } + pos += 2; fingerprint_local[i] = low | (high << 4); } @@ -684,7 +997,7 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx); if (!crt) { - log_v("could not fetch peer certificate"); + log_w("could not fetch peer certificate"); return false; } @@ -698,7 +1011,7 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const // Check if fingerprints match if (memcmp(fingerprint_local, fingerprint_remote, 32)) { - log_d("fingerprint doesn't match"); + log_w("fingerprint doesn't match"); return false; } @@ -732,7 +1045,7 @@ bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name) string san_str((const char*)san->buf.p, san->buf.len); transform(san_str.begin(), san_str.end(), san_str.begin(), ::tolower); - if (matchName(san_str, domain_name_str)) { + if (match_name(san_str, domain_name_str)) { return true; } @@ -749,7 +1062,7 @@ bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name) if (!MBEDTLS_OID_CMP(MBEDTLS_OID_AT_CN, &common_name->oid)) { string common_name_str((const char*)common_name->val.p, common_name->val.len); - if (matchName(common_name_str, domain_name_str)) { + if (match_name(common_name_str, domain_name_str)) { return true; } diff --git a/src/ssl_client.h b/src/ssl_client.h index 760121c..91d40db 100644 --- a/src/ssl_client.h +++ b/src/ssl_client.h @@ -49,7 +49,18 @@ typedef struct sslclient_context { } sslclient_context; void ssl_init(sslclient_context *ssl_client, Client *client); +void log_failed_cert(int flags); +void cleanup(sslclient_context *ssl_client, bool ca_cert_initialized, bool client_cert_initialized, bool client_key_initialized, int ret, const char *rootCABuff, const char *cli_cert, const char *cli_key); int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey); +int init_tcp_connection(sslclient_context *ssl_client, const char *host, uint32_t port); +int seed_random_number_generator(sslclient_context *ssl_client); +int set_up_tls_defaults(sslclient_context *ssl_client); +int auth_root_ca_buff(sslclient_context *ssl_client, const char *rootCABuff, bool *ca_cert_initialized, const char *pskIdent, const char *psKey); +int auth_client_cert_key(sslclient_context *ssl_client, const char *cli_cert, const char *cli_key, bool *client_cert_initialized, bool *client_key_initialized); +int set_hostname_for_tls(sslclient_context *ssl_client, const char *host); +int set_io_callbacks_and_timeout(sslclient_context *ssl_client, int timeout); +int perform_ssl_handshake(sslclient_context *ssl_client, const char *cli_cert, const char *cli_key); +int verify_server_cert(sslclient_context *ssl_client); void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); int data_to_read(sslclient_context *ssl_client); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len); diff --git a/test/mocks/MbedTLS.h b/test/mocks/MbedTLS.h index 768bf9a..2d1c8aa 100644 --- a/test/mocks/MbedTLS.h +++ b/test/mocks/MbedTLS.h @@ -5,6 +5,10 @@ // #define MBEDTLS_ERROR_C #define MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED +#define MBEDTLS_X509_BADCERT_EXPIRED -0x01 +#define MBEDTLS_X509_BADCERT_NOT_TRUSTED -0x08 +#define MBEDTLS_ERR_X509_CERT_VERIFY_FAILED -0x2700 +#define MBEDTLS_X509_BADCERT_OTHER -0x0100 #define MBEDTLS_ERR_SSL_WANT_READ -0x6900 #define MBEDTLS_ERR_NET_SEND_FAILED -0x004E #define MBEDTLS_ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED -0x0034 @@ -12,6 +16,8 @@ #define MBEDTLS_ERR_SSL_BAD_INPUT_DATA -0x7100 #define MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE -0x7080 #define MBEDTLS_ERR_SSL_WANT_WRITE -0x6880 +#define MBEDTLS_ERR_NET_CONN_RESET -0x004C +#define MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE -0x7780 #define MBEDTLS_SSL_IS_CLIENT 0 #define MBEDTLS_SSL_TRANSPORT_STREAM 0 #define MBEDTLS_SSL_PRESET_DEFAULT 0 @@ -19,9 +25,11 @@ #define MBEDTLS_PSK_MAX_LEN 32 #define MBEDTLS_SSL_VERIFY_NONE 0 #define MBEDTLS_OID_ISO_CCITT_DS "\x55" -#define MBEDTLS_OID_AT MBEDTLS_OID_ISO_CCITT_DS "\x04" -#define MBEDTLS_OID_AT_CN MBEDTLS_OID_AT "\x03" -#define MBEDTLS_OID_CMP(oid_str, oid_buf) (strncmp((oid_str), (char*)(oid_buf)->p, (oid_buf)->len) == 0) +#define MBEDTLS_OID_AT "\x55\x04" +#define MBEDTLS_OID_AT_CN "\x55\x04\x03" +#define MBEDTLS_OID_CMP(oid_str, oid_buf) (false) +#define MBEDTLS_ASN1_IA5_STRING 0x16 +#define MBEDTLS_ASN1_OID 0x06 typedef struct mbedtls_asn1_buf { int tag; @@ -50,10 +58,6 @@ typedef int mbedtls_ssl_recv_timeout_t(void *ctx, unsigned char *buf, size_t len struct mbedtls_ssl_context {}; struct mbedtls_ctr_drbg_context {}; struct mbedtls_entropy_context {}; -struct mbedtls_ssl_config { - void* ca_chain; - void* key_cert; -}; struct rawStruct { const unsigned char *p; size_t len; @@ -63,6 +67,12 @@ struct mbedtls_x509_crt { mbedtls_x509_sequence subject_alt_names; mbedtls_asn1_named_data subject; }; +struct mbedtls_ssl_config { + void* ca_chain; + void* key_cert; + mbedtls_x509_crt* actual_ca_chain; + mbedtls_x509_crt* actual_key_cert; +}; struct mbedtls_x509_crl {}; struct mbedtls_pk_context {}; struct mbedtls_sha256_context { @@ -72,102 +82,348 @@ struct mbedtls_sha256_context { int is224; }; -const mbedtls_x509_crt dummy_cert = { - {NULL, 0}, // raw (rawStruct) - - // subject_alt_names (mbedtls_x509_sequence) - { - {0, 0, NULL}, // buf (mbedtls_asn1_buf) - NULL // next - }, - - // subject (mbedtls_asn1_named_data) - { - {0, 0, NULL}, // oid (mbedtls_asn1_buf) - {0, 0, NULL}, // val (mbedtls_asn1_buf) - NULL, // next - 0 // next_merged (unsigned char) - } +const char* mock_cert_data = "MockCertificateData"; + +mbedtls_x509_crt dummy_cert = { + {reinterpret_cast(mock_cert_data), strlen(mock_cert_data)}, + + // subject_alt_names (mbedtls_x509_sequence) + { + {0, 0, NULL}, // buf (mbedtls_asn1_buf) + NULL // next + }, + + // subject (mbedtls_asn1_named_data) + { + {0, 0, NULL}, // oid (mbedtls_asn1_buf) + {0, 0, NULL}, // val (mbedtls_asn1_buf) + NULL, // next + 0 // next_merged (unsigned char) + } +}; + +std::string dName = "example.com"; +size_t len = dName.length(); +unsigned char* uchar_ptr = reinterpret_cast(&dName[0]); + +mbedtls_asn1_buf example_com_buffer = { + MBEDTLS_ASN1_IA5_STRING, + len, + uchar_ptr +}; + +mbedtls_x509_sequence example_com_sequence = { + example_com_buffer, + NULL +}; + +mbedtls_x509_crt dummy_cert_with_san = { + {reinterpret_cast(mock_cert_data), strlen(mock_cert_data)}, + example_com_sequence, + { + {0, 0, NULL}, + {0, 0, NULL}, + NULL, + 0 + } +}; + +mbedtls_x509_crt dummy_cert_with_cn = { + {reinterpret_cast(mock_cert_data), strlen(mock_cert_data)}, + { + {0, 0, NULL}, + NULL + }, + { + {MBEDTLS_ASN1_OID, sizeof(MBEDTLS_OID_AT_CN) - 1, const_cast(reinterpret_cast(MBEDTLS_OID_AT_CN))}, + {MBEDTLS_ASN1_IA5_STRING, strlen("example.com"), const_cast(reinterpret_cast("example.com"))}, + NULL, + 0 + } }; -const mbedtls_x509_crt *mbedtls_ssl_get_peer_cert(const mbedtls_ssl_context *ssl) { return &dummy_cert; } - -void mbedtls_ssl_init(mbedtls_ssl_context *ssl) {} -void mbedtls_ssl_config_init(mbedtls_ssl_config *conf) {} -void mbedtls_entropy_init(mbedtls_entropy_context *ctx) {} -void mbedtls_ctr_drbg_init(mbedtls_ctr_drbg_context *ctx) {} -void mbedtls_x509_crt_init(mbedtls_x509_crt *crt) {} -void mbedtls_ssl_conf_authmode(mbedtls_ssl_config *conf, int authmode) {} -void mbedtls_ssl_conf_ca_chain(mbedtls_ssl_config *conf, mbedtls_x509_crt *ca_chain, mbedtls_x509_crl *ca_crl) {} -void mbedtls_pk_init(mbedtls_pk_context *ctx) {} -void mbedtls_x509_crt_free(mbedtls_x509_crt *crt) {} -void mbedtls_ssl_conf_rng(mbedtls_ssl_config *conf, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng) {} -void mbedtls_ssl_set_bio( mbedtls_ssl_context *ssl, void *p_bio, mbedtls_ssl_send_t *f_send, mbedtls_ssl_recv_t *f_recv, mbedtls_ssl_recv_timeout_t *f_recv_timeout) {} -void mbedtls_ssl_conf_read_timeout(mbedtls_ssl_config *conf, uint32_t timeout) {} -void mbedtls_pk_free(mbedtls_pk_context *ctx) {} -void mbedtls_ssl_free(mbedtls_ssl_context *ssl) {} -void mbedtls_ssl_config_free(mbedtls_ssl_config *conf) {} -void mbedtls_ctr_drbg_free(mbedtls_ctr_drbg_context *ctx) {} -void mbedtls_entropy_free(mbedtls_entropy_context *ctx) {} -void mbedtls_sha256_init(mbedtls_sha256_context *ctx) {} -void mbedtls_sha256_starts(mbedtls_sha256_context *ctx, int is224) {} -void mbedtls_sha256_update(mbedtls_sha256_context *ctx, const unsigned char *input, size_t ilen) {} -void mbedtls_sha256_finish(mbedtls_sha256_context *ctx, unsigned char output[32]) {} - -uint32_t mbedtls_ssl_get_verify_result(const mbedtls_ssl_context *ssl) { return (uint32_t)0; } -size_t mbedtls_ssl_get_bytes_avail(const mbedtls_ssl_context *ssl) { return (size_t)0; } - -int mbedtls_ctr_drbg_seed_returns = 0; -int mbedtls_entropy_func_returns = 0; -int mbedtls_ssl_config_defaults_returns = 0; -int mbedtls_x509_crt_parse_returns = 0; -int mbedtls_ssl_conf_psk_returns = 0; -int mbedtls_pk_parse_key_returns = 0; -int mbedtls_ssl_conf_own_cert_returns = 0; -int mbedtls_ssl_set_hostname_returns = 0; -int mbedtls_ctr_drbg_random_returns = 0; -int mbedtls_ssl_setup_returns = 0; -int mbedtls_ssl_handshake_returns = 0; -int mbedtls_ssl_get_record_expansion_returns = 0; -int mbedtls_x509_crt_verify_info_returns = 0; -int mbedtls_ssl_read_returns = 0; -int mbedtls_ssl_write_returns = 0; +mbedtls_x509_crt dummy_cert_without_match = { + {reinterpret_cast(mock_cert_data), strlen(mock_cert_data)}, + { + {strlen("notexample.com"), MBEDTLS_ASN1_IA5_STRING, const_cast(reinterpret_cast("notexample.com"))}, + NULL + }, + { + {sizeof(MBEDTLS_OID_AT_CN) - 1, MBEDTLS_ASN1_OID, const_cast(reinterpret_cast(MBEDTLS_OID_AT_CN))}, + {strlen("notexample.com"), MBEDTLS_ASN1_IA5_STRING, const_cast(reinterpret_cast("notexample.com"))}, + NULL, + 0 + } +}; + +// Const removed from mbedtls_ssl_get_peer_cert for mocking - const mbedtls_x509_crt * +FunctionEmulator mbedtls_ssl_get_peer_cert_stub("mbedtls_ssl_get_peer_cert"); +mbedtls_x509_crt *mbedtls_ssl_get_peer_cert(const mbedtls_ssl_context *ssl) { + mbedtls_ssl_get_peer_cert_stub.recordFunctionCall(); + return mbedtls_ssl_get_peer_cert_stub.mock("mbedtls_ssl_get_peer_cert"); +} + +FunctionEmulator mbedtls_ssl_init_stub("mbedtls_ssl_init"); +void mbedtls_ssl_init(mbedtls_ssl_context *ssl) { + mbedtls_ssl_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_config_init_stub("mbedtls_ssl_config_init"); +void mbedtls_ssl_config_init(mbedtls_ssl_config *conf) { + mbedtls_ssl_config_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_entropy_init_stub("mbedtls_entropy_init"); +void mbedtls_entropy_init(mbedtls_entropy_context *ctx) { + mbedtls_entropy_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ctr_drbg_init_stub("mbedtls_ctr_drbg_init"); +void mbedtls_ctr_drbg_init(mbedtls_ctr_drbg_context *ctx) { + mbedtls_ctr_drbg_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_x509_crt_init_stub("mbedtls_x509_crt_init"); +void mbedtls_x509_crt_init(mbedtls_x509_crt *crt) { + mbedtls_x509_crt_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_conf_authmode_stub("mbedtls_ssl_conf_authmode"); +void mbedtls_ssl_conf_authmode(mbedtls_ssl_config *conf, int authmode) { + mbedtls_ssl_conf_authmode_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_conf_ca_chain_stub("mbedtls_ssl_conf_ca_chain"); +void mbedtls_ssl_conf_ca_chain(mbedtls_ssl_config *conf, mbedtls_x509_crt *ca_chain, mbedtls_x509_crl *ca_crl) { + mbedtls_ssl_conf_ca_chain_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_pk_init_stub("mbedtls_pk_init"); +void mbedtls_pk_init(mbedtls_pk_context *ctx) { + mbedtls_pk_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_x509_crt_free_stub("mbedtls_x509_crt_free"); +void mbedtls_x509_crt_free(mbedtls_x509_crt *crt) { + mbedtls_x509_crt_free_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_conf_rng_stub("mbedtls_ssl_conf_rng"); +void mbedtls_ssl_conf_rng(mbedtls_ssl_config *conf, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng) { + mbedtls_ssl_conf_rng_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_set_bio_stub("mbedtls_ssl_set_bio"); +void mbedtls_ssl_set_bio( mbedtls_ssl_context *ssl, void *p_bio, mbedtls_ssl_send_t *f_send, mbedtls_ssl_recv_t *f_recv, mbedtls_ssl_recv_timeout_t *f_recv_timeout) { + mbedtls_ssl_set_bio_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_conf_read_timeout_stub("mbedtls_ssl_conf_read_timeout"); +void mbedtls_ssl_conf_read_timeout(mbedtls_ssl_config *conf, uint32_t timeout) { + mbedtls_ssl_conf_read_timeout_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_pk_free_stub("mbedtls_pk_free"); +void mbedtls_pk_free(mbedtls_pk_context *ctx) { + mbedtls_pk_free_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_free_stub("mbedtls_ssl_free"); +void mbedtls_ssl_free(mbedtls_ssl_context *ssl) { + mbedtls_ssl_free_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_config_free_stub("mbedtls_ssl_config_free"); +void mbedtls_ssl_config_free(mbedtls_ssl_config *conf) { + mbedtls_ssl_config_free_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ctr_drbg_free_stub("mbedtls_ctr_drbg_free"); +void mbedtls_ctr_drbg_free(mbedtls_ctr_drbg_context *ctx) { + mbedtls_ctr_drbg_free_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_entropy_free_stub("mbedtls_entropy_free"); +void mbedtls_entropy_free(mbedtls_entropy_context *ctx) { + mbedtls_entropy_free_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_sha256_init_stub("mbedtls_sha256_init"); +void mbedtls_sha256_init(mbedtls_sha256_context *ctx) { + mbedtls_sha256_init_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_sha256_starts_stub("mbedtls_sha256_starts"); +void mbedtls_sha256_starts(mbedtls_sha256_context *ctx, int is224) { + mbedtls_sha256_starts_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_sha256_update_stub("mbedtls_sha256_update"); +void mbedtls_sha256_update(mbedtls_sha256_context *ctx, const unsigned char *input, size_t ilen) { + mbedtls_sha256_update_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_sha256_finish_stub("mbedtls_sha256_finish"); +void mbedtls_sha256_finish(mbedtls_sha256_context *ctx, unsigned char output[32]) { + mbedtls_sha256_finish_stub.recordFunctionCall(); +} + +FunctionEmulator mbedtls_ssl_get_verify_result_stub("mbedtls_ssl_get_verify_result"); +uint32_t mbedtls_ssl_get_verify_result(const mbedtls_ssl_context *ssl) { + mbedtls_ssl_get_verify_result_stub.recordFunctionCall(); + return mbedtls_ssl_get_verify_result_stub.mock("mbedtls_ssl_get_verify_result"); +} + +FunctionEmulator mbedtls_ssl_get_bytes_avail_stub("mbedtls_ssl_get_bytes_avail"); +size_t mbedtls_ssl_get_bytes_avail(const mbedtls_ssl_context *ssl) { + mbedtls_ssl_get_bytes_avail_stub.recordFunctionCall(); + return mbedtls_ssl_get_bytes_avail_stub.mock("mbedtls_ssl_get_bytes_avail"); +} + +FunctionEmulator mbedtls_ctr_drbg_seed_stub("mbedtls_ctr_drbg_seed"); +int mbedtls_ctr_drbg_seed(mbedtls_ctr_drbg_context *ctx, int (*f_entropy)(void *, unsigned char *, size_t), void *p_entropy, const unsigned char *custom, size_t len) { + mbedtls_ctr_drbg_seed_stub.recordFunctionCall(); + return mbedtls_ctr_drbg_seed_stub.mock("mbedtls_ctr_drbg_seed"); +} + +FunctionEmulator mbedtls_entropy_func_stub("mbedtls_entropy_func"); +int mbedtls_entropy_func(void *data, unsigned char *output, size_t len) { + mbedtls_entropy_func_stub.recordFunctionCall(); + return mbedtls_entropy_func_stub.mock("mbedtls_entropy_func"); +} + +FunctionEmulator mbedtls_ssl_config_defaults_stub("mbedtls_ssl_config_defaults"); +int mbedtls_ssl_config_defaults(mbedtls_ssl_config *conf, int endpoint, int transport, int preset) { + mbedtls_ssl_config_defaults_stub.recordFunctionCall(); + return mbedtls_ssl_config_defaults_stub.mock("mbedtls_ssl_config_defaults"); +} + +FunctionEmulator mbedtls_x509_crt_parse_stub("mbedtls_x509_crt_parse"); +int mbedtls_x509_crt_parse(mbedtls_x509_crt *chain, const unsigned char *buf, size_t buflen) { + mbedtls_x509_crt_parse_stub.recordFunctionCall(); + return mbedtls_x509_crt_parse_stub.mock("mbedtls_x509_crt_parse"); +} + +FunctionEmulator mbedtls_ssl_conf_psk_stub("mbedtls_ssl_conf_psk"); +int mbedtls_ssl_conf_psk(mbedtls_ssl_config *conf, const unsigned char *psk, size_t psk_len, const unsigned char *psk_identity, size_t psk_identity_len) { + mbedtls_ssl_conf_psk_stub.recordFunctionCall(); + return mbedtls_ssl_conf_psk_stub.mock("mbedtls_ssl_conf_psk"); +} + +FunctionEmulator mbedtls_pk_parse_key_stub("mbedtls_pk_parse_key"); +int mbedtls_pk_parse_key(mbedtls_pk_context *pk, const unsigned char *key, size_t keylen, const unsigned char *pwd, size_t pwdlen) { + mbedtls_pk_parse_key_stub.recordFunctionCall(); + return mbedtls_pk_parse_key_stub.mock("mbedtls_pk_parse_key"); +} + +FunctionEmulator mbedtls_ssl_conf_own_cert_stub("mbedtls_ssl_conf_own_cert"); +int mbedtls_ssl_conf_own_cert(mbedtls_ssl_config *conf, mbedtls_x509_crt *own_cert, mbedtls_pk_context *pk_key) { + mbedtls_ssl_conf_own_cert_stub.recordFunctionCall(); + return mbedtls_ssl_conf_own_cert_stub.mock("mbedtls_ssl_conf_own_cert"); +} + +FunctionEmulator mbedtls_ssl_set_hostname_stub("mbedtls_ssl_set_hostname"); +int mbedtls_ssl_set_hostname(mbedtls_ssl_context *ssl, const char *hostname) { + mbedtls_ssl_set_hostname_stub.recordFunctionCall(); + return mbedtls_ssl_set_hostname_stub.mock("mbedtls_ssl_set_hostname"); +} + +FunctionEmulator mbedtls_ctr_drbg_random_stub("mbedtls_ctr_drbg_random"); +int mbedtls_ctr_drbg_random(void *p_rng, unsigned char *output, size_t output_len) { + mbedtls_ctr_drbg_random_stub.recordFunctionCall(); + return mbedtls_ctr_drbg_random_stub.mock("mbedtls_ctr_drbg_random"); +} + +FunctionEmulator mbedtls_ssl_setup_stub("mbedtls_ssl_setup"); +int mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf) { + mbedtls_ssl_setup_stub.recordFunctionCall(); + return mbedtls_ssl_setup_stub.mock("mbedtls_ssl_setup"); +} + +FunctionEmulator mbedtls_ssl_handshake_stub("mbedtls_ssl_handshake"); +int mbedtls_ssl_handshake(mbedtls_ssl_context *ssl) { + mbedtls_ssl_handshake_stub.recordFunctionCall(); + return mbedtls_ssl_handshake_stub.mock("mbedtls_ssl_handshake"); +} + +FunctionEmulator mbedtls_ssl_get_record_expansion_stub("mbedtls_ssl_get_record_expansion"); +int mbedtls_ssl_get_record_expansion(const mbedtls_ssl_context *ssl) { + mbedtls_ssl_get_record_expansion_stub.recordFunctionCall(); + return mbedtls_ssl_get_record_expansion_stub.mock("mbedtls_ssl_get_record_expansion"); +} + +FunctionEmulator mbedtls_x509_crt_verify_info_stub("mbedtls_x509_crt_verify_info"); +int mbedtls_x509_crt_verify_info(char *buf, size_t size, const char *prefix, uint32_t flags) { + mbedtls_x509_crt_verify_info_stub.recordFunctionCall(); + return mbedtls_x509_crt_verify_info_stub.mock("mbedtls_x509_crt_verify_info"); +} + +FunctionEmulator mbedtls_ssl_read_stub("mbedtls_ssl_read"); +int mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len) { + mbedtls_ssl_read_stub.recordFunctionCall(); + return mbedtls_ssl_read_stub.mock("mbedtls_ssl_read"); +} + +FunctionEmulator mbedtls_ssl_write_stub("mbedtls_ssl_write"); +int mbedtls_ssl_write(mbedtls_ssl_context *ssl, const unsigned char *buf, size_t len) { + mbedtls_ssl_write_stub.recordFunctionCall(); + return mbedtls_ssl_write_stub.mock("mbedtls_ssl_write"); +} + +FunctionEmulator mbedtls_ssl_get_version_stub("mbedtls_ssl_get_version"); +const char *mbedtls_ssl_get_version(const mbedtls_ssl_context *ssl) { + mbedtls_ssl_get_version_stub.recordFunctionCall(); + return mbedtls_ssl_get_version_stub.mock("mbedtls_ssl_get_version"); +} + +FunctionEmulator mbedtls_ssl_get_ciphersuite_stub("mbedtls_ssl_get_ciphersuite"); +const char *mbedtls_ssl_get_ciphersuite(const mbedtls_ssl_context *ssl) { + mbedtls_ssl_get_ciphersuite_stub.recordFunctionCall(); + return mbedtls_ssl_get_ciphersuite_stub.mock("mbedtls_ssl_get_ciphersuite"); +} void mbedtls_mock_reset_return_values() { - mbedtls_ctr_drbg_seed_returns = 0; - mbedtls_entropy_func_returns = 0; - mbedtls_ssl_config_defaults_returns = 0; - mbedtls_x509_crt_parse_returns = 0; - mbedtls_ssl_conf_psk_returns = 0; - mbedtls_pk_parse_key_returns = 0; - mbedtls_ssl_conf_own_cert_returns = 0; - mbedtls_ssl_set_hostname_returns = 0; - mbedtls_ctr_drbg_random_returns = 0; - mbedtls_ssl_setup_returns = 0; - mbedtls_ssl_handshake_returns = 0; - mbedtls_ssl_get_record_expansion_returns = 0; - mbedtls_x509_crt_verify_info_returns = 0; - mbedtls_ssl_read_returns = 0; - mbedtls_ssl_write_returns = 0; -} - -int mbedtls_ctr_drbg_seed(mbedtls_ctr_drbg_context *ctx, int (*f_entropy)(void *, unsigned char *, size_t), void *p_entropy, const unsigned char *custom, size_t len) { return mbedtls_ctr_drbg_seed_returns; } -int mbedtls_entropy_func(void *data, unsigned char *output, size_t len) { return mbedtls_entropy_func_returns; } -int mbedtls_ssl_config_defaults(mbedtls_ssl_config *conf, int endpoint, int transport, int preset) { return mbedtls_ssl_config_defaults_returns; } -int mbedtls_x509_crt_parse(mbedtls_x509_crt *chain, const unsigned char *buf, size_t buflen) { return mbedtls_x509_crt_parse_returns; } -int mbedtls_ssl_conf_psk(mbedtls_ssl_config *conf, const unsigned char *psk, size_t psk_len, const unsigned char *psk_identity, size_t psk_identity_len) { return mbedtls_ssl_conf_psk_returns; } -int mbedtls_pk_parse_key(mbedtls_pk_context *pk, const unsigned char *key, size_t keylen, const unsigned char *pwd, size_t pwdlen) { return mbedtls_pk_parse_key_returns; } -int mbedtls_ssl_conf_own_cert(mbedtls_ssl_config *conf, mbedtls_x509_crt *own_cert, mbedtls_pk_context *pk_key) { return mbedtls_ssl_conf_own_cert_returns; } -int mbedtls_ssl_set_hostname(mbedtls_ssl_context *ssl, const char *hostname) { return mbedtls_ssl_set_hostname_returns; } -int mbedtls_ctr_drbg_random(void *p_rng, unsigned char *output, size_t output_len) { return mbedtls_ctr_drbg_random_returns; } -int mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf) { return mbedtls_ssl_setup_returns; } -int mbedtls_ssl_handshake(mbedtls_ssl_context *ssl) { return mbedtls_ssl_handshake_returns; } -int mbedtls_ssl_get_record_expansion(const mbedtls_ssl_context *ssl) { return mbedtls_ssl_get_record_expansion_returns; } -int mbedtls_x509_crt_verify_info(char *buf, size_t size, const char *prefix, uint32_t flags) { return mbedtls_x509_crt_verify_info_returns; } -int mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len) { return mbedtls_ssl_read_returns; } -int mbedtls_ssl_write(mbedtls_ssl_context *ssl, const unsigned char *buf, size_t len) { return mbedtls_ssl_write_returns; } - -const char *mbedtls_ssl_get_version(const mbedtls_ssl_context *ssl) { return (const char*)""; } -const char *mbedtls_ssl_get_ciphersuite(const mbedtls_ssl_context *ssl) { return (const char*)""; } + mbedtls_ssl_get_peer_cert_stub.reset(); + mbedtls_ssl_init_stub.reset(); + mbedtls_ssl_config_init_stub.reset(); + mbedtls_entropy_init_stub.reset(); + mbedtls_ctr_drbg_init_stub.reset(); + mbedtls_x509_crt_init_stub.reset(); + mbedtls_ssl_conf_authmode_stub.reset(); + mbedtls_ssl_conf_ca_chain_stub.reset(); + mbedtls_pk_init_stub.reset(); + mbedtls_x509_crt_free_stub.reset(); + mbedtls_ssl_conf_rng_stub.reset(); + mbedtls_ssl_set_bio_stub.reset(); + mbedtls_ssl_conf_read_timeout_stub.reset(); + mbedtls_pk_free_stub.reset(); + mbedtls_ssl_free_stub.reset(); + mbedtls_ssl_config_free_stub.reset(); + mbedtls_ctr_drbg_free_stub.reset(); + mbedtls_entropy_free_stub.reset(); + mbedtls_sha256_init_stub.reset(); + mbedtls_sha256_starts_stub.reset(); + mbedtls_sha256_update_stub.reset(); + mbedtls_sha256_finish_stub.reset(); + mbedtls_ssl_get_verify_result_stub.reset(); + mbedtls_ssl_get_bytes_avail_stub.reset(); + mbedtls_ctr_drbg_seed_stub.reset(); + mbedtls_entropy_func_stub.reset(); + mbedtls_ssl_config_defaults_stub.reset(); + mbedtls_x509_crt_parse_stub.reset(); + mbedtls_ssl_conf_psk_stub.reset(); + mbedtls_pk_parse_key_stub.reset(); + mbedtls_ssl_conf_own_cert_stub.reset(); + mbedtls_ssl_set_hostname_stub.reset(); + mbedtls_ctr_drbg_random_stub.reset(); + mbedtls_ssl_setup_stub.reset(); + mbedtls_ssl_handshake_stub.reset(); + mbedtls_ssl_get_record_expansion_stub.reset(); + mbedtls_x509_crt_verify_info_stub.reset(); + mbedtls_ssl_read_stub.reset(); + mbedtls_ssl_write_stub.reset(); + mbedtls_ssl_get_version_stub.reset(); + mbedtls_ssl_get_ciphersuite_stub.reset(); +} #endif // MBEDTLS_MOCK_H \ No newline at end of file diff --git a/test/mocks/TestClient.h b/test/mocks/TestClient.h index 050efe6..eaba844 100644 --- a/test/mocks/TestClient.h +++ b/test/mocks/TestClient.h @@ -3,19 +3,23 @@ #include "Client.h" #include "Emulator.h" +#include "FunctionEmulator.h" + + +FunctionEmulator test_client_stop_stub("TestClient::stop()"); class TestClient : public Client, public Emulator { public: int connect(IPAddress ip, uint16_t port) override { - return 1; // 1 means successful connection, you can change based on test requirements. + return this->mock("connect"); } int connect(const char *host, uint16_t port) override { - return 1; + return this->mock("connect"); } size_t write(uint8_t byte) override { - return 1; // 1 byte written + return this->mock("write"); } size_t write(const uint8_t *buf, size_t size) override { @@ -23,11 +27,11 @@ class TestClient : public Client, public Emulator { } int available() override { - return 0; // No bytes available + return this->mock("available"); } int read() override { - return -1; // -1 generally indicates no bytes available + return this->mock("read"); } int read(uint8_t *buf, size_t size) override { @@ -35,12 +39,14 @@ class TestClient : public Client, public Emulator { } int peek() override { - return -1; // -1 generally indicates no bytes available + return this->mock("peek"); } void flush() override {} - void stop() override {} + void stop() override { + test_client_stop_stub.recordFunctionCall(); + } uint8_t connected() override { return this->mock("connected"); diff --git a/test/unit_test_private_api.cpp b/test/unit_test_private_api.cpp index dc5074c..f7bbb61 100644 --- a/test/unit_test_private_api.cpp +++ b/test/unit_test_private_api.cpp @@ -1,24 +1,36 @@ -#define log_d(...); printf(__VA_ARGS__); printf("\n"); -#define log_i(...); printf(__VA_ARGS__); printf("\n"); -#define log_w(...); printf(__VA_ARGS__); printf("\n"); -#define log_e(...); printf(__VA_ARGS__); printf("\n"); -#define log_v(...); printf(__VA_ARGS__); printf("\n"); +// #define EMULATOR_LOG +#include "unity.h" +#include "Arduino.h" +#include "Emulation.h" + #define portTICK_PERIOD_MS 1 #define vTaskDelay(x) delay(x) -#include "unity.h" -#include "Arduino.h" #include "mocks/ESPClass.hpp" #include "mocks/TestClient.h" #include "ssl_client.cpp" using namespace fakeit; -TestClient testClient; -sslclient_context *testContext; +TestClient testClient; // Mocked client +sslclient_context *testContext; // Context for tests + +/** + * @brief Set the up stop ssl socket object for these tests. + * + * @param ctx The sslclient_context to set up. + * @param client The client to set up. + */ +void setup_stop_ssl_socket(sslclient_context* ctx, Client* client) { + ctx->ssl_conf.actual_ca_chain = (mbedtls_x509_crt*) malloc(sizeof(mbedtls_x509_crt)); + ctx->ssl_conf.actual_key_cert = &dummy_cert; + ctx->ssl_conf.ca_chain = ctx->ssl_conf.actual_ca_chain; + ctx->ssl_conf.key_cert = ctx->ssl_conf.actual_key_cert; +} void setUp(void) { ArduinoFakeReset(); + ResetEmulators(); testClient.reset(); testClient.returns("connected", (uint8_t)1); mbedtls_mock_reset_return_values(); @@ -98,7 +110,7 @@ void test_partial_write(void) { // Act int result = client_net_send(&testClient, buf, sizeof(buf)); - + // Assert TEST_ASSERT_EQUAL_INT(1500, result); // Only half the buffer is sent } @@ -106,13 +118,14 @@ void test_partial_write(void) { void test_disconnected_client(void) { // Arrange unsigned char buf[1000]; - testClient.reset(); // Reset the mock client - testClient.returns("connected", (uint8_t)0); // Mock the client to return false for "connected" function + testClient.reset(); + testClient.returns("connected", (uint8_t)0); // Act int result = client_net_send(&testClient, buf, sizeof(buf)); // Assert + TEST_ASSERT_EQUAL_INT(1, log_e_stub.timesCalled()); TEST_ASSERT_EQUAL_INT(-2, result); // -2 indicates disconnected client } @@ -128,145 +141,1787 @@ void run_client_net_send_tests(void) { UNITY_END(); } -/* Test get_ssl_receive function */ +/* Test client_net_recv function */ -void test_get_ssl_receive_success(void) { +void test_null_client_context(void) { // Arrange - unsigned char data[1024]; - mbedtls_ssl_read_returns = 1024; + unsigned char buf[100]; // Act - int result = get_ssl_receive(testContext, data, sizeof(data)); + int result = client_net_recv(NULL, buf, sizeof(buf)); // Assert - TEST_ASSERT_EQUAL_INT(1024, result); + TEST_ASSERT_EQUAL_INT(-1, result); } -void test_get_ssl_receive_partial_read(void) { +void test_disconnected_client_client_net_recv(void) { // Arrange - unsigned char data[1024]; - mbedtls_ssl_read_returns = 512; + testClient.reset(); + testClient.returns("connected", (uint8_t)0); + unsigned char buf[100]; // Act - int result = get_ssl_receive(testContext, data, sizeof(data)); + int result = client_net_recv(&testClient, buf, sizeof(buf)); // Assert - TEST_ASSERT_EQUAL_INT(512, result); + TEST_ASSERT_EQUAL_INT(-2, result); } -void test_get_ssl_receive_failure(void) { +void test_successful_client_read(void) { // Arrange - unsigned char data[1024]; - mbedtls_ssl_read_returns = MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + unsigned char buf[100]; + testClient.returns("read", (int)50); // Act - int result = get_ssl_receive(testContext, data, sizeof(data)); + int result = client_net_recv(&testClient, buf, sizeof(buf)); // Assert - TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_BAD_INPUT_DATA, result); + TEST_ASSERT_EQUAL_INT(50, result); } -void test_get_ssl_receive_zero_length(void) { +void test_failed_client_read(void) { // Arrange - unsigned char data[1]; - mbedtls_ssl_read_returns = 0; + unsigned char buf[100]; + testClient.returns("read", (int)0); // Mock a read failure // Act - int result = get_ssl_receive(testContext, data, 0); + int result = client_net_recv(&testClient, buf, sizeof(buf)); // Assert - TEST_ASSERT_EQUAL_INT(0, result); + TEST_ASSERT_EQUAL_INT(0, result); // Expecting 0 as read() has failed } -void run_get_ssl_receive_tests(void) { +void run_client_net_recv_tests(void) { UNITY_BEGIN(); - RUN_TEST(test_get_ssl_receive_success); - RUN_TEST(test_get_ssl_receive_partial_read); - RUN_TEST(test_get_ssl_receive_failure); - RUN_TEST(test_get_ssl_receive_zero_length); + RUN_TEST(test_null_client_context); + RUN_TEST(test_disconnected_client_client_net_recv); + RUN_TEST(test_successful_client_read); + RUN_TEST(test_failed_client_read); UNITY_END(); } -/* Test client_net_recv function */ +/* Test handle_error function */ -void test_null_client_context(void) { +void test_handle_error_no_logging_on_minus_30848(void) { // Arrange - unsigned char buf[100]; + int err = -30848; // Act - int result = client_net_recv(NULL, buf, sizeof(buf)); + int result = _handle_error(err, "testFunction", 123); + + // Assert + TEST_ASSERT_EQUAL_INT(-30848, result); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); +} +void test_handle_error_logging_with_mbedtls_error_c(void) { + // Arrange + int err = MBEDTLS_ERR_NET_SEND_FAILED; + + // Act + int result = _handle_error(err, "testFunction", 123); + + // Assert + TEST_ASSERT_EQUAL_INT(-0x004E, result); + TEST_ASSERT_TRUE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_e_stub.timesCalled()); +} + +void test_handle_error_logging_without_mbedtls_error_c(void) { + // Arrange + int err = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE; // Some error code not being specially handled + + // Act + int result = _handle_error(err, "testFunction", 123); + + // Assert + TEST_ASSERT_EQUAL_INT(err, result); + TEST_ASSERT_TRUE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_e_stub.timesCalled()); +} + +void run_handle_error_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_handle_error_no_logging_on_minus_30848); + RUN_TEST(test_handle_error_logging_with_mbedtls_error_c); + RUN_TEST(test_handle_error_logging_without_mbedtls_error_c); + UNITY_END(); +} + +/* Test client_net_recv_timeout function */ + +void test_ctx_is_null(void) { + // Arrange + unsigned char buf[10]; + + // Act + int result = client_net_recv_timeout(nullptr, buf, 10, 1000); + // Assert + TEST_ASSERT_FALSE(log_v_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_e_stub.timesCalled()); TEST_ASSERT_EQUAL_INT(-1, result); } -void test_disconnected_client_client_net_recv(void) { +void test_successful_read_without_delay(void) { + // Arrange + testClient.returns("available", (int)10); + testClient.returns("read", (int)10); + unsigned char buf[10]; + + // Act + int result = client_net_recv_timeout(&testClient, buf, 10, 1000); + + // Assert + TEST_ASSERT_EQUAL_INT(2, log_v_stub.timesCalled()); + TEST_ASSERT_GREATER_THAN(0, result); +} + +void test_successful_read_with_delay(void) { + // Arrange + testClient.returns("available", (int)10); + testClient.returns("read", (int)10); + unsigned char buf[10]; + + // Act + int result = client_net_recv_timeout(&testClient, buf, 10, 1000); + + // Assert + TEST_ASSERT_EQUAL_INT(2, log_v_stub.timesCalled()); + TEST_ASSERT_GREATER_THAN(0, result); +} + +void test_read_timeout(void) { // Arrange testClient.reset(); - testClient.returns("connected", (uint8_t)0); - unsigned char buf[100]; + testClient.returns("available", (int)0); + testClient.returns("read", (int)0); + unsigned char buf[10]; // Act - int result = client_net_recv(&testClient, buf, sizeof(buf)); + int result = client_net_recv_timeout(&testClient, buf, 10, 100); + + // Assert + TEST_ASSERT_EQUAL_INT(1, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL(MBEDTLS_ERR_SSL_WANT_READ, result); +} +void test_read_returns_zero(void) { + // Arrange + testClient.returns("available", (int)10); + testClient.returns("read", (int)0); + unsigned char buf[10]; + + // Act + int result = client_net_recv_timeout(&testClient, buf, 10, 1000); + // Assert - TEST_ASSERT_EQUAL_INT(-2, result); + TEST_ASSERT_EQUAL_INT(1, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL(MBEDTLS_ERR_SSL_WANT_READ, result); } -void test_successful_client_read(void) { +void test_len_zero(void) { // Arrange - unsigned char buf[100]; - testClient.returns("read", (int)50); + unsigned char buf[10]; // Act - int result = client_net_recv(&testClient, buf, sizeof(buf)); + int result = client_net_recv_timeout(&testClient, buf, 0, 1000); + + // Assert + TEST_ASSERT_TRUE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void run_client_net_recv_timeout_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_ctx_is_null); + RUN_TEST(test_successful_read_without_delay); + RUN_TEST(test_successful_read_with_delay); + RUN_TEST(test_read_timeout); + RUN_TEST(test_read_returns_zero); + RUN_TEST(test_len_zero); + UNITY_END(); +} + +/* test ssl_init function */ +void test_ssl_init_correct_initialization() { + // Arrange / Act + ssl_init(testContext, &testClient); + // Assert - TEST_ASSERT_EQUAL_INT(50, result); + TEST_ASSERT_EQUAL_PTR(&testClient, testContext->client); + TEST_ASSERT_EQUAL_MEMORY(&testClient, testContext->client, sizeof(Client)); } -void test_failed_client_read(void) { +void test_ssl_init_mbedtls_functions_called() { + // Arrange / Act + ssl_init(testContext, &testClient); + + // Assert + TEST_ASSERT_TRUE(mbedtls_ssl_init_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ssl_config_init_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ctr_drbg_init_stub.wasCalled()); +} + +void test_ssl_init_logging() { + // Assert / Act + ssl_init(testContext, &testClient); + ArgContext args = log_v_stub.getArguments(); + + // Assert + TEST_ASSERT_EQUAL_STRING("Init SSL", args.resolve(0).c_str()); +} + +void run_ssl_init_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_ssl_init_correct_initialization); + RUN_TEST(test_ssl_init_mbedtls_functions_called); + RUN_TEST(test_ssl_init_mbedtls_functions_called); + UNITY_END(); +} + +/* test data_to_read function */ + +void test_data_to_read_success() { // Arrange - unsigned char buf[100]; - testClient.returns("read", (int)0); // Mock a read failure + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", 5); + mbedtls_ssl_get_bytes_avail_stub.returns("mbedtls_ssl_get_bytes_avail", (size_t)5); + + // Act + int result = data_to_read(testContext); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_EQUAL(5, result); +} +void test_data_to_read_edge_case() { + // Arrange + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", MBEDTLS_ERR_SSL_WANT_READ); + mbedtls_ssl_get_bytes_avail_stub.returns("mbedtls_ssl_get_bytes_avail", (size_t)0); + // Act - int result = client_net_recv(&testClient, buf, sizeof(buf)); + int result = data_to_read(testContext); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_EQUAL(0, result); +} +void test_data_to_read_failure() { + // Arrange + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", MBEDTLS_ERR_NET_CONN_RESET); + mbedtls_ssl_get_bytes_avail_stub.returns("mbedtls_ssl_get_bytes_avail", (size_t)0); + + // Act + int result = data_to_read(testContext); + // Assert - TEST_ASSERT_EQUAL_INT(0, result); // Expecting 0 as read() has failed + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_EQUAL(-76, result); // -0x004C = MBEDTLS_ERR_NET_CONN_RESET } -void run_client_net_recv_tests(void) { +void run_data_to_read_tests(void) { UNITY_BEGIN(); - RUN_TEST(test_null_client_context); - RUN_TEST(test_disconnected_client_client_net_recv); - RUN_TEST(test_successful_client_read); - RUN_TEST(test_failed_client_read); + RUN_TEST(test_data_to_read_success); + RUN_TEST(test_data_to_read_edge_case); + RUN_TEST(test_data_to_read_failure); UNITY_END(); } -/* End of test functions */ +/* test log_failed_cert function */ -#ifdef ARDUINO +void test_log_failed_cert_with_some_flags(void) { + // Arrange + int flags = MBEDTLS_X509_BADCERT_EXPIRED; + + // Act + log_failed_cert(flags); + + // Assert + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); +} -#include +void test_log_failed_cert_with_null_flags(void) { + // Arrange + int flags = 0; + + // Act + log_failed_cert(flags); + + // Assert + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); +} -void setup() { - delay(2000); // If using Serial, allow time for serial monitor to open - run_all_tests(); +void run_log_failed_cert_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_log_failed_cert_with_some_flags); + RUN_TEST(test_log_failed_cert_with_null_flags); + UNITY_END(); } -void loop() { - // Empty loop +/* test cleanup function */ + +void test_cleanup_with_all_resources_initialized_and_no_error(void) { + // Arrange + bool ca_cert_initialized = true; + bool client_cert_initialized = true; + bool client_key_initialized = true; + int ret = 0; + + // Act + cleanup(testContext, ca_cert_initialized, client_cert_initialized, client_key_initialized, ret, NULL, NULL, NULL); + + // Assert + TEST_ASSERT_TRUE(mbedtls_x509_crt_free_stub.timesCalled() == 2); + TEST_ASSERT_TRUE(mbedtls_pk_free_stub.wasCalled()); + TEST_ASSERT_TRUE(log_d_stub.wasCalled()); } -#else +void test_cleanup_with_some_resources_initialized_and_no_error(void) { + // Arrange + sslclient_context ssl_client; + bool ca_cert_initialized = true; + bool client_cert_initialized = false; + bool client_key_initialized = true; + int ret = 0; -int main(int argc, char **argv) { - run_client_net_send_tests(); - run_get_ssl_receive_tests(); - run_client_net_recv_tests(); + // Act + cleanup(&ssl_client, ca_cert_initialized, client_cert_initialized, client_key_initialized, ret, NULL, NULL, NULL); + + // Assert + TEST_ASSERT_TRUE(mbedtls_x509_crt_free_stub.timesCalled() == 1); + TEST_ASSERT_TRUE(mbedtls_pk_free_stub.wasCalled()); + TEST_ASSERT_TRUE(log_d_stub.wasCalled()); +} + +void run_cleanup_tests() { + UNITY_BEGIN(); + RUN_TEST(test_cleanup_with_all_resources_initialized_and_no_error); + RUN_TEST(test_cleanup_with_some_resources_initialized_and_no_error); + UNITY_END(); +} + +/* test start_ssl_client function */ + +void test_successful_ssl_client_start(void) { + // Arrange + testClient.reset(); + testContext->client = &testClient; + testClient.returns("connect", (int)1); + testContext->client = &testClient; + const char *host = "example.com"; + uint32_t port = 443; + int timeout = 1000; + const char *rootCABuff = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_cert = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_key = "-----BEGIN PRIVATE KEY-----...-----END PRIVATE KEY-----"; + const char *pskIdent = NULL; + const char *psKey = NULL; + + mbedtls_ctr_drbg_seed_stub.returns("mbedtls_ctr_drbg_seed", 0); + mbedtls_ssl_config_defaults_stub.returns("mbedtls_ssl_config_defaults", 0); + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", 0); + mbedtls_pk_parse_key_stub.returns("mbedtls_pk_parse_key", 0); + mbedtls_ssl_conf_own_cert_stub.returns("mbedtls_ssl_conf_own_cert", 0); + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", 0); + mbedtls_ssl_setup_stub.returns("mbedtls_ssl_setup", 0); + mbedtls_ssl_handshake_stub.returns("mbedtls_ssl_handshake", 0); + mbedtls_ssl_get_record_expansion_stub.returns("mbedtls_ssl_get_record_expansion", 0); + mbedtls_ssl_get_verify_result_stub.returns("mbedtls_ssl_get_verify_result", (uint32_t)0); + + // Act + int result = start_ssl_client(testContext, host, port, timeout, rootCABuff, cli_cert, cli_key, pskIdent, psKey); + + // Assert + TEST_ASSERT_EQUAL(1, result); +} + +void test_ssl_client_start_with_invalid_host(void) { + // Arrange + testClient.reset(); + testContext->client = &testClient; + testClient.returns("connect", (int)0); + const char *host = "example.com"; + uint32_t port = 443; + int timeout = 1000; + const char *rootCABuff = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_cert = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_key = "-----BEGIN PRIVATE KEY-----...-----END PRIVATE KEY-----"; + const char *pskIdent = NULL; + const char *psKey = NULL; + + // Act + int result = start_ssl_client(testContext, "invalid_host", port, timeout, rootCABuff, cli_cert, cli_key, pskIdent, psKey); + + // Assert + TEST_ASSERT_EQUAL(0, result); +} + +void test_ssl_client_start_invalid_port(void) { + // Arrange + testClient.reset(); + testContext->client = &testClient; + testClient.returns("connect", (int)0); + const char *host = "example.com"; + int timeout = 1000; + const char *rootCABuff = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_cert = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_key = "-----BEGIN PRIVATE KEY-----...-----END PRIVATE KEY-----"; + const char *pskIdent = NULL; + const char *psKey = NULL; + uint32_t port = (uint32_t)432589743022453; + + // Act + int result = start_ssl_client(testContext, host, port, timeout, rootCABuff, cli_cert, cli_key, pskIdent, psKey); + + // Assert + TEST_ASSERT_EQUAL(0, result); +} + +void test_ssl_client_start_failed_tcp_connection(void) { + // Arrange + const char *host = "example.com"; + uint32_t port = 443; + int timeout = 1000; + const char *rootCABuff = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_cert = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_key = "-----BEGIN PRIVATE KEY-----...-----END PRIVATE KEY-----"; + const char *pskIdent = NULL; + const char *psKey = NULL; + + // Act - null testContext->client + int result = start_ssl_client(testContext, host, port, timeout, rootCABuff, cli_cert, cli_key, pskIdent, psKey); + + // Assert + TEST_ASSERT_EQUAL(0, result); +} + +void test_ssl_client_start_failed_ssl_tls_handshake(void) { + // Arrange + testClient.reset(); + testContext->client = &testClient; + testClient.returns("connect", (int)1); + testContext->client = &testClient; + const char *host = "example.com"; + uint32_t port = 443; + int timeout = 1000; + const char *rootCABuff = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_cert = "-----BEGIN CERTIFICATE-----...-----END CERTIFICATE-----"; + const char *cli_key = "-----BEGIN PRIVATE KEY-----...-----END PRIVATE KEY-----"; + const char *pskIdent = NULL; + const char *psKey = NULL; + + mbedtls_ctr_drbg_seed_stub.returns("mbedtls_ctr_drbg_seed", 0); + mbedtls_ssl_config_defaults_stub.returns("mbedtls_ssl_config_defaults", 0); + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", 0); + mbedtls_pk_parse_key_stub.returns("mbedtls_pk_parse_key", 0); + mbedtls_ssl_conf_own_cert_stub.returns("mbedtls_ssl_conf_own_cert", 0); + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", 0); + mbedtls_ssl_setup_stub.returns("mbedtls_ssl_setup", 0); + mbedtls_ssl_handshake_stub.returns("mbedtls_ssl_handshake", 0); + mbedtls_ssl_get_record_expansion_stub.returns("mbedtls_ssl_get_record_expansion", MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE); + + // Act + int result = start_ssl_client(testContext, host, port, timeout, rootCABuff, cli_cert, cli_key, pskIdent, psKey); + + // Assert + TEST_ASSERT_EQUAL(0, result); +} + +void run_start_ssl_client_tests() { + UNITY_BEGIN(); + RUN_TEST(test_successful_ssl_client_start); + RUN_TEST(test_ssl_client_start_with_invalid_host); + RUN_TEST(test_ssl_client_start_invalid_port); + RUN_TEST(test_ssl_client_start_failed_tcp_connection); + RUN_TEST(test_ssl_client_start_failed_ssl_tls_handshake); + UNITY_END(); +} + +/* test init_tcp_connection function */ + +void test_init_tcp_connection_SuccessfulConnection_ReturnsZero(void) { + // Arrange + testContext->client = &testClient; + testClient.reset(); + testClient.returns("connect", (int)1); + const char* host = "example.com"; + uint32_t port = 443; + + // Act + int result = init_tcp_connection(testContext, host, port); + + // Assert + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_init_tcp_connection_NullClient_ReturnsMinusOne(void) { + // Arrange + const char* host = "example.com"; + uint32_t port = 443; + testContext->client = nullptr; + + // Act + int result = init_tcp_connection(testContext, host, port); + + // Assert + TEST_ASSERT_FALSE(log_v_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_e_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(-1, result); +} + +void test_init_tcp_connection_FailedConnection_ReturnsMinusTwo(void) { + // Arrange + testContext->client = &testClient; + testClient.reset(); + testClient.returns("connect", (int)0); + const char* host = "example.com"; + uint32_t port = 443; + + // Act + int result = init_tcp_connection(testContext, host, port); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_e_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(-2, result); +} + +void test_init_tcp_connection_EdgeCase_LargePortNumber_SuccessfulConnection(void) { + // Arrange + testContext->client = &testClient; + testClient.reset(); + testClient.returns("connect", (int)1); + const char* host = "example.com"; + uint32_t largePort = UINT32_MAX; + + // Act + int result = init_tcp_connection(testContext, host, largePort); + + // Assert + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(1, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void run_init_tcp_connection_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_init_tcp_connection_SuccessfulConnection_ReturnsZero); + RUN_TEST(test_init_tcp_connection_NullClient_ReturnsMinusOne); + RUN_TEST(test_init_tcp_connection_FailedConnection_ReturnsMinusTwo); + RUN_TEST(test_init_tcp_connection_EdgeCase_LargePortNumber_SuccessfulConnection); + UNITY_END(); +} + +/* test seed_random_number_generator function */ + +void test_seed_random_number_generator_SuccessfulSeed_ReturnsZero(void) { + // Arrange + mbedtls_ctr_drbg_seed_stub.returns("mbedtls_ctr_drbg_seed", 0); + + // Act + int result = seed_random_number_generator(testContext); + + // Assert + TEST_ASSERT_EQUAL_INT(2, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_seed_random_number_generator_CtrDrbgSeedFails_ReturnsErrorCode(void) { + // Arrange + mbedtls_ctr_drbg_seed_stub.returns("mbedtls_ctr_drbg_seed", MBEDTLS_ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED); + + // Act + int result = seed_random_number_generator(testContext); + + // Assert + TEST_ASSERT_EQUAL_INT(2, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_CTR_DRBG_ENTROPY_SOURCE_FAILED, result); +} + +void run_seed_random_number_generator_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_seed_random_number_generator_SuccessfulSeed_ReturnsZero); + RUN_TEST(test_seed_random_number_generator_CtrDrbgSeedFails_ReturnsErrorCode); + UNITY_END(); +} + +/* Test set_up_tls_defaults function */ + +void test_set_up_tls_defaults_SuccessfulSetup_ReturnsZero(void) { + // Arrange + mbedtls_ssl_config_defaults_stub.returns("mbedtls_ssl_config_defaults", 0); + + // Act + int result = set_up_tls_defaults(testContext); + + // Assert + TEST_ASSERT_EQUAL_INT(1, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_set_up_tls_defaults_FailedSetup_ReturnsErrorCode(void) { + // Arrange + mbedtls_ssl_config_defaults_stub.returns("mbedtls_ssl_config_defaults", -1); + + // Act + int result = set_up_tls_defaults(testContext); + + // Assert + TEST_ASSERT_EQUAL_INT(1, log_v_stub.timesCalled()); + TEST_ASSERT_EQUAL_INT(-1, result); +} + +void run_set_up_tls_defaults_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_set_up_tls_defaults_SuccessfulSetup_ReturnsZero); + RUN_TEST(test_set_up_tls_defaults_FailedSetup_ReturnsErrorCode); + UNITY_END(); +} + +/* test stop_ssl_socket function */ + +void test_stop_ssl_socket_success(void) { + // Arrange + test_client_stop_stub.reset(); + ssl_init(testContext, &testClient); + setup_stop_ssl_socket(testContext, &testClient); + log_d_stub.reset(); + + // Act + stop_ssl_socket(testContext, "rootCABuff_example", "cli_cert_example", "cli_key_example"); + + // Assert + TEST_ASSERT_TRUE(test_client_stop_stub.wasCalled()); + TEST_ASSERT_TRUE(log_d_stub.timesCalled() == 9); + TEST_ASSERT_TRUE(mbedtls_x509_crt_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_pk_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ssl_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ssl_config_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ctr_drbg_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_entropy_free_stub.wasCalled()); +} + +void test_stop_ssl_socket_edge_null_pointers(void) { + // Arrange + test_client_stop_stub.reset(); + ssl_init(testContext, &testClient); + log_d_stub.reset(); + + // Act + stop_ssl_socket(testContext, "rootCABuff_example", "cli_cert_example", "cli_key_example"); + + // Assert + TEST_ASSERT_TRUE(test_client_stop_stub.wasCalled()); + TEST_ASSERT_TRUE(log_d_stub.timesCalled() == 7); + TEST_ASSERT_FALSE(mbedtls_x509_crt_free_stub.wasCalled()); + TEST_ASSERT_FALSE(mbedtls_pk_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ssl_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ssl_config_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ctr_drbg_free_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_entropy_free_stub.wasCalled()); +} + +void test_stop_ssl_socket_failure_will_not_double_free(void) { + // Arrange + test_client_stop_stub.reset(); + ssl_init(testContext, &testClient); + testContext->client = NULL; + log_d_stub.reset(); + + // Act + stop_ssl_socket(testContext, "rootCABuff_example", "cli_cert_example", "cli_key_example"); + + // Assert + TEST_ASSERT_FALSE(test_client_stop_stub.wasCalled()); +} + +void run_stop_ssl_socket_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_stop_ssl_socket_success); + RUN_TEST(test_stop_ssl_socket_edge_null_pointers); + RUN_TEST(test_stop_ssl_socket_failure_will_not_double_free); + UNITY_END(); +} + +/* test send_ssl_data function */ + +void test_send_ssl_data_successful_write(void) { + // Arrange + testContext->client = &testClient; + testContext->handshake_timeout = 100; + const uint8_t data[] = "test_data"; + int len = sizeof(data) - 1; // Excluding null terminator + mbedtls_ssl_write_stub.returns("mbedtls_ssl_write", len); + + // Act + int ret = send_ssl_data(testContext, data, len); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 4); + TEST_ASSERT_TRUE(mbedtls_ssl_write_stub.wasCalled()); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(len, ret); +} + +void test_send_ssl_data_want_write_then_success(void) { + // Arrange + testContext->client = &testClient; + testContext->handshake_timeout = 100; + const uint8_t data[] = "test_data"; + int len = sizeof(data) - 1; // Excluding null terminator + + // First two calls to mbedtls_ssl_write will return WANT_WRITE, then it will succeed + mbedtls_ssl_write_stub.returns("mbedtls_ssl_write", MBEDTLS_ERR_SSL_WANT_WRITE) + .then(MBEDTLS_ERR_SSL_WANT_WRITE) + .then(len); + + // Act + int ret = send_ssl_data(testContext, data, len); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 4); + TEST_ASSERT_TRUE(mbedtls_ssl_write_stub.wasCalled()); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(len, ret); +} + +void test_send_ssl_data_null_context(void) { + // Act + int ret = send_ssl_data(NULL, NULL, 0); + + // Assert + TEST_ASSERT_FALSE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(mbedtls_ssl_write_stub.timesCalled() == 0); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(-1, ret); +} + +void test_send_ssl_data_mbedtls_failure(void) { + // Arrange + testContext->client = &testClient; + testContext->handshake_timeout = 100; + const uint8_t data[] = "test_data"; + int len = sizeof(data) - 1; // Excluding null terminator + mbedtls_ssl_write_stub.returns("mbedtls_ssl_write", MBEDTLS_ERR_SSL_ALLOC_FAILED); + + // Act + int ret = send_ssl_data(testContext, data, len); + + // Assert + TEST_ASSERT_TRUE(ret < 0); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 3); + TEST_ASSERT_TRUE(mbedtls_ssl_write_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); +} + +void test_send_ssl_data_zero_length(void) { + // Arrange + testContext->client = &testClient; + testContext->handshake_timeout = 100; + const uint8_t data[] = "test_data"; + mbedtls_ssl_write_stub.returns("mbedtls_ssl_write", 0); + + // Act + int ret = send_ssl_data(testContext, data, 0); + + // Assert + TEST_ASSERT_EQUAL_INT(0, ret); + TEST_ASSERT_TRUE(mbedtls_ssl_write_stub.wasCalled()); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 3); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); +} + +void test_send_ssl_data_want_read_then_success(void) { + // Arrange + testContext->client = &testClient; + testContext->handshake_timeout = 100; + const uint8_t data[] = "test_data"; + int len = sizeof(data) - 1; // Excluding null terminator + + // First two calls to mbedtls_ssl_write will return WANT_READ, then it will succeed + mbedtls_ssl_write_stub.returns("mbedtls_ssl_write", MBEDTLS_ERR_SSL_WANT_WRITE) + .then(MBEDTLS_ERR_SSL_WANT_WRITE) + .then(len); + + // Act + int ret = send_ssl_data(testContext, data, len); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 4); + TEST_ASSERT_TRUE(mbedtls_ssl_write_stub.wasCalled()); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(len, ret); +} + +void run_send_ssl_data_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_send_ssl_data_successful_write); + RUN_TEST(test_send_ssl_data_want_write_then_success); + RUN_TEST(test_send_ssl_data_null_context); + RUN_TEST(test_send_ssl_data_mbedtls_failure); + RUN_TEST(test_send_ssl_data_zero_length); + RUN_TEST(test_send_ssl_data_want_read_then_success); + UNITY_END(); +} + +/* Test get_ssl_receive function */ + +void test_get_ssl_receive_success(void) { + // Arrange + unsigned char data[1024]; + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", 1024); + + // Act + int result = get_ssl_receive(testContext, data, sizeof(data)); + + // Assert + TEST_ASSERT_EQUAL_INT(1024, result); +} + +void test_get_ssl_receive_partial_read(void) { + // Arrange + unsigned char data[1024]; + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", 512); + + // Act + int result = get_ssl_receive(testContext, data, sizeof(data)); + + // Assert + TEST_ASSERT_EQUAL_INT(512, result); +} + +void test_get_ssl_receive_failure_bad_input(void) { + // Arrange + unsigned char data[1024]; + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", MBEDTLS_ERR_SSL_BAD_INPUT_DATA); + + // Act + int result = get_ssl_receive(testContext, data, sizeof(data)); + + // Assert + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_BAD_INPUT_DATA, result); +} + +void test_get_ssl_receive_failed_alloc(void) { + // Arrange + unsigned char data[1024]; + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", MBEDTLS_ERR_SSL_ALLOC_FAILED); + + // Act + int result = get_ssl_receive(testContext, data, sizeof(data)); + + // Assert + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_ALLOC_FAILED, result); +} + +void test_get_ssl_receive_zero_length(void) { + // Arrange + unsigned char data[1]; + mbedtls_ssl_read_stub.returns("mbedtls_ssl_read", 0); + + // Act + int result = get_ssl_receive(testContext, data, 0); + + // Assert + TEST_ASSERT_EQUAL_INT(0, result); +} + +void run_get_ssl_receive_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_get_ssl_receive_success); + RUN_TEST(test_get_ssl_receive_partial_read); + RUN_TEST(test_get_ssl_receive_failure_bad_input); + RUN_TEST(test_get_ssl_receive_failed_alloc); + RUN_TEST(test_get_ssl_receive_zero_length); + UNITY_END(); +} + +/* test parse_hex_nibble function */ + +void test_parse_hex_nibble_digit(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('5', &result); + + // Assert + TEST_ASSERT_TRUE(success); + TEST_ASSERT_EQUAL_UINT8(5, result); +} + +void test_parse_hex_nibble_lowercase(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('b', &result); + + // Assert + TEST_ASSERT_TRUE(success); + TEST_ASSERT_EQUAL_UINT8(11, result); +} + +void test_parse_hex_nibble_uppercase(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('D', &result); + + // Assert + TEST_ASSERT_TRUE(success); + TEST_ASSERT_EQUAL_UINT8(13, result); +} + +void test_parse_hex_nibble_below_range(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('/', &result); + + // Assert + TEST_ASSERT_FALSE(success); +} + +void test_parse_hex_nibble_between_range(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('h', &result); + + // Assert + TEST_ASSERT_FALSE(success); +} + +void test_parse_hex_nibble_above_range(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('Z', &result); + + // Assert + TEST_ASSERT_FALSE(success); +} + +void test_parse_hex_nibble_edge_smallest(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('0', &result); + + // Assert + TEST_ASSERT_TRUE(success); + TEST_ASSERT_EQUAL_UINT8(0, result); +} + +void test_parse_hex_nibble_edge_largest(void) { + // Arrange + uint8_t result; + + // Act + bool success = parse_hex_nibble('f', &result); + + // Assert + TEST_ASSERT_TRUE(success); + TEST_ASSERT_EQUAL_UINT8(15, result); +} + +void run_parse_hex_nibble_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_parse_hex_nibble_digit); + RUN_TEST(test_parse_hex_nibble_lowercase); + RUN_TEST(test_parse_hex_nibble_uppercase); + RUN_TEST(test_parse_hex_nibble_below_range); + RUN_TEST(test_parse_hex_nibble_between_range); + RUN_TEST(test_parse_hex_nibble_above_range); + RUN_TEST(test_parse_hex_nibble_edge_smallest); + RUN_TEST(test_parse_hex_nibble_edge_largest); + UNITY_END(); +} + +/* test match_name function */ + +void test_match_name_exact_match(void) { + // Arrange + string name = "example.com"; + string domainName = "example.com"; + + // Act + bool result = match_name(name, domainName); + + // Assert + TEST_ASSERT_TRUE(result); +} + +void test_match_name_simple_wildcard_match(void) { + // Arrange + string name = "*.example.com"; + string domainName = "test.example.com"; + + // Act + bool result = match_name(name, domainName); + + // Assert + TEST_ASSERT_TRUE(result); +} + +void test_match_name_exact_mismatch(void) { + // Arrange + string name = "example1.com"; + string domainName = "example2.com"; + + // Act + bool result = match_name(name, domainName); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void test_match_name_wildcard_wrong_position(void) { + // Arrange + string name = "test.*.example.com"; + string domainName = "test.abc.example.com"; + + // Act + bool result = match_name(name, domainName); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void test_match_name_wildcard_not_beginning(void) { + // Arrange + string name = "te*.example.com"; + string domainName = "test.example.com"; + + // Act + bool result = match_name(name, domainName); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void test_match_name_wildcard_without_subdomain(void) { + // Arrange + string name = "*.example.com"; + string domainName = "example.com"; + + // Act + bool result = match_name(name, domainName); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void run_match_name_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_match_name_exact_match); + RUN_TEST(test_match_name_simple_wildcard_match); + RUN_TEST(test_match_name_exact_mismatch); + RUN_TEST(test_match_name_wildcard_wrong_position); + RUN_TEST(test_match_name_wildcard_not_beginning); + RUN_TEST(test_match_name_wildcard_without_subdomain); + UNITY_END(); +} + +/* test verify_ssl_fingerprint function */ + +void test_verify_ssl_fingerprint_short_fp(void) { + // Arrange + const char* short_fp = "d83c1c1f57"; + + // Act + bool result = verify_ssl_fingerprint(testContext, short_fp, nullptr); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void test_verify_ssl_fingerprint_invalid_format(void) { + // Arrange + const char* invalid_fp = "invalid_format_fp"; + + // Act + bool result = verify_ssl_fingerprint(testContext, invalid_fp, nullptr); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void test_verify_ssl_fingerprint_invalid_hex_sequence(void) { + // Arrange + const char* invalid_hex = "d83c1c1f574fd9e75a7848ad8fb131302c31e224ad8c2617a9b3e24e81fc44ez"; // 'z' is not a valid hex character + + // Act + bool result = verify_ssl_fingerprint(testContext, invalid_hex, nullptr); + + // Assert + TEST_ASSERT_FALSE_MESSAGE(result, "Expected invalid hex sequence to fail."); +} + +void test_verify_ssl_fingerprint_domain_fail(void) { + // Arrange + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert); + + const char* test_fp = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + + // Act + bool result = verify_ssl_fingerprint(testContext, test_fp, "examplecom"); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void test_verify_ssl_fingerprint_no_peer_cert(void) { + // Arrange + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert); + const char* valid_fp = "d83c1c1f574fd9e75a7848ad8fb131302c31e224ad8c2617a9b3e24e81fc44e5"; + + // Act + bool result = verify_ssl_fingerprint(testContext, valid_fp, nullptr); + + // Assert + TEST_ASSERT_FALSE(result); +} + +void run_verify_ssl_fingerprint_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_verify_ssl_fingerprint_short_fp); + RUN_TEST(test_verify_ssl_fingerprint_invalid_format); + RUN_TEST(test_verify_ssl_fingerprint_invalid_hex_sequence); + RUN_TEST(test_verify_ssl_fingerprint_domain_fail); + RUN_TEST(test_verify_ssl_fingerprint_no_peer_cert); + UNITY_END(); +} + +/* test verify_ssl_dn function */ + +void test_verify_ssl_dn_match_in_sans(void) { + // Arrange + std:string domainName = "example.com"; + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert_with_san); + + // Act + bool result = verify_ssl_dn(testContext, domainName.c_str()); + + // Assert + TEST_ASSERT_TRUE_MESSAGE(result, "Expected to match domain name in SANs."); +} + +void test_verify_ssl_dn_match_in_cn(void) { + // Arrange + std:string domainName = "example.com"; + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert_with_cn); + + // Act + bool result = verify_ssl_dn(testContext, domainName.c_str()); + + // Assert + TEST_ASSERT_TRUE_MESSAGE(result, "Expected to match domain name in CN."); +} + +void test_verify_ssl_dn_no_match(void) { + // Arrange + std:string domainName = "example.com"; + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert_without_match); + + // Act + bool result = verify_ssl_dn(testContext, domainName.c_str()); + + // Assert + TEST_ASSERT_FALSE_MESSAGE(result, "Expected no domain name match in both SANs and CN."); +} + +void test_verify_ssl_dn_empty_domain_name(void) { + // Arrange + std::string emptyDomainName = ""; + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert_without_match); + + // Act + bool result = verify_ssl_dn(testContext, emptyDomainName.c_str()); + + // Assert + TEST_ASSERT_FALSE_MESSAGE(result, "Expected to fail with an empty domain name."); +} + +void test_verify_ssl_dn_no_peer_cert(void) { + // Arrange + std:string domainName = "example.com"; + mbedtls_ssl_get_peer_cert_stub.returns("mbedtls_ssl_get_peer_cert", &dummy_cert); + + // Act + bool result = verify_ssl_dn(testContext, domainName.c_str()); + + // Assert + TEST_ASSERT_FALSE_MESSAGE(result, "Expected to fail when no peer certificate is found."); +} + +void run_verify_ssl_dn_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_verify_ssl_dn_match_in_sans); + RUN_TEST(test_verify_ssl_dn_match_in_cn); + RUN_TEST(test_verify_ssl_dn_no_match); + RUN_TEST(test_verify_ssl_dn_empty_domain_name); + RUN_TEST(test_verify_ssl_dn_no_peer_cert); + UNITY_END(); +} + +/* test auth_root_ca_buff function */ + +void test_auth_root_ca_buff_success(void) { + // Arrange + const char *valid_ca_buff = ""; + bool ca_cert_initialized = false; + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", 0); + + // Act + int result = auth_root_ca_buff(testContext, valid_ca_buff, &ca_cert_initialized, NULL, NULL); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT_MESSAGE(0, result, "Expected successful configuration."); +} + +void test_auth_root_ca_buff_failure(void) { + // Arrange + const char *invalid_ca_buff = ""; + bool ca_cert_initialized = false; + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE); + + // Act + int result = auth_root_ca_buff(testContext, invalid_ca_buff, &ca_cert_initialized, NULL, NULL); + + // Assert + TEST_ASSERT_EQUAL_INT_MESSAGE(MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE, result, "Expected failure in configuration."); +} + +void test_auth_root_ca_buff_edge(void) { + // Arrange + int returnVal = -1; + + // Act + int result = auth_root_ca_buff(testContext, NULL, NULL, "", ""); + + // Assert + TEST_ASSERT_EQUAL_INT(returnVal, result); +} + +void test_auth_root_ca_buff_null_ssl_client(void) { + // Arrange + int func_ret = 0; + int returnVal = -1; + + // Act + int result = auth_root_ca_buff(NULL, NULL, NULL, NULL, NULL); + + // Assert + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 0); + TEST_ASSERT_EQUAL_INT(returnVal, result); +} + +void test_auth_root_ca_buff_invalid_ca_valid_psk(void) { + // Arrange + const char *invalid_ca_buff = ""; + const char *valid_pskIdent = ""; + const char *valid_psKey = ""; + bool ca_cert_initialized = false; + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE); + + // Act + int result = auth_root_ca_buff(testContext, invalid_ca_buff, &ca_cert_initialized, valid_pskIdent, valid_psKey); + + // Assert + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 0); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE, result); +} + +void test_auth_root_ca_buff_valid_ca_valid_psk(void) { + // Arrange + const char *valid_ca_buff = ""; + const char *valid_pskIdent = ""; + const char *valid_psKey = ""; + int returnVal = -1; + + // Act + int result = auth_root_ca_buff(testContext, valid_ca_buff, NULL, valid_pskIdent, valid_psKey); + + // Assert + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(returnVal, result); +} + +void test_auth_root_ca_buff_long_psk(void) { + // Arrange + const char *long_psKey = ""; + + // Act + int result = auth_root_ca_buff(testContext, NULL, NULL, "", long_psKey); + + // Assert + TEST_ASSERT_EQUAL_INT(-1, result); +} + +void test_auth_root_ca_buff_malformed_psk(void) { + // Arrange + const char *malformed_psKey = ""; + + // Act + int result = auth_root_ca_buff(testContext, NULL, NULL, "", malformed_psKey); + + // Assert + TEST_ASSERT_EQUAL_INT(-1, result); +} + +void run_auth_root_ca_buff_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_auth_root_ca_buff_success); + RUN_TEST(test_auth_root_ca_buff_failure); + RUN_TEST(test_auth_root_ca_buff_edge); + RUN_TEST(test_auth_root_ca_buff_null_ssl_client); + RUN_TEST(test_auth_root_ca_buff_invalid_ca_valid_psk); + RUN_TEST(test_auth_root_ca_buff_valid_ca_valid_psk); + RUN_TEST(test_auth_root_ca_buff_long_psk); + RUN_TEST(test_auth_root_ca_buff_malformed_psk); + UNITY_END(); +} + +/* test auth_client_cert_key function */ + +void test_auth_client_cert_key_both_null() { + // Arrange / Act + int result = auth_client_cert_key(testContext, NULL, NULL, nullptr, nullptr); + + // Assert + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_auth_client_cert_key_cert_null() { + // Arrange / Act + int result = auth_client_cert_key(testContext, NULL, "", nullptr, nullptr); + + // Assert + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_auth_client_cert_key_key_null() { + // Arrange / Act + int result = auth_client_cert_key(testContext, "", NULL, nullptr, nullptr); + + // Assert + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_auth_client_cert_key_valid() { + // Arrange + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", 0); + bool cert_init = false, key_init = false; + + // Act + int result = auth_client_cert_key(testContext, "", "", &cert_init, &key_init); + + // Assert + TEST_ASSERT_EQUAL_INT(0, result); + TEST_ASSERT_TRUE(cert_init); + TEST_ASSERT_TRUE(key_init); +} + +void test_auth_client_cert_key_invalid_cert() { + // Arrange + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", MBEDTLS_X509_BADCERT_NOT_TRUSTED); + bool cert_init = false, key_init = false; + + // Act + int result = auth_client_cert_key(testContext, "", "", &cert_init, &key_init); + + // Assert + TEST_ASSERT_EQUAL_INT(MBEDTLS_X509_BADCERT_NOT_TRUSTED, result); + TEST_ASSERT_FALSE(cert_init); +} + +void test_auth_client_cert_key_invalid_key() { + // Arrange + mbedtls_x509_crt_parse_stub.returns("mbedtls_x509_crt_parse", MBEDTLS_X509_BADCERT_NOT_TRUSTED); + bool cert_init = false, key_init = false; + + // Act + int result = auth_client_cert_key(testContext, "", "", &cert_init, &key_init); + + // Assert + TEST_ASSERT_EQUAL_INT(MBEDTLS_X509_BADCERT_NOT_TRUSTED, result); + TEST_ASSERT_FALSE(key_init); +} + +void run_auth_client_cert_key_tests(void) { + UNITY_BEGIN(); + RUN_TEST(test_auth_client_cert_key_both_null); + RUN_TEST(test_auth_client_cert_key_cert_null); + RUN_TEST(test_auth_client_cert_key_key_null); + RUN_TEST(test_auth_client_cert_key_valid); + RUN_TEST(test_auth_client_cert_key_invalid_cert); + RUN_TEST(test_auth_client_cert_key_invalid_key); + UNITY_END(); +} + +/* test set_hostname_for_tls function */ + +void test_set_hostname_for_tls_success(void) { + // Arrange + const char *host = "example.com"; + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", 0); + mbedtls_ssl_setup_stub.returns("mbedtls_ssl_setup", 0); + + // Act + int result = set_hostname_for_tls(testContext, host); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_set_hostname_for_tls_null_host(void) { + // Arrange + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", MBEDTLS_ERR_SSL_BAD_INPUT_DATA); + const char *host = NULL; + + // Act + int result = set_hostname_for_tls(testContext, host); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_BAD_INPUT_DATA, result); +} + +void test_set_hostname_for_tls_empty_host(void) { + // Arrange + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", MBEDTLS_ERR_SSL_BAD_INPUT_DATA); + const char *host = ""; + + // Act + int result = set_hostname_for_tls(testContext, host); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_BAD_INPUT_DATA, result); +} + +void test_set_hostname_for_tls_alloc_failed(void) { + // Arrange + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", MBEDTLS_ERR_SSL_ALLOC_FAILED); + const char *host = "example.com"; + + // Act + int result = set_hostname_for_tls(testContext, host); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_ALLOC_FAILED, result); +} + +void test_set_hostname_for_tls_ssl_setup_failed(void) { + // Arrange + mbedtls_ssl_set_hostname_stub.returns("mbedtls_ssl_set_hostname", 0); + mbedtls_ssl_setup_stub.returns("mbedtls_ssl_setup", MBEDTLS_ERR_SSL_ALLOC_FAILED); + const char *host = "example.com"; + + // Act + int result = set_hostname_for_tls(testContext, host); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 0); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_ALLOC_FAILED, result); +} + +void run_set_hostname_for_tls_tests() { + UNITY_BEGIN(); + RUN_TEST(test_set_hostname_for_tls_success); + RUN_TEST(test_set_hostname_for_tls_null_host); + RUN_TEST(test_set_hostname_for_tls_empty_host); + RUN_TEST(test_set_hostname_for_tls_alloc_failed); + RUN_TEST(test_set_hostname_for_tls_ssl_setup_failed); + UNITY_END(); +} + +/* test set_io_callbacks function */ + +void test_set_io_callbacks_and_timeout_success(void) { + // Arrange + int successfulReturn = 0; + + // Act + int result = set_io_callbacks_and_timeout(testContext, 5000); + + // Assert + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_EQUAL_INT(successfulReturn, result); +} + +void test_set_io_callbacks_and_timeout_zero_timeout(void) { + // Arrange + int successfulReturn = 0; + + // Act + int result = set_io_callbacks_and_timeout(testContext, 0); + + // Assert + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_EQUAL_INT(successfulReturn, result); +} + +void test_set_io_callbacks_and_timeout_negative_timeout(void) { + // Arrange + int failedReturn = -2; + + // Act + int result = set_io_callbacks_and_timeout(testContext, -5000); + + // Assert + TEST_ASSERT_FALSE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(failedReturn, result); +} + +void test_set_io_callbacks_and_timeout_null_context(void) { + // Arrange + int failedReturn = -1; + + // Act + int result = set_io_callbacks_and_timeout(NULL, 5000); + + // Assert + TEST_ASSERT_EQUAL_INT(failedReturn, result); +} + +void test_set_io_callbacks_and_timeout_large_timeout(void) { + // Arrange + int successfulReturn = 0; + + // Act + int result = set_io_callbacks_and_timeout(testContext, INT_MAX); + + + // Assert + TEST_ASSERT_EQUAL_INT(0, result); +} + +void run_set_io_callbacks_tests() { + UNITY_BEGIN(); + RUN_TEST(test_set_io_callbacks_and_timeout_success); + RUN_TEST(test_set_io_callbacks_and_timeout_zero_timeout); + RUN_TEST(test_set_io_callbacks_and_timeout_negative_timeout); + // RUN_TEST(test_set_io_callbacks_and_timeout_null_context); + RUN_TEST(test_set_io_callbacks_and_timeout_large_timeout); + UNITY_END(); +} + +/* test perform_ssl_handshake function */ + +void test_perform_ssl_handshake_success(void) { + // Arrange + const char *cli_cert = NULL; + const char *cli_key = NULL; + mbedtls_ssl_handshake_stub.returns("mbedtls_ssl_handshake", 0); + + // Act + int result = perform_ssl_handshake(testContext, cli_cert, cli_key); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_perform_ssl_handshake_timeout(void) { + // Arrange + const char *cli_cert = NULL; + const char *cli_key = NULL; + testContext->handshake_timeout = 1; + mbedtls_ssl_handshake_stub.returns("mbedtls_ssl_handshake", MBEDTLS_ERR_SSL_WANT_READ); + + // Act + int result = perform_ssl_handshake(testContext, cli_cert, cli_key); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(-1, result); +} + +void test_perform_ssl_handshake_cert_key_provided(void) { + // Arrange + const char *cli_cert = "dummy_cert"; + const char *cli_key = "dummy_key"; + mbedtls_ssl_handshake_stub.returns("mbedtls_ssl_handshake", 0); + mbedtls_ssl_get_record_expansion_stub.returns("mbedtls_ssl_get_record_expansion", 0); + + // Act + int result = perform_ssl_handshake(testContext, cli_cert, cli_key); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_TRUE(log_w_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_perform_ssl_handshake_null_context(void) { + // Arrange + const char *cli_cert = NULL; + const char *cli_key = NULL; + + // Act + int result = perform_ssl_handshake(NULL, cli_cert, cli_key); + + // Assert + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_FALSE(log_v_stub.wasCalled()); + TEST_ASSERT_EQUAL(-1, result); +} + +void test_perform_ssl_handshake_record_expansion_failure(void) { + // Arrange + const char *cli_cert = "dummy_cert"; + const char *cli_key = "dummy_key"; + mbedtls_ssl_handshake_stub.returns("mbedtls_ssl_handshake", 0); + mbedtls_ssl_get_record_expansion_stub.returns("mbedtls_ssl_get_record_expansion", MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE); + + // Act + int result = perform_ssl_handshake(testContext, cli_cert, cli_key); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 2); + TEST_ASSERT_TRUE(log_w_stub.wasCalled()); + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_EQUAL_INT(MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE, result); +} + +void run_perform_ssl_handshake_tests() { + UNITY_BEGIN(); + RUN_TEST(test_perform_ssl_handshake_success); + RUN_TEST(test_perform_ssl_handshake_timeout); + RUN_TEST(test_perform_ssl_handshake_cert_key_provided); + RUN_TEST(test_perform_ssl_handshake_null_context); + RUN_TEST(test_perform_ssl_handshake_record_expansion_failure); + UNITY_END(); +} + +/* test verify_server_cert function */ + +void test_verify_server_cert_success(void) { + // Arrange + mbedtls_ssl_get_verify_result_stub.returns("mbedtls_ssl_get_verify_result", (uint32_t)0); + + // Act + int result = verify_server_cert(testContext); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL_INT(0, result); +} + +void test_verify_server_cert_fail_handshake(void) { + // Arrange + mbedtls_ssl_get_verify_result_stub.returns("mbedtls_ssl_get_verify_result", (uint32_t)-1u); + + // Act + uint32_t result = verify_server_cert(testContext); + + // Assert + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL((uint32_t)-1u, result); +} + +void test_verify_server_cert_null_context(void) { + // Arrange / Act + int result = verify_server_cert(NULL); + + // Assert + TEST_ASSERT_FALSE(log_v_stub.wasCalled()); + TEST_ASSERT_TRUE(log_e_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL(-1, result); +} + +void test_verify_server_cert_mismatched_cert_key(void) { + // Arrange + mbedtls_ssl_get_verify_result_stub.returns("mbedtls_ssl_get_verify_result", (uint32_t)MBEDTLS_ERR_X509_CERT_VERIFY_FAILED); + + // Act + uint32_t result = verify_server_cert(testContext); + + // Assert + TEST_ASSERT_FALSE(log_e_stub.wasCalled()); + TEST_ASSERT_TRUE(log_v_stub.timesCalled() == 1); + TEST_ASSERT_EQUAL((uint32_t)MBEDTLS_ERR_X509_CERT_VERIFY_FAILED, result); +} + +void run_verify_server_cert_tests() { + UNITY_BEGIN(); + RUN_TEST(test_verify_server_cert_success); + RUN_TEST(test_verify_server_cert_fail_handshake); + RUN_TEST(test_verify_server_cert_null_context); + RUN_TEST(test_verify_server_cert_mismatched_cert_key); + UNITY_END(); +} + +/* End of test functions */ + +#ifdef ARDUINO + +#include + +void setup() { + run_all_tests(); +} + +void loop() {} + +#else + +int main(int argc, char **argv) { + run_handle_error_tests(); + run_client_net_recv_tests(); + run_client_net_recv_timeout_tests(); + run_client_net_send_tests(); + run_ssl_init_tests(); + run_log_failed_cert_tests(); + run_cleanup_tests(); + run_start_ssl_client_tests(); + run_init_tcp_connection_tests(); + run_seed_random_number_generator_tests(); + run_set_up_tls_defaults_tests(); + run_auth_root_ca_buff_tests(); + run_auth_client_cert_key_tests(); + run_set_hostname_for_tls_tests(); + run_set_io_callbacks_tests(); + run_perform_ssl_handshake_tests(); + run_verify_server_cert_tests(); + run_stop_ssl_socket_tests(); + run_data_to_read_tests(); + run_send_ssl_data_tests(); + run_get_ssl_receive_tests(); + run_parse_hex_nibble_tests(); + run_match_name_tests(); + run_verify_ssl_fingerprint_tests(); // We are currently not testing the fingerprint verification + run_verify_ssl_dn_tests(); return 0; }