Skip to content

Commit

Permalink
Add DTLS support to MQTT-SN client
Browse files Browse the repository at this point in the history
  • Loading branch information
embhorn committed Aug 25, 2023
1 parent 65a88f9 commit 2c9c951
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 58 deletions.
95 changes: 78 additions & 17 deletions examples/mqttexample.c
Original file line number Diff line number Diff line change
Expand Up @@ -198,31 +198,26 @@ void mqtt_show_usage(MQTTCtx* mqttCtx)
PRINTF("-? Help, print this usage");
PRINTF("-h <host> Host to connect to, default: %s",
mqttCtx->host);
/* Remove TLS and SNI args for sn-client */
if(XSTRNCMP(mqttCtx->app_name, "sn-client", 10)){
#ifdef ENABLE_MQTT_TLS
PRINTF("-p <num> Port to connect on, default: Normal %d, TLS %d",
MQTT_DEFAULT_PORT, MQTT_SECURE_PORT);
PRINTF("-t Enable TLS");
PRINTF("-A <file> Load CA (validate peer)");
PRINTF("-K <key> Use private key (for TLS mutual auth)");
PRINTF("-c <cert> Use certificate (for TLS mutual auth)");
#ifdef HAVE_SNI
PRINTF("-S <str> Use Host Name Indication, blank defaults to host");
#endif
#ifdef HAVE_PQC
#ifdef HAVE_SNI
/* Remove SNI args for sn-client */
if(XSTRNCMP(mqttCtx->app_name, "sn-client", 10)){
PRINTF("-S <str> Use Host Name Indication, blank defaults to host");
}
#endif
#ifdef HAVE_PQC
PRINTF("-Q <str> Use Key Share with post-quantum algorithm");
#endif
#endif
#else
PRINTF("-p <num> Port to connect on, default: %d",
MQTT_DEFAULT_PORT);
#endif
}

else{
PRINTF("-p <num> Port to connect on, default: %d",
MQTT_DEFAULT_PORT);
}
PRINTF("-q <num> Qos Level 0-2, default: %d",
mqttCtx->qos);
PRINTF("-s Disable clean session connect flag");
Expand Down Expand Up @@ -401,10 +396,9 @@ int mqtt_parse_args(MQTTCtx* mqttCtx, int argc, char** argv)
mqtt_show_usage(mqttCtx);
return MY_EX_USAGE;
}
/* Remove TLS and SNI functionality for sn-client */

/* Remove SNI functionality for sn-client */
if(!XSTRNCMP(mqttCtx->app_name, "sn-client", 10)){
mqttCtx->use_tls = 0;
#ifdef HAVE_SNI
useSNI=0;
#endif
Expand Down Expand Up @@ -646,7 +640,8 @@ int mqtt_tls_cb(MqttClient* client)
return rc;
}
}
#elif defined(WOLFMQTT_ZEPHYR)
#else
/* Note: Zephyr example uses NO_FILESYSTEM */
#ifdef WOLFSSL_ENCRYPTED_KEYS
/* Setup password callback for pkcs8 key */
wolfSSL_CTX_set_default_passwd_cb(client->tls.ctx,
Expand Down Expand Up @@ -714,12 +709,78 @@ int mqtt_tls_cb(MqttClient* client)

return rc;
}

#ifdef WOLFMQTT_SN
int mqtt_dtls_cb(MqttClient* client) {
#ifdef WOLFSSL_DTLS
int rc = WOLFSSL_FAILURE;

client->tls.ctx = wolfSSL_CTX_new(wolfDTLSv1_2_client_method());
if (client->tls.ctx) {
wolfSSL_CTX_set_verify(client->tls.ctx, WOLFSSL_VERIFY_PEER,
mqtt_tls_verify_cb);

/* default to success */
rc = WOLFSSL_SUCCESS;

#if !defined(NO_CERT) && !defined(NO_FILESYSTEM)
if (mTlsCaFile) {
/* Load CA certificate file */
rc = wolfSSL_CTX_load_verify_locations(client->tls.ctx,
mTlsCaFile, NULL);
if (rc != WOLFSSL_SUCCESS) {
PRINTF("Error loading CA %s: %d (%s)", mTlsCaFile,
rc, wolfSSL_ERR_reason_error_string(rc));
return rc;
}
}
if (mTlsCertFile && mTlsKeyFile) {
/* Load If using a mutual authentication */
rc = wolfSSL_CTX_use_certificate_file(client->tls.ctx,
mTlsCertFile, WOLFSSL_FILETYPE_PEM);
if (rc != WOLFSSL_SUCCESS) {
PRINTF("Error loading certificate %s: %d (%s)", mTlsCertFile,
rc, wolfSSL_ERR_reason_error_string(rc));
return rc;
}

rc = wolfSSL_CTX_use_PrivateKey_file(client->tls.ctx,
mTlsKeyFile, WOLFSSL_FILETYPE_PEM);
if (rc != WOLFSSL_SUCCESS) {
PRINTF("Error loading key %s: %d (%s)", mTlsKeyFile,
rc, wolfSSL_ERR_reason_error_string(rc));
return rc;
}
}
#endif

client->tls.ssl = wolfSSL_new(client->tls.ctx);
if (client->tls.ssl == NULL) {
rc = WOLFSSL_FAILURE;
return rc;
}
}

PRINTF("MQTT DTLS Setup (%d)", rc);
#else /* WOLFSSL_DTLS */
(void)client;
int rc = 0;
PRINTF("MQTT DTLS Setup - Must enable DTLS in wolfSSL!");
#endif
return rc;
}
#endif /* WOLFMQTT_SN */
#else
int mqtt_tls_cb(MqttClient* client)
{
(void)client;
return 0;
}
int mqtt_dtls_cb(MqttClient* client)
{
(void)client;
return 0;
}
#endif /* ENABLE_MQTT_TLS */

