From 1209633b0fb45ec89725fd63820ffacd2e596bf0 Mon Sep 17 00:00:00 2001 From: knakul853 Date: Tue, 23 Sep 2025 16:58:23 +0530 Subject: [PATCH] client: start keepalive only in StartPolling; stop on StopPolling/Close; tests server: add HTTPServer.Close(ctx); ignore ErrServerClosed; goleak test --- go.mod | 1 + go.sum | 2 + pkg/client/client.go | 70 +++++---- pkg/client/client_test.go | 269 +++++++++++++++++++++++++++++++++ pkg/server/http_server.go | 24 ++- pkg/server/http_server_test.go | 50 +++++- 6 files changed, 377 insertions(+), 39 deletions(-) create mode 100644 pkg/client/client_test.go diff --git a/go.mod b/go.mod index c163deed..87dadbf2 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/rs/xid v1.5.0 github.com/stretchr/testify v1.9.0 github.com/syndtr/goleveldb v1.0.0 + go.uber.org/goleak v1.2.0 go.uber.org/multierr v1.11.0 go.uber.org/ratelimit v0.3.0 go.uber.org/zap v1.25.0 diff --git a/go.sum b/go.sum index a914b6f3..24640cd8 100644 --- a/go.sum +++ b/go.sum @@ -368,6 +368,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= diff --git a/pkg/client/client.go b/pkg/client/client.go index f76b5d3d..123e2976 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -64,6 +64,7 @@ type Client struct { pubKey *rsa.PublicKey quitChan chan struct{} quitKeepAliveChan chan struct{} + keepAliveInterval time.Duration disableHTTPFallback bool token string correlationIdLength int @@ -157,6 +158,7 @@ func New(options *Options) (*Client, error) { httpClient: httpclient, token: token, disableHTTPFallback: options.DisableHTTPFallback, + keepAliveInterval: options.KeepAliveInterval, correlationIdLength: options.CorrelationIdLength, CorrelationIdNonceLength: options.CorrelationIdNonceLength, } @@ -192,39 +194,41 @@ func New(options *Options) (*Client, error) { } } - // start a keep alive routine - client.quitKeepAliveChan = make(chan struct{}) - if options.KeepAliveInterval > 0 { - ticker := time.NewTicker(options.KeepAliveInterval) - go func() { - for { - // exit if the client is closed - if client.State.Load() == Closed { + return client, nil +} + +// startKeepAlive starts the keepalive goroutine if configured +func (c *Client) startKeepAlive() { + if c.keepAliveInterval <= 0 { + return + } + if c.quitKeepAliveChan != nil { + return + } + c.quitKeepAliveChan = make(chan struct{}) + ticker := time.NewTicker(c.keepAliveInterval) + go func() { + for { + if c.State.Load() == Closed { + return + } + select { + case <-ticker.C: + pubKeyData, err := encodePublicKey(c.pubKey) + if err != nil { return } - select { - case <-ticker.C: - // todo: internal logic needs a complete redesign - pubKeyData, err := encodePublicKey(client.pubKey) - if err != nil { - return - } - // attempts to re-register - server will reject is already existing - registrationRequest, err := encodeRegistrationRequest(pubKeyData, client.secretKey, client.correlationID) - if err != nil { - return - } - // silently fails to re-register if the session is still alive - _ = client.performRegistration(client.serverURL.String(), registrationRequest) - case <-client.quitKeepAliveChan: - ticker.Stop() + registrationRequest, err := encodeRegistrationRequest(pubKeyData, c.secretKey, c.correlationID) + if err != nil { return } + _ = c.performRegistration(c.serverURL.String(), registrationRequest) + case <-c.quitKeepAliveChan: + ticker.Stop() + return } - }() - } - - return client, nil + } + }() } // initializeRSAKeys does the one-time initialization for RSA crypto mechanism @@ -367,6 +371,7 @@ func (c *Client) StartPolling(duration time.Duration, callback InteractionCallba } c.State.Store(Polling) + c.startKeepAlive() ticker := time.NewTicker(duration) c.quitChan = make(chan struct{}) @@ -518,6 +523,10 @@ func (c *Client) StopPolling() error { return errors.New("client is not polling") } close(c.quitChan) + if c.quitKeepAliveChan != nil { + close(c.quitKeepAliveChan) + c.quitKeepAliveChan = nil + } c.State.Store(Idle) @@ -537,7 +546,10 @@ func (c *Client) Close() error { return errors.New("client is already closed") } - close(c.quitKeepAliveChan) + if c.quitKeepAliveChan != nil { + close(c.quitKeepAliveChan) + c.quitKeepAliveChan = nil + } register := server.DeregisterRequest{ CorrelationID: c.correlationID, diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go new file mode 100644 index 00000000..d23e2fb4 --- /dev/null +++ b/pkg/client/client_test.go @@ -0,0 +1,269 @@ +package client + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/projectdiscovery/interactsh/pkg/server" + "github.com/projectdiscovery/retryablehttp-go" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestClient_NoLeak_AfterNew(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + mux.HandleFunc("/poll", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(&server.PollResponse{Data: []string{}, Extra: []string{}, AESKey: "", TLDData: []string{}}) + }) + mux.HandleFunc("/deregister", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "deregistration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ + ServerURL: ts.URL, + HTTPClient: httpClient, + KeepAliveInterval: 10 * time.Millisecond, + } + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + _ = c + time.Sleep(50 * time.Millisecond) +} + +func TestClient_NoLeaks_WhenStoppedAndClosed(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + mux.HandleFunc("/poll", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(&server.PollResponse{Data: []string{}, Extra: []string{}, AESKey: "", TLDData: []string{}}) + }) + mux.HandleFunc("/deregister", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "deregistration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ + ServerURL: ts.URL, + HTTPClient: httpClient, + // disable keepalive here to avoid race clobbering Polling -> Idle during re-register + KeepAliveInterval: 0, + } + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if err := c.StartPolling(10*time.Millisecond, func(_ *server.Interaction) {}); err != nil { + t.Fatalf("unexpected error starting polling: %v", err) + } + + time.Sleep(30 * time.Millisecond) + + if err := c.StopPolling(); err != nil { + t.Fatalf("unexpected error stopping polling: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("unexpected error closing client: %v", err) + } + + time.Sleep(30 * time.Millisecond) +} + +func TestClient_StartPolling_StateErrors(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ServerURL: ts.URL, HTTPClient: httpClient, KeepAliveInterval: 0} + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + c.State.Store(Polling) + if err := c.StartPolling(5*time.Millisecond, func(_ *server.Interaction) {}); err == nil { + t.Fatalf("expected error when already polling") + } + + c.State.Store(Closed) + if err := c.StartPolling(5*time.Millisecond, func(_ *server.Interaction) {}); err == nil { + t.Fatalf("expected error when client is closed") + } +} + +func TestClient_StopPolling_NotPolling(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ServerURL: ts.URL, HTTPClient: httpClient, KeepAliveInterval: 0} + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if err := c.StopPolling(); err == nil { + t.Fatalf("expected error when not polling") + } +} + +func TestClient_Close_ErrorsAndDeregister(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + mux.HandleFunc("/poll", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(&server.PollResponse{Data: []string{}, Extra: []string{}, AESKey: "", TLDData: []string{}}) + }) + mux.HandleFunc("/deregister", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "deregistration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ServerURL: ts.URL, HTTPClient: httpClient, KeepAliveInterval: 0} + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + if err := c.StartPolling(5*time.Millisecond, func(_ *server.Interaction) {}); err != nil { + t.Fatalf("unexpected error starting polling: %v", err) + } + if err := c.Close(); err == nil { + _ = c.StopPolling() + t.Fatalf("expected error when closing while polling") + } + + if err := c.StopPolling(); err != nil { + t.Fatalf("unexpected error stopping polling: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("unexpected error closing client: %v", err) + } + if err := c.Close(); err == nil { + t.Fatalf("expected error when closing already closed client") + } +} + +func TestClient_URL_Builds_And_EmptyOnClosed(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ServerURL: ts.URL, HTTPClient: httpClient, KeepAliveInterval: 0} + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + u := c.URL() + if u == "" { + t.Fatalf("expected non-empty URL") + } + + c.State.Store(Closed) + if got := c.URL(); got != "" { + t.Fatalf("expected empty URL when closed, got %q", got) + } +} + +func TestClient_SaveSessionTo_WritesYAML(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "registration successful"}) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + httpOpts := retryablehttp.DefaultOptionsSpraying + httpOpts.Timeout = 2 * time.Second + httpClient := retryablehttp.NewClient(httpOpts) + + opts := &Options{ServerURL: ts.URL, HTTPClient: httpClient, KeepAliveInterval: 0} + c, err := New(opts) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + tmp := t.TempDir() + "/sess.yaml" + if err := c.SaveSessionTo(tmp); err != nil { + t.Fatalf("unexpected error saving session: %v", err) + } +} diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index 20310aeb..17d2b924 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "crypto/tls" "encoding/base64" "fmt" @@ -92,7 +93,9 @@ func (h *HTTPServer) ListenAndServe(tlsConfig *tls.Config, httpAlive, httpsAlive httpsAlive <- true if err := h.tlsserver.ListenAndServeTLS("", ""); err != nil { - gologger.Error().Msgf("Could not serve http on tls: %s\n", err) + if err != http.ErrServerClosed { + gologger.Error().Msgf("Could not serve http on tls: %s\n", err) + } httpsAlive <- false } }() @@ -100,8 +103,25 @@ func (h *HTTPServer) ListenAndServe(tlsConfig *tls.Config, httpAlive, httpsAlive httpAlive <- true if err := h.nontlsserver.ListenAndServe(); err != nil { httpAlive <- false - gologger.Error().Msgf("Could not serve http: %s\n", err) + if err != http.ErrServerClosed { + gologger.Error().Msgf("Could not serve http: %s\n", err) + } + } +} + +// Close gracefully shuts down both HTTP and HTTPS servers +func (h *HTTPServer) Close(ctx context.Context) error { + var err1, err2 error + if h.nontlsserver.Addr != "" { + err1 = h.nontlsserver.Shutdown(ctx) + } + if h.tlsserver.Addr != "" { + err2 = h.tlsserver.Shutdown(ctx) + } + if err1 != nil { + return err1 } + return err2 } func (h *HTTPServer) logger(handler http.Handler) http.HandlerFunc { diff --git a/pkg/server/http_server_test.go b/pkg/server/http_server_test.go index be666ad8..ac125cb3 100644 --- a/pkg/server/http_server_test.go +++ b/pkg/server/http_server_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "io" "net/http" "net/http/httptest" @@ -8,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.uber.org/goleak" ) func TestWriteResponseFromDynamicRequest(t *testing.T) { @@ -38,15 +40,15 @@ func TestWriteResponseFromDynamicRequest(t *testing.T) { require.Equal(t, "this is example body", string(body), "could not get correct result") }) - t.Run("b64_body", func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com/?b64_body=dGhpcyBpcyBleGFtcGxlIGJvZHk=", nil) - w := httptest.NewRecorder() - writeResponseFromDynamicRequest(w, req) + t.Run("b64_body", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com/?b64_body=dGhpcyBpcyBleGFtcGxlIGJvZHk=", nil) + w := httptest.NewRecorder() + writeResponseFromDynamicRequest(w, req) - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - require.Equal(t, "this is example body", string(body), "could not get correct result") - }) + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + require.Equal(t, "this is example body", string(body), "could not get correct result") + }) t.Run("header", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/?header=Key:value&header=Test:Another", nil) w := httptest.NewRecorder() @@ -57,3 +59,35 @@ func TestWriteResponseFromDynamicRequest(t *testing.T) { require.Equal(t, resp.Header.Get("Test"), "Another", "could not get correct result") }) } + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestHTTPServer_NoLeak_WhenClosed(t *testing.T) { + t.Parallel() + + opts := &Options{ + Domains: []string{"example.com"}, + ListenIP: "127.0.0.1", + HttpPort: 0, + HttpsPort: 0, + CorrelationIdLength: 8, + CorrelationIdNonceLength: 6, + } + s, err := NewHTTPServer(opts) + require.NoError(t, err) + + httpAlive := make(chan bool, 1) + httpsAlive := make(chan bool, 1) + go s.ListenAndServe(nil, httpAlive, httpsAlive) + select { + case <-httpAlive: + case <-time.After(200 * time.Millisecond): + t.Fatalf("server did not start") + } + + _ = s.Close(context.Background()) + + time.Sleep(200 * time.Millisecond) +}