diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..585f569 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea/ +.vscode/ +bin/ \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..5d528e3 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,19 @@ +Copyright (c) 2019 coinpaprika + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..4367b15 --- /dev/null +++ b/README.md @@ -0,0 +1,115 @@ +# ratelimiter + +Simple rate limiter for any resources inspired by Cloudflare's approach: [How we built rate limiting capable of scaling to millions of domains.](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/) + +## Usage + +### Getting started + +```go +package main + +import ( + "fmt" + "log" + "time" + + "github.com/coinpaprika/ratelimiter" +) + +func main() { + limitedKey := "key" + windowSize := 1 * time.Minute + + dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second) // create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second + + var maxLimit int64 = 5 + rateLimiter := ratelimiter.New(dataStore, maxLimit, windowSize) // allow 5 requests per windowSize (1 minute) + + for i := 0; i < 10; i++ { + limitStatus, err := rateLimiter.Check(limitedKey) + if err != nil { + log.Fatal(err) + } + if limitStatus.IsLimited { + fmt.Printf("too high rate for key: %s: rate: %f, limit: %d\nsleep: %s", limitedKey, limitStatus.CurrentRate, maxLimit, *limitStatus.LimitDuration) + time.Sleep(*limitStatus.LimitDuration) + } else { + err := rateLimiter.Inc(limitedKey) + if err != nil { + log.Fatal(err) + } + } + } +} +``` + +### Rate-limit IP requests in http middleware + +```go +func rateLimitMiddleware(rateLimiter *ratelimiter.RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteIP := GetRemoteIP([]string{"X-Forwarded-For", "RemoteAddr", "X-Real-IP"}, 0, r) + key := fmt.Sprintf("%s_%s_%s", remoteIP, r.URL.String(), r.Method) + + limitStatus, err := rateLimiter.Check(key) + if err != nil { + // if rate limit error then pass the request + next.ServeHTTP(w, r) + } + if limitStatus.IsLimited { + w.WriteHeader(http.StatusTooManyRequests) + return + } else { + rateLimiter.Inc(key) + } + + next.ServeHTTP(w, r) + }) + } +} + +func hello(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +} + +func main() { + windowSize := 1 * time.Minute + dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second) // create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second + rateLimiter := ratelimiter.New(dataStore, 5, windowSize) // allow 5 requests per windowSize (1 minute) + + rateLimiterHandler := rateLimitMiddleware(rateLimiter) + helloHandler := http.HandlerFunc(hello) + http.Handle("/", rateLimiterHandler(helloHandler)) + + log.Fatal(http.ListenAndServe(":8080", nil)) + +} +``` +See full [example](./examples/http_middleware/http_middleware.go) + +### Implement your own limit data store +To use custom data store (memcached, Redis, MySQL etc.) you just need to implement [LimitStore](./limit_store.go) interface: +```go +type FakeDataStore struct{} + +func (f FakeDataStore) Inc(key string, window time.Time) error { + return nil +} + +func (f FakeDataStore) Get(key string, previousWindow, currentWindow time.Time) (prevValue int64, currValue int64, err error) { + return 0, 0, nil +} +// ... +rateLimiter := ratelimiter.New(FakeDataStore{}, maxLimit, windowSize) +``` + +## Examples + +Check out the [examples](./examples) directory. + + +## License + +ratelimiter is available under the MIT license. See the [LICENSE file](./LICENSE.md) for more info. \ No newline at end of file diff --git a/examples/http_middleware/http_middleware.go b/examples/http_middleware/http_middleware.go new file mode 100644 index 0000000..cfc8f4e --- /dev/null +++ b/examples/http_middleware/http_middleware.go @@ -0,0 +1,90 @@ +package main + +import ( + "fmt" + "html" + "log" + "net" + "net/http" + "strings" + "time" + + "github.com/coinpaprika/ratelimiter" +) + +// copied from https://github.com/didip/tollbooth/blob/master/libstring/libstring.go#L21 +func GetRemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string { + realIP := r.Header.Get("X-Real-IP") + forwardedFor := r.Header.Get("X-Forwarded-For") + + for _, lookup := range ipLookups { + if lookup == "RemoteAddr" { + // 1. Cover the basic use cases for both ipv4 and ipv6 + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // 2. Upon error, just return the remote addr. + return r.RemoteAddr + } + return ip + } + if lookup == "X-Forwarded-For" && forwardedFor != "" { + // X-Forwarded-For is potentially a list of addresses separated with "," + parts := strings.Split(forwardedFor, ",") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + + partIndex := len(parts) - 1 - forwardedForIndexFromBehind + if partIndex < 0 { + partIndex = 0 + } + + return parts[partIndex] + } + if lookup == "X-Real-IP" && realIP != "" { + return realIP + } + } + + return "" +} + +func rateLimitMiddleware(rateLimiter *ratelimiter.RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteIP := GetRemoteIP([]string{"X-Forwarded-For", "RemoteAddr", "X-Real-IP"}, 0, r) + key := fmt.Sprintf("%s_%s_%s", remoteIP, r.URL.String(), r.Method) + + limitStatus, err := rateLimiter.Check(key) + if err != nil { + // if rate limit error then pass the request + next.ServeHTTP(w, r) + } + if limitStatus.IsLimited { + w.WriteHeader(http.StatusTooManyRequests) + return + } else { + rateLimiter.Inc(key) + } + + next.ServeHTTP(w, r) + }) + } +} + +func hello(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +} + +func main() { + windowSize := 1 * time.Minute + dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second) // create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second + rateLimiter := ratelimiter.New(dataStore, 5, windowSize) // allow 5 requests per windowSize (1 minute) + + rateLimiterHandler := rateLimitMiddleware(rateLimiter) + helloHandler := http.HandlerFunc(hello) + http.Handle("/", rateLimiterHandler(helloHandler)) + + log.Fatal(http.ListenAndServe(":8080", nil)) + +} diff --git a/examples/simple/simple.go b/examples/simple/simple.go new file mode 100644 index 0000000..278519f --- /dev/null +++ b/examples/simple/simple.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "log" + "time" + + "github.com/coinpaprika/ratelimiter" +) + +func main() { + limitedKey := "key" + windowSize := 1 * time.Minute + + dataStore := ratelimiter.NewMapLimitStore(2*windowSize, 10*time.Second) // create map data store for rate limiter and set each element's expiration time to 2*windowSize and old data flush interval to 10*time.Second + + var maxLimit int64 = 5 + rateLimiter := ratelimiter.New(dataStore, maxLimit, windowSize) // allow 5 requests per windowSize (1 minute) + + for i := 0; i < 10; i++ { + limitStatus, err := rateLimiter.Check(limitedKey) + if err != nil { + log.Fatal(err) + } + if limitStatus.IsLimited { + fmt.Printf("too high rate for key: %s: rate: %f, limit: %d\nsleep: %s", limitedKey, limitStatus.CurrentRate, maxLimit, *limitStatus.LimitDuration) + time.Sleep(*limitStatus.LimitDuration) + } else { + err := rateLimiter.Inc(limitedKey) + if err != nil { + log.Fatal(err) + } + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1f17a97 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/coinpaprika/ratelimiter + +require github.com/stretchr/testify v1.3.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4347755 --- /dev/null +++ b/go.sum @@ -0,0 +1,7 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/limit_store.go b/limit_store.go new file mode 100644 index 0000000..50b2c21 --- /dev/null +++ b/limit_store.go @@ -0,0 +1,11 @@ +package ratelimiter + +import "time" + +// LimitStore is the interface that represents limiter internal data store. Any database struct that implements LimitStore should have functions for incrementing counter of a given key and getting counter values of a given key for previous and current window +type LimitStore interface { + // Inc increments current window limit counter for key + Inc(key string, window time.Time) error + // Get gets value of previous window counter and current window counter for key + Get(key string, previousWindow, currentWindow time.Time) (prevValue int64, currValue int64, err error) +} diff --git a/map_limit_store.go b/map_limit_store.go new file mode 100644 index 0000000..7371343 --- /dev/null +++ b/map_limit_store.go @@ -0,0 +1,67 @@ +package ratelimiter + +import ( + "fmt" + "sync" + "time" +) + +type limitValue struct { + val int64 + lastUpdate time.Time +} + +// MapLimitStore represents internal limiter data database where data are stored in golang maps +type MapLimitStore struct { + data map[string]limitValue + mutex sync.RWMutex + expirationTime time.Duration +} + +// NewMapLimitStore creates new in-memory data store for internal limiter data. Each element of MapLimitStore is set as expired after expirationTime from its last counter increment. Expired elements are removed with a period specified by the flushInterval argument +func NewMapLimitStore(expirationTime time.Duration, flushInterval time.Duration) (m *MapLimitStore) { + m = &MapLimitStore{ + data: make(map[string]limitValue), + expirationTime: expirationTime, + } + go func() { + ticker := time.NewTicker(flushInterval) + for range ticker.C { + var deletedKeys []string + for key, val := range m.data { + if val.lastUpdate.Before(time.Now().UTC().Add(-m.expirationTime)) { + m.mutex.Lock() + delete(m.data, key) + deletedKeys = append(deletedKeys, key) + m.mutex.Unlock() + } + } + } + }() + return m +} + +// Inc increments current window limit counter for key +func (m *MapLimitStore) Inc(key string, window time.Time) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + data := m.data[mapKey(key, window)] + data.val++ + data.lastUpdate = time.Now().UTC() + m.data[mapKey(key, window)] = data + return nil +} + +// Get gets value of previous window counter and current window counter for key +func (m *MapLimitStore) Get(key string, previousWindow, currentWindow time.Time) (prevValue int64, currValue int64, err error) { + m.mutex.RLock() + defer m.mutex.RUnlock() + prevValue = m.data[mapKey(key, previousWindow)].val + currValue = m.data[mapKey(key, currentWindow)].val + return prevValue, currValue, nil +} + +func mapKey(key string, window time.Time) string { + return fmt.Sprintf("%s_%s", key, window.Format(time.RFC3339)) +} diff --git a/map_limit_store_test.go b/map_limit_store_test.go new file mode 100644 index 0000000..7ed2da9 --- /dev/null +++ b/map_limit_store_test.go @@ -0,0 +1,94 @@ +package ratelimiter + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewMapLimitStore(t *testing.T) { + + type args struct { + expirationTime time.Duration + flushInterval time.Duration + } + tests := []struct { + name string + args args + }{ + { + name: "test_NewMapLimitStore", + args: args{ + expirationTime: 1 * time.Minute, + flushInterval: 2 * time.Second, + }, + }, + } + for _, tt := range tests { + mapLimitStore := NewMapLimitStore(tt.args.expirationTime, tt.args.flushInterval) + assert.Equal(t, tt.args.expirationTime, mapLimitStore.expirationTime) + } +} + +func TestMapLimitStore_Inc(t *testing.T) { + tests := []struct { + name string + key string + window time.Time + wantErr bool + }{ + { + name: "test_MapLimitStore_Inc", + key: "tt", + window: time.Now().UTC(), + wantErr: false, + }, + } + for _, tt := range tests { + m := NewMapLimitStore(1*time.Minute, 10*time.Second) + err := m.Inc(tt.key, tt.window) + assert.NoError(t, err) + prevVal, currVal, err := m.Get(tt.key, tt.window, tt.window) + assert.NoError(t, err) + assert.Equal(t, int64(1), prevVal) + assert.Equal(t, int64(1), currVal) + } +} + +func TestMapLimitStore_Get(t *testing.T) { + type args struct { + key string + previousWindow time.Time + currentWindow time.Time + } + tests := []struct { + name string + args args + wantPrevValue int64 + wantCurrValue int64 + wantErr bool + }{ + { + name: "test_MapLimitStore_Get", + args: args{ + key: "tt", + previousWindow: time.Now().UTC().Add(-1 * time.Minute), + currentWindow: time.Now().UTC(), + }, + wantPrevValue: 10, + wantCurrValue: 5, + wantErr: false, + }, + } + for _, tt := range tests { + m := NewMapLimitStore(1*time.Minute, 10*time.Second) + m.data[mapKey(tt.args.key, tt.args.previousWindow)] = limitValue{val: tt.wantPrevValue} + m.data[mapKey(tt.args.key, tt.args.currentWindow)] = limitValue{val: tt.wantCurrValue} + + prevVal, currVal, err := m.Get(tt.args.key, tt.args.previousWindow, tt.args.currentWindow) + assert.NoError(t, err) + assert.Equal(t, tt.wantPrevValue, prevVal) + assert.Equal(t, tt.wantCurrValue, currVal) + } +} diff --git a/ratelimiter.go b/ratelimiter.go new file mode 100644 index 0000000..be22aa7 --- /dev/null +++ b/ratelimiter.go @@ -0,0 +1,85 @@ +package ratelimiter + +import ( + "time" +) + +// RateLimiter is a simple rate-limiter for any resources inspired by Cloudflare's approach: https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ +type RateLimiter struct { + dataStore LimitStore + requestsLimit int64 + windowSize time.Duration +} + +// New creates new rate limiter. A dataStore is internal limiter data store, requestsLimit and windowSize are parameters of limiter e.g. requestsLimit: 5 and windowSize: 1*time.Minute means that limiter allows up to 5 requests per minute +func New(dataStore LimitStore, requestsLimit int64, windowSize time.Duration) *RateLimiter { + + return &RateLimiter{ + dataStore: dataStore, + requestsLimit: requestsLimit, + windowSize: windowSize, + } +} + +// Inc increments limiter counter for a given key or returns error when it's not possible. Inc should be called when a request for a given limited resource has been fulfilled +func (r *RateLimiter) Inc(key string) error { + currentWindow := time.Now().UTC().Truncate(r.windowSize) + return r.dataStore.Inc(key, currentWindow) +} + +// LimitStatus represents current status of limitation for a given key +type LimitStatus struct { + // IsLimited is true when a given key should be rate-limited + IsLimited bool + // LimitDuration is not nil when IsLimited is true. It's the time for which a given key should be blocked before CurrentRate falls below declared in constructor requests limit + LimitDuration *time.Duration + // CurrentRate is approximated current requests rate per window size (declared in the constructor) + CurrentRate float64 +} + +// Check checks status of rate-limiting for a key. It returns error when limiter data could not be read +func (r *RateLimiter) Check(key string) (limitStatus *LimitStatus, err error) { + currentWindow := time.Now().UTC().Truncate(r.windowSize) + previousWindow := currentWindow.Add(-r.windowSize) + prevValue, currentValue, err := r.dataStore.Get(key, previousWindow, currentWindow) + if err != nil { + return nil, err + } + timeFromCurrWindow := time.Now().UTC().Sub(currentWindow) + + rate := float64((float64(r.windowSize)-float64(timeFromCurrWindow))/float64(r.windowSize))*float64(prevValue) + float64(currentValue) + limitStatus = &LimitStatus{} + if rate > float64(r.requestsLimit) { + limitStatus.IsLimited = true + limitDuration := r.calcLimitDuration(prevValue, currentValue, timeFromCurrWindow) + limitStatus.LimitDuration = &limitDuration + } + limitStatus.CurrentRate = rate + + return limitStatus, nil +} + +func (r *RateLimiter) calcRate(timeFromCurrWindow time.Duration, prevValue int64, currentValue int64) float64 { + return float64((float64(r.windowSize)-float64(timeFromCurrWindow))/float64(r.windowSize))*float64(prevValue) + float64(currentValue) +} + +func (r *RateLimiter) calcLimitDuration(prevValue, currValue int64, timeFromCurrWindow time.Duration) time.Duration { + // we should find x parameter in equation: x*prevValue+currentValue = r.requestsLimit + // then (1.0-x)*windowSize is duration from current window start when limit can be removed + // then ((1.0-x)*windowSize) - timeFromCurrWindow is duration since current time to the time when limit can be removed = limitDuration + // -- + // if prevValue is zero then unblock is in the next window so we should use equation x*currentValue+nextWindowValue = r.requestsLimit + // to calculate x parameter + var limitDuration time.Duration + if prevValue == 0 { + // unblock in the next window where prevValue is currValue and currValue is zero (assuming that since limit start all requests are blocked) + nextWindowUnblockPoint := float64(r.windowSize) * (1.0 - (float64(r.requestsLimit) / float64(currValue))) + timeToNextWindow := r.windowSize - timeFromCurrWindow + limitDuration = timeToNextWindow + time.Duration(int64(nextWindowUnblockPoint)+1) + } else { + currWindowUnblockPoint := float64(r.windowSize) * (1.0 - (float64(r.requestsLimit-currValue) / float64(prevValue))) + limitDuration = time.Duration(int64(currWindowUnblockPoint+1)) - timeFromCurrWindow + + } + return limitDuration +} diff --git a/ratelimiter_test.go b/ratelimiter_test.go new file mode 100644 index 0000000..6582081 --- /dev/null +++ b/ratelimiter_test.go @@ -0,0 +1,146 @@ +package ratelimiter + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRateLimiter_IsLimited(t *testing.T) { + + tests := []struct { + name string + requestsLimit int64 + windowSize time.Duration + incNumber int + wantLimitStatus *LimitStatus + }{ + { + name: "test_RateLimiter_IsLimited_not_limited", + requestsLimit: int64(5), + windowSize: 10 * time.Second, + incNumber: 4, + wantLimitStatus: &LimitStatus{ + IsLimited: false, + CurrentRate: 4, + }, + }, + { + name: "test_RateLimiter_IsLimited_limited", + requestsLimit: int64(5), + windowSize: 10 * time.Second, + incNumber: 6, + wantLimitStatus: &LimitStatus{ + IsLimited: true, + CurrentRate: 6, + }, + }, + } + for _, tt := range tests { + store := NewMapLimitStore(1*time.Hour, 1*time.Hour) + r := New(store, tt.requestsLimit, tt.windowSize) + for i := 0; i < tt.incNumber; i++ { + err := r.Inc("key") + assert.NoError(t, err, tt.name) + } + limitStatus, err := r.Check("key") + assert.NoError(t, err, tt.name) + assert.Equal(t, tt.wantLimitStatus.IsLimited, limitStatus.IsLimited, tt.name) + assert.Equal(t, tt.wantLimitStatus.CurrentRate, limitStatus.CurrentRate, tt.name) + + } +} + +func TestRateLimiter_calcLimitDuration(t *testing.T) { + tests := []struct { + name string + prevValue int64 + currValue int64 + timeFromCurrWindow time.Duration + requestsLimit int64 + windowSize time.Duration + want time.Duration + }{ + { + name: "TestRateLimiter_calcLimitDuration_prev_value_is_not_zero", + prevValue: 5, + currValue: 6, + timeFromCurrWindow: 1 * time.Second, + requestsLimit: 5, + windowSize: 10 * time.Second, + want: time.Duration(11 * time.Second), // 10*(1.0-( (5-6)/5)) - 1 + }, + { + name: "TestRateLimiter_calcLimitDuration_prev_value_is_zero", + prevValue: 0, + currValue: 6, + timeFromCurrWindow: 1 * time.Second, + requestsLimit: 5, + windowSize: 10 * time.Second, + want: time.Duration(10666666666 * time.Nanosecond), // 10*(1.0-(5/6)) + (10-1) + }, + } + for _, tt := range tests { + store := NewMapLimitStore(1*time.Hour, 1*time.Hour) + r := New(store, tt.requestsLimit, tt.windowSize) + dur := r.calcLimitDuration(tt.prevValue, tt.currValue, tt.timeFromCurrWindow) + assert.InDelta(t, tt.want, dur, 3) + } +} + +func TestRateLimiter_calcRate(t *testing.T) { + + tests := []struct { + name string + requestsLimit int64 + windowSize time.Duration + timeFromCurrWindow time.Duration + prevValue int64 + currentValue int64 + want float64 + }{ + { + name: "TestRateLimiter_calcRate_prev_not_zero", + requestsLimit: 5, + windowSize: 10 * time.Second, + timeFromCurrWindow: 1 * time.Second, + prevValue: 5, + currentValue: 6, + want: (0.9 * 5) + 6.0, + }, + { + name: "TestRateLimiter_calcRate_prev_zero", + requestsLimit: 5, + windowSize: 10 * time.Second, + timeFromCurrWindow: 1 * time.Second, + prevValue: 0, + currentValue: 6, + want: 6.0, + }, + { + name: "TestRateLimiter_calcRate_timeFromCurrWindow_zero", + requestsLimit: 5, + windowSize: 10 * time.Second, + timeFromCurrWindow: 0 * time.Second, + prevValue: 5, + currentValue: 0, + want: 5.0, + }, + { + name: "TestRateLimiter_calcRate_timeFromCurrWindow_max", + requestsLimit: 5, + windowSize: 10 * time.Second, + timeFromCurrWindow: 10 * time.Second, + prevValue: 5, + currentValue: 6, + want: 6.0, + }, + } + for _, tt := range tests { + store := NewMapLimitStore(1*time.Hour, 1*time.Hour) + r := New(store, tt.requestsLimit, tt.windowSize) + rate := r.calcRate(tt.timeFromCurrWindow, tt.prevValue, tt.currentValue) + assert.Equal(t, tt.want, rate) + } +}