diff --git a/modules/mqtt/mqtt_test.go b/modules/mqtt/mqtt_test.go index 3666db2c..8a17c1b8 100644 --- a/modules/mqtt/mqtt_test.go +++ b/modules/mqtt/mqtt_test.go @@ -20,13 +20,13 @@ func (t *mqttTester) getScanner() (*Scanner, error) { // Identifiers flags.ClientID = "testClient" // MQTT-specific - flags.ClientRandom = "blabla" // on the TCP handshake // Client and user flags.SubscribeTopics = "#,$SYS/#" flags.TopicsSeparator = "," flags.LimitMessages = 1 flags.LimitTopics = 10 + flags.UseTLS = true // Attempt anonymous auth with // an empty user and password as the @@ -69,7 +69,7 @@ func (t *mqttTester) runTest(test *testing.T, name string) { var tests = map[string]*mqttTester{ "success": { addr: "test.mosquitto.org", - port: 1883, + port: 8883, expectedStatus: zgrab2.SCAN_SUCCESS, }, } diff --git a/modules/mqtt/scanner.go b/modules/mqtt/scanner.go index 60a16308..137bdd66 100644 --- a/modules/mqtt/scanner.go +++ b/modules/mqtt/scanner.go @@ -100,7 +100,7 @@ func (scan *scan) getTLSConfig() (*tls.Config, error) { func (scan *scan) makeClient() (paho.Client, error) { // TODO: implement support for web-sockets as well? - o := paho.NewClientOptions() + opts := paho.NewClientOptions() // Add TLS scheme := "tcp" @@ -110,7 +110,7 @@ func (scan *scan) makeClient() (paho.Client, error) { if err != nil { return nil, err } - o.SetTLSConfig(cfg) + opts.SetTLSConfig(cfg) } // Add broker @@ -119,18 +119,21 @@ func (scan *scan) makeClient() (paho.Client, error) { port = scan.target.Port } t := fmt.Sprintf("%s://%s:%d", scheme, scan.target.Host(), *port) - o.AddBroker(t) + opts.AddBroker(t) // Add auth details if scan.scanner.config.UserAuth { - o.SetUsername(scan.scanner.config.Username) - o.SetPassword(scan.scanner.config.Password) + opts.SetUsername(scan.scanner.config.Username) + opts.SetPassword(scan.scanner.config.Password) } - o.SetClientID(scan.scanner.config.ClientID) - o.SetCleanSession(true) - o.SetOrderMatters(false) - return paho.NewClient(o), nil + // TODO: change the dialer to a zgrab2 one. + // opts.SetDialer() + opts.SetClientID(scan.scanner.config.ClientID) + opts.SetCleanSession(true) + opts.SetOrderMatters(false) + opts.SetAutoReconnect(true) + return paho.NewClient(opts), nil } func (scan *scan) Init() (*scan, error) { @@ -148,7 +151,11 @@ func (scan *scan) messageHandler(msgChan chan paho.Message) func(c paho.Client, tCount := make(map[string]int) var mu sync.Mutex + isFull := func(t string) bool { + mu.Lock() + defer mu.Unlock() + tc, ok := tCount[t] // if the array does not exist, check the number of topics if !ok && (tLimit > -1 && len(tCount) >= tLimit) { @@ -162,17 +169,19 @@ func (scan *scan) messageHandler(msgChan chan paho.Message) func(c paho.Client, return false } - return func(c paho.Client, m paho.Message) { - topic := m.Topic() + var addToCount = func(topic string) { mu.Lock() + defer mu.Unlock() + tCount[topic]++ + } + + var addMessage = func(c paho.Client, m paho.Message) { + topic := m.Topic() if isFull(topic) { - c.Unsubscribe(m.Topic()) - mu.Unlock() + c.Unsubscribe(topic) return } - tCount[topic]++ - mu.Unlock() - + addToCount(topic) select { case msgChan <- m: // sent @@ -180,6 +189,12 @@ func (scan *scan) messageHandler(msgChan chan paho.Message) func(c paho.Client, //ignore } } + + // We cannot block here, so call a goroutine to handle + // the message instead. + return func(c paho.Client, m paho.Message) { + go addMessage(c, m) + } } // Grab starts the scan @@ -199,6 +214,10 @@ func (scan *scan) Grab() *zgrab2.ScanError { msgs := make(chan paho.Message) handler := scan.messageHandler(msgs) + var wg *sync.WaitGroup + wg.Add(1) + go scan.handleMessages(msgs, wg) + if t := scan.client.SubscribeMultiple(filt, handler); t.Wait() && t.Error() != nil { return zgrab2.NewScanError(zgrab2.SCAN_CONNECTION_REFUSED, t.Error()) } @@ -206,11 +225,15 @@ func (scan *scan) Grab() *zgrab2.ScanError { ctx, cancel := context.WithTimeout(context.Background(), scan.scanner.config.Timeout) defer cancel() - go func() { - <-ctx.Done() - scan.client.Unsubscribe(subs...) - close(msgs) - }() + <-ctx.Done() + scan.client.Unsubscribe(subs...) + close(msgs) + wg.Wait() + return nil +} + +func (scan *scan) handleMessages(msgs chan paho.Message, wg *sync.WaitGroup) { + defer wg.Done() topics := make(map[string][]string) for m := range msgs { @@ -220,7 +243,6 @@ func (scan *scan) Grab() *zgrab2.ScanError { topics[m.Topic()] = msg } scan.results.Topics = topics - return nil } // Scanner implements the zgrab2.Scanner interface. diff --git a/modules/mqtt/test/docker-compose.yml b/modules/mqtt/test/docker-compose.yml deleted file mode 100644 index a06d5d3d..00000000 --- a/modules/mqtt/test/docker-compose.yml +++ /dev/null @@ -1,48 +0,0 @@ -services: - mosquitto: - image: eclipse-mosquitto - container_name: mosquitto - ports: - - "9001:9001" - - "8883:8883" - - "1883:1883" # mqtt://mosquitto:1883 - volumes: - - ./mosquitto.conf:/mosquitto/config/mosquitto.conf - networks: - - brokers - - rabbitmq: - image: rabbitmq:management - container_name: rabbitmq - ports: - - "5672:5672" # AMQP port - - "15672:15672" # Management UI port - environment: - - RABBITMQ_DEFAULT_USER=admin - - RABBITMQ_DEFAULT_PASS=admin - networks: - - brokers - - emqx: - image: emqx/emqx - container_name: emqx - ports: - - "1884:1883" # mqtt://emqx:1884 - - "18083:18083" # Dashboard: admin / public - environment: - EMQX_NAME: emqx - EMQX_LISTENER__TCP__EXTERNAL: 1883 - EMQX_LOADED_PLUGINS: emqx_auth_username emqx_recon emqx_retainer emqx_dashboard - networks: - - brokers - - mqtt-explorer: - image: smeagolworms4/mqtt-explorer - ports: - - "4000:4000" - networks: - - brokers - -networks: - brokers: - diff --git a/modules/mqtt/test/mosquitto.conf b/modules/mqtt/test/mosquitto.conf deleted file mode 100644 index c7b17104..00000000 --- a/modules/mqtt/test/mosquitto.conf +++ /dev/null @@ -1,2 +0,0 @@ -allow_anonymous true -listener 1883 \ No newline at end of file