Skip to content

Commit 2c9c951

Browse files
committed
Add DTLS support to MQTT-SN client
1 parent 65a88f9 commit 2c9c951

File tree

7 files changed

+159
-58
lines changed

7 files changed

+159
-58
lines changed

examples/mqttexample.c

+78-17
Original file line numberDiff line numberDiff line change
@@ -198,31 +198,26 @@ void mqtt_show_usage(MQTTCtx* mqttCtx)
198198
PRINTF("-? Help, print this usage");
199199
PRINTF("-h <host> Host to connect to, default: %s",
200200
mqttCtx->host);
201-
/* Remove TLS and SNI args for sn-client */
202-
if(XSTRNCMP(mqttCtx->app_name, "sn-client", 10)){
203201
#ifdef ENABLE_MQTT_TLS
204202
PRINTF("-p <num> Port to connect on, default: Normal %d, TLS %d",
205203
MQTT_DEFAULT_PORT, MQTT_SECURE_PORT);
206204
PRINTF("-t Enable TLS");
207205
PRINTF("-A <file> Load CA (validate peer)");
208206
PRINTF("-K <key> Use private key (for TLS mutual auth)");
209207
PRINTF("-c <cert> Use certificate (for TLS mutual auth)");
210-
#ifdef HAVE_SNI
211-
PRINTF("-S <str> Use Host Name Indication, blank defaults to host");
212-
#endif
213-
#ifdef HAVE_PQC
208+
#ifdef HAVE_SNI
209+
/* Remove SNI args for sn-client */
210+
if(XSTRNCMP(mqttCtx->app_name, "sn-client", 10)){
211+
PRINTF("-S <str> Use Host Name Indication, blank defaults to host");
212+
}
213+
#endif
214+
#ifdef HAVE_PQC
214215
PRINTF("-Q <str> Use Key Share with post-quantum algorithm");
215-
#endif
216+
#endif
216217
#else
217218
PRINTF("-p <num> Port to connect on, default: %d",
218219
MQTT_DEFAULT_PORT);
219220
#endif
220-
}
221-
222-
else{
223-
PRINTF("-p <num> Port to connect on, default: %d",
224-
MQTT_DEFAULT_PORT);
225-
}
226221
PRINTF("-q <num> Qos Level 0-2, default: %d",
227222
mqttCtx->qos);
228223
PRINTF("-s Disable clean session connect flag");
@@ -401,10 +396,9 @@ int mqtt_parse_args(MQTTCtx* mqttCtx, int argc, char** argv)
401396
mqtt_show_usage(mqttCtx);
402397
return MY_EX_USAGE;
403398
}
404-
405-
/* Remove TLS and SNI functionality for sn-client */
399+
400+
/* Remove SNI functionality for sn-client */
406401
if(!XSTRNCMP(mqttCtx->app_name, "sn-client", 10)){
407-
mqttCtx->use_tls = 0;
408402
#ifdef HAVE_SNI
409403
useSNI=0;
410404
#endif
@@ -646,7 +640,8 @@ int mqtt_tls_cb(MqttClient* client)
646640
return rc;
647641
}
648642
}
649-
#elif defined(WOLFMQTT_ZEPHYR)
643+
#else
644+
/* Note: Zephyr example uses NO_FILESYSTEM */
650645
#ifdef WOLFSSL_ENCRYPTED_KEYS
651646
/* Setup password callback for pkcs8 key */
652647
wolfSSL_CTX_set_default_passwd_cb(client->tls.ctx,
@@ -714,12 +709,78 @@ int mqtt_tls_cb(MqttClient* client)
714709

715710
return rc;
716711
}
712+
713+
#ifdef WOLFMQTT_SN
714+
int mqtt_dtls_cb(MqttClient* client) {
715+
#ifdef WOLFSSL_DTLS
716+
int rc = WOLFSSL_FAILURE;
717+
718+
client->tls.ctx = wolfSSL_CTX_new(wolfDTLSv1_2_client_method());
719+
if (client->tls.ctx) {
720+
wolfSSL_CTX_set_verify(client->tls.ctx, WOLFSSL_VERIFY_PEER,
721+
mqtt_tls_verify_cb);
722+
723+
/* default to success */
724+
rc = WOLFSSL_SUCCESS;
725+
726+
#if !defined(NO_CERT) && !defined(NO_FILESYSTEM)
727+
if (mTlsCaFile) {
728+
/* Load CA certificate file */
729+
rc = wolfSSL_CTX_load_verify_locations(client->tls.ctx,
730+
mTlsCaFile, NULL);
731+
if (rc != WOLFSSL_SUCCESS) {
732+
PRINTF("Error loading CA %s: %d (%s)", mTlsCaFile,
733+
rc, wolfSSL_ERR_reason_error_string(rc));
734+
return rc;
735+
}
736+
}
737+
if (mTlsCertFile && mTlsKeyFile) {
738+
/* Load If using a mutual authentication */
739+
rc = wolfSSL_CTX_use_certificate_file(client->tls.ctx,
740+
mTlsCertFile, WOLFSSL_FILETYPE_PEM);
741+
if (rc != WOLFSSL_SUCCESS) {
742+
PRINTF("Error loading certificate %s: %d (%s)", mTlsCertFile,
743+
rc, wolfSSL_ERR_reason_error_string(rc));
744+
return rc;
745+
}
746+
747+
rc = wolfSSL_CTX_use_PrivateKey_file(client->tls.ctx,
748+
mTlsKeyFile, WOLFSSL_FILETYPE_PEM);
749+
if (rc != WOLFSSL_SUCCESS) {
750+
PRINTF("Error loading key %s: %d (%s)", mTlsKeyFile,
751+
rc, wolfSSL_ERR_reason_error_string(rc));
752+
return rc;
753+
}
754+
}
755+
#endif
756+
757+
client->tls.ssl = wolfSSL_new(client->tls.ctx);
758+
if (client->tls.ssl == NULL) {
759+
rc = WOLFSSL_FAILURE;
760+
return rc;
761+
}
762+
}
763+
764+
PRINTF("MQTT DTLS Setup (%d)", rc);
765+
#else /* WOLFSSL_DTLS */
766+
(void)client;
767+
int rc = 0;
768+
PRINTF("MQTT DTLS Setup - Must enable DTLS in wolfSSL!");
769+
#endif
770+
return rc;
771+
}
772+
#endif /* WOLFMQTT_SN */
717773
#else
718774
int mqtt_tls_cb(MqttClient* client)
719775
{
720776
(void)client;
721777
return 0;
722778
}
779+
int mqtt_dtls_cb(MqttClient* client)
780+
{
781+
(void)client;
782+
return 0;
783+
}
723784
#endif /* ENABLE_MQTT_TLS */
724785

