Skip to content

Commit

Permalink
improve locking and releasing
Browse files Browse the repository at this point in the history
  • Loading branch information
RicYaben committed Nov 6, 2024
1 parent 7425adf commit dfdd56a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
2 changes: 0 additions & 2 deletions modules/mqtt/mqtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package mqtt

import (
"testing"
"time"

"github.com/zmap/zgrab2"
)
Expand All @@ -25,7 +24,6 @@ func (t *mqttTester) getScanner() (*Scanner, error) {

// Client and user
flags.SubscribeTopics = "#,$SYS/#"
flags.SubscribeTimeout = 10 * time.Second
flags.TopicsSeparator = ","
flags.LimitMessages = 1
flags.LimitTopics = 10
Expand Down
29 changes: 19 additions & 10 deletions modules/mqtt/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
package mqtt

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"log"
"strings"
"sync"
"time"

paho "github.com/eclipse/paho.mqtt.golang"
"github.com/zmap/zgrab2"
Expand All @@ -32,9 +32,8 @@ type Flags struct {
LimitMessages int `long:"limit-messages" description:"messages per topic, one is enough to prove read access. Default: 0; Limitless: -1;"`
LimitTopics int `long:"limit-topics" description:"number of topics to include, 100 topics cover most use cases. Default: 0; Limitless: -1;"`

SubscribeTopics string `long:"subscribe-topics" default:"#,$SYS/#" description:"list of topics to subscribe to. Defaults to wildcard all and system."`
TopicsSeparator string `long:"separator" default:"," description:"subscribe topics separator"`
SubscribeTimeout time.Duration `long:"wait" default:"10s" description:"time to accept messages from the subscribed topics. Defaults to 10 seconds"`
SubscribeTopics string `long:"subscribe-topics" default:"#,$SYS/#" description:"list of topics to subscribe to. Defaults to wildcard all and system."`
TopicsSeparator string `long:"separator" default:"," description:"subscribe topics separator"`

UserAuth bool `long:"user-auth" description:"whether to authenticate using a set of credentials"`
Username string `long:"username" description:"username to authenticate"`
Expand Down Expand Up @@ -132,6 +131,7 @@ func (scan *scan) makeClient() (paho.Client, error) {

o.SetClientID(scan.scanner.config.ClientID)
o.SetCleanSession(true)
o.SetOrderMatters(false)
return paho.NewClient(o), nil
}

Expand All @@ -149,6 +149,7 @@ func (scan *scan) messageHandler(msgChan chan paho.Message) func(c paho.Client,
tLimit := scan.scanner.config.LimitTopics
tCount := make(map[string]int)

var mu sync.Mutex
isFull := func(t string) bool {
tc, ok := tCount[t]
// if the array does not exist, check the number of topics
Expand All @@ -163,17 +164,23 @@ func (scan *scan) messageHandler(msgChan chan paho.Message) func(c paho.Client,
return false
}

var mu sync.Mutex
return func(c paho.Client, m paho.Message) {
topic := m.Topic()
mu.Lock()
if isFull(topic) {
c.Unsubscribe(m.Topic())
mu.Unlock()
return
}
tCount[topic]++
mu.Unlock()
msgChan <- m

select {
case msgChan <- m:
// sent
default:
//ignore
}
}
}

Expand All @@ -182,6 +189,7 @@ func (scan *scan) Grab() *zgrab2.ScanError {
if t := scan.client.Connect(); t.Wait() && t.Error() != nil {
return zgrab2.NewScanError(zgrab2.SCAN_CONNECTION_REFUSED, t.Error())
}
defer scan.client.Disconnect(250)

subs := strings.Split(scan.scanner.config.SubscribeTopics, scan.scanner.config.TopicsSeparator)
filt := make(map[string]byte)
Expand All @@ -192,12 +200,16 @@ func (scan *scan) Grab() *zgrab2.ScanError {
// Limit the number of messages we get
msgs := make(chan paho.Message)
handler := scan.messageHandler(msgs)

if t := scan.client.SubscribeMultiple(filt, handler); t.Wait() && t.Error() != nil {
return zgrab2.NewScanError(zgrab2.SCAN_CONNECTION_REFUSED, t.Error())
}

ctx, cancel := context.WithTimeout(context.Background(), scan.scanner.config.Timeout)
defer cancel()

go func() {
<-time.After(scan.scanner.config.SubscribeTimeout)
<-ctx.Done()
scan.client.Unsubscribe(subs...)
close(msgs)
}()
Expand Down Expand Up @@ -226,9 +238,6 @@ func (scanner *Scanner) Protocol() string {
// Init initializes the Scanner.
func (scanner *Scanner) Init(flags zgrab2.ScanFlags) error {
scanner.config = flags.(*Flags)
if scanner.config.SubscribeTimeout <= 0 {
scanner.config.SubscribeTimeout = 10 * time.Second
}
return nil
}

Expand Down

0 comments on commit dfdd56a

Please sign in to comment.