From c8a566c671592817bd73b37d1dc4d47e04051c50 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Wed, 10 Sep 2025 14:07:29 +0530 Subject: [PATCH 01/21] initial implementation for distributed and local storage rate limiters --- pkg/gofr/container/container.go | 3 + pkg/gofr/service/circuit_breaker_test.go | 8 + pkg/gofr/service/logger.go | 1 + pkg/gofr/service/metrics.go | 2 + pkg/gofr/service/mock_metrics.go | 35 ++ pkg/gofr/service/rate_limiter.go | 84 +++++ pkg/gofr/service/rate_limiter_distributed.go | 272 ++++++++++++++ pkg/gofr/service/rate_limiter_local.go | 374 +++++++++++++++++++ 8 files changed, 779 insertions(+) create mode 100644 pkg/gofr/service/rate_limiter.go create mode 100644 pkg/gofr/service/rate_limiter_distributed.go create mode 100644 pkg/gofr/service/rate_limiter_local.go diff --git a/pkg/gofr/container/container.go b/pkg/gofr/container/container.go index 6184a6419a..bf4a3c19f4 100644 --- a/pkg/gofr/container/container.go +++ b/pkg/gofr/container/container.go @@ -259,6 +259,9 @@ func (c *Container) registerFrameworkMetrics() { httpBuckets := []float64{.001, .003, .005, .01, .02, .03, .05, .1, .2, .3, .5, .75, 1, 2, 3, 5, 10, 30} c.Metrics().NewHistogram("app_http_response", "Response time of HTTP requests in seconds.", httpBuckets...) c.Metrics().NewHistogram("app_http_service_response", "Response time of HTTP service requests in seconds.", httpBuckets...) + c.Metrics().NewCounter("app_rate_limiter_requests_total", "Total rate limiter requests") + c.Metrics().NewCounter("app_rate_limiter_denied_total", "Total denied requests") + c.Metrics().NewGauge("app_rate_limiter_tokens_available", "Current tokens available") } { // Redis metrics diff --git a/pkg/gofr/service/circuit_breaker_test.go b/pkg/gofr/service/circuit_breaker_test.go index dfb2868531..58089acfdc 100644 --- a/pkg/gofr/service/circuit_breaker_test.go +++ b/pkg/gofr/service/circuit_breaker_test.go @@ -592,6 +592,14 @@ func (m *mockMetrics) RecordHistogram(ctx context.Context, name string, value fl m.Called(ctx, name, value, labels) } +func (m *mockMetrics) IncrementCounter(ctx context.Context, name string, labels ...string) { + m.Called(ctx, name, labels) +} + +func (m *mockMetrics) SetGauge(name string, value float64, labels ...string) { + m.Called(name, value, labels) +} + type customTransport struct { } diff --git a/pkg/gofr/service/logger.go b/pkg/gofr/service/logger.go index 04cf64917b..c94b72caee 100644 --- a/pkg/gofr/service/logger.go +++ b/pkg/gofr/service/logger.go @@ -8,6 +8,7 @@ import ( type Logger interface { Log(args ...any) + Debug(args ...any) } type Log struct { diff --git a/pkg/gofr/service/metrics.go b/pkg/gofr/service/metrics.go index ae48c21494..64753d9650 100644 --- a/pkg/gofr/service/metrics.go +++ b/pkg/gofr/service/metrics.go @@ -3,5 +3,7 @@ package service import "context" type Metrics interface { + IncrementCounter(ctx context.Context, name string, labels ...string) + SetGauge(name string, value float64, labels ...string) RecordHistogram(ctx context.Context, name string, value float64, labels ...string) } diff --git a/pkg/gofr/service/mock_metrics.go b/pkg/gofr/service/mock_metrics.go index df74ef0bf2..0923f05eb2 100644 --- a/pkg/gofr/service/mock_metrics.go +++ b/pkg/gofr/service/mock_metrics.go @@ -20,6 +20,7 @@ import ( type MockMetrics struct { ctrl *gomock.Controller recorder *MockMetricsMockRecorder + isgomock struct{} } // MockMetricsMockRecorder is the mock recorder for MockMetrics. @@ -39,6 +40,23 @@ func (m *MockMetrics) EXPECT() *MockMetricsMockRecorder { return m.recorder } +// IncrementCounter mocks base method. +func (m *MockMetrics) IncrementCounter(ctx context.Context, name string, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{ctx, name} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "IncrementCounter", varargs...) +} + +// IncrementCounter indicates an expected call of IncrementCounter. +func (mr *MockMetricsMockRecorder) IncrementCounter(ctx, name any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, name}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementCounter", reflect.TypeOf((*MockMetrics)(nil).IncrementCounter), varargs...) +} + // RecordHistogram mocks base method. func (m *MockMetrics) RecordHistogram(ctx context.Context, name string, value float64, labels ...string) { m.ctrl.T.Helper() @@ -55,3 +73,20 @@ func (mr *MockMetricsMockRecorder) RecordHistogram(ctx, name, value any, labels varargs := append([]any{ctx, name, value}, labels...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordHistogram", reflect.TypeOf((*MockMetrics)(nil).RecordHistogram), varargs...) } + +// SetGauge mocks base method. +func (m *MockMetrics) SetGauge(name string, value float64, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{name, value} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "SetGauge", varargs...) +} + +// SetGauge indicates an expected call of SetGauge. +func (mr *MockMetricsMockRecorder) SetGauge(name, value any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{name, value}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGauge", reflect.TypeOf((*MockMetrics)(nil).SetGauge), varargs...) +} diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go new file mode 100644 index 0000000000..90ef48c8ff --- /dev/null +++ b/pkg/gofr/service/rate_limiter.go @@ -0,0 +1,84 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "time" + + gofrRedis "gofr.dev/pkg/gofr/datasource/redis" +) + +var ( + errInvalidRequestRate = errors.New("requestsPerSecond must be greater than 0") + errInvalidBurstSize = errors.New("burst must be greater than 0") + errInvalidRedisResultType = errors.New("unexpected Redis result type") +) + +// RateLimiterConfig with custom keying support. +type RateLimiterConfig struct { + RequestsPerSecond float64 // Token refill rate (must be > 0) + Burst int // Maximum burst capacity (must be > 0) + KeyFunc func(*http.Request) string // Optional custom key extraction + RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting +} + +// Default key function extracts scheme://host +func defaultKeyFunc(req *http.Request) string { + if req == nil || req.URL == nil { + return "unknown" + } + + return req.URL.Scheme + "://" + req.URL.Host +} + +// Validate checks if the configuration is valid. +func (config *RateLimiterConfig) Validate() error { + if config.RequestsPerSecond <= 0 { + return fmt.Errorf("%w: %f", errInvalidRequestRate, config.RequestsPerSecond) + } + + if config.Burst <= 0 { + return fmt.Errorf("%w: %d", errInvalidBurstSize, config.Burst) + } + + // Set default key function if not provided. + if config.KeyFunc == nil { + config.KeyFunc = defaultKeyFunc + } + + return nil +} + +// AddOption implements the Options interface. +func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { + if err := config.Validate(); err != nil { + if httpSvc, ok := h.(*httpService); ok { + httpSvc.Logger.Log("Invalid rate limiter config, disabling rate limiting", "error", err) + } + + return h + } + + // Choose implementation based on Redis client availability. + if config.RedisClient != nil { + return NewDistributedRateLimiter(*config, h) + } + + // Log warning for local rate limiting. + if httpSvc, ok := h.(*httpService); ok { + httpSvc.Logger.Log("Using local rate limiting - not suitable for multi-instance deployments") + } + + return NewLocalRateLimiter(*config, h) +} + +// RateLimitError represents a rate limiting error. +type RateLimitError struct { + ServiceKey string + RetryAfter time.Duration +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limit exceeded for service: %s, retry after: %v", e.ServiceKey, e.RetryAfter) +} diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go new file mode 100644 index 0000000000..812c874780 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -0,0 +1,272 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "strconv" + "time" + + gofrRedis "gofr.dev/pkg/gofr/datasource/redis" +) + +// tokenBucketScript is a Lua script for atomic token bucket rate limiting in Redis. +// +//nolint:gosec // This is a Lua script for Redis, not credentials +const tokenBucketScript = ` +local key = KEYS[1] +local burst = tonumber(ARGV[1]) +local refill_rate = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) + +-- Fetch bucket +local bucket = redis.call("HMGET", key, "tokens", "last_refill") +local tokens = tonumber(bucket[1]) +local last_refill = tonumber(bucket[2]) + +if tokens == nil then + tokens = burst + last_refill = now +end + +-- Refill tokens +local delta = math.max(0, now - last_refill) +local new_tokens = math.min(burst, tokens + delta * refill_rate) + +-- Try to consume one token +if new_tokens < 1 then + -- not enough tokens + redis.call("HMSET", key, "tokens", new_tokens, "last_refill", now) + redis.call("EXPIRE", key, 600) + return 0 +else + redis.call("HMSET", key, "tokens", new_tokens - 1, "last_refill", now) + redis.call("EXPIRE", key, 600) + return 1 +end +` + +// distributedRateLimiter with metrics support. +type distributedRateLimiter struct { + config RateLimiterConfig + redisClient gofrRedis.Redis + logger Logger + metrics Metrics + HTTP +} + +func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP) HTTP { + httpSvc := h.(*httpService) + + rl := &distributedRateLimiter{ + config: config, + redisClient: *config.RedisClient, + logger: httpSvc.Logger, + metrics: httpSvc.Metrics, + HTTP: h, + } + + return rl +} + +// Safe Redis result parsing. +func toInt64(i any) (int64, error) { + switch v := i.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case float64: + return int64(v), nil + case string: + return strconv.ParseInt(v, 10, 64) + default: + return 0, fmt.Errorf("%w: %T", errInvalidRedisResultType, i) + } +} + +// checkRateLimit for distributed version with metrics. +func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { + serviceKey := rl.config.KeyFunc(req) + now := time.Now().UnixNano() + + cmd := rl.redisClient.Eval( + context.Background(), + tokenBucketScript, + []string{"gofr:ratelimit:" + serviceKey}, + rl.config.Burst, + rl.config.RequestsPerSecond, + now, + ) + + result, err := cmd.Result() + if err != nil { + rl.logger.Log("Redis rate limiter error, allowing request", "error", err) + // Record error metric + if rl.metrics != nil { + rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_errors_total", "service", serviceKey, "type", "redis_error") + } + + return nil // Fail open + } + + // ✅ FIX: Safe result parsing + resultArray, ok := result.([]any) + if !ok || len(resultArray) != 2 { + rl.logger.Log("Invalid Redis response format, allowing request") + return nil // Fail open + } + + allowed, err := toInt64(resultArray[0]) + if err != nil { + rl.logger.Log("Invalid Redis allowed value, allowing request", "error", err) + return nil + } + + retryAfterMs, err := toInt64(resultArray[1]) + if err != nil { + rl.logger.Log("Invalid Redis retry-after value, allowing request", "error", err) + return nil + } + + // ✅ FIX: Record metrics for distributed limiter + if rl.metrics != nil { + rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_requests_total", "service", serviceKey) + + if allowed != 1 { + rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_denied_total", "service", serviceKey) + } + } + + if allowed != 1 { + retryAfter := time.Duration(retryAfterMs) * time.Millisecond + rl.logger.Debug("Distributed rate limit exceeded", + "service", serviceKey, + "retry_after", retryAfter) + + return &RateLimitError{ + ServiceKey: serviceKey, + RetryAfter: retryAfter, + } + } + + return nil +} + +// GetWithHeaders performs rate-limited HTTP GET request with custom headers. +func (rl *distributedRateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, + headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.GetWithHeaders(ctx, path, queryParams, headers) +} + +// PostWithHeaders performs rate-limited HTTP POST request with custom headers. +func (rl *distributedRateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers) +} + +// PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. +func (rl *distributedRateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers) +} + +// PutWithHeaders performs rate-limited HTTP PUT request with custom headers. +func (rl *distributedRateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers) +} + +// DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. +func (rl *distributedRateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, + headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.DeleteWithHeaders(ctx, path, body, headers) +} + +// Get performs rate-limited HTTP GET request. +func (rl *distributedRateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Get(ctx, path, queryParams) +} + +// Post performs rate-limited HTTP POST request. +func (rl *distributedRateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Post(ctx, path, queryParams, body) +} + +// Patch performs rate-limited HTTP PATCH request. +func (rl *distributedRateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Patch(ctx, path, queryParams, body) +} + +// Put performs rate-limited HTTP PUT request. +func (rl *distributedRateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Put(ctx, path, queryParams, body) +} + +// Delete performs rate-limited HTTP DELETE request. +func (rl *distributedRateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Delete(ctx, path, body) +} diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go new file mode 100644 index 0000000000..d1216247a4 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_local.go @@ -0,0 +1,374 @@ +package service + +import ( + "context" + "net/http" + "runtime" + "sync" + "sync/atomic" + "time" +) + +const backoffAttemptThreshold = 3 + +// tokenBucket with fractional accumulator for better precision. +type tokenBucket struct { + tokens int64 // Current tokens (scaled by scale) + fractionalTokens float64 // Fractional remainder to avoid precision loss + lastRefillTime int64 // Unix nano timestamp + maxTokens int64 // Maximum tokens (scaled by scale) + refillPerNano float64 // Tokens per nanosecond (float64 for precision) + fracMutex sync.Mutex // Protects fractionalTokens +} + +// localRateLimiter with metrics support. +type localRateLimiter struct { + config RateLimiterConfig + buckets sync.Map + logger Logger + metrics Metrics + HTTP +} + +// bucketEntry holds bucket with last access time for cleanup. +type bucketEntry struct { + bucket *tokenBucket + lastAccess int64 // Unix timestamp +} + +const ( + scale int64 = 1e9 // Scaling factor (typed constant) + cleanupInterval = 5 * time.Minute // How often to clean up unused buckets + bucketTTL = 10 * time.Minute // How long to keep unused buckets + maxCASAttempts = 10 // ✅ FIX: Max CAS attempts + maxCASTime = 100 * time.Microsecond // ✅ FIX: Max CAS time +) + +// NewLocalRateLimiter creates a new local rate limiter with metrics. +func NewLocalRateLimiter(config RateLimiterConfig, h HTTP) HTTP { + httpSvc := h.(*httpService) + + rl := &localRateLimiter{ + config: config, + logger: httpSvc.Logger, + metrics: httpSvc.Metrics, + HTTP: h, + } + + go rl.cleanupRoutine() + + return rl +} + +// newTokenBucket creates a new atomic token bucket with proper float64 scaling. +func newTokenBucket(maxTokens int, refillRate float64) *tokenBucket { + maxScaled := int64(maxTokens) * scale + + // ✅ FIX: Calculate tokens per nanosecond with float64 precision + refillPerNanoFloat := refillRate * float64(scale) / float64(time.Second) + + return &tokenBucket{ + tokens: maxScaled, + fractionalTokens: 0.0, // ✅ FIX: Initialize fractional accumulator + lastRefillTime: time.Now().UnixNano(), + maxTokens: maxScaled, + refillPerNano: refillPerNanoFloat, + } +} + +// allow with enhanced precision and metrics. +func (tb *tokenBucket) allow(ctx context.Context, metrics Metrics, serviceKey string) (bool, time.Duration) { + start := time.Now() + + for attempt := 0; attempt < maxCASAttempts && time.Since(start) < maxCASTime; attempt++ { + now := time.Now().UnixNano() + newTokens := tb.refillTokens(now) + + if newTokens < scale { + retry := tb.calculateRetry(newTokens) + tb.advanceTime(now) + tb.recordDenied(ctx, metrics, serviceKey) + + return false, retry + } + + if tb.consumeToken(newTokens, now) { + tb.recordSuccess(ctx, metrics, serviceKey, newTokens-scale) + return true, 0 + } + + tb.backoff(attempt) + } + + return false, time.Second +} + +func (tb *tokenBucket) refillTokens(now int64) int64 { + oldTime := atomic.LoadInt64(&tb.lastRefillTime) + oldTokens := atomic.LoadInt64(&tb.tokens) + + elapsed := now - oldTime + if elapsed < 0 { + elapsed = 0 + } + + tb.fracMutex.Lock() + tokensToAddFloat := float64(elapsed)*tb.refillPerNano + tb.fractionalTokens + tokensToAdd := int64(tokensToAddFloat) + tb.fractionalTokens = tokensToAddFloat - float64(tokensToAdd) + tb.fracMutex.Unlock() + + newTokens := oldTokens + tokensToAdd + if newTokens > tb.maxTokens { + newTokens = tb.maxTokens + } + + return newTokens +} + +func (tb *tokenBucket) calculateRetry(tokens int64) time.Duration { + if tb.refillPerNano == 0 { + return time.Second + } + + missing := float64(scale - tokens) + nanos := missing / tb.refillPerNano + + retry := time.Duration(nanos) + if retry < time.Second { + retry = time.Second + } + + return retry +} + +func (tb *tokenBucket) advanceTime(now int64) { + oldTime := atomic.LoadInt64(&tb.lastRefillTime) + atomic.CompareAndSwapInt64(&tb.lastRefillTime, oldTime, now) +} + +func (tb *tokenBucket) consumeToken(tokens, now int64) bool { + oldTokens := atomic.LoadInt64(&tb.tokens) + + if atomic.CompareAndSwapInt64(&tb.tokens, oldTokens, tokens-scale) { + atomic.StoreInt64(&tb.lastRefillTime, now) + + return true + } + + return false +} + +func (*tokenBucket) recordDenied(ctx context.Context, metrics Metrics, serviceKey string) { + if metrics != nil { + metrics.IncrementCounter(ctx, "app_rate_limiter_denied_total", "service", serviceKey) + } +} + +func (*tokenBucket) recordSuccess(ctx context.Context, metrics Metrics, serviceKey string, remaining int64) { + if metrics != nil { + metrics.IncrementCounter(ctx, "app_rate_limiter_requests_total", "service", serviceKey) + + availableTokens := float64(remaining) / float64(scale) + + metrics.SetGauge("app_rate_limiter_tokens_available", availableTokens, "service", serviceKey) + } +} + +func (*tokenBucket) backoff(attempt int) { + if attempt < backoffAttemptThreshold { + runtime.Gosched() + } else { + time.Sleep(time.Microsecond) + } +} + +// checkRateLimit with custom keying support. +func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { + // ✅ FIX: Use configurable KeyFunc for custom keying + serviceKey := rl.config.KeyFunc(req) + now := time.Now().Unix() + + entry, _ := rl.buckets.LoadOrStore(serviceKey, &bucketEntry{ + bucket: newTokenBucket(rl.config.Burst, rl.config.RequestsPerSecond), + lastAccess: now, + }) + + bucketEntry := entry.(*bucketEntry) + atomic.StoreInt64(&bucketEntry.lastAccess, now) + + allowed, retryAfter := bucketEntry.bucket.allow(context.Background(), rl.metrics, serviceKey) + if !allowed { + // ✅ FIX: Debug level to prevent log spam + rl.logger.Debug("Rate limit exceeded", + "service", serviceKey, + "rate", rl.config.RequestsPerSecond, + "burst", rl.config.Burst, + "retry_after", retryAfter) + + return &RateLimitError{ + ServiceKey: serviceKey, + RetryAfter: retryAfter, + } + } + + return nil +} + +// updateRateLimiterMetrics follows GoFr's updateMetrics pattern. +func (rl *localRateLimiter) updateRateLimiterMetrics(ctx context.Context, serviceKey string, allowed bool, tokensAvailable float64) { + if rl.metrics != nil { + rl.metrics.IncrementCounter(ctx, "app_rate_limiter_requests_total", "service", serviceKey) + + if !allowed { + rl.metrics.IncrementCounter(ctx, "app_rate_limiter_denied_total", "service", serviceKey) + } + + rl.metrics.SetGauge("app_rate_limiter_tokens_available", tokensAvailable, "service", serviceKey) + } +} + +// cleanupRoutine removes unused buckets. +func (rl *localRateLimiter) cleanupRoutine() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + cutoff := time.Now().Unix() - int64(bucketTTL.Seconds()) + cleaned := 0 + + rl.buckets.Range(func(key, value any) bool { + entry := value.(*bucketEntry) + + if atomic.LoadInt64(&entry.lastAccess) < cutoff { + rl.buckets.Delete(key) + + cleaned++ + } + + return true + }) + + if cleaned > 0 { + rl.logger.Debug("Cleaned up rate limiter buckets", "count", cleaned) + } + } +} + +// GetWithHeaders performs rate-limited HTTP GET request with custom headers. +func (rl *localRateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, + headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.GetWithHeaders(ctx, path, queryParams, headers) +} + +// PostWithHeaders performs rate-limited HTTP POST request with custom headers. +func (rl *localRateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers) +} + +// PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. +func (rl *localRateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers) +} + +// PutWithHeaders performs rate-limited HTTP PUT request with custom headers. +func (rl *localRateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, + headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers) +} + +// DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. +func (rl *localRateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, + headers map[string]string) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.DeleteWithHeaders(ctx, path, body, headers) +} + +// Get performs rate-limited HTTP GET request. +func (rl *localRateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Get(ctx, path, queryParams) +} + +// Post performs rate-limited HTTP POST request. +func (rl *localRateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Post(ctx, path, queryParams, body) +} + +// Patch performs rate-limited HTTP PATCH request. +func (rl *localRateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Patch(ctx, path, queryParams, body) +} + +// Put performs rate-limited HTTP PUT request. +func (rl *localRateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Put(ctx, path, queryParams, body) +} + +// Delete performs rate-limited HTTP DELETE request. +func (rl *localRateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Delete(ctx, path, body) +} From 2048c548c5530552bce251e8a3a6f28616cf32aa Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Wed, 10 Sep 2025 16:55:29 +0530 Subject: [PATCH 02/21] fix linters and other errors in the initial implementation --- pkg/gofr/service/rate_limiter_distributed.go | 2 -- pkg/gofr/service/rate_limiter_local.go | 37 ++++++-------------- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index 812c874780..f40d41d374 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -110,7 +110,6 @@ func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { return nil // Fail open } - // ✅ FIX: Safe result parsing resultArray, ok := result.([]any) if !ok || len(resultArray) != 2 { rl.logger.Log("Invalid Redis response format, allowing request") @@ -129,7 +128,6 @@ func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { return nil } - // ✅ FIX: Record metrics for distributed limiter if rl.metrics != nil { rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_requests_total", "service", serviceKey) diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go index d1216247a4..a70dd01c7d 100644 --- a/pkg/gofr/service/rate_limiter_local.go +++ b/pkg/gofr/service/rate_limiter_local.go @@ -64,12 +64,11 @@ func NewLocalRateLimiter(config RateLimiterConfig, h HTTP) HTTP { func newTokenBucket(maxTokens int, refillRate float64) *tokenBucket { maxScaled := int64(maxTokens) * scale - // ✅ FIX: Calculate tokens per nanosecond with float64 precision refillPerNanoFloat := refillRate * float64(scale) / float64(time.Second) return &tokenBucket{ tokens: maxScaled, - fractionalTokens: 0.0, // ✅ FIX: Initialize fractional accumulator + fractionalTokens: 0.0, lastRefillTime: time.Now().UnixNano(), maxTokens: maxScaled, refillPerNano: refillPerNanoFloat, @@ -77,7 +76,7 @@ func newTokenBucket(maxTokens int, refillRate float64) *tokenBucket { } // allow with enhanced precision and metrics. -func (tb *tokenBucket) allow(ctx context.Context, metrics Metrics, serviceKey string) (bool, time.Duration) { +func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration, tokensRemaining int64) { start := time.Now() for attempt := 0; attempt < maxCASAttempts && time.Since(start) < maxCASTime; attempt++ { @@ -87,20 +86,18 @@ func (tb *tokenBucket) allow(ctx context.Context, metrics Metrics, serviceKey st if newTokens < scale { retry := tb.calculateRetry(newTokens) tb.advanceTime(now) - tb.recordDenied(ctx, metrics, serviceKey) - return false, retry + return false, retry, newTokens } if tb.consumeToken(newTokens, now) { - tb.recordSuccess(ctx, metrics, serviceKey, newTokens-scale) - return true, 0 + return true, 0, newTokens - scale } tb.backoff(attempt) } - return false, time.Second + return false, time.Second, 0 } func (tb *tokenBucket) refillTokens(now int64) int64 { @@ -159,22 +156,6 @@ func (tb *tokenBucket) consumeToken(tokens, now int64) bool { return false } -func (*tokenBucket) recordDenied(ctx context.Context, metrics Metrics, serviceKey string) { - if metrics != nil { - metrics.IncrementCounter(ctx, "app_rate_limiter_denied_total", "service", serviceKey) - } -} - -func (*tokenBucket) recordSuccess(ctx context.Context, metrics Metrics, serviceKey string, remaining int64) { - if metrics != nil { - metrics.IncrementCounter(ctx, "app_rate_limiter_requests_total", "service", serviceKey) - - availableTokens := float64(remaining) / float64(scale) - - metrics.SetGauge("app_rate_limiter_tokens_available", availableTokens, "service", serviceKey) - } -} - func (*tokenBucket) backoff(attempt int) { if attempt < backoffAttemptThreshold { runtime.Gosched() @@ -185,7 +166,6 @@ func (*tokenBucket) backoff(attempt int) { // checkRateLimit with custom keying support. func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { - // ✅ FIX: Use configurable KeyFunc for custom keying serviceKey := rl.config.KeyFunc(req) now := time.Now().Unix() @@ -197,9 +177,12 @@ func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { bucketEntry := entry.(*bucketEntry) atomic.StoreInt64(&bucketEntry.lastAccess, now) - allowed, retryAfter := bucketEntry.bucket.allow(context.Background(), rl.metrics, serviceKey) + allowed, retryAfter, tokensRemaining := bucketEntry.bucket.allow() + + tokensAvailable := float64(tokensRemaining) / float64(scale) + rl.updateRateLimiterMetrics(context.Background(), serviceKey, allowed, tokensAvailable) + if !allowed { - // ✅ FIX: Debug level to prevent log spam rl.logger.Debug("Rate limit exceeded", "service", serviceKey, "rate", rl.config.RequestsPerSecond, From 1d28995185bf1ea0ef54b5fbf6f5f2b239f7bf04 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Thu, 11 Sep 2025 16:49:51 +0530 Subject: [PATCH 03/21] fix bugs in the implementation after testing --- pkg/gofr/service/oauth.go | 2 +- pkg/gofr/service/rate_limiter.go | 28 +++++++- pkg/gofr/service/rate_limiter_distributed.go | 72 +++++++++++-------- pkg/gofr/service/rate_limiter_local.go | 74 ++++++++++++++++---- 4 files changed, 130 insertions(+), 46 deletions(-) diff --git a/pkg/gofr/service/oauth.go b/pkg/gofr/service/oauth.go index c991e93756..aed4ecb621 100644 --- a/pkg/gofr/service/oauth.go +++ b/pkg/gofr/service/oauth.go @@ -79,7 +79,7 @@ func validateTokenURL(tokenURL string) error { return AuthErr{nil, "invalid host pattern, contains `..`"} case strings.HasSuffix(u.Host, "."): return AuthErr{nil, "invalid host pattern, ends with `.`"} - case u.Scheme != "http" && u.Scheme != "https": + case u.Scheme != methodHTTP && u.Scheme != methodHTTPS: return AuthErr{nil, "invalid scheme, allowed http and https only"} default: return nil diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index 90ef48c8ff..6cf0c04977 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -23,13 +23,32 @@ type RateLimiterConfig struct { RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting } -// Default key function extracts scheme://host +// defaultKeyFunc extracts a normalized service key from an HTTP request. func defaultKeyFunc(req *http.Request) string { if req == nil || req.URL == nil { return "unknown" } - return req.URL.Scheme + "://" + req.URL.Host + scheme := req.URL.Scheme + host := req.URL.Host + + if scheme == "" { + if req.TLS != nil { + scheme = methodHTTPS + } else { + scheme = methodHTTP + } + } + + if host == "" { + host = req.Host + } + + if host == "" { + host = unknownServiceKey + } + + return scheme + "://" + host } // Validate checks if the configuration is valid. @@ -82,3 +101,8 @@ type RateLimitError struct { func (e *RateLimitError) Error() string { return fmt.Sprintf("rate limit exceeded for service: %s, retry after: %v", e.ServiceKey, e.RetryAfter) } + +// StatusCode Implement StatusCodeResponder so Responder picks correct HTTP code. +func (*RateLimitError) StatusCode() int { + return http.StatusTooManyRequests // 429 +} diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index f40d41d374..7f9e293a7d 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -30,22 +30,31 @@ if tokens == nil then end -- Refill tokens -local delta = math.max(0, now - last_refill) +local delta = math.max(0, (now - last_refill)/1e9) local new_tokens = math.min(burst, tokens + delta * refill_rate) --- Try to consume one token -if new_tokens < 1 then - -- not enough tokens - redis.call("HMSET", key, "tokens", new_tokens, "last_refill", now) - redis.call("EXPIRE", key, 600) - return 0 +local allowed = 0 +local retryAfter = 0 + +if new_tokens >= 1 then + allowed = 1 + new_tokens = new_tokens - 1 else - redis.call("HMSET", key, "tokens", new_tokens - 1, "last_refill", now) - redis.call("EXPIRE", key, 600) - return 1 + retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms end + +redis.call("HSET", key, "tokens", new_tokens, "last_refill", now) +redis.call("EXPIRE", key, 600) + +return {allowed, retryAfter} ` +// DistributedRateLimiter implements Redis-based distributed rate limiting using Token Bucket algorithm. +// Strategy: Token Bucket with Redis Lua scripts for atomic operations +// - Suitable for: Multi-instance production deployments +// - Benefits: True distributed limiting across all service instances +// - Performance: Single Redis call per rate limit check with atomic Lua execution + // distributedRateLimiter with metrics support. type distributedRateLimiter struct { config RateLimiterConfig @@ -116,17 +125,8 @@ func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { return nil // Fail open } - allowed, err := toInt64(resultArray[0]) - if err != nil { - rl.logger.Log("Invalid Redis allowed value, allowing request", "error", err) - return nil - } - - retryAfterMs, err := toInt64(resultArray[1]) - if err != nil { - rl.logger.Log("Invalid Redis retry-after value, allowing request", "error", err) - return nil - } + allowed, _ := toInt64(resultArray[0]) + retryAfterMs, _ := toInt64(resultArray[1]) if rl.metrics != nil { rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_requests_total", "service", serviceKey) @@ -154,7 +154,8 @@ func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { // GetWithHeaders performs rate-limited HTTP GET request with custom headers. func (rl *distributedRateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -166,7 +167,8 @@ func (rl *distributedRateLimiter) GetWithHeaders(ctx context.Context, path strin // PostWithHeaders performs rate-limited HTTP POST request with custom headers. func (rl *distributedRateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -178,7 +180,8 @@ func (rl *distributedRateLimiter) PostWithHeaders(ctx context.Context, path stri // PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. func (rl *distributedRateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -190,7 +193,8 @@ func (rl *distributedRateLimiter) PatchWithHeaders(ctx context.Context, path str // PutWithHeaders performs rate-limited HTTP PUT request with custom headers. func (rl *distributedRateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -202,7 +206,8 @@ func (rl *distributedRateLimiter) PutWithHeaders(ctx context.Context, path strin // DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. func (rl *distributedRateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -213,7 +218,8 @@ func (rl *distributedRateLimiter) DeleteWithHeaders(ctx context.Context, path st // Get performs rate-limited HTTP GET request. func (rl *distributedRateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -225,7 +231,8 @@ func (rl *distributedRateLimiter) Get(ctx context.Context, path string, queryPar // Post performs rate-limited HTTP POST request. func (rl *distributedRateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -237,7 +244,8 @@ func (rl *distributedRateLimiter) Post(ctx context.Context, path string, queryPa // Patch performs rate-limited HTTP PATCH request. func (rl *distributedRateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -249,7 +257,8 @@ func (rl *distributedRateLimiter) Patch(ctx context.Context, path string, queryP // Put performs rate-limited HTTP PUT request. func (rl *distributedRateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -260,7 +269,8 @@ func (rl *distributedRateLimiter) Put(ctx context.Context, path string, queryPar // Delete performs rate-limited HTTP DELETE request. func (rl *distributedRateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go index a70dd01c7d..f35b62dbe6 100644 --- a/pkg/gofr/service/rate_limiter_local.go +++ b/pkg/gofr/service/rate_limiter_local.go @@ -4,12 +4,18 @@ import ( "context" "net/http" "runtime" + "strings" "sync" "sync/atomic" "time" ) -const backoffAttemptThreshold = 3 +const ( + backoffAttemptThreshold = 3 + unknownServiceKey = "unknown" + methodHTTP = "http" + methodHTTPS = "https" +) // tokenBucket with fractional accumulator for better precision. type tokenBucket struct { @@ -21,6 +27,12 @@ type tokenBucket struct { fracMutex sync.Mutex // Protects fractionalTokens } +// LocalRateLimiter implements in-memory rate limiting using the Token Bucket algorithm. +// Strategy: Token Bucket with fractional precision for sub-1 RPS support +// - Suitable for: Single-instance deployments, development, testing +// - Limitations: Per-instance limiting only, not suitable for multi-instance production +// - Performance: Lock-free atomic operations with CAS loops + // localRateLimiter with metrics support. type localRateLimiter struct { config RateLimiterConfig @@ -100,6 +112,7 @@ func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration, tokensRema return false, time.Second, 0 } +// refillTokens calculates and returns new token count after refilling based on elapsed time. func (tb *tokenBucket) refillTokens(now int64) int64 { oldTime := atomic.LoadInt64(&tb.lastRefillTime) oldTokens := atomic.LoadInt64(&tb.tokens) @@ -123,6 +136,7 @@ func (tb *tokenBucket) refillTokens(now int64) int64 { return newTokens } +// calculateRetry computes the precise time duration until the next token becomes available. func (tb *tokenBucket) calculateRetry(tokens int64) time.Duration { if tb.refillPerNano == 0 { return time.Second @@ -164,6 +178,31 @@ func (*tokenBucket) backoff(attempt int) { } } +// buildFullURL constructs an absolute URL by combining the base service URL with the given path. +func buildFullURL(path string, httpSvc HTTP) string { + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path + } + + // Get base URL from embedded HTTP service + httpSvcImpl, ok := httpSvc.(*httpService) + if !ok { + return path + } + + base := strings.TrimRight(httpSvcImpl.url, "/") + if base == "" { + return path + } + + // Ensure path starts with / + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + return base + path +} + // checkRateLimit with custom keying support. func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { serviceKey := rl.config.KeyFunc(req) @@ -180,7 +219,7 @@ func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { allowed, retryAfter, tokensRemaining := bucketEntry.bucket.allow() tokensAvailable := float64(tokensRemaining) / float64(scale) - rl.updateRateLimiterMetrics(context.Background(), serviceKey, allowed, tokensAvailable) + rl.updateRateLimiterMetrics(req.Context(), serviceKey, allowed, tokensAvailable) if !allowed { rl.logger.Debug("Rate limit exceeded", @@ -241,7 +280,8 @@ func (rl *localRateLimiter) cleanupRoutine() { // GetWithHeaders performs rate-limited HTTP GET request with custom headers. func (rl *localRateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -253,7 +293,8 @@ func (rl *localRateLimiter) GetWithHeaders(ctx context.Context, path string, que // PostWithHeaders performs rate-limited HTTP POST request with custom headers. func (rl *localRateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -265,7 +306,9 @@ func (rl *localRateLimiter) PostWithHeaders(ctx context.Context, path string, qu // PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. func (rl *localRateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -277,7 +320,8 @@ func (rl *localRateLimiter) PatchWithHeaders(ctx context.Context, path string, q // PutWithHeaders performs rate-limited HTTP PUT request with custom headers. func (rl *localRateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -289,7 +333,8 @@ func (rl *localRateLimiter) PutWithHeaders(ctx context.Context, path string, que // DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. func (rl *localRateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, headers map[string]string) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -300,7 +345,8 @@ func (rl *localRateLimiter) DeleteWithHeaders(ctx context.Context, path string, // Get performs rate-limited HTTP GET request. func (rl *localRateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -312,7 +358,8 @@ func (rl *localRateLimiter) Get(ctx context.Context, path string, queryParams ma // Post performs rate-limited HTTP POST request. func (rl *localRateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -324,7 +371,8 @@ func (rl *localRateLimiter) Post(ctx context.Context, path string, queryParams m // Patch performs rate-limited HTTP PATCH request. func (rl *localRateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -336,7 +384,8 @@ func (rl *localRateLimiter) Patch(ctx context.Context, path string, queryParams // Put performs rate-limited HTTP PUT request. func (rl *localRateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err @@ -347,7 +396,8 @@ func (rl *localRateLimiter) Put(ctx context.Context, path string, queryParams ma // Delete performs rate-limited HTTP DELETE request. func (rl *localRateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, path, http.NoBody) + fullURL := buildFullURL(path, rl.HTTP) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) if err := rl.checkRateLimit(req); err != nil { return nil, err From 4e6ecd7718c0eac61a0d42811a1d96690a2899af Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Thu, 11 Sep 2025 18:09:19 +0530 Subject: [PATCH 04/21] add documentation --- docs/advanced-guide/http-communication/page.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/advanced-guide/http-communication/page.md b/docs/advanced-guide/http-communication/page.md index d90761fbc3..c8bad1327a 100644 --- a/docs/advanced-guide/http-communication/page.md +++ b/docs/advanced-guide/http-communication/page.md @@ -95,10 +95,13 @@ GoFr provides its user with additional configurational options while registering - **DefaultHeaders** - This option allows user to set some default headers that will be propagated to the downstream HTTP Service every time it is being called. - **HealthConfig** - This option allows user to add the `HealthEndpoint` along with `Timeout` to enable and perform the timely health checks for downstream HTTP Service. - **RetryConfig** - This option allows user to add the maximum number of retry count if before returning error if any downstream HTTP Service fails. +- **RateLimiterConfig** - This option allows user to configure rate limiting for downstream service calls using token bucket algorithm. It controls the request rate to prevent overwhelming dependent services and supports both in-memory and Redis-based implementations. #### Usage: ```go +rc := redis.NewClient(cfg, a.Logger(), a.Metrics()) + a.AddHTTPService("cat-facts", "https://catfact.ninja", service.NewAPIKeyConfig("some-random-key"), service.NewBasicAuthConfig("username", "password"), @@ -119,5 +122,11 @@ a.AddHTTPService("cat-facts", "https://catfact.ninja", &service.RetryConfig{ MaxRetries: 5 }, + + &service.RateLimiterConfig{ + Rate: 5, + Burst: 10, + RedisClient: rc, // if RedisClient is nil, in-memory rate limiter will be used + }, ) ``` \ No newline at end of file From a651ea76f87794faf42782bf8b59d6cdb2f0bff0 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Tue, 16 Sep 2025 11:09:18 +0530 Subject: [PATCH 05/21] add test for local rate limiter implementation --- pkg/gofr/service/rate_limiter_test.go | 413 ++++++++++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 pkg/gofr/service/rate_limiter_test.go diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go new file mode 100644 index 0000000000..5389806039 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_test.go @@ -0,0 +1,413 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "gofr.dev/pkg/gofr/logging" +) + +func newBaseHTTPService(t *testing.T, hitCounter *atomic.Int64) *httpService { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hitCounter.Add(1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + return &httpService{ + Client: http.DefaultClient, + url: srv.URL, + Logger: logging.NewMockLogger(logging.INFO), + Tracer: otel.Tracer("gofr-http-client"), + } +} + +func assertAllowed(t *testing.T, resp *http.Response, err error) { + t.Helper() + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func assertRateLimited(t *testing.T, err error, key ...string) { + t.Helper() + require.Error(t, err) + + var rlErr *RateLimitError + + require.ErrorAs(t, err, &rlErr) + + if len(key) > 0 { + assert.Equal(t, key[0], rlErr.ServiceKey) + } + + assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) +} + +func isAllowed(t *testing.T, err error) bool { + t.Helper() + + return assert.NoError(t, err) +} + +func isRateLimitError(t *testing.T, err error) bool { + t.Helper() + + var rlErr *RateLimitError + + return assert.ErrorAs(t, err, &rlErr) +} + +func wait(d time.Duration) { time.Sleep(d) } + +// Tests ---------------------------------------------------------------------- + +// Ensures constructor wraps *httpService and first request passes. +func TestNewLocalRateLimiter_Basic(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 5, + Burst: 5, + KeyFunc: func(*http.Request) string { return "svc-basic" }, + }, base) + + resp, err := rl.Get(t.Context(), "/ok", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(1), hits.Load()) +} + +// Burst=1 then immediate second call denied; after refill allowed again. +func TestLocalRateLimiter_EnforceLimit(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-limit" }, + }, base) + + resp, err := rl.Get(t.Context(), "/r1", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/r2", nil) + require.Nil(t, resp) + assertRateLimited(t, err) + + resp.Body.Close() + + wait(1100 * time.Millisecond) + + resp, err = rl.Get(t.Context(), "/r3", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(2), hits.Load()) +} + +// Fractional RPS (0.5 -> 1 token every 2s). +func TestLocalRateLimiter_FractionalRPS(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 0.5, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-frac" }, + }, base) + + resp, err := rl.Get(t.Context(), "/a", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/b", nil) + require.Nil(t, resp) + assertRateLimited(t, err) + + if resp != nil { + _ = resp.Body.Close() + } + + wait(2100 * time.Millisecond) + + resp, err = rl.Get(t.Context(), "/c", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(2), hits.Load()) +} + +// Different paths share same bucket via custom KeyFunc. +func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + KeyFunc: func(*http.Request) string { return "shared-key" }, + }, base) + + resp, err := rl.Get(t.Context(), "/p1", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/p2", nil) + require.Nil(t, resp) + assertRateLimited(t, err) + + if resp != nil { + _ = resp.Body.Close() + } + + wait(1100 * time.Millisecond) + + resp, err = rl.Get(t.Context(), "/p3", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(2), hits.Load()) +} + +// Concurrency: Burst=1 & RPS=1 => only one succeeds immediately. +func TestLocalRateLimiter_Concurrency(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-conc" }, + }, base) + + const workers = 12 + + var wg sync.WaitGroup + + wg.Add(workers) + + results := make([]error, workers) + + for i := 0; i < workers; i++ { + go func(i int) { + defer wg.Done() + + resp, err := rl.Get(context.Background(), "/c", nil) + + if resp != nil { + _ = resp.Body.Close() + } + + results[i] = err + }(i) + } + + wg.Wait() + + var allowed, denied int + + for _, e := range results { + switch { + case isAllowed(t, e): + allowed++ + case isRateLimitError(t, e): + denied++ + } + } + + assert.Equal(t, 1, allowed) + assert.Equal(t, workers-1, denied) + assert.Equal(t, int64(1), hits.Load()) +} + +// buildFullURL behavior for relative and absolute forms. +func TestBuildFullURL(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + assert.Contains(t, buildFullURL("/x", base), "/x") + assert.Equal(t, "http://example.com/z", buildFullURL("http://example.com/z", base)) + assert.Contains(t, buildFullURL("rel", base), "/rel") +} + +// Ensures metrics calls do not panic when metrics nil (guard path). +func TestLocalRateLimiter_NoMetrics(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 2, + KeyFunc: func(*http.Request) string { return "svc-nometrics" }, + }, base) + + resp, err := rl.Get(t.Context(), "/m", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } +} + +// Denial path exposes RateLimitError fields. +func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 0, // Always zero refill + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-zero" }, + }, base) + + resp, err := rl.Get(t.Context(), "/z1", nil) + + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/z2", nil) + require.Nil(t, resp) + + if resp != nil { + _ = resp.Body.Close() + } + + var rlErr *RateLimitError + + require.ErrorAs(t, err, &rlErr) + + assert.Equal(t, "svc-zero", rlErr.ServiceKey) + assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) +} + +func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + // Success limiter: plenty of capacity + successRL := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 100, + Burst: 100, + KeyFunc: func(*http.Request) string { return "wrapper-allow" }, + }, base) + + // Deny limiter: zero capacity (covers error branch) + denyRL := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 0, + Burst: 0, + KeyFunc: func(*http.Request) string { return "wrapper-deny" }, + }, base) + + tests := []struct { + name string + call func(h HTTP) (*http.Response, error) + }{ + {"Get", func(h HTTP) (*http.Response, error) { return h.Get(t.Context(), "/g", nil) }}, + {"GetWithHeaders", func(h HTTP) (*http.Response, error) { + return h.GetWithHeaders(t.Context(), "/gh", nil, map[string]string{"X": "1"}) + }}, + {"Post", func(h HTTP) (*http.Response, error) { return h.Post(t.Context(), "/p", nil, []byte("x")) }}, + {"PostWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PostWithHeaders(t.Context(), "/ph", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Patch", func(h HTTP) (*http.Response, error) { return h.Patch(t.Context(), "/pa", nil, []byte("x")) }}, + {"PatchWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PatchWithHeaders(t.Context(), "/pah", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Put", func(h HTTP) (*http.Response, error) { return h.Put(t.Context(), "/put", nil, []byte("x")) }}, + {"PutWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PutWithHeaders(t.Context(), "/puth", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Delete", func(h HTTP) (*http.Response, error) { return h.Delete(t.Context(), "/d", []byte("x")) }}, + {"DeleteWithHeaders", func(h HTTP) (*http.Response, error) { + return h.DeleteWithHeaders(t.Context(), "/dh", []byte("x"), map[string]string{"X": "1"}) + }}, + } + + // Success path + for _, tc := range tests { + t.Run(tc.name+"_Allowed", func(t *testing.T) { + resp, err := tc.call(successRL) + + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + }) + } + + // Denied path (each should hit rate limit before underlying service) + for _, tc := range tests { + t.Run(tc.name+"_RateLimited", func(t *testing.T) { + resp, err := tc.call(denyRL) + + require.Error(t, err) + assert.Nil(t, resp) + + if resp != nil { + _ = resp.Body.Close() + } + + var rlErr *RateLimitError + + assert.ErrorAs(t, err, &rlErr) + }) + } + + // At least all success invocations should have reached downstream. + assert.Equal(t, int64(len(tests)), hits.Load()) +} From e82631849e2bc88889a13781e07847fd44a5c9b1 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Tue, 16 Sep 2025 20:56:03 +0530 Subject: [PATCH 06/21] fix test --- pkg/gofr/service/rate_limiter_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index 5389806039..2ce6eae104 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -118,7 +118,9 @@ func TestLocalRateLimiter_EnforceLimit(t *testing.T) { require.Nil(t, resp) assertRateLimited(t, err) - resp.Body.Close() + if resp != nil { + _ = resp.Body.Close() + } wait(1100 * time.Millisecond) From 51616b5181be2038814aae1f6935db1ae224c67b Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Tue, 16 Sep 2025 22:54:04 +0530 Subject: [PATCH 07/21] fix rate limiter concurrency test --- pkg/gofr/service/rate_limiter_distributed.go | 8 +- pkg/gofr/service/rate_limiter_local_test.go | 403 +++++++++++++++++ pkg/gofr/service/rate_limiter_test.go | 447 ++++--------------- 3 files changed, 494 insertions(+), 364 deletions(-) create mode 100644 pkg/gofr/service/rate_limiter_local_test.go diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index 7f9e293a7d..b4653ee507 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -58,7 +58,7 @@ return {allowed, retryAfter} // distributedRateLimiter with metrics support. type distributedRateLimiter struct { config RateLimiterConfig - redisClient gofrRedis.Redis + redisClient *gofrRedis.Redis logger Logger metrics Metrics HTTP @@ -69,12 +69,16 @@ func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP) HTTP { rl := &distributedRateLimiter{ config: config, - redisClient: *config.RedisClient, + redisClient: config.RedisClient, logger: httpSvc.Logger, metrics: httpSvc.Metrics, HTTP: h, } + if rl.redisClient == nil { + rl.logger.Log("Distributed rate limiter initialized without Redis client; operating pass-through") + } + return rl } diff --git a/pkg/gofr/service/rate_limiter_local_test.go b/pkg/gofr/service/rate_limiter_local_test.go new file mode 100644 index 0000000000..3c755d3a13 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_local_test.go @@ -0,0 +1,403 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "gofr.dev/pkg/gofr/logging" +) + +func newBaseHTTPService(t *testing.T, hitCounter *atomic.Int64) *httpService { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hitCounter.Add(1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + return &httpService{ + Client: http.DefaultClient, + url: srv.URL, + Logger: logging.NewMockLogger(logging.INFO), + Tracer: otel.Tracer("gofr-http-client"), + } +} + +func assertAllowed(t *testing.T, resp *http.Response, err error) { + t.Helper() + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func assertRateLimited(t *testing.T, err error, key ...string) { + t.Helper() + require.Error(t, err) + + var rlErr *RateLimitError + + require.ErrorAs(t, err, &rlErr) + + if len(key) > 0 { + assert.Equal(t, key[0], rlErr.ServiceKey) + } + + assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) +} + +func wait(d time.Duration) { time.Sleep(d) } + +func TestNewLocalRateLimiter_Basic(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 5, + Burst: 5, + KeyFunc: func(*http.Request) string { return "svc-basic" }, + }, base) + + resp, err := rl.Get(t.Context(), "/ok", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(1), hits.Load()) +} + +// Burst=1 then immediate second call denied; after refill allowed again. +func TestLocalRateLimiter_EnforceLimit(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-limit" }, + }, base) + + resp, err := rl.Get(t.Context(), "/r1", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/r2", nil) + require.Nil(t, resp) + assertRateLimited(t, err) + + if resp != nil { + _ = resp.Body.Close() + } + + wait(1100 * time.Millisecond) + + resp, err = rl.Get(t.Context(), "/r3", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(2), hits.Load()) +} + +// Fractional RPS (0.5 -> 1 token every 2s). +func TestLocalRateLimiter_FractionalRPS(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 0.5, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-frac" }, + }, base) + + resp, err := rl.Get(t.Context(), "/a", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/b", nil) + require.Nil(t, resp) + assertRateLimited(t, err) + + if resp != nil { + _ = resp.Body.Close() + } + + wait(2100 * time.Millisecond) + + resp, err = rl.Get(t.Context(), "/c", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(2), hits.Load()) +} + +// Different paths share same bucket via custom KeyFunc. +func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + KeyFunc: func(*http.Request) string { return "shared-key" }, + }, base) + + resp, err := rl.Get(t.Context(), "/p1", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/p2", nil) + require.Nil(t, resp) + assertRateLimited(t, err) + + if resp != nil { + _ = resp.Body.Close() + } + + wait(1100 * time.Millisecond) + + resp, err = rl.Get(t.Context(), "/p3", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + assert.Equal(t, int64(2), hits.Load()) +} + +// Concurrency: Burst=1 & RPS=1 => only one succeeds immediately. +func TestLocalRateLimiter_Concurrency(t *testing.T) { + var hits atomic.Int64 + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 1, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-conc" }, + }, base) + + const workers = 12 + results := make([]error, workers) + + var wg sync.WaitGroup + + wg.Add(workers) + + for i := 0; i < workers; i++ { + go func(i int) { + defer wg.Done() + + resp, err := rl.Get(context.Background(), "/c", nil) + + if resp != nil { + _ = resp.Body.Close() + } + + results[i] = err + }(i) + } + + wg.Wait() + + var allowed, denied int + + for _, e := range results { + if e == nil { + allowed++ + continue + } + + var rlErr *RateLimitError + + if errors.As(e, &rlErr) { + denied++ + continue + } + + t.Fatalf("unexpected error type: %v", e) + } + + assert.Equal(t, 1, allowed) + assert.Equal(t, workers-1, denied) + assert.Equal(t, int64(1), hits.Load()) +} + +// buildFullURL behavior for relative and absolute forms. +func TestBuildFullURL(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + assert.Contains(t, buildFullURL("/x", base), "/x") + assert.Equal(t, "http://example.com/z", buildFullURL("http://example.com/z", base)) + assert.Contains(t, buildFullURL("rel", base), "/rel") +} + +// Ensures metrics calls do not panic when metrics nil (guard path). +func TestLocalRateLimiter_NoMetrics(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 2, + Burst: 2, + KeyFunc: func(*http.Request) string { return "svc-nometrics" }, + }, base) + + resp, err := rl.Get(t.Context(), "/m", nil) + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } +} + +// Denial path exposes RateLimitError fields. +func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 0, // Always zero refill + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-zero" }, + }, base) + + resp, err := rl.Get(t.Context(), "/z1", nil) + + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + + resp, err = rl.Get(t.Context(), "/z2", nil) + require.Nil(t, resp) + + if resp != nil { + _ = resp.Body.Close() + } + + var rlErr *RateLimitError + + require.ErrorAs(t, err, &rlErr) + + assert.Equal(t, "svc-zero", rlErr.ServiceKey) + assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) +} + +func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { + var hits atomic.Int64 + + base := newBaseHTTPService(t, &hits) + + // Success limiter: plenty of capacity + successRL := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 100, + Burst: 100, + KeyFunc: func(*http.Request) string { return "wrapper-allow" }, + }, base) + + // Deny limiter: zero capacity (covers error branch) + denyRL := NewLocalRateLimiter(RateLimiterConfig{ + RequestsPerSecond: 0, + Burst: 0, + KeyFunc: func(*http.Request) string { return "wrapper-deny" }, + }, base) + + tests := []struct { + name string + call func(h HTTP) (*http.Response, error) + }{ + {"Get", func(h HTTP) (*http.Response, error) { return h.Get(t.Context(), "/g", nil) }}, + {"GetWithHeaders", func(h HTTP) (*http.Response, error) { + return h.GetWithHeaders(t.Context(), "/gh", nil, map[string]string{"X": "1"}) + }}, + {"Post", func(h HTTP) (*http.Response, error) { return h.Post(t.Context(), "/p", nil, []byte("x")) }}, + {"PostWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PostWithHeaders(t.Context(), "/ph", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Patch", func(h HTTP) (*http.Response, error) { return h.Patch(t.Context(), "/pa", nil, []byte("x")) }}, + {"PatchWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PatchWithHeaders(t.Context(), "/pah", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Put", func(h HTTP) (*http.Response, error) { return h.Put(t.Context(), "/put", nil, []byte("x")) }}, + {"PutWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PutWithHeaders(t.Context(), "/puth", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Delete", func(h HTTP) (*http.Response, error) { return h.Delete(t.Context(), "/d", []byte("x")) }}, + {"DeleteWithHeaders", func(h HTTP) (*http.Response, error) { + return h.DeleteWithHeaders(t.Context(), "/dh", []byte("x"), map[string]string{"X": "1"}) + }}, + } + + // Success path + for _, tc := range tests { + t.Run(tc.name+"_Allowed", func(t *testing.T) { + resp, err := tc.call(successRL) + + assertAllowed(t, resp, err) + + if resp != nil { + _ = resp.Body.Close() + } + }) + } + + // Denied path (each should hit rate limit before underlying service) + for _, tc := range tests { + t.Run(tc.name+"_RateLimited", func(t *testing.T) { + resp, err := tc.call(denyRL) + + require.Error(t, err) + assert.Nil(t, resp) + + if resp != nil { + _ = resp.Body.Close() + } + + var rlErr *RateLimitError + + assert.ErrorAs(t, err, &rlErr) + }) + } + + // At least all success invocations should have reached downstream. + assert.Equal(t, int64(len(tests)), hits.Load()) +} diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index 2ce6eae104..ff18533001 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -1,26 +1,24 @@ package service import ( - "context" + "crypto/tls" "net/http" "net/http/httptest" - "sync" - "sync/atomic" + "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - + gofrRedis "gofr.dev/pkg/gofr/datasource/redis" "gofr.dev/pkg/gofr/logging" ) -func newBaseHTTPService(t *testing.T, hitCounter *atomic.Int64) *httpService { +func newHTTPService(t *testing.T) *httpService { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - hitCounter.Add(1) w.WriteHeader(http.StatusOK) })) t.Cleanup(srv.Close) @@ -33,383 +31,108 @@ func newBaseHTTPService(t *testing.T, hitCounter *atomic.Int64) *httpService { } } -func assertAllowed(t *testing.T, resp *http.Response, err error) { - t.Helper() - require.NoError(t, err) - require.NotNil(t, resp) - assert.Equal(t, http.StatusOK, resp.StatusCode) -} - -func assertRateLimited(t *testing.T, err error, key ...string) { - t.Helper() - require.Error(t, err) - - var rlErr *RateLimitError - - require.ErrorAs(t, err, &rlErr) - - if len(key) > 0 { - assert.Equal(t, key[0], rlErr.ServiceKey) - } - - assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) -} - -func isAllowed(t *testing.T, err error) bool { - t.Helper() - - return assert.NoError(t, err) -} - -func isRateLimitError(t *testing.T, err error) bool { - t.Helper() - - var rlErr *RateLimitError - - return assert.ErrorAs(t, err, &rlErr) -} - -func wait(d time.Duration) { time.Sleep(d) } - -// Tests ---------------------------------------------------------------------- - -// Ensures constructor wraps *httpService and first request passes. -func TestNewLocalRateLimiter_Basic(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 5, - Burst: 5, - KeyFunc: func(*http.Request) string { return "svc-basic" }, - }, base) - - resp, err := rl.Get(t.Context(), "/ok", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(1), hits.Load()) -} - -// Burst=1 then immediate second call denied; after refill allowed again. -func TestLocalRateLimiter_EnforceLimit(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 1, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-limit" }, - }, base) - - resp, err := rl.Get(t.Context(), "/r1", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/r2", nil) - require.Nil(t, resp) - assertRateLimited(t, err) - - if resp != nil { - _ = resp.Body.Close() - } - - wait(1100 * time.Millisecond) - - resp, err = rl.Get(t.Context(), "/r3", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(2), hits.Load()) -} - -// Fractional RPS (0.5 -> 1 token every 2s). -func TestLocalRateLimiter_FractionalRPS(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 0.5, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-frac" }, - }, base) - - resp, err := rl.Get(t.Context(), "/a", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/b", nil) - require.Nil(t, resp) - assertRateLimited(t, err) - - if resp != nil { - _ = resp.Body.Close() - } - - wait(2100 * time.Millisecond) - - resp, err = rl.Get(t.Context(), "/c", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(2), hits.Load()) -} - -// Different paths share same bucket via custom KeyFunc. -func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 1, - Burst: 1, - KeyFunc: func(*http.Request) string { return "shared-key" }, - }, base) - - resp, err := rl.Get(t.Context(), "/p1", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/p2", nil) - require.Nil(t, resp) - assertRateLimited(t, err) - - if resp != nil { - _ = resp.Body.Close() - } - - wait(1100 * time.Millisecond) - - resp, err = rl.Get(t.Context(), "/p3", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(2), hits.Load()) +func TestRateLimiterConfig_Validate(t *testing.T) { + t.Run("invalid RPS", func(t *testing.T) { + cfg := RateLimiterConfig{RequestsPerSecond: 0, Burst: 1} + err := cfg.Validate() + require.Error(t, err) + assert.ErrorIs(t, err, errInvalidRequestRate) + }) + + t.Run("invalid Burst", func(t *testing.T) { + cfg := RateLimiterConfig{RequestsPerSecond: 1, Burst: 0} + err := cfg.Validate() + require.Error(t, err) + assert.ErrorIs(t, err, errInvalidBurstSize) + }) + + t.Run("sets default KeyFunc when nil", func(t *testing.T) { + cfg := RateLimiterConfig{RequestsPerSecond: 1.5, Burst: 2} + require.Nil(t, cfg.KeyFunc) + require.NoError(t, cfg.Validate()) + require.NotNil(t, cfg.KeyFunc) + }) } -// Concurrency: Burst=1 & RPS=1 => only one succeeds immediately. -func TestLocalRateLimiter_Concurrency(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 1, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-conc" }, - }, base) - - const workers = 12 - - var wg sync.WaitGroup - - wg.Add(workers) - - results := make([]error, workers) - - for i := 0; i < workers; i++ { - go func(i int) { - defer wg.Done() +func TestDefaultKeyFunc(t *testing.T) { + t.Run("nil request", func(t *testing.T) { + assert.Equal(t, "unknown", defaultKeyFunc(nil)) + }) - resp, err := rl.Get(context.Background(), "/c", nil) + t.Run("nil URL", func(t *testing.T) { + req := &http.Request{} + assert.Equal(t, "unknown", defaultKeyFunc(req)) + }) - if resp != nil { - _ = resp.Body.Close() - } - - results[i] = err - }(i) - } - - wg.Wait() + t.Run("http derived scheme", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Host: "example.com"}, + } + assert.Equal(t, "http://example.com", defaultKeyFunc(req)) + }) - var allowed, denied int + t.Run("https derived scheme", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Host: "secure.com"}, + TLS: &tls.ConnectionState{}, + } + assert.Equal(t, "https://secure.com", defaultKeyFunc(req)) + }) - for _, e := range results { - switch { - case isAllowed(t, e): - allowed++ - case isRateLimitError(t, e): - denied++ + t.Run("host from req.Host fallback", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + Host: "fallback:9090", } - } + assert.Equal(t, "http://fallback:9090", defaultKeyFunc(req)) + }) - assert.Equal(t, 1, allowed) - assert.Equal(t, workers-1, denied) - assert.Equal(t, int64(1), hits.Load()) + t.Run("unknown service key when no host present", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + } + assert.Equal(t, "http://unknown", defaultKeyFunc(req)) + }) } -// buildFullURL behavior for relative and absolute forms. -func TestBuildFullURL(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - assert.Contains(t, buildFullURL("/x", base), "/x") - assert.Equal(t, "http://example.com/z", buildFullURL("http://example.com/z", base)) - assert.Contains(t, buildFullURL("rel", base), "/rel") +func TestAddOption_InvalidConfigReturnsOriginal(t *testing.T) { + h := newHTTPService(t) + cfg := RateLimiterConfig{RequestsPerSecond: 0, Burst: 1} // invalid + out := cfg.AddOption(h) + assert.Same(t, h, out) } -// Ensures metrics calls do not panic when metrics nil (guard path). -func TestLocalRateLimiter_NoMetrics(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 2, - Burst: 2, - KeyFunc: func(*http.Request) string { return "svc-nometrics" }, - }, base) +func TestAddOption_LocalLimiter(t *testing.T) { + h := newHTTPService(t) + cfg := RateLimiterConfig{RequestsPerSecond: 2, Burst: 3} + out := cfg.AddOption(h) - resp, err := rl.Get(t.Context(), "/m", nil) - assertAllowed(t, resp, err) + _, isLocal := out.(*localRateLimiter) + assert.True(t, isLocal, "expected *localRateLimiter") - if resp != nil { - _ = resp.Body.Close() - } + assert.NotNil(t, cfg.KeyFunc) } -// Denial path exposes RateLimitError fields. -func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 0, // Always zero refill - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-zero" }, - }, base) - - resp, err := rl.Get(t.Context(), "/z1", nil) - - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/z2", nil) - require.Nil(t, resp) - - if resp != nil { - _ = resp.Body.Close() +func TestAddOption_DistributedLimiter(t *testing.T) { + h := newHTTPService(t) + cfg := RateLimiterConfig{ + RequestsPerSecond: 5, + Burst: 5, + RedisClient: new(gofrRedis.Redis), } - var rlErr *RateLimitError + out := cfg.AddOption(h) + _, isDist := out.(*distributedRateLimiter) - require.ErrorAs(t, err, &rlErr) - - assert.Equal(t, "svc-zero", rlErr.ServiceKey) - assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) + assert.True(t, isDist, "expected *distributedRateLimiter") } -func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - // Success limiter: plenty of capacity - successRL := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 100, - Burst: 100, - KeyFunc: func(*http.Request) string { return "wrapper-allow" }, - }, base) - - // Deny limiter: zero capacity (covers error branch) - denyRL := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 0, - Burst: 0, - KeyFunc: func(*http.Request) string { return "wrapper-deny" }, - }, base) - - tests := []struct { - name string - call func(h HTTP) (*http.Response, error) - }{ - {"Get", func(h HTTP) (*http.Response, error) { return h.Get(t.Context(), "/g", nil) }}, - {"GetWithHeaders", func(h HTTP) (*http.Response, error) { - return h.GetWithHeaders(t.Context(), "/gh", nil, map[string]string{"X": "1"}) - }}, - {"Post", func(h HTTP) (*http.Response, error) { return h.Post(t.Context(), "/p", nil, []byte("x")) }}, - {"PostWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PostWithHeaders(t.Context(), "/ph", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Patch", func(h HTTP) (*http.Response, error) { return h.Patch(t.Context(), "/pa", nil, []byte("x")) }}, - {"PatchWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PatchWithHeaders(t.Context(), "/pah", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Put", func(h HTTP) (*http.Response, error) { return h.Put(t.Context(), "/put", nil, []byte("x")) }}, - {"PutWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PutWithHeaders(t.Context(), "/puth", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Delete", func(h HTTP) (*http.Response, error) { return h.Delete(t.Context(), "/d", []byte("x")) }}, - {"DeleteWithHeaders", func(h HTTP) (*http.Response, error) { - return h.DeleteWithHeaders(t.Context(), "/dh", []byte("x"), map[string]string{"X": "1"}) - }}, - } - - // Success path - for _, tc := range tests { - t.Run(tc.name+"_Allowed", func(t *testing.T) { - resp, err := tc.call(successRL) - - assertAllowed(t, resp, err) +func TestRateLimitError(t *testing.T) { + err := &RateLimitError{ServiceKey: "svc-x", RetryAfter: 1500 * time.Millisecond} - if resp != nil { - _ = resp.Body.Close() - } - }) - } - - // Denied path (each should hit rate limit before underlying service) - for _, tc := range tests { - t.Run(tc.name+"_RateLimited", func(t *testing.T) { - resp, err := tc.call(denyRL) - - require.Error(t, err) - assert.Nil(t, resp) - - if resp != nil { - _ = resp.Body.Close() - } - - var rlErr *RateLimitError - - assert.ErrorAs(t, err, &rlErr) - }) - } + assert.Contains(t, err.Error(), "svc-x") + assert.Contains(t, err.Error(), "retry after") + assert.Equal(t, http.StatusTooManyRequests, err.StatusCode()) - // At least all success invocations should have reached downstream. - assert.Equal(t, int64(len(tests)), hits.Load()) + assert.NotErrorIs(t, err, errInvalidBurstSize, "unexpected error match") } From 71d4f728804fc1bfa1e36fb19de9395f9cb1521b Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Wed, 17 Sep 2025 12:07:05 +0530 Subject: [PATCH 08/21] fix linters --- pkg/gofr/service/rate_limiter_local_test.go | 1 + pkg/gofr/service/rate_limiter_test.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/gofr/service/rate_limiter_local_test.go b/pkg/gofr/service/rate_limiter_local_test.go index 3c755d3a13..92a4c97510 100644 --- a/pkg/gofr/service/rate_limiter_local_test.go +++ b/pkg/gofr/service/rate_limiter_local_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "gofr.dev/pkg/gofr/logging" ) diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index ff18533001..e638529627 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + gofrRedis "gofr.dev/pkg/gofr/datasource/redis" "gofr.dev/pkg/gofr/logging" ) From a10447452d6f001a4ea15f7d306916f8a743382b Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Thu, 25 Sep 2025 12:10:35 +0530 Subject: [PATCH 09/21] make time window generic --- pkg/gofr/service/rate_limiter.go | 23 ++++++-- pkg/gofr/service/rate_limiter_distributed.go | 20 ++++--- pkg/gofr/service/rate_limiter_local.go | 9 +-- pkg/gofr/service/rate_limiter_local_test.go | 62 +++++++++++--------- pkg/gofr/service/rate_limiter_test.go | 16 ++--- 5 files changed, 77 insertions(+), 53 deletions(-) diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index 6cf0c04977..fea281da92 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -17,10 +17,11 @@ var ( // RateLimiterConfig with custom keying support. type RateLimiterConfig struct { - RequestsPerSecond float64 // Token refill rate (must be > 0) - Burst int // Maximum burst capacity (must be > 0) - KeyFunc func(*http.Request) string // Optional custom key extraction - RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting + Requests float64 // Number of requests allowed + Window time.Duration // Time window (e.g., time.Minute, time.Hour) + Burst int // Maximum burst capacity (must be > 0) + KeyFunc func(*http.Request) string // Optional custom key extraction + RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting } // defaultKeyFunc extracts a normalized service key from an HTTP request. @@ -53,8 +54,12 @@ func defaultKeyFunc(req *http.Request) string { // Validate checks if the configuration is valid. func (config *RateLimiterConfig) Validate() error { - if config.RequestsPerSecond <= 0 { - return fmt.Errorf("%w: %f", errInvalidRequestRate, config.RequestsPerSecond) + if config.Requests <= 0 { + return fmt.Errorf("%w: %f", errInvalidRequestRate, config.Requests) + } + + if config.Window <= 0 { + config.Window = time.Minute // Default: per-minute rate limiting } if config.Burst <= 0 { @@ -92,6 +97,12 @@ func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { return NewLocalRateLimiter(*config, h) } +// RequestsPerSecond converts the configured rate to requests per second. +func (config *RateLimiterConfig) RequestsPerSecond() float64 { + // Convert any time window to "requests per second" for internal math + return float64(config.Requests) / config.Window.Seconds() +} + // RateLimitError represents a rate limiting error. type RateLimitError struct { ServiceKey string diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index b4653ee507..3b097f65ad 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -16,8 +16,12 @@ import ( const tokenBucketScript = ` local key = KEYS[1] local burst = tonumber(ARGV[1]) -local refill_rate = tonumber(ARGV[2]) -local now = tonumber(ARGV[3]) +local requests = tonumber(ARGV[2]) +local window_seconds = tonumber(ARGV[3]) +local now = tonumber(ARGV[4]) + +-- Calculate refill rate as requests per second +local refill_rate = requests / window_seconds -- Fetch bucket local bucket = redis.call("HMGET", key, "tokens", "last_refill") @@ -25,8 +29,8 @@ local tokens = tonumber(bucket[1]) local last_refill = tonumber(bucket[2]) if tokens == nil then - tokens = burst - last_refill = now +tokens = burst +last_refill = now end -- Refill tokens @@ -37,10 +41,10 @@ local allowed = 0 local retryAfter = 0 if new_tokens >= 1 then - allowed = 1 - new_tokens = new_tokens - 1 +allowed = 1 +new_tokens = new_tokens - 1 else - retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms +retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms end redis.call("HSET", key, "tokens", new_tokens, "last_refill", now) @@ -108,7 +112,7 @@ func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { tokenBucketScript, []string{"gofr:ratelimit:" + serviceKey}, rl.config.Burst, - rl.config.RequestsPerSecond, + int64(rl.config.Window.Seconds()), now, ) diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go index f35b62dbe6..a0a3149a95 100644 --- a/pkg/gofr/service/rate_limiter_local.go +++ b/pkg/gofr/service/rate_limiter_local.go @@ -73,10 +73,11 @@ func NewLocalRateLimiter(config RateLimiterConfig, h HTTP) HTTP { } // newTokenBucket creates a new atomic token bucket with proper float64 scaling. -func newTokenBucket(maxTokens int, refillRate float64) *tokenBucket { - maxScaled := int64(maxTokens) * scale +func newTokenBucket(config *RateLimiterConfig) *tokenBucket { + maxScaled := int64(config.Burst) * scale - refillPerNanoFloat := refillRate * float64(scale) / float64(time.Second) + requestsPerSecond := config.RequestsPerSecond() + refillPerNanoFloat := requestsPerSecond * float64(scale) / float64(time.Second) return &tokenBucket{ tokens: maxScaled, @@ -209,7 +210,7 @@ func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { now := time.Now().Unix() entry, _ := rl.buckets.LoadOrStore(serviceKey, &bucketEntry{ - bucket: newTokenBucket(rl.config.Burst, rl.config.RequestsPerSecond), + bucket: newTokenBucket(&rl.config), lastAccess: now, }) diff --git a/pkg/gofr/service/rate_limiter_local_test.go b/pkg/gofr/service/rate_limiter_local_test.go index 92a4c97510..150cb1fda5 100644 --- a/pkg/gofr/service/rate_limiter_local_test.go +++ b/pkg/gofr/service/rate_limiter_local_test.go @@ -64,9 +64,10 @@ func TestNewLocalRateLimiter_Basic(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 5, - Burst: 5, - KeyFunc: func(*http.Request) string { return "svc-basic" }, + Requests: 5, + Window: time.Second, + Burst: 5, + KeyFunc: func(*http.Request) string { return "svc-basic" }, }, base) resp, err := rl.Get(t.Context(), "/ok", nil) @@ -86,9 +87,10 @@ func TestLocalRateLimiter_EnforceLimit(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 1, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-limit" }, + Requests: 1, + Window: time.Second, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-limit" }, }, base) resp, err := rl.Get(t.Context(), "/r1", nil) @@ -125,9 +127,10 @@ func TestLocalRateLimiter_FractionalRPS(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 0.5, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-frac" }, + Requests: 0.5, + Window: time.Second, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-frac" }, }, base) resp, err := rl.Get(t.Context(), "/a", nil) @@ -164,9 +167,10 @@ func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 1, - Burst: 1, - KeyFunc: func(*http.Request) string { return "shared-key" }, + Requests: 1, + Window: time.Second, + Burst: 1, + KeyFunc: func(*http.Request) string { return "shared-key" }, }, base) resp, err := rl.Get(t.Context(), "/p1", nil) @@ -206,9 +210,10 @@ func TestLocalRateLimiter_Concurrency(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 1, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-conc" }, + Requests: 1, + Window: time.Second, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-conc" }, }, base) const workers = 12 @@ -275,9 +280,10 @@ func TestLocalRateLimiter_NoMetrics(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 2, - Burst: 2, - KeyFunc: func(*http.Request) string { return "svc-nometrics" }, + Requests: 2, + Window: time.Second, + Burst: 2, + KeyFunc: func(*http.Request) string { return "svc-nometrics" }, }, base) resp, err := rl.Get(t.Context(), "/m", nil) @@ -295,9 +301,10 @@ func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 0, // Always zero refill - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-zero" }, + Requests: 0, // Always zero refill + Window: time.Second, + Burst: 1, + KeyFunc: func(*http.Request) string { return "svc-zero" }, }, base) resp, err := rl.Get(t.Context(), "/z1", nil) @@ -330,16 +337,17 @@ func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { // Success limiter: plenty of capacity successRL := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 100, - Burst: 100, - KeyFunc: func(*http.Request) string { return "wrapper-allow" }, + Requests: 100, + Window: time.Second, + Burst: 100, + KeyFunc: func(*http.Request) string { return "wrapper-allow" }, }, base) // Deny limiter: zero capacity (covers error branch) denyRL := NewLocalRateLimiter(RateLimiterConfig{ - RequestsPerSecond: 0, - Burst: 0, - KeyFunc: func(*http.Request) string { return "wrapper-deny" }, + Requests: 0, + Burst: 0, + KeyFunc: func(*http.Request) string { return "wrapper-deny" }, }, base) tests := []struct { diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index e638529627..78bec3a2e5 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -34,21 +34,21 @@ func newHTTPService(t *testing.T) *httpService { func TestRateLimiterConfig_Validate(t *testing.T) { t.Run("invalid RPS", func(t *testing.T) { - cfg := RateLimiterConfig{RequestsPerSecond: 0, Burst: 1} + cfg := RateLimiterConfig{Requests: 0, Burst: 1} err := cfg.Validate() require.Error(t, err) assert.ErrorIs(t, err, errInvalidRequestRate) }) t.Run("invalid Burst", func(t *testing.T) { - cfg := RateLimiterConfig{RequestsPerSecond: 1, Burst: 0} + cfg := RateLimiterConfig{Requests: 1, Burst: 0} err := cfg.Validate() require.Error(t, err) assert.ErrorIs(t, err, errInvalidBurstSize) }) t.Run("sets default KeyFunc when nil", func(t *testing.T) { - cfg := RateLimiterConfig{RequestsPerSecond: 1.5, Burst: 2} + cfg := RateLimiterConfig{Requests: 1.5, Burst: 2} require.Nil(t, cfg.KeyFunc) require.NoError(t, cfg.Validate()) require.NotNil(t, cfg.KeyFunc) @@ -98,14 +98,14 @@ func TestDefaultKeyFunc(t *testing.T) { func TestAddOption_InvalidConfigReturnsOriginal(t *testing.T) { h := newHTTPService(t) - cfg := RateLimiterConfig{RequestsPerSecond: 0, Burst: 1} // invalid + cfg := RateLimiterConfig{Requests: 0, Burst: 1} // invalid out := cfg.AddOption(h) assert.Same(t, h, out) } func TestAddOption_LocalLimiter(t *testing.T) { h := newHTTPService(t) - cfg := RateLimiterConfig{RequestsPerSecond: 2, Burst: 3} + cfg := RateLimiterConfig{Requests: 2, Burst: 3} out := cfg.AddOption(h) _, isLocal := out.(*localRateLimiter) @@ -117,9 +117,9 @@ func TestAddOption_LocalLimiter(t *testing.T) { func TestAddOption_DistributedLimiter(t *testing.T) { h := newHTTPService(t) cfg := RateLimiterConfig{ - RequestsPerSecond: 5, - Burst: 5, - RedisClient: new(gofrRedis.Redis), + Requests: 5, + Burst: 5, + RedisClient: new(gofrRedis.Redis), } out := cfg.AddOption(h) From 8c8b29bfa93d71f0ca1f1a76a0daf9f938e14714 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Thu, 25 Sep 2025 12:13:34 +0530 Subject: [PATCH 10/21] update documentation --- docs/advanced-guide/http-communication/page.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/advanced-guide/http-communication/page.md b/docs/advanced-guide/http-communication/page.md index c8bad1327a..597a05d565 100644 --- a/docs/advanced-guide/http-communication/page.md +++ b/docs/advanced-guide/http-communication/page.md @@ -124,7 +124,8 @@ a.AddHTTPService("cat-facts", "https://catfact.ninja", }, &service.RateLimiterConfig{ - Rate: 5, + Requests: 5, + Window: time.Minute, Burst: 10, RedisClient: rc, // if RedisClient is nil, in-memory rate limiter will be used }, From 12bb728afe123544050a3af7c0738a7d3f5e82e1 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Mon, 29 Sep 2025 11:37:27 +0530 Subject: [PATCH 11/21] resolve review comments --- pkg/gofr/service/rate_limiter.go | 10 ++- pkg/gofr/service/rate_limiter_distributed.go | 76 +++++--------------- pkg/gofr/service/rate_limiter_local.go | 39 ++++------ pkg/gofr/service/rate_limiter_store.go | 73 +++++++++++++++++++ pkg/gofr/service/rate_limiter_test.go | 17 +---- 5 files changed, 116 insertions(+), 99 deletions(-) create mode 100644 pkg/gofr/service/rate_limiter_store.go diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index fea281da92..cdca11b1a1 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -10,8 +10,8 @@ import ( ) var ( - errInvalidRequestRate = errors.New("requestsPerSecond must be greater than 0") - errInvalidBurstSize = errors.New("burst must be greater than 0") + errInvalidRequestRate = errors.New("requests must be greater than 0 per configured time window") + errBurstLessThanRequests = errors.New("burst must be greater than requests per window") errInvalidRedisResultType = errors.New("unexpected Redis result type") ) @@ -63,7 +63,11 @@ func (config *RateLimiterConfig) Validate() error { } if config.Burst <= 0 { - return fmt.Errorf("%w: %d", errInvalidBurstSize, config.Burst) + config.Burst = int(config.Requests) + } + + if float64(config.Burst) < config.Requests { + return fmt.Errorf("%w: burst=%d, requests=%f", errBurstLessThanRequests, config.Burst, config.Requests) } // Set default key function if not provided. diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index 3b097f65ad..7f46433970 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -5,9 +5,6 @@ import ( "fmt" "net/http" "strconv" - "time" - - gofrRedis "gofr.dev/pkg/gofr/datasource/redis" ) // tokenBucketScript is a Lua script for atomic token bucket rate limiting in Redis. @@ -61,10 +58,10 @@ return {allowed, retryAfter} // distributedRateLimiter with metrics support. type distributedRateLimiter struct { - config RateLimiterConfig - redisClient *gofrRedis.Redis - logger Logger - metrics Metrics + config RateLimiterConfig + store RateLimiterStore + logger Logger + metrics Metrics HTTP } @@ -72,15 +69,11 @@ func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP) HTTP { httpSvc := h.(*httpService) rl := &distributedRateLimiter{ - config: config, - redisClient: config.RedisClient, - logger: httpSvc.Logger, - metrics: httpSvc.Metrics, - HTTP: h, - } - - if rl.redisClient == nil { - rl.logger.Log("Distributed rate limiter initialized without Redis client; operating pass-through") + config: config, + store: NewRedisRateLimiterStore(config.RedisClient), + logger: httpSvc.Logger, + metrics: httpSvc.Metrics, + HTTP: h, } return rl @@ -105,55 +98,24 @@ func toInt64(i any) (int64, error) { // checkRateLimit for distributed version with metrics. func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { serviceKey := rl.config.KeyFunc(req) - now := time.Now().UnixNano() - - cmd := rl.redisClient.Eval( - context.Background(), - tokenBucketScript, - []string{"gofr:ratelimit:" + serviceKey}, - rl.config.Burst, - int64(rl.config.Window.Seconds()), - now, - ) - - result, err := cmd.Result() + + allowed, retryAfter, err := rl.store.Allow(context.Background(), serviceKey, rl.config) if err != nil { - rl.logger.Log("Redis rate limiter error, allowing request", "error", err) - // Record error metric - if rl.metrics != nil { - rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_errors_total", "service", serviceKey, "type", "redis_error") - } + rl.logger.Log("Rate limiter store error, allowing request", "error", err) - return nil // Fail open - } + rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_errors_total", "service", serviceKey, "type", "store_error") - resultArray, ok := result.([]any) - if !ok || len(resultArray) != 2 { - rl.logger.Log("Invalid Redis response format, allowing request") - return nil // Fail open + return nil } - allowed, _ := toInt64(resultArray[0]) - retryAfterMs, _ := toInt64(resultArray[1]) - - if rl.metrics != nil { - rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_requests_total", "service", serviceKey) + rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_requests_total", "service", serviceKey) - if allowed != 1 { - rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_denied_total", "service", serviceKey) - } - } + if !allowed { + rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_denied_total", "service", serviceKey) - if allowed != 1 { - retryAfter := time.Duration(retryAfterMs) * time.Millisecond - rl.logger.Debug("Distributed rate limit exceeded", - "service", serviceKey, - "retry_after", retryAfter) + rl.logger.Debug("Distributed rate limit exceeded", "service", serviceKey, "retry_after", retryAfter) - return &RateLimitError{ - ServiceKey: serviceKey, - RetryAfter: retryAfter, - } + return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} } return nil diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go index a0a3149a95..625662b20d 100644 --- a/pkg/gofr/service/rate_limiter_local.go +++ b/pkg/gofr/service/rate_limiter_local.go @@ -36,7 +36,7 @@ type tokenBucket struct { // localRateLimiter with metrics support. type localRateLimiter struct { config RateLimiterConfig - buckets sync.Map + store RateLimiterStore logger Logger metrics Metrics HTTP @@ -207,32 +207,16 @@ func buildFullURL(path string, httpSvc HTTP) string { // checkRateLimit with custom keying support. func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { serviceKey := rl.config.KeyFunc(req) - now := time.Now().Unix() - entry, _ := rl.buckets.LoadOrStore(serviceKey, &bucketEntry{ - bucket: newTokenBucket(&rl.config), - lastAccess: now, - }) + allowed, retryAfter, _ := rl.store.Allow(req.Context(), serviceKey, rl.config) - bucketEntry := entry.(*bucketEntry) - atomic.StoreInt64(&bucketEntry.lastAccess, now) - - allowed, retryAfter, tokensRemaining := bucketEntry.bucket.allow() - - tokensAvailable := float64(tokensRemaining) / float64(scale) - rl.updateRateLimiterMetrics(req.Context(), serviceKey, allowed, tokensAvailable) + rl.updateRateLimiterMetrics(req.Context(), serviceKey, allowed, 0) if !allowed { - rl.logger.Debug("Rate limit exceeded", - "service", serviceKey, - "rate", rl.config.RequestsPerSecond, - "burst", rl.config.Burst, - "retry_after", retryAfter) - - return &RateLimitError{ - ServiceKey: serviceKey, - RetryAfter: retryAfter, - } + rl.logger.Debug("Rate limit exceeded", "service", serviceKey, "rate", rl.config.RequestsPerSecond, ""+ + "burst", rl.config.Burst, "retry_after", retryAfter) + + return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} } return nil @@ -260,11 +244,16 @@ func (rl *localRateLimiter) cleanupRoutine() { cutoff := time.Now().Unix() - int64(bucketTTL.Seconds()) cleaned := 0 - rl.buckets.Range(func(key, value any) bool { + localStore, ok := rl.store.(*LocalRateLimiterStore) + if !ok { + continue // Not a local store, skip cleanup + } + + localStore.buckets.Range(func(key, value any) bool { entry := value.(*bucketEntry) if atomic.LoadInt64(&entry.lastAccess) < cutoff { - rl.buckets.Delete(key) + localStore.buckets.Delete(key) cleaned++ } diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go new file mode 100644 index 0000000000..6ad4afae81 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_store.go @@ -0,0 +1,73 @@ +package service + +import ( + "context" + "sync" + "time" + + gofrRedis "gofr.dev/pkg/gofr/datasource/redis" +) + +// RateLimiterStore abstracts the storage for rate limiter buckets. +type RateLimiterStore interface { + Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) +} + +// RedisRateLimiterStore implements RateLimiterStore using Redis. +type RedisRateLimiterStore struct { + client *gofrRedis.Redis +} + +func NewRedisRateLimiterStore(client *gofrRedis.Redis) *RedisRateLimiterStore { + return &RedisRateLimiterStore{client: client} +} + +func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { + now := time.Now().UnixNano() + + cmd := r.client.Eval( + ctx, + tokenBucketScript, + []string{"gofr:ratelimit:" + key}, + config.Burst, + int64(config.Window.Seconds()), + now, + ) + + result, err := cmd.Result() + if err != nil { + return true, 0, err // Fail open + } + + resultArray, ok := result.([]any) + if !ok || len(resultArray) != 2 { + return true, 0, errInvalidRedisResultType // Fail open + } + + allowed, _ := toInt64(resultArray[0]) + + retryAfterMs, _ := toInt64(resultArray[1]) + + return allowed == 1, time.Duration(retryAfterMs) * time.Millisecond, nil +} + +// LocalRateLimiterStore implements RateLimiterStore using in-memory buckets. +type LocalRateLimiterStore struct { + buckets *sync.Map +} + +func NewLocalRateLimiterStore() *LocalRateLimiterStore { + return &LocalRateLimiterStore{buckets: &sync.Map{}} +} + +func (l *LocalRateLimiterStore) Allow(_ context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { + now := time.Now().Unix() + entry, _ := l.buckets.LoadOrStore(key, &bucketEntry{ + bucket: newTokenBucket(&config), + lastAccess: now, + }) + bucketEntry := entry.(*bucketEntry) + allowed, retryAfter, _ := bucketEntry.bucket.allow() + + return allowed, retryAfter, nil +} diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index 78bec3a2e5..3960539f09 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -6,7 +6,6 @@ import ( "net/http/httptest" "net/url" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -40,11 +39,11 @@ func TestRateLimiterConfig_Validate(t *testing.T) { assert.ErrorIs(t, err, errInvalidRequestRate) }) - t.Run("invalid Burst", func(t *testing.T) { - cfg := RateLimiterConfig{Requests: 1, Burst: 0} + t.Run("burst less than requests", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 5, Burst: 3} err := cfg.Validate() require.Error(t, err) - assert.ErrorIs(t, err, errInvalidBurstSize) + assert.ErrorIs(t, err, errBurstLessThanRequests) }) t.Run("sets default KeyFunc when nil", func(t *testing.T) { @@ -127,13 +126,3 @@ func TestAddOption_DistributedLimiter(t *testing.T) { assert.True(t, isDist, "expected *distributedRateLimiter") } - -func TestRateLimitError(t *testing.T) { - err := &RateLimitError{ServiceKey: "svc-x", RetryAfter: 1500 * time.Millisecond} - - assert.Contains(t, err.Error(), "svc-x") - assert.Contains(t, err.Error(), "retry after") - assert.Equal(t, http.StatusTooManyRequests, err.StatusCode()) - - assert.NotErrorIs(t, err, errInvalidBurstSize, "unexpected error match") -} From 51d3388981ed54b45fd8dc364c4644a88a39a7c4 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Mon, 29 Sep 2025 12:50:16 +0530 Subject: [PATCH 12/21] replace concrete rate limiter stores with interface --- .../advanced-guide/http-communication/page.md | 11 ++++++++- pkg/gofr/container/container.go | 9 ++++--- pkg/gofr/service/rate_limiter.go | 24 ++++++++++--------- pkg/gofr/service/rate_limiter_distributed.go | 4 ++-- pkg/gofr/service/rate_limiter_local.go | 3 ++- pkg/gofr/service/rate_limiter_local_test.go | 19 ++++++++------- pkg/gofr/service/rate_limiter_store.go | 7 +++--- pkg/gofr/service/rate_limiter_test.go | 6 ++--- 8 files changed, 50 insertions(+), 33 deletions(-) diff --git a/docs/advanced-guide/http-communication/page.md b/docs/advanced-guide/http-communication/page.md index 597a05d565..f134cb9b18 100644 --- a/docs/advanced-guide/http-communication/page.md +++ b/docs/advanced-guide/http-communication/page.md @@ -97,6 +97,15 @@ GoFr provides its user with additional configurational options while registering - **RetryConfig** - This option allows user to add the maximum number of retry count if before returning error if any downstream HTTP Service fails. - **RateLimiterConfig** - This option allows user to configure rate limiting for downstream service calls using token bucket algorithm. It controls the request rate to prevent overwhelming dependent services and supports both in-memory and Redis-based implementations. +**Rate Limiter Store: Customization** +GoFr allows you to use a custom rate limiter store by implementing the RateLimiterStore interface.This enables integration with any backend (e.g., Redis, database, or custom logic) +Interface: +```go +type RateLimiterStore interface { + Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter int64, err error) +} +``` + #### Usage: ```go @@ -127,7 +136,7 @@ a.AddHTTPService("cat-facts", "https://catfact.ninja", Requests: 5, Window: time.Minute, Burst: 10, - RedisClient: rc, // if RedisClient is nil, in-memory rate limiter will be used + Store: service.NewRedisRateLimiterStore(rc)}, // Skip this field to use in-memory store }, ) ``` \ No newline at end of file diff --git a/pkg/gofr/container/container.go b/pkg/gofr/container/container.go index 7b5fc33a7a..438f8638fc 100644 --- a/pkg/gofr/container/container.go +++ b/pkg/gofr/container/container.go @@ -260,9 +260,12 @@ func (c *Container) registerFrameworkMetrics() { httpBuckets := []float64{.001, .003, .005, .01, .02, .03, .05, .1, .2, .3, .5, .75, 1, 2, 3, 5, 10, 30} c.Metrics().NewHistogram("app_http_response", "Response time of HTTP requests in seconds.", httpBuckets...) c.Metrics().NewHistogram("app_http_service_response", "Response time of HTTP service requests in seconds.", httpBuckets...) - c.Metrics().NewCounter("app_rate_limiter_requests_total", "Total rate limiter requests") - c.Metrics().NewCounter("app_rate_limiter_denied_total", "Total denied requests") - c.Metrics().NewGauge("app_rate_limiter_tokens_available", "Current tokens available") + + rateLimiterBuckets := []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5} + c.Metrics().NewHistogram("app_rate_limiter_stats", "Response time of rate limiter checks in milliseconds.", rateLimiterBuckets...) + c.Metrics().NewCounter("app_rate_limiter_requests_total", "Total rate limiter requests.") + c.Metrics().NewCounter("app_rate_limiter_denied_total", "Total rate limiter denied requests.") + c.Metrics().NewCounter("app_rate_limiter_errors_total", "Total rate limiter errors.") } { // Redis metrics diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index cdca11b1a1..1244030558 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -5,8 +5,6 @@ import ( "fmt" "net/http" "time" - - gofrRedis "gofr.dev/pkg/gofr/datasource/redis" ) var ( @@ -17,11 +15,11 @@ var ( // RateLimiterConfig with custom keying support. type RateLimiterConfig struct { - Requests float64 // Number of requests allowed - Window time.Duration // Time window (e.g., time.Minute, time.Hour) - Burst int // Maximum burst capacity (must be > 0) - KeyFunc func(*http.Request) string // Optional custom key extraction - RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting + Requests float64 // Number of requests allowed + Window time.Duration // Time window (e.g., time.Minute, time.Hour) + Burst int // Maximum burst capacity (must be > 0) + KeyFunc func(*http.Request) string // Optional custom key extraction + Store RateLimiterStore } // defaultKeyFunc extracts a normalized service key from an HTTP request. @@ -88,9 +86,13 @@ func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { return h } - // Choose implementation based on Redis client availability. - if config.RedisClient != nil { - return NewDistributedRateLimiter(*config, h) + // Choose implementation based on Redis client availability. Default to local store if not set + if config.Store != nil { + if _, ok := config.Store.(*RedisRateLimiterStore); ok { + return NewDistributedRateLimiter(*config, h, config.Store) + } + + return NewLocalRateLimiter(*config, h, config.Store) } // Log warning for local rate limiting. @@ -98,7 +100,7 @@ func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { httpSvc.Logger.Log("Using local rate limiting - not suitable for multi-instance deployments") } - return NewLocalRateLimiter(*config, h) + return NewLocalRateLimiter(*config, h, NewLocalRateLimiterStore()) } // RequestsPerSecond converts the configured rate to requests per second. diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index 7f46433970..5b08be7014 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -65,12 +65,12 @@ type distributedRateLimiter struct { HTTP } -func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP) HTTP { +func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP, store RateLimiterStore) HTTP { httpSvc := h.(*httpService) rl := &distributedRateLimiter{ config: config, - store: NewRedisRateLimiterStore(config.RedisClient), + store: store, logger: httpSvc.Logger, metrics: httpSvc.Metrics, HTTP: h, diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go index 625662b20d..94f003a9f9 100644 --- a/pkg/gofr/service/rate_limiter_local.go +++ b/pkg/gofr/service/rate_limiter_local.go @@ -57,11 +57,12 @@ const ( ) // NewLocalRateLimiter creates a new local rate limiter with metrics. -func NewLocalRateLimiter(config RateLimiterConfig, h HTTP) HTTP { +func NewLocalRateLimiter(config RateLimiterConfig, h HTTP, store RateLimiterStore) HTTP { httpSvc := h.(*httpService) rl := &localRateLimiter{ config: config, + store: store, logger: httpSvc.Logger, metrics: httpSvc.Metrics, HTTP: h, diff --git a/pkg/gofr/service/rate_limiter_local_test.go b/pkg/gofr/service/rate_limiter_local_test.go index 150cb1fda5..dd09644ea9 100644 --- a/pkg/gofr/service/rate_limiter_local_test.go +++ b/pkg/gofr/service/rate_limiter_local_test.go @@ -64,11 +64,12 @@ func TestNewLocalRateLimiter_Basic(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ + Requests: 5, Window: time.Second, Burst: 5, KeyFunc: func(*http.Request) string { return "svc-basic" }, - }, base) + }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/ok", nil) assertAllowed(t, resp, err) @@ -91,7 +92,7 @@ func TestLocalRateLimiter_EnforceLimit(t *testing.T) { Window: time.Second, Burst: 1, KeyFunc: func(*http.Request) string { return "svc-limit" }, - }, base) + }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/r1", nil) assertAllowed(t, resp, err) @@ -131,7 +132,7 @@ func TestLocalRateLimiter_FractionalRPS(t *testing.T) { Window: time.Second, Burst: 1, KeyFunc: func(*http.Request) string { return "svc-frac" }, - }, base) + }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/a", nil) assertAllowed(t, resp, err) @@ -171,7 +172,7 @@ func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) { Window: time.Second, Burst: 1, KeyFunc: func(*http.Request) string { return "shared-key" }, - }, base) + }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/p1", nil) assertAllowed(t, resp, err) @@ -214,7 +215,7 @@ func TestLocalRateLimiter_Concurrency(t *testing.T) { Window: time.Second, Burst: 1, KeyFunc: func(*http.Request) string { return "svc-conc" }, - }, base) + }, base, NewLocalRateLimiterStore()) const workers = 12 results := make([]error, workers) @@ -284,7 +285,7 @@ func TestLocalRateLimiter_NoMetrics(t *testing.T) { Window: time.Second, Burst: 2, KeyFunc: func(*http.Request) string { return "svc-nometrics" }, - }, base) + }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/m", nil) assertAllowed(t, resp, err) @@ -305,7 +306,7 @@ func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { Window: time.Second, Burst: 1, KeyFunc: func(*http.Request) string { return "svc-zero" }, - }, base) + }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/z1", nil) @@ -341,14 +342,14 @@ func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { Window: time.Second, Burst: 100, KeyFunc: func(*http.Request) string { return "wrapper-allow" }, - }, base) + }, base, NewLocalRateLimiterStore()) // Deny limiter: zero capacity (covers error branch) denyRL := NewLocalRateLimiter(RateLimiterConfig{ Requests: 0, Burst: 0, KeyFunc: func(*http.Request) string { return "wrapper-deny" }, - }, base) + }, base, NewLocalRateLimiterStore()) tests := []struct { name string diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go index 6ad4afae81..0edf258ba5 100644 --- a/pkg/gofr/service/rate_limiter_store.go +++ b/pkg/gofr/service/rate_limiter_store.go @@ -29,9 +29,10 @@ func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config Ra ctx, tokenBucketScript, []string{"gofr:ratelimit:" + key}, - config.Burst, - int64(config.Window.Seconds()), - now, + config.Burst, // ARGV[1]: burst + config.Requests, // ARGV[2]: requests + int64(config.Window.Seconds()), // ARGV[3]: window_seconds + now, // ARGV[4]: now (nanoseconds) ) result, err := cmd.Result() diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index 3960539f09..ced05f3d77 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -116,9 +116,9 @@ func TestAddOption_LocalLimiter(t *testing.T) { func TestAddOption_DistributedLimiter(t *testing.T) { h := newHTTPService(t) cfg := RateLimiterConfig{ - Requests: 5, - Burst: 5, - RedisClient: new(gofrRedis.Redis), + Requests: 5, + Burst: 5, + Store: NewRedisRateLimiterStore(new(gofrRedis.Redis)), } out := cfg.AddOption(h) From a90d0e26a7ec9879d8508e3cf1219f7f7c230738 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Tue, 30 Sep 2025 11:39:25 +0530 Subject: [PATCH 13/21] add more tests --- pkg/gofr/service/rate_limiter_distributed.go | 8 +++ pkg/gofr/service/rate_limiter_local.go | 8 +++ pkg/gofr/service/rate_limiter_local_test.go | 76 +++++++++++++------- pkg/gofr/service/rate_limiter_test.go | 38 ++++++++++ 4 files changed, 104 insertions(+), 26 deletions(-) diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go index 5b08be7014..b57a762144 100644 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ b/pkg/gofr/service/rate_limiter_distributed.go @@ -66,6 +66,14 @@ type distributedRateLimiter struct { } func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP, store RateLimiterStore) HTTP { + if err := config.Validate(); err != nil { + if httpSvc, ok := h.(*httpService); ok { + httpSvc.Logger.Log("Invalid rate limiter config, disabling distributed rate limiting", "error", err) + } + + return h + } + httpSvc := h.(*httpService) rl := &distributedRateLimiter{ diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go index 94f003a9f9..14f9a03417 100644 --- a/pkg/gofr/service/rate_limiter_local.go +++ b/pkg/gofr/service/rate_limiter_local.go @@ -58,6 +58,14 @@ const ( // NewLocalRateLimiter creates a new local rate limiter with metrics. func NewLocalRateLimiter(config RateLimiterConfig, h HTTP, store RateLimiterStore) HTTP { + if err := config.Validate(); err != nil { + if httpSvc, ok := h.(*httpService); ok { + httpSvc.Logger.Log("Invalid rate limiter config, disabling local rate limiting", "error", err) + } + + return h + } + httpSvc := h.(*httpService) rl := &localRateLimiter{ diff --git a/pkg/gofr/service/rate_limiter_local_test.go b/pkg/gofr/service/rate_limiter_local_test.go index dd09644ea9..adc1f050f3 100644 --- a/pkg/gofr/service/rate_limiter_local_test.go +++ b/pkg/gofr/service/rate_limiter_local_test.go @@ -87,11 +87,13 @@ func TestLocalRateLimiter_EnforceLimit(t *testing.T) { base := newBaseHTTPService(t, &hits) + key := "svc-limit-" + time.Now().Format("150405.000000000") + rl := NewLocalRateLimiter(RateLimiterConfig{ Requests: 1, Window: time.Second, Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-limit" }, + KeyFunc: func(*http.Request) string { return key }, }, base, NewLocalRateLimiterStore()) resp, err := rl.Get(t.Context(), "/r1", nil) @@ -302,7 +304,7 @@ func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { base := newBaseHTTPService(t, &hits) rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 0, // Always zero refill + Requests: 1, // Always zero refill Window: time.Second, Burst: 1, KeyFunc: func(*http.Request) string { return "svc-zero" }, @@ -331,26 +333,17 @@ func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) } -func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { +func TestLocalRateLimiter_WrapperMethods_Allowed(t *testing.T) { var hits atomic.Int64 - base := newBaseHTTPService(t, &hits) - // Success limiter: plenty of capacity - successRL := NewLocalRateLimiter(RateLimiterConfig{ + rl := NewLocalRateLimiter(RateLimiterConfig{ Requests: 100, Window: time.Second, Burst: 100, KeyFunc: func(*http.Request) string { return "wrapper-allow" }, }, base, NewLocalRateLimiterStore()) - // Deny limiter: zero capacity (covers error branch) - denyRL := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 0, - Burst: 0, - KeyFunc: func(*http.Request) string { return "wrapper-deny" }, - }, base, NewLocalRateLimiterStore()) - tests := []struct { name string call func(h HTTP) (*http.Response, error) @@ -377,10 +370,9 @@ func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { }}, } - // Success path for _, tc := range tests { - t.Run(tc.name+"_Allowed", func(t *testing.T) { - resp, err := tc.call(successRL) + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.call(rl) assertAllowed(t, resp, err) @@ -390,24 +382,56 @@ func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) { }) } - // Denied path (each should hit rate limit before underlying service) + assert.Equal(t, int64(len(tests)), hits.Load()) +} + +func TestLocalRateLimiter_WrapperMethods_InvalidConfig(t *testing.T) { + var hits atomic.Int64 + base := newBaseHTTPService(t, &hits) + + rl := NewLocalRateLimiter(RateLimiterConfig{ + Requests: 0, + Burst: 0, + KeyFunc: func(*http.Request) string { return "wrapper-deny" }, + }, base, NewLocalRateLimiterStore()) + + tests := []struct { + name string + call func(h HTTP) (*http.Response, error) + }{ + {"Get", func(h HTTP) (*http.Response, error) { return h.Get(t.Context(), "/g", nil) }}, + {"GetWithHeaders", func(h HTTP) (*http.Response, error) { + return h.GetWithHeaders(t.Context(), "/gh", nil, map[string]string{"X": "1"}) + }}, + {"Post", func(h HTTP) (*http.Response, error) { return h.Post(t.Context(), "/p", nil, []byte("x")) }}, + {"PostWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PostWithHeaders(t.Context(), "/ph", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Patch", func(h HTTP) (*http.Response, error) { return h.Patch(t.Context(), "/pa", nil, []byte("x")) }}, + {"PatchWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PatchWithHeaders(t.Context(), "/pah", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Put", func(h HTTP) (*http.Response, error) { return h.Put(t.Context(), "/put", nil, []byte("x")) }}, + {"PutWithHeaders", func(h HTTP) (*http.Response, error) { + return h.PutWithHeaders(t.Context(), "/puth", nil, []byte("x"), map[string]string{"X": "1"}) + }}, + {"Delete", func(h HTTP) (*http.Response, error) { return h.Delete(t.Context(), "/d", []byte("x")) }}, + {"DeleteWithHeaders", func(h HTTP) (*http.Response, error) { + return h.DeleteWithHeaders(t.Context(), "/dh", []byte("x"), map[string]string{"X": "1"}) + }}, + } + for _, tc := range tests { - t.Run(tc.name+"_RateLimited", func(t *testing.T) { - resp, err := tc.call(denyRL) + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.call(rl) - require.Error(t, err) - assert.Nil(t, resp) + assertAllowed(t, resp, err) if resp != nil { _ = resp.Body.Close() } - - var rlErr *RateLimitError - - assert.ErrorAs(t, err, &rlErr) }) } - // At least all success invocations should have reached downstream. assert.Equal(t, int64(len(tests)), hits.Load()) } diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index ced05f3d77..d3cd28c4be 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -1,11 +1,13 @@ package service import ( + "context" "crypto/tls" "net/http" "net/http/httptest" "net/url" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -126,3 +128,39 @@ func TestAddOption_DistributedLimiter(t *testing.T) { assert.True(t, isDist, "expected *distributedRateLimiter") } + +type dummyStore struct{} + +func (dummyStore) Allow(_ context.Context, _ string, _ RateLimiterConfig) (allowed bool, + retryAfter time.Duration, err error) { + return true, 0, nil +} + +func TestNewDistributedRateLimiter_WithHTTPService_Success(t *testing.T) { + config := RateLimiterConfig{ + Requests: 10, + Window: time.Minute, + Burst: 10, + } + h := newHTTPService(t) + store := &dummyStore{} + + result := NewDistributedRateLimiter(config, h, store) + + _, ok := result.(*distributedRateLimiter) + assert.True(t, ok, "should return distributedRateLimiter") +} + +func TestNewDistributedRateLimiter_WithHTTPService_Error(t *testing.T) { + config := RateLimiterConfig{ + Requests: 0, // Invalid + Window: time.Minute, + Burst: 10, + } + h := newHTTPService(t) + store := &dummyStore{} + + result := NewDistributedRateLimiter(config, h, store) + + assert.Same(t, h, result, "should return original HTTP on invalid config") +} From fd7fe70adc11204eb76b72088423d8c45af4e9cb Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Tue, 30 Sep 2025 16:51:53 +0530 Subject: [PATCH 14/21] refactor implementation to unify the structs and remove duplicate codes --- pkg/gofr/service/rate_limiter.go | 258 +++++++---- pkg/gofr/service/rate_limiter_config.go | 127 ++++++ pkg/gofr/service/rate_limiter_config_test.go | 143 ++++++ pkg/gofr/service/rate_limiter_distributed.go | 258 ----------- pkg/gofr/service/rate_limiter_local.go | 406 ----------------- pkg/gofr/service/rate_limiter_local_test.go | 437 ------------------- pkg/gofr/service/rate_limiter_store.go | 230 +++++++++- pkg/gofr/service/rate_limiter_store_test.go | 106 +++++ pkg/gofr/service/rate_limiter_test.go | 292 +++++++------ 9 files changed, 938 insertions(+), 1319 deletions(-) create mode 100644 pkg/gofr/service/rate_limiter_config.go create mode 100644 pkg/gofr/service/rate_limiter_config_test.go delete mode 100644 pkg/gofr/service/rate_limiter_distributed.go delete mode 100644 pkg/gofr/service/rate_limiter_local.go delete mode 100644 pkg/gofr/service/rate_limiter_local_test.go create mode 100644 pkg/gofr/service/rate_limiter_store_test.go diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index 1244030558..498e493cd9 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -1,125 +1,231 @@ package service import ( - "errors" - "fmt" + "context" "net/http" - "time" + "strings" ) -var ( - errInvalidRequestRate = errors.New("requests must be greater than 0 per configured time window") - errBurstLessThanRequests = errors.New("burst must be greater than requests per window") - errInvalidRedisResultType = errors.New("unexpected Redis result type") -) +// rateLimiter provides unified rate limiting for HTTP clients. +type rateLimiter struct { + config RateLimiterConfig + store RateLimiterStore + logger Logger + metrics Metrics + HTTP // Embedded HTTP service +} + +// NewRateLimiter creates a new unified rate limiter. +func NewRateLimiter(config RateLimiterConfig, h HTTP) HTTP { + httpSvc := h.(*httpService) + + rl := &rateLimiter{ + config: config, + store: config.Store, + logger: httpSvc.Logger, + metrics: httpSvc.Metrics, + HTTP: h, + } + + // Start cleanup routine + ctx := context.Background() + rl.store.StartCleanup(ctx, rl.logger) -// RateLimiterConfig with custom keying support. -type RateLimiterConfig struct { - Requests float64 // Number of requests allowed - Window time.Duration // Time window (e.g., time.Minute, time.Hour) - Burst int // Maximum burst capacity (must be > 0) - KeyFunc func(*http.Request) string // Optional custom key extraction - Store RateLimiterStore + return rl } -// defaultKeyFunc extracts a normalized service key from an HTTP request. -func defaultKeyFunc(req *http.Request) string { - if req == nil || req.URL == nil { - return "unknown" +// buildFullURL constructs an absolute URL by combining the base service URL with the given path. +func (rl *rateLimiter) buildFullURL(path string) string { + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path } - scheme := req.URL.Scheme - host := req.URL.Host + // Get base URL from embedded HTTP service + httpSvcImpl, ok := rl.HTTP.(*httpService) + if !ok { + return path + } - if scheme == "" { - if req.TLS != nil { - scheme = methodHTTPS - } else { - scheme = methodHTTP - } + base := strings.TrimRight(httpSvcImpl.url, "/") + if base == "" { + return path } - if host == "" { - host = req.Host + // Ensure path starts with / + if !strings.HasPrefix(path, "/") { + path = "/" + path } - if host == "" { - host = unknownServiceKey + return base + path +} + +// checkRateLimit performs rate limit check using the configured store. +func (rl *rateLimiter) checkRateLimit(req *http.Request) error { + serviceKey := rl.config.KeyFunc(req) + allowed, retryAfter, err := rl.store.Allow(req.Context(), serviceKey, rl.config) + + // Update metrics + rl.updateRateLimiterMetrics(req.Context(), serviceKey, allowed, err) + + if err != nil { + rl.logger.Log("Rate limiter store error, allowing request", "error", err) + + return nil // Fail open } - return scheme + "://" + host + if !allowed { + rl.logger.Debug("Rate limit exceeded", "service", serviceKey, "rate", rl.config.RequestsPerSecond(), + "burst", rl.config.Burst, "retry_after", retryAfter) + + return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} + } + + return nil } -// Validate checks if the configuration is valid. -func (config *RateLimiterConfig) Validate() error { - if config.Requests <= 0 { - return fmt.Errorf("%w: %f", errInvalidRequestRate, config.Requests) +// updateRateLimiterMetrics updates metrics for rate limiting operations. +func (rl *rateLimiter) updateRateLimiterMetrics(ctx context.Context, serviceKey string, allowed bool, err error) { + if rl.metrics == nil { + return } - if config.Window <= 0 { - config.Window = time.Minute // Default: per-minute rate limiting + rl.metrics.IncrementCounter(ctx, "app_rate_limiter_requests_total", "service", serviceKey) + + if err != nil { + rl.metrics.IncrementCounter(ctx, "app_rate_limiter_errors_total", "service", serviceKey, "type", "store_error") } - if config.Burst <= 0 { - config.Burst = int(config.Requests) + if !allowed { + rl.metrics.IncrementCounter(ctx, "app_rate_limiter_denied_total", "service", serviceKey) } +} + +// HTTP Method Implementations - All methods follow the same pattern. + +// Get performs rate-limited HTTP GET request. +func (rl *rateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Get(ctx, path, queryParams) +} + +// GetWithHeaders performs rate-limited HTTP GET request with custom headers. +func (rl *rateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, + headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) - if float64(config.Burst) < config.Requests { - return fmt.Errorf("%w: burst=%d, requests=%f", errBurstLessThanRequests, config.Burst, config.Requests) + if err := rl.checkRateLimit(req); err != nil { + return nil, err } - // Set default key function if not provided. - if config.KeyFunc == nil { - config.KeyFunc = defaultKeyFunc + return rl.HTTP.GetWithHeaders(ctx, path, queryParams, headers) +} + +// Post performs rate-limited HTTP POST request. +func (rl *rateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err } - return nil + return rl.HTTP.Post(ctx, path, queryParams, body) } -// AddOption implements the Options interface. -func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { - if err := config.Validate(); err != nil { - if httpSvc, ok := h.(*httpService); ok { - httpSvc.Logger.Log("Invalid rate limiter config, disabling rate limiting", "error", err) - } +// PostWithHeaders performs rate-limited HTTP POST request with custom headers. +func (rl *rateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) - return h + if err := rl.checkRateLimit(req); err != nil { + return nil, err } - // Choose implementation based on Redis client availability. Default to local store if not set - if config.Store != nil { - if _, ok := config.Store.(*RedisRateLimiterStore); ok { - return NewDistributedRateLimiter(*config, h, config.Store) - } + return rl.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers) +} + +// Put performs rate-limited HTTP PUT request. +func (rl *rateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) - return NewLocalRateLimiter(*config, h, config.Store) + if err := rl.checkRateLimit(req); err != nil { + return nil, err } - // Log warning for local rate limiting. - if httpSvc, ok := h.(*httpService); ok { - httpSvc.Logger.Log("Using local rate limiting - not suitable for multi-instance deployments") + return rl.HTTP.Put(ctx, path, queryParams, body) +} + +// PutWithHeaders performs rate-limited HTTP PUT request with custom headers. +func (rl *rateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, + headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err } - return NewLocalRateLimiter(*config, h, NewLocalRateLimiterStore()) + return rl.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers) } -// RequestsPerSecond converts the configured rate to requests per second. -func (config *RateLimiterConfig) RequestsPerSecond() float64 { - // Convert any time window to "requests per second" for internal math - return float64(config.Requests) / config.Window.Seconds() +// Patch performs rate-limited HTTP PATCH request. +func (rl *rateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Patch(ctx, path, queryParams, body) } -// RateLimitError represents a rate limiting error. -type RateLimitError struct { - ServiceKey string - RetryAfter time.Duration +// PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. +func (rl *rateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers) } -func (e *RateLimitError) Error() string { - return fmt.Sprintf("rate limit exceeded for service: %s, retry after: %v", e.ServiceKey, e.RetryAfter) +// Delete performs rate-limited HTTP DELETE request. +func (rl *rateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Delete(ctx, path, body) } -// StatusCode Implement StatusCodeResponder so Responder picks correct HTTP code. -func (*RateLimitError) StatusCode() int { - return http.StatusTooManyRequests // 429 +// DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. +func (rl *rateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, + headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.DeleteWithHeaders(ctx, path, body, headers) } diff --git a/pkg/gofr/service/rate_limiter_config.go b/pkg/gofr/service/rate_limiter_config.go new file mode 100644 index 0000000000..4698bbc44b --- /dev/null +++ b/pkg/gofr/service/rate_limiter_config.go @@ -0,0 +1,127 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "time" +) + +var ( + errInvalidRequestRate = errors.New("requests must be greater than 0 per configured time window") + errBurstLessThanRequests = errors.New("burst must be greater than requests per window") + errInvalidRedisResultType = errors.New("unexpected Redis result type") +) + +const ( + unknownServiceKey = "unknown" + methodHTTP = "http" + methodHTTPS = "https" +) + +// RateLimiterConfig with custom keying support. +type RateLimiterConfig struct { + Requests float64 // Number of requests allowed + Window time.Duration // Time window (e.g., time.Minute, time.Hour) + Burst int // Maximum burst capacity (must be > 0) + KeyFunc func(*http.Request) string // Optional custom key extraction + Store RateLimiterStore +} + +// defaultKeyFunc extracts a normalized service key from an HTTP request. +func defaultKeyFunc(req *http.Request) string { + if req == nil || req.URL == nil { + return unknownServiceKey + } + + scheme := req.URL.Scheme + host := req.URL.Host + + if scheme == "" { + if req.TLS != nil { + scheme = methodHTTPS + } else { + scheme = methodHTTP + } + } + + if host == "" { + host = req.Host + } + + if host == "" { + host = unknownServiceKey + } + + return scheme + "://" + host +} + +// Validate checks if the configuration is valid. +func (config *RateLimiterConfig) Validate() error { + if config.Requests <= 0 { + return fmt.Errorf("%w: %f", errInvalidRequestRate, config.Requests) + } + + if config.Window <= 0 { + config.Window = time.Minute // Default: per-minute rate limiting + } + + if config.Burst <= 0 { + config.Burst = int(config.Requests) + } + + if float64(config.Burst) < config.Requests { + return fmt.Errorf("%w: burst=%d, requests=%f", errBurstLessThanRequests, config.Burst, config.Requests) + } + + // Set default key function if not provided. + if config.KeyFunc == nil { + config.KeyFunc = defaultKeyFunc + } + + return nil +} + +// AddOption implements the Options interface. +func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { + if err := config.Validate(); err != nil { + if httpSvc, ok := h.(*httpService); ok { + httpSvc.Logger.Log("Invalid rate limiter config, disabling rate limiting", "error", err) + } + + return h + } + + // Default to local store if not set + if config.Store == nil { + config.Store = NewLocalRateLimiterStore() + + // Log warning for local rate limiting. + if httpSvc, ok := h.(*httpService); ok { + httpSvc.Logger.Log("Using local rate limiting - not suitable for multi-instance deployments") + } + } + + return NewRateLimiter(*config, h) +} + +// RequestsPerSecond converts the configured rate to requests per second. +func (config *RateLimiterConfig) RequestsPerSecond() float64 { + // Convert any time window to "requests per second" for internal math + return float64(config.Requests) / config.Window.Seconds() +} + +// RateLimitError represents a rate limiting error. +type RateLimitError struct { + ServiceKey string + RetryAfter time.Duration +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limit exceeded for service: %s, retry after: %v", e.ServiceKey, e.RetryAfter) +} + +// StatusCode Implement StatusCodeResponder so Responder picks correct HTTP code. +func (*RateLimitError) StatusCode() int { + return http.StatusTooManyRequests // 429 +} diff --git a/pkg/gofr/service/rate_limiter_config_test.go b/pkg/gofr/service/rate_limiter_config_test.go new file mode 100644 index 0000000000..ab9909997f --- /dev/null +++ b/pkg/gofr/service/rate_limiter_config_test.go @@ -0,0 +1,143 @@ +package service + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "gofr.dev/pkg/gofr/logging" + "gofr.dev/pkg/gofr/testutil" +) + +func newHTTPService(t *testing.T) *httpService { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + return &httpService{ + Client: http.DefaultClient, + url: srv.URL, + Logger: logging.NewMockLogger(logging.INFO), + Tracer: otel.Tracer("gofr-http-client"), + } +} + +func TestRateLimiterConfig_Validate(t *testing.T) { + t.Run("invalid RPS", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 0, Burst: 1} + err := cfg.Validate() + + require.Error(t, err) + assert.ErrorIs(t, err, errInvalidRequestRate) + }) + + t.Run("burst less than requests", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 5, Burst: 3} + err := cfg.Validate() + + require.Error(t, err) + assert.ErrorIs(t, err, errBurstLessThanRequests) + }) + + t.Run("sets default KeyFunc when nil", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 1.5, Burst: 2} + + require.Nil(t, cfg.KeyFunc) + require.NoError(t, cfg.Validate()) + require.NotNil(t, cfg.KeyFunc) + }) +} + +func TestDefaultKeyFunc(t *testing.T) { + t.Run("nil request", func(t *testing.T) { + assert.Equal(t, "unknown", defaultKeyFunc(nil)) + }) + + t.Run("nil URL", func(t *testing.T) { + req := &http.Request{} + + assert.Equal(t, "unknown", defaultKeyFunc(req)) + }) + + t.Run("http derived scheme", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Host: "example.com"}, + } + + assert.Equal(t, "http://example.com", defaultKeyFunc(req)) + }) + + t.Run("https derived scheme", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Host: "secure.com"}, + TLS: &tls.ConnectionState{}, + } + + assert.Equal(t, "https://secure.com", defaultKeyFunc(req)) + }) + + t.Run("host from req.Host fallback", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + Host: "fallback:9090", + } + + assert.Equal(t, "http://fallback:9090", defaultKeyFunc(req)) + }) + + t.Run("unknown service key when no host present", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + } + + assert.Equal(t, "http://unknown", defaultKeyFunc(req)) + }) +} + +func TestAddOption_InvalidConfigReturnsOriginal(t *testing.T) { + h := newHTTPService(t) + + cfg := RateLimiterConfig{Requests: 0, Burst: 1} // invalid + + out := cfg.AddOption(h) + + assert.Same(t, h, out) +} + +func TestAddOption_DefaultsToLocalStoreAndLogsWarning(t *testing.T) { + log := testutil.StdoutOutputForFunc(func() { + h := newHTTPService(t) + + cfg := RateLimiterConfig{Requests: 2, Burst: 2, Window: time.Second} + + cfg.Store = nil + + _ = cfg.AddOption(h) + }) + + assert.Contains(t, log, "Using local rate limiting - not suitable for multi-instance deployments") +} + +func TestRequestsPerSecond(t *testing.T) { + cfg := RateLimiterConfig{Requests: 10, Window: 2 * time.Second} + + assert.InEpsilon(t, 5.0, cfg.RequestsPerSecond(), 0.001) +} + +func TestRateLimitError_ErrorAndStatusCode(t *testing.T) { + err := &RateLimitError{ServiceKey: "svc", RetryAfter: 2 * time.Second} + + assert.Contains(t, err.Error(), "rate limit exceeded for service: svc") + + assert.Equal(t, http.StatusTooManyRequests, err.StatusCode()) +} diff --git a/pkg/gofr/service/rate_limiter_distributed.go b/pkg/gofr/service/rate_limiter_distributed.go deleted file mode 100644 index b57a762144..0000000000 --- a/pkg/gofr/service/rate_limiter_distributed.go +++ /dev/null @@ -1,258 +0,0 @@ -package service - -import ( - "context" - "fmt" - "net/http" - "strconv" -) - -// tokenBucketScript is a Lua script for atomic token bucket rate limiting in Redis. -// -//nolint:gosec // This is a Lua script for Redis, not credentials -const tokenBucketScript = ` -local key = KEYS[1] -local burst = tonumber(ARGV[1]) -local requests = tonumber(ARGV[2]) -local window_seconds = tonumber(ARGV[3]) -local now = tonumber(ARGV[4]) - --- Calculate refill rate as requests per second -local refill_rate = requests / window_seconds - --- Fetch bucket -local bucket = redis.call("HMGET", key, "tokens", "last_refill") -local tokens = tonumber(bucket[1]) -local last_refill = tonumber(bucket[2]) - -if tokens == nil then -tokens = burst -last_refill = now -end - --- Refill tokens -local delta = math.max(0, (now - last_refill)/1e9) -local new_tokens = math.min(burst, tokens + delta * refill_rate) - -local allowed = 0 -local retryAfter = 0 - -if new_tokens >= 1 then -allowed = 1 -new_tokens = new_tokens - 1 -else -retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms -end - -redis.call("HSET", key, "tokens", new_tokens, "last_refill", now) -redis.call("EXPIRE", key, 600) - -return {allowed, retryAfter} -` - -// DistributedRateLimiter implements Redis-based distributed rate limiting using Token Bucket algorithm. -// Strategy: Token Bucket with Redis Lua scripts for atomic operations -// - Suitable for: Multi-instance production deployments -// - Benefits: True distributed limiting across all service instances -// - Performance: Single Redis call per rate limit check with atomic Lua execution - -// distributedRateLimiter with metrics support. -type distributedRateLimiter struct { - config RateLimiterConfig - store RateLimiterStore - logger Logger - metrics Metrics - HTTP -} - -func NewDistributedRateLimiter(config RateLimiterConfig, h HTTP, store RateLimiterStore) HTTP { - if err := config.Validate(); err != nil { - if httpSvc, ok := h.(*httpService); ok { - httpSvc.Logger.Log("Invalid rate limiter config, disabling distributed rate limiting", "error", err) - } - - return h - } - - httpSvc := h.(*httpService) - - rl := &distributedRateLimiter{ - config: config, - store: store, - logger: httpSvc.Logger, - metrics: httpSvc.Metrics, - HTTP: h, - } - - return rl -} - -// Safe Redis result parsing. -func toInt64(i any) (int64, error) { - switch v := i.(type) { - case int64: - return v, nil - case int: - return int64(v), nil - case float64: - return int64(v), nil - case string: - return strconv.ParseInt(v, 10, 64) - default: - return 0, fmt.Errorf("%w: %T", errInvalidRedisResultType, i) - } -} - -// checkRateLimit for distributed version with metrics. -func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error { - serviceKey := rl.config.KeyFunc(req) - - allowed, retryAfter, err := rl.store.Allow(context.Background(), serviceKey, rl.config) - if err != nil { - rl.logger.Log("Rate limiter store error, allowing request", "error", err) - - rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_errors_total", "service", serviceKey, "type", "store_error") - - return nil - } - - rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_requests_total", "service", serviceKey) - - if !allowed { - rl.metrics.IncrementCounter(context.Background(), "app_rate_limiter_denied_total", "service", serviceKey) - - rl.logger.Debug("Distributed rate limit exceeded", "service", serviceKey, "retry_after", retryAfter) - - return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} - } - - return nil -} - -// GetWithHeaders performs rate-limited HTTP GET request with custom headers. -func (rl *distributedRateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, - headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.GetWithHeaders(ctx, path, queryParams, headers) -} - -// PostWithHeaders performs rate-limited HTTP POST request with custom headers. -func (rl *distributedRateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, - body []byte, headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers) -} - -// PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. -func (rl *distributedRateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, - body []byte, headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers) -} - -// PutWithHeaders performs rate-limited HTTP PUT request with custom headers. -func (rl *distributedRateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, - body []byte, headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers) -} - -// DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. -func (rl *distributedRateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, - headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.DeleteWithHeaders(ctx, path, body, headers) -} - -// Get performs rate-limited HTTP GET request. -func (rl *distributedRateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Get(ctx, path, queryParams) -} - -// Post performs rate-limited HTTP POST request. -func (rl *distributedRateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, - body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Post(ctx, path, queryParams, body) -} - -// Patch performs rate-limited HTTP PATCH request. -func (rl *distributedRateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, - body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Patch(ctx, path, queryParams, body) -} - -// Put performs rate-limited HTTP PUT request. -func (rl *distributedRateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, - body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Put(ctx, path, queryParams, body) -} - -// Delete performs rate-limited HTTP DELETE request. -func (rl *distributedRateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Delete(ctx, path, body) -} diff --git a/pkg/gofr/service/rate_limiter_local.go b/pkg/gofr/service/rate_limiter_local.go deleted file mode 100644 index 14f9a03417..0000000000 --- a/pkg/gofr/service/rate_limiter_local.go +++ /dev/null @@ -1,406 +0,0 @@ -package service - -import ( - "context" - "net/http" - "runtime" - "strings" - "sync" - "sync/atomic" - "time" -) - -const ( - backoffAttemptThreshold = 3 - unknownServiceKey = "unknown" - methodHTTP = "http" - methodHTTPS = "https" -) - -// tokenBucket with fractional accumulator for better precision. -type tokenBucket struct { - tokens int64 // Current tokens (scaled by scale) - fractionalTokens float64 // Fractional remainder to avoid precision loss - lastRefillTime int64 // Unix nano timestamp - maxTokens int64 // Maximum tokens (scaled by scale) - refillPerNano float64 // Tokens per nanosecond (float64 for precision) - fracMutex sync.Mutex // Protects fractionalTokens -} - -// LocalRateLimiter implements in-memory rate limiting using the Token Bucket algorithm. -// Strategy: Token Bucket with fractional precision for sub-1 RPS support -// - Suitable for: Single-instance deployments, development, testing -// - Limitations: Per-instance limiting only, not suitable for multi-instance production -// - Performance: Lock-free atomic operations with CAS loops - -// localRateLimiter with metrics support. -type localRateLimiter struct { - config RateLimiterConfig - store RateLimiterStore - logger Logger - metrics Metrics - HTTP -} - -// bucketEntry holds bucket with last access time for cleanup. -type bucketEntry struct { - bucket *tokenBucket - lastAccess int64 // Unix timestamp -} - -const ( - scale int64 = 1e9 // Scaling factor (typed constant) - cleanupInterval = 5 * time.Minute // How often to clean up unused buckets - bucketTTL = 10 * time.Minute // How long to keep unused buckets - maxCASAttempts = 10 // ✅ FIX: Max CAS attempts - maxCASTime = 100 * time.Microsecond // ✅ FIX: Max CAS time -) - -// NewLocalRateLimiter creates a new local rate limiter with metrics. -func NewLocalRateLimiter(config RateLimiterConfig, h HTTP, store RateLimiterStore) HTTP { - if err := config.Validate(); err != nil { - if httpSvc, ok := h.(*httpService); ok { - httpSvc.Logger.Log("Invalid rate limiter config, disabling local rate limiting", "error", err) - } - - return h - } - - httpSvc := h.(*httpService) - - rl := &localRateLimiter{ - config: config, - store: store, - logger: httpSvc.Logger, - metrics: httpSvc.Metrics, - HTTP: h, - } - - go rl.cleanupRoutine() - - return rl -} - -// newTokenBucket creates a new atomic token bucket with proper float64 scaling. -func newTokenBucket(config *RateLimiterConfig) *tokenBucket { - maxScaled := int64(config.Burst) * scale - - requestsPerSecond := config.RequestsPerSecond() - refillPerNanoFloat := requestsPerSecond * float64(scale) / float64(time.Second) - - return &tokenBucket{ - tokens: maxScaled, - fractionalTokens: 0.0, - lastRefillTime: time.Now().UnixNano(), - maxTokens: maxScaled, - refillPerNano: refillPerNanoFloat, - } -} - -// allow with enhanced precision and metrics. -func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration, tokensRemaining int64) { - start := time.Now() - - for attempt := 0; attempt < maxCASAttempts && time.Since(start) < maxCASTime; attempt++ { - now := time.Now().UnixNano() - newTokens := tb.refillTokens(now) - - if newTokens < scale { - retry := tb.calculateRetry(newTokens) - tb.advanceTime(now) - - return false, retry, newTokens - } - - if tb.consumeToken(newTokens, now) { - return true, 0, newTokens - scale - } - - tb.backoff(attempt) - } - - return false, time.Second, 0 -} - -// refillTokens calculates and returns new token count after refilling based on elapsed time. -func (tb *tokenBucket) refillTokens(now int64) int64 { - oldTime := atomic.LoadInt64(&tb.lastRefillTime) - oldTokens := atomic.LoadInt64(&tb.tokens) - - elapsed := now - oldTime - if elapsed < 0 { - elapsed = 0 - } - - tb.fracMutex.Lock() - tokensToAddFloat := float64(elapsed)*tb.refillPerNano + tb.fractionalTokens - tokensToAdd := int64(tokensToAddFloat) - tb.fractionalTokens = tokensToAddFloat - float64(tokensToAdd) - tb.fracMutex.Unlock() - - newTokens := oldTokens + tokensToAdd - if newTokens > tb.maxTokens { - newTokens = tb.maxTokens - } - - return newTokens -} - -// calculateRetry computes the precise time duration until the next token becomes available. -func (tb *tokenBucket) calculateRetry(tokens int64) time.Duration { - if tb.refillPerNano == 0 { - return time.Second - } - - missing := float64(scale - tokens) - nanos := missing / tb.refillPerNano - - retry := time.Duration(nanos) - if retry < time.Second { - retry = time.Second - } - - return retry -} - -func (tb *tokenBucket) advanceTime(now int64) { - oldTime := atomic.LoadInt64(&tb.lastRefillTime) - atomic.CompareAndSwapInt64(&tb.lastRefillTime, oldTime, now) -} - -func (tb *tokenBucket) consumeToken(tokens, now int64) bool { - oldTokens := atomic.LoadInt64(&tb.tokens) - - if atomic.CompareAndSwapInt64(&tb.tokens, oldTokens, tokens-scale) { - atomic.StoreInt64(&tb.lastRefillTime, now) - - return true - } - - return false -} - -func (*tokenBucket) backoff(attempt int) { - if attempt < backoffAttemptThreshold { - runtime.Gosched() - } else { - time.Sleep(time.Microsecond) - } -} - -// buildFullURL constructs an absolute URL by combining the base service URL with the given path. -func buildFullURL(path string, httpSvc HTTP) string { - if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { - return path - } - - // Get base URL from embedded HTTP service - httpSvcImpl, ok := httpSvc.(*httpService) - if !ok { - return path - } - - base := strings.TrimRight(httpSvcImpl.url, "/") - if base == "" { - return path - } - - // Ensure path starts with / - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - - return base + path -} - -// checkRateLimit with custom keying support. -func (rl *localRateLimiter) checkRateLimit(req *http.Request) error { - serviceKey := rl.config.KeyFunc(req) - - allowed, retryAfter, _ := rl.store.Allow(req.Context(), serviceKey, rl.config) - - rl.updateRateLimiterMetrics(req.Context(), serviceKey, allowed, 0) - - if !allowed { - rl.logger.Debug("Rate limit exceeded", "service", serviceKey, "rate", rl.config.RequestsPerSecond, ""+ - "burst", rl.config.Burst, "retry_after", retryAfter) - - return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} - } - - return nil -} - -// updateRateLimiterMetrics follows GoFr's updateMetrics pattern. -func (rl *localRateLimiter) updateRateLimiterMetrics(ctx context.Context, serviceKey string, allowed bool, tokensAvailable float64) { - if rl.metrics != nil { - rl.metrics.IncrementCounter(ctx, "app_rate_limiter_requests_total", "service", serviceKey) - - if !allowed { - rl.metrics.IncrementCounter(ctx, "app_rate_limiter_denied_total", "service", serviceKey) - } - - rl.metrics.SetGauge("app_rate_limiter_tokens_available", tokensAvailable, "service", serviceKey) - } -} - -// cleanupRoutine removes unused buckets. -func (rl *localRateLimiter) cleanupRoutine() { - ticker := time.NewTicker(cleanupInterval) - defer ticker.Stop() - - for range ticker.C { - cutoff := time.Now().Unix() - int64(bucketTTL.Seconds()) - cleaned := 0 - - localStore, ok := rl.store.(*LocalRateLimiterStore) - if !ok { - continue // Not a local store, skip cleanup - } - - localStore.buckets.Range(func(key, value any) bool { - entry := value.(*bucketEntry) - - if atomic.LoadInt64(&entry.lastAccess) < cutoff { - localStore.buckets.Delete(key) - - cleaned++ - } - - return true - }) - - if cleaned > 0 { - rl.logger.Debug("Cleaned up rate limiter buckets", "count", cleaned) - } - } -} - -// GetWithHeaders performs rate-limited HTTP GET request with custom headers. -func (rl *localRateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, - headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.GetWithHeaders(ctx, path, queryParams, headers) -} - -// PostWithHeaders performs rate-limited HTTP POST request with custom headers. -func (rl *localRateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, - body []byte, headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers) -} - -// PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. -func (rl *localRateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, - body []byte, headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers) -} - -// PutWithHeaders performs rate-limited HTTP PUT request with custom headers. -func (rl *localRateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, - headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers) -} - -// DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. -func (rl *localRateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, - headers map[string]string) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.DeleteWithHeaders(ctx, path, body, headers) -} - -// Get performs rate-limited HTTP GET request. -func (rl *localRateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Get(ctx, path, queryParams) -} - -// Post performs rate-limited HTTP POST request. -func (rl *localRateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, - body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Post(ctx, path, queryParams, body) -} - -// Patch performs rate-limited HTTP PATCH request. -func (rl *localRateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, - body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Patch(ctx, path, queryParams, body) -} - -// Put performs rate-limited HTTP PUT request. -func (rl *localRateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, - body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Put(ctx, path, queryParams, body) -} - -// Delete performs rate-limited HTTP DELETE request. -func (rl *localRateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { - fullURL := buildFullURL(path, rl.HTTP) - req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) - - if err := rl.checkRateLimit(req); err != nil { - return nil, err - } - - return rl.HTTP.Delete(ctx, path, body) -} diff --git a/pkg/gofr/service/rate_limiter_local_test.go b/pkg/gofr/service/rate_limiter_local_test.go deleted file mode 100644 index adc1f050f3..0000000000 --- a/pkg/gofr/service/rate_limiter_local_test.go +++ /dev/null @@ -1,437 +0,0 @@ -package service - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "gofr.dev/pkg/gofr/logging" -) - -func newBaseHTTPService(t *testing.T, hitCounter *atomic.Int64) *httpService { - t.Helper() - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - hitCounter.Add(1) - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(srv.Close) - - return &httpService{ - Client: http.DefaultClient, - url: srv.URL, - Logger: logging.NewMockLogger(logging.INFO), - Tracer: otel.Tracer("gofr-http-client"), - } -} - -func assertAllowed(t *testing.T, resp *http.Response, err error) { - t.Helper() - require.NoError(t, err) - require.NotNil(t, resp) - assert.Equal(t, http.StatusOK, resp.StatusCode) -} - -func assertRateLimited(t *testing.T, err error, key ...string) { - t.Helper() - require.Error(t, err) - - var rlErr *RateLimitError - - require.ErrorAs(t, err, &rlErr) - - if len(key) > 0 { - assert.Equal(t, key[0], rlErr.ServiceKey) - } - - assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) -} - -func wait(d time.Duration) { time.Sleep(d) } - -func TestNewLocalRateLimiter_Basic(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - - Requests: 5, - Window: time.Second, - Burst: 5, - KeyFunc: func(*http.Request) string { return "svc-basic" }, - }, base, NewLocalRateLimiterStore()) - - resp, err := rl.Get(t.Context(), "/ok", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(1), hits.Load()) -} - -// Burst=1 then immediate second call denied; after refill allowed again. -func TestLocalRateLimiter_EnforceLimit(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - key := "svc-limit-" + time.Now().Format("150405.000000000") - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 1, - Window: time.Second, - Burst: 1, - KeyFunc: func(*http.Request) string { return key }, - }, base, NewLocalRateLimiterStore()) - - resp, err := rl.Get(t.Context(), "/r1", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/r2", nil) - require.Nil(t, resp) - assertRateLimited(t, err) - - if resp != nil { - _ = resp.Body.Close() - } - - wait(1100 * time.Millisecond) - - resp, err = rl.Get(t.Context(), "/r3", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(2), hits.Load()) -} - -// Fractional RPS (0.5 -> 1 token every 2s). -func TestLocalRateLimiter_FractionalRPS(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 0.5, - Window: time.Second, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-frac" }, - }, base, NewLocalRateLimiterStore()) - - resp, err := rl.Get(t.Context(), "/a", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/b", nil) - require.Nil(t, resp) - assertRateLimited(t, err) - - if resp != nil { - _ = resp.Body.Close() - } - - wait(2100 * time.Millisecond) - - resp, err = rl.Get(t.Context(), "/c", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(2), hits.Load()) -} - -// Different paths share same bucket via custom KeyFunc. -func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 1, - Window: time.Second, - Burst: 1, - KeyFunc: func(*http.Request) string { return "shared-key" }, - }, base, NewLocalRateLimiterStore()) - - resp, err := rl.Get(t.Context(), "/p1", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/p2", nil) - require.Nil(t, resp) - assertRateLimited(t, err) - - if resp != nil { - _ = resp.Body.Close() - } - - wait(1100 * time.Millisecond) - - resp, err = rl.Get(t.Context(), "/p3", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - assert.Equal(t, int64(2), hits.Load()) -} - -// Concurrency: Burst=1 & RPS=1 => only one succeeds immediately. -func TestLocalRateLimiter_Concurrency(t *testing.T) { - var hits atomic.Int64 - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 1, - Window: time.Second, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-conc" }, - }, base, NewLocalRateLimiterStore()) - - const workers = 12 - results := make([]error, workers) - - var wg sync.WaitGroup - - wg.Add(workers) - - for i := 0; i < workers; i++ { - go func(i int) { - defer wg.Done() - - resp, err := rl.Get(context.Background(), "/c", nil) - - if resp != nil { - _ = resp.Body.Close() - } - - results[i] = err - }(i) - } - - wg.Wait() - - var allowed, denied int - - for _, e := range results { - if e == nil { - allowed++ - continue - } - - var rlErr *RateLimitError - - if errors.As(e, &rlErr) { - denied++ - continue - } - - t.Fatalf("unexpected error type: %v", e) - } - - assert.Equal(t, 1, allowed) - assert.Equal(t, workers-1, denied) - assert.Equal(t, int64(1), hits.Load()) -} - -// buildFullURL behavior for relative and absolute forms. -func TestBuildFullURL(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - assert.Contains(t, buildFullURL("/x", base), "/x") - assert.Equal(t, "http://example.com/z", buildFullURL("http://example.com/z", base)) - assert.Contains(t, buildFullURL("rel", base), "/rel") -} - -// Ensures metrics calls do not panic when metrics nil (guard path). -func TestLocalRateLimiter_NoMetrics(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 2, - Window: time.Second, - Burst: 2, - KeyFunc: func(*http.Request) string { return "svc-nometrics" }, - }, base, NewLocalRateLimiterStore()) - - resp, err := rl.Get(t.Context(), "/m", nil) - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } -} - -// Denial path exposes RateLimitError fields. -func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) { - var hits atomic.Int64 - - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 1, // Always zero refill - Window: time.Second, - Burst: 1, - KeyFunc: func(*http.Request) string { return "svc-zero" }, - }, base, NewLocalRateLimiterStore()) - - resp, err := rl.Get(t.Context(), "/z1", nil) - - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - - resp, err = rl.Get(t.Context(), "/z2", nil) - require.Nil(t, resp) - - if resp != nil { - _ = resp.Body.Close() - } - - var rlErr *RateLimitError - - require.ErrorAs(t, err, &rlErr) - - assert.Equal(t, "svc-zero", rlErr.ServiceKey) - assert.GreaterOrEqual(t, rlErr.RetryAfter, time.Second) -} - -func TestLocalRateLimiter_WrapperMethods_Allowed(t *testing.T) { - var hits atomic.Int64 - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 100, - Window: time.Second, - Burst: 100, - KeyFunc: func(*http.Request) string { return "wrapper-allow" }, - }, base, NewLocalRateLimiterStore()) - - tests := []struct { - name string - call func(h HTTP) (*http.Response, error) - }{ - {"Get", func(h HTTP) (*http.Response, error) { return h.Get(t.Context(), "/g", nil) }}, - {"GetWithHeaders", func(h HTTP) (*http.Response, error) { - return h.GetWithHeaders(t.Context(), "/gh", nil, map[string]string{"X": "1"}) - }}, - {"Post", func(h HTTP) (*http.Response, error) { return h.Post(t.Context(), "/p", nil, []byte("x")) }}, - {"PostWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PostWithHeaders(t.Context(), "/ph", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Patch", func(h HTTP) (*http.Response, error) { return h.Patch(t.Context(), "/pa", nil, []byte("x")) }}, - {"PatchWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PatchWithHeaders(t.Context(), "/pah", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Put", func(h HTTP) (*http.Response, error) { return h.Put(t.Context(), "/put", nil, []byte("x")) }}, - {"PutWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PutWithHeaders(t.Context(), "/puth", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Delete", func(h HTTP) (*http.Response, error) { return h.Delete(t.Context(), "/d", []byte("x")) }}, - {"DeleteWithHeaders", func(h HTTP) (*http.Response, error) { - return h.DeleteWithHeaders(t.Context(), "/dh", []byte("x"), map[string]string{"X": "1"}) - }}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - resp, err := tc.call(rl) - - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - }) - } - - assert.Equal(t, int64(len(tests)), hits.Load()) -} - -func TestLocalRateLimiter_WrapperMethods_InvalidConfig(t *testing.T) { - var hits atomic.Int64 - base := newBaseHTTPService(t, &hits) - - rl := NewLocalRateLimiter(RateLimiterConfig{ - Requests: 0, - Burst: 0, - KeyFunc: func(*http.Request) string { return "wrapper-deny" }, - }, base, NewLocalRateLimiterStore()) - - tests := []struct { - name string - call func(h HTTP) (*http.Response, error) - }{ - {"Get", func(h HTTP) (*http.Response, error) { return h.Get(t.Context(), "/g", nil) }}, - {"GetWithHeaders", func(h HTTP) (*http.Response, error) { - return h.GetWithHeaders(t.Context(), "/gh", nil, map[string]string{"X": "1"}) - }}, - {"Post", func(h HTTP) (*http.Response, error) { return h.Post(t.Context(), "/p", nil, []byte("x")) }}, - {"PostWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PostWithHeaders(t.Context(), "/ph", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Patch", func(h HTTP) (*http.Response, error) { return h.Patch(t.Context(), "/pa", nil, []byte("x")) }}, - {"PatchWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PatchWithHeaders(t.Context(), "/pah", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Put", func(h HTTP) (*http.Response, error) { return h.Put(t.Context(), "/put", nil, []byte("x")) }}, - {"PutWithHeaders", func(h HTTP) (*http.Response, error) { - return h.PutWithHeaders(t.Context(), "/puth", nil, []byte("x"), map[string]string{"X": "1"}) - }}, - {"Delete", func(h HTTP) (*http.Response, error) { return h.Delete(t.Context(), "/d", []byte("x")) }}, - {"DeleteWithHeaders", func(h HTTP) (*http.Response, error) { - return h.DeleteWithHeaders(t.Context(), "/dh", []byte("x"), map[string]string{"X": "1"}) - }}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - resp, err := tc.call(rl) - - assertAllowed(t, resp, err) - - if resp != nil { - _ = resp.Body.Close() - } - }) - } - - assert.Equal(t, int64(len(tests)), hits.Load()) -} diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go index 0edf258ba5..8c0b998922 100644 --- a/pkg/gofr/service/rate_limiter_store.go +++ b/pkg/gofr/service/rate_limiter_store.go @@ -2,17 +2,208 @@ package service import ( "context" + "fmt" + "strconv" "sync" + "sync/atomic" "time" gofrRedis "gofr.dev/pkg/gofr/datasource/redis" ) -// RateLimiterStore abstracts the storage for rate limiter buckets. +const ( + cleanupInterval = 5 * time.Minute // How often to clean up unused buckets + bucketTTL = 10 * time.Minute // How long to keep unused buckets +) + +// RateLimiterStore abstracts the storage and cleanup for rate limiter buckets. type RateLimiterStore interface { Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) + StartCleanup(ctx context.Context, logger Logger) + StopCleanup() +} + +// tokenBucket with simplified integer-only token handling +type tokenBucket struct { + tokens int64 // Current tokens + lastRefillTime int64 // Unix nano timestamp + maxTokens int64 // Maximum tokens + refillRate int64 // Tokens per second (as integer) +} + +// bucketEntry holds bucket with last access time for cleanup. +type bucketEntry struct { + bucket *tokenBucket + lastAccess int64 // Unix timestamp } +// newTokenBucket creates a new token bucket with integer-only math +func newTokenBucket(config *RateLimiterConfig) *tokenBucket { + maxTokens := int64(config.Burst) + refillRate := int64(config.RequestsPerSecond()) + + return &tokenBucket{ + tokens: maxTokens, + lastRefillTime: time.Now().UnixNano(), + maxTokens: maxTokens, + refillRate: refillRate, + } +} + +// allow checks if a token can be consumed +func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration) { + now := time.Now().UnixNano() + + // Calculate tokens to add based on elapsed time + elapsed := now - atomic.LoadInt64(&tb.lastRefillTime) + tokensToAdd := elapsed * tb.refillRate / int64(time.Second) + + // Update tokens atomically + for { + oldTokens := atomic.LoadInt64(&tb.tokens) + newTokens := oldTokens + tokensToAdd + + if newTokens > tb.maxTokens { + newTokens = tb.maxTokens + } + + // Try to consume a token + if newTokens >= 1 { + if atomic.CompareAndSwapInt64(&tb.tokens, oldTokens, newTokens-1) { + atomic.StoreInt64(&tb.lastRefillTime, now) + return true, 0 + } + } else { + // Calculate wait time + waitTime := time.Duration((1-newTokens)*int64(time.Second)/tb.refillRate) * time.Nanosecond + + if waitTime < time.Millisecond { + waitTime = time.Millisecond + } + + return false, waitTime + } + } +} + +// LocalRateLimiterStore implements RateLimiterStore using in-memory buckets. +type LocalRateLimiterStore struct { + buckets *sync.Map + stopCh chan struct{} +} + +func NewLocalRateLimiterStore() *LocalRateLimiterStore { + return &LocalRateLimiterStore{ + buckets: &sync.Map{}, + } +} + +func (l *LocalRateLimiterStore) Allow(_ context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { + now := time.Now().Unix() + entry, _ := l.buckets.LoadOrStore(key, &bucketEntry{ + bucket: newTokenBucket(&config), + lastAccess: now, + }) + + bucketEntry := entry.(*bucketEntry) + + atomic.StoreInt64(&bucketEntry.lastAccess, now) + + allowed, retryAfter := bucketEntry.bucket.allow() + + return allowed, retryAfter, nil +} + +func (l *LocalRateLimiterStore) StartCleanup(ctx context.Context, logger Logger) { + l.stopCh = make(chan struct{}) + + go func() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + l.cleanupExpiredBuckets(logger) + case <-l.stopCh: + return + case <-ctx.Done(): + return + } + } + }() +} + +func (l *LocalRateLimiterStore) StopCleanup() { + if l.stopCh != nil { + close(l.stopCh) + } +} + +func (l *LocalRateLimiterStore) cleanupExpiredBuckets(logger Logger) { + cutoff := time.Now().Unix() - int64(bucketTTL.Seconds()) + cleaned := 0 + + l.buckets.Range(func(key, value any) bool { + entry := value.(*bucketEntry) + if atomic.LoadInt64(&entry.lastAccess) < cutoff { + l.buckets.Delete(key) + cleaned++ + } + + return true + }) + + if cleaned > 0 && logger != nil { + logger.Debug("Cleaned up rate limiter buckets", "count", cleaned) + } +} + +// tokenBucketScript is a Lua script for atomic token bucket rate limiting in Redis. +// Updated to use integer-only token math for simplicity +// +//nolint:gosec // This is a Lua script for Redis, not credentials +const tokenBucketScript = ` +local key = KEYS[1] +local burst = tonumber(ARGV[1]) +local requests = tonumber(ARGV[2]) +local window_seconds = tonumber(ARGV[3]) +local now = tonumber(ARGV[4]) + +-- Calculate refill rate as requests per second +local refill_rate = requests / window_seconds + +-- Fetch bucket +local bucket = redis.call("HMGET", key, "tokens", "last_refill") +local tokens = tonumber(bucket[1]) +local last_refill = tonumber(bucket[2]) + +if tokens == nil then + tokens = burst + last_refill = now +end + +-- Refill tokens (integer math only) +local delta = math.max(0, (now - last_refill)/1e9) +local tokens_to_add = math.floor(delta * refill_rate) +local new_tokens = math.min(burst, tokens + tokens_to_add) + +local allowed = 0 +local retryAfter = 0 + +if new_tokens >= 1 then + allowed = 1 + new_tokens = new_tokens - 1 +else + retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms +end + +redis.call("HSET", key, "tokens", new_tokens, "last_refill", now) +redis.call("EXPIRE", key, 600) + +return {allowed, retryAfter} +` + // RedisRateLimiterStore implements RateLimiterStore using Redis. type RedisRateLimiterStore struct { client *gofrRedis.Redis @@ -24,7 +215,6 @@ func NewRedisRateLimiterStore(client *gofrRedis.Redis) *RedisRateLimiterStore { func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { now := time.Now().UnixNano() - cmd := r.client.Eval( ctx, tokenBucketScript, @@ -46,29 +236,35 @@ func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config Ra } allowed, _ := toInt64(resultArray[0]) - retryAfterMs, _ := toInt64(resultArray[1]) return allowed == 1, time.Duration(retryAfterMs) * time.Millisecond, nil } -// LocalRateLimiterStore implements RateLimiterStore using in-memory buckets. -type LocalRateLimiterStore struct { - buckets *sync.Map +func (r *RedisRateLimiterStore) StartCleanup(_ context.Context, _ Logger) { + // No-op: Redis handles cleanup automatically via EXPIRE commands in Lua script } -func NewLocalRateLimiterStore() *LocalRateLimiterStore { - return &LocalRateLimiterStore{buckets: &sync.Map{}} +func (r *RedisRateLimiterStore) StopCleanup() { + // No-op: Redis handles cleanup automatically } -func (l *LocalRateLimiterStore) Allow(_ context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { - now := time.Now().Unix() - entry, _ := l.buckets.LoadOrStore(key, &bucketEntry{ - bucket: newTokenBucket(&config), - lastAccess: now, - }) - bucketEntry := entry.(*bucketEntry) - allowed, retryAfter, _ := bucketEntry.bucket.allow() +// toInt64 safely converts Redis result to int64 +func toInt64(i any) (int64, error) { + switch v := i.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case float64: + return int64(v), nil + case string: + if v == "" { + return 0, nil + } - return allowed, retryAfter, nil + return strconv.ParseInt(v, 10, 64) + default: + return 0, fmt.Errorf("%w: %T", errInvalidRedisResultType, i) + } } diff --git a/pkg/gofr/service/rate_limiter_store_test.go b/pkg/gofr/service/rate_limiter_store_test.go new file mode 100644 index 0000000000..abc7f9626f --- /dev/null +++ b/pkg/gofr/service/rate_limiter_store_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "gofr.dev/pkg/gofr/logging" + "gofr.dev/pkg/gofr/testutil" +) + +func TestTokenBucket_Allow(t *testing.T) { + cfg := RateLimiterConfig{Requests: 2, Burst: 2, Window: time.Second} + tb := newTokenBucket(&cfg) + + // Should allow first two requests + allowed, wait := tb.allow() + assert.True(t, allowed) + assert.Zero(t, wait) + + allowed, wait = tb.allow() + assert.True(t, allowed) + assert.Zero(t, wait) + + // Third request should be rate limited + allowed, wait = tb.allow() + assert.False(t, allowed) + assert.GreaterOrEqual(t, wait, time.Millisecond) +} + +func TestLocalRateLimiterStore_Allow(t *testing.T) { + store := NewLocalRateLimiterStore() + cfg := RateLimiterConfig{Requests: 1, Burst: 1, Window: time.Second} + key := "test-key" + + allowed, retry, err := store.Allow(context.Background(), key, cfg) + assert.True(t, allowed) + assert.Zero(t, retry) + assert.NoError(t, err) + + allowed, retry, err = store.Allow(context.Background(), key, cfg) + assert.False(t, allowed) + assert.GreaterOrEqual(t, retry, time.Millisecond) + assert.NoError(t, err) +} + +func TestLocalRateLimiterStore_CleanupExpiredBuckets(t *testing.T) { + store := NewLocalRateLimiterStore() + cfg := RateLimiterConfig{Requests: 1, Burst: 1, Window: time.Second} + key := "cleanup-key" + + _, _, _ = store.Allow(context.Background(), key, cfg) + + // Simulate old lastAccess + entry, _ := store.buckets.Load(key) + bucketEntry := entry.(*bucketEntry) + bucketEntry.lastAccess = time.Now().Unix() - int64(bucketTTL.Seconds()) - 1 + + log := testutil.StdoutOutputForFunc(func() { + store.cleanupExpiredBuckets(logging.NewMockLogger(logging.DEBUG)) + }) + + _, exists := store.buckets.Load(key) + assert.False(t, exists) + assert.Contains(t, log, "Cleaned up rate limiter buckets") +} + +func TestLocalRateLimiterStore_StartAndStopCleanup(t *testing.T) { + store := NewLocalRateLimiterStore() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StartCleanup(ctx, logging.NewMockLogger(logging.INFO)) + assert.NotNil(t, store.stopCh) + + store.StopCleanup() +} + +func TestRedisRateLimiterStore_toInt64_ValidCases(t *testing.T) { + tests := []struct { + input any + expected int64 + }{ + {int64(5), 5}, + {int(7), 7}, + {float64(3.0), 3}, + {"42", 42}, + {"", 0}, + } + + for _, tc := range tests { + val, err := toInt64(tc.input) + + assert.NoError(t, err) + assert.Equal(t, tc.expected, val) + } +} + +func TestRedisRateLimiterStore_toInt64_ErrorCases(t *testing.T) { + _, err := toInt64(struct{}{}) + + assert.ErrorIs(t, err, errInvalidRedisResultType) +} diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index d3cd28c4be..11cc5ee07b 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -2,165 +2,207 @@ package service import ( "context" - "crypto/tls" + "errors" + "net/http" - "net/http/httptest" - "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - gofrRedis "gofr.dev/pkg/gofr/datasource/redis" - "gofr.dev/pkg/gofr/logging" + "go.uber.org/mock/gomock" ) -func newHTTPService(t *testing.T) *httpService { - t.Helper() +// --- Simple logger mock --- +type mockLogger struct { + logs []string +} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(srv.Close) +func (l *mockLogger) Log(_ ...any) { l.logs = append(l.logs, "Log") } +func (l *mockLogger) Debug(_ ...any) { l.logs = append(l.logs, "Debug") } - return &httpService{ - Client: http.DefaultClient, - url: srv.URL, - Logger: logging.NewMockLogger(logging.INFO), - Tracer: otel.Tracer("gofr-http-client"), - } +type mockStore struct { + allowed bool + retryAfter time.Duration + err error } -func TestRateLimiterConfig_Validate(t *testing.T) { - t.Run("invalid RPS", func(t *testing.T) { - cfg := RateLimiterConfig{Requests: 0, Burst: 1} - err := cfg.Validate() - require.Error(t, err) - assert.ErrorIs(t, err, errInvalidRequestRate) - }) - - t.Run("burst less than requests", func(t *testing.T) { - cfg := RateLimiterConfig{Requests: 5, Burst: 3} - err := cfg.Validate() - require.Error(t, err) - assert.ErrorIs(t, err, errBurstLessThanRequests) - }) - - t.Run("sets default KeyFunc when nil", func(t *testing.T) { - cfg := RateLimiterConfig{Requests: 1.5, Burst: 2} - require.Nil(t, cfg.KeyFunc) - require.NoError(t, cfg.Validate()) - require.NotNil(t, cfg.KeyFunc) - }) +func (m *mockStore) Allow(_ context.Context, _ string, _ RateLimiterConfig) (bool, time.Duration, error) { + return m.allowed, m.retryAfter, m.err } +func (*mockStore) StartCleanup(_ context.Context, _ Logger) {} -func TestDefaultKeyFunc(t *testing.T) { - t.Run("nil request", func(t *testing.T) { - assert.Equal(t, "unknown", defaultKeyFunc(nil)) - }) - - t.Run("nil URL", func(t *testing.T) { - req := &http.Request{} - assert.Equal(t, "unknown", defaultKeyFunc(req)) - }) - - t.Run("http derived scheme", func(t *testing.T) { - req := &http.Request{ - URL: &url.URL{Host: "example.com"}, - } - assert.Equal(t, "http://example.com", defaultKeyFunc(req)) - }) - - t.Run("https derived scheme", func(t *testing.T) { - req := &http.Request{ - URL: &url.URL{Host: "secure.com"}, - TLS: &tls.ConnectionState{}, - } - assert.Equal(t, "https://secure.com", defaultKeyFunc(req)) - }) - - t.Run("host from req.Host fallback", func(t *testing.T) { - req := &http.Request{ - URL: &url.URL{}, - Host: "fallback:9090", - } - assert.Equal(t, "http://fallback:9090", defaultKeyFunc(req)) - }) - - t.Run("unknown service key when no host present", func(t *testing.T) { - req := &http.Request{ - URL: &url.URL{}, - } - assert.Equal(t, "http://unknown", defaultKeyFunc(req)) - }) -} +func (*mockStore) StopCleanup() {} -func TestAddOption_InvalidConfigReturnsOriginal(t *testing.T) { - h := newHTTPService(t) - cfg := RateLimiterConfig{Requests: 0, Burst: 1} // invalid - out := cfg.AddOption(h) - assert.Same(t, h, out) -} +func TestRateLimiter_buildFullURL(t *testing.T) { + httpSvc := &httpService{url: "http://base.com/api"} + rl := &rateLimiter{HTTP: httpSvc} -func TestAddOption_LocalLimiter(t *testing.T) { - h := newHTTPService(t) - cfg := RateLimiterConfig{Requests: 2, Burst: 3} - out := cfg.AddOption(h) + assert.Equal(t, "http://foo.com/bar", rl.buildFullURL("http://foo.com/bar")) + assert.Equal(t, "https://foo.com/bar", rl.buildFullURL("https://foo.com/bar")) + assert.Equal(t, "http://base.com/api/foo", rl.buildFullURL("foo")) + assert.Equal(t, "http://base.com/api/foo", rl.buildFullURL("/foo")) - _, isLocal := out.(*localRateLimiter) - assert.True(t, isLocal, "expected *localRateLimiter") + httpSvc.url = "" - assert.NotNil(t, cfg.KeyFunc) + assert.Equal(t, "bar", rl.buildFullURL("bar")) + + rl.HTTP = &mockHTTP{} + + assert.Equal(t, "baz", rl.buildFullURL("baz")) } -func TestAddOption_DistributedLimiter(t *testing.T) { - h := newHTTPService(t) - cfg := RateLimiterConfig{ - Requests: 5, - Burst: 5, - Store: NewRedisRateLimiterStore(new(gofrRedis.Redis)), +func TestRateLimiter_checkRateLimit_Error(t *testing.T) { + store := &mockStore{allowed: true, err: errors.New("fail")} + logger := &mockLogger{} + + ctrl := gomock.NewController(t) + metrics := NewMockMetrics(ctrl) + + metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_requests_total", "service", "svc") + metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_errors_total", "service", "svc", "type", "store_error") + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + logger: logger, + metrics: metrics, } - out := cfg.AddOption(h) - _, isDist := out.(*distributedRateLimiter) + req, _ := http.NewRequest("GET", "/", nil) - assert.True(t, isDist, "expected *distributedRateLimiter") + err := rl.checkRateLimit(req) + + assert.NoError(t, err) + assert.Contains(t, logger.logs, "Log") } -type dummyStore struct{} +func TestRateLimiter_checkRateLimit_Denied(t *testing.T) { + store := &mockStore{allowed: false} + logger := &mockLogger{} + + ctrl := gomock.NewController(t) + metrics := NewMockMetrics(ctrl) + + metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_requests_total", "service", "svc") + metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_denied_total", "service", "svc") + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + logger: logger, + metrics: metrics, + } + + req, _ := http.NewRequest("GET", "/", nil) + err := rl.checkRateLimit(req) -func (dummyStore) Allow(_ context.Context, _ string, _ RateLimiterConfig) (allowed bool, - retryAfter time.Duration, err error) { - return true, 0, nil + assert.IsType(t, &RateLimitError{}, err) + assert.Contains(t, logger.logs, "Debug") } -func TestNewDistributedRateLimiter_WithHTTPService_Success(t *testing.T) { - config := RateLimiterConfig{ - Requests: 10, - Window: time.Minute, - Burst: 10, +func TestRateLimiter_checkRateLimit_Allowed(t *testing.T) { + store := &mockStore{allowed: true} + + logger := &mockLogger{} + + ctrl := gomock.NewController(t) + metrics := NewMockMetrics(ctrl) + + metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_requests_total", "service", "svc") + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + logger: logger, + metrics: metrics, } - h := newHTTPService(t) - store := &dummyStore{} - result := NewDistributedRateLimiter(config, h, store) + req, _ := http.NewRequest("GET", "/", nil) - _, ok := result.(*distributedRateLimiter) - assert.True(t, ok, "should return distributedRateLimiter") + err := rl.checkRateLimit(req) + assert.NoError(t, err) } -func TestNewDistributedRateLimiter_WithHTTPService_Error(t *testing.T) { - config := RateLimiterConfig{ - Requests: 0, // Invalid - Window: time.Minute, - Burst: 10, +func TestRateLimiter_HTTPMethods(t *testing.T) { + mock := &mockHTTP{} + + store := &mockStore{allowed: true} + logger := &mockLogger{} + + ctrl := gomock.NewController(t) + metrics := NewMockMetrics(ctrl) + + metrics.EXPECT().IncrementCounter(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + logger: logger, + metrics: metrics, + HTTP: mock, } - h := newHTTPService(t) - store := &dummyStore{} - result := NewDistributedRateLimiter(config, h, store) + ctx := context.Background() + resp, err := rl.Get(ctx, "foo", nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = rl.GetWithHeaders(ctx, "foo", nil, nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = rl.Post(ctx, "foo", nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + resp, err = rl.PostWithHeaders(ctx, "foo", nil, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + resp, err = rl.Put(ctx, "foo", nil, nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = rl.PutWithHeaders(ctx, "foo", nil, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = rl.Patch(ctx, "foo", nil, nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = rl.PatchWithHeaders(ctx, "foo", nil, nil, nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = rl.Delete(ctx, "foo", nil) + + assert.NoError(t, err) + assert.Equal(t, http.StatusNoContent, resp.StatusCode) + + resp, err = rl.DeleteWithHeaders(ctx, "foo", nil, nil) - assert.Same(t, h, result, "should return original HTTP on invalid config") + assert.NoError(t, err) + assert.Equal(t, http.StatusNoContent, resp.StatusCode) } From 815dec7b170053e819782bfa66d3cea210bd9677 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Fri, 3 Oct 2025 11:39:01 +0530 Subject: [PATCH 15/21] re-write tests --- pkg/gofr/service/rate_limiter_store.go | 35 ++++++++------- pkg/gofr/service/rate_limiter_store_test.go | 8 ++-- pkg/gofr/service/rate_limiter_test.go | 49 ++++++++++++++------- 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go index 8c0b998922..8524261c5e 100644 --- a/pkg/gofr/service/rate_limiter_store.go +++ b/pkg/gofr/service/rate_limiter_store.go @@ -23,7 +23,7 @@ type RateLimiterStore interface { StopCleanup() } -// tokenBucket with simplified integer-only token handling +// tokenBucket with simplified integer-only token handling. type tokenBucket struct { tokens int64 // Current tokens lastRefillTime int64 // Unix nano timestamp @@ -37,7 +37,7 @@ type bucketEntry struct { lastAccess int64 // Unix timestamp } -// newTokenBucket creates a new token bucket with integer-only math +// newTokenBucket creates a new token bucket with integer-only math. func newTokenBucket(config *RateLimiterConfig) *tokenBucket { maxTokens := int64(config.Burst) refillRate := int64(config.RequestsPerSecond()) @@ -50,7 +50,7 @@ func newTokenBucket(config *RateLimiterConfig) *tokenBucket { } } -// allow checks if a token can be consumed +// allow checks if a token can be consumed. func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration) { now := time.Now().UnixNano() @@ -67,22 +67,22 @@ func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration) { newTokens = tb.maxTokens } - // Try to consume a token - if newTokens >= 1 { - if atomic.CompareAndSwapInt64(&tb.tokens, oldTokens, newTokens-1) { - atomic.StoreInt64(&tb.lastRefillTime, now) - return true, 0 - } - } else { - // Calculate wait time + // Early return if not enough tokens + if newTokens < 1 { waitTime := time.Duration((1-newTokens)*int64(time.Second)/tb.refillRate) * time.Nanosecond - if waitTime < time.Millisecond { waitTime = time.Millisecond } return false, waitTime } + + // Try to consume a token + if atomic.CompareAndSwapInt64(&tb.tokens, oldTokens, newTokens-1) { + atomic.StoreInt64(&tb.lastRefillTime, now) + + return true, 0 + } } } @@ -148,6 +148,7 @@ func (l *LocalRateLimiterStore) cleanupExpiredBuckets(logger Logger) { entry := value.(*bucketEntry) if atomic.LoadInt64(&entry.lastAccess) < cutoff { l.buckets.Delete(key) + cleaned++ } @@ -241,15 +242,15 @@ func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config Ra return allowed == 1, time.Duration(retryAfterMs) * time.Millisecond, nil } -func (r *RedisRateLimiterStore) StartCleanup(_ context.Context, _ Logger) { - // No-op: Redis handles cleanup automatically via EXPIRE commands in Lua script +func (*RedisRateLimiterStore) StartCleanup(_ context.Context, _ Logger) { + // No-op: Redis handles cleanup automatically via EXPIRE commands in Lua script. } -func (r *RedisRateLimiterStore) StopCleanup() { - // No-op: Redis handles cleanup automatically +func (*RedisRateLimiterStore) StopCleanup() { + // No-op: Redis handles cleanup automatically. } -// toInt64 safely converts Redis result to int64 +// toInt64 safely converts Redis result to int64. func toInt64(i any) (int64, error) { switch v := i.(type) { case int64: diff --git a/pkg/gofr/service/rate_limiter_store_test.go b/pkg/gofr/service/rate_limiter_store_test.go index abc7f9626f..0173c73045 100644 --- a/pkg/gofr/service/rate_limiter_store_test.go +++ b/pkg/gofr/service/rate_limiter_store_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gofr.dev/pkg/gofr/logging" "gofr.dev/pkg/gofr/testutil" @@ -38,7 +39,7 @@ func TestLocalRateLimiterStore_Allow(t *testing.T) { allowed, retry, err := store.Allow(context.Background(), key, cfg) assert.True(t, allowed) assert.Zero(t, retry) - assert.NoError(t, err) + require.NoError(t, err) allowed, retry, err = store.Allow(context.Background(), key, cfg) assert.False(t, allowed) @@ -51,7 +52,8 @@ func TestLocalRateLimiterStore_CleanupExpiredBuckets(t *testing.T) { cfg := RateLimiterConfig{Requests: 1, Burst: 1, Window: time.Second} key := "cleanup-key" - _, _, _ = store.Allow(context.Background(), key, cfg) + _, _, err := store.Allow(context.Background(), key, cfg) + require.NoError(t, err) // Simulate old lastAccess entry, _ := store.buckets.Load(key) @@ -94,7 +96,7 @@ func TestRedisRateLimiterStore_toInt64_ValidCases(t *testing.T) { for _, tc := range tests { val, err := toInt64(tc.input) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, tc.expected, val) } } diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index 11cc5ee07b..551ab71716 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -2,8 +2,6 @@ package service import ( "context" - "errors" - "net/http" "testing" "time" @@ -13,7 +11,6 @@ import ( "go.uber.org/mock/gomock" ) -// --- Simple logger mock --- type mockLogger struct { logs []string } @@ -53,7 +50,7 @@ func TestRateLimiter_buildFullURL(t *testing.T) { } func TestRateLimiter_checkRateLimit_Error(t *testing.T) { - store := &mockStore{allowed: true, err: errors.New("fail")} + store := &mockStore{allowed: true, err: errTest} logger := &mockLogger{} ctrl := gomock.NewController(t) @@ -72,11 +69,11 @@ func TestRateLimiter_checkRateLimit_Error(t *testing.T) { metrics: metrics, } - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) err := rl.checkRateLimit(req) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, logger.logs, "Log") } @@ -100,7 +97,7 @@ func TestRateLimiter_checkRateLimit_Denied(t *testing.T) { metrics: metrics, } - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) err := rl.checkRateLimit(req) assert.IsType(t, &RateLimitError{}, err) @@ -127,7 +124,7 @@ func TestRateLimiter_checkRateLimit_Allowed(t *testing.T) { metrics: metrics, } - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) err := rl.checkRateLimit(req) assert.NoError(t, err) @@ -161,48 +158,68 @@ func TestRateLimiter_HTTPMethods(t *testing.T) { require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.GetWithHeaders(ctx, "foo", nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.Post(ctx, "foo", nil, nil) require.NoError(t, err) assert.Equal(t, http.StatusCreated, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.PostWithHeaders(ctx, "foo", nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusCreated, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.Put(ctx, "foo", nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.PutWithHeaders(ctx, "foo", nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.Patch(ctx, "foo", nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.PatchWithHeaders(ctx, "foo", nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.Delete(ctx, "foo", nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusNoContent, resp.StatusCode) + defer resp.Body.Close() + resp, err = rl.DeleteWithHeaders(ctx, "foo", nil, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusNoContent, resp.StatusCode) + + _ = resp.Body.Close() } From c54c78e7e8d1b1f9b65802fcf728a5e3f2d58d34 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Fri, 3 Oct 2025 12:48:37 +0530 Subject: [PATCH 16/21] revert unwanted changes --- docs/advanced-guide/http-communication/page.md | 6 ++++-- pkg/gofr/container/container.go | 6 ------ pkg/gofr/service/rate_limiter_store.go | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/docs/advanced-guide/http-communication/page.md b/docs/advanced-guide/http-communication/page.md index f134cb9b18..10eaa940a8 100644 --- a/docs/advanced-guide/http-communication/page.md +++ b/docs/advanced-guide/http-communication/page.md @@ -102,14 +102,16 @@ GoFr allows you to use a custom rate limiter store by implementing the RateLimit Interface: ```go type RateLimiterStore interface { - Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter int64, err error) +Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) +StartCleanup(ctx context.Context) +StopCleanup() } ``` #### Usage: ```go -rc := redis.NewClient(cfg, a.Logger(), a.Metrics()) +rc := redis.NewClient(a.Config, a.Logger(), a.Metrics()) a.AddHTTPService("cat-facts", "https://catfact.ninja", service.NewAPIKeyConfig("some-random-key"), diff --git a/pkg/gofr/container/container.go b/pkg/gofr/container/container.go index 438f8638fc..36bfdd736e 100644 --- a/pkg/gofr/container/container.go +++ b/pkg/gofr/container/container.go @@ -260,12 +260,6 @@ func (c *Container) registerFrameworkMetrics() { httpBuckets := []float64{.001, .003, .005, .01, .02, .03, .05, .1, .2, .3, .5, .75, 1, 2, 3, 5, 10, 30} c.Metrics().NewHistogram("app_http_response", "Response time of HTTP requests in seconds.", httpBuckets...) c.Metrics().NewHistogram("app_http_service_response", "Response time of HTTP service requests in seconds.", httpBuckets...) - - rateLimiterBuckets := []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5} - c.Metrics().NewHistogram("app_rate_limiter_stats", "Response time of rate limiter checks in milliseconds.", rateLimiterBuckets...) - c.Metrics().NewCounter("app_rate_limiter_requests_total", "Total rate limiter requests.") - c.Metrics().NewCounter("app_rate_limiter_denied_total", "Total rate limiter denied requests.") - c.Metrics().NewCounter("app_rate_limiter_errors_total", "Total rate limiter errors.") } { // Redis metrics diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go index 8524261c5e..2c89a5ac89 100644 --- a/pkg/gofr/service/rate_limiter_store.go +++ b/pkg/gofr/service/rate_limiter_store.go @@ -242,7 +242,7 @@ func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config Ra return allowed == 1, time.Duration(retryAfterMs) * time.Millisecond, nil } -func (*RedisRateLimiterStore) StartCleanup(_ context.Context, _ Logger) { +func (*RedisRateLimiterStore) StartCleanup(_ context.Context) { // No-op: Redis handles cleanup automatically via EXPIRE commands in Lua script. } From 8e6f6eadbd8c2a36e07dc37ca26c52dff43b90fc Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Fri, 3 Oct 2025 12:55:44 +0530 Subject: [PATCH 17/21] remove changes in interface of logger and metrics --- pkg/gofr/service/circuit_breaker_test.go | 8 ------ pkg/gofr/service/logger.go | 1 - pkg/gofr/service/metrics.go | 2 -- pkg/gofr/service/mock_metrics.go | 36 +----------------------- 4 files changed, 1 insertion(+), 46 deletions(-) diff --git a/pkg/gofr/service/circuit_breaker_test.go b/pkg/gofr/service/circuit_breaker_test.go index 58089acfdc..dfb2868531 100644 --- a/pkg/gofr/service/circuit_breaker_test.go +++ b/pkg/gofr/service/circuit_breaker_test.go @@ -592,14 +592,6 @@ func (m *mockMetrics) RecordHistogram(ctx context.Context, name string, value fl m.Called(ctx, name, value, labels) } -func (m *mockMetrics) IncrementCounter(ctx context.Context, name string, labels ...string) { - m.Called(ctx, name, labels) -} - -func (m *mockMetrics) SetGauge(name string, value float64, labels ...string) { - m.Called(name, value, labels) -} - type customTransport struct { } diff --git a/pkg/gofr/service/logger.go b/pkg/gofr/service/logger.go index c94b72caee..04cf64917b 100644 --- a/pkg/gofr/service/logger.go +++ b/pkg/gofr/service/logger.go @@ -8,7 +8,6 @@ import ( type Logger interface { Log(args ...any) - Debug(args ...any) } type Log struct { diff --git a/pkg/gofr/service/metrics.go b/pkg/gofr/service/metrics.go index 64753d9650..ae48c21494 100644 --- a/pkg/gofr/service/metrics.go +++ b/pkg/gofr/service/metrics.go @@ -3,7 +3,5 @@ package service import "context" type Metrics interface { - IncrementCounter(ctx context.Context, name string, labels ...string) - SetGauge(name string, value float64, labels ...string) RecordHistogram(ctx context.Context, name string, value float64, labels ...string) } diff --git a/pkg/gofr/service/mock_metrics.go b/pkg/gofr/service/mock_metrics.go index 0923f05eb2..45129e90a6 100644 --- a/pkg/gofr/service/mock_metrics.go +++ b/pkg/gofr/service/mock_metrics.go @@ -40,23 +40,6 @@ func (m *MockMetrics) EXPECT() *MockMetricsMockRecorder { return m.recorder } -// IncrementCounter mocks base method. -func (m *MockMetrics) IncrementCounter(ctx context.Context, name string, labels ...string) { - m.ctrl.T.Helper() - varargs := []any{ctx, name} - for _, a := range labels { - varargs = append(varargs, a) - } - m.ctrl.Call(m, "IncrementCounter", varargs...) -} - -// IncrementCounter indicates an expected call of IncrementCounter. -func (mr *MockMetricsMockRecorder) IncrementCounter(ctx, name any, labels ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, name}, labels...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementCounter", reflect.TypeOf((*MockMetrics)(nil).IncrementCounter), varargs...) -} - // RecordHistogram mocks base method. func (m *MockMetrics) RecordHistogram(ctx context.Context, name string, value float64, labels ...string) { m.ctrl.T.Helper() @@ -72,21 +55,4 @@ func (mr *MockMetricsMockRecorder) RecordHistogram(ctx, name, value any, labels mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, name, value}, labels...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordHistogram", reflect.TypeOf((*MockMetrics)(nil).RecordHistogram), varargs...) -} - -// SetGauge mocks base method. -func (m *MockMetrics) SetGauge(name string, value float64, labels ...string) { - m.ctrl.T.Helper() - varargs := []any{name, value} - for _, a := range labels { - varargs = append(varargs, a) - } - m.ctrl.Call(m, "SetGauge", varargs...) -} - -// SetGauge indicates an expected call of SetGauge. -func (mr *MockMetricsMockRecorder) SetGauge(name, value any, labels ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{name, value}, labels...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGauge", reflect.TypeOf((*MockMetrics)(nil).SetGauge), varargs...) -} +} \ No newline at end of file From 8568065897e35984101c4b425677dbefcb7f864b Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Fri, 3 Oct 2025 13:09:29 +0530 Subject: [PATCH 18/21] fix linters --- pkg/gofr/service/rate_limiter.go | 49 +++--------------- pkg/gofr/service/rate_limiter_store.go | 12 ++--- pkg/gofr/service/rate_limiter_store_test.go | 10 +--- pkg/gofr/service/rate_limiter_test.go | 57 +++------------------ 4 files changed, 20 insertions(+), 108 deletions(-) diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index 498e493cd9..628ab527df 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -8,28 +8,22 @@ import ( // rateLimiter provides unified rate limiting for HTTP clients. type rateLimiter struct { - config RateLimiterConfig - store RateLimiterStore - logger Logger - metrics Metrics - HTTP // Embedded HTTP service + config RateLimiterConfig + store RateLimiterStore + HTTP // Embedded HTTP service } // NewRateLimiter creates a new unified rate limiter. func NewRateLimiter(config RateLimiterConfig, h HTTP) HTTP { - httpSvc := h.(*httpService) - rl := &rateLimiter{ - config: config, - store: config.Store, - logger: httpSvc.Logger, - metrics: httpSvc.Metrics, - HTTP: h, + config: config, + store: config.Store, + HTTP: h, } // Start cleanup routine ctx := context.Background() - rl.store.StartCleanup(ctx, rl.logger) + rl.store.StartCleanup(ctx) return rl } @@ -62,46 +56,19 @@ func (rl *rateLimiter) buildFullURL(path string) string { // checkRateLimit performs rate limit check using the configured store. func (rl *rateLimiter) checkRateLimit(req *http.Request) error { serviceKey := rl.config.KeyFunc(req) - allowed, retryAfter, err := rl.store.Allow(req.Context(), serviceKey, rl.config) - - // Update metrics - rl.updateRateLimiterMetrics(req.Context(), serviceKey, allowed, err) + allowed, retryAfter, err := rl.store.Allow(req.Context(), serviceKey, rl.config) if err != nil { - rl.logger.Log("Rate limiter store error, allowing request", "error", err) - return nil // Fail open } if !allowed { - rl.logger.Debug("Rate limit exceeded", "service", serviceKey, "rate", rl.config.RequestsPerSecond(), - "burst", rl.config.Burst, "retry_after", retryAfter) - return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} } return nil } -// updateRateLimiterMetrics updates metrics for rate limiting operations. -func (rl *rateLimiter) updateRateLimiterMetrics(ctx context.Context, serviceKey string, allowed bool, err error) { - if rl.metrics == nil { - return - } - - rl.metrics.IncrementCounter(ctx, "app_rate_limiter_requests_total", "service", serviceKey) - - if err != nil { - rl.metrics.IncrementCounter(ctx, "app_rate_limiter_errors_total", "service", serviceKey, "type", "store_error") - } - - if !allowed { - rl.metrics.IncrementCounter(ctx, "app_rate_limiter_denied_total", "service", serviceKey) - } -} - -// HTTP Method Implementations - All methods follow the same pattern. - // Get performs rate-limited HTTP GET request. func (rl *rateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { fullURL := rl.buildFullURL(path) diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go index 2c89a5ac89..0911ae60de 100644 --- a/pkg/gofr/service/rate_limiter_store.go +++ b/pkg/gofr/service/rate_limiter_store.go @@ -19,7 +19,7 @@ const ( // RateLimiterStore abstracts the storage and cleanup for rate limiter buckets. type RateLimiterStore interface { Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) - StartCleanup(ctx context.Context, logger Logger) + StartCleanup(ctx context.Context) StopCleanup() } @@ -114,7 +114,7 @@ func (l *LocalRateLimiterStore) Allow(_ context.Context, key string, config Rate return allowed, retryAfter, nil } -func (l *LocalRateLimiterStore) StartCleanup(ctx context.Context, logger Logger) { +func (l *LocalRateLimiterStore) StartCleanup(ctx context.Context) { l.stopCh = make(chan struct{}) go func() { @@ -124,7 +124,7 @@ func (l *LocalRateLimiterStore) StartCleanup(ctx context.Context, logger Logger) for { select { case <-ticker.C: - l.cleanupExpiredBuckets(logger) + l.cleanupExpiredBuckets() case <-l.stopCh: return case <-ctx.Done(): @@ -140,7 +140,7 @@ func (l *LocalRateLimiterStore) StopCleanup() { } } -func (l *LocalRateLimiterStore) cleanupExpiredBuckets(logger Logger) { +func (l *LocalRateLimiterStore) cleanupExpiredBuckets() { cutoff := time.Now().Unix() - int64(bucketTTL.Seconds()) cleaned := 0 @@ -154,10 +154,6 @@ func (l *LocalRateLimiterStore) cleanupExpiredBuckets(logger Logger) { return true }) - - if cleaned > 0 && logger != nil { - logger.Debug("Cleaned up rate limiter buckets", "count", cleaned) - } } // tokenBucketScript is a Lua script for atomic token bucket rate limiting in Redis. diff --git a/pkg/gofr/service/rate_limiter_store_test.go b/pkg/gofr/service/rate_limiter_store_test.go index 0173c73045..5d2e762fd9 100644 --- a/pkg/gofr/service/rate_limiter_store_test.go +++ b/pkg/gofr/service/rate_limiter_store_test.go @@ -7,9 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "gofr.dev/pkg/gofr/logging" - "gofr.dev/pkg/gofr/testutil" ) func TestTokenBucket_Allow(t *testing.T) { @@ -60,13 +57,10 @@ func TestLocalRateLimiterStore_CleanupExpiredBuckets(t *testing.T) { bucketEntry := entry.(*bucketEntry) bucketEntry.lastAccess = time.Now().Unix() - int64(bucketTTL.Seconds()) - 1 - log := testutil.StdoutOutputForFunc(func() { - store.cleanupExpiredBuckets(logging.NewMockLogger(logging.DEBUG)) - }) + store.cleanupExpiredBuckets() _, exists := store.buckets.Load(key) assert.False(t, exists) - assert.Contains(t, log, "Cleaned up rate limiter buckets") } func TestLocalRateLimiterStore_StartAndStopCleanup(t *testing.T) { @@ -75,7 +69,7 @@ func TestLocalRateLimiterStore_StartAndStopCleanup(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store.StartCleanup(ctx, logging.NewMockLogger(logging.INFO)) + store.StartCleanup(ctx) assert.NotNil(t, store.stopCh) store.StopCleanup() diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go index 551ab71716..b0b2027c0c 100644 --- a/pkg/gofr/service/rate_limiter_test.go +++ b/pkg/gofr/service/rate_limiter_test.go @@ -8,16 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" ) -type mockLogger struct { - logs []string -} - -func (l *mockLogger) Log(_ ...any) { l.logs = append(l.logs, "Log") } -func (l *mockLogger) Debug(_ ...any) { l.logs = append(l.logs, "Debug") } - type mockStore struct { allowed bool retryAfter time.Duration @@ -27,7 +19,7 @@ type mockStore struct { func (m *mockStore) Allow(_ context.Context, _ string, _ RateLimiterConfig) (bool, time.Duration, error) { return m.allowed, m.retryAfter, m.err } -func (*mockStore) StartCleanup(_ context.Context, _ Logger) {} +func (*mockStore) StartCleanup(_ context.Context) {} func (*mockStore) StopCleanup() {} @@ -51,22 +43,13 @@ func TestRateLimiter_buildFullURL(t *testing.T) { func TestRateLimiter_checkRateLimit_Error(t *testing.T) { store := &mockStore{allowed: true, err: errTest} - logger := &mockLogger{} - - ctrl := gomock.NewController(t) - metrics := NewMockMetrics(ctrl) - - metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_requests_total", "service", "svc") - metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_errors_total", "service", "svc", "type", "store_error") rl := &rateLimiter{ config: RateLimiterConfig{ KeyFunc: func(*http.Request) string { return "svc" }, Store: store, }, - store: store, - logger: logger, - metrics: metrics, + store: store, } req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) @@ -74,54 +57,34 @@ func TestRateLimiter_checkRateLimit_Error(t *testing.T) { err := rl.checkRateLimit(req) require.NoError(t, err) - assert.Contains(t, logger.logs, "Log") } func TestRateLimiter_checkRateLimit_Denied(t *testing.T) { store := &mockStore{allowed: false} - logger := &mockLogger{} - - ctrl := gomock.NewController(t) - metrics := NewMockMetrics(ctrl) - - metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_requests_total", "service", "svc") - metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_denied_total", "service", "svc") rl := &rateLimiter{ config: RateLimiterConfig{ KeyFunc: func(*http.Request) string { return "svc" }, Store: store, }, - store: store, - logger: logger, - metrics: metrics, + store: store, } req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) err := rl.checkRateLimit(req) assert.IsType(t, &RateLimitError{}, err) - assert.Contains(t, logger.logs, "Debug") } func TestRateLimiter_checkRateLimit_Allowed(t *testing.T) { store := &mockStore{allowed: true} - logger := &mockLogger{} - - ctrl := gomock.NewController(t) - metrics := NewMockMetrics(ctrl) - - metrics.EXPECT().IncrementCounter(gomock.Any(), "app_rate_limiter_requests_total", "service", "svc") - rl := &rateLimiter{ config: RateLimiterConfig{ KeyFunc: func(*http.Request) string { return "svc" }, Store: store, }, - store: store, - logger: logger, - metrics: metrics, + store: store, } req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) @@ -134,22 +97,14 @@ func TestRateLimiter_HTTPMethods(t *testing.T) { mock := &mockHTTP{} store := &mockStore{allowed: true} - logger := &mockLogger{} - - ctrl := gomock.NewController(t) - metrics := NewMockMetrics(ctrl) - - metrics.EXPECT().IncrementCounter(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() rl := &rateLimiter{ config: RateLimiterConfig{ KeyFunc: func(*http.Request) string { return "svc" }, Store: store, }, - store: store, - logger: logger, - metrics: metrics, - HTTP: mock, + store: store, + HTTP: mock, } ctx := context.Background() From bb387eac2414ee44d50759febe7268f219f201f0 Mon Sep 17 00:00:00 2001 From: Umang01-hash Date: Fri, 3 Oct 2025 13:09:29 +0530 Subject: [PATCH 19/21] refactoring implementation --- pkg/gofr/service/rate_limiter.go | 10 +++++++ pkg/gofr/service/rate_limiter_config.go | 39 ++++++++++--------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go index 628ab527df..58660a9b05 100644 --- a/pkg/gofr/service/rate_limiter.go +++ b/pkg/gofr/service/rate_limiter.go @@ -28,6 +28,16 @@ func NewRateLimiter(config RateLimiterConfig, h HTTP) HTTP { return rl } +// AddOption allows RateLimiterConfig to be used as a service.Options. +func (cfg *RateLimiterConfig) AddOption(h HTTP) HTTP { + // Assume cfg is already validated via constructor + if cfg.Store == nil { + cfg.Store = NewLocalRateLimiterStore() + } + + return NewRateLimiter(*cfg, h) +} + // buildFullURL constructs an absolute URL by combining the base service URL with the given path. func (rl *rateLimiter) buildFullURL(path string) string { if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { diff --git a/pkg/gofr/service/rate_limiter_config.go b/pkg/gofr/service/rate_limiter_config.go index 4698bbc44b..1c0f12b5ef 100644 --- a/pkg/gofr/service/rate_limiter_config.go +++ b/pkg/gofr/service/rate_limiter_config.go @@ -28,6 +28,22 @@ type RateLimiterConfig struct { Store RateLimiterStore } +func NewRateLimiterConfig(requests float64, window time.Duration, burst int, store RateLimiterStore, keyFunc func(*http.Request) string) (*RateLimiterConfig, error) { + cfg := &RateLimiterConfig{ + Requests: requests, + Window: window, + Burst: burst, + Store: store, + KeyFunc: keyFunc, + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + + return cfg, nil +} + // defaultKeyFunc extracts a normalized service key from an HTTP request. func defaultKeyFunc(req *http.Request) string { if req == nil || req.URL == nil { @@ -82,29 +98,6 @@ func (config *RateLimiterConfig) Validate() error { return nil } -// AddOption implements the Options interface. -func (config *RateLimiterConfig) AddOption(h HTTP) HTTP { - if err := config.Validate(); err != nil { - if httpSvc, ok := h.(*httpService); ok { - httpSvc.Logger.Log("Invalid rate limiter config, disabling rate limiting", "error", err) - } - - return h - } - - // Default to local store if not set - if config.Store == nil { - config.Store = NewLocalRateLimiterStore() - - // Log warning for local rate limiting. - if httpSvc, ok := h.(*httpService); ok { - httpSvc.Logger.Log("Using local rate limiting - not suitable for multi-instance deployments") - } - } - - return NewRateLimiter(*config, h) -} - // RequestsPerSecond converts the configured rate to requests per second. func (config *RateLimiterConfig) RequestsPerSecond() float64 { // Convert any time window to "requests per second" for internal math From f9f7b042e46f5b0a436c2d322db9e8cbae51a284 Mon Sep 17 00:00:00 2001 From: Eng Zer Jun Date: Fri, 3 Oct 2025 18:03:16 +0800 Subject: [PATCH 20/21] build(deps): update github.com/grpc-ecosystem/go-grpc-middleware to v2 v2 was released in August 2023 to use the Go Protobuf v2 API. We only use the recovery interceptor so we are not affected by the breaking changes. Reference: https://github.com/grpc-ecosystem/go-grpc-middleware/releases/tag/v2.0.0 Signed-off-by: Eng Zer Jun --- go.mod | 2 +- go.sum | 35 ++--------------------------------- pkg/gofr/grpc.go | 7 +++---- 3 files changed, 6 insertions(+), 38 deletions(-) diff --git a/go.mod b/go.mod index 67d4a7d76f..9b2478bc4a 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.3 - github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 7f968c54a1..6f86848d84 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ github.com/XSAM/otelsql v0.40.0 h1:8jaiQ6KcoEXF46fBmPEqb+pp29w2xjWfuXjZXTXBjaA= github.com/XSAM/otelsql v0.40.0/go.mod h1:/7F+1XKt3/sTlYtwKtkHQ5Gzoom+EerXmD1VdnTqfB4= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= -github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -58,8 +57,6 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -69,7 +66,6 @@ github.com/go-redis/redismock/v9 v9.2.0 h1:ZrMYQeKPECZPjOj5u9eyOjg8Nnb0BS9lkVIZ6 github.com/go-redis/redismock/v9 v9.2.0/go.mod h1:18KHfGDK4Y6c2R0H38EUGWAdc7ZQS9gfYxc94k7rWT0= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -81,7 +77,6 @@ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUv github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -116,8 +111,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= -github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= -github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 h1:sGm2vDRFUrQJO/Veii4h4zG2vvqG6uWNkBHSTqXOZk0= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2/go.mod h1:wd1YpapPLivG6nQgbf7ZkG1hhSOXDhhn4MLTknx2aAc= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -127,12 +122,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -151,12 +142,10 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/openzipkin/zipkin-go v0.4.3 h1:9EGwpqkgnwdEIJ+Od7QVSEIH+ocmm5nPat0G7sjsSdg= github.com/openzipkin/zipkin-go v0.4.3/go.mod h1:M9wCJZFWCo2RiY+o1eBCEMe0Dp2S5LDHcMZmk3RmK7c= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -185,17 +174,11 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/segmentio/kafka-go v0.4.49 h1:GJiNX1d/g+kG6ljyJEoi9++PUMdXGAxb7JGPiDCuNmk= github.com/segmentio/kafka-go v0.4.49/go.mod h1:Y1gn60kzLEEaW28YshXyk2+VCUKbJ3Qr6DrnT3i4+9E= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -244,14 +227,10 @@ go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJr go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -266,7 +245,6 @@ golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAf golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -299,11 +277,9 @@ golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -327,7 +303,6 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -346,7 +321,6 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= @@ -358,7 +332,6 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI= google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= @@ -374,15 +347,11 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= diff --git a/pkg/gofr/grpc.go b/pkg/gofr/grpc.go index a31557a203..96615a0964 100644 --- a/pkg/gofr/grpc.go +++ b/pkg/gofr/grpc.go @@ -9,8 +9,7 @@ import ( "strconv" "strings" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" + grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -120,8 +119,8 @@ func registerGRPCMetrics(c *container.Container) { } func (g *grpcServer) createServer() error { - interceptorOption := grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(g.interceptors...)) - streamOpt := grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(g.streamInterceptors...)) + interceptorOption := grpc.ChainUnaryInterceptor(g.interceptors...) + streamOpt := grpc.ChainStreamInterceptor(g.streamInterceptors...) g.options = append(g.options, interceptorOption, streamOpt) g.server = grpc.NewServer(g.options...) From 8bfd3fb1ec582b4e545a84020f96dccf06e89484 Mon Sep 17 00:00:00 2001 From: NishantRajZop Date: Sat, 27 Sep 2025 19:11:48 +0530 Subject: [PATCH 21/21] refactor(docs): replace {serviceName} placeholders with for clarity and consistency --- docs/advanced-guide/grpc/page.md | 46 ++++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/advanced-guide/grpc/page.md b/docs/advanced-guide/grpc/page.md index 630ce7ef46..02b8adc369 100644 --- a/docs/advanced-guide/grpc/page.md +++ b/docs/advanced-guide/grpc/page.md @@ -55,8 +55,8 @@ syntax = "proto3"; // Indicates the go package where the generated file will be produced option go_package = "path/to/your/proto/file"; -service {serviceName}Service { - rpc {serviceMethod} ({serviceRequest}) returns ({serviceResponse}) {} +service Service { + rpc () returns () {} } ``` @@ -66,13 +66,13 @@ Users must define the type of message being exchanged between server and client, procedure call. Below is a generic representation for services' gRPC messages type. ```protobuf -message {serviceRequest} { +message { int64 id = 1; string name = 2; // other fields that can be passed } -message {serviceResponse} { +message { int64 id = 1; string name = 2; string address = 3; @@ -90,10 +90,10 @@ protoc \ --go_opt=paths=source_relative \ --go-grpc_out=. \ --go-grpc_opt=paths=source_relative \ - {serviceName}.proto + .proto ``` -This command generates two files, `{serviceName}.pb.go` and `{serviceName}_grpc.pb.go`, containing the necessary code for performing RPC calls. +This command generates two files, `.pb.go` and `_grpc.pb.go`, containing the necessary code for performing RPC calls. ## Prerequisite: gofr-cli must be installed To install the CLI - @@ -109,15 +109,15 @@ go install gofr.dev/cli/gofr@latest gofr wrap grpc server -proto=./path/your/proto/file ``` -This command leverages the `gofr-cli` to generate a `{serviceName}_server.go` file (e.g., `customer_server.go`) +This command leverages the `gofr-cli` to generate a `_server.go` file (e.g., `customer_server.go`) containing a template for your gRPC server implementation, including context support, in the same directory as that of the specified proto file. **2. Modify the Generated Code:** -- Customize the `{serviceName}GoFrServer` struct with required dependencies and fields. -- Implement the `{serviceMethod}` method to handle incoming requests, as required in this usecase: - - Bind the request payload using `ctx.Bind(&{serviceRequest})`. +- Customize the `GoFrServer` struct with required dependencies and fields. +- Implement the `` method to handle incoming requests, as required in this usecase: + - Bind the request payload using `ctx.Bind(&)`. - Process the request and generate a response. ## Registering the gRPC Service with Gofr @@ -138,7 +138,7 @@ import ( func main() { app := gofr.New() - packageName.Register{serviceName}ServerWithGofr(app, &{packageName}.New{serviceName}GoFrServer()) + packageName.RegisterServerWithGofr(app, &.NewGoFrServer()) app.Run() } @@ -163,7 +163,7 @@ func main() { grpc.ConnectionTimeout(10 * time.Second), ) - packageName.Register{serviceName}ServerWithGofr(app, &{packageName}.New{serviceName}GoFrServer()) + packageName.RegisterServerWithGofr(app, &.NewGoFrServer()) app.Run() } @@ -180,7 +180,7 @@ func main() { app.AddGRPCUnaryInterceptors(authInterceptor) - packageName.Register{serviceName}ServerWithGofr(app, &{packageName}.New{serviceName}GoFrServer()) + packageName.RegisterServerWithGofr(app, &.NewGoFrServer()) app.Run() } @@ -227,26 +227,26 @@ For more details on adding additional interceptors and server options, refer to ```bash gofr wrap grpc client -proto=./path/your/proto/file ``` -This command leverages the `gofr-cli` to generate a `{serviceName}_client.go` file (e.g., `customer_client.go`). This file must not be modified. +This command leverages the `gofr-cli` to generate a `_client.go` file (e.g., `customer_client.go`). This file must not be modified. -**2. Register the connection to your gRPC service inside your {serviceMethod} and make inter-service calls as follows :** +**2. Register the connection to your gRPC service inside your and make inter-service calls as follows :** ```go // gRPC Handler with context support -func {serviceMethod}(ctx *gofr.Context) (*{serviceResponse}, error) { +func (ctx *gofr.Context) (*, error) { // Create the gRPC client - srv, err := New{serviceName}GoFrClient("your-grpc-server-host", ctx.Metrics()) + srv, err := NewGoFrClient("your-grpc-server-host", ctx.Metrics()) if err != nil { return nil, err } // Prepare the request - req := &{serviceRequest}{ + req := &{ // populate fields as necessary } // Call the gRPC method with tracing/metrics enabled - res, err := srv.{serviceMethod}(ctx, req) + res, err := srv.(ctx, req) if err != nil { return nil, err } @@ -307,7 +307,7 @@ func main() { app := gofr.New() // Create a gRPC client for the service - gRPCClient, err := client.New{serviceName}GoFrClient( + gRPCClient, err := client.NewGoFrClient( app.Config.Get("GRPC_SERVER_HOST"), app.Metrics(), grpc.WithChainUnaryInterceptor(MetadataUnaryInterceptor), @@ -374,7 +374,7 @@ func main() { return } - gRPCClient, err := client.New{serviceName}GoFrClient( + gRPCClient, err := client.NewGoFrClient( app.Config.Get("GRPC_SERVER_HOST"), app.Metrics(), grpc.WithTransportCredentials(creds), @@ -409,7 +409,7 @@ GoFr provides built-in health checks for gRPC services, enabling observability, ### Client Interface ```go -type {serviceName}GoFrClient interface { +type GoFrClient interface { SayHello(*gofr.Context, *HelloRequest, ...grpc.CallOption) (*HelloResponse, error) health } @@ -422,7 +422,7 @@ type health interface { ### Server Integration ```go -type {serviceName}GoFrServer struct { +type GoFrServer struct { health *healthServer } ```