diff --git a/auto_ratelimit.go b/auto_ratelimit.go new file mode 100644 index 0000000..e781d1a --- /dev/null +++ b/auto_ratelimit.go @@ -0,0 +1,263 @@ +package ratelimit + +import ( + "context" + "errors" + "sync" + "time" +) + +var ( + ErrAutoKeyAlreadyExists = errors.New("key already exists") + ErrAutoKeyMissing = errors.New("key does not exist") +) + +// AutoLimiterOption is a function that configures the AutoLimiter +type AutoLimiterOption func(*AutoLimiter) + +// WithUnlimited sets the limiter to unlimited mode +func WithUnlimited() AutoLimiterOption { + return func(e *AutoLimiter) { + e.defaultOptions.IsUnlimited = true + } +} + +// WithDuration sets the duration for the rate limiter +func WithDuration(duration time.Duration) AutoLimiterOption { + return func(e *AutoLimiter) { + e.defaultOptions.Duration = duration + } +} + +// WithMaxCount sets the maximum count for the rate limiter +func WithMaxCount(maxCount uint) AutoLimiterOption { + return func(e *AutoLimiter) { + e.defaultOptions.MaxCount = maxCount + } +} + +// AutoLimiter is an improved version of MultiLimiter with better memory management +type AutoLimiter struct { + limiters sync.Map // map of active limiters + options sync.Map // map of custom options (only for keys with custom settings) + ctx context.Context + + // Default options for automatically created limiters + defaultOptions *internalOptions +} + +// NewAutoLimiter creates a new auto limiter instance using functional options +func NewAutoLimiter(ctx context.Context, opts ...AutoLimiterOption) *AutoLimiter { + e := &AutoLimiter{ + ctx: ctx, + limiters: sync.Map{}, + options: sync.Map{}, + defaultOptions: &internalOptions{}, + } + + // Apply all options to set the defaults + for _, opt := range opts { + opt(e) + } + + return e +} + +// internalOptions holds configuration internally (not exposed to users) +type internalOptions struct { + Key string + IsUnlimited bool + MaxCount uint + Duration time.Duration +} + +// Validate internal options +func (o *internalOptions) Validate() error { + if !o.IsUnlimited { + if o.Key == "" { + return errors.New("empty keys not allowed") + } + if o.MaxCount == 0 { + return errors.New("maxcount cannot be zero") + } + if o.Duration == 0 { + return errors.New("time duration not set") + } + } + return nil +} + +// Add creates a new rate limiter with custom settings (only for keys that need specific limits) +func (e *AutoLimiter) Add(key string, opts ...AutoLimiterOption) error { + // Create options struct and apply functional options to it + options := &internalOptions{Key: key} + + // Apply all options to the options struct + for _, opt := range opts { + // Create a temporary limiter to apply the option + tempLimiter := &AutoLimiter{defaultOptions: options} + opt(tempLimiter) + } + + // Validate the configuration + if !options.IsUnlimited { + if key == "" { + return errors.New("empty keys not allowed") + } + if options.MaxCount == 0 { + return errors.New("maxcount cannot be zero") + } + if options.Duration == 0 { + return errors.New("time duration not set") + } + } + + // Check if key already exists + if _, exists := e.limiters.Load(key); exists { + return ErrAutoKeyAlreadyExists + } + + // Create new limiter with custom settings + var rlimiter *Limiter + if options.IsUnlimited { + rlimiter = NewUnlimited(e.ctx) + } else { + rlimiter = New(e.ctx, options.MaxCount, options.Duration) + } + + // Store the limiter and options + e.limiters.Store(key, rlimiter) + e.options.Store(key, options) + + return nil +} + +// Take one token from bucket - creates limiter automatically if it doesn't exist +func (e *AutoLimiter) Take(key string) error { + limiter, err := e.get(key) + if err != nil { + // Key doesn't exist, create it with default settings + limiter = e.createOrDefault(key) + } + limiter.Take() + return nil +} + +// Stop internal limiters with defined keys or all if no key is provided +func (e *AutoLimiter) Stop(keys ...string) { + if len(keys) == 0 { + e.limiters.Range(func(key, value any) bool { + if limiter, ok := value.(*Limiter); ok { + limiter.Stop() + e.limiters.Delete(key) + // Keep the options for potential recreation + } + return true + }) + return + } + for _, v := range keys { + if limiter, err := e.get(v); err == nil { + limiter.Stop() + e.limiters.Delete(v) + // Keep the options for potential recreation + } + } +} + +// Remove completely removes a key and its options +func (e *AutoLimiter) Remove(key string) { + // Stop and remove the limiter if it exists + if limiter, err := e.get(key); err == nil { + limiter.Stop() + e.limiters.Delete(key) + } + // Remove the stored options + e.options.Delete(key) +} + +// get returns *Limiter instance +func (e *AutoLimiter) get(key string) (*Limiter, error) { + val, _ := e.limiters.Load(key) + if val == nil { + return nil, ErrAutoKeyMissing + } + if limiter, ok := val.(*Limiter); ok { + return limiter, nil + } + return nil, errors.New("type assertion of rateLimiter failed in autoLimiter") +} + +// recreateLimiter recreates a limiter from stored options +func (e *AutoLimiter) recreateLimiter(key string) (*Limiter, error) { + // Check if we have stored options for this key + optsVal, exists := e.options.Load(key) + if !exists { + return nil, ErrAutoKeyMissing + } + + opts, ok := optsVal.(*internalOptions) + if !ok { + return nil, errors.New("invalid options type") + } + + // Create new limiter with stored options + var rlimiter *Limiter + if opts.IsUnlimited { + rlimiter = NewUnlimited(e.ctx) + } else { + rlimiter = New(e.ctx, opts.MaxCount, opts.Duration) + } + + // Store the new limiter + e.limiters.Store(key, rlimiter) + + return rlimiter, nil +} + +// createOrDefault creates a new limiter with default settings +func (e *AutoLimiter) createOrDefault(key string) *Limiter { + // First check if we have stored custom options for this key + if optsVal, exists := e.options.Load(key); exists { + if _, ok := optsVal.(*internalOptions); ok { + // Recreate with stored custom options + if limiter, err := e.recreateLimiter(key); err == nil { + return limiter + } + // If recreation fails, fall back to defaults + } + } + + // No custom options, create with stored default options + var limiter *Limiter + if e.defaultOptions.IsUnlimited { + limiter = NewUnlimited(e.ctx) + } else { + limiter = New(e.ctx, e.defaultOptions.MaxCount, e.defaultOptions.Duration) + } + + // Store the limiter + e.limiters.Store(key, limiter) + + // Note: We don't store options for default limiters since they can be recreated + // Only custom limiters (added via Add()) get their options stored + + return limiter +} + +// AddAndTake adds a key with custom settings if not present and then takes a token +func (e *AutoLimiter) AddAndTake(key string, opts ...AutoLimiterOption) error { + // Check if limiter already exists + if limiter, err := e.get(key); err == nil { + limiter.Take() + return nil + } + + // Add the key with custom settings + if err := e.Add(key, opts...); err != nil { + return err + } + + // Take a token + return e.Take(key) +} diff --git a/auto_ratelimit_test.go b/auto_ratelimit_test.go new file mode 100644 index 0000000..98527bd --- /dev/null +++ b/auto_ratelimit_test.go @@ -0,0 +1,249 @@ +package ratelimit + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAutoLimiterNew(t *testing.T) { + ctx := context.Background() + + // Test with no options + limiter := NewAutoLimiter(ctx) + require.NotNil(t, limiter) + require.NotNil(t, limiter.defaultOptions) + + // Test with unlimited option + limiter = NewAutoLimiter(ctx, WithUnlimited()) + require.NotNil(t, limiter) + require.True(t, limiter.defaultOptions.IsUnlimited) + + // Test with duration and max count + limiter = NewAutoLimiter(ctx, WithDuration(2*time.Second), WithMaxCount(50)) + require.NotNil(t, limiter) + require.Equal(t, 2*time.Second, limiter.defaultOptions.Duration) + require.Equal(t, uint(50), limiter.defaultOptions.MaxCount) + require.False(t, limiter.defaultOptions.IsUnlimited) +} + +func TestAutoLimiterAdd(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(10)) + + // Test adding a key with custom options + err := limiter.Add("key1", WithDuration(500*time.Millisecond), WithMaxCount(5)) + require.NoError(t, err) + + // Test adding the same key again (should fail) + err = limiter.Add("key1", WithDuration(time.Second), WithMaxCount(20)) + require.Error(t, err) + require.Equal(t, ErrAutoKeyAlreadyExists, err) + + // Test adding with empty key + err = limiter.Add("", WithDuration(time.Second), WithMaxCount(10)) + require.Error(t, err) + require.Contains(t, err.Error(), "empty keys not allowed") + + // Test adding with zero max count + err = limiter.Add("key2", WithDuration(time.Second), WithMaxCount(0)) + require.Error(t, err) + require.Contains(t, err.Error(), "maxcount cannot be zero") + + // Test adding with zero duration + err = limiter.Add("key3", WithDuration(0), WithMaxCount(10)) + require.Error(t, err) + require.Contains(t, err.Error(), "time duration not set") + + // Test adding unlimited key + err = limiter.Add("unlimited", WithUnlimited()) + require.NoError(t, err) +} + +func TestAutoLimiterTake(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(5)) + + // Test taking from non-existent key (should create with defaults) + err := limiter.Take("key1") + require.NoError(t, err) + + // Test taking multiple times + for i := 0; i < 4; i++ { + err = limiter.Take("key1") + require.NoError(t, err) + } + + // Test taking when limit is reached - this should block, so we test with a timeout + done := make(chan bool) + go func() { + _ = limiter.Take("key1") // This should block + done <- true + }() + + // Wait a short time to see if it blocks + select { + case <-done: + t.Fatal("Take() should block when limit is reached") + case <-time.After(100 * time.Millisecond): + // Expected behavior - it's blocking + } + + // Test taking from unlimited key + _ = limiter.Add("unlimited", WithUnlimited()) + for i := 0; i < 100; i++ { + err = limiter.Take("unlimited") + require.NoError(t, err) + } +} + +func TestAutoLimiterStop(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(10)) + + // Add some keys + _ = limiter.Add("key1", WithDuration(500*time.Millisecond), WithMaxCount(5)) + _ = limiter.Add("key2", WithDuration(time.Second), WithMaxCount(10)) + + // Test stopping specific keys + limiter.Stop("key1") + + // key1 should be stopped but options preserved + _, err := limiter.get("key1") + require.Error(t, err) + + // key2 should still exist + _, err = limiter.get("key2") + require.NoError(t, err) + + // Test stopping all keys + limiter.Stop() + + // All limiters should be stopped + _, err = limiter.get("key2") + require.Error(t, err) +} + +func TestAutoLimiterRemove(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(10)) + + // Add a key with custom options + _ = limiter.Add("key1", WithDuration(500*time.Millisecond), WithMaxCount(5)) + + // Remove the key completely + limiter.Remove("key1") + + // Both limiter and options should be gone + _, err := limiter.get("key1") + require.Error(t, err) + + // Options should also be removed + _, exists := limiter.options.Load("key1") + require.False(t, exists) +} + +func TestAutoLimiterRecreateLimiter(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(10)) + + // Add a key with custom options + _ = limiter.Add("key1", WithDuration(500*time.Millisecond), WithMaxCount(5)) + + // Stop the limiter + limiter.Stop("key1") + + // Recreate should work with stored options + recreated, err := limiter.recreateLimiter("key1") + require.NoError(t, err) + require.NotNil(t, recreated) + + // Should be able to use the recreated limiter + require.True(t, recreated.CanTake()) +} + +func TestAutoLimiterCreateOrDefault(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(10)) + + // Test creating with defaults for new key + newLimiter := limiter.createOrDefault("newkey") + require.NotNil(t, newLimiter) + require.True(t, newLimiter.CanTake()) + + // Test creating with custom options for existing key + _ = limiter.Add("custom", WithDuration(500*time.Millisecond), WithMaxCount(5)) + limiter.Stop("custom") + + // Should recreate with custom options, not defaults + recreated := limiter.createOrDefault("custom") + require.NotNil(t, recreated) + + // Should be able to take 5 tokens (custom limit), not 10 (default) + for i := 0; i < 5; i++ { + require.True(t, recreated.CanTake()) + recreated.Take() + } + require.False(t, recreated.CanTake()) +} + +func TestAutoLimiterAddAndTake(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(100*time.Millisecond), WithMaxCount(10)) + + // Test AddAndTake on new key + err := limiter.AddAndTake("key1", WithDuration(100*time.Millisecond), WithMaxCount(3)) + require.NoError(t, err) + + // Test AddAndTake on existing key (should just take a token, not change the limit) + err = limiter.AddAndTake("key1", WithDuration(time.Second), WithMaxCount(10)) + require.NoError(t, err) +} + +func TestAutoLimiterMemoryEfficiency(t *testing.T) { + ctx := context.Background() + limiter := NewAutoLimiter(ctx, WithDuration(time.Second), WithMaxCount(10)) + + // Add many keys + for i := 0; i < 100; i++ { + key := fmt.Sprintf("key%d", i) + _ = limiter.Add(key, WithDuration(time.Second), WithMaxCount(5)) + } + + // Stop some keys + for i := 0; i < 50; i++ { + key := fmt.Sprintf("key%d", i) + limiter.Stop(key) + } + + // Remove some keys completely + for i := 50; i < 75; i++ { + key := fmt.Sprintf("key%d", i) + limiter.Remove(key) + } + + // Check that stopped keys still have options stored + for i := 0; i < 50; i++ { + key := fmt.Sprintf("key%d", i) + _, exists := limiter.options.Load(key) + require.True(t, exists) + } + + // Check that removed keys have no options + for i := 50; i < 75; i++ { + key := fmt.Sprintf("key%d", i) + _, exists := limiter.options.Load(key) + require.False(t, exists) + } + + // Check that active keys still work + for i := 75; i < 100; i++ { + key := fmt.Sprintf("key%d", i) + // Just test that we can take a token (this will create the limiter if it doesn't exist) + err := limiter.Take(key) + require.NoError(t, err) + } +}