Skip to content

Commit

Permalink
Fix and abstract away HTTP RateLimiter
Browse files Browse the repository at this point in the history
  • Loading branch information
hhvrc committed Nov 15, 2024
1 parent 5c0eb09 commit 554f36f
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 105 deletions.
38 changes: 38 additions & 0 deletions include/RateLimiter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include "Common.h"
#include "SimpleMutex.h"

#include <cstdint>
#include <vector>

namespace OpenShock {
class RateLimiter {
DISABLE_COPY(RateLimiter);
DISABLE_MOVE(RateLimiter);

public:
RateLimiter();
~RateLimiter();

void addLimit(uint32_t durationMs, uint16_t count);
void clearLimits();

bool tryRequest();
void clearRequests();

void blockFor(int64_t blockForMs);

private:
struct Limit {
int64_t durationMs;
uint16_t count;
};

OpenShock::SimpleMutex m_mutex;
int64_t m_nextSlot;
int64_t m_nextCleanup;
std::vector<Limit> m_limits;
std::vector<int64_t> m_requests;
};
} // namespace OpenShock
121 changes: 121 additions & 0 deletions src/RateLimiter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include <freertos/FreeRTOS.h>

#include "RateLimiter.h"

#include "Time.h"

#include <algorithm>

const char* const TAG = "RateLimiter";

OpenShock::RateLimiter::RateLimiter()
: m_mutex()
, m_nextSlot(0)
, m_nextCleanup(0)
, m_limits()
, m_requests()
{
}

OpenShock::RateLimiter::~RateLimiter()
{
}

void OpenShock::RateLimiter::addLimit(uint32_t durationMs, uint16_t count)
{
m_mutex.lock(portMAX_DELAY);

// Insert sorted
m_limits.insert(std::upper_bound(m_limits.begin(), m_limits.end(), durationMs, [](int64_t durationMs, const Limit& limit) { return durationMs < limit.durationMs; }), {durationMs, count});

m_nextSlot = 0;
m_nextCleanup = 0;

m_mutex.unlock();
}

void OpenShock::RateLimiter::clearLimits()
{
m_mutex.lock(portMAX_DELAY);

m_limits.clear();

m_mutex.unlock();
}

bool OpenShock::RateLimiter::tryRequest()
{
int64_t now = OpenShock::millis();

OpenShock::ScopedLock lock__(&m_mutex);

if (m_limits.empty()) {
return true;
}
if (m_requests.empty()) {
m_requests.push_back(now);
return true;
}

if (m_nextCleanup <= now) {
int64_t longestLimit = m_limits.back().durationMs;
int64_t expiresAt = now - longestLimit;

auto nextToExpire = std::find_if(m_requests.begin(), m_requests.end(), [expiresAt](int64_t requestedAtMs) { return requestedAtMs > expiresAt; });
if (nextToExpire != m_requests.end()) {
m_requests.erase(m_requests.begin(), nextToExpire);
}

m_nextCleanup = m_requests.front() + longestLimit;
}

if (m_nextSlot > now) {
return false;
}

// Check if we've exceeded any limits, starting with the highest limit first
for (std::size_t i = m_limits.size(); i > 0;) {
const auto& limit = m_limits[--i];

// Calculate the window start time
int64_t windowStart = now - limit.durationMs;

// Check how many requests are inside the limit window
std::size_t insideWindow = 0;
for (int64_t request : m_requests) {
if (request > windowStart) {
insideWindow++;
}
}

// If the window is full, set the wait time until its available, and reject the request
if (insideWindow >= limit.count) {
m_nextSlot = m_requests.back() + limit.durationMs;
return false;
}
}

// Add the request
m_requests.push_back(now);

return true;
}
void OpenShock::RateLimiter::clearRequests()
{
m_mutex.lock(portMAX_DELAY);

m_requests.clear();

m_mutex.unlock();
}

void OpenShock::RateLimiter::blockFor(int64_t blockForMs)
{
int64_t blockUntil = OpenShock::millis() + blockForMs;

m_mutex.lock(portMAX_DELAY);

m_nextSlot = std::max(m_nextSlot, blockUntil);

m_mutex.unlock();
}
116 changes: 11 additions & 105 deletions src/http/HTTPRequestManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const char* const TAG = "HTTPRequestManager";

#include "Common.h"
#include "Logging.h"
#include "RateLimiter.h"
#include "SimpleMutex.h"
#include "Time.h"
#include "util/StringUtils.h"
Expand All @@ -22,100 +23,8 @@ using namespace std::string_view_literals;
const std::size_t HTTP_BUFFER_SIZE = 4096LLU;
const int HTTP_DOWNLOAD_SIZE_LIMIT = 200 * 1024 * 1024; // 200 MB

