diff --git a/ROADMAP.md b/ROADMAP.md index b929b2559..0e29e72f7 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -3,19 +3,20 @@ This document shows key roadmap of Hertz development from the year of 2022 to 20 # New Features: - Community Build - - Support more middlewares for users, like sessions、gzip. - - Support reverse proxy. + - [x] Support more middlewares for users, like sessions、gzip. + - [x] Support reverse proxy. - Support swagger. - Protocol - - Support Websocket. - - Support HTTP2. + - [x] Support Websocket. + - [x] Support HTTP2. + - [x] Support HTTP3. - Service Governance - Support more extension for users. - Performance Optimization - - Improve the server throughput in small packet case. + - Improve the server throughput in tiny packet case. - User Experience Optimization - - Provide good development practices for users to develop with Hertz more easily. - - Improve code generation tool(hz) usability. + - [x] Provide good development practices for users to develop with Hertz more easily. + - [x] Improve code generation tool(hz) usability. All developers are welcome to contribute your extension to [hertz-contrib](https://github.com/hertz-contrib). diff --git a/cmd/hz/generator/custom_files.go b/cmd/hz/generator/custom_files.go index 1efaae07a..1315948f1 100644 --- a/cmd/hz/generator/custom_files.go +++ b/cmd/hz/generator/custom_files.go @@ -192,7 +192,7 @@ func renderImportTpl(tplInfo *Template, data interface{}) ([]string, error) { // renderAppendContent used to render append content for 'update' command func renderAppendContent(tplInfo *Template, renderInfo interface{}) (string, error) { - tpl, err := template.New(tplInfo.Path).Parse(tplInfo.UpdateBehavior.AppendTpl) + tpl, err := template.New(tplInfo.Path).Funcs(funcMap).Parse(tplInfo.UpdateBehavior.AppendTpl) if err != nil { return "", fmt.Errorf("parse append content template(%s) failed, err: %v", tplInfo.Path, err) } diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go index 2a949c5d6..609e1d6cf 100644 --- a/cmd/hz/generator/handler.go +++ b/cmd/hz/generator/handler.go @@ -57,6 +57,13 @@ type Handler struct { Methods []*HttpMethod } +type SingleHandler struct { + *HttpMethod + FilePath string + PackageName string + ProjPackage string +} + type Client struct { Handler ServiceName string @@ -234,13 +241,12 @@ func (pkgGen *HttpPackageGenerator) updateHandler(handler interface{}, handlerTp if handlerSingleTpl == nil { return fmt.Errorf("tpl %s not found", handlerSingleTplName) } - data := make(map[string]string, 5) - data["Comment"] = method.Comment - data["Name"] = method.Name - data["RequestTypeName"] = method.RequestTypeName - data["ReturnTypeName"] = method.ReturnTypeName - data["Serializer"] = method.Serializer - data["OutputDir"] = method.OutputDir + data := SingleHandler{ + HttpMethod: method, + FilePath: handler.(Handler).FilePath, + PackageName: handler.(Handler).PackageName, + ProjPackage: handler.(Handler).ProjPackage, + } handlerFunc := bytes.NewBuffer(nil) err = handlerSingleTpl.Execute(handlerFunc, data) if err != nil { diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 2a66229c7..9071dbc85 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -1515,6 +1515,40 @@ func Test_Issue964(t *testing.T) { } } +type reqSameType struct { + Parent *reqSameType `json:"parent"` + Children []reqSameType `json:"children"` + Foo1 reqSameType2 `json:"foo1"` + A string `json:"a"` +} + +type reqSameType2 struct { + Foo1 *reqSameType `json:"foo1"` +} + +func TestBind_Issue1015(t *testing.T) { + req := newMockRequest(). + SetJSONContentType(). + SetBody([]byte(`{"parent":{"parent":{}, "children":[{},{}], "foo1":{"foo1":{}}}, "children":[{},{}], "a":"asd"}`)) + + var result reqSameType + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.NotNil(t, result.Parent) + assert.NotNil(t, result.Parent.Parent) + assert.Nil(t, result.Parent.Parent.Parent) + assert.NotNil(t, result.Parent.Children) + assert.DeepEqual(t, 2, len(result.Parent.Children)) + assert.NotNil(t, result.Parent.Foo1.Foo1) + assert.DeepEqual(t, "", result.Parent.A) + assert.DeepEqual(t, 2, len(result.Children)) + assert.Nil(t, result.Foo1.Foo1) + assert.DeepEqual(t, "asd", result.A) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 0bd13442a..f18a68127 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -80,7 +80,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder continue } - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag, config) + dec, needValidate2, err := getFieldDecoder(parentInfos{[]reflect.Type{el}, []int{}, ""}, el.Field(i), i, byTag, config) if err != nil { return nil, false, err } @@ -103,7 +103,13 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder }, needValidate, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) { +type parentInfos struct { + Types []reflect.Type + Indexes []int + JSONName string +} + +func getFieldDecoder(pInfo parentInfos, field reflect.StructField, index int, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } @@ -116,7 +122,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } // JSONName is like 'a.b.c' for 'required validate' - fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, config) + fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, pInfo.JSONName, config) if len(fieldTagInfos) == 0 && !config.DisableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } @@ -126,19 +132,19 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare // customized type decoder has the highest priority if customizedFunc, exist := config.TypeUnmarshalFuncs[field.Type]; exist { - dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc, config) + dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, customizedFunc, config) return dec, needValidate, err } // slice/array field decoder if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { - dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx, config) + dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, needValidate, err } // map filed decoder if field.Type.Kind() == reflect.Map { - dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config) + dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, needValidate, err } @@ -149,11 +155,11 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare // todo: more built-in common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): // file binding - dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx, config) + dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, needValidate, err } if !config.DisableStructFieldResolve { // decode struct type separately - structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx, config) + structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) if err != nil { return nil, needValidate, err } @@ -162,17 +168,26 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } } + // prevent infinite recursion when struct field with the same name as a struct + if hasSameType(pInfo.Types, el) { + return decoders, needValidate, nil + } + + pIdx := pInfo.Indexes for i := 0; i < el.NumField(); i++ { if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { // ignore unexported field continue } var idxes []int - if len(parentIdx) > 0 { - idxes = append(idxes, parentIdx...) + if len(pInfo.Indexes) > 0 { + idxes = append(idxes, pIdx...) } idxes = append(idxes, index) - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag, config) + pInfo.Indexes = idxes + pInfo.Types = append(pInfo.Types, el) + pInfo.JSONName = newParentJSONName + dec, needValidate2, err := getFieldDecoder(pInfo, el.Field(i), i, byTag, config) needValidate = needValidate || needValidate2 if err != nil { return nil, false, err @@ -186,6 +201,16 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } // base type decoder - dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config) + dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, needValidate, err } + +// hasSameType determine if the same type is present in the parent-child relationship +func hasSameType(pts []reflect.Type, ft reflect.Type) bool { + for _, pt := range pts { + if reflect.DeepEqual(getElemType(pt), getElemType(ft)) { + return true + } + } + return false +} diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index 09fbc10cd..f7b04ba25 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -539,7 +539,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo begin := req.Options().StartTime() dialTimeout := rc.dialTimeout - if reqTimeout < dialTimeout || dialTimeout == 0 { + if (reqTimeout > 0 && reqTimeout < dialTimeout) || dialTimeout == 0 { dialTimeout = reqTimeout } cc, inPool, err := c.acquireConn(dialTimeout) diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 7c45d1406..8c0869dde 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -82,7 +82,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), MaxConns: 1, @@ -212,16 +212,16 @@ func testContinueReadResponseBodyStream(t *testing.T, header, body string, maxBo } } -func newSlowConnDialer(dialer func(network, addr string) (network.Conn, error)) network.Dialer { +func newSlowConnDialer(dialer func(network, addr string, timeout time.Duration) (network.Conn, error)) network.Dialer { return &mockDialer{customDialConn: dialer} } type mockDialer struct { - customDialConn func(network, addr string) (network.Conn, error) + customDialConn func(network, addr string, timeout time.Duration) (network.Conn, error) } func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { - return m.customDialConn(network, address) + return m.customDialConn(network, address, timeout) } func (m *mockDialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { @@ -244,7 +244,7 @@ func (s *slowDialer) DialConnection(network, address string, timeout time.Durati func TestReadTimeoutPriority(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), MaxConns: 1, @@ -274,7 +274,7 @@ func TestReadTimeoutPriority(t *testing.T) { func TestDoNonNilReqResp(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return &writeErrConn{ Conn: mock.NewConn("HTTP/1.1 400 OK\nContent-Length: 6\n\n123456"), }, @@ -295,7 +295,7 @@ func TestDoNonNilReqResp(t *testing.T) { func TestDoNonNilReqResp1(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return &writeErrConn{ Conn: mock.NewConn(""), }, @@ -314,7 +314,7 @@ func TestDoNonNilReqResp1(t *testing.T) { func TestWriteTimeoutPriority(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowWriteDialer(addr) }), MaxConns: 1, @@ -376,7 +376,7 @@ func TestStateObserve(t *testing.T) { }{} c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), StateObserve: func(hcs config.HostClientState) { @@ -404,7 +404,7 @@ func TestStateObserve(t *testing.T) { func TestCachedTLSConfig(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), TLSConfig: &tls.Config{ @@ -426,7 +426,7 @@ func TestRetry(t *testing.T) { var times int32 c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { times++ if times < 3 { return &retryConn{ @@ -486,7 +486,7 @@ func (w retryConn) SetWriteTimeout(t time.Duration) error { func TestConnInPoolRetry(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.NewOneTimeConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil }), }, @@ -518,7 +518,7 @@ func TestConnInPoolRetry(t *testing.T) { func TestConnNotRetry(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.NewBrokenConn(""), nil }), }, @@ -558,7 +558,7 @@ func TestStreamNoContent(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ - Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) { return conn, nil }), }, @@ -576,3 +576,24 @@ func TestStreamNoContent(t *testing.T) { assert.True(t, conn.isClose) } + +func TestDialTimeout(t *testing.T) { + c := &HostClient{ + ClientOptions: &ClientOptions{ + DialTimeout: time.Second * 10, + Dialer: &mockDialer{ + customDialConn: func(network, addr string, timeout time.Duration) (network.Conn, error) { + assert.DeepEqual(t, time.Second*10, timeout) + return nil, errors.New("test error") + }, + }, + }, + Addr: "foobar", + } + + req := protocol.AcquireRequest() + req.SetRequestURI("http://foobar/baz") + resp := protocol.AcquireResponse() + + c.Do(context.Background(), req, resp) +} diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index c17aa0ff7..a5d33f13d 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -114,7 +114,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { defer func() { if s.EnableTrace { - if err != nil && !errors.Is(err, errs.ErrIdleTimeout) && !errors.Is(err, errs.ErrHijacked) { + if shouldRecordInTraceError(err) { ctx.GetTraceInfo().Stats().SetError(err) } // in case of error, we need to trigger all events @@ -460,3 +460,23 @@ func (e *eventStack) pop() func(ti traceinfo.TraceInfo, err error) { *e = (*e)[:len(*e)-1] return last } + +func shouldRecordInTraceError(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, errs.ErrIdleTimeout) { + return false + } + + if errors.Is(err, errs.ErrHijacked) { + return false + } + + if errors.Is(err, errs.ErrShortConnection) { + return false + } + + return true +} diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go index 5f8f78247..dc8790a97 100644 --- a/pkg/protocol/http1/server_test.go +++ b/pkg/protocol/http1/server_test.go @@ -443,3 +443,13 @@ type mockErrorWriter struct { func (errorWriter *mockErrorWriter) Flush() error { return errors.New("error") } + +func TestShouldRecordInTraceError(t *testing.T) { + assert.False(t, shouldRecordInTraceError(nil)) + assert.False(t, shouldRecordInTraceError(errHijacked)) + assert.False(t, shouldRecordInTraceError(errIdleTimeout)) + assert.False(t, shouldRecordInTraceError(errShortConnection)) + + assert.True(t, shouldRecordInTraceError(errTimeout)) + assert.True(t, shouldRecordInTraceError(errors.New("foo error"))) +} diff --git a/pkg/protocol/uri.go b/pkg/protocol/uri.go index 4fd8788e5..73b5b984c 100644 --- a/pkg/protocol/uri.go +++ b/pkg/protocol/uri.go @@ -49,7 +49,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" - "github.com/cloudwego/hertz/pkg/common/hlog" ) // AcquireURI returns an empty URI instance from the pool. @@ -388,11 +387,7 @@ func getScheme(rawURL []byte) (scheme, path []byte) { return nil, rawURL } case c == ':': - if i == 0 { - hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL) - return nil, nil - } - return rawURL[:i], rawURL[i+1:] + return checkSchemeWhenCharIsColon(i, rawURL) default: // we have encountered an invalid character, // so there is no valid scheme diff --git a/pkg/protocol/uri_unix.go b/pkg/protocol/uri_unix.go index 0127ceef0..d3726d8aa 100644 --- a/pkg/protocol/uri_unix.go +++ b/pkg/protocol/uri_unix.go @@ -44,6 +44,8 @@ package protocol +import "github.com/cloudwego/hertz/pkg/common/hlog" + func addLeadingSlash(dst, src []byte) []byte { // add leading slash for unix paths if len(src) == 0 || src[0] != '/' { @@ -52,3 +54,13 @@ func addLeadingSlash(dst, src []byte) []byte { return dst } + +// checkSchemeWhenCharIsColon check url begin with : +// Scenarios that handle protocols like "http:" +func checkSchemeWhenCharIsColon(i int, rawURL []byte) (scheme, path []byte) { + if i == 0 { + hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL) + return + } + return rawURL[:i], rawURL[i+1:] +} diff --git a/pkg/protocol/uri_unix_test.go b/pkg/protocol/uri_unix_test.go new file mode 100644 index 000000000..f89cd112c --- /dev/null +++ b/pkg/protocol/uri_unix_test.go @@ -0,0 +1,44 @@ +//go:build !windows +// +build !windows + +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protocol + +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestGetScheme(t *testing.T) { + scheme, path := getScheme([]byte("https://foo.com")) + assert.DeepEqual(t, "https", string(scheme)) + assert.DeepEqual(t, "//foo.com", string(path)) + + scheme, path = getScheme([]byte(":")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "", string(path)) + + scheme, path = getScheme([]byte("ws://127.0.0.1")) + assert.DeepEqual(t, "ws", string(scheme)) + assert.DeepEqual(t, "//127.0.0.1", string(path)) + + scheme, path = getScheme([]byte("/hertz/demo")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "/hertz/demo", string(path)) +} diff --git a/pkg/protocol/uri_windows.go b/pkg/protocol/uri_windows.go index abf13e72f..a40daef98 100644 --- a/pkg/protocol/uri_windows.go +++ b/pkg/protocol/uri_windows.go @@ -44,6 +44,8 @@ package protocol +import "github.com/cloudwego/hertz/pkg/common/hlog" + func addLeadingSlash(dst, src []byte) []byte { // zero length and "C:/" case isDisk := len(src) > 2 && src[1] == ':' @@ -53,3 +55,20 @@ func addLeadingSlash(dst, src []byte) []byte { return dst } + +// checkSchemeWhenCharIsColon check url begin with : +// Scenarios that handle protocols like "http:" +// Add the path to the win file, e.g. "E:\gopath", "E:\". +func checkSchemeWhenCharIsColon(i int, rawURL []byte) (scheme, path []byte) { + if i == 0 { + hlog.Errorf("error happened when trying to parse the rawURL(%s): missing protocol scheme", rawURL) + return + } + + // case :\ + if i+1 < len(rawURL) && rawURL[i+1] == '\\' { + return nil, rawURL + } + + return rawURL[:i], rawURL[i+1:] +} diff --git a/pkg/protocol/uri_windows_test.go b/pkg/protocol/uri_windows_test.go index 507924b97..0ec5e4e78 100644 --- a/pkg/protocol/uri_windows_test.go +++ b/pkg/protocol/uri_windows_test.go @@ -14,7 +14,11 @@ package protocol -import "testing" +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) func TestURIPathNormalizeIssue86(t *testing.T) { t.Parallel() @@ -26,3 +30,25 @@ func TestURIPathNormalizeIssue86(t *testing.T) { testURIPathNormalize(t, &u, "/..\\..\\..\\..\\..\\", "/") testURIPathNormalize(t, &u, "/..%5c..%5cfoo", "/foo") } + +func TestGetScheme(t *testing.T) { + scheme, path := getScheme([]byte("E:\\file.go")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "E:\\file.go", string(path)) + + scheme, path = getScheme([]byte("E:\\")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "E:\\", string(path)) + + scheme, path = getScheme([]byte("https://foo.com")) + assert.DeepEqual(t, "https", string(scheme)) + assert.DeepEqual(t, "//foo.com", string(path)) + + scheme, path = getScheme([]byte("://")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "", string(path)) + + scheme, path = getScheme([]byte("ws://127.0.0.1")) + assert.DeepEqual(t, "ws", string(scheme)) + assert.DeepEqual(t, "//127.0.0.1", string(path)) +} diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 37a154bc2..b3e0adb30 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -55,6 +55,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server/binding" + "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -63,6 +64,7 @@ import ( "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/cloudwego/hertz/pkg/protocol/suite" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -854,3 +856,176 @@ func TestCustomValidator(t *testing.T) { }) performRequest(e, "GET", "/validate?a=2") } + +var errTestDeregsitry = fmt.Errorf("test deregsitry error") + +type mockDeregsitryErr struct{} + +var _ registry.Registry = &mockDeregsitryErr{} + +func (e mockDeregsitryErr) Register(*registry.Info) error { + return nil +} + +func (e mockDeregsitryErr) Deregister(*registry.Info) error { + return errTestDeregsitry +} + +func TestEngineShutdown(t *testing.T) { + defaultTransporter = standard.NewTransporter + mockCtxCallback := func(ctx context.Context) {} + // Test case 1: serve not running error + engine := NewEngine(config.NewOptions(nil)) + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + err := engine.Shutdown(ctx1) + assert.DeepEqual(t, errStatusNotRunning, err) + + // Test case 2: serve successfully running and shutdown + engine = NewEngine(config.NewOptions(nil)) + engine.OnShutdown = []CtxCallback{mockCtxCallback} + go func() { + engine.Run() + }() + // wait for engine to start + time.Sleep(100 * time.Millisecond) + + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + err = engine.Shutdown(ctx2) + assert.Nil(t, err) + assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status)) + + // Test case 3: serve successfully running and shutdown with deregistry error + engine = NewEngine(config.NewOptions(nil)) + engine.OnShutdown = []CtxCallback{mockCtxCallback} + engine.options.Registry = &mockDeregsitryErr{} + go func() { + engine.Run() + }() + // wait for engine to start + time.Sleep(100 * time.Millisecond) + + ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) + defer cancel3() + err = engine.Shutdown(ctx3) + assert.DeepEqual(t, errTestDeregsitry, err) + assert.DeepEqual(t, statusShutdown, atomic.LoadUint32(&engine.status)) +} + +type mockStreamer struct{} + +type mockProtocolServer struct{} + +func (s *mockStreamer) Serve(c context.Context, conn network.StreamConn) error { + return nil +} + +func (s *mockProtocolServer) Serve(c context.Context, conn network.Conn) error { + return nil +} + +type mockStreamConn struct { + network.StreamConn + version string +} + +var _ network.StreamConn = &mockStreamConn{} + +func (m *mockStreamConn) GetVersion() uint32 { + return network.Version1 +} + +func TestEngineServeStream(t *testing.T) { + engine := &Engine{ + options: &config.Options{ + ALPN: true, + TLS: &tls.Config{}, + }, + protocolStreamServers: map[string]protocol.StreamServer{ + suite.HTTP3: &mockStreamer{}, + }, + } + + // Test ALPN path + conn := &mockStreamConn{version: suite.HTTP3} + err := engine.ServeStream(context.Background(), conn) + assert.Nil(t, err) + + // Test default path + engine.options.ALPN = false + conn = &mockStreamConn{} + err = engine.ServeStream(context.Background(), conn) + assert.Nil(t, err) + + // Test unsupported protocol + engine.protocolStreamServers = map[string]protocol.StreamServer{} + conn = &mockStreamConn{} + err = engine.ServeStream(context.Background(), conn) + assert.DeepEqual(t, errs.ErrNotSupportProtocol, err) +} + +func TestEngineServe(t *testing.T) { + engine := NewEngine(config.NewOptions(nil)) + engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} + engine.protocolServers[suite.HTTP2] = &mockProtocolServer{} + + // test H2C path + ctx := context.Background() + conn := mock.NewConn("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + engine.options.H2C = true + err := engine.Serve(ctx, conn) + assert.Nil(t, err) + + // test ALPN path + ctx = context.Background() + conn = mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + engine.options.H2C = false + engine.options.ALPN = true + engine.options.TLS = &tls.Config{} + err = engine.Serve(ctx, conn) + assert.Nil(t, err) + + // test HTTP1 path + engine.options.ALPN = false + err = engine.Serve(ctx, conn) + assert.Nil(t, err) +} + +func TestOndata(t *testing.T) { + ctx := context.Background() + engine := NewEngine(config.NewOptions(nil)) + + // test stream conn + streamConn := &mockStreamConn{version: suite.HTTP3} + engine.protocolStreamServers[suite.HTTP3] = &mockStreamer{} + err := engine.onData(ctx, streamConn) + assert.Nil(t, err) + + // test conn + conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} + err = engine.onData(ctx, conn) + assert.Nil(t, err) +} + +func TestAcquireHijackConn(t *testing.T) { + engine := &Engine{ + NoHijackConnPool: false, + } + // test conn pool + conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + hijackConn := engine.acquireHijackConn(conn) + assert.NotNil(t, hijackConn) + assert.NotNil(t, hijackConn.Conn) + assert.DeepEqual(t, engine, hijackConn.e) + assert.DeepEqual(t, conn, hijackConn.Conn) + + // test no conn pool + engine.NoHijackConnPool = true + hijackConn = engine.acquireHijackConn(conn) + assert.NotNil(t, hijackConn) + assert.NotNil(t, hijackConn.Conn) + assert.DeepEqual(t, engine, hijackConn.e) + assert.DeepEqual(t, conn, hijackConn.Conn) +} diff --git a/version.go b/version.go index db0604763..fef300f96 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.7.2" + Version = "v0.7.3" )