Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SetClientBehavior method to allow users can select proxy client's do behavior #19

Merged
merged 10 commits into from
May 31, 2024
44 changes: 44 additions & 0 deletions proxy_client_behavior.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package reverseproxy

import "time"

type clientBehaviorType int

const (
do clientBehaviorType = iota
doDeadline
doRedirects
doTimeout
)

type clientBehavior struct {
clientBehaviorType clientBehaviorType
param interface{}
}

func ClientDo() clientBehavior {
return clientBehavior{
clientBehaviorType: do,
}
}

func ClientDoRedirects(param int) clientBehavior {
return clientBehavior{
clientBehaviorType: doRedirects,
param: param,
}
}

func ClientDoDeadline(param time.Time) clientBehavior {
return clientBehavior{
clientBehaviorType: doDeadline,
param: param,
}
}

func ClientDoTimeout(param time.Time) clientBehavior {
return clientBehavior{
clientBehaviorType: doTimeout,
param: param,
}
}
40 changes: 35 additions & 5 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"reflect"
"strings"
"sync"
"time"
"unsafe"

"github.com/cloudwego/hertz/pkg/app"
Expand All @@ -45,6 +46,8 @@ import (
type ReverseProxy struct {
client *client.Client

clientBehavior clientBehavior

// target is set as a reverse proxy address
Target string

Expand Down Expand Up @@ -275,11 +278,8 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
req.Header.Add("X-Forwarded-For", ip)
}
}
fn := client.Do
if r.client != nil {
fn = r.client.Do
}
err := fn(c, req, resp)

err := r.doClientBehavior(c, req, resp)
if err != nil {
hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error())
r.getErrorHandler()(ctx, err)
Expand Down Expand Up @@ -345,13 +345,43 @@ func (r *ReverseProxy) SetSaveOriginResHeader(b bool) {
r.saveOriginResHeader = b
}

func (r *ReverseProxy) SetClientBehavior(cb clientBehavior) {
r.clientBehavior = cb
}

func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) {
if r.errorHandler != nil {
return r.errorHandler
}
return r.defaultErrorHandler
}

func (r *ReverseProxy) doClientBehavior(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
var err error
switch r.clientBehavior.clientBehaviorType {
case do:
err = r.client.Do(ctx, req, resp)
case doDeadline:
deadline := r.clientBehavior.param.(time.Time)
err = r.client.DoDeadline(ctx, req, resp, deadline)
case doRedirects:
maxRedirectsCount := r.clientBehavior.param.(int)
err = r.client.DoRedirects(ctx, req, resp, maxRedirectsCount)
case doTimeout:
timeout := r.clientBehavior.param.(time.Time)
err = r.client.DoDeadline(ctx, req, resp, timeout)
}
return err
}

func (r *ReverseProxy) do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
dragonYang200 marked this conversation as resolved.
Show resolved Hide resolved
if r.client != nil {
return r.client.Do(ctx, req, resp)
} else {
return client.Do(ctx, req, resp)
}
}

// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
Expand Down
22 changes: 22 additions & 0 deletions reverse_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,25 @@ func TestReverseProxySaveRespHeader(t *testing.T) {
}
assert.DeepEqual(t, "bbb", res.Header.Get("aaa"))
}

func TestReverseProxySetClientBehavior(t *testing.T) {
r := server.New(server.WithHostPorts("127.0.0.1:9998"))

r.POST("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) {
ctx.GetConn().Close()
})
proxy, _ := NewSingleHostReverseProxy("http://127.0.0.1:9998/proxy")
proxy.SetClientBehavior(ClientDo())
r.POST("/backend", proxy.ServeHTTP)
go r.Spin()
time.Sleep(time.Second)
cli, _ := client.NewClient()
req := protocol.AcquireRequest()
req.SetMethod("POST")
resp := protocol.AcquireResponse()
req.SetRequestURI("http://127.0.0.1:9998/backend")
cli.Do(context.Background(), req, resp)
if g, e := resp.StatusCode(), http.StatusTeapot; g != e {
t.Errorf("got res.StatusCode %d; expected %d", g, e)
}
}
Loading