int mqtt_file_load(const char* filePath, byte** fileBuf, int *fileLen)
Expand Down
5 changes: 5 additions & 0 deletions examples/mqttexample.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ int mqtt_parse_args(MQTTCtx* mqttCtx, int argc, char** argv);
int err_sys(const char* msg);

int mqtt_tls_cb(MqttClient* client);

#ifdef WOLFMQTT_SN
int mqtt_dtls_cb(MqttClient* client);
#endif

word16 mqtt_get_packetid(void);

#ifdef WOLFMQTT_NONBLOCK
Expand Down
31 changes: 6 additions & 25 deletions examples/mqttnet.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,14 @@
#include <config.h>
#endif

#include "wolfmqtt/mqtt_client.h"
#include "examples/mqttnet.h"
#include "examples/mqttexample.h"
#include "examples/mqttport.h"

/* Local context for Net callbacks */
typedef enum {
SOCK_BEGIN = 0,
SOCK_CONN
} NB_Stat;

#if 0 /* TODO: add multicast support */
typedef struct MulticastCtx {

} MulticastCtx;
#endif


typedef struct _SocketContext {
SOCKET_T fd;
NB_Stat stat;
SOCK_ADDR_IN addr;
#ifdef MICROCHIP_MPLAB_HARMONY
word32 bytes;
#endif
#if defined(WOLFMQTT_MULTITHREAD) && defined(WOLFMQTT_ENABLE_STDIN_CAP)
/* "self pipe" -> signal wake sleep() */
SOCKET_T pfd[2];
#endif
MQTTCtx* mqttCtx;
} SocketContext;

/* Private functions */

/* -------------------------------------------------------------------------- */
Expand Down Expand Up @@ -578,7 +554,7 @@ static int SN_NetConnect(void *context, const char* host, word16 port,
struct addrinfo hints;
MQTTCtx* mqttCtx = sock->mqttCtx;

PRINTF("NetConnect: Host %s, Port %u, Timeout %d ms, Use TLS %d",
PRINTF("NetConnect: Host %s, Port %u, Timeout %d ms, Use DTLS %d",
host, port, timeout_ms, mqttCtx->use_tls);

/* Get address information for host and locate IPv4 */
Expand Down Expand Up @@ -830,6 +806,11 @@ static int NetRead_ex(void *context, byte* buf, int buf_len,
}
else {
bytes += rc; /* Data */
#if defined(WOLFMQTT_SN) && defined(WOLFSSL_DTLS)
if (wolfSSL_dtls(mqttCtx->client.tls.ssl)) {
break;
}
#endif
}
}

Expand Down
21 changes: 21 additions & 0 deletions examples/mqttnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@
#endif

#include "examples/mqttexample.h"
#include "examples/mqttport.h"

/* Local context for Net callbacks */
typedef enum {
SOCK_BEGIN = 0,
SOCK_CONN
} NB_Stat;

typedef struct _SocketContext {
SOCKET_T fd;
NB_Stat stat;
SOCK_ADDR_IN addr;
#ifdef MICROCHIP_MPLAB_HARMONY
word32 bytes;
#endif
#if defined(WOLFMQTT_MULTITHREAD) && defined(WOLFMQTT_ENABLE_STDIN_CAP)
/* "self pipe" -> signal wake sleep() */
SOCKET_T pfd[2];
#endif
MQTTCtx* mqttCtx;
} SocketContext;

/* Functions used to handle the MqttNet structure creation / destruction */
int MqttClientNet_Init(MqttNet* net, MQTTCtx* mqttCtx);
Expand Down
9 changes: 7 additions & 2 deletions examples/sn-client/sn-client.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ int sn_test(MQTTCtx *mqttCtx)
goto exit;
}

