-
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathratelimiter.go
196 lines (171 loc) · 5.29 KB
/
ratelimiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
package caddywaf
import (
"fmt"
"log"
"regexp"
"sync"
"time"
)
// requestCounter struct
type requestCounter struct {
count int
window time.Time
}
// RateLimit struct
type RateLimit struct {
Requests int `json:"requests"`
Window time.Duration `json:"window"`
CleanupInterval time.Duration `json:"cleanup_interval"`
Paths []string `json:"paths,omitempty"` // Optional paths to apply rate limit
PathRegexes []*regexp.Regexp `json:"-"` // Compiled regexes for the given paths
MatchAllPaths bool `json:"match_all_paths,omitempty"`
}
// RateLimiter struct
type RateLimiter struct {
sync.RWMutex
requests map[string]map[string]*requestCounter // Nested map for path-based rate limiting
config RateLimit
stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop
totalRequests int64 // Total requests received by this rate limiter
blockedRequests int64 // Total requests blocked by this rate limiter
muMetrics sync.RWMutex // Mutex to protect metrics access
}
// NewRateLimiter creates a new RateLimiter instance.
func NewRateLimiter(config RateLimit) (*RateLimiter, error) {
// Compile path regexes if paths are provided
if len(config.Paths) > 0 {
config.PathRegexes = make([]*regexp.Regexp, len(config.Paths))
for i, path := range config.Paths {
var err error
config.PathRegexes[i], err = regexp.Compile(path)
if err != nil {
return nil, fmt.Errorf("failed to compile regex for path %s: %v", path, err)
}
}
}
return &RateLimiter{
requests: make(map[string]map[string]*requestCounter),
config: config,
stopCleanup: make(chan struct{}), // Initialize the stopCleanup channel
}, nil
}
// isRateLimited checks if a given IP is rate limited for a specific path.
func (rl *RateLimiter) isRateLimited(ip, path string) bool {
now := time.Now()
rl.Lock() // Use Lock for write operations or potential creation of nested maps.
defer rl.Unlock()
rl.incrementTotalRequestsMetric() // Increment the total requests received
var key string
if rl.config.MatchAllPaths {
key = ip
} else {
// Check if path is matching
if len(rl.config.PathRegexes) > 0 {
matched := false
for _, regex := range rl.config.PathRegexes {
if regex.MatchString(path) {
matched = true
break
}
}
if !matched {
return false // Path does not match any configured paths, no rate limiting
}
}
key = ip + path
}
// Initialize the nested map if it doesn't exist
if _, exists := rl.requests[ip]; !exists {
rl.requests[ip] = make(map[string]*requestCounter)
}
// Get or create the counter for the specific key (path + ip)
counter, exists := rl.requests[ip][key]
if exists {
if now.Sub(counter.window) > rl.config.Window {
// Window expired, reset the counter
rl.requests[ip][key] = &requestCounter{count: 1, window: now}
return false
}
// Window not expired, increment the counter
counter.count++
if counter.count > rl.config.Requests {
rl.incrementBlockedRequestsMetric() // Increment if the request is going to be blocked.
return true
}
return false
}
// IP and path combination doesn't exist, add it
rl.requests[ip][key] = &requestCounter{count: 1, window: now}
return false
}
// cleanupExpiredEntries removes expired entries from the rate limiter.
func (rl *RateLimiter) cleanupExpiredEntries() {
now := time.Now()
rl.Lock()
defer rl.Unlock()
for ip, pathCounters := range rl.requests {
for path, counter := range pathCounters {
if now.Sub(counter.window) > rl.config.Window {
delete(pathCounters, path)
}
}
if len(pathCounters) == 0 {
delete(rl.requests, ip)
}
}
}
// startCleanup starts the goroutine to periodically clean up expired entries.
func (rl *RateLimiter) startCleanup() {
go func() {
// log.Println("[INFO] Starting rate limiter cleanup goroutine")
ticker := time.NewTicker(rl.config.CleanupInterval)
defer func() {
ticker.Stop()
// log.Println("[INFO] Rate limiter cleanup goroutine stopped")
}()
for {
select {
case <-ticker.C:
rl.cleanupExpiredEntries()
case <-rl.stopCleanup:
return
}
}
}()
}
// signalStopCleanup signals the cleanup goroutine to stop.
func (rl *RateLimiter) signalStopCleanup() {
rl.Lock()
defer rl.Unlock()
select {
case <-rl.stopCleanup:
// Channel already closed, do nothing
default:
log.Println("[INFO] Signaling rate limiter cleanup goroutine to stop")
close(rl.stopCleanup)
}
}
// GetTotalRequests returns the total number of requests received by this rate limiter.
func (rl *RateLimiter) GetTotalRequests() int64 {
rl.muMetrics.RLock()
defer rl.muMetrics.RUnlock()
return rl.totalRequests
}
// GetBlockedRequests returns the total number of requests blocked by this rate limiter.
func (rl *RateLimiter) GetBlockedRequests() int64 {
rl.muMetrics.RLock()
defer rl.muMetrics.RUnlock()
return rl.blockedRequests
}
// incrementTotalRequestsMetric increments the total requests counter
func (rl *RateLimiter) incrementTotalRequestsMetric() {
rl.muMetrics.Lock()
defer rl.muMetrics.Unlock()
rl.totalRequests++
}
// incrementBlockedRequestsMetric increments the blocked requests counter
func (rl *RateLimiter) incrementBlockedRequestsMetric() {
rl.muMetrics.Lock()
defer rl.muMetrics.Unlock()
rl.blockedRequests++
}