Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/host fallback conditions #629

Merged
merged 10 commits into from
Jan 19, 2024
138 changes: 79 additions & 59 deletions ably/rest_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
_ "crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"mime"
"net"
"net/http"
"net/http/httptrace"
"net/url"
Expand Down Expand Up @@ -716,70 +718,68 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R
}
if err != nil {
c.log.Error("RestClient: error handling response: ", err)
if e, ok := err.(*ErrorInfo); ok {
if canFallBack(e.StatusCode, resp) {
fallbacks, _ := c.opts.getFallbackHosts()
c.log.Infof("RestClient: trying to fallback with hosts=%v", fallbacks)
if len(fallbacks) > 0 {
left := fallbacks
iteration := 0
maxLimit := c.opts.HTTPMaxRetryCount
if maxLimit == 0 {
maxLimit = defaultOptions.HTTPMaxRetryCount
}
c.log.Infof("RestClient: maximum fallback retry limit=%d", maxLimit)
if canFallBack(err, resp) {
fallbacks, _ := c.opts.getFallbackHosts()
c.log.Infof("RestClient: trying to fallback with hosts=%v", fallbacks)
if len(fallbacks) > 0 {
left := fallbacks
iteration := 0
maxLimit := c.opts.HTTPMaxRetryCount
if maxLimit == 0 {
maxLimit = defaultOptions.HTTPMaxRetryCount
}
c.log.Infof("RestClient: maximum fallback retry limit=%d", maxLimit)

for {
if len(left) == 0 {
c.log.Errorf("RestClient: exhausted fallback hosts", err)
return nil, err
}
var h string
if len(left) == 1 {
h = left[0]
} else {
h = left[rand.Intn(len(left)-1)]
}
var n []string
for _, v := range left {
if v != h {
n = append(n, v)
}
for {
if len(left) == 0 {
c.log.Errorf("RestClient: exhausted fallback hosts", err)
return nil, err
}
var h string
if len(left) == 1 {
h = left[0]
} else {
h = left[rand.Intn(len(left)-1)]
}
var n []string
for _, v := range left {
if v != h {
n = append(n, v)
}
left = n
req, err := c.newHTTPRequest(ctx, r)
if err != nil {
}
left = n
req, err := c.newHTTPRequest(ctx, r)
if err != nil {
return nil, err
}
c.log.Infof("RestClient: chose fallback host=%q ", h)
req.URL.Host = h
req.Host = ""
req.Header.Set(hostHeader, h)
resp, err := c.opts.httpclient().Do(req)
if err == nil {
resp, err = handle(resp, r.Out)
} else {
c.log.Error("RestClient: failed sending a request to a fallback host", err)
}
if err != nil {
c.log.Error("RestClient: error handling response: ", err)
if iteration == maxLimit-1 {
return nil, err
}
c.log.Infof("RestClient: chose fallback host=%q ", h)
req.URL.Host = h
req.Host = ""
req.Header.Set(hostHeader, h)
resp, err := c.opts.httpclient().Do(req)
if err == nil {
resp, err = handle(resp, r.Out)
} else {
c.log.Error("RestClient: failed sending a request to a fallback host", err)
if canFallBack(err, resp) {
iteration++
continue
}
if err != nil {
c.log.Error("RestClient: error handling response: ", err)
if iteration == maxLimit-1 {
return nil, err
}
if ev, ok := err.(*ErrorInfo); ok {
if canFallBack(ev.StatusCode, resp) {
iteration++
continue
}
}
return nil, err
}
c.successFallbackHost.put(h)
return resp, nil
return nil, err
}
c.successFallbackHost.put(h)
return resp, nil
}
return nil, err
}
return nil, err
}
if e, ok := err.(*ErrorInfo); ok {
if e.Code == ErrTokenErrorUnspecified {
if r.NoRenew || !c.Auth.isTokenRenewable() {
return nil, err
Expand All @@ -796,9 +796,29 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R
return resp, nil
}

func canFallBack(statusCode int, res *http.Response) bool {
return (statusCode >= http.StatusInternalServerError && statusCode <= http.StatusGatewayTimeout) ||
(res != nil && strings.EqualFold(res.Header.Get("Server"), "CloudFront") && statusCode >= http.StatusBadRequest) // RSC15l4
func canFallBack(err error, res *http.Response) bool {
return isTimeoutOrDnsErr(err) || //RSC15l1, RSC15l2
(res.StatusCode >= http.StatusInternalServerError && //RSC15l3
res.StatusCode <= http.StatusGatewayTimeout) ||
isCloudFrontError(res) //RSC15l4
}

// RSC15l4
func isCloudFrontError(res *http.Response) bool {
return res != nil &&
strings.EqualFold(res.Header.Get("Server"), "CloudFront") &&
res.StatusCode >= http.StatusBadRequest
}

func isTimeoutOrDnsErr(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() { // RSC15l2
return true
}
}
var dnsErr *net.DNSError
return errors.As(err, &dnsErr) // RSC15l1
}

// newHTTPRequest creates a new http.Request that can be sent to ably endpoints.
Expand Down
53 changes: 51 additions & 2 deletions ably/rest_client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ func TestRest_hostfallback(t *testing.T) {
client, err := ably.NewREST(app.Options(append(options, ably.WithHTTPClient(newHTTPClientMock(server)))...)...)
assert.NoError(t, err)
err = client.Channels.Get("test").Publish(context.Background(), "ping", "pong")
assert.Error(t, err,
"expected an error")
assert.Error(t, err, "expected an error")
return retryCount, hosts
}
t.Run("RSC15d RSC15a must use alternative host", func(t *testing.T) {
Expand Down Expand Up @@ -368,6 +367,56 @@ func TestRest_hostfallback(t *testing.T) {
uniq[h] = true
}
})

runTestServerWithRequestTimeout := func(t *testing.T, options []ably.ClientOption) (int, []string) {
var retryCount int
var hosts []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hosts = append(hosts, r.Host)
retryCount++
time.Sleep(2 * time.Second)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
httpClientMock := &http.Client{
Timeout: 1 * time.Second,
Transport: &http.Transport{
Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
},
}
client, err := ably.NewREST(app.Options(append(options, ably.WithHTTPClient(httpClientMock))...)...)
assert.NoError(t, err)
err = client.Channels.Get("test").Publish(context.Background(), "ping", "pong")
assert.Contains(t, err.Error(), "context deadline exceeded (Client.Timeout exceeded while awaiting headers)")
return retryCount, hosts
}

t.Run("RSC15l2 must use alternative host on timeout", func(t *testing.T) {

options := []ably.ClientOption{
ably.WithFallbackHosts(ably.DefaultFallbackHosts()),
ably.WithTLS(false),
ably.WithUseTokenAuth(true),
}
retryCount, hosts := runTestServerWithRequestTimeout(t, options)
assert.Equal(t, 4, retryCount, "expected 4 http calls got %d", retryCount)
// make sure the host header is set. Since we are using defaults from the spec
// the hosts should be in [a..e].ably-realtime.com
expect := strings.Join(ably.DefaultFallbackHosts(), ", ")
for _, host := range hosts[1:] {
assert.Contains(t, expect, host, "expected %s got be in %s", host, expect)
}

// ensure all picked fallbacks are unique
uniq := make(map[string]bool)
for _, h := range hosts {
_, ok := uniq[h]
assert.False(t, ok,
"duplicate fallback %s", h)
uniq[h] = true
}
})

t.Run("rsc15b", func(t *testing.T) {
t.Run("must not occur when default rest.ably.io is overridden", func(t *testing.T) {

Expand Down
Loading