Skip to content

Commit

Permalink
Fix rate limiting behind proxy, make configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Heckel committed Nov 5, 2021
1 parent 86a16e3 commit 0170f67
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 43 deletions.
19 changes: 17 additions & 2 deletions cmd/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ func New() *cli.App {
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: config.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: config.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: config.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: config.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
}
return &cli.App{
Name: "ntfy",
Expand All @@ -50,6 +55,11 @@ func execRun(c *cli.Context) error {
cacheDuration := c.Duration("cache-duration")
keepaliveInterval := c.Duration("keepalive-interval")
managerInterval := c.Duration("manager-interval")
globalTopicLimit := c.Int("global-topic-limit")
visitorSubscriptionLimit := c.Int("visitor-subscription-limit")
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
behindProxy := c.Bool("behind-proxy")

// Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
Expand All @@ -69,6 +79,11 @@ func execRun(c *cli.Context) error {
conf.CacheDuration = cacheDuration
conf.KeepaliveInterval = keepaliveInterval
conf.ManagerInterval = managerInterval
conf.GlobalTopicLimit = globalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.BehindProxy = behindProxy
s, err := server.New(conf)
if err != nil {
log.Fatalln(err)
Expand Down
57 changes: 29 additions & 28 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package config

import (
"golang.org/x/time/rate"
"time"
)

Expand All @@ -15,42 +14,44 @@ const (
)

// Defines all the limits
// - request limit: max number of PUT/GET/.. requests (here: 50 requests bucket, replenished at a rate of one per 10 seconds)
// - global topic limit: max number of topics overall
// - subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
var (
defaultGlobalTopicLimit = 5000
defaultVisitorRequestLimit = rate.Every(10 * time.Second)
defaultVisitorRequestLimitBurst = 60
defaultVisitorSubscriptionLimit = 30
// - per visistor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
// - per visistor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
const (
DefaultGlobalTopicLimit = 5000
DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 10 * time.Second
DefaultVisitorSubscriptionLimit = 30
)

// Config is the main config struct for the application. Use New to instantiate a default config struct.
type Config struct {
ListenHTTP string
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
KeepaliveInterval time.Duration
ManagerInterval time.Duration
GlobalTopicLimit int
VisitorRequestLimit rate.Limit
VisitorRequestLimitBurst int
VisitorSubscriptionLimit int
ListenHTTP string
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
KeepaliveInterval time.Duration
ManagerInterval time.Duration
GlobalTopicLimit int
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorSubscriptionLimit int
BehindProxy bool
}

// New instantiates a default new config
func New(listenHTTP string) *Config {
return &Config{
ListenHTTP: listenHTTP,
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
GlobalTopicLimit: defaultGlobalTopicLimit,
VisitorRequestLimit: defaultVisitorRequestLimit,
VisitorRequestLimitBurst: defaultVisitorRequestLimitBurst,
VisitorSubscriptionLimit: defaultVisitorSubscriptionLimit,
ListenHTTP: listenHTTP,
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
GlobalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
BehindProxy: false,
}
}
26 changes: 25 additions & 1 deletion config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@
#
# keepalive-interval: 30s

# Interval in which the manager prunes old messages, deletes topics and prints the stats.
# Interval in which the manager prunes old messages, deletes topics
# and prints the stats.
#
# manager-interval: 1m

# Rate limiting: Total number of topics before the server rejects new topics.
#
# global-topic-limit: 5000

# Rate limiting: Number of subscriptions per visitor (IP address)
#
# visitor-subscription-limit: 30

# Rate limiting: Allowed GET/PUT/POST requests per second, per visitor:
# - visitor-request-limit-burst is the initial bucket of requests each visitor has
# - visitor-request-limit-replenish is the rate at which the bucket is refilled
#
# visitor-request-limit-burst: 60
# visitor-request-limit-replenish: 10s

# If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection.
#
# WARNING: If you are behind a proxy, you must set this, otherwise all visitors are rate limited
# as if they are one.
#
# behind-proxy: false
36 changes: 25 additions & 11 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,22 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
v := s.visitor(r.RemoteAddr)
if err := v.RequestAllowed(); err != nil {
return err
}
if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.handleHome(w, r)
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
return s.handleEmpty(w, r)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
return s.handleStatic(w, r)
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.handlePublish(w, r, v)
return s.withRateLimit(w, r, s.handlePublish)
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
return s.handleSubscribeJSON(w, r, v)
return s.withRateLimit(w, r, s.handleSubscribeJSON)
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
return s.handleSubscribeSSE(w, r, v)
return s.withRateLimit(w, r, s.handleSubscribeSSE)
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
return s.handleSubscribeRaw(w, r, v)
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
return s.withRateLimit(w, r, s.handleSubscribeRaw)
}
return errHTTPNotFound
}
Expand All @@ -186,6 +184,10 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
return err
}

func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
return nil
}

func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
return nil
Expand Down Expand Up @@ -394,15 +396,27 @@ func (s *Server) updateStatsAndExpire() {
s.messages, len(s.topics), subscribers, messages, len(s.visitors))
}

func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
v := s.visitor(r)
if err := v.RequestAllowed(); err != nil {
return err
}
return handler(w, r, v)
}

// visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(remoteAddr string) *visitor {
func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr // This should not happen in real life; only in tests.
}
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For")
}
v, exists := s.visitors[ip]
if !exists {
s.visitors[ip] = newVisitor(s.config)
Expand Down
2 changes: 1 addition & 1 deletion server/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type visitor struct {
func newVisitor(conf *config.Config) *visitor {
return &visitor{
config: conf,
limiter: rate.NewLimiter(conf.VisitorRequestLimit, conf.VisitorRequestLimitBurst),
limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
seen: time.Now(),
}
Expand Down

0 comments on commit 0170f67

Please sign in to comment.