Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 331 additions & 35 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ clap = { version = "4.4", features = ["derive", "env", "cargo"] }
# other utils
dotenvy = "0.15"
url = "2.4"
tower_governor = "0.4"

[build-dependencies]
tonic-build = { version = "0.10" }
Expand Down
6 changes: 6 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ pub struct Config {

#[arg(long, env = "DEFGUARD_PROXY_LOG_LEVEL", default_value_t = LevelFilter::Info)]
pub log_level: LevelFilter,

#[arg(long, env = "DEFGUARD_PROXY_RATELIMIT_PERSECOND", default_value_t = 0)]
pub rate_limit_per_second: u64,

#[arg(long, env = "DEFGUARD_PROXY_RATELIMIT_BURST", default_value_t = 0)]
pub rate_limit_burst: u32,
}
2 changes: 1 addition & 1 deletion src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl ProxyServer {
}
}

/// Sends message to the other side of RPC, with given `payload` and optional 'device_info`.
/// Sends message to the other side of RPC, with given `payload` and optional `device_info`.
/// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
#[instrument(name = "send_grpc_message", level = "debug", skip(self))]
pub fn send(
Expand Down
37 changes: 17 additions & 20 deletions src/handlers/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use axum::{extract::State, routing::post, Json, Router};
use tracing::{error, info};

use crate::{
error::ApiError,
handlers::get_core_response,
Expand All @@ -7,9 +10,8 @@ use crate::{
ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo,
},
};
use axum::{extract::State, routing::post, Json, Router};

pub fn router() -> Router<AppState> {
pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/start", post(start_client_mfa))
.route("/finish", post(finish_client_mfa))
Expand All @@ -27,15 +29,13 @@ async fn start_client_mfa(
device_info,
)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::ClientMfaStart(response) => {
info!("Started desktop client authorization {req:?}");
Ok(Json(response))
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}

if let core_response::Payload::ClientMfaStart(response) = payload {
info!("Started desktop client authorization {req:?}");
Ok(Json(response))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

Expand All @@ -51,14 +51,11 @@ async fn finish_client_mfa(
device_info,
)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::ClientMfaFinish(response) => {
info!("Finished desktop client authorization");
Ok(Json(response))
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
if let core_response::Payload::ClientMfaFinish(response) = payload {
info!("Finished desktop client authorization");
Ok(Json(response))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
82 changes: 35 additions & 47 deletions src/handlers/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,19 @@ pub async fn start_enrollment_process(
.grpc_server
.send(Some(core_request::Payload::EnrollmentStart(req)), None)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::EnrollmentStart(response) => {
info!(
"Started enrollment process for user {:?} by admin {:?}",
response.user, response.admin
);
// set session cookie
let cookie = Cookie::build((ENROLLMENT_COOKIE_NAME, token))
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());

Ok((private_cookies.add(cookie), Json(response)))
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
if let core_response::Payload::EnrollmentStart(response) = payload {
info!(
"Started enrollment process for user {:?} by admin {:?}",
response.user, response.admin
);
// set session cookie
let cookie = Cookie::build((ENROLLMENT_COOKIE_NAME, token))
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());

Ok((private_cookies.add(cookie), Json(response)))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

Expand All @@ -78,20 +75,17 @@ pub async fn activate_user(
.grpc_server
.send(Some(core_request::Payload::ActivateUser(req)), device_info)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::Empty(_) => {
if let Some(cookie) = private_cookies.get(ENROLLMENT_COOKIE_NAME) {
info!("Activated user - phone number {phone:?}");
debug!("Enrollment finished. Removing session cookie");
private_cookies = private_cookies.remove(cookie);
}

Ok(private_cookies)
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
if let core_response::Payload::Empty(()) = payload {
if let Some(cookie) = private_cookies.get(ENROLLMENT_COOKIE_NAME) {
info!("Activated user - phone number {phone:?}");
debug!("Enrollment finished. Removing session cookie");
private_cookies = private_cookies.remove(cookie);
}

Ok(private_cookies)
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

Expand All @@ -114,15 +108,12 @@ pub async fn create_device(
.grpc_server
.send(Some(core_request::Payload::NewDevice(req)), device_info)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::DeviceConfig(response) => {
info!("Added new device {name} {pubkey}");
Ok(Json(response))
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
if let core_response::Payload::DeviceConfig(response) = payload {
info!("Added new device {name} {pubkey}");
Ok(Json(response))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

Expand All @@ -144,14 +135,11 @@ pub async fn get_network_info(
.grpc_server
.send(Some(core_request::Payload::ExistingDevice(req)), None)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::DeviceConfig(response) => {
info!("Got network info for device {pubkey}");
Ok(Json(response))
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
if let core_response::Payload::DeviceConfig(response) = payload {
info!("Got network info for device {pubkey}");
Ok(Json(response))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
24 changes: 10 additions & 14 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,15 @@ where
///
/// Waits for core response with a given timeout and returns the response payload.
async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
match timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
Ok(core_response) => {
debug!("Got gRPC response from Defguard core: {core_response:?}");
if let Ok(Payload::CoreError(core_error)) = core_response {
return Err(core_error.into());
};
core_response.map_err(|err| {
ApiError::Unexpected(format!("Failed to receive core response: {err}"))
})
}
Err(_) => {
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT} seconds");
Err(ApiError::CoreTimeout)
}
if let Ok(core_response) = timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
debug!("Got gRPC response from Defguard core: {core_response:?}");
if let Ok(Payload::CoreError(core_error)) = core_response {
return Err(core_error.into());
};
core_response
.map_err(|err| ApiError::Unexpected(format!("Failed to receive core response: {err}")))
} else {
error!("Did not receive core response within {CORE_RESPONSE_TIMEOUT} seconds");
Err(ApiError::CoreTimeout)
}
}
55 changes: 23 additions & 32 deletions src/handlers/password_reset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,12 @@ pub async fn request_password_reset(
device_info,
)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::Empty(_) => {
info!("Started password reset request for {}", req.email);
Ok(())
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
if let core_response::Payload::Empty(()) = payload {
info!("Started password reset request for {}", req.email);
Ok(())
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

Expand All @@ -66,19 +63,16 @@ pub async fn start_password_reset(
device_info,
)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::PasswordResetStart(response) => {
// set session cookie
let cookie = Cookie::build((PASSWORD_RESET_COOKIE_NAME, token))
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());
if let core_response::Payload::PasswordResetStart(response) = payload {
// set session cookie
let cookie = Cookie::build((PASSWORD_RESET_COOKIE_NAME, token))
.expires(OffsetDateTime::from_unix_timestamp(response.deadline_timestamp).unwrap());

info!("Started password reset process");
Ok((private_cookies.add(cookie), Json(response)))
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
info!("Started password reset process");
Ok((private_cookies.add(cookie), Json(response)))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

Expand All @@ -100,17 +94,14 @@ pub async fn reset_password(
.grpc_server
.send(Some(core_request::Payload::PasswordReset(req)), device_info)?;
let payload = get_core_response(rx).await?;
match payload {
core_response::Payload::Empty(_) => {
if let Some(cookie) = private_cookies.get(PASSWORD_RESET_COOKIE_NAME) {
info!("Password reset finished. Removing session cookie");
private_cookies = private_cookies.remove(cookie);
}
Ok(private_cookies)
}
_ => {
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
if let core_response::Payload::Empty(()) = payload {
if let Some(cookie) = private_cookies.get(PASSWORD_RESET_COOKIE_NAME) {
info!("Password reset finished. Removing session cookie");
private_cookies = private_cookies.remove(cookie);
}
Ok(private_cookies)
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
49 changes: 48 additions & 1 deletion src/http.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
fs::read_to_string,
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Duration,
};

use anyhow::Context;
Expand All @@ -17,6 +18,9 @@ use clap::crate_version;
use serde::Serialize;
use tokio::{net::TcpListener, task::JoinSet};
use tonic::transport::{Identity, Server, ServerTlsConfig};
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
};
use tower_http::{
services::{ServeDir, ServeFile},
trace::{self, TraceLayer},
Expand All @@ -33,6 +37,7 @@ use crate::{

pub(crate) static ENROLLMENT_COOKIE_NAME: &str = "defguard_proxy";
pub(crate) static PASSWORD_RESET_COOKIE_NAME: &str = "defguard_proxy_password_reset";
const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60);

#[derive(Clone)]
pub(crate) struct AppState {
Expand Down Expand Up @@ -133,7 +138,44 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
let serve_web_dir = ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html"));
let serve_images =
ServeDir::new("web/src/shared/images/svg").not_found_service(handle_404.into_service());
let app = Router::new()

// Setup tower_governor rate-limiter
debug!(
"Configuring rate limiter, per_second: {}, burst: {}",
config.rate_limit_per_second, config.rate_limit_burst
);
let governor_conf = GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_second(config.rate_limit_per_second)
.burst_size(config.rate_limit_burst)
.finish();

let governor_conf = if let Some(conf) = governor_conf {
let governor_limiter = conf.limiter().clone();

// Start background task to cleanup rate-limiter data
tokio::spawn(async move {
loop {
tokio::time::sleep(RATE_LIMITER_CLEANUP_PERIOD).await;
tracing::debug!(
"Cleaning-up rate limiter storage, current size: {}",
governor_limiter.len()
);
governor_limiter.retain_recent();
}
});
info!(
"Configured rate limiter, per_second: {}, burst: {}",
config.rate_limit_per_second, config.rate_limit_burst
);
Some(conf)
} else {
info!("Skipping rate limiter setup");
None
};

// Build axum app
let mut app = Router::new()
.nest(
"/api/v1",
Router::new()
Expand Down Expand Up @@ -161,6 +203,11 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
})
.on_response(trace::DefaultOnResponse::new().level(Level::DEBUG)),
);
if let Some(conf) = governor_conf {
app = app.layer(GovernorLayer {
config: conf.into(),
});
}
debug!("Configured API server routing: {app:?}");

// Start web server.
Expand Down