Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add RateLimit middleware using TokenBucket algorithm #557

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
11 changes: 10 additions & 1 deletion pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,23 @@ type ServerSecurityOptions struct {
// These are the Access-Control-Request-Headers that the server will respond to.
// By default, the server will allow Accept, Accept-Language, Content-Language, and Content-Type.
// DeprecatedUser this setting to add any additional headers which are needed
AllowedHeaders []string `json:"allowedHeaders"`
AllowedHeaders []string `json:"allowedHeaders"`
RateLimit RateLimitOptions `json:"rateLimit"`
}

type SslOptions struct {
CertificateFile string `json:"certificateFile"`
KeyFile string `json:"keyFile"`
}

// RateLimitOptions is a type to hold rate limit configuration options.
type RateLimitOptions struct {
Enabled bool `json:"enabled" pflag:",Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm."`
RequestsPerSecond int `json:"requestsPerSecond" pflag:",The number of requests allowed per second."`
BurstSize int `json:"burstSize" pflag:",The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens."`
CleanupInterval config.Duration `json:"cleanupInterval" pflag:",The interval at which the rate limiter cleans up entries that have not been used for a certain period of time."`
}

LaPetiteSouris marked this conversation as resolved.
Show resolved Hide resolved
var defaultServerConfig = &ServerConfig{
HTTPPort: 8088,
Security: ServerSecurityOptions{
Expand Down
4 changes: 4 additions & 0 deletions pkg/config/serverconfig_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions pkg/config/serverconfig_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
auth.AuthenticationLoggingInterceptor,
middlewareInterceptors,
)
if cfg.Security.RateLimit.Enabled {
LaPetiteSouris marked this conversation as resolved.
Show resolved Hide resolved
rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration)
rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter)
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors)
}
} else {
LaPetiteSouris marked this conversation as resolved.
Show resolved Hide resolved
logger.Infof(ctx, "Creating gRPC server without authentication")
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor)
if cfg.Security.RateLimit.Enabled {
logger.Warningf(ctx, "Rate limit is enabled but auth is not")
}
}

serverOpts := []grpc.ServerOption{
Expand Down Expand Up @@ -257,6 +265,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry,
}

oauth2ResourceServer = oauth2Provider

} else {
oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL)
if err != nil {
Expand Down
116 changes: 116 additions & 0 deletions plugins/rate_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package plugins

import (
"context"
"errors"
"fmt"
"sync"
"time"

auth "github.com/flyteorg/flyteadmin/auth"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type RateLimitExceeded error

// accessRecords stores the rate limiter and the last access time
type accessRecords struct {
limiter *rate.Limiter
lastAccess time.Time
mutex *sync.Mutex
}

// LimiterStore stores the access records for each user
type LimiterStore struct {
// accessPerUser is a synchronized map of userID to accessRecords
accessPerUser *sync.Map
requestPerSec int
burstSize int
cleanupInterval time.Duration
}

// Allow takes a userID and returns an error if the user has exceeded the rate limit
func (l *LimiterStore) Allow(userID string) error {
accessRecord, _ := l.accessPerUser.LoadOrStore(userID, &accessRecords{
limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize),
lastAccess: time.Now(),
mutex: &sync.Mutex{},
})
accessRecord.(*accessRecords).mutex.Lock()
defer accessRecord.(*accessRecords).mutex.Unlock()

accessRecord.(*accessRecords).lastAccess = time.Now()
l.accessPerUser.Store(userID, accessRecord)

if !accessRecord.(*accessRecords).limiter.Allow() {
return RateLimitExceeded(fmt.Errorf("rate limit exceeded"))
}
Comment on lines +48 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this happen before we modify the map?


return nil
}

// clean removes the access records for users who have not accessed the system for a while
func (l *LimiterStore) clean() {
l.accessPerUser.Range(func(key, value interface{}) bool {
value.(*accessRecords).mutex.Lock()
defer value.(*accessRecords).mutex.Unlock()
if time.Since(value.(*accessRecords).lastAccess) > l.cleanupInterval {
l.accessPerUser.Delete(key)
}
return true
})
}

func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore {
l := &LimiterStore{
accessPerUser: &sync.Map{},
requestPerSec: requestPerSec,
burstSize: burstSize,
cleanupInterval: cleanupInterval,
}

go func() {
for {
time.Sleep(l.cleanupInterval)
l.clean()
}
}()

return l
}

// RateLimiter is a struct that implements the RateLimiter interface from grpc middleware
type RateLimiter struct {
limiter *LimiterStore
}

func (r *RateLimiter) Limit(ctx context.Context) error {
IdenCtx := auth.IdentityContextFromContext(ctx)
if IdenCtx.IsEmpty() {
return errors.New("no identity context found")
}
userID := IdenCtx.UserID()
if err := r.limiter.Allow(userID); err != nil {
return err
}
return nil
}

func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter {
limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval)
return &RateLimiter{limiter: limiter}
}

func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded")
}

return handler(ctx, req)
}
}
126 changes: 126 additions & 0 deletions plugins/rate_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package plugins

import (
"context"
"testing"
"time"

auth "github.com/flyteorg/flyteadmin/auth"
"github.com/stretchr/testify/assert"
)

func TestNewRateLimiter(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NotNil(t, rlStore)
}

func TestLimiterAllow(t *testing.T) {
rlStore := newRateLimitStore(1, 1, 10*time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
assert.NoError(t, rlStore.Allow("hello"))
}

func TestLimiterAllowBurst(t *testing.T) {
rlStore := newRateLimitStore(1, 2, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
assert.NoError(t, rlStore.Allow("world"))
}

func TestLimiterClean(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
rlStore.clean()
assert.NoError(t, rlStore.Allow("hello"))
}

func TestLimiterAllowOnMultipleRequests(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("a"))
assert.NoError(t, rlStore.Allow("b"))
assert.NoError(t, rlStore.Allow("c"))
assert.Error(t, rlStore.Allow("a"))
assert.Error(t, rlStore.Allow("b"))

time.Sleep(time.Second)

assert.NoError(t, rlStore.Allow("a"))
assert.Error(t, rlStore.Allow("a"))
assert.NoError(t, rlStore.Allow("b"))
assert.Error(t, rlStore.Allow("b"))
assert.NoError(t, rlStore.Allow("c"))
}

func TestRateLimiterLimitPass(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
assert.NoError(t, err)

ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
err = rateLimit.Limit(ctx)
assert.NoError(t, err)

}

func TestRateLimiterLimitStop(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
assert.NoError(t, err)
ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
err = rateLimit.Limit(ctx)
assert.NoError(t, err)

err = rateLimit.Limit(ctx)
assert.Error(t, err)

}

func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

ctx := context.TODO()

err := rateLimit.Limit(ctx)
assert.Error(t, err)
}

func TestRateLimiterUpdateLastAccessTime(t *testing.T) {
rlStore := newRateLimitStore(2, 2, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
// get last access time

accessRecord, _ := rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
firstAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.NoError(t, rlStore.Allow("hello"))

accessRecord, _ = rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
secondAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.True(t, secondAccessTime.After(firstAccessTime))

// Verify that the last access time is updated even when user is rate limited
assert.Error(t, rlStore.Allow("hello"))

accessRecord, _ = rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
thirdAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.True(t, thirdAccessTime.After(secondAccessTime))

}