diff --git a/ably/ably_test.go b/ably/ably_test.go index 608598b6..89fd0627 100644 --- a/ably/ably_test.go +++ b/ably/ably_test.go @@ -261,6 +261,21 @@ func (rec *MessageRecorder) CheckIfSent(action ably.ProtoAction, times int) func } } +func (rec *MessageRecorder) CheckIfReceived(action ably.ProtoAction, times int) func() bool { + return func() bool { + counter := 0 + for _, m := range rec.Received() { + if m.Action == action { + counter++ + if counter == times { + return true + } + } + } + return false + } +} + func (rec *MessageRecorder) FindFirst(action ably.ProtoAction) *ably.ProtocolMessage { for _, m := range rec.Sent() { if m.Action == action { diff --git a/ably/export_test.go b/ably/export_test.go index dcc12f52..b97d2094 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -262,6 +262,10 @@ func (c *Connection) ConnectionStateTTL() time.Duration { return c.connectionStateTTL() } +func (r *Realtime) Logger() logger { + return r.log() +} + func NewInternalLogger(l Logger) logger { return logger{l: l} } diff --git a/ably/realtime_conn_integration_test.go b/ably/realtime_conn_integration_test.go index 4916e2d9..11317145 100644 --- a/ably/realtime_conn_integration_test.go +++ b/ably/realtime_conn_integration_test.go @@ -320,47 +320,3 @@ func TestRealtimeConn_ReconnectFromSuspendedState(t *testing.T) { err = ablytest.Wait(ablytest.ConnWaiter(c, c.Connect, ably.ConnectionEventConnected), nil) assert.NoError(t, err) } - -func TestRealtimeConn_RTN22_ServerInitiatedAuth(t *testing.T) { - t.Parallel() - app, restClient := ablytest.NewREST() - defer safeclose(t, app) - - recorder := NewMessageRecorder() - - dial := func(proto string, url *url.URL, timeout time.Duration) (ably.Conn, error) { - c, err := recorder.Dial(proto, url, timeout) - if err != nil { - return nil, err - } - return c, nil - } - - authCallback := func(ctx context.Context, tp ably.TokenParams) (ably.Tokener, error) { - token, err := restClient.Auth.RequestToken(context.Background(), &ably.TokenParams{TTL: 3000}) // 3 second time - if err != nil { - return nil, err - } - return token, nil - } - - realtime, err := ably.NewRealtime( - ably.WithAutoConnect(false), - ably.WithDial(dial), - ably.WithEnvironment(ablytest.Environment), - ably.WithAuthCallback(authCallback)) - - assert.NoError(t, err) - defer realtime.Close() - - err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) - assert.NoError(t, err) - - for i := 0; i < 3; i++ { - err = ablytest.Wait(ablytest.ConnWaiter(realtime, nil, ably.ConnectionEventUpdate), nil) - assert.NoError(t, err) - assert.Equal(t, ably.ConnectionStateConnected, realtime.Connection.State()) - } - - // Need to add few more assertions -} diff --git a/ably/realtime_conn_spec_integration_test.go b/ably/realtime_conn_spec_integration_test.go index a6780382..9ab38666 100644 --- a/ably/realtime_conn_spec_integration_test.go +++ b/ably/realtime_conn_spec_integration_test.go @@ -1778,6 +1778,56 @@ func TestRealtimeConn_RTN15h3_Success(t *testing.T) { ablytest.Instantly.NoRecv(t, nil, stateChanges, t.Fatalf) } +func TestRealtimeConn_RTN15h_Integration_ClientInitiatedAuth(t *testing.T) { + t.Parallel() + app, restClient := ablytest.NewREST() + defer safeclose(t, app) + recorder := NewMessageRecorder() + + authCallbackTokens := []string{} + // Returns token that expires after 3 seconds causing disconnect every 3 seconds + authCallback := func(ctx context.Context, tp ably.TokenParams) (ably.Tokener, error) { + token, err := restClient.Auth.RequestToken(context.Background(), &ably.TokenParams{TTL: 3000}) + authCallbackTokens = append(authCallbackTokens, token.Token) + return token, err + } + + realtime, err := ably.NewRealtime( + ably.WithAutoConnect(false), + ably.WithDial(recorder.Dial), + ably.WithEnvironment(ablytest.Environment), + ably.WithAuthCallback(authCallback)) + + assert.NoError(t, err) + defer realtime.Close() + + err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) + + for i := 0; i < 3; i++ { + err = ablytest.Wait(ablytest.ConnWaiter(realtime, nil, ably.ConnectionEventConnecting), nil) + var errorInfo *ably.ErrorInfo + assert.Error(t, err) + assert.ErrorAs(t, err, &errorInfo) + assert.Equal(t, 401, errorInfo.StatusCode) + assert.Equal(t, 40142, int(errorInfo.Code)) + assert.ErrorContains(t, err, "token expired") + err = ablytest.Wait(ablytest.ConnWaiter(realtime, nil, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) + assert.Equal(t, ably.ConnectionStateConnected, realtime.Connection.State()) + } + + assert.True(t, ablytest.Instantly.IsTrue(recorder.CheckIfReceived(ably.ActionDisconnected, 3))) + tokens := []string{} + assert.Len(t, recorder.URLs(), 4) // 4 connect attempts made in total, disconnect received after each one + for _, url := range recorder.URLs() { + tokens = append(tokens, url.Query().Get("access_token")) + } + assert.Len(t, tokens, 4) // 4 tokens explicitly requested and supplied for every attempt + assertUnique(t, tokens) // Make sure all tokens are unique for every connection attempt + assert.ElementsMatch(t, authCallbackTokens, tokens) +} + func TestRealtimeConn_RTN15i_OnErrorWhenConnected(t *testing.T) { in := make(chan *ably.ProtocolMessage, 1)