Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request caching + add ratelimit-over-408 + endpoint rewrite support #13

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
nirn-proxy
54 changes: 54 additions & 0 deletions lib/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package lib

import (
"net/http"
"time"
)

type CacheEntry struct {
Data []byte
CreatedAt time.Time
ExpiresIn time.Duration
Headers http.Header
}

func (c *CacheEntry) Expired() bool {
return time.Since(c.CreatedAt) > c.ExpiresIn
}

type Cache struct {
entries map[string]*CacheEntry
}

func NewCache() *Cache {
return &Cache{
entries: make(map[string]*CacheEntry),
}
}

func (c *Cache) Get(key string) *CacheEntry {
entry, ok := c.entries[key]

if !ok {
return nil
}

if entry.Expired() {
c.Delete(key)
return nil
}

return entry
}

func (c *Cache) Set(key string, entry *CacheEntry) {
c.entries[key] = entry
}

func (c *Cache) Delete(key string) {
delete(c.entries, key)
}

func (c *Cache) Clear() {
c.entries = make(map[string]*CacheEntry)
}
244 changes: 232 additions & 12 deletions lib/discord.go
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
package lib

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"github.com/sirupsen/logrus"
"io"
"io/ioutil"
"math"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/sirupsen/logrus"
)

var client *http.Client
Expand All @@ -22,8 +25,98 @@ var contextTimeout time.Duration

var globalOverrideMap = make(map[string]uint)

var endpointCache = make(map[string]*Cache)

var disableRestLimitDetection = false

// List of endpoints to cache and their expiry times
var useEndpointCache bool
var cacheEndpoints = map[string]time.Duration{
"/api/users/@me": 10 * time.Minute,
"/api/v*/users/@me": 10 * time.Minute,
"/api/gateway": 60 * time.Minute,
"/api/v*/gateway": 60 * time.Minute,
"/api/gateway/*": 30 * time.Minute,
"/api/v*/gateway/*": 30 * time.Minute,
"/api/v*/applications/@me": 5 * time.Minute,
}

// In some cases, we may want to transparently rewrite endpoints
//
// For example, when using a gateway proxy, the proxy may provide its own /api/gateway/bot endpoint
//
// This allows transparently rewriting the endpoint to the proxy's
var endpointRewrite = map[string]string{}

var wsProxy string
var ratelimitOver408 bool

func init() {
if len(os.Args) > 1 {
for _, arg := range os.Args[1:] {
argSplit := strings.SplitN(arg, "=", 2)

if len(argSplit) < 2 {
argSplit = append(argSplit, "")
}

switch argSplit[0] {
case "ws-proxy":
wsProxy = argSplit[1]
case "port":
os.Setenv("PORT", argSplit[1])
case "ratelimit-over-408":
ratelimitOver408 = true
case "use-endpoint-cache":
useEndpointCache = true
case "cache-endpoints":
if argSplit[1] == "" {
continue
}

if argSplit[1] == "false" {
cacheEndpoints = make(map[string]time.Duration)
} else {
var endpoints map[string]time.Duration

err := json.Unmarshal([]byte(argSplit[1]), &endpoints)

if err != nil {
logrus.Fatal("Failed to parse cache-endpoints: ", err)
}

cacheEndpoints = endpoints
}
case "endpoint-rewrite":
for _, rewrite := range strings.Split(argSplit[1], ",") {
// split by '->'
rewriteSplit := strings.Split(rewrite, "@")

if len(rewriteSplit) != 2 {
logrus.Fatal("Invalid endpoint rewrite: ", rewrite)
}

endpointRewrite[rewriteSplit[0]] = rewriteSplit[1]
}
default:
logrus.Fatal("Unknown argument: ", argSplit[0])
}
}
}

if wsProxy == "" {
wsProxy = EnvGet("WS_PROXY", "")
}

if !ratelimitOver408 {
ratelimitOver408 = EnvGetBool("RATELIMIT_OVER_408", false)
}

if !useEndpointCache {
useEndpointCache = EnvGetBool("USE_ENDPOINT_CACHE", false)
}
}

type BotGatewayResponse struct {
SessionStartLimit map[string]int `json:"session_start_limit"`
}
Expand Down Expand Up @@ -161,7 +254,7 @@ func GetBotGlobalLimit(token string, user *BotUserResponse) (uint, error) {
return 0, errors.New("500 on gateway/bot")
}

body, _ := ioutil.ReadAll(bot.Body)
body, _ := io.ReadAll(bot.Body)

var s BotGatewayResponse

Expand Down Expand Up @@ -200,7 +293,7 @@ func GetBotUser(token string) (*BotUserResponse, error) {
return nil, errors.New("500 on users/@me")
}

body, _ := ioutil.ReadAll(bot.Body)
body, _ := io.ReadAll(bot.Body)

var s BotUserResponse

Expand All @@ -213,7 +306,63 @@ func GetBotUser(token string) (*BotUserResponse, error) {
}