struct RateLimit {
RateLimit()
: m_mutex()
, m_blockUntilMs(0)
, m_limits()
, m_requests()
{
}

void addLimit(uint32_t durationMs, uint16_t count)
{
m_mutex.lock(portMAX_DELAY);

// Insert sorted
m_limits.insert(std::upper_bound(m_limits.begin(), m_limits.end(), durationMs, [](int64_t durationMs, const Limit& limit) { return durationMs > limit.durationMs; }), {durationMs, count});

m_mutex.unlock();
}

void clearLimits()
{
m_mutex.lock(portMAX_DELAY);

m_limits.clear();

m_mutex.unlock();
}

bool tryRequest()
{
int64_t now = OpenShock::millis();

OpenShock::ScopedLock lock__(&m_mutex);

if (m_blockUntilMs > now) {
return false;
}

// Remove all requests that are older than the biggest limit
while (!m_requests.empty() && m_requests.front() < now - m_limits.back().durationMs) {
m_requests.erase(m_requests.begin());
}

// Check if we've exceeded any limits
auto it = std::find_if(m_limits.begin(), m_limits.end(), [this](const RateLimit::Limit& limit) { return m_requests.size() >= limit.count; });
if (it != m_limits.end()) {
m_blockUntilMs = now + it->durationMs;
return false;
}

// Add the request
m_requests.push_back(now);

return true;
}
void clearRequests()
{
m_mutex.lock(portMAX_DELAY);

m_requests.clear();

m_mutex.unlock();
}

void blockUntil(int64_t blockUntilMs)
{
m_mutex.lock(portMAX_DELAY);

m_blockUntilMs = blockUntilMs;

m_mutex.unlock();
}

uint32_t requestsSince(int64_t sinceMs)
{
OpenShock::ScopedLock lock__(&m_mutex);

return std::count_if(m_requests.begin(), m_requests.end(), [sinceMs](int64_t requestMs) { return requestMs >= sinceMs; });
}

private:
struct Limit {
int64_t durationMs;
uint16_t count;
};

OpenShock::SimpleMutex m_mutex;
int64_t m_blockUntilMs;
std::vector<Limit> m_limits;
std::vector<int64_t> m_requests;
};

static OpenShock::SimpleMutex s_rateLimitsMutex = {};
static std::unordered_map<std::string, std::shared_ptr<RateLimit>> s_rateLimits = {};
static OpenShock::SimpleMutex s_rateLimitsMutex = {};
static std::unordered_map<std::string, std::shared_ptr<OpenShock::RateLimiter>> s_rateLimits = {};

using namespace OpenShock;

Expand Down Expand Up @@ -156,9 +65,9 @@ std::string_view _getDomain(std::string_view url)
return url;
}

std::shared_ptr<RateLimit> _rateLimitFactory(std::string_view domain)
std::shared_ptr<OpenShock::RateLimiter> _rateLimiterFactory(std::string_view domain)
{
auto rateLimit = std::make_shared<RateLimit>();
auto rateLimit = std::make_shared<OpenShock::RateLimiter>();

// Add default limits
rateLimit->addLimit(1000, 5); // 5 per second
Expand All @@ -173,7 +82,7 @@ std::shared_ptr<RateLimit> _rateLimitFactory(std::string_view domain)
return rateLimit;
}

std::shared_ptr<RateLimit> _getRateLimiter(std::string_view url)
std::shared_ptr<OpenShock::RateLimiter> _getRateLimiter(std::string_view url)
{
auto domain = std::string(_getDomain(url));
if (domain.empty()) {
Expand All @@ -184,7 +93,7 @@ std::shared_ptr<RateLimit> _getRateLimiter(std::string_view url)

auto it = s_rateLimits.find(domain);
if (it == s_rateLimits.end()) {
s_rateLimits.emplace(domain, _rateLimitFactory(domain));
s_rateLimits.emplace(domain, _rateLimiterFactory(domain));
it = s_rateLimits.find(domain);
}

Expand Down Expand Up @@ -469,7 +378,7 @@ HTTP::Response<std::size_t> _doGetStream(
std::string_view url,
const std::map<String, String>& headers,
const std::vector<int>& acceptedCodes,
std::shared_ptr<RateLimit> rateLimiter,
std::shared_ptr<OpenShock::RateLimiter> rateLimiter,
HTTP::GotContentLengthCallback contentLengthCallback,
HTTP::DownloadCallback downloadCallback,
uint32_t timeoutMs
Expand Down Expand Up @@ -509,11 +418,8 @@ HTTP::Response<std::size_t> _doGetStream(
retryAfter = 15;
}

// Get the block-until time
int64_t blockUntilMs = OpenShock::millis() + retryAfter * 1000;

// Apply the block-until time
rateLimiter->blockUntil(blockUntilMs);
// Apply the block-for time
rateLimiter->blockFor(retryAfter * 1000);

return {HTTP::RequestResult::RateLimited, responseCode, 0};
}
Expand Down Expand Up @@ -563,7 +469,7 @@ HTTP::Response<std::size_t> _doGetStream(
HTTP::Response<std::size_t>
HTTP::Download(std::string_view url, const std::map<String, String>& headers, HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, const std::vector<int>& acceptedCodes, uint32_t timeoutMs)
{
std::shared_ptr<RateLimit> rateLimiter = _getRateLimiter(url);
std::shared_ptr<OpenShock::RateLimiter> rateLimiter = _getRateLimiter(url);
if (rateLimiter == nullptr) {
return {RequestResult::InvalidURL, 0, 0};
}
Expand Down

0 comments on commit 554f36f

Please sign in to comment.