diff --git a/ably/realtime_conn.go b/ably/realtime_conn.go index 99cf7bc2..beb2a358 100644 --- a/ably/realtime_conn.go +++ b/ably/realtime_conn.go @@ -201,9 +201,6 @@ func (c *Connection) Close() { // If not in connecting or connected state, this method causes the connection to open, entering the // [ably.ConnectionStateConnecting] state (RTC1b, RTN3, RTN11). func (c *Connection) connect(arg connArgs) (result, error) { - c.mtx.Lock() - arg.mode = c.getMode() - c.mtx.Unlock() return c.connectWithRetryLoop(arg) } @@ -212,27 +209,21 @@ type connArgs struct { connDetails *connectionDetails result bool dialOnce bool - mode connectionMode retryIn time.Duration } func (c *Connection) reconnect(arg connArgs) (result, error) { c.mtx.Lock() - var mode connectionMode if arg.connDetails != nil && c.opts.Now().Sub(arg.lastActivityAt) >= time.Duration(arg.connDetails.ConnectionStateTTL+arg.connDetails.MaxIdleInterval) { // RTN15g c.msgSerial = 0 c.key = "" // c.id isn't cleared since it's used later to determine if the // reconnection resulted in a new transport-level connection. - mode = normalMode - } else { - mode = c.getMode() } c.mtx.Unlock() - arg.mode = mode r, err := c.connectWithRetryLoop(arg) if err != nil { return nil, err @@ -248,6 +239,8 @@ func (c *Connection) reconnect(arg connArgs) (result, error) { } func (c *Connection) getMode() connectionMode { + c.mtx.Lock() + defer c.mtx.Unlock() if c.key != "" { return resumeMode } @@ -392,7 +385,8 @@ func (c *Connection) connectWith(arg connArgs) (result, error) { ConnectionStateDisconnected, ) } - query, err := c.params(arg.mode) + connectMode := c.getMode() + query, err := c.params(connectMode) if err != nil { return nil, err } @@ -416,7 +410,7 @@ func (c *Connection) connectWith(arg connArgs) (result, error) { // Start eventloop go c.eventloop() - c.reconnecting = arg.mode == recoveryMode || arg.mode == resumeMode + c.reconnecting = connectMode == recoveryMode || connectMode == resumeMode c.arg = arg return res, nil } diff --git a/ably/realtime_conn_integration_test.go b/ably/realtime_conn_integration_test.go index 6fd54273..11317145 100644 --- a/ably/realtime_conn_integration_test.go +++ b/ably/realtime_conn_integration_test.go @@ -259,3 +259,64 @@ func TestRealtimeConn_SendErrorReconnects(t *testing.T) { ablytest.Soon.Recv(t, &err, publishErr, t.Fatalf) assert.NoError(t, err) } + +func TestRealtimeConn_ReconnectFromSuspendedState(t *testing.T) { + dialErr := make(chan error, 1) + msgReceiveErr := make(chan error, 1) + + dial := DialFunc(func(p string, url *url.URL, timeout time.Duration) (ably.Conn, error) { + err := <-dialErr + if err != nil { + return nil, err + } + ws, err := ably.DialWebsocket(p, url, timeout) + if err != nil { + return nil, err + } + return connMock{ + SendFunc: ws.Send, + ReceiveFunc: func(deadline time.Time) (*ably.ProtocolMessage, error) { + err := <-msgReceiveErr + if err != nil { + return nil, err + } + msg, err := ws.Receive(deadline) + if msg.Action == ably.ActionConnected { + msg.ConnectionDetails.ConnectionStateTTL = ably.DurationFromMsecs(500 * time.Millisecond) + } + return msg, err + }, + CloseFunc: ws.Close, + }, nil + }) + + // No errors for first connect + dialErr <- nil + msgReceiveErr <- nil + + app, c := ablytest.NewRealtime(ably.WithDial(dial), + ably.WithDisconnectedRetryTimeout(time.Second), + ably.WithSuspendedRetryTimeout(time.Second)) + defer func() { + msgReceiveErr <- nil // receive safe close event + safeclose(t, ablytest.FullRealtimeCloser(c), app) + }() + + err := ablytest.Wait(ablytest.ConnWaiter(c, c.Connect, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) + + // Initiate disconnect and fail subsequent reconnects + msgReceiveErr <- errors.New("initiate disconnect") + dialErr <- errors.New("initiate failure for subsequent reconnects") + + ablytest.Wait(ablytest.ConnWaiter(c, c.Connect, ably.ConnectionEventDisconnected), nil) + ablytest.Wait(ablytest.ConnWaiter(c, c.Connect, ably.ConnectionEventSuspended), nil) + ablytest.Wait(ablytest.ConnWaiter(c, c.Connect, ably.ConnectionEventSuspended), nil) + + // Enable successful connection again + dialErr <- nil + msgReceiveErr <- nil + + err = ablytest.Wait(ablytest.ConnWaiter(c, c.Connect, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) +}