diff --git a/ably/ably_test.go b/ably/ably_test.go index 55f65d5b5..13ed51655 100644 --- a/ably/ably_test.go +++ b/ably/ably_test.go @@ -191,7 +191,7 @@ func (pc pipeConn) Close() error { // MessageRecorder type MessageRecorder struct { mu sync.Mutex - url []*url.URL + urls []*url.URL sent []*ably.ProtocolMessage received []*ably.ProtocolMessage } @@ -207,7 +207,7 @@ func NewMessageRecorder() *MessageRecorder { // Reset resets the recorded urls, sent and received messages func (rec *MessageRecorder) Reset() { rec.mu.Lock() - rec.url = nil + rec.urls = nil rec.sent = nil rec.received = nil rec.mu.Unlock() @@ -216,7 +216,7 @@ func (rec *MessageRecorder) Reset() { // Dial func (rec *MessageRecorder) Dial(proto string, u *url.URL, timeout time.Duration) (ably.Conn, error) { rec.mu.Lock() - rec.url = append(rec.url, u) + rec.urls = append(rec.urls, u) rec.mu.Unlock() conn, err := ably.DialWebsocket(proto, u, timeout) if err != nil { @@ -229,11 +229,11 @@ func (rec *MessageRecorder) Dial(proto string, u *url.URL, timeout time.Duration } // URL -func (rec *MessageRecorder) URL() []*url.URL { +func (rec *MessageRecorder) URLs() []*url.URL { rec.mu.Lock() defer rec.mu.Unlock() - newUrl := make([]*url.URL, len(rec.url)) - copy(newUrl, rec.url) + newUrl := make([]*url.URL, len(rec.urls)) + copy(newUrl, rec.urls) return newUrl } @@ -522,3 +522,24 @@ var canceledCtx context.Context = func() context.Context { cancel() return ctx }() + +func assertSubset(t *testing.T, set []string, subset []string) { + t.Helper() + for _, item := range subset { + if !ablyutil.SliceContains(set, item) { + t.Errorf("expected %s got be in %s", item, set) + } + } +} + +func assertUnique(t *testing.T, list []string) { + t.Helper() + hashSet := ablyutil.NewHashSet() + for _, item := range list { + if hashSet.Has(item) { + t.Errorf("duplicate item %s", item) + } else { + hashSet.Add(item) + } + } +} diff --git a/ably/auth_integration_test.go b/ably/auth_integration_test.go index 498b4539f..6803b1d60 100644 --- a/ably/auth_integration_test.go +++ b/ably/auth_integration_test.go @@ -606,7 +606,7 @@ func TestAuth_RealtimeAccessToken(t *testing.T) { err = ablytest.FullRealtimeCloser(client).Close() assert.NoError(t, err, "Close()=%v", err) - recUrls := rec.URL() + recUrls := rec.URLs() assert.NotEqual(t, 0, len(recUrls), "want urls to be non-empty") for _, recUrl := range recUrls { diff --git a/ably/error_test.go b/ably/error_test.go index 59bdcd6ff..c5eb860d8 100644 --- a/ably/error_test.go +++ b/ably/error_test.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "net/http/httptest" "net/url" @@ -146,3 +147,41 @@ func TestIssue_154(t *testing.T) { assert.Equal(t, http.StatusMethodNotAllowed, et.StatusCode, "expected %d got %d: %v", http.StatusMethodNotAllowed, et.StatusCode, err) } + +func Test_DNSOrTimeoutErr(t *testing.T) { + dnsErr := net.DNSError{ + Err: "Can't resolve host", + Name: "Host unresolvable", + Server: "rest.ably.com", + IsTimeout: false, + IsTemporary: false, + IsNotFound: false, + } + + WrappedDNSErr := fmt.Errorf("custom error occured %w", &dnsErr) + if !ably.IsTimeoutOrDnsErr(WrappedDNSErr) { + t.Fatalf("%v is a DNS error", WrappedDNSErr) + } + + urlErr := url.Error{ + URL: "rest.ably.io", + Err: errors.New("URL error occured"), + Op: "IO read OP", + } + + if ably.IsTimeoutOrDnsErr(&urlErr) { + t.Fatalf("%v is not a DNS or timeout error", urlErr) + } + + urlErr.Err = &dnsErr + + if !ably.IsTimeoutOrDnsErr(WrappedDNSErr) { + t.Fatalf("%v is a DNS error", urlErr) + } + + dnsErr.IsTimeout = true + + if !ably.IsTimeoutOrDnsErr(WrappedDNSErr) { + t.Fatalf("%v is a timeout error", urlErr) + } +} diff --git a/ably/export_test.go b/ably/export_test.go index 55aaef87f..7fcbd64e2 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -65,6 +65,10 @@ func UnwrapStatusCode(err error) int { return statusCode(err) } +func IsTimeoutOrDnsErr(err error) bool { + return isTimeoutOrDnsErr(err) +} + func (a *Auth) Timestamp(ctx context.Context, query bool) (time.Time, error) { return a.timestamp(ctx, query) } diff --git a/ably/http_paginated_response_integration_test.go b/ably/http_paginated_response_integration_test.go index ed196accf..d47285b29 100644 --- a/ably/http_paginated_response_integration_test.go +++ b/ably/http_paginated_response_integration_test.go @@ -20,7 +20,9 @@ func TestHTTPPaginatedFallback(t *testing.T) { app, err := ablytest.NewSandbox(nil) assert.NoError(t, err) defer app.Close() - opts := app.Options(ably.WithUseBinaryProtocol(false), ably.WithRESTHost("ably.invalid"), ably.WithFallbackHostsUseDefault(true)) + opts := app.Options(ably.WithUseBinaryProtocol(false), + ably.WithRESTHost("ably.invalid"), + ably.WithFallbackHosts(nil)) client, err := ably.NewREST(opts...) assert.NoError(t, err) t.Run("request_time", func(t *testing.T) { diff --git a/ably/internal/ablyutil/strings.go b/ably/internal/ablyutil/strings.go new file mode 100644 index 000000000..5c3514a8d --- /dev/null +++ b/ably/internal/ablyutil/strings.go @@ -0,0 +1,74 @@ +package ablyutil + +import ( + "math/rand" + "sort" + "strings" + "time" +) + +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +func GenerateRandomString(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +} + +type HashSet map[string]struct{} // struct {} has zero space complexity + +func NewHashSet() HashSet { + return make(HashSet) +} + +func (s HashSet) Add(item string) { + s[item] = struct{}{} +} + +func (s HashSet) Remove(item string) { + delete(s, item) +} + +func (s HashSet) Has(item string) bool { + _, ok := s[item] + return ok +} + +func Copy(list []string) []string { + copiedList := make([]string, len(list)) + copy(copiedList, list) + return copiedList +} + +func Sort(list []string) []string { + copiedList := Copy(list) + sort.Strings(copiedList) + return copiedList +} + +func Shuffle(list []string) []string { + copiedList := Copy(list) + if len(copiedList) <= 1 { + return copiedList + } + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(copiedList), func(i, j int) { copiedList[i], copiedList[j] = copiedList[j], copiedList[i] }) + return copiedList +} + +func SliceContains(s []string, str string) bool { + for _, v := range s { + if v == str { + return true + } + } + return false +} + +func Empty(s string) bool { + return len(strings.TrimSpace(s)) == 0 +} diff --git a/ably/internal/ablyutil/strings_test.go b/ably/internal/ablyutil/strings_test.go new file mode 100644 index 000000000..0ae8f7b59 --- /dev/null +++ b/ably/internal/ablyutil/strings_test.go @@ -0,0 +1,111 @@ +package ablyutil_test + +import ( + "testing" + + "github.com/ably/ably-go/ably/internal/ablyutil" + "github.com/stretchr/testify/assert" +) + +func Test_string(t *testing.T) { + t.Run("String array Shuffle", func(t *testing.T) { + t.Parallel() + + strList := []string{} + shuffledList := ablyutil.Shuffle(strList) + assert.Equal(t, strList, shuffledList) + + strList = []string{"a"} + shuffledList = ablyutil.Shuffle(strList) + assert.Equal(t, strList, shuffledList) + + strList = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p"} + shuffledList = ablyutil.Shuffle(strList) + assert.NotEqual(t, strList, shuffledList) + assert.Equal(t, ablyutil.Sort(strList), ablyutil.Sort(shuffledList)) + }) + + t.Run("String array contains", func(t *testing.T) { + t.Parallel() + strarr := []string{"apple", "banana", "dragonfruit"} + + if !ablyutil.SliceContains(strarr, "apple") { + t.Error("String array should contain apple") + } + if ablyutil.SliceContains(strarr, "orange") { + t.Error("String array should not contain orange") + } + }) + + t.Run("Empty String", func(t *testing.T) { + t.Parallel() + str := "" + if !ablyutil.Empty(str) { + t.Error("String should be empty") + } + str = " " + if !ablyutil.Empty(str) { + t.Error("String should be empty") + } + str = "ab" + if ablyutil.Empty(str) { + t.Error("String should not be empty") + } + }) +} + +func TestHashSet(t *testing.T) { + t.Run("Add should not duplicate entries", func(t *testing.T) { + hashSet := ablyutil.NewHashSet() + hashSet.Add("apple") + hashSet.Add("apple") + assert.Len(t, hashSet, 1) + + hashSet.Add("banana") + assert.Len(t, hashSet, 2) + + hashSet.Add("orange") + assert.Len(t, hashSet, 3) + + hashSet.Add("banana") + hashSet.Add("apple") + hashSet.Add("orange") + hashSet.Add("orange") + + assert.Len(t, hashSet, 3) + }) + + t.Run("Should check if item is present", func(t *testing.T) { + hashSet := ablyutil.NewHashSet() + hashSet.Add("apple") + hashSet.Add("orange") + if !hashSet.Has("apple") { + t.Fatalf("Set should contain apple") + } + if hashSet.Has("banana") { + t.Fatalf("Set shouldm't contain banana") + } + if !hashSet.Has("orange") { + t.Fatalf("Set should contain orange") + } + }) + + t.Run("Should remove element", func(t *testing.T) { + hashSet := ablyutil.NewHashSet() + hashSet.Add("apple") + assert.Len(t, hashSet, 1) + + hashSet.Add("orange") + assert.Len(t, hashSet, 2) + + hashSet.Remove("apple") + assert.Len(t, hashSet, 1) + + if hashSet.Has("apple") { + t.Fatalf("Set shouldm't contain apple") + } + hashSet.Remove("orange") + assert.Len(t, hashSet, 0) + + }) +} diff --git a/ably/realtime_channel_integration_test.go b/ably/realtime_channel_integration_test.go index 88dadd5d5..c5a0ec0a5 100644 --- a/ably/realtime_channel_integration_test.go +++ b/ably/realtime_channel_integration_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/ably/ably-go/ably" + "github.com/ably/ably-go/ably/internal/ablyutil" "github.com/ably/ably-go/ablytest" "github.com/stretchr/testify/assert" @@ -320,7 +321,7 @@ func TestRealtimeChannel_ShouldReturnErrorIfReadLimitExceeded(t *testing.T) { assert.NoError(t, err, "client2:.Subscribe(context.Background())=%v", err) defer unsub2() - messageWith2MbSize := ablytest.GenerateRandomString(2048) + messageWith2MbSize := ablyutil.GenerateRandomString(2048) err = channel1.Publish(context.Background(), "hello", messageWith2MbSize) assert.NoError(t, err, "client1: Publish()=%v", err) diff --git a/ably/rest_client.go b/ably/rest_client.go index b25a88b99..0698bf510 100644 --- a/ably/rest_client.go +++ b/ably/rest_client.go @@ -6,11 +6,13 @@ import ( _ "crypto/sha512" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "io/ioutil" "math/rand" "mime" + "net" "net/http" "net/http/httptrace" "net/url" @@ -709,6 +711,7 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R c.log.Verbose("RestClient: enabling httptrace") } resp, err := c.opts.httpclient().Do(req) + serverResp := resp if err == nil { resp, err = handle(resp, r.Out) } else { @@ -716,70 +719,69 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R } if err != nil { c.log.Error("RestClient: error handling response: ", err) - if e, ok := err.(*ErrorInfo); ok { - if canFallBack(e.StatusCode, resp) { - fallbacks, _ := c.opts.getFallbackHosts() - c.log.Infof("RestClient: trying to fallback with hosts=%v", fallbacks) - if len(fallbacks) > 0 { - left := fallbacks - iteration := 0 - maxLimit := c.opts.HTTPMaxRetryCount - if maxLimit == 0 { - maxLimit = defaultOptions.HTTPMaxRetryCount - } - c.log.Infof("RestClient: maximum fallback retry limit=%d", maxLimit) + if canFallBack(err, serverResp) { + fallbacks, _ := c.opts.getFallbackHosts() + c.log.Infof("RestClient: trying to fallback with hosts=%v", fallbacks) + if len(fallbacks) > 0 { + left := fallbacks + iteration := 0 + maxLimit := c.opts.HTTPMaxRetryCount + if maxLimit == 0 { + maxLimit = defaultOptions.HTTPMaxRetryCount + } + c.log.Infof("RestClient: maximum fallback retry limit=%d", maxLimit) - for { - if len(left) == 0 { - c.log.Errorf("RestClient: exhausted fallback hosts", err) - return nil, err - } - var h string - if len(left) == 1 { - h = left[0] - } else { - h = left[rand.Intn(len(left)-1)] - } - var n []string - for _, v := range left { - if v != h { - n = append(n, v) - } + for { + if len(left) == 0 { + c.log.Errorf("RestClient: exhausted fallback hosts", err) + return nil, err + } + var h string + if len(left) == 1 { + h = left[0] + } else { + h = left[rand.Intn(len(left)-1)] + } + var n []string + for _, v := range left { + if v != h { + n = append(n, v) } - left = n - req, err := c.newHTTPRequest(ctx, r) - if err != nil { + } + left = n + req, err := c.newHTTPRequest(ctx, r) + if err != nil { + return nil, err + } + c.log.Infof("RestClient: chose fallback host=%q ", h) + req.URL.Host = h + req.Host = "" + req.Header.Set(hostHeader, h) + resp, err := c.opts.httpclient().Do(req) + serverResp := resp + if err == nil { + resp, err = handle(resp, r.Out) + } else { + c.log.Error("RestClient: failed sending a request to a fallback host", err) + } + if err != nil { + c.log.Error("RestClient: error handling response: ", err) + if iteration == maxLimit-1 { return nil, err } - c.log.Infof("RestClient: chose fallback host=%q ", h) - req.URL.Host = h - req.Host = "" - req.Header.Set(hostHeader, h) - resp, err := c.opts.httpclient().Do(req) - if err == nil { - resp, err = handle(resp, r.Out) - } else { - c.log.Error("RestClient: failed sending a request to a fallback host", err) - } - if err != nil { - c.log.Error("RestClient: error handling response: ", err) - if iteration == maxLimit-1 { - return nil, err - } - if ev, ok := err.(*ErrorInfo); ok { - if canFallBack(ev.StatusCode, resp) { - iteration++ - continue - } - } - return nil, err + if canFallBack(err, serverResp) { + iteration++ + continue } - c.successFallbackHost.put(h) - return resp, nil + return nil, err } + c.successFallbackHost.put(h) + return resp, nil } - return nil, err } + return nil, err + } + if e, ok := err.(*ErrorInfo); ok { if e.Code == ErrTokenErrorUnspecified { if r.NoRenew || !c.Auth.isTokenRenewable() { return nil, err @@ -796,9 +798,35 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R return resp, nil } -func canFallBack(statusCode int, res *http.Response) bool { - return (statusCode >= http.StatusInternalServerError && statusCode <= http.StatusGatewayTimeout) || - (res != nil && strings.EqualFold(res.Header.Get("Server"), "CloudFront") && statusCode >= http.StatusBadRequest) // RSC15l4 +func canFallBack(err error, res *http.Response) bool { + return isStatusCodeBetween500_504(res) || // RSC15l3 + isCloudFrontError(res) || //RSC15l4 + isTimeoutOrDnsErr(err) //RSC15l1, RSC15l2 +} + +func isTimeoutOrDnsErr(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { // RSC15l2 + return true + } + } + var dnsErr *net.DNSError + return errors.As(err, &dnsErr) // RSC15l1 +} + +// RSC15l3 +func isStatusCodeBetween500_504(res *http.Response) bool { + return res != nil && + res.StatusCode >= http.StatusInternalServerError && + res.StatusCode <= http.StatusGatewayTimeout +} + +// RSC15l4 +func isCloudFrontError(res *http.Response) bool { + return res != nil && + strings.EqualFold(res.Header.Get("Server"), "CloudFront") && + res.StatusCode >= http.StatusBadRequest } // newHTTPRequest creates a new http.Request that can be sent to ably endpoints. diff --git a/ably/rest_client_integration_test.go b/ably/rest_client_integration_test.go index 597c231f5..bf4c823b0 100644 --- a/ably/rest_client_integration_test.go +++ b/ably/rest_client_integration_test.go @@ -17,6 +17,7 @@ import ( "net/url" "regexp" "strings" + "sync" "sync/atomic" "testing" "time" @@ -320,7 +321,7 @@ func TestRest_RSC7_AblyAgent(t *testing.T) { }) } -func TestRest_hostfallback(t *testing.T) { +func TestRest_RSC15_HostFallback(t *testing.T) { app, err := ablytest.NewSandbox(nil) assert.NoError(t, err) @@ -329,7 +330,7 @@ func TestRest_hostfallback(t *testing.T) { var retryCount int var hosts []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hosts = append(hosts, r.Host) + hosts = append(hosts, strings.Split(r.Host, ":")[0]) retryCount++ w.WriteHeader(http.StatusInternalServerError) })) @@ -337,37 +338,93 @@ func TestRest_hostfallback(t *testing.T) { client, err := ably.NewREST(app.Options(append(options, ably.WithHTTPClient(newHTTPClientMock(server)))...)...) assert.NoError(t, err) err = client.Channels.Get("test").Publish(context.Background(), "ping", "pong") - assert.Error(t, err, - "expected an error") + assert.Error(t, err, "expected an error") return retryCount, hosts } - t.Run("RSC15d RSC15a must use alternative host", func(t *testing.T) { + t.Run("RSC15a, RSC15b, RSC15d, RSC15g3: must use alternative host", func(t *testing.T) { options := []ably.ClientOption{ ably.WithFallbackHosts(ably.DefaultFallbackHosts()), ably.WithTLS(false), + ably.WithEnvironment(""), // remove default sandbox env + ably.WithHTTPMaxRetryCount(10), ably.WithUseTokenAuth(true), } retryCount, hosts := runTestServer(t, options) - assert.Equal(t, 4, retryCount, - "expected 4 http calls got %d", retryCount) - // make sure the host header is set. Since we are using defaults from the spec - // the hosts should be in [a..e].ably-realtime.com - expect := strings.Join(ably.DefaultFallbackHosts(), ", ") - for _, host := range hosts[1:] { - assert.Contains(t, expect, host, - "expected %s got be in %s", host, expect) + assert.Equal(t, 6, retryCount) // 1 primary and 5 default fallback hosts + assert.Equal(t, "rest.ably.io", hosts[0]) // primary host + assertSubset(t, ably.DefaultFallbackHosts(), hosts[1:]) // remaining fallback hosts + assertUnique(t, hosts) // ensure all picked fallbacks are unique + }) + + runTestServerWithRequestTimeout := func(t *testing.T, options []ably.ClientOption) (int, []string) { + var retryCount int + var hosts []string + allHostsTried := make(chan struct{}, 1) + var mtx sync.Mutex + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mtx.Lock() + defer mtx.Unlock() + hosts = append(hosts, strings.Split(r.Host, ":")[0]) + retryCount++ + time.Sleep(2 * time.Second) + if retryCount == 6 { + allHostsTried <- struct{}{} + } + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + httpClientMock := &http.Client{ + Timeout: 1 * time.Second, + Transport: &http.Transport{ + Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) }, + }, + } + client, err := ably.NewREST(app.Options(append(options, ably.WithHTTPClient(httpClientMock))...)...) + assert.NoError(t, err) + err = client.Channels.Get("test").Publish(context.Background(), "ping", "pong") + <-allHostsTried + assert.Contains(t, err.Error(), "context deadline exceeded (Client.Timeout exceeded while awaiting headers)") + return retryCount, hosts + } + + t.Run("RSC15l2 must use alternative host on timeout", func(t *testing.T) { + + options := []ably.ClientOption{ + ably.WithFallbackHosts(ably.DefaultFallbackHosts()), + ably.WithTLS(false), + ably.WithEnvironment(""), // remove default sandbox env + ably.WithHTTPMaxRetryCount(10), + ably.WithUseTokenAuth(true), } + retryCount, hosts := runTestServerWithRequestTimeout(t, options) + assert.Equal(t, 6, retryCount) // 1 primary and 5 default fallback hosts + assert.Equal(t, "rest.ably.io", hosts[0]) // primary host + assertSubset(t, ably.DefaultFallbackHosts(), hosts[1:]) // remaining fallback hosts + assertUnique(t, hosts) // ensure all picked fallbacks are unique + }) - // ensure all picked fallbacks are unique - uniq := make(map[string]bool) - for _, h := range hosts { - _, ok := uniq[h] - assert.False(t, ok, - "duplicate fallback %s", h) - uniq[h] = true + t.Run("RSC15l1 must use alternative host on host unresolvable or unreachable", func(t *testing.T) { + options := []ably.ClientOption{ + ably.WithFallbackHosts(ably.DefaultFallbackHosts()), + ably.WithRESTHost("foobar.ably.com"), + ably.WithFallbackHosts([]string{ + "spam.ably.com", + "tatto.ably.com", + "rest.ably.io"}), + ably.WithTLS(false), + ably.WithUseTokenAuth(true), } + client, err := ably.NewREST(app.Options(options...)...) + assert.NoError(t, err) + tm, err := client.Time(context.Background()) + assert.Nil(t, err) + assert.NotNil(t, tm) + time.Sleep(2 * time.Second) + cachedFallbackHost := client.GetCachedFallbackHost() + assert.Equal(t, "rest.ably.io", cachedFallbackHost) }) + t.Run("rsc15b", func(t *testing.T) { t.Run("must not occur when default rest.ably.io is overridden", func(t *testing.T) { diff --git a/ablytest/ablytest.go b/ablytest/ablytest.go index dba905ad0..fc7b83630 100644 --- a/ablytest/ablytest.go +++ b/ablytest/ablytest.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "math/rand" "net/http" "os" "reflect" @@ -210,15 +209,3 @@ func TimeFuncs(afterCalls chan<- AfterCall) ( return now, after } - -const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) - -func GenerateRandomString(length int) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] - } - return string(b) -}