/* Setup socket direct to gateway (UDP network, so no TLS) */
/* The client.ctx will be stored in the cert callback ctx during
MqttSocket_Connect for use by mqtt_tls_verify_cb */
mqttCtx->client.ctx = mqttCtx;

/* Setup socket direct to gateway */
rc = MqttClient_NetConnect(&mqttCtx->client, mqttCtx->host,
mqttCtx->port, DEFAULT_CON_TIMEOUT_MS, 0, NULL);
mqttCtx->port, DEFAULT_CON_TIMEOUT_MS,
mqttCtx->use_tls, mqtt_dtls_cb);

PRINTF("MQTT-SN Socket Connect: %s (%d)",
MqttClient_ReturnCodeToString(rc), rc);
Expand Down
54 changes: 41 additions & 13 deletions src/mqtt_packet.c
Original file line number Diff line number Diff line change
Expand Up @@ -3524,14 +3524,18 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
int timeout_ms)
{
int rc, len = 0, remain_read = 0;
word16 total_len = 0;
word16 total_len = 0, idx = 0;

switch (client->packet.stat)
{
case MQTT_PK_BEGIN:
{
/* Read first 2 bytes using MSG_PEEK */
rc = MqttSocket_Peek(client, rx_buf, 2, timeout_ms);
/* Read first 2 bytes */
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
rc = MqttSocket_Read(client, rx_buf, 2, timeout_ms);
} else {
rc = MqttSocket_Peek(client, rx_buf, 2, timeout_ms);
}
if (rc < 0) {
return MqttPacket_HandleNetError(client, rc);
}
Expand All @@ -3544,16 +3548,29 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,

if (rx_buf[0] == SN_PACKET_LEN_IND){
/* Read length stored in first three bytes, type in fourth */
rc = MqttSocket_Peek(client, rx_buf, 4, timeout_ms);
if (rc < 0) {
return MqttPacket_HandleNetError(client, rc);
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
rc = MqttSocket_Read(client, rx_buf+len, 2, timeout_ms);
if (rc < 0) {
return MqttPacket_HandleNetError(client, rc);
}
else if (rc != 2) {
return MqttPacket_HandleNetError(client,
MQTT_TRACE_ERROR(MQTT_CODE_ERROR_NETWORK));
}
rc += len;
}
else if (rc != 4) {
return MqttPacket_HandleNetError(client,
MQTT_TRACE_ERROR(MQTT_CODE_ERROR_NETWORK));
else {
rc = MqttSocket_Peek(client, rx_buf, 4, timeout_ms);
if (rc < 0) {
return MqttPacket_HandleNetError(client, rc);
}
else if (rc != 4) {
return MqttPacket_HandleNetError(client,
MQTT_TRACE_ERROR(MQTT_CODE_ERROR_NETWORK));
}
len = rc;
}

len = rc;
(void)MqttDecode_Num(&rx_buf[1], &total_len);
client->packet.header_len = len;
}
Expand All @@ -3580,7 +3597,12 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
}
else if ((total_len == 2) || (total_len == 4)) {
/* Handle peek */
client->packet.remain_len = total_len;
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
client->packet.remain_len = total_len - len;
}
else {
client->packet.remain_len = total_len;
}
}
else {
client->packet.remain_len = 0;
Expand All @@ -3592,16 +3614,22 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
client->packet.remain_len = rx_buf_len -
client->packet.header_len;
}

if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
total_len -= client->packet.header_len;
idx = client->packet.header_len;
}
/* Read whole message */
if (client->packet.remain_len > 0) {
rc = MqttSocket_Read(client, &rx_buf[0],
rc = MqttSocket_Read(client, &rx_buf[idx],
total_len, timeout_ms);
if (rc <= 0) {
return MqttPacket_HandleNetError(client, rc);
}
remain_read = rc;
}
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
remain_read += client->packet.header_len;
}

break;
}
Expand Down
2 changes: 1 addition & 1 deletion wolfmqtt/mqtt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,11 @@ WOLFMQTT_API int MqttClient_IsMessageActive(
* \param client Pointer to MqttClient structure
* \param host Address of the broker server
* \param port Optional custom port. If zero will use defaults
* \param timeout_ms Milliseconds until read timeout
* \param use_tls If non-zero value will connect with and use TLS for
encryption of data
* \param cb A function callback for configuration of the SSL
context certificate checking
* \param timeout_ms Milliseconds until read timeout
* \return MQTT_CODE_SUCCESS or MQTT_CODE_ERROR_*
(see enum MqttPacketResponseCodes)
*/
Expand Down

0 comments on commit 2c9c951

Please sign in to comment.