Skip to content

Commit

Permalink
Fixed host fallback condition for rest client
Browse files Browse the repository at this point in the history
  • Loading branch information
sacOO7 committed Jan 19, 2024
1 parent b8e6391 commit 7ed5ace
Showing 1 changed file with 79 additions and 59 deletions.
138 changes: 79 additions & 59 deletions ably/rest_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
_ "crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"mime"
"net"
"net/http"
"net/http/httptrace"
"net/url"
Expand Down Expand Up @@ -716,70 +718,68 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R
}
if err != nil {
c.log.Error("RestClient: error handling response: ", err)
if e, ok := err.(*ErrorInfo); ok {
if canFallBack(e.StatusCode, resp) {
fallbacks, _ := c.opts.getFallbackHosts()
c.log.Infof("RestClient: trying to fallback with hosts=%v", fallbacks)
if len(fallbacks) > 0 {
left := fallbacks
iteration := 0
maxLimit := c.opts.HTTPMaxRetryCount
if maxLimit == 0 {
maxLimit = defaultOptions.HTTPMaxRetryCount
}
c.log.Infof("RestClient: maximum fallback retry limit=%d", maxLimit)
if canFallBack(err, resp) {
fallbacks, _ := c.opts.getFallbackHosts()
c.log.Infof("RestClient: trying to fallback with hosts=%v", fallbacks)
if len(fallbacks) > 0 {
left := fallbacks
iteration := 0
maxLimit := c.opts.HTTPMaxRetryCount
if maxLimit == 0 {
maxLimit = defaultOptions.HTTPMaxRetryCount
}
c.log.Infof("RestClient: maximum fallback retry limit=%d", maxLimit)

for {
if len(left) == 0 {
c.log.Errorf("RestClient: exhausted fallback hosts", err)
return nil, err
}
var h string
if len(left) == 1 {
h = left[0]
} else {
h = left[rand.Intn(len(left)-1)]
}
var n []string
for _, v := range left {
if v != h {
n = append(n, v)
}
for {
if len(left) == 0 {
c.log.Errorf("RestClient: exhausted fallback hosts", err)
return nil, err
}
var h string
if len(left) == 1 {
h = left[0]
} else {
h = left[rand.Intn(len(left)-1)]
}
var n []string
for _, v := range left {
if v != h {
n = append(n, v)
}
left = n
req, err := c.newHTTPRequest(ctx, r)
if err != nil {
}
left = n
req, err := c.newHTTPRequest(ctx, r)
if err != nil {
return nil, err
}
c.log.Infof("RestClient: chose fallback host=%q ", h)
req.URL.Host = h
req.Host = ""
req.Header.Set(hostHeader, h)
resp, err := c.opts.httpclient().Do(req)
if err == nil {
resp, err = handle(resp, r.Out)
} else {
c.log.Error("RestClient: failed sending a request to a fallback host", err)
}
if err != nil {
c.log.Error("RestClient: error handling response: ", err)
if iteration == maxLimit-1 {
return nil, err
}
c.log.Infof("RestClient: chose fallback host=%q ", h)
req.URL.Host = h
req.Host = ""
req.Header.Set(hostHeader, h)
resp, err := c.opts.httpclient().Do(req)
if err == nil {
resp, err = handle(resp, r.Out)
} else {
c.log.Error("RestClient: failed sending a request to a fallback host", err)
if canFallBack(err, resp) {
iteration++
continue
}
if err != nil {
c.log.Error("RestClient: error handling response: ", err)
if iteration == maxLimit-1 {
return nil, err
}
if ev, ok := err.(*ErrorInfo); ok {
if canFallBack(ev.StatusCode, resp) {
iteration++
continue
}
}
return nil, err
}
c.successFallbackHost.put(h)
return resp, nil
return nil, err
}
c.successFallbackHost.put(h)
return resp, nil
}
return nil, err
}
return nil, err
}
if e, ok := err.(*ErrorInfo); ok {
if e.Code == ErrTokenErrorUnspecified {
if r.NoRenew || !c.Auth.isTokenRenewable() {
return nil, err
Expand All @@ -796,9 +796,29 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R
return resp, nil
}

func canFallBack(statusCode int, res *http.Response) bool {
return (statusCode >= http.StatusInternalServerError && statusCode <= http.StatusGatewayTimeout) ||
(res != nil && strings.EqualFold(res.Header.Get("Server"), "CloudFront") && statusCode >= http.StatusBadRequest) // RSC15l4
func canFallBack(err error, res *http.Response) bool {
return isTimeoutOrDnsErr(err) || //RSC15l1, RSC15l2
(res.StatusCode >= http.StatusInternalServerError && //RSC15l3
res.StatusCode <= http.StatusGatewayTimeout) ||
isCloudFrontError(res) //RSC15l4
}

// RSC15l4
func isCloudFrontError(res *http.Response) bool {
return res != nil &&
strings.EqualFold(res.Header.Get("Server"), "CloudFront") &&
res.StatusCode >= http.StatusBadRequest
}

func isTimeoutOrDnsErr(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() { // RSC15l2
return true
}
}
var dnsErr *net.DNSError
return errors.As(err, &dnsErr) // RSC15l1
}

// newHTTPRequest creates a new http.Request that can be sent to ably endpoints.
Expand Down

0 comments on commit 7ed5ace

Please sign in to comment.