Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions cluster/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,32 @@
package cluster

import (
"net"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/wind-c/comqtt/v2/cluster/log"
"github.com/wind-c/comqtt/v2/cluster/utils"
"github.com/wind-c/comqtt/v2/config"
)

func getFreePort() (int, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}

func TestCluster(t *testing.T) {
log.Init(log.DefaultOptions())

bindPort1, err := getFreePort()
bindPort1, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node1")
raftPort1, err := getFreePort()
raftPort1, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node1 Raft")

bindPort2, err := getFreePort()
bindPort2, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node2")
raftPort2, err := getFreePort()
raftPort2, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node2 Raft")

bindPort3, err := getFreePort()
bindPort3, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node3")
raftPort3, err := getFreePort()
raftPort3, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node3 Raft")

members := []string{
Expand Down
4 changes: 2 additions & 2 deletions cluster/discovery/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ package discovery

import (
"encoding/json"
"github.com/wind-c/comqtt/v2/mqtt"
"net"
"os"
"strconv"

"github.com/wind-c/comqtt/v2/mqtt"
)

const (
Expand All @@ -31,7 +32,6 @@ type Node interface {
BindMqttServer(server *mqtt.Server)
LocalAddr() string
LocalName() string
NumMembers() int
Members() []Member
EventChan() <-chan *Event
SendToNode(nodeName string, msg []byte) error
Expand Down
9 changes: 2 additions & 7 deletions cluster/discovery/serf/membership.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strconv"

"github.com/hashicorp/logutils"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/serf/serf"
mb "github.com/wind-c/comqtt/v2/cluster/discovery"
"github.com/wind-c/comqtt/v2/cluster/log"
Expand Down Expand Up @@ -99,8 +98,8 @@ func (m *Membership) EventChan() <-chan *mb.Event {
return m.eventCh
}

func (m *Membership) NumMembers() int {
return m.serf.NumNodes()
func (m *Membership) numMembers() int {
return len(m.aliveMembers())
}

func (m *Membership) LocalName() string {
Expand Down Expand Up @@ -195,10 +194,6 @@ func (m *Membership) eventLoop() {
}
}

func (m *Membership) send(to memberlist.Address, msg []byte) error {
return m.serf.Memberlist().SendToAddress(to, msg)
}

// SendToOthers send message to all nodes except yourself
func (m *Membership) SendToOthers(msg []byte) {
m.Broadcast(msg)
Expand Down
170 changes: 170 additions & 0 deletions cluster/discovery/serf/membership_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package serf

import (
"os"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/wind-c/comqtt/v2/cluster/log"
"github.com/wind-c/comqtt/v2/cluster/utils"
"github.com/wind-c/comqtt/v2/config"
)

func TestMain(m *testing.M) {
log.Init(log.DefaultOptions())
code := m.Run()
os.Exit(code)
}

func TestJoinAndLeave(t *testing.T) {
bindPort1, err := utils.GetFreePort()
assert.NoError(t, err)
conf1 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort1,
NodeName: "test-node-1",
}
inboundMsgCh1 := make(chan []byte)
membership1 := New(conf1, inboundMsgCh1)
err = membership1.Setup()
assert.NoError(t, err)
defer membership1.Stop()

assert.Equal(t, 1, membership1.numMembers())

bindPort2, err := utils.GetFreePort()
assert.NoError(t, err)
conf2 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort2,
NodeName: "test-node-2",
}
inboundMsgCh2 := make(chan []byte)
membership2 := New(conf2, inboundMsgCh2)
err = membership2.Setup()
assert.NoError(t, err)
defer membership2.Stop()

numJoined, err := membership2.Join([]string{"127.0.0.1:" + strconv.Itoa(bindPort1)})
assert.NoError(t, err)
time.Sleep(3 * time.Second)
assert.Equal(t, numJoined, 1)
assert.Equal(t, 2, membership1.numMembers())
assert.Equal(t, 2, membership2.numMembers())

t.Log("Leave node 2")
err = membership2.Leave()
assert.NoError(t, err)

time.Sleep(5 * time.Second)

assert.Equal(t, 1, membership1.numMembers())
}

func TestSendToNode(t *testing.T) {
bindPort1, err := utils.GetFreePort()
assert.NoError(t, err)
bindPort2, err := utils.GetFreePort()
assert.NoError(t, err)

conf1 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort1,
NodeName: "test-node-1",
}
conf2 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort2,
NodeName: "test-node-2",
Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)},
}
inboundMsgCh1 := make(chan []byte)
inboundMsgCh2 := make(chan []byte)

membership1 := New(conf1, inboundMsgCh1)
err = membership1.Setup()
assert.NoError(t, err)
defer membership1.Stop()

membership2 := New(conf2, inboundMsgCh2)
err = membership2.Setup()
assert.NoError(t, err)
defer membership2.Stop()

time.Sleep(3 * time.Second)

err = membership1.SendToNode("test-node-2", []byte("test message"))
assert.NoError(t, err)

select {
case msg := <-inboundMsgCh2:
assert.Equal(t, []byte("test message"), msg)
case <-time.After(5 * time.Second):
t.Fatal("Did not receive the message in membership2")
}
}

func TestSendToOthers(t *testing.T) {
bindPort1, err := utils.GetFreePort()
assert.NoError(t, err)
bindPort2, err := utils.GetFreePort()
assert.NoError(t, err)
bindPort3, err := utils.GetFreePort()
assert.NoError(t, err)

conf1 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort1,
NodeName: "test-node-1",
}
conf2 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort2,
NodeName: "test-node-2",
Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)},
}
conf3 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort3,
NodeName: "test-node-3",
Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)},
}
inboundMsgCh1 := make(chan []byte)
inboundMsgCh2 := make(chan []byte)
inboundMsgCh3 := make(chan []byte)

membership1 := New(conf1, inboundMsgCh1)
err = membership1.Setup()
assert.NoError(t, err)
defer membership1.Stop()

membership2 := New(conf2, inboundMsgCh2)
err = membership2.Setup()
assert.NoError(t, err)
defer membership2.Stop()

membership3 := New(conf3, inboundMsgCh3)
err = membership3.Setup()
assert.NoError(t, err)
defer membership3.Stop()

time.Sleep(3 * time.Second)

membership1.SendToOthers([]byte("test message"))

select {
case msg := <-inboundMsgCh2:
assert.Equal(t, []byte("test message"), msg)
case <-time.After(5 * time.Second):
t.Fatal("Did not receive the message in membership2")
}

select {
case msg := <-inboundMsgCh3:
assert.Equal(t, []byte("test message"), msg)
case <-time.After(5 * time.Second):
t.Fatal("Did not receive the message in membership3")
}
}
14 changes: 12 additions & 2 deletions cluster/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ package utils

import (
"fmt"
"github.com/hashicorp/go-sockaddr"
"net"
"os"
"reflect"
"strings"
"testing"

"github.com/satori/go.uuid"
"github.com/hashicorp/go-sockaddr"

uuid "github.com/satori/go.uuid"
)

func InArray(val interface{}, array interface{}) bool {
Expand Down Expand Up @@ -151,3 +152,12 @@ func PathExists(path string) bool {
}
return true
}

func GetFreePort() (int, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}