diff --git a/docs/advanced-guide/grpc/page.md b/docs/advanced-guide/grpc/page.md index 630ce7ef4..02b8adc36 100644 --- a/docs/advanced-guide/grpc/page.md +++ b/docs/advanced-guide/grpc/page.md @@ -55,8 +55,8 @@ syntax = "proto3"; // Indicates the go package where the generated file will be produced option go_package = "path/to/your/proto/file"; -service {serviceName}Service { - rpc {serviceMethod} ({serviceRequest}) returns ({serviceResponse}) {} +service Service { + rpc () returns () {} } ``` @@ -66,13 +66,13 @@ Users must define the type of message being exchanged between server and client, procedure call. Below is a generic representation for services' gRPC messages type. ```protobuf -message {serviceRequest} { +message { int64 id = 1; string name = 2; // other fields that can be passed } -message {serviceResponse} { +message { int64 id = 1; string name = 2; string address = 3; @@ -90,10 +90,10 @@ protoc \ --go_opt=paths=source_relative \ --go-grpc_out=. \ --go-grpc_opt=paths=source_relative \ - {serviceName}.proto + .proto ``` -This command generates two files, `{serviceName}.pb.go` and `{serviceName}_grpc.pb.go`, containing the necessary code for performing RPC calls. +This command generates two files, `.pb.go` and `_grpc.pb.go`, containing the necessary code for performing RPC calls. ## Prerequisite: gofr-cli must be installed To install the CLI - @@ -109,15 +109,15 @@ go install gofr.dev/cli/gofr@latest gofr wrap grpc server -proto=./path/your/proto/file ``` -This command leverages the `gofr-cli` to generate a `{serviceName}_server.go` file (e.g., `customer_server.go`) +This command leverages the `gofr-cli` to generate a `_server.go` file (e.g., `customer_server.go`) containing a template for your gRPC server implementation, including context support, in the same directory as that of the specified proto file. **2. Modify the Generated Code:** -- Customize the `{serviceName}GoFrServer` struct with required dependencies and fields. -- Implement the `{serviceMethod}` method to handle incoming requests, as required in this usecase: - - Bind the request payload using `ctx.Bind(&{serviceRequest})`. +- Customize the `GoFrServer` struct with required dependencies and fields. +- Implement the `` method to handle incoming requests, as required in this usecase: + - Bind the request payload using `ctx.Bind(&)`. - Process the request and generate a response. ## Registering the gRPC Service with Gofr @@ -138,7 +138,7 @@ import ( func main() { app := gofr.New() - packageName.Register{serviceName}ServerWithGofr(app, &{packageName}.New{serviceName}GoFrServer()) + packageName.RegisterServerWithGofr(app, &.NewGoFrServer()) app.Run() } @@ -163,7 +163,7 @@ func main() { grpc.ConnectionTimeout(10 * time.Second), ) - packageName.Register{serviceName}ServerWithGofr(app, &{packageName}.New{serviceName}GoFrServer()) + packageName.RegisterServerWithGofr(app, &.NewGoFrServer()) app.Run() } @@ -180,7 +180,7 @@ func main() { app.AddGRPCUnaryInterceptors(authInterceptor) - packageName.Register{serviceName}ServerWithGofr(app, &{packageName}.New{serviceName}GoFrServer()) + packageName.RegisterServerWithGofr(app, &.NewGoFrServer()) app.Run() } @@ -227,26 +227,26 @@ For more details on adding additional interceptors and server options, refer to ```bash gofr wrap grpc client -proto=./path/your/proto/file ``` -This command leverages the `gofr-cli` to generate a `{serviceName}_client.go` file (e.g., `customer_client.go`). This file must not be modified. +This command leverages the `gofr-cli` to generate a `_client.go` file (e.g., `customer_client.go`). This file must not be modified. -**2. Register the connection to your gRPC service inside your {serviceMethod} and make inter-service calls as follows :** +**2. Register the connection to your gRPC service inside your and make inter-service calls as follows :** ```go // gRPC Handler with context support -func {serviceMethod}(ctx *gofr.Context) (*{serviceResponse}, error) { +func (ctx *gofr.Context) (*, error) { // Create the gRPC client - srv, err := New{serviceName}GoFrClient("your-grpc-server-host", ctx.Metrics()) + srv, err := NewGoFrClient("your-grpc-server-host", ctx.Metrics()) if err != nil { return nil, err } // Prepare the request - req := &{serviceRequest}{ + req := &{ // populate fields as necessary } // Call the gRPC method with tracing/metrics enabled - res, err := srv.{serviceMethod}(ctx, req) + res, err := srv.(ctx, req) if err != nil { return nil, err } @@ -307,7 +307,7 @@ func main() { app := gofr.New() // Create a gRPC client for the service - gRPCClient, err := client.New{serviceName}GoFrClient( + gRPCClient, err := client.NewGoFrClient( app.Config.Get("GRPC_SERVER_HOST"), app.Metrics(), grpc.WithChainUnaryInterceptor(MetadataUnaryInterceptor), @@ -374,7 +374,7 @@ func main() { return } - gRPCClient, err := client.New{serviceName}GoFrClient( + gRPCClient, err := client.NewGoFrClient( app.Config.Get("GRPC_SERVER_HOST"), app.Metrics(), grpc.WithTransportCredentials(creds), @@ -409,7 +409,7 @@ GoFr provides built-in health checks for gRPC services, enabling observability, ### Client Interface ```go -type {serviceName}GoFrClient interface { +type GoFrClient interface { SayHello(*gofr.Context, *HelloRequest, ...grpc.CallOption) (*HelloResponse, error) health } @@ -422,7 +422,7 @@ type health interface { ### Server Integration ```go -type {serviceName}GoFrServer struct { +type GoFrServer struct { health *healthServer } ``` diff --git a/docs/advanced-guide/http-communication/page.md b/docs/advanced-guide/http-communication/page.md index d90761fbc..10eaa940a 100644 --- a/docs/advanced-guide/http-communication/page.md +++ b/docs/advanced-guide/http-communication/page.md @@ -95,10 +95,24 @@ GoFr provides its user with additional configurational options while registering - **DefaultHeaders** - This option allows user to set some default headers that will be propagated to the downstream HTTP Service every time it is being called. - **HealthConfig** - This option allows user to add the `HealthEndpoint` along with `Timeout` to enable and perform the timely health checks for downstream HTTP Service. - **RetryConfig** - This option allows user to add the maximum number of retry count if before returning error if any downstream HTTP Service fails. +- **RateLimiterConfig** - This option allows user to configure rate limiting for downstream service calls using token bucket algorithm. It controls the request rate to prevent overwhelming dependent services and supports both in-memory and Redis-based implementations. + +**Rate Limiter Store: Customization** +GoFr allows you to use a custom rate limiter store by implementing the RateLimiterStore interface.This enables integration with any backend (e.g., Redis, database, or custom logic) +Interface: +```go +type RateLimiterStore interface { +Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) +StartCleanup(ctx context.Context) +StopCleanup() +} +``` #### Usage: ```go +rc := redis.NewClient(a.Config, a.Logger(), a.Metrics()) + a.AddHTTPService("cat-facts", "https://catfact.ninja", service.NewAPIKeyConfig("some-random-key"), service.NewBasicAuthConfig("username", "password"), @@ -119,5 +133,12 @@ a.AddHTTPService("cat-facts", "https://catfact.ninja", &service.RetryConfig{ MaxRetries: 5 }, + + &service.RateLimiterConfig{ + Requests: 5, + Window: time.Minute, + Burst: 10, + Store: service.NewRedisRateLimiterStore(rc)}, // Skip this field to use in-memory store + }, ) ``` \ No newline at end of file diff --git a/go.mod b/go.mod index 67d4a7d76..9b2478bc4 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.3 - github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 7f968c54a..6f86848d8 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ github.com/XSAM/otelsql v0.40.0 h1:8jaiQ6KcoEXF46fBmPEqb+pp29w2xjWfuXjZXTXBjaA= github.com/XSAM/otelsql v0.40.0/go.mod h1:/7F+1XKt3/sTlYtwKtkHQ5Gzoom+EerXmD1VdnTqfB4= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= -github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -58,8 +57,6 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -69,7 +66,6 @@ github.com/go-redis/redismock/v9 v9.2.0 h1:ZrMYQeKPECZPjOj5u9eyOjg8Nnb0BS9lkVIZ6 github.com/go-redis/redismock/v9 v9.2.0/go.mod h1:18KHfGDK4Y6c2R0H38EUGWAdc7ZQS9gfYxc94k7rWT0= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -81,7 +77,6 @@ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUv github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -116,8 +111,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc h1:GN2Lv3MGO7AS6PrRoT6yV5+wkrOpcszoIsO4+4ds248= github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc/go.mod h1:+JKpmjMGhpgPL+rXZ5nsZieVzvarn86asRlBg4uNGnk= -github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= -github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 h1:sGm2vDRFUrQJO/Veii4h4zG2vvqG6uWNkBHSTqXOZk0= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2/go.mod h1:wd1YpapPLivG6nQgbf7ZkG1hhSOXDhhn4MLTknx2aAc= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -127,12 +122,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -151,12 +142,10 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/openzipkin/zipkin-go v0.4.3 h1:9EGwpqkgnwdEIJ+Od7QVSEIH+ocmm5nPat0G7sjsSdg= github.com/openzipkin/zipkin-go v0.4.3/go.mod h1:M9wCJZFWCo2RiY+o1eBCEMe0Dp2S5LDHcMZmk3RmK7c= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -185,17 +174,11 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/segmentio/kafka-go v0.4.49 h1:GJiNX1d/g+kG6ljyJEoi9++PUMdXGAxb7JGPiDCuNmk= github.com/segmentio/kafka-go v0.4.49/go.mod h1:Y1gn60kzLEEaW28YshXyk2+VCUKbJ3Qr6DrnT3i4+9E= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -244,14 +227,10 @@ go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJr go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -266,7 +245,6 @@ golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAf golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -299,11 +277,9 @@ golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -327,7 +303,6 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -346,7 +321,6 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= @@ -358,7 +332,6 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI= google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= @@ -374,15 +347,11 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= diff --git a/pkg/gofr/grpc.go b/pkg/gofr/grpc.go index a31557a20..96615a096 100644 --- a/pkg/gofr/grpc.go +++ b/pkg/gofr/grpc.go @@ -9,8 +9,7 @@ import ( "strconv" "strings" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" + grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -120,8 +119,8 @@ func registerGRPCMetrics(c *container.Container) { } func (g *grpcServer) createServer() error { - interceptorOption := grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(g.interceptors...)) - streamOpt := grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(g.streamInterceptors...)) + interceptorOption := grpc.ChainUnaryInterceptor(g.interceptors...) + streamOpt := grpc.ChainStreamInterceptor(g.streamInterceptors...) g.options = append(g.options, interceptorOption, streamOpt) g.server = grpc.NewServer(g.options...) diff --git a/pkg/gofr/service/mock_metrics.go b/pkg/gofr/service/mock_metrics.go index df74ef0bf..45129e90a 100644 --- a/pkg/gofr/service/mock_metrics.go +++ b/pkg/gofr/service/mock_metrics.go @@ -20,6 +20,7 @@ import ( type MockMetrics struct { ctrl *gomock.Controller recorder *MockMetricsMockRecorder + isgomock struct{} } // MockMetricsMockRecorder is the mock recorder for MockMetrics. @@ -54,4 +55,4 @@ func (mr *MockMetricsMockRecorder) RecordHistogram(ctx, name, value any, labels mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, name, value}, labels...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordHistogram", reflect.TypeOf((*MockMetrics)(nil).RecordHistogram), varargs...) -} +} \ No newline at end of file diff --git a/pkg/gofr/service/oauth.go b/pkg/gofr/service/oauth.go index c991e9375..aed4ecb62 100644 --- a/pkg/gofr/service/oauth.go +++ b/pkg/gofr/service/oauth.go @@ -79,7 +79,7 @@ func validateTokenURL(tokenURL string) error { return AuthErr{nil, "invalid host pattern, contains `..`"} case strings.HasSuffix(u.Host, "."): return AuthErr{nil, "invalid host pattern, ends with `.`"} - case u.Scheme != "http" && u.Scheme != "https": + case u.Scheme != methodHTTP && u.Scheme != methodHTTPS: return AuthErr{nil, "invalid scheme, allowed http and https only"} default: return nil diff --git a/pkg/gofr/service/rate_limiter.go b/pkg/gofr/service/rate_limiter.go new file mode 100644 index 000000000..58660a9b0 --- /dev/null +++ b/pkg/gofr/service/rate_limiter.go @@ -0,0 +1,208 @@ +package service + +import ( + "context" + "net/http" + "strings" +) + +// rateLimiter provides unified rate limiting for HTTP clients. +type rateLimiter struct { + config RateLimiterConfig + store RateLimiterStore + HTTP // Embedded HTTP service +} + +// NewRateLimiter creates a new unified rate limiter. +func NewRateLimiter(config RateLimiterConfig, h HTTP) HTTP { + rl := &rateLimiter{ + config: config, + store: config.Store, + HTTP: h, + } + + // Start cleanup routine + ctx := context.Background() + rl.store.StartCleanup(ctx) + + return rl +} + +// AddOption allows RateLimiterConfig to be used as a service.Options. +func (cfg *RateLimiterConfig) AddOption(h HTTP) HTTP { + // Assume cfg is already validated via constructor + if cfg.Store == nil { + cfg.Store = NewLocalRateLimiterStore() + } + + return NewRateLimiter(*cfg, h) +} + +// buildFullURL constructs an absolute URL by combining the base service URL with the given path. +func (rl *rateLimiter) buildFullURL(path string) string { + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path + } + + // Get base URL from embedded HTTP service + httpSvcImpl, ok := rl.HTTP.(*httpService) + if !ok { + return path + } + + base := strings.TrimRight(httpSvcImpl.url, "/") + if base == "" { + return path + } + + // Ensure path starts with / + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + return base + path +} + +// checkRateLimit performs rate limit check using the configured store. +func (rl *rateLimiter) checkRateLimit(req *http.Request) error { + serviceKey := rl.config.KeyFunc(req) + + allowed, retryAfter, err := rl.store.Allow(req.Context(), serviceKey, rl.config) + if err != nil { + return nil // Fail open + } + + if !allowed { + return &RateLimitError{ServiceKey: serviceKey, RetryAfter: retryAfter} + } + + return nil +} + +// Get performs rate-limited HTTP GET request. +func (rl *rateLimiter) Get(ctx context.Context, path string, queryParams map[string]any) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Get(ctx, path, queryParams) +} + +// GetWithHeaders performs rate-limited HTTP GET request with custom headers. +func (rl *rateLimiter) GetWithHeaders(ctx context.Context, path string, queryParams map[string]any, + headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.GetWithHeaders(ctx, path, queryParams, headers) +} + +// Post performs rate-limited HTTP POST request. +func (rl *rateLimiter) Post(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Post(ctx, path, queryParams, body) +} + +// PostWithHeaders performs rate-limited HTTP POST request with custom headers. +func (rl *rateLimiter) PostWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers) +} + +// Put performs rate-limited HTTP PUT request. +func (rl *rateLimiter) Put(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Put(ctx, path, queryParams, body) +} + +// PutWithHeaders performs rate-limited HTTP PUT request with custom headers. +func (rl *rateLimiter) PutWithHeaders(ctx context.Context, path string, queryParams map[string]any, body []byte, + headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers) +} + +// Patch performs rate-limited HTTP PATCH request. +func (rl *rateLimiter) Patch(ctx context.Context, path string, queryParams map[string]any, + body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Patch(ctx, path, queryParams, body) +} + +// PatchWithHeaders performs rate-limited HTTP PATCH request with custom headers. +func (rl *rateLimiter) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]any, + body []byte, headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodPatch, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers) +} + +// Delete performs rate-limited HTTP DELETE request. +func (rl *rateLimiter) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.Delete(ctx, path, body) +} + +// DeleteWithHeaders performs rate-limited HTTP DELETE request with custom headers. +func (rl *rateLimiter) DeleteWithHeaders(ctx context.Context, path string, body []byte, + headers map[string]string) (*http.Response, error) { + fullURL := rl.buildFullURL(path) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, fullURL, http.NoBody) + + if err := rl.checkRateLimit(req); err != nil { + return nil, err + } + + return rl.HTTP.DeleteWithHeaders(ctx, path, body, headers) +} diff --git a/pkg/gofr/service/rate_limiter_config.go b/pkg/gofr/service/rate_limiter_config.go new file mode 100644 index 000000000..1c0f12b5e --- /dev/null +++ b/pkg/gofr/service/rate_limiter_config.go @@ -0,0 +1,120 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "time" +) + +var ( + errInvalidRequestRate = errors.New("requests must be greater than 0 per configured time window") + errBurstLessThanRequests = errors.New("burst must be greater than requests per window") + errInvalidRedisResultType = errors.New("unexpected Redis result type") +) + +const ( + unknownServiceKey = "unknown" + methodHTTP = "http" + methodHTTPS = "https" +) + +// RateLimiterConfig with custom keying support. +type RateLimiterConfig struct { + Requests float64 // Number of requests allowed + Window time.Duration // Time window (e.g., time.Minute, time.Hour) + Burst int // Maximum burst capacity (must be > 0) + KeyFunc func(*http.Request) string // Optional custom key extraction + Store RateLimiterStore +} + +func NewRateLimiterConfig(requests float64, window time.Duration, burst int, store RateLimiterStore, keyFunc func(*http.Request) string) (*RateLimiterConfig, error) { + cfg := &RateLimiterConfig{ + Requests: requests, + Window: window, + Burst: burst, + Store: store, + KeyFunc: keyFunc, + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + + return cfg, nil +} + +// defaultKeyFunc extracts a normalized service key from an HTTP request. +func defaultKeyFunc(req *http.Request) string { + if req == nil || req.URL == nil { + return unknownServiceKey + } + + scheme := req.URL.Scheme + host := req.URL.Host + + if scheme == "" { + if req.TLS != nil { + scheme = methodHTTPS + } else { + scheme = methodHTTP + } + } + + if host == "" { + host = req.Host + } + + if host == "" { + host = unknownServiceKey + } + + return scheme + "://" + host +} + +// Validate checks if the configuration is valid. +func (config *RateLimiterConfig) Validate() error { + if config.Requests <= 0 { + return fmt.Errorf("%w: %f", errInvalidRequestRate, config.Requests) + } + + if config.Window <= 0 { + config.Window = time.Minute // Default: per-minute rate limiting + } + + if config.Burst <= 0 { + config.Burst = int(config.Requests) + } + + if float64(config.Burst) < config.Requests { + return fmt.Errorf("%w: burst=%d, requests=%f", errBurstLessThanRequests, config.Burst, config.Requests) + } + + // Set default key function if not provided. + if config.KeyFunc == nil { + config.KeyFunc = defaultKeyFunc + } + + return nil +} + +// RequestsPerSecond converts the configured rate to requests per second. +func (config *RateLimiterConfig) RequestsPerSecond() float64 { + // Convert any time window to "requests per second" for internal math + return float64(config.Requests) / config.Window.Seconds() +} + +// RateLimitError represents a rate limiting error. +type RateLimitError struct { + ServiceKey string + RetryAfter time.Duration +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limit exceeded for service: %s, retry after: %v", e.ServiceKey, e.RetryAfter) +} + +// StatusCode Implement StatusCodeResponder so Responder picks correct HTTP code. +func (*RateLimitError) StatusCode() int { + return http.StatusTooManyRequests // 429 +} diff --git a/pkg/gofr/service/rate_limiter_config_test.go b/pkg/gofr/service/rate_limiter_config_test.go new file mode 100644 index 000000000..ab9909997 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_config_test.go @@ -0,0 +1,143 @@ +package service + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "gofr.dev/pkg/gofr/logging" + "gofr.dev/pkg/gofr/testutil" +) + +func newHTTPService(t *testing.T) *httpService { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + return &httpService{ + Client: http.DefaultClient, + url: srv.URL, + Logger: logging.NewMockLogger(logging.INFO), + Tracer: otel.Tracer("gofr-http-client"), + } +} + +func TestRateLimiterConfig_Validate(t *testing.T) { + t.Run("invalid RPS", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 0, Burst: 1} + err := cfg.Validate() + + require.Error(t, err) + assert.ErrorIs(t, err, errInvalidRequestRate) + }) + + t.Run("burst less than requests", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 5, Burst: 3} + err := cfg.Validate() + + require.Error(t, err) + assert.ErrorIs(t, err, errBurstLessThanRequests) + }) + + t.Run("sets default KeyFunc when nil", func(t *testing.T) { + cfg := RateLimiterConfig{Requests: 1.5, Burst: 2} + + require.Nil(t, cfg.KeyFunc) + require.NoError(t, cfg.Validate()) + require.NotNil(t, cfg.KeyFunc) + }) +} + +func TestDefaultKeyFunc(t *testing.T) { + t.Run("nil request", func(t *testing.T) { + assert.Equal(t, "unknown", defaultKeyFunc(nil)) + }) + + t.Run("nil URL", func(t *testing.T) { + req := &http.Request{} + + assert.Equal(t, "unknown", defaultKeyFunc(req)) + }) + + t.Run("http derived scheme", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Host: "example.com"}, + } + + assert.Equal(t, "http://example.com", defaultKeyFunc(req)) + }) + + t.Run("https derived scheme", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Host: "secure.com"}, + TLS: &tls.ConnectionState{}, + } + + assert.Equal(t, "https://secure.com", defaultKeyFunc(req)) + }) + + t.Run("host from req.Host fallback", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + Host: "fallback:9090", + } + + assert.Equal(t, "http://fallback:9090", defaultKeyFunc(req)) + }) + + t.Run("unknown service key when no host present", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + } + + assert.Equal(t, "http://unknown", defaultKeyFunc(req)) + }) +} + +func TestAddOption_InvalidConfigReturnsOriginal(t *testing.T) { + h := newHTTPService(t) + + cfg := RateLimiterConfig{Requests: 0, Burst: 1} // invalid + + out := cfg.AddOption(h) + + assert.Same(t, h, out) +} + +func TestAddOption_DefaultsToLocalStoreAndLogsWarning(t *testing.T) { + log := testutil.StdoutOutputForFunc(func() { + h := newHTTPService(t) + + cfg := RateLimiterConfig{Requests: 2, Burst: 2, Window: time.Second} + + cfg.Store = nil + + _ = cfg.AddOption(h) + }) + + assert.Contains(t, log, "Using local rate limiting - not suitable for multi-instance deployments") +} + +func TestRequestsPerSecond(t *testing.T) { + cfg := RateLimiterConfig{Requests: 10, Window: 2 * time.Second} + + assert.InEpsilon(t, 5.0, cfg.RequestsPerSecond(), 0.001) +} + +func TestRateLimitError_ErrorAndStatusCode(t *testing.T) { + err := &RateLimitError{ServiceKey: "svc", RetryAfter: 2 * time.Second} + + assert.Contains(t, err.Error(), "rate limit exceeded for service: svc") + + assert.Equal(t, http.StatusTooManyRequests, err.StatusCode()) +} diff --git a/pkg/gofr/service/rate_limiter_store.go b/pkg/gofr/service/rate_limiter_store.go new file mode 100644 index 000000000..0911ae60d --- /dev/null +++ b/pkg/gofr/service/rate_limiter_store.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "fmt" + "strconv" + "sync" + "sync/atomic" + "time" + + gofrRedis "gofr.dev/pkg/gofr/datasource/redis" +) + +const ( + cleanupInterval = 5 * time.Minute // How often to clean up unused buckets + bucketTTL = 10 * time.Minute // How long to keep unused buckets +) + +// RateLimiterStore abstracts the storage and cleanup for rate limiter buckets. +type RateLimiterStore interface { + Allow(ctx context.Context, key string, config RateLimiterConfig) (allowed bool, retryAfter time.Duration, err error) + StartCleanup(ctx context.Context) + StopCleanup() +} + +// tokenBucket with simplified integer-only token handling. +type tokenBucket struct { + tokens int64 // Current tokens + lastRefillTime int64 // Unix nano timestamp + maxTokens int64 // Maximum tokens + refillRate int64 // Tokens per second (as integer) +} + +// bucketEntry holds bucket with last access time for cleanup. +type bucketEntry struct { + bucket *tokenBucket + lastAccess int64 // Unix timestamp +} + +// newTokenBucket creates a new token bucket with integer-only math. +func newTokenBucket(config *RateLimiterConfig) *tokenBucket { + maxTokens := int64(config.Burst) + refillRate := int64(config.RequestsPerSecond()) + + return &tokenBucket{ + tokens: maxTokens, + lastRefillTime: time.Now().UnixNano(), + maxTokens: maxTokens, + refillRate: refillRate, + } +} + +// allow checks if a token can be consumed. +func (tb *tokenBucket) allow() (allowed bool, waitTime time.Duration) { + now := time.Now().UnixNano() + + // Calculate tokens to add based on elapsed time + elapsed := now - atomic.LoadInt64(&tb.lastRefillTime) + tokensToAdd := elapsed * tb.refillRate / int64(time.Second) + + // Update tokens atomically + for { + oldTokens := atomic.LoadInt64(&tb.tokens) + newTokens := oldTokens + tokensToAdd + + if newTokens > tb.maxTokens { + newTokens = tb.maxTokens + } + + // Early return if not enough tokens + if newTokens < 1 { + waitTime := time.Duration((1-newTokens)*int64(time.Second)/tb.refillRate) * time.Nanosecond + if waitTime < time.Millisecond { + waitTime = time.Millisecond + } + + return false, waitTime + } + + // Try to consume a token + if atomic.CompareAndSwapInt64(&tb.tokens, oldTokens, newTokens-1) { + atomic.StoreInt64(&tb.lastRefillTime, now) + + return true, 0 + } + } +} + +// LocalRateLimiterStore implements RateLimiterStore using in-memory buckets. +type LocalRateLimiterStore struct { + buckets *sync.Map + stopCh chan struct{} +} + +func NewLocalRateLimiterStore() *LocalRateLimiterStore { + return &LocalRateLimiterStore{ + buckets: &sync.Map{}, + } +} + +func (l *LocalRateLimiterStore) Allow(_ context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { + now := time.Now().Unix() + entry, _ := l.buckets.LoadOrStore(key, &bucketEntry{ + bucket: newTokenBucket(&config), + lastAccess: now, + }) + + bucketEntry := entry.(*bucketEntry) + + atomic.StoreInt64(&bucketEntry.lastAccess, now) + + allowed, retryAfter := bucketEntry.bucket.allow() + + return allowed, retryAfter, nil +} + +func (l *LocalRateLimiterStore) StartCleanup(ctx context.Context) { + l.stopCh = make(chan struct{}) + + go func() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + l.cleanupExpiredBuckets() + case <-l.stopCh: + return + case <-ctx.Done(): + return + } + } + }() +} + +func (l *LocalRateLimiterStore) StopCleanup() { + if l.stopCh != nil { + close(l.stopCh) + } +} + +func (l *LocalRateLimiterStore) cleanupExpiredBuckets() { + cutoff := time.Now().Unix() - int64(bucketTTL.Seconds()) + cleaned := 0 + + l.buckets.Range(func(key, value any) bool { + entry := value.(*bucketEntry) + if atomic.LoadInt64(&entry.lastAccess) < cutoff { + l.buckets.Delete(key) + + cleaned++ + } + + return true + }) +} + +// tokenBucketScript is a Lua script for atomic token bucket rate limiting in Redis. +// Updated to use integer-only token math for simplicity +// +//nolint:gosec // This is a Lua script for Redis, not credentials +const tokenBucketScript = ` +local key = KEYS[1] +local burst = tonumber(ARGV[1]) +local requests = tonumber(ARGV[2]) +local window_seconds = tonumber(ARGV[3]) +local now = tonumber(ARGV[4]) + +-- Calculate refill rate as requests per second +local refill_rate = requests / window_seconds + +-- Fetch bucket +local bucket = redis.call("HMGET", key, "tokens", "last_refill") +local tokens = tonumber(bucket[1]) +local last_refill = tonumber(bucket[2]) + +if tokens == nil then + tokens = burst + last_refill = now +end + +-- Refill tokens (integer math only) +local delta = math.max(0, (now - last_refill)/1e9) +local tokens_to_add = math.floor(delta * refill_rate) +local new_tokens = math.min(burst, tokens + tokens_to_add) + +local allowed = 0 +local retryAfter = 0 + +if new_tokens >= 1 then + allowed = 1 + new_tokens = new_tokens - 1 +else + retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms +end + +redis.call("HSET", key, "tokens", new_tokens, "last_refill", now) +redis.call("EXPIRE", key, 600) + +return {allowed, retryAfter} +` + +// RedisRateLimiterStore implements RateLimiterStore using Redis. +type RedisRateLimiterStore struct { + client *gofrRedis.Redis +} + +func NewRedisRateLimiterStore(client *gofrRedis.Redis) *RedisRateLimiterStore { + return &RedisRateLimiterStore{client: client} +} + +func (r *RedisRateLimiterStore) Allow(ctx context.Context, key string, config RateLimiterConfig) (bool, time.Duration, error) { + now := time.Now().UnixNano() + cmd := r.client.Eval( + ctx, + tokenBucketScript, + []string{"gofr:ratelimit:" + key}, + config.Burst, // ARGV[1]: burst + config.Requests, // ARGV[2]: requests + int64(config.Window.Seconds()), // ARGV[3]: window_seconds + now, // ARGV[4]: now (nanoseconds) + ) + + result, err := cmd.Result() + if err != nil { + return true, 0, err // Fail open + } + + resultArray, ok := result.([]any) + if !ok || len(resultArray) != 2 { + return true, 0, errInvalidRedisResultType // Fail open + } + + allowed, _ := toInt64(resultArray[0]) + retryAfterMs, _ := toInt64(resultArray[1]) + + return allowed == 1, time.Duration(retryAfterMs) * time.Millisecond, nil +} + +func (*RedisRateLimiterStore) StartCleanup(_ context.Context) { + // No-op: Redis handles cleanup automatically via EXPIRE commands in Lua script. +} + +func (*RedisRateLimiterStore) StopCleanup() { + // No-op: Redis handles cleanup automatically. +} + +// toInt64 safely converts Redis result to int64. +func toInt64(i any) (int64, error) { + switch v := i.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case float64: + return int64(v), nil + case string: + if v == "" { + return 0, nil + } + + return strconv.ParseInt(v, 10, 64) + default: + return 0, fmt.Errorf("%w: %T", errInvalidRedisResultType, i) + } +} diff --git a/pkg/gofr/service/rate_limiter_store_test.go b/pkg/gofr/service/rate_limiter_store_test.go new file mode 100644 index 000000000..5d2e762fd --- /dev/null +++ b/pkg/gofr/service/rate_limiter_store_test.go @@ -0,0 +1,102 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenBucket_Allow(t *testing.T) { + cfg := RateLimiterConfig{Requests: 2, Burst: 2, Window: time.Second} + tb := newTokenBucket(&cfg) + + // Should allow first two requests + allowed, wait := tb.allow() + assert.True(t, allowed) + assert.Zero(t, wait) + + allowed, wait = tb.allow() + assert.True(t, allowed) + assert.Zero(t, wait) + + // Third request should be rate limited + allowed, wait = tb.allow() + assert.False(t, allowed) + assert.GreaterOrEqual(t, wait, time.Millisecond) +} + +func TestLocalRateLimiterStore_Allow(t *testing.T) { + store := NewLocalRateLimiterStore() + cfg := RateLimiterConfig{Requests: 1, Burst: 1, Window: time.Second} + key := "test-key" + + allowed, retry, err := store.Allow(context.Background(), key, cfg) + assert.True(t, allowed) + assert.Zero(t, retry) + require.NoError(t, err) + + allowed, retry, err = store.Allow(context.Background(), key, cfg) + assert.False(t, allowed) + assert.GreaterOrEqual(t, retry, time.Millisecond) + assert.NoError(t, err) +} + +func TestLocalRateLimiterStore_CleanupExpiredBuckets(t *testing.T) { + store := NewLocalRateLimiterStore() + cfg := RateLimiterConfig{Requests: 1, Burst: 1, Window: time.Second} + key := "cleanup-key" + + _, _, err := store.Allow(context.Background(), key, cfg) + require.NoError(t, err) + + // Simulate old lastAccess + entry, _ := store.buckets.Load(key) + bucketEntry := entry.(*bucketEntry) + bucketEntry.lastAccess = time.Now().Unix() - int64(bucketTTL.Seconds()) - 1 + + store.cleanupExpiredBuckets() + + _, exists := store.buckets.Load(key) + assert.False(t, exists) +} + +func TestLocalRateLimiterStore_StartAndStopCleanup(t *testing.T) { + store := NewLocalRateLimiterStore() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + store.StartCleanup(ctx) + assert.NotNil(t, store.stopCh) + + store.StopCleanup() +} + +func TestRedisRateLimiterStore_toInt64_ValidCases(t *testing.T) { + tests := []struct { + input any + expected int64 + }{ + {int64(5), 5}, + {int(7), 7}, + {float64(3.0), 3}, + {"42", 42}, + {"", 0}, + } + + for _, tc := range tests { + val, err := toInt64(tc.input) + + require.NoError(t, err) + assert.Equal(t, tc.expected, val) + } +} + +func TestRedisRateLimiterStore_toInt64_ErrorCases(t *testing.T) { + _, err := toInt64(struct{}{}) + + assert.ErrorIs(t, err, errInvalidRedisResultType) +} diff --git a/pkg/gofr/service/rate_limiter_test.go b/pkg/gofr/service/rate_limiter_test.go new file mode 100644 index 000000000..b0b2027c0 --- /dev/null +++ b/pkg/gofr/service/rate_limiter_test.go @@ -0,0 +1,180 @@ +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockStore struct { + allowed bool + retryAfter time.Duration + err error +} + +func (m *mockStore) Allow(_ context.Context, _ string, _ RateLimiterConfig) (bool, time.Duration, error) { + return m.allowed, m.retryAfter, m.err +} +func (*mockStore) StartCleanup(_ context.Context) {} + +func (*mockStore) StopCleanup() {} + +func TestRateLimiter_buildFullURL(t *testing.T) { + httpSvc := &httpService{url: "http://base.com/api"} + rl := &rateLimiter{HTTP: httpSvc} + + assert.Equal(t, "http://foo.com/bar", rl.buildFullURL("http://foo.com/bar")) + assert.Equal(t, "https://foo.com/bar", rl.buildFullURL("https://foo.com/bar")) + assert.Equal(t, "http://base.com/api/foo", rl.buildFullURL("foo")) + assert.Equal(t, "http://base.com/api/foo", rl.buildFullURL("/foo")) + + httpSvc.url = "" + + assert.Equal(t, "bar", rl.buildFullURL("bar")) + + rl.HTTP = &mockHTTP{} + + assert.Equal(t, "baz", rl.buildFullURL("baz")) +} + +func TestRateLimiter_checkRateLimit_Error(t *testing.T) { + store := &mockStore{allowed: true, err: errTest} + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + } + + req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + + err := rl.checkRateLimit(req) + + require.NoError(t, err) +} + +func TestRateLimiter_checkRateLimit_Denied(t *testing.T) { + store := &mockStore{allowed: false} + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + } + + req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + err := rl.checkRateLimit(req) + + assert.IsType(t, &RateLimitError{}, err) +} + +func TestRateLimiter_checkRateLimit_Allowed(t *testing.T) { + store := &mockStore{allowed: true} + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + } + + req, _ := http.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + + err := rl.checkRateLimit(req) + assert.NoError(t, err) +} + +func TestRateLimiter_HTTPMethods(t *testing.T) { + mock := &mockHTTP{} + + store := &mockStore{allowed: true} + + rl := &rateLimiter{ + config: RateLimiterConfig{ + KeyFunc: func(*http.Request) string { return "svc" }, + Store: store, + }, + store: store, + HTTP: mock, + } + + ctx := context.Background() + resp, err := rl.Get(ctx, "foo", nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.GetWithHeaders(ctx, "foo", nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.Post(ctx, "foo", nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.PostWithHeaders(ctx, "foo", nil, nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.Put(ctx, "foo", nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.PutWithHeaders(ctx, "foo", nil, nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.Patch(ctx, "foo", nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.PatchWithHeaders(ctx, "foo", nil, nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.Delete(ctx, "foo", nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusNoContent, resp.StatusCode) + + defer resp.Body.Close() + + resp, err = rl.DeleteWithHeaders(ctx, "foo", nil, nil) + + require.NoError(t, err) + assert.Equal(t, http.StatusNoContent, resp.StatusCode) + + _ = resp.Body.Close() +}