diff --git a/modules/caddyhttp/reverseproxy/admin.go b/modules/caddyhttp/reverseproxy/admin.go index 7e72a4cdb51..3e445c9b7b5 100644 --- a/modules/caddyhttp/reverseproxy/admin.go +++ b/modules/caddyhttp/reverseproxy/admin.go @@ -102,6 +102,33 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er }) return true }) + // Iterate over the inflight hosts + inflightHosts.Range(func(key, val any) bool { + address, ok := key.(string) + if !ok { + rangeErr = caddy.APIError{ + HTTPStatus: http.StatusInternalServerError, + Err: fmt.Errorf("could not type assert upstream address"), + } + return false + } + + upstream, ok := val.(*Host) + if !ok { + rangeErr = caddy.APIError{ + HTTPStatus: http.StatusInternalServerError, + Err: fmt.Errorf("could not type assert upstream struct"), + } + return false + } + + results = append(results, upstreamStatus{ + Address: address, + NumRequests: upstream.NumRequests(), + Fails: upstream.Fails(), + }) + return true + }) // If an error happened during the range, return it if rangeErr != nil { diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index 300003f2b87..345f71d4a83 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -132,6 +132,16 @@ func (u *Upstream) fillHost() { u.Host = host } +func (u *Upstream) fillInfilghtHost(numRemaiRequests int) { + host := new(Host) + existingHost, loaded := inflightHosts.LoadOrStore(u.String(), host) + if loaded { + host = existingHost.(*Host) + } + _ = host.countRequest(numRemaiRequests) + u.Host = host +} + // Host is the basic, in-memory representation of the state of a remote host. // Its fields are accessed atomically and Host values must not be copied. type Host struct { @@ -268,6 +278,10 @@ func GetDialInfo(ctx context.Context) (DialInfo, bool) { // through config reloads. var hosts = caddy.NewUsagePool() +// inflightHosts is the global repository for hosts that are +// currently in use by inflight upstream request. +var inflightHosts = caddy.NewUsagePool() + // dialInfoVarKey is the key used for the variable that holds // the dial info for the upstream connection. const dialInfoVarKey = "reverse_proxy.dial_info" diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index c8b10581ae8..31bb587973a 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -394,6 +394,9 @@ func (h *Handler) Cleanup() error { // remove hosts from our config from the pool for _, upstream := range h.Upstreams { + if upstream.NumRequests() > 0 { + upstream.fillInfilghtHost(upstream.NumRequests()) + } _, _ = hosts.Delete(upstream.String()) } @@ -458,8 +461,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht } var done bool - done, proxyErr = h.proxyLoopIteration(clonedReq, r, w, proxyErr, start, retries, repl, reqHeader, reqHost, next) + done, dialInfo, proxyErr := h.proxyLoopIteration(clonedReq, r, w, proxyErr, start, retries, repl, reqHeader, reqHost, next) if done { + key := dialInfo.Address + val := inflightHosts.Load(key) + if val != nil { + host, _ := val.(*Host) + if host.NumRequests() <= 0 { + _, _ = inflightHosts.Delete(key) + } + } break } if h.VerboseLogs { @@ -490,7 +501,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // be assigned to the proxyErr value for the next iteration of the loop (or the error handled after break). func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w http.ResponseWriter, proxyErr error, start time.Time, retries int, repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler, -) (bool, error) { +) (bool, *DialInfo, error) { // get the updated list of upstreams upstreams := h.Upstreams if h.DynamicUpstreams != nil { @@ -524,9 +535,9 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h proxyErr = caddyhttp.Error(http.StatusServiceUnavailable, errNoUpstream) } if !h.LoadBalancing.tryAgain(h.ctx, start, retries, proxyErr, r, h.logger) { - return true, proxyErr + return true, nil, proxyErr } - return false, proxyErr + return false, nil, proxyErr } // the dial address may vary per-request if placeholders are @@ -534,7 +545,7 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h // DialInfo struct should have valid network address syntax dialInfo, err := upstream.fillDialInfo(repl) if err != nil { - return true, fmt.Errorf("making dial info: %v", err) + return true, nil, fmt.Errorf("making dial info: %v", err) } if c := h.logger.Check(zapcore.DebugLevel, "selected upstream"); c != nil { @@ -574,7 +585,7 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h if proxyErr == nil || errors.Is(proxyErr, context.Canceled) { // context.Canceled happens when the downstream client // cancels the request, which is not our failure - return true, nil + return true, &dialInfo, nil } // if the roundtrip was successful, don't retry the request or @@ -582,7 +593,7 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h // occur after the roundtrip if, for example, a response handler // after the roundtrip returns an error) if succ, ok := proxyErr.(roundtripSucceededError); ok { - return true, succ.error + return true, &dialInfo, succ.error } // remember this failure (if enabled) @@ -590,10 +601,10 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h // if we've tried long enough, break if !h.LoadBalancing.tryAgain(h.ctx, start, retries, proxyErr, r, h.logger) { - return true, proxyErr + return true, &dialInfo, proxyErr } - return false, proxyErr + return false, &dialInfo, proxyErr } // Mapping of the canonical form of the headers, to the RFC 6455 form, @@ -829,8 +840,14 @@ func (h Handler) addForwardedHeaders(req *http.Request) error { func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origReq *http.Request, repl *caddy.Replacer, di DialInfo, next caddyhttp.Handler) error { _ = di.Upstream.Host.countRequest(1) //nolint:errcheck - defer di.Upstream.Host.countRequest(-1) - + defer func() { + di.Upstream.Host.countRequest(-1) + inflightHost := inflightHosts.Load(di.Address) + if inflightHost != nil { + host, _ := inflightHost.(*Host) + host.countRequest(-1) + } + }() // point the request to this upstream h.directRequest(req, di) diff --git a/usagepool.go b/usagepool.go index a6466b9b124..e410feda269 100644 --- a/usagepool.go +++ b/usagepool.go @@ -194,6 +194,16 @@ func (up *UsagePool) Delete(key any) (deleted bool, err error) { return deleted, err } +func (up *UsagePool) Load(key any) any { + up.RLock() + defer up.RUnlock() + upv, loaded := up.pool[key] + if loaded { + return upv.value + } + return nil +} + // References returns the number of references (count of usages) to a // key in the pool, and true if the key exists, or false otherwise. func (up *UsagePool) References(key any) (int, bool) {