From e37d3082d28243a9175ac47c6dba4cb3eda603f9 Mon Sep 17 00:00:00 2001
From: meizz <meizz.liu@anker-in.com>
Date: Fri, 22 Nov 2024 18:49:54 +0800
Subject: [PATCH 1/3] feat:  add StartWithListener method  to api and rpc
 server

---
 rest/engine.go                  | 26 ++++++++++++++++++
 rest/internal/starter.go        | 46 ++++++++++++++++++++++++++++++++
 rest/internal/starter_test.go   |  9 +++++++
 rest/server.go                  |  8 ++++++
 rest/server_test.go             | 19 +++++++++++++
 zrpc/internal/rpcserver.go      | 30 +++++++++++++++++++++
 zrpc/internal/rpcserver_test.go | 20 ++++++++++++--
 zrpc/internal/server.go         |  2 ++
 zrpc/server.go                  | 11 ++++++++
 zrpc/server_test.go             | 47 +++++++++++++++++++++++++++++++++
 10 files changed, 216 insertions(+), 2 deletions(-)

diff --git a/rest/engine.go b/rest/engine.go
index e57786caf205..c3c7e1f033f5 100644
--- a/rest/engine.go
+++ b/rest/engine.go
@@ -4,6 +4,7 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"net"
 	"net/http"
 	"sort"
 	"time"
@@ -330,6 +331,31 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
 		ng.conf.KeyFile, router, opts...)
 }
 
+func (ng *engine) startWithListener(listener net.Listener, router httpx.Router, opts ...StartOption) error {
+	if err := ng.bindRoutes(router); err != nil {
+		return err
+	}
+
+	// make sure user defined options overwrite default options
+	opts = append([]StartOption{ng.withTimeout()}, opts...)
+
+	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
+		return internal.StartHttpWithListener(listener, router, opts...)
+	}
+
+	// make sure user defined options overwrite default options
+	opts = append([]StartOption{
+		func(svr *http.Server) {
+			if ng.tlsConfig != nil {
+				svr.TLSConfig = ng.tlsConfig
+			}
+		},
+	}, opts...)
+
+	return internal.StartHttpsWithListener(listener, ng.conf.CertFile,
+		ng.conf.KeyFile, router, opts...)
+}
+
 func (ng *engine) use(middleware Middleware) {
 	ng.middlewares = append(ng.middlewares, middleware)
 }
diff --git a/rest/internal/starter.go b/rest/internal/starter.go
index 174303342b7e..a2533bf57d71 100644
--- a/rest/internal/starter.go
+++ b/rest/internal/starter.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"net"
 	"net/http"
 
 	"github.com/zeromicro/go-zero/core/logx"
@@ -23,6 +24,13 @@ func StartHttp(host string, port int, handler http.Handler, opts ...StartOption)
 	}, opts...)
 }
 
+// StartHttpWithListener starts a http server with listener.
+func StartHttpWithListener(listener net.Listener, handler http.Handler, opts ...StartOption) error {
+	return startWithListener(listener, handler, func(svr *http.Server) error {
+		return svr.Serve(listener)
+	}, opts...)
+}
+
 // StartHttps starts a https server.
 func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
 	opts ...StartOption) error {
@@ -32,6 +40,15 @@ func StartHttps(host string, port int, certFile, keyFile string, handler http.Ha
 	}, opts...)
 }
 
+// StartHttpsWithListener starts a https server with listener.
+func StartHttpsWithListener(listener net.Listener, certFile, keyFile string, handler http.Handler,
+	opts ...StartOption) error {
+	return startWithListener(listener, handler, func(svr *http.Server) error {
+		// certFile and keyFile are set in buildHttpsServer
+		return svr.ServeTLS(listener, certFile, keyFile)
+	}, opts...)
+}
+
 func start(host string, port int, handler http.Handler, run func(svr *http.Server) error,
 	opts ...StartOption) (err error) {
 	server := &http.Server{
@@ -59,3 +76,32 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve
 	health.AddProbe(healthManager)
 	return run(server)
 }
+
+func startWithListener(listener net.Listener, handler http.Handler, run func(svr *http.Server) error,
+	opts ...StartOption) (err error) {
+
+	server := &http.Server{
+		Addr:    fmt.Sprintf("%s", listener.Addr().String()),
+		Handler: handler,
+	}
+	for _, opt := range opts {
+		opt(server)
+	}
+
+	healthManager := health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, listener.Addr().String()))
+	waitForCalled := proc.AddShutdownListener(func() {
+		healthManager.MarkNotReady()
+		if e := server.Shutdown(context.Background()); e != nil {
+			logx.Error(e)
+		}
+	})
+	defer func() {
+		if errors.Is(err, http.ErrServerClosed) {
+			waitForCalled()
+		}
+	}()
+
+	healthManager.MarkReady()
+	health.AddProbe(healthManager)
+	return run(server)
+}
diff --git a/rest/internal/starter_test.go b/rest/internal/starter_test.go
index a54c215f9f56..b21f837c71b0 100644
--- a/rest/internal/starter_test.go
+++ b/rest/internal/starter_test.go
@@ -34,3 +34,12 @@ func TestStartHttps(t *testing.T) {
 	assert.NotNil(t, err)
 	proc.WrapUp()
 }
