Skip to content

Commit 5271e76

Browse files
authored
feat: rate-limit proxy requests
1 parent 81c9551 commit 5271e76

File tree

9 files changed

+472
-150
lines changed

9 files changed

+472
-150
lines changed

Cargo.lock

+331-35
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ clap = { version = "4.4", features = ["derive", "env", "cargo"] }
3737
# other utils
3838
dotenvy = "0.15"
3939
url = "2.4"
40+
tower_governor = "0.4"
4041

4142
[build-dependencies]
4243
tonic-build = { version = "0.10" }

src/config.rs

+6
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,10 @@ pub struct Config {
2525

2626
#[arg(long, env = "DEFGUARD_PROXY_LOG_LEVEL", default_value_t = LevelFilter::Info)]
2727
pub log_level: LevelFilter,
28+
29+
#[arg(long, env = "DEFGUARD_PROXY_RATELIMIT_PERSECOND", default_value_t = 0)]
30+
pub rate_limit_per_second: u64,
31+
32+
#[arg(long, env = "DEFGUARD_PROXY_RATELIMIT_BURST", default_value_t = 0)]
33+
pub rate_limit_burst: u32,
2834
}

src/grpc.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ impl ProxyServer {
3737
}
3838
}
3939

40-
/// Sends message to the other side of RPC, with given `payload` and optional 'device_info`.
40+
/// Sends message to the other side of RPC, with given `payload` and optional `device_info`.
4141
/// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
4242
#[instrument(name = "send_grpc_message", level = "debug", skip(self))]
4343
pub fn send(

src/handlers/desktop_client_mfa.rs

+17-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use axum::{extract::State, routing::post, Json, Router};
2+
use tracing::{error, info};
3+
14
use crate::{
25
error::ApiError,
36
handlers::get_core_response,
@@ -7,9 +10,8 @@ use crate::{
710
ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo,
811
},
912
};
10-
use axum::{extract::State, routing::post, Json, Router};
1113

12-
pub fn router() -> Router<AppState> {
14+
pub(crate) fn router() -> Router<AppState> {
1315
Router::new()
1416
.route("/start", post(start_client_mfa))
1517
.route("/finish", post(finish_client_mfa))
@@ -27,15 +29,13 @@ async fn start_client_mfa(
2729
device_info,
2830
)?;
2931
let payload = get_core_response(rx).await?;
30-
match payload {
31-
core_response::Payload::ClientMfaStart(response) => {
32-
info!("Started desktop client authorization {req:?}");
33-
Ok(Json(response))
34-
}
35-
_ => {
36-
error!("Received invalid gRPC response type: {payload:?}");
37-
Err(ApiError::InvalidResponseType)
38-
}
32+
33+
if let core_response::Payload::ClientMfaStart(response) = payload {
34+
info!("Started desktop client authorization {req:?}");
35+
Ok(Json(response))
36+
} else {
37+
error!("Received invalid gRPC response type: {payload:#?}");
38+
Err(ApiError::InvalidResponseType)
3939
}
4040
}
4141

@@ -51,14 +51,11 @@ async fn finish_client_mfa(
5151
device_info,
5252
)?;
5353
let payload = get_core_response(rx).await?;
54-
match payload {
55-
core_response::Payload::ClientMfaFinish(response) => {
56-
info!("Finished desktop client authorization");
57-
Ok(Json(response))
58-
}
59-
_ => {
60-
error!("Received invalid gRPC response type: {payload:?}");
61-
Err(ApiError::InvalidResponseType)
62-
}
54+
if let core_response::Payload::ClientMfaFinish(response) = payload {
55+
info!("Finished desktop client authorization");
56+
Ok(Json(response))
57+
} else {
58+
error!("Received invalid gRPC response type: {payload:#?}");
59+
Err(ApiError::InvalidResponseType)
6360
}
6461
}

src/handlers/enrollment.rs

+35-47
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,19 @@ pub async fn start_enrollment_process(
4040
.grpc_server
4141
.send(Some(core_request::Payload::EnrollmentStart(req)), None)?;
4242
let payload = get_core_response(rx).await?;
43-
match payload {
44-
core_response::Payload::EnrollmentStart(response) => {
45-
info!(
46-
"Started enrollment process for user {:?} by admin {:?}",
47-
response.user, response.admin
48-
);
49-
// set session cookie
50-
let cookie = Cookie::build((ENROLLMENT_COOKIE_NAME, token))
51-
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());
52-
53-
Ok((private_cookies.add(cookie), Json(response)))
54-
}
55-
_ => {
56-
error!("Received invalid gRPC response type: {payload:?}");
57-
Err(ApiError::InvalidResponseType)
58-
}
43+
if let core_response::Payload::EnrollmentStart(response) = payload {
44+
info!(
45+
"Started enrollment process for user {:?} by admin {:?}",
46+
response.user, response.admin
47+
);
48+
// set session cookie
49+
let cookie = Cookie::build((ENROLLMENT_COOKIE_NAME, token))
50+
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());
51+
52+
Ok((private_cookies.add(cookie), Json(response)))
53+
} else {
54+
error!("Received invalid gRPC response type: {payload:#?}");
55+
Err(ApiError::InvalidResponseType)
5956
}
6057
}
6158

@@ -78,20 +75,17 @@ pub async fn activate_user(
7875
.grpc_server
7976
.send(Some(core_request::Payload::ActivateUser(req)), device_info)?;
8077
let payload = get_core_response(rx).await?;
81-
match payload {
82-
core_response::Payload::Empty(_) => {
83-
if let Some(cookie) = private_cookies.get(ENROLLMENT_COOKIE_NAME) {
84-
info!("Activated user - phone number {phone:?}");
85-
debug!("Enrollment finished. Removing session cookie");
86-
private_cookies = private_cookies.remove(cookie);
87-
}
88-
89-
Ok(private_cookies)
90-
}
91-
_ => {
92-
error!("Received invalid gRPC response type: {payload:?}");
93-
Err(ApiError::InvalidResponseType)
78+
if let core_response::Payload::Empty(()) = payload {
79+
if let Some(cookie) = private_cookies.get(ENROLLMENT_COOKIE_NAME) {
80+
info!("Activated user - phone number {phone:?}");
81+
debug!("Enrollment finished. Removing session cookie");
82+
private_cookies = private_cookies.remove(cookie);
9483
}
84+
85+
Ok(private_cookies)
86+
} else {
87+
error!("Received invalid gRPC response type: {payload:#?}");
88+
Err(ApiError::InvalidResponseType)
9589
}
9690
}
9791

@@ -114,15 +108,12 @@ pub async fn create_device(
114108
.grpc_server
115109
.send(Some(core_request::Payload::NewDevice(req)), device_info)?;
116110
let payload = get_core_response(rx).await?;
117-
match payload {
118-
core_response::Payload::DeviceConfig(response) => {
119-
info!("Added new device {name} {pubkey}");
120-
Ok(Json(response))
121-
}
122-
_ => {
123-
error!("Received invalid gRPC response type: {payload:?}");
124-
Err(ApiError::InvalidResponseType)
125-
}
111+
if let core_response::Payload::DeviceConfig(response) = payload {
112+
info!("Added new device {name} {pubkey}");
113+
Ok(Json(response))
114+
} else {
115+
error!("Received invalid gRPC response type: {payload:#?}");
116+
Err(ApiError::InvalidResponseType)
126117
}
127118
}
128119

@@ -144,14 +135,11 @@ pub async fn get_network_info(
144135
.grpc_server
145136
.send(Some(core_request::Payload::ExistingDevice(req)), None)?;
146137
let payload = get_core_response(rx).await?;
147-
match payload {
148-
core_response::Payload::DeviceConfig(response) => {
149-
info!("Got network info for device {pubkey}");
150-
Ok(Json(response))
151-
}
152-
_ => {
153-
error!("Received invalid gRPC response type: {payload:?}");
154-
Err(ApiError::InvalidResponseType)
155-
}
138+
if let core_response::Payload::DeviceConfig(response) = payload {
139+
info!("Got network info for device {pubkey}");
140+
Ok(Json(response))
141+
} else {
142+
error!("Received invalid gRPC response type: {payload:#?}");
143+
Err(ApiError::InvalidResponseType)
156144
}
157145
}

src/handlers/mod.rs

+10-14
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,15 @@ where
4646
///
4747
/// Waits for core response with a given timeout and returns the response payload.
4848
async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
49-
match timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
50-
Ok(core_response) => {
51-
debug!("Got gRPC response from Defguard core: {core_response:?}");
52-
if let Ok(Payload::CoreError(core_error)) = core_response {
53-
return Err(core_error.into());
54-
};
55-
core_response.map_err(|err| {
56-
ApiError::Unexpected(format!("Failed to receive core response: {err}"))
57-
})
58-
}
59-
Err(_) => {
60-
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT} seconds");
61-
Err(ApiError::CoreTimeout)
62-
}
49+
if let Ok(core_response) = timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
50+
debug!("Got gRPC response from Defguard core: {core_response:?}");
51+
if let Ok(Payload::CoreError(core_error)) = core_response {
52+
return Err(core_error.into());
53+
};
54+
core_response
55+
.map_err(|err| ApiError::Unexpected(format!("Failed to receive core response: {err}")))
56+
} else {
57+
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT} seconds");
58+
Err(ApiError::CoreTimeout)
6359
}
6460
}

src/handlers/password_reset.rs

+23-32
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,12 @@ pub async fn request_password_reset(
3232
device_info,
3333
)?;
3434
let payload = get_core_response(rx).await?;
35-
match payload {
36-
core_response::Payload::Empty(_) => {
37-
info!("Started password reset request for {}", req.email);
38-
Ok(())
39-
}
40-
_ => {
41-
error!("Received invalid gRPC response type: {payload:?}");
42-
Err(ApiError::InvalidResponseType)
43-
}
35+
if let core_response::Payload::Empty(()) = payload {
36+
info!("Started password reset request for {}", req.email);
37+
Ok(())
38+
} else {
39+
error!("Received invalid gRPC response type: {payload:#?}");
40+
Err(ApiError::InvalidResponseType)
4441
}
4542
}
4643

@@ -66,19 +63,16 @@ pub async fn start_password_reset(
6663
device_info,
6764
)?;
6865
let payload = get_core_response(rx).await?;
69-
match payload {
70-
core_response::Payload::PasswordResetStart(response) => {
71-
// set session cookie
72-
let cookie = Cookie::build((PASSWORD_RESET_COOKIE_NAME, token))
73-
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());
66+
if let core_response::Payload::PasswordResetStart(response) = payload {
67+
// set session cookie
68+
let cookie = Cookie::build((PASSWORD_RESET_COOKIE_NAME, token))
69+
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());
7470

75-
info!("Started password reset process");
76-
Ok((private_cookies.add(cookie), Json(response)))
77-
}
78-
_ => {
79-
error!("Received invalid gRPC response type: {payload:?}");
80-
Err(ApiError::InvalidResponseType)
81-
}
71+
info!("Started password reset process");
72+
Ok((private_cookies.add(cookie), Json(response)))
73+
} else {
74+
error!("Received invalid gRPC response type: {payload:#?}");
75+
Err(ApiError::InvalidResponseType)
8276
}
8377
}
8478

