From fef1ad48a919f1105d524d8279f8a9283fb4f8f7 Mon Sep 17 00:00:00 2001 From: Erik Zilber Date: Tue, 3 Sep 2024 11:40:25 -0400 Subject: [PATCH] Added new middleware system (#571) --- client.go | 45 ++++++++++++++++--- client_http.go | 4 ++ client_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 157 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 348e71cc9..28433b766 100644 --- a/client.go +++ b/client.go @@ -136,7 +136,7 @@ type RequestParams struct { // Generic helper to execute HTTP requests using the net/http package // // nolint:unused, funlen, gocognit -func (c *httpClient) doRequest(ctx context.Context, method, url string, params RequestParams, mutators ...func(req *http.Request) error) error { +func (c *httpClient) doRequest(ctx context.Context, method, url string, params RequestParams) error { var ( req *http.Request bodyBuffer *bytes.Buffer @@ -150,7 +150,7 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R return err } - if err = c.applyMutators(req, mutators); err != nil { + if err = c.applyBeforeRequest(req); err != nil { return err } @@ -180,6 +180,12 @@ func (c *httpClient) doRequest(ctx context.Context, method, url string, params R return err } } + + // Apply after-response mutations + if err = c.applyAfterResponse(resp); err != nil { + return err + } + return nil } @@ -252,13 +258,26 @@ func (c *httpClient) createRequest(ctx context.Context, method, url string, para } // nolint:unused -func (c *httpClient) applyMutators(req *http.Request, mutators []func(req *http.Request) error) error { - for _, mutate := range mutators { +func (c *httpClient) applyBeforeRequest(req *http.Request) error { + for _, mutate := range c.onBeforeRequest { if err := mutate(req); err != nil { if c.debug && c.logger != nil { - c.logger.Errorf("failed to mutate request: %v", err) + c.logger.Errorf("failed to mutate before request: %v", err) + } + return fmt.Errorf("failed to mutate before request: %w", err) + } + } + return nil +} + +// nolint:unused +func (c *httpClient) applyAfterResponse(resp *http.Response) error { + for _, mutate := range c.onAfterResponse { + if err := mutate(resp); err != nil { + if c.debug && c.logger != nil { + c.logger.Errorf("failed to mutate after response: %v", err) } - return fmt.Errorf("failed to mutate request: %w", err) + return fmt.Errorf("failed to mutate after response: %w", err) } } return nil @@ -394,6 +413,20 @@ func (c *Client) OnAfterResponse(m func(response *Response) error) { }) } +// nolint:unused +func (c *httpClient) httpOnBeforeRequest(m func(*http.Request) error) *httpClient { + c.onBeforeRequest = append(c.onBeforeRequest, m) + + return c +} + +// nolint:unused +func (c *httpClient) httpOnAfterResponse(m func(*http.Response) error) *httpClient { + c.onAfterResponse = append(c.onAfterResponse, m) + + return c +} + // UseURL parses the individual components of the given API URL and configures the client // accordingly. For example, a valid URL. // For example: diff --git a/client_http.go b/client_http.go index 85058e6cf..7f16362c5 100644 --- a/client_http.go +++ b/client_http.go @@ -49,4 +49,8 @@ type httpClient struct { cachedEntryLock *sync.RWMutex //nolint:unused logger httpLogger + //nolint:unused + onBeforeRequest []func(*http.Request) error + //nolint:unused + onAfterResponse []func(*http.Response) error } diff --git a/client_test.go b/client_test.go index 2b29204d1..93b133f97 100644 --- a/client_test.go +++ b/client_test.go @@ -313,7 +313,42 @@ func TestDoRequest_FailedDecodeResponse(t *testing.T) { } } -func TestDoRequest_MutatorError(t *testing.T) { +func TestDoRequest_BeforeRequestSuccess(t *testing.T) { + var capturedRequest *http.Request + + handler := func(w http.ResponseWriter, r *http.Request) { + capturedRequest = r // Capture the request to inspect it later + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"message":"success"}`)) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &httpClient{ + httpClient: server.Client(), + } + + // Define a mutator that successfully modifies the request + mutator := func(req *http.Request) error { + req.Header.Set("X-Custom-Header", "CustomValue") + return nil + } + + client.httpOnBeforeRequest(mutator) + + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Check if the header was successfully added to the captured request + if reqHeader := capturedRequest.Header.Get("X-Custom-Header"); reqHeader != "CustomValue" { + t.Fatalf("expected X-Custom-Header to be set to CustomValue, got: %v", reqHeader) + } +} + +func TestDoRequest_BeforeRequestError(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") @@ -330,8 +365,71 @@ func TestDoRequest_MutatorError(t *testing.T) { return errors.New("mutator error") } - err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}, mutator) - expectedErr := "failed to mutate request" + client.httpOnBeforeRequest(mutator) + + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + expectedErr := "failed to mutate before request" + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("expected error %q, got: %v", expectedErr, err) + } +} + +func TestDoRequest_AfterResponseSuccess(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"message":"success"}`)) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + // Create a custom RoundTripper to capture the response + tr := &testRoundTripper{ + Transport: server.Client().Transport, + } + client := &httpClient{ + httpClient: &http.Client{Transport: tr}, + } + + mutator := func(resp *http.Response) error { + resp.Header.Set("X-Modified-Header", "ModifiedValue") + return nil + } + + client.httpOnAfterResponse(mutator) + + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Check if the header was successfully added to the response + if respHeader := tr.Response.Header.Get("X-Modified-Header"); respHeader != "ModifiedValue" { + t.Fatalf("expected X-Modified-Header to be set to ModifiedValue, got: %v", respHeader) + } +} + +func TestDoRequest_AfterResponseError(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"message":"success"}`)) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + client := &httpClient{ + httpClient: server.Client(), + } + + mutator := func(resp *http.Response) error { + return errors.New("mutator error") + } + + client.httpOnAfterResponse(mutator) + + err := client.doRequest(context.Background(), http.MethodGet, server.URL, RequestParams{}) + expectedErr := "failed to mutate after response" if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected error %q, got: %v", expectedErr, err) } @@ -426,3 +524,16 @@ func removeTimestamps(log string) string { } return strings.Join(filteredLines, "\n") } + +type testRoundTripper struct { + Transport http.RoundTripper + Response *http.Response +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.Transport.RoundTrip(req) + if err == nil { + t.Response = resp + } + return resp, err +}