+
+func TestStartHttpsWithListener(t *testing.T) {
+	svr := httptest.NewUnstartedServer(http.NotFoundHandler())
+	err := StartHttpsWithListener(svr.Listener, "", "", http.NotFoundHandler(), func(svr *http.Server) {
+		svr.IdleTimeout = 0
+	})
+	assert.NotNil(t, err)
+	proc.WrapUp()
+}
diff --git a/rest/server.go b/rest/server.go
index b1e5487bd8a5..c55419f1f95c 100644
--- a/rest/server.go
+++ b/rest/server.go
@@ -3,6 +3,7 @@ package rest
 import (
 	"crypto/tls"
 	"errors"
+	"net"
 	"net/http"
 	"path"
 	"time"
@@ -121,6 +122,13 @@ func (s *Server) Start() {
 	handleError(s.ngin.start(s.router))
 }
 
+// StartWithListener starts the Server with listener
+// Graceful shutdown is enabled by default.
+// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
+func (s *Server) StartWithListener(listener net.Listener) {
+	handleError(s.ngin.startWithListener(listener, s.router))
+}
+
 // StartWithOpts starts the Server.
 // Graceful shutdown is enabled by default.
 // Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
diff --git a/rest/server_test.go b/rest/server_test.go
index 9a92d58f8203..3a298b48c826 100644
--- a/rest/server_test.go
+++ b/rest/server_test.go
@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"io/fs"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -124,6 +125,24 @@ Port: 0
 			})
 			svr.Stop()
 		}()
+
+		func() {
+			defer func() {
+				p := recover()
+				switch v := p.(type) {
+				case error:
+					assert.Equal(t, "foo", v.Error())
+				default:
+					t.Fail()
+				}
+			}()
+
+			address := fmt.Sprintf("%s:%d", cnf.Host, cnf.Port)
+			listener, err := net.Listen("tcp", address)
+			assert.Nil(t, err)
+			svr.StartWithListener(listener)
+			svr.Stop()
+		}()
 	}
 }
 
diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go
index c1302f03ae52..7d94c4cc4b20 100644
--- a/zrpc/internal/rpcserver.go
+++ b/zrpc/internal/rpcserver.go
@@ -78,6 +78,36 @@ func (s *rpcServer) Start(register RegisterFn) error {
 	return server.Serve(lis)
 }
 