725786
int mqtt_file_load(const char* filePath, byte** fileBuf, int *fileLen)

examples/mqttexample.h

+5
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ int mqtt_parse_args(MQTTCtx* mqttCtx, int argc, char** argv);
188188
int err_sys(const char* msg);
189189

190190
int mqtt_tls_cb(MqttClient* client);
191+
192+
#ifdef WOLFMQTT_SN
193+
int mqtt_dtls_cb(MqttClient* client);
194+
#endif
195+
191196
word16 mqtt_get_packetid(void);
192197

193198
#ifdef WOLFMQTT_NONBLOCK

examples/mqttnet.c

+6-25
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,14 @@
2424
#include <config.h>
2525
#endif
2626

27-
#include "wolfmqtt/mqtt_client.h"
2827
#include "examples/mqttnet.h"
29-
#include "examples/mqttexample.h"
30-
#include "examples/mqttport.h"
31-
32-
/* Local context for Net callbacks */
33-
typedef enum {
34-
SOCK_BEGIN = 0,
35-
SOCK_CONN
36-
} NB_Stat;
3728

3829
#if 0 /* TODO: add multicast support */
3930
typedef struct MulticastCtx {
4031

4132
} MulticastCtx;
4233
#endif
4334

44-
45-
typedef struct _SocketContext {
46-
SOCKET_T fd;
47-
NB_Stat stat;
48-
SOCK_ADDR_IN addr;
49-
#ifdef MICROCHIP_MPLAB_HARMONY
50-
word32 bytes;
51-
#endif
52-
#if defined(WOLFMQTT_MULTITHREAD) && defined(WOLFMQTT_ENABLE_STDIN_CAP)
53-
/* "self pipe" -> signal wake sleep() */
54-
SOCKET_T pfd[2];
55-
#endif
56-
MQTTCtx* mqttCtx;
57-
} SocketContext;
58-
5935
/* Private functions */
6036