func doDiscordReq(ctx context.Context, path string, method string, body io.ReadCloser, header http.Header, query string) (*http.Response, error) {
discordReq, err := http.NewRequestWithContext(ctx, method, "https://discord.com"+path+"?"+query, body)
identifier := ctx.Value("identifier")
if identifier == nil {
identifier = "Internal"
}

logger.Info(method, " ", path+"?"+query)

identifierStr, ok := identifier.(string)

if ok {
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
if useEndpointCache && identifier != "internal" {
cache, ok := endpointCache[identifierStr]

if !ok {
endpointCache[identifierStr] = NewCache()
cache = endpointCache[identifierStr]
}

// Check endpoint cache
cacheEntry := cache.Get(path)

if cacheEntry != nil {
// Send cached response
logger.WithFields(logrus.Fields{
"method": method,
"path": path,
"status": "200 (cached)",
}).Debug("Discord request")

headers := cacheEntry.Headers.Clone()
headers.Set("X-Cached", "true")

// Set rl headers so bot won't be perpetually stuck
headers.Set("X-RateLimit-Limit", "5")
headers.Set("X-RateLimit-Remaining", "5")
headers.Set("X-RateLimit-Bucket", "cache")

return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBuffer(cacheEntry.Data)),
Header: headers,
}, nil
}
}
}

// Check for a rewrite
var urlBase = "https://discord.com"
for rw := range endpointRewrite {
if ok, _ := filepath.Match(rw, path); ok {
urlBase = endpointRewrite[rw]
break

}
}

discordReq, err := http.NewRequestWithContext(ctx, method, urlBase+path+"?"+query, body)
if err != nil {
return nil, err
}
Expand All @@ -222,12 +371,6 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC
startTime := time.Now()
discordResp, err := client.Do(discordReq)

identifier := ctx.Value("identifier")
if identifier == nil {
// Queues always have an identifier, if there's none in the context, we called the method from outside a queue
identifier = "Internal"
}

if err == nil {
route := GetMetricsPath(path)
status := discordResp.Status
Expand All @@ -242,6 +385,59 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC

RequestHistogram.With(map[string]string{"route": route, "status": status, "method": method, "clientId": identifier.(string)}).Observe(elapsed)
}

if wsProxy != "" && discordResp.StatusCode == 200 {
var isGwProxyUrl bool

if strings.HasSuffix(path, "/gateway") || strings.HasSuffix(path, "/gateway/bot") {
isGwProxyUrl = true
}

if isGwProxyUrl {
var data map[string]any

err := json.NewDecoder(discordResp.Body).Decode(&data)

if err != nil {
return nil, err
}

data["url"] = wsProxy

bytes, err := json.Marshal(data)

if err != nil {
return nil, err
}

discordResp.Body = io.NopCloser(strings.NewReader(string(bytes)))
}
}

if useEndpointCache {
var expiry *time.Duration

for endpoint, exp := range cacheEndpoints {
if ok, _ := filepath.Match(endpoint, path); ok {
expiry = &exp
break
}
}

if expiry != nil && discordResp.StatusCode == 200 {
body, _ := io.ReadAll(discordResp.Body)
endpointCache[identifierStr].Set(path, &CacheEntry{
Data: body,
CreatedAt: time.Now(),
ExpiresIn: *expiry,
Headers: discordResp.Header,
})

// Put body back into response
discordResp.Body = io.NopCloser(bytes.NewBuffer(body))
}
}

return discordResp, err
}

Expand All @@ -255,7 +451,30 @@ func ProcessRequest(ctx context.Context, item *QueueItem) (*http.Response, error

if err != nil {
if ctx.Err() == context.DeadlineExceeded {
res.WriteHeader(408)
if ratelimitOver408 {
res.WriteHeader(429)
res.Header().Add("Reset-After", "3")

// Set rl headers so bot won't be perpetually stuck
if res.Header().Get("X-RateLimit-Limit") == "" {
res.Header().Set("X-RateLimit-Limit", "5")
}
if res.Header().Get("X-RateLimit-Remaining") == "" {
res.Header().Set("X-RateLimit-Remaining", "0")
}

if res.Header().Get("X-RateLimit-Bucket") == "" {
res.Header().Set("X-RateLimit-Bucket", "proxyTimeout")
}

// Default to 'shared' so the bot doesn't think its
// against them
if res.Header().Get("X-RateLimit-Scope") == "" {
res.Header().Set("X-RateLimit-Scope", "shared")
}
} else {
res.WriteHeader(408)
}
} else {
res.WriteHeader(500)
}
Expand All @@ -279,3 +498,4 @@ func ProcessRequest(ctx context.Context, item *QueueItem) (*http.Response, error

return discordResp, nil
}

7 changes: 4 additions & 3 deletions lib/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package lib
import (
"context"
"errors"
"github.com/Clever/leakybucket"
"github.com/Clever/leakybucket/memory"
"github.com/sirupsen/logrus"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/Clever/leakybucket"
"github.com/Clever/leakybucket/memory"
"github.com/sirupsen/logrus"
)

type QueueItem struct {
Expand Down