+func (s *rpcServer) StartWithListener(listener net.Listener, register RegisterFn) error {
+
+	unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.unaryInterceptors...)
+	streamInterceptorOption := grpc.ChainStreamInterceptor(s.streamInterceptors...)
+
+	options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
+	server := grpc.NewServer(options...)
+	register(server)
+
+	// register the health check service
+	if s.health != nil {
+		grpc_health_v1.RegisterHealthServer(server, s.health)
+		s.health.Resume()
+	}
+	s.healthManager.MarkReady()
+	health.AddProbe(s.healthManager)
+
+	// we need to make sure all others are wrapped up,
+	// so we do graceful stop at shutdown phase instead of wrap up phase
+	waitForCalled := proc.AddShutdownListener(func() {
+		if s.health != nil {
+			s.health.Shutdown()
+		}
+		server.GracefulStop()
+	})
+	defer waitForCalled()
+
+	return server.Serve(listener)
+}
+
 // WithRpcHealth returns a func that sets rpc health switch to a Server.
 func WithRpcHealth(health bool) ServerOption {
 	return func(options *rpcServerOptions) {
diff --git a/zrpc/internal/rpcserver_test.go b/zrpc/internal/rpcserver_test.go
index 696dae68713c..b415f15eba8a 100644
--- a/zrpc/internal/rpcserver_test.go
+++ b/zrpc/internal/rpcserver_test.go
@@ -1,6 +1,7 @@
 package internal
 
 import (
+	"net"
 	"sync"
 	"testing"
 	"time"
@@ -17,8 +18,8 @@ func TestRpcServer(t *testing.T) {
 	var wg, wgDone sync.WaitGroup
 	var grpcServer *grpc.Server
 	var lock sync.Mutex
-	wg.Add(1)
-	wgDone.Add(1)
+	wg.Add(2)
+	wgDone.Add(2)
 	go func() {
 		err := server.Start(func(server *grpc.Server) {
 			lock.Lock()
@@ -31,6 +32,21 @@ func TestRpcServer(t *testing.T) {
 		wgDone.Done()
 	}()
 
+	go func() {
+		listener, err := net.Listen("tcp", "localhost:54322")
+		assert.Nil(t, err)
+		serverWithListener := NewRpcServer(listener.Addr().String(), WithRpcHealth(true))
+		err = serverWithListener.StartWithListener(listener, func(server *grpc.Server) {
+			lock.Lock()
+			mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
+			grpcServer = server
+			lock.Unlock()
+			wg.Done()
+		})
+		assert.Nil(t, err)
+		wgDone.Done()
+	}()
+
 	wg.Wait()
 	time.Sleep(100 * time.Millisecond)
 
diff --git a/zrpc/internal/server.go b/zrpc/internal/server.go
index fc9eea0cbb50..40f4972deb72 100644
--- a/zrpc/internal/server.go
+++ b/zrpc/internal/server.go
@@ -1,6 +1,7 @@
 package internal
 
 import (
+	"net"
 	"time"
 
 	"google.golang.org/grpc"
@@ -21,6 +22,7 @@ type (
 		AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor)
 		SetName(string)
 		Start(register RegisterFn) error
+		StartWithListener(listener net.Listener, register RegisterFn) error
 	}
 
 	baseRpcServer struct {
diff --git a/zrpc/server.go b/zrpc/server.go
index 813fc358d298..d8463514331a 100644
--- a/zrpc/server.go
+++ b/zrpc/server.go
@@ -1,6 +1,7 @@
 package zrpc
 
 import (
+	"net"
 	"time"
 
 	"github.com/zeromicro/go-zero/core/load"
@@ -92,6 +93,16 @@ func (rs *RpcServer) Start() {
 	}
 }
 
+// StartWithListener starts the RpcServer with listener.
+// Graceful shutdown is enabled by default.
+// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
+func (rs *RpcServer) StartWithListener(listener net.Listener) {
+	if err := rs.server.StartWithListener(listener, rs.register); err != nil {
+		logx.Error(err)
+		panic(err)
+	}
+}
+
 // Stop stops the RpcServer.
 func (rs *RpcServer) Stop() {
 	logx.Close()
diff --git a/zrpc/server_test.go b/zrpc/server_test.go
index e42e379a6cd1..03c344c34f1d 100644
--- a/zrpc/server_test.go
+++ b/zrpc/server_test.go
@@ -2,6 +2,7 @@ package zrpc
 
 import (
 	"context"
+	"net"
 	"testing"
 	"time"
 
@@ -56,6 +57,48 @@ func TestServer(t *testing.T) {
 	svr.Stop()
 }
 
+func TestServer_StartWithListener(t *testing.T) {
+	DontLogContentForMethod("foo")
+	SetServerSlowThreshold(time.Second)
+	svr := MustNewServer(RpcServerConf{
+		ServiceConf: service.ServiceConf{
+			Log: logx.LogConf{
+				ServiceName: "foo",
+				Mode:        "console",
+			},
+		},
+		ListenOn:      "localhost:8081",
+		Etcd:          discov.EtcdConf{},
+		Auth:          false,
+		Redis:         redis.RedisKeyConf{},
+		StrictControl: false,
+		Timeout:       0,
+		CpuThreshold:  0,
+		Middlewares: ServerMiddlewaresConf{
+			Trace:      true,
+			Recover:    true,
+			Stat:       true,
+			Prometheus: true,
+			Breaker:    true,
+		},
+		MethodTimeouts: []MethodTimeoutConf{
+			{
+				FullMethod: "/foo",
+				Timeout:    time.Second,
+			},
+		},
+	}, func(server *grpc.Server) {
+	})
+	svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
+	svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor)
+	svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor)
+
+	listener, err := net.Listen("tcp", "localhost:8081")
+	assert.Nil(t, err)
+	go svr.StartWithListener(listener)
+	svr.Stop()
+}
+
 func TestServerError(t *testing.T) {
 	_, err := NewServer(RpcServerConf{
 		ServiceConf: service.ServiceConf{
@@ -159,6 +202,10 @@ func (m *mockedServer) Start(_ internal.RegisterFn) error {
 	return nil
 }
 
+func (m *mockedServer) StartWithListener(_ net.Listener, _ internal.RegisterFn) error {
+	return nil
+}
+
 func Test_setupUnaryInterceptors(t *testing.T) {
 	tests := []struct {
 		name string

From 58342b7c2c4ea28aa6ba32c5841f4f8cced33b81 Mon Sep 17 00:00:00 2001
From: meizz <meizz.liu@anker-in.com>
Date: Mon, 25 Nov 2024 15:25:47 +0800
Subject: [PATCH 2/3] feat: Supplement test cases to improve coverage

---
 rest/engine_test.go           | 38 +++++++++++++++++++++++++++++++++++
 rest/internal/starter_test.go |  9 +++++++++
 2 files changed, 47 insertions(+)

diff --git a/rest/engine_test.go b/rest/engine_test.go
index 4f86d2173efd..12fc7c4fedd9 100644
--- a/rest/engine_test.go
+++ b/rest/engine_test.go
@@ -5,6 +5,7 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -429,6 +430,43 @@ func TestEngine_start(t *testing.T) {
 	})
 }
 
+func TestEngine_startWithListener(t *testing.T) {
+	logx.Disable()
+
+	t.Run("http", func(t *testing.T) {
+		ng := newEngine(RestConf{
+			Host: "localhost",
+			Port: -1,
+		})
+		address := fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)
+		listener, err := net.Listen("tcp", address)
+		assert.Error(t, err)
+		if listener != nil {
+			assert.Error(t, ng.startWithListener(listener, router.NewRouter()))
+		} else {
+			assert.Error(t, ng.start(router.NewRouter()))
+		}
+	})
+
+	t.Run("https", func(t *testing.T) {
+		ng := newEngine(RestConf{
+			Host:     "localhost",
+			Port:     -1,
+			CertFile: "foo",
+			KeyFile:  "bar",
+		})
+		ng.tlsConfig = &tls.Config{}
+		address := fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)
+		listener, err := net.Listen("tcp", address)
+		assert.Error(t, err)
+		if listener != nil {
+			assert.Error(t, ng.startWithListener(listener, router.NewRouter()))
+		} else {
+			assert.Error(t, ng.start(router.NewRouter()))
+		}
+	})
+}
+
 type mockedRouter struct {
 }
 
diff --git a/rest/internal/starter_test.go b/rest/internal/starter_test.go
index b21f837c71b0..9e59ddfaa8f4 100644
--- a/rest/internal/starter_test.go
+++ b/rest/internal/starter_test.go
@@ -35,6 +35,15 @@ func TestStartHttps(t *testing.T) {
 	proc.WrapUp()
 }
 
+func TestStartHttpWithListener(t *testing.T) {
+	svr := httptest.NewUnstartedServer(http.NotFoundHandler())
+	err := StartHttpWithListener(svr.Listener, http.NotFoundHandler(), func(svr *http.Server) {
+		svr.IdleTimeout = 0
+	})
+	assert.NotNil(t, err)
+	proc.WrapUp()
+}
+
 func TestStartHttpsWithListener(t *testing.T) {
 	svr := httptest.NewUnstartedServer(http.NotFoundHandler())
 	err := StartHttpsWithListener(svr.Listener, "", "", http.NotFoundHandler(), func(svr *http.Server) {

From cdc80d8cbd47bc5c27f5b2ce3152d27a70db060b Mon Sep 17 00:00:00 2001
From: meizz <meizz.liu@anker-in.com>
Date: Wed, 4 Dec 2024 12:21:33 +0800
Subject: [PATCH 3/3] fix: startWithListener register etcd

---
 zrpc/internal/rpcpubserver.go      | 9 +++++++++
 zrpc/internal/rpcpubserver_test.go | 4 ++++
 2 files changed, 13 insertions(+)

diff --git a/zrpc/internal/rpcpubserver.go b/zrpc/internal/rpcpubserver.go
index 70b481323d92..4c795ad7c271 100644
--- a/zrpc/internal/rpcpubserver.go
+++ b/zrpc/internal/rpcpubserver.go
@@ -1,6 +1,7 @@
 package internal
 
 import (
+	"net"
 	"os"
 	"strings"
 
@@ -53,6 +54,14 @@ func (s keepAliveServer) Start(fn RegisterFn) error {
 	return s.Server.Start(fn)
 }
 
+func (s keepAliveServer) StartWithListener(listener net.Listener, fn RegisterFn) error {
+	if err := s.registerEtcd(); err != nil {
+		return err
+	}
+
+	return s.Server.StartWithListener(listener, fn)
+}
+
 func figureOutListenOn(listenOn string) string {
 	fields := strings.Split(listenOn, ":")
 	if len(fields) == 0 {
diff --git a/zrpc/internal/rpcpubserver_test.go b/zrpc/internal/rpcpubserver_test.go
index cc36e4653357..ec009820736c 100644
--- a/zrpc/internal/rpcpubserver_test.go
+++ b/zrpc/internal/rpcpubserver_test.go
@@ -18,6 +18,10 @@ func TestNewRpcPubServer(t *testing.T) {
 	assert.NotPanics(t, func() {
 		s.Start(nil)
 	})
+
+	assert.NotPanics(t, func() {
+		s.StartWithListener(nil, nil)
+	})
 }
 
 func TestFigureOutListenOn(t *testing.T) {