6137
/* -------------------------------------------------------------------------- */
@@ -578,7 +554,7 @@ static int SN_NetConnect(void *context, const char* host, word16 port,
578554
struct addrinfo hints;
579555
MQTTCtx* mqttCtx = sock->mqttCtx;
580556

581-
PRINTF("NetConnect: Host %s, Port %u, Timeout %d ms, Use TLS %d",
557+
PRINTF("NetConnect: Host %s, Port %u, Timeout %d ms, Use DTLS %d",
582558
host, port, timeout_ms, mqttCtx->use_tls);
583559

584560
/* Get address information for host and locate IPv4 */
@@ -830,6 +806,11 @@ static int NetRead_ex(void *context, byte* buf, int buf_len,
830806
}
831807
else {
832808
bytes += rc; /* Data */
809+
#if defined(WOLFMQTT_SN) && defined(WOLFSSL_DTLS)
810+
if (wolfSSL_dtls(mqttCtx->client.tls.ssl)) {
811+
break;
812+
}
813+
#endif
833814
}
834815
}
835816

examples/mqttnet.h

+21
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@
2727
#endif
2828

2929
#include "examples/mqttexample.h"
30+
#include "examples/mqttport.h"
31+
32+
/* Local context for Net callbacks */
33+
typedef enum {
34+
SOCK_BEGIN = 0,
35+
SOCK_CONN
36+
} NB_Stat;
37+
38+
typedef struct _SocketContext {
39+
SOCKET_T fd;
40+
NB_Stat stat;
41+
SOCK_ADDR_IN addr;
42+
#ifdef MICROCHIP_MPLAB_HARMONY
43+
word32 bytes;
44+
#endif
45+
#if defined(WOLFMQTT_MULTITHREAD) && defined(WOLFMQTT_ENABLE_STDIN_CAP)
46+
/* "self pipe" -> signal wake sleep() */
47+
SOCKET_T pfd[2];
48+
#endif
49+
MQTTCtx* mqttCtx;
50+
} SocketContext;
3051

3152
/* Functions used to handle the MqttNet structure creation / destruction */
3253
int MqttClientNet_Init(MqttNet* net, MQTTCtx* mqttCtx);

examples/sn-client/sn-client.c

+7-2
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,14 @@ int sn_test(MQTTCtx *mqttCtx)
139139
goto exit;
140140
}
141141

142-
/* Setup socket direct to gateway (UDP network, so no TLS) */
142+
/* The client.ctx will be stored in the cert callback ctx during
143+
MqttSocket_Connect for use by mqtt_tls_verify_cb */
144+
mqttCtx->client.ctx = mqttCtx;
145+
146+
/* Setup socket direct to gateway */
143147
rc = MqttClient_NetConnect(&mqttCtx->client, mqttCtx->host,
144-
mqttCtx->port, DEFAULT_CON_TIMEOUT_MS, 0, NULL);
148+
mqttCtx->port, DEFAULT_CON_TIMEOUT_MS,
149+
mqttCtx->use_tls, mqtt_dtls_cb);
145150

146151
PRINTF("MQTT-SN Socket Connect: %s (%d)",
147152
MqttClient_ReturnCodeToString(rc), rc);

src/mqtt_packet.c

