From ccf25f01c1129a6c9607862e1ac8ad5c9b8dae2d Mon Sep 17 00:00:00 2001 From: Changyu Moon <121847433+window9u@users.noreply.github.com> Date: Thu, 6 Feb 2025 18:20:12 +0900 Subject: [PATCH] Add Request Timing Controls for Auth Webhook Client (#1142) Introduces MinWaitInterval and RequestTimeout configurations to enhance authentication request handling. These settings prevent connection timeouts, memory leaks, and enable fine-tuning of retry logic through exponential backoff parameters. Default values are optimized for common usage patterns with 100ms MinWaitInterval and 3s RequestTimeout. --------- Co-authored-by: Youngteac Hong --- cmd/yorkie/server.go | 20 ++- pkg/webhook/client.go | 124 ++++++++------- pkg/webhook/client_test.go | 291 +++++++++++++++++++++++++--------- server/backend/backend.go | 33 +++- server/backend/config.go | 44 +++++ server/backend/config_test.go | 14 +- server/config.go | 10 ++ server/config.sample.yml | 10 +- server/config_test.go | 8 + server/packs/packs_test.go | 22 +-- server/rpc/auth/webhook.go | 76 ++++++--- server/rpc/server_test.go | 22 +-- test/complex/main_test.go | 1 + test/helper/helper.go | 4 + 14 files changed, 484 insertions(+), 195 deletions(-) diff --git a/cmd/yorkie/server.go b/cmd/yorkie/server.go index 4e5ef04e4..6d9714f9f 100644 --- a/cmd/yorkie/server.go +++ b/cmd/yorkie/server.go @@ -48,6 +48,8 @@ var ( mongoPingTimeout time.Duration authWebhookMaxWaitInterval time.Duration + authWebhookMinWaitInterval time.Duration + authWebhookRequestTimeout time.Duration authWebhookCacheTTL time.Duration projectCacheTTL time.Duration @@ -64,6 +66,8 @@ func newServerCmd() *cobra.Command { conf.Backend.ClientDeactivateThreshold = clientDeactivateThreshold conf.Backend.AuthWebhookMaxWaitInterval = authWebhookMaxWaitInterval.String() + conf.Backend.AuthWebhookMinWaitInterval = authWebhookMinWaitInterval.String() + conf.Backend.AuthWebhookRequestTimeout = authWebhookRequestTimeout.String() conf.Backend.AuthWebhookCacheTTL = authWebhookCacheTTL.String() conf.Backend.ProjectCacheTTL = projectCacheTTL.String() @@ -295,17 +299,29 @@ func init() { server.DefaultSnapshotDisableGC, "Whether to disable garbage collection of snapshots.", ) + cmd.Flags().DurationVar( + &authWebhookRequestTimeout, + "auth-webhook-request-timeout", + server.DefaultAuthWebhookRequestTimeout, + "Timeout for each authorization webhook request.", + ) cmd.Flags().Uint64Var( &conf.Backend.AuthWebhookMaxRetries, "auth-webhook-max-retries", server.DefaultAuthWebhookMaxRetries, - "Maximum number of retries for an authorization webhook.", + "Maximum number of retries for authorization webhook.", + ) + cmd.Flags().DurationVar( + &authWebhookMinWaitInterval, + "auth-webhook-min-wait-interval", + server.DefaultAuthWebhookMinWaitInterval, + "Minimum wait interval between retries(exponential backoff).", ) cmd.Flags().DurationVar( &authWebhookMaxWaitInterval, "auth-webhook-max-wait-interval", server.DefaultAuthWebhookMaxWaitInterval, - "Maximum wait interval for authorization webhook.", + "Maximum wait interval between retries(exponential backoff).", ) cmd.Flags().IntVar( &conf.Backend.AuthWebhookCacheSize, diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 9fb76bde4..f91afa34b 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -31,8 +31,6 @@ import ( "syscall" "time" - "github.com/yorkie-team/yorkie/pkg/cache" - "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/server/logging" ) @@ -47,54 +45,53 @@ var ( ErrWebhookTimeout = errors.New("webhook timeout") ) -// Options are the options for the webhook client. +// Options are the options for the webhook httpClient. type Options struct { - CacheKeyPrefix string - CacheTTL time.Duration - + RequestTimeout time.Duration MaxRetries uint64 + MinWaitInterval time.Duration MaxWaitInterval time.Duration - - HMACKey string } -// Client is a client for the webhook. +// Client is a httpClient for the webhook. type Client[Req any, Res any] struct { - cache *cache.LRUExpireCache[string, types.Pair[int, *Res]] - url string - options Options + httpClient *http.Client + options Options } // NewClient creates a new instance of Client. func NewClient[Req any, Res any]( - url string, - Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]], options Options, ) *Client[Req, Res] { return &Client[Req, Res]{ - url: url, - cache: Cache, + httpClient: &http.Client{ + Timeout: options.RequestTimeout, + }, options: options, } } // Send sends the given request to the webhook. -func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) { - body, err := json.Marshal(req) +func (c *Client[Req, Res]) Send( + ctx context.Context, + url, hmacKey string, + body []byte, +) (*Res, int, error) { + signature, err := createSignature(body, hmacKey) if err != nil { - return nil, 0, fmt.Errorf("marshal webhook request: %w", err) - } - - cacheKey := c.options.CacheKeyPrefix + ":" + string(body) - if entry, ok := c.cache.Get(cacheKey); ok { - return entry.Second, entry.First, nil + return nil, 0, fmt.Errorf("create signature: %w", err) } var res Res status, err := c.withExponentialBackoff(ctx, func() (int, error) { - resp, err := c.post("application/json", body) + req, err := c.buildRequest(ctx, url, signature, body) + if err != nil { + return 0, fmt.Errorf("build request: %w", err) + } + + resp, err := c.httpClient.Do(req) if err != nil { - return 0, fmt.Errorf("post to webhook: %w", err) + return 0, fmt.Errorf("do request: %w", err) } defer func() { if err := resp.Body.Close(); err != nil { @@ -103,9 +100,7 @@ func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) } }() - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusUnauthorized && - resp.StatusCode != http.StatusForbidden { + if !isExpectedStatus(resp.StatusCode) { return resp.StatusCode, ErrUnexpectedStatusCode } @@ -119,55 +114,58 @@ func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) return nil, status, err } - // TODO(hackerwins): We should consider caching the response of Unauthorized as well. - if status != http.StatusUnauthorized { - c.cache.Add(cacheKey, types.Pair[int, *Res]{First: status, Second: &res}, c.options.CacheTTL) - } - return &res, status, nil } -// post sends an HTTP POST request with HMAC-SHA256 signature headers. -// If key is empty, post sends an HTTP POST without signature. -func (c *Client[Req, Res]) post(contentType string, body []byte) (*http.Response, error) { - req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(body)) +// buildRequest creates a new HTTP POST request with the appropriate headers. +func (c *Client[Req, Res]) buildRequest( + ctx context.Context, + url, hmac string, + body []byte, +) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(body)) if err != nil { - return nil, fmt.Errorf("create HTTP request: %w", err) + return nil, fmt.Errorf("create POST request with context: %w", err) } - req.Header.Set("Content-Type", contentType) - if c.options.HMACKey != "" { - mac := hmac.New(sha256.New, []byte(c.options.HMACKey)) - if _, err := mac.Write(body); err != nil { - return nil, fmt.Errorf("write HMAC body: %w", err) - } - signature := mac.Sum(nil) - signatureHex := hex.EncodeToString(signature) // Convert to hex string - req.Header.Set("X-Signature-256", fmt.Sprintf("sha256=%s", signatureHex)) - } + req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("send to %s: %w", c.url, err) // Wrapped with context + if hmac != "" { + req.Header.Set("X-Signature-256", hmac) } - return resp, nil + return req, nil +} + +// createSignature sets the HMAC signature header for the request. +func createSignature(data []byte, hmacKey string) (string, error) { + if hmacKey == "" { + return "", nil + } + mac := hmac.New(sha256.New, []byte(hmacKey)) + if _, err := mac.Write(data); err != nil { + return "", fmt.Errorf("write HMAC body: %w", err) + } + signatureHex := hex.EncodeToString(mac.Sum(nil)) + return fmt.Sprintf("sha256=%s", signatureHex), nil } func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) { var retries uint64 var statusCode int + var err error + for retries <= c.options.MaxRetries { - statusCode, err := webhookFn() + statusCode, err = webhookFn() if !shouldRetry(statusCode, err) { - if err == ErrUnexpectedStatusCode { + if errors.Is(err, ErrUnexpectedStatusCode) { return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) } return statusCode, err } - waitBeforeRetry := waitInterval(retries, c.options.MaxWaitInterval) + waitBeforeRetry := waitInterval(retries, c.options.MinWaitInterval, c.options.MaxWaitInterval) select { case <-ctx.Done(): @@ -181,9 +179,9 @@ func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout) } -// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds. -func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration { - interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond +// waitInterval returns the interval of given retries. (2^retries * minWaitInterval) . +func waitInterval(retries uint64, minWaitInterval, maxWaitInterval time.Duration) time.Duration { + interval := time.Duration(math.Pow(2, float64(retries))) * minWaitInterval if maxWaitInterval < interval { return maxWaitInterval } @@ -197,7 +195,7 @@ func shouldRetry(statusCode int, err error) bool { // If the connection is reset, we should retry. var errno syscall.Errno if errors.As(err, &errno) { - return errno == syscall.ECONNRESET + return errors.Is(errno, syscall.ECONNRESET) } return statusCode == http.StatusInternalServerError || @@ -205,3 +203,9 @@ func shouldRetry(statusCode int, err error) bool { statusCode == http.StatusGatewayTimeout || statusCode == http.StatusTooManyRequests } + +func isExpectedStatus(statusCode int) bool { + return statusCode == http.StatusOK || + statusCode == http.StatusUnauthorized || + statusCode == http.StatusForbidden +} diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 61f7004eb..b0afa1a34 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -11,13 +11,12 @@ import ( "io" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/yorkie-team/yorkie/pkg/cache" - "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/pkg/webhook" ) @@ -31,6 +30,7 @@ type testResponse struct { Greeting string `json:"greeting"` } +// verifySignature verifies that the HMAC signature in the header matches the expected value. func verifySignature(signatureHeader, secret string, body []byte) error { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(body) @@ -39,115 +39,252 @@ func verifySignature(signatureHeader, secret string, body []byte) error { if !hmac.Equal([]byte(signatureHeader), []byte(expectedSigHeader)) { return errors.New("signature validation failed") } - return nil } -func TestHMAC(t *testing.T) { - const secretKey = "my-secret-key" - const wrongKey = "wrong-key" - reqData := testRequest{Name: "HMAC Tester"} - resData := testResponse{Greeting: "HMAC OK"} - - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// newHMACTestServer creates a new httptest.Server that verifies the HMAC signature. +// It returns a valid JSON response if the signature is correct. +func newHMACTestServer(t *testing.T, validSecret string, responseData testResponse) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { signatureHeader := r.Header.Get("X-Signature-256") if signatureHeader == "" { - w.WriteHeader(http.StatusUnauthorized) + http.Error(w, "unauthorized", http.StatusUnauthorized) return } + bodyBytes, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "bad request", http.StatusBadRequest) return } - if err := verifySignature(signatureHeader, secretKey, bodyBytes); err != nil { - w.WriteHeader(http.StatusForbidden) + if err := verifySignature(signatureHeader, validSecret, bodyBytes); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) + })) +} + +func newRetryServer(t *testing.T, replyAfter int, responseData testResponse) *httptest.Server { + var requestCount int32 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := int(atomic.AddInt32(&requestCount, 1)) + if count < replyAfter { + w.WriteHeader(http.StatusServiceUnavailable) return } + + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - assert.NoError(t, json.NewEncoder(w).Encode(resData)) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) + })) +} + +func newDelayServer(t *testing.T, delayTime time.Duration, responseData testResponse) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), delayTime) + defer cancel() + + select { + case <-ctx.Done(): + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) + } })) +} +func TestHMAC(t *testing.T) { + const validSecret = "my-secret-key" + const invalidSecret = "wrong-key" + expectedResponse := testResponse{Greeting: "HMAC OK"} + + testServer := newHMACTestServer(t, validSecret, expectedResponse) defer testServer.Close() - t.Run("webhook client with valid HMAC key test", func(t *testing.T) { - client := webhook.NewClient[testRequest, testResponse]( - testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - HMACKey: secretKey, - }, - ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, reqData) + client := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + RequestTimeout: 1 * time.Second, + }) + + t.Run("valid HMAC key test", func(t *testing.T) { + reqPayload := testRequest{Name: "ValidHMAC"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := client.Send(context.Background(), testServer.URL, validSecret, body) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) - assert.Equal(t, resData.Greeting, resp.Greeting) - }) - - t.Run("webhook client with invalid HMAC key test", func(t *testing.T) { - client := webhook.NewClient[testRequest, testResponse]( - testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - HMACKey: wrongKey, - }, - ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, reqData) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) + }) + + t.Run("invalid HMAC key test", func(t *testing.T) { + reqPayload := testRequest{Name: "InvalidHMAC"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := client.Send(context.Background(), testServer.URL, invalidSecret, body) assert.Error(t, err) + // The server responds with 403 Forbidden if the signature is invalid. assert.Equal(t, http.StatusForbidden, statusCode) assert.Nil(t, resp) }) - t.Run("webhook client without HMAC key test", func(t *testing.T) { - client := webhook.NewClient[testRequest]( - testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - }, - ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, reqData) + t.Run("missing HMAC key test", func(t *testing.T) { + reqPayload := testRequest{Name: "MissingHMAC"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := client.Send(context.Background(), testServer.URL, "", body) assert.Error(t, err) + // The server responds with 401 Unauthorized if no signature header is provided. assert.Equal(t, http.StatusUnauthorized, statusCode) assert.Nil(t, resp) }) - t.Run("webhook client with empty body test", func(t *testing.T) { - client := webhook.NewClient[testRequest]( - testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - HMACKey: secretKey, - }, - ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, testRequest{}) + t.Run("empty body test", func(t *testing.T) { + reqPayload := testRequest{} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := client.Send(context.Background(), testServer.URL, validSecret, body) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, statusCode) + assert.NotNil(t, resp) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) + }) +} + +func TestBackoff(t *testing.T) { + replyAfter := 4 + reachableRetries := replyAfter - 1 + unreachableRetries := replyAfter - 2 + expectedResponse := testResponse{Greeting: "retry succeed"} + server := newRetryServer(t, replyAfter, expectedResponse) + defer server.Close() + + reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 10 * time.Millisecond, + MaxRetries: uint64(reachableRetries), + MinWaitInterval: 1 * time.Millisecond, + MaxWaitInterval: 5 * time.Millisecond, + }) + + unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 10 * time.Millisecond, + MaxRetries: uint64(unreachableRetries), + MinWaitInterval: 1 * time.Millisecond, + MaxWaitInterval: 5 * time.Millisecond, + }) + + t.Run("retry fail test", func(t *testing.T) { + reqPayload := testRequest{Name: "retry fails"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := unreachableClient.Send(context.Background(), server.URL, "", body) + assert.Error(t, err) + assert.ErrorContains(t, err, webhook.ErrWebhookTimeout.Error()) + assert.Equal(t, http.StatusServiceUnavailable, statusCode) + assert.Nil(t, resp) + }) + + t.Run("retry succeed timeout", func(t *testing.T) { + reqPayload := testRequest{Name: "retry succeed"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := reachableClient.Send(context.Background(), server.URL, "", body) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, statusCode) + assert.NotNil(t, resp) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) + }) +} + +func TestRequestTimeout(t *testing.T) { + delayTime := 10 * time.Millisecond + expectedResponse := testResponse{Greeting: "hello"} + server := newDelayServer(t, delayTime, expectedResponse) + defer server.Close() + + reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 15 * time.Millisecond, + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + }) + + unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 5 * time.Millisecond, + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + }) + + t.Run("request succeed after timeout", func(t *testing.T) { + reqPayload := testRequest{Name: "TimeoutTest"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := reachableClient.Send(context.Background(), server.URL, "", body) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) - assert.Equal(t, resData.Greeting, resp.Greeting) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) + }) + + t.Run("request fails with timeout test", func(t *testing.T) { + reqPayload := testRequest{Name: "TimeoutTest"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := unreachableClient.Send(context.Background(), server.URL, "", body) + assert.Error(t, err) + assert.Equal(t, 0, statusCode) + assert.Nil(t, resp) + }) +} + +func TestErrorHandling(t *testing.T) { + expectedResponse := testResponse{Greeting: "hello"} + server := newRetryServer(t, 2, expectedResponse) + defer server.Close() + + unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 50 * time.Millisecond, + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + }) + + t.Run("request fails with context done test", func(t *testing.T) { + reqPayload := testRequest{Name: "ContextDone"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + resp, statusCode, err := unreachableClient.Send(ctx, server.URL, "", body) + assert.Error(t, err) + assert.Equal(t, http.StatusServiceUnavailable, statusCode) + assert.Nil(t, resp) + }) + + t.Run("request fails with unreachable url test", func(t *testing.T) { + reqPayload := testRequest{Name: "invalidURL"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := unreachableClient.Send(context.Background(), "", "", body) + assert.Error(t, err) + assert.Equal(t, 0, statusCode) + assert.Nil(t, resp) }) } diff --git a/server/backend/backend.go b/server/backend/backend.go index cf9bd4bf3..9d1a95839 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -27,6 +27,7 @@ import ( "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/pkg/cache" pkgtypes "github.com/yorkie-team/yorkie/pkg/types" + "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server/backend/background" "github.com/yorkie-team/yorkie/server/backend/database" memdb "github.com/yorkie-team/yorkie/server/backend/database/memory" @@ -48,6 +49,9 @@ type Backend struct { int, *types.AuthWebhookResponse, ]] + // WebhookClient is used to send auth webhook. + WebhookClient *webhook.Client[types.AuthWebhookRequest, types.AuthWebhookResponse] + // PubSub is used to publish/subscribe events to/from clients. PubSub *pubsub.PubSub // Locker is used to lock/unlock resources. @@ -82,18 +86,28 @@ func New( conf.Hostname = hostname } - // 02. Create in-memory cache, pubsub, and locker. - cache := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( + // 02. Create the webhook webhookCache and client. + webhookCache := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( conf.AuthWebhookCacheSize, ) + webhookClient := webhook.NewClient[types.AuthWebhookRequest, types.AuthWebhookResponse]( + webhook.Options{ + MaxRetries: conf.AuthWebhookMaxRetries, + MinWaitInterval: conf.ParseAuthWebhookMinWaitInterval(), + MaxWaitInterval: conf.ParseAuthWebhookMaxWaitInterval(), + RequestTimeout: conf.ParseAuthWebhookRequestTimeout(), + }, + ) + + // 03. Create pubsub, and locker. locker := sync.New() pubsub := pubsub.New() - // 03. Create the background instance. The background instance is used to + // 04. Create the background instance. The background instance is used to // manage background tasks. bg := background.New(metrics) - // 04. Create the database instance. If the MongoDB configuration is given, + // 05. Create the database instance. If the MongoDB configuration is given, // create a MongoDB instance. Otherwise, create a memory database instance. var err error var db database.Database @@ -141,10 +155,13 @@ func New( ) return &Backend{ - Config: conf, - WebhookCache: cache, - Locker: locker, - PubSub: pubsub, + Config: conf, + + WebhookCache: webhookCache, + WebhookClient: webhookClient, + + Locker: locker, + PubSub: pubsub, Metrics: metrics, DB: db, diff --git a/server/backend/config.go b/server/backend/config.go index 89399ddb7..9f1767c6c 100644 --- a/server/backend/config.go +++ b/server/backend/config.go @@ -64,6 +64,12 @@ type Config struct { // AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. AuthWebhookMaxWaitInterval string `yaml:"AuthWebhookMaxWaitInterval"` + // AuthWebhookMinWaitInterval is the min interval that waits before retrying the authorization webhook. + AuthWebhookMinWaitInterval string `yaml:"AuthWebhookMinWaitInterval"` + + // AuthWebhookRequestTimeout is the max waiting time per auth webhook request + AuthWebhookRequestTimeout string `yaml:"AuthWebhookRequestTimeout"` + // AuthWebhookCacheSize is the cache size of the authorization webhook. AuthWebhookCacheSize int `yaml:"AuthWebhookCacheSize"` @@ -101,6 +107,22 @@ func (c *Config) Validate() error { ) } + if _, err := time.ParseDuration(c.AuthWebhookMinWaitInterval); err != nil { + return fmt.Errorf( + `invalid argument "%s" for "--auth-webhook-min-wait-interval" flag: %w`, + c.AuthWebhookMinWaitInterval, + err, + ) + } + + if _, err := time.ParseDuration(c.AuthWebhookRequestTimeout); err != nil { + return fmt.Errorf( + `invalid argument "%s" for "--auth-webhook-request-timeout" flag: %w`, + c.AuthWebhookRequestTimeout, + err, + ) + } + if _, err := time.ParseDuration(c.AuthWebhookCacheTTL); err != nil { return fmt.Errorf( `invalid argument "%s" for "--auth-webhook-cache-ttl" flag: %w`, @@ -142,6 +164,28 @@ func (c *Config) ParseAuthWebhookMaxWaitInterval() time.Duration { return result } +// ParseAuthWebhookMinWaitInterval returns min wait interval. +func (c *Config) ParseAuthWebhookMinWaitInterval() time.Duration { + result, err := time.ParseDuration(c.AuthWebhookMinWaitInterval) + if err != nil { + fmt.Fprintln(os.Stderr, "parse auth webhook min wait interval: %w", err) + os.Exit(1) + } + + return result +} + +// ParseAuthWebhookRequestTimeout returns request timeout. +func (c *Config) ParseAuthWebhookRequestTimeout() time.Duration { + result, err := time.ParseDuration(c.AuthWebhookRequestTimeout) + if err != nil { + fmt.Fprintln(os.Stderr, "parse auth webhook request timeout: %w", err) + os.Exit(1) + } + + return result +} + // ParseAuthWebhookCacheTTL returns TTL for authorized cache. func (c *Config) ParseAuthWebhookCacheTTL() time.Duration { result, err := time.ParseDuration(c.AuthWebhookCacheTTL) diff --git a/server/backend/config_test.go b/server/backend/config_test.go index c4b8247eb..a7b3a34ba 100644 --- a/server/backend/config_test.go +++ b/server/backend/config_test.go @@ -29,6 +29,8 @@ func TestConfig(t *testing.T) { validConf := backend.Config{ ClientDeactivateThreshold: "1h", AuthWebhookMaxWaitInterval: "0ms", + AuthWebhookMinWaitInterval: "0ms", + AuthWebhookRequestTimeout: "0ms", AuthWebhookCacheTTL: "10s", ProjectCacheTTL: "10m", } @@ -43,11 +45,19 @@ func TestConfig(t *testing.T) { assert.Error(t, conf2.Validate()) conf3 := validConf - conf3.AuthWebhookCacheTTL = "s" + conf3.AuthWebhookMinWaitInterval = "3" assert.Error(t, conf3.Validate()) conf4 := validConf - conf4.ProjectCacheTTL = "10 minutes" + conf4.AuthWebhookRequestTimeout = "1" assert.Error(t, conf4.Validate()) + + conf5 := validConf + conf5.AuthWebhookCacheTTL = "s" + assert.Error(t, conf5.Validate()) + + conf6 := validConf + conf6.ProjectCacheTTL = "10 minutes" + assert.Error(t, conf6.Validate()) }) } diff --git a/server/config.go b/server/config.go index fd8589dbd..ec63a8958 100644 --- a/server/config.go +++ b/server/config.go @@ -60,8 +60,10 @@ const ( DefaultSnapshotWithPurgingChanges = false DefaultSnapshotDisableGC = false + DefaultAuthWebhookRequestTimeout = 3 * time.Second DefaultAuthWebhookMaxRetries = 10 DefaultAuthWebhookMaxWaitInterval = 3000 * time.Millisecond + DefaultAuthWebhookMinWaitInterval = 100 * time.Millisecond DefaultAuthWebhookCacheSize = 5000 DefaultAuthWebhookCacheTTL = 10 * time.Second DefaultProjectCacheSize = 256 @@ -185,6 +187,14 @@ func (c *Config) ensureDefaultValue() { c.Backend.AuthWebhookMaxWaitInterval = DefaultAuthWebhookMaxWaitInterval.String() } + if c.Backend.AuthWebhookMinWaitInterval == "" { + c.Backend.AuthWebhookMinWaitInterval = DefaultAuthWebhookMinWaitInterval.String() + } + + if c.Backend.AuthWebhookRequestTimeout == "" { + c.Backend.AuthWebhookRequestTimeout = DefaultAuthWebhookRequestTimeout.String() + } + if c.Backend.AuthWebhookCacheTTL == "" { c.Backend.AuthWebhookCacheTTL = DefaultAuthWebhookCacheTTL.String() } diff --git a/server/config.sample.yml b/server/config.sample.yml index b7f124b11..71b09fb5d 100644 --- a/server/config.sample.yml +++ b/server/config.sample.yml @@ -65,10 +65,16 @@ Backend: # AuthWebhookMethods is the list of methods to use for authorization. AuthWebhookMethods: [] - # AuthWebhookMaxRetries is the max count that retries the authorization webhook. + # AuthWebhookRequestTimeout is the timeout for each authorization webhook request. + AuthWebhookRequestTimeout: "3s" + + # AuthWebhookMaxRetries is the max number of retries for the authorization webhook. AuthWebhookMaxRetries: 10 - # AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. + # AuthWebhookMinWaitInterval is the minimum wait interval between retries(exponential backoff). + AuthWebhookMinWaitInterval: "100ms" + + # AuthWebhookMaxWaitInterval is the maximum wait interval between retries(exponential backoff). AuthWebhookMaxWaitInterval: "3s" # AuthWebhookCacheTTL is the TTL value to set when caching the authorized result. diff --git a/server/config_test.go b/server/config_test.go index 5ca5008a9..d4b83ee41 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -70,6 +70,14 @@ func TestNewConfigFromFile(t *testing.T) { assert.NoError(t, err) assert.Equal(t, authWebhookMaxWaitInterval, server.DefaultAuthWebhookMaxWaitInterval) + authWebhookMinWaitInterval, err := time.ParseDuration(conf.Backend.AuthWebhookMinWaitInterval) + assert.NoError(t, err) + assert.Equal(t, authWebhookMinWaitInterval, server.DefaultAuthWebhookMinWaitInterval) + + authWebhookRequestTimeout, err := time.ParseDuration(conf.Backend.AuthWebhookRequestTimeout) + assert.NoError(t, err) + assert.Equal(t, authWebhookRequestTimeout, server.DefaultAuthWebhookRequestTimeout) + authWebhookCacheTTL, err := time.ParseDuration(conf.Backend.AuthWebhookCacheTTL) assert.NoError(t, err) assert.Equal(t, authWebhookCacheTTL, server.DefaultAuthWebhookCacheTTL) diff --git a/server/packs/packs_test.go b/server/packs/packs_test.go index fa7413cfd..b5ac2bede 100644 --- a/server/packs/packs_test.go +++ b/server/packs/packs_test.go @@ -94,15 +94,19 @@ func TestMain(m *testing.M) { testBackend, err = backend.New( &backend.Config{ - AdminUser: helper.AdminUser, - AdminPassword: helper.AdminPassword, - UseDefaultProject: helper.UseDefaultProject, - ClientDeactivateThreshold: helper.ClientDeactivateThreshold, - SnapshotThreshold: helper.SnapshotThreshold, - AuthWebhookCacheSize: helper.AuthWebhookSize, - ProjectCacheSize: helper.ProjectCacheSize, - ProjectCacheTTL: helper.ProjectCacheTTL.String(), - AdminTokenDuration: helper.AdminTokenDuration, + AdminUser: helper.AdminUser, + AdminPassword: helper.AdminPassword, + UseDefaultProject: helper.UseDefaultProject, + ClientDeactivateThreshold: helper.ClientDeactivateThreshold, + SnapshotThreshold: helper.SnapshotThreshold, + AuthWebhookCacheSize: helper.AuthWebhookSize, + AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), + AuthWebhookMaxWaitInterval: helper.AuthWebhookMaxWaitInterval.String(), + AuthWebhookMinWaitInterval: helper.AuthWebhookMinWaitInterval.String(), + AuthWebhookRequestTimeout: helper.AuthWebhookRequestTimeout.String(), + ProjectCacheSize: helper.ProjectCacheSize, + ProjectCacheTTL: helper.ProjectCacheTTL.String(), + AdminTokenDuration: helper.AdminTokenDuration, }, &mongo.Config{ ConnectionURI: helper.MongoConnectionURI, YorkieDatabase: helper.TestDBName(), diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index 218390f78..cb27e5df3 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -18,12 +18,14 @@ package auth import ( "context" + "encoding/json" "errors" "fmt" "net/http" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/internal/metaerrors" + pkgtypes "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server/backend" ) @@ -44,41 +46,63 @@ func verifyAccess( token string, accessInfo *types.AccessInfo, ) error { - cli := webhook.NewClient[types.AuthWebhookRequest]( - prj.AuthWebhookURL, - be.WebhookCache, - webhook.Options{ - CacheKeyPrefix: prj.PublicKey + ":auth", - CacheTTL: be.Config.ParseAuthWebhookCacheTTL(), - MaxRetries: be.Config.AuthWebhookMaxRetries, - MaxWaitInterval: be.Config.ParseAuthWebhookMaxWaitInterval(), - }, - ) - - res, status, err := cli.Send(ctx, types.AuthWebhookRequest{ + req := types.AuthWebhookRequest{ Token: token, Method: accessInfo.Method, Attributes: accessInfo.Attributes, - }) + } + + body, err := json.Marshal(req) if err != nil { - return fmt.Errorf("send to webhook: %w", err) + return fmt.Errorf("marshal webhook request: %w", err) } - if status == http.StatusOK && res.Allowed { - return nil + cacheKey := generateCacheKey(prj.PublicKey, body) + if entry, ok := be.WebhookCache.Get(cacheKey); ok { + return handleWebhookResponse(entry.First, entry.Second) } - if status == http.StatusForbidden && !res.Allowed { - return metaerrors.New( - ErrPermissionDenied, - map[string]string{"reason": res.Reason}, - ) + + res, status, err := be.WebhookClient.Send( + ctx, + prj.AuthWebhookURL, + "", + body, + ) + if err != nil { + return fmt.Errorf("send to webhook: %w", err) } - if status == http.StatusUnauthorized && !res.Allowed { - return metaerrors.New( - ErrUnauthenticated, - map[string]string{"reason": res.Reason}, + + // TODO(hackerwins): We should consider caching the response of Unauthorized as well. + if status != http.StatusUnauthorized { + be.WebhookCache.Add( + cacheKey, + pkgtypes.Pair[int, *types.AuthWebhookResponse]{First: status, Second: res}, + be.Config.ParseAuthWebhookCacheTTL(), ) } - return fmt.Errorf("%d: %w", status, webhook.ErrUnexpectedResponse) + return handleWebhookResponse(status, res) +} + +// generateCacheKey creates a unique key for caching webhook responses. +func generateCacheKey(publicKey string, body []byte) string { + return fmt.Sprintf("%s:auth:%s", publicKey, body) +} + +// handleWebhookResponse processes the webhook response and returns an error if necessary. +func handleWebhookResponse(status int, res *types.AuthWebhookResponse) error { + if res == nil { + return fmt.Errorf("nil response for status %d: %w", status, webhook.ErrUnexpectedResponse) + } + + switch { + case status == http.StatusOK && res.Allowed: + return nil + case status == http.StatusForbidden && !res.Allowed: + return metaerrors.New(ErrPermissionDenied, map[string]string{"reason": res.Reason}) + case status == http.StatusUnauthorized && !res.Allowed: + return metaerrors.New(ErrUnauthenticated, map[string]string{"reason": res.Reason}) + default: + return fmt.Errorf("%d: %w", status, webhook.ErrUnexpectedResponse) + } } diff --git a/server/rpc/server_test.go b/server/rpc/server_test.go index 205733b35..a0175417b 100644 --- a/server/rpc/server_test.go +++ b/server/rpc/server_test.go @@ -68,15 +68,19 @@ func TestMain(m *testing.M) { } be, err := backend.New(&backend.Config{ - AdminUser: helper.AdminUser, - AdminPassword: helper.AdminPassword, - UseDefaultProject: helper.UseDefaultProject, - ClientDeactivateThreshold: helper.ClientDeactivateThreshold, - SnapshotThreshold: helper.SnapshotThreshold, - AuthWebhookCacheSize: helper.AuthWebhookSize, - ProjectCacheSize: helper.ProjectCacheSize, - ProjectCacheTTL: helper.ProjectCacheTTL.String(), - AdminTokenDuration: helper.AdminTokenDuration, + AdminUser: helper.AdminUser, + AdminPassword: helper.AdminPassword, + UseDefaultProject: helper.UseDefaultProject, + ClientDeactivateThreshold: helper.ClientDeactivateThreshold, + SnapshotThreshold: helper.SnapshotThreshold, + AuthWebhookCacheSize: helper.AuthWebhookSize, + AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), + AuthWebhookMaxWaitInterval: helper.AuthWebhookMaxWaitInterval.String(), + AuthWebhookMinWaitInterval: helper.AuthWebhookMinWaitInterval.String(), + AuthWebhookRequestTimeout: helper.AuthWebhookRequestTimeout.String(), + ProjectCacheSize: helper.ProjectCacheSize, + ProjectCacheTTL: helper.ProjectCacheTTL.String(), + AdminTokenDuration: helper.AdminTokenDuration, }, &mongo.Config{ ConnectionURI: helper.MongoConnectionURI, YorkieDatabase: helper.TestDBName(), diff --git a/test/complex/main_test.go b/test/complex/main_test.go index 26b530b99..cd96b3886 100644 --- a/test/complex/main_test.go +++ b/test/complex/main_test.go @@ -79,6 +79,7 @@ func TestMain(m *testing.M) { ClientDeactivateThreshold: helper.ClientDeactivateThreshold, SnapshotThreshold: helper.SnapshotThreshold, AuthWebhookCacheSize: helper.AuthWebhookSize, + AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), ProjectCacheSize: helper.ProjectCacheSize, ProjectCacheTTL: helper.ProjectCacheTTL.String(), AdminTokenDuration: helper.AdminTokenDuration, diff --git a/test/helper/helper.go b/test/helper/helper.go index 9068b5302..4137c7321 100644 --- a/test/helper/helper.go +++ b/test/helper/helper.go @@ -75,6 +75,8 @@ var ( SnapshotThreshold = int64(10) SnapshotWithPurgingChanges = false AuthWebhookMaxWaitInterval = 3 * gotime.Millisecond + AuthWebhookMinWaitInterval = 3 * gotime.Millisecond + AuthWebhookRequestTimeout = 100 * gotime.Millisecond AuthWebhookSize = 100 AuthWebhookCacheTTL = 10 * gotime.Second ProjectCacheSize = 256 @@ -264,6 +266,8 @@ func TestConfig() *server.Config { SnapshotThreshold: SnapshotThreshold, SnapshotWithPurgingChanges: SnapshotWithPurgingChanges, AuthWebhookMaxWaitInterval: AuthWebhookMaxWaitInterval.String(), + AuthWebhookMinWaitInterval: AuthWebhookMinWaitInterval.String(), + AuthWebhookRequestTimeout: AuthWebhookRequestTimeout.String(), AuthWebhookCacheSize: AuthWebhookSize, AuthWebhookCacheTTL: AuthWebhookCacheTTL.String(), ProjectCacheSize: ProjectCacheSize,