Skip to content

Commit 05399b6

Browse files
committed
feat: add RateLimiter utility class
1 parent d941755 commit 05399b6

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed

src/gradient/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
validate_client_credentials as validate_client_credentials,
3434
validate_client_instance as validate_client_instance,
3535
ResponseCache as ResponseCache,
36+
RateLimiter as RateLimiter,
3637
)
3738
from ._compat import (
3839
get_args as get_args,

src/gradient/_utils/_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,51 @@ def size(self) -> int:
499499
return len(self._cache)
500500

501501

502+
# Rate Limiting Classes
503+
class RateLimiter:
504+
"""Simple token bucket rate limiter."""
505+
506+
def __init__(self, requests_per_minute: int = 60) -> None:
507+
"""Initialize rate limiter.
508+
509+
Args:
510+
requests_per_minute: Maximum requests allowed per minute
511+
"""
512+
self.requests_per_minute: int = requests_per_minute
513+
self.tokens: float = float(requests_per_minute)
514+
self.last_refill: float = self._now()
515+
self.refill_rate: float = requests_per_minute / 60.0 # tokens per second
516+
517+
def _now(self) -> float:
518+
"""Get current time in seconds."""
519+
import time
520+
return time.time()
521+
522+
def _refill(self) -> None:
523+
"""Refill tokens based on elapsed time."""
524+
now = self._now()
525+
elapsed = now - self.last_refill
526+
self.tokens = min(self.requests_per_minute, self.tokens + elapsed * self.refill_rate)
527+
self.last_refill = now
528+
529+
def acquire(self, tokens: int = 1) -> bool:
530+
"""Try to acquire tokens. Returns True if successful."""
531+
self._refill()
532+
if self.tokens >= tokens:
533+
self.tokens -= tokens
534+
return True
535+
return False
536+
537+
def wait_time(self, tokens: int = 1) -> float:
538+
"""Get seconds to wait for tokens to be available."""
539+
self._refill()
540+
if self.tokens >= tokens:
541+
return 0.0
542+
543+
needed = tokens - self.tokens
544+
return needed / self.refill_rate
545+
546+
502547
# API Key Validation Functions
503548
def validate_api_key(api_key: str | None) -> bool:
504549
"""Validate an API key format.

tests/test_rate_limiter.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Tests for rate limiting functionality."""
2+
3+
import time
4+
import pytest
5+
from gradient._utils import RateLimiter
6+
7+
8+
class TestRateLimiter:
9+
"""Test rate limiting functionality."""
10+
11+
def test_rate_limiter_basic(self):
12+
"""Test basic rate limiter operations."""
13+
limiter = RateLimiter(requests_per_minute=10)
14+
15+
# Should allow initial requests
16+
assert limiter.acquire() is True
17+
assert limiter.acquire() is True
18+
19+
# Should deny when tokens exhausted
20+
limiter.tokens = 0 # Force exhaustion
21+
assert limiter.acquire() is False
22+
23+
def test_rate_limiter_wait_time(self):
24+
"""Test wait time calculation."""
25+
limiter = RateLimiter(requests_per_minute=60) # 1 request per second
26+
27+
# Exhaust tokens
28+
limiter.tokens = 0
29+
30+
# Should calculate correct wait time
31+
wait_time = limiter.wait_time()
32+
assert wait_time > 0
33+
assert wait_time <= 1.0 # Should not exceed 1 second
34+
35+
def test_rate_limiter_refill(self):
36+
"""Test token refill over time."""
37+
limiter = RateLimiter(requests_per_minute=60) # 1 token per second
38+
39+
# Exhaust tokens
40+
limiter.tokens = 0
41+
start_time = limiter._now()
42+
43+
# Wait for refill
44+
time.sleep(0.1)
45+
46+
# Should have refilled some tokens
47+
limiter._refill()
48+
assert limiter.tokens > 0
49+
50+
def test_rate_limiter_custom_rate(self):
51+
"""Test custom rate limits."""
52+
limiter = RateLimiter(requests_per_minute=120) # 2 requests per second
53+
54+
# Should have double the tokens of default
55+
assert limiter.requests_per_minute == 120
56+
assert limiter.refill_rate == 2.0

0 commit comments

Comments
 (0)