+41-13
Original file line numberDiff line numberDiff line change
@@ -3524,14 +3524,18 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
35243524
int timeout_ms)
35253525
{
35263526
int rc, len = 0, remain_read = 0;
3527-
word16 total_len = 0;
3527+
word16 total_len = 0, idx = 0;
35283528

35293529
switch (client->packet.stat)
35303530
{
35313531
case MQTT_PK_BEGIN:
35323532
{
3533-
/* Read first 2 bytes using MSG_PEEK */
3534-
rc = MqttSocket_Peek(client, rx_buf, 2, timeout_ms);
3533+
/* Read first 2 bytes */
3534+
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
3535+
rc = MqttSocket_Read(client, rx_buf, 2, timeout_ms);
3536+
} else {
3537+
rc = MqttSocket_Peek(client, rx_buf, 2, timeout_ms);
3538+
}
35353539
if (rc < 0) {
35363540
return MqttPacket_HandleNetError(client, rc);
35373541
}
@@ -3544,16 +3548,29 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
35443548

35453549
if (rx_buf[0] == SN_PACKET_LEN_IND){
35463550
/* Read length stored in first three bytes, type in fourth */
3547-
rc = MqttSocket_Peek(client, rx_buf, 4, timeout_ms);
3548-
if (rc < 0) {
3549-
return MqttPacket_HandleNetError(client, rc);
3551+
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
3552+
rc = MqttSocket_Read(client, rx_buf+len, 2, timeout_ms);
3553+
if (rc < 0) {
3554+
return MqttPacket_HandleNetError(client, rc);
3555+
}
3556+
else if (rc != 2) {
3557+
return MqttPacket_HandleNetError(client,
3558+
MQTT_TRACE_ERROR(MQTT_CODE_ERROR_NETWORK));
3559+
}
3560+
rc += len;
35503561
}
3551-
else if (rc != 4) {
3552-
return MqttPacket_HandleNetError(client,
3553-
MQTT_TRACE_ERROR(MQTT_CODE_ERROR_NETWORK));
3562+
else {
3563+
rc = MqttSocket_Peek(client, rx_buf, 4, timeout_ms);
3564+
if (rc < 0) {
3565+
return MqttPacket_HandleNetError(client, rc);
3566+
}
3567+
else if (rc != 4) {
3568+
return MqttPacket_HandleNetError(client,
3569+
MQTT_TRACE_ERROR(MQTT_CODE_ERROR_NETWORK));
3570+
}
3571+
len = rc;
35543572
}
35553573

3556-
len = rc;
35573574
(void)MqttDecode_Num(&rx_buf[1], &total_len);
35583575
client->packet.header_len = len;
35593576
}
@@ -3580,7 +3597,12 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
35803597
}
35813598
else if ((total_len == 2) || (total_len == 4)) {
35823599
/* Handle peek */
3583-
client->packet.remain_len = total_len;
3600+
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
3601+
client->packet.remain_len = total_len - len;
3602+
}
3603+
else {
3604+
client->packet.remain_len = total_len;
3605+
}
35843606
}
35853607
else {
35863608
client->packet.remain_len = 0;
@@ -3592,16 +3614,22 @@ int SN_Packet_Read(MqttClient *client, byte* rx_buf, int rx_buf_len,
35923614
client->packet.remain_len = rx_buf_len -
35933615
client->packet.header_len;
35943616
}
3595-
3617+
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
3618+
total_len -= client->packet.header_len;
3619+
idx = client->packet.header_len;
3620+
}
35963621
/* Read whole message */
35973622
if (client->packet.remain_len > 0) {
3598-
rc = MqttSocket_Read(client, &rx_buf[0],
3623+
rc = MqttSocket_Read(client, &rx_buf[idx],
35993624
total_len, timeout_ms);
36003625
if (rc <= 0) {
36013626
return MqttPacket_HandleNetError(client, rc);
36023627
}
36033628
remain_read = rc;
36043629
}
3630+
if (client->flags & MQTT_CLIENT_FLAG_IS_TLS) {
3631+
remain_read += client->packet.header_len;
3632+
}
36053633

36063634
break;
36073635
}

wolfmqtt/mqtt_client.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,11 @@ WOLFMQTT_API int MqttClient_IsMessageActive(
510510
* \param client Pointer to MqttClient structure
511511
* \param host Address of the broker server
512512
* \param port Optional custom port. If zero will use defaults
513+
* \param timeout_ms Milliseconds until read timeout
513514
* \param use_tls If non-zero value will connect with and use TLS for
514515
encryption of data
515516
* \param cb A function callback for configuration of the SSL
516517
context certificate checking
517-
* \param timeout_ms Milliseconds until read timeout
518518
* \return MQTT_CODE_SUCCESS or MQTT_CODE_ERROR_*
519519
(see enum MqttPacketResponseCodes)
520520
*/

0 commit comments

Comments
 (0)