@@ -100,17 +94,14 @@ pub async fn reset_password(
10094
.grpc_server
10195
.send(Some(core_request::Payload::PasswordReset(req)), device_info)?;
10296
let payload = get_core_response(rx).await?;
103-
match payload {
104-
core_response::Payload::Empty(_) => {
105-
if let Some(cookie) = private_cookies.get(PASSWORD_RESET_COOKIE_NAME) {
106-
info!("Password reset finished. Removing session cookie");
107-
private_cookies = private_cookies.remove(cookie);
108-
}
109-
Ok(private_cookies)
110-
}
111-
_ => {
112-
error!("Received invalid gRPC response type: {payload:?}");
113-
Err(ApiError::InvalidResponseType)
97+
if let core_response::Payload::Empty(()) = payload {
98+
if let Some(cookie) = private_cookies.get(PASSWORD_RESET_COOKIE_NAME) {
99+
info!("Password reset finished. Removing session cookie");
100+
private_cookies = private_cookies.remove(cookie);
114101
}
102+
Ok(private_cookies)
103+
} else {
104+
error!("Received invalid gRPC response type: {payload:#?}");
105+
Err(ApiError::InvalidResponseType)
115106
}
116107
}

src/http.rs

+48-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::{
22
fs::read_to_string,
33
net::{IpAddr, Ipv4Addr, SocketAddr},
4+
time::Duration,
45
};
56

67
use anyhow::Context;
@@ -17,6 +18,9 @@ use clap::crate_version;
1718
use serde::Serialize;
1819
use tokio::{net::TcpListener, task::JoinSet};
1920
use tonic::transport::{Identity, Server, ServerTlsConfig};
21+
use tower_governor::{
22+
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
23+
};
2024
use tower_http::{
2125
services::{ServeDir, ServeFile},
2226
trace::{self, TraceLayer},
@@ -33,6 +37,7 @@ use crate::{
3337

3438
pub(crate) static ENROLLMENT_COOKIE_NAME: &str = "defguard_proxy";
3539
pub(crate) static PASSWORD_RESET_COOKIE_NAME: &str = "defguard_proxy_password_reset";
40+
const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60);
3641

3742
#[derive(Clone)]
3843
pub(crate) struct AppState {
@@ -133,7 +138,44 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
133138
let serve_web_dir = ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html"));
134139
let serve_images =
135140
ServeDir::new("web/src/shared/images/svg").not_found_service(handle_404.into_service());
136-
let app = Router::new()
141+
142+
// Setup tower_governor rate-limiter
143+
debug!(
144+
"Configuring rate limiter, per_second: {}, burst: {}",
145+
config.rate_limit_per_second, config.rate_limit_burst
146+
);
147+
let governor_conf = GovernorConfigBuilder::default()
148+
.key_extractor(SmartIpKeyExtractor)
149+
.per_second(config.rate_limit_per_second)
150+
.burst_size(config.rate_limit_burst)
151+
.finish();
152+
153+
let governor_conf = if let Some(conf) = governor_conf {
154+
let governor_limiter = conf.limiter().clone();
155+
156+
// Start background task to cleanup rate-limiter data
157+
tokio::spawn(async move {
158+
loop {
159+
tokio::time::sleep(RATE_LIMITER_CLEANUP_PERIOD).await;
160+
tracing::debug!(
161+
"Cleaning-up rate limiter storage, current size: {}",
162+
governor_limiter.len()
163+
);
164+
governor_limiter.retain_recent();
165+
}
166+
});
167+
info!(
168+
"Configured rate limiter, per_second: {}, burst: {}",
169+
config.rate_limit_per_second, config.rate_limit_burst
170+
);
171+
Some(conf)
172+
} else {
173+
info!("Skipping rate limiter setup");
174+
None
175+
};
176+
177+
// Build axum app
178+
let mut app = Router::new()
137179
.nest(
138180
"/api/v1",
139181
Router::new()
@@ -161,6 +203,11 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
161203
})
162204
.on_response(trace::DefaultOnResponse::new().level(Level::DEBUG)),
163205
);
206+
if let Some(conf) = governor_conf {
207+
app = app.layer(GovernorLayer {
208+
config: conf.into(),
209+
});
210+
}
164211
debug!("Configured API server routing: {app:?}");
165212

166213
// Start web server.

0 commit comments

Comments
 (0)