diff --git a/pkg/config/config.go b/pkg/config/config.go index a855f75dd..d86c07333 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -81,7 +81,7 @@ type ServerSecurityOptions struct { // By default, the server will allow Accept, Accept-Language, Content-Language, and Content-Type. // DeprecatedUser this setting to add any additional headers which are needed AllowedHeaders []string `json:"allowedHeaders"` - RateLimit RateLimitOptions `json:"rateLimitOptions"` + RateLimit RateLimitOptions `json:"rateLimit"` } type SslOptions struct { @@ -89,12 +89,12 @@ type SslOptions struct { KeyFile string `json:"keyFile"` } -// declare RateLimitConfig +// RateLimitOptions is a type to hold rate limit configuration options. type RateLimitOptions struct { - Enabled bool `json:"enabled"` - RequestsPerSecond int `json:"requestsPerSecond"` - BurstSize int `json:"burstSize"` - CleanupInterval config.Duration `json:"cleanupInterval"` + Enabled bool `json:"enabled" pflag:",Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm."` + RequestsPerSecond int `json:"requestsPerSecond" pflag:",The number of requests allowed per second."` + BurstSize int `json:"burstSize" pflag:",The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens."` + CleanupInterval config.Duration `json:"cleanupInterval" pflag:",The interval at which the rate limiter cleans up entries that have not been used for a certain period of time."` } var defaultServerConfig = &ServerConfig{ diff --git a/pkg/config/serverconfig_flags.go b/pkg/config/serverconfig_flags.go index c37a82603..fe7e00e00 100755 --- a/pkg/config/serverconfig_flags.go +++ b/pkg/config/serverconfig_flags.go @@ -63,6 +63,10 @@ func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.allowCors"), defaultServerConfig.Security.AllowCors, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedOrigins"), defaultServerConfig.Security.AllowedOrigins, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedHeaders"), defaultServerConfig.Security.AllowedHeaders, "") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.rateLimit.enabled"), defaultServerConfig.Security.RateLimit.Enabled, "Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "security.rateLimit.requestsPerSecond"), defaultServerConfig.Security.RateLimit.RequestsPerSecond, "The number of requests allowed per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "security.rateLimit.burstSize"), defaultServerConfig.Security.RateLimit.BurstSize, "The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "security.rateLimit.cleanupInterval"), defaultServerConfig.Security.RateLimit.CleanupInterval.String(), "The interval at which the rate limiter cleans up entries that have not been used for a certain period of time.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.port"), defaultServerConfig.GrpcConfig.Port, "On which grpc port to serve admin") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "grpc.serverReflection"), defaultServerConfig.GrpcConfig.ServerReflection, "Enable GRPC Server Reflection") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.maxMessageSizeBytes"), defaultServerConfig.GrpcConfig.MaxMessageSizeBytes, "The max size in bytes for incoming gRPC messages") diff --git a/pkg/config/serverconfig_flags_test.go b/pkg/config/serverconfig_flags_test.go index b16e0416d..f5eaa0752 100755 --- a/pkg/config/serverconfig_flags_test.go +++ b/pkg/config/serverconfig_flags_test.go @@ -281,6 +281,62 @@ func TestServerConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_security.rateLimit.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.enabled", testValue) + if vBool, err := cmdFlags.GetBool("security.rateLimit.enabled"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vBool), &actual.Security.RateLimit.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.requestsPerSecond", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.requestsPerSecond", testValue) + if vInt, err := cmdFlags.GetInt("security.rateLimit.requestsPerSecond"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.Security.RateLimit.RequestsPerSecond) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.burstSize", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.burstSize", testValue) + if vInt, err := cmdFlags.GetInt("security.rateLimit.burstSize"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.Security.RateLimit.BurstSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.cleanupInterval", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultServerConfig.Security.RateLimit.CleanupInterval.String() + + cmdFlags.Set("security.rateLimit.cleanupInterval", testValue) + if vString, err := cmdFlags.GetString("security.rateLimit.cleanupInterval"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vString), &actual.Security.RateLimit.CleanupInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_grpc.port", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/server/service.go b/pkg/server/service.go index 795ffbdfd..4522a2a4f 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -98,6 +98,9 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } else { logger.Infof(ctx, "Creating gRPC server without authentication") chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) + if cfg.Security.RateLimit.Enabled { + logger.Warningf(ctx, "Rate limit is enabled but auth is not") + } } serverOpts := []grpc.ServerOption{ diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index d32cdf985..07800b728 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -16,13 +16,15 @@ import ( type RateLimitExceeded error -// define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map +// accessRecords stores the rate limiter and the last access time type accessRecords struct { limiter *rate.Limiter lastAccess time.Time } +// LimiterStore stores the access records for each user type LimiterStore struct { + // accessPerUser is a synchronized map of userID to accessRecords accessPerUser map[string]*accessRecords mutex *sync.Mutex requestPerSec int @@ -30,9 +32,7 @@ type LimiterStore struct { cleanupInterval time.Duration } -// define a function named Allow that takes userID and returns RateLimitError -// the function check if the user is in the map, if not, create a new accessRecords for the user -// then it check if the user can access the resource, if not, return RateLimitError +// Allow takes a userID and returns an error if the user has exceeded the rate limit func (l *LimiterStore) Allow(userID string) error { l.mutex.Lock() defer l.mutex.Unlock()