diff --git a/server.go b/server.go index d097e85c..0a21815c 100644 --- a/server.go +++ b/server.go @@ -45,6 +45,7 @@ var ( // Capabilities indicates the capabilities and features provided by the server. type Capabilities struct { + MaximumClients int64 `yaml:"maximum_clients" json:"maximum_clients"` // maximum number of connected clients MaximumMessageExpiryInterval int64 `yaml:"maximum_message_expiry_interval" json:"maximum_message_expiry_interval"` // maximum message expiry if message expiry is 0 or over MaximumClientWritesPending int32 `yaml:"maximum_client_writes_pending" json:"maximum_client_writes_pending"` // maximum number of pending message writes for a client MaximumSessionExpiryInterval uint32 `yaml:"maximum_session_expiry_interval" json:"maximum_session_expiry_interval"` // maximum number of seconds to keep disconnected sessions @@ -65,6 +66,7 @@ type Capabilities struct { // NewDefaultServerCapabilities defines the default features and capabilities provided by the server. func NewDefaultServerCapabilities() *Capabilities { return &Capabilities{ + MaximumClients: math.MaxInt64, // maximum number of connected clients MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions @@ -414,6 +416,16 @@ func (s *Server) attachClient(cl *Client, listener string) error { } cl.ParseConnect(listener, pk) + if atomic.LoadInt64(&s.Info.ClientsConnected) >= s.Options.Capabilities.MaximumClients { + if cl.Properties.ProtocolVersion < 5 { + s.SendConnack(cl, packets.ErrServerUnavailable, false, nil) + } else { + s.SendConnack(cl, packets.ErrServerBusy, false, nil) + } + + return packets.ErrServerBusy + } + code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] if code != packets.CodeSuccess { if err := s.SendConnack(cl, code, false, nil); err != nil { diff --git a/server_test.go b/server_test.go index 711ae839..2a255cf4 100644 --- a/server_test.go +++ b/server_test.go @@ -944,6 +944,41 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { _ = r.Close() } +func TestEstablishConnectionMaximumClientsReached(t *testing.T) { + cc := NewDefaultServerCapabilities() + cc.MaximumClients = 0 + s := New(&Options{ + Logger: logger, + Capabilities: cc, + }) + _ = s.AddHook(new(AllowHook), nil) + defer s.Close() + + r, w := net.Pipe() + o := make(chan error) + go func() { + o <- s.EstablishConnection("tcp", r) + }() + + go func() { + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + }() + + // receive the connack + recv := make(chan []byte) + go func() { + buf, err := io.ReadAll(w) + require.NoError(t, err) + recv <- buf + }() + + err := <-o + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrServerBusy) + + _ = r.Close() +} + // See https://github.com/mochi-mqtt/server/issues/178 func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { s := newServer()