diff --git a/src/config.rs b/src/config.rs index e5573ee..038a541 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,7 @@ use clap::Parser; +use tracing::log::LevelFilter; -#[derive(Parser)] +#[derive(Parser, Debug)] #[command(version)] pub struct Config { // port the API server will listen on @@ -21,4 +22,7 @@ pub struct Config { #[arg(long, env = "DEFGUARD_PROXY_GRPC_KEY")] pub grpc_key: Option, + + #[arg(long, env = "DEFGUARD_PROXY_LOG_LEVEL", default_value_t = LevelFilter::Info)] + pub log_level: LevelFilter, } diff --git a/src/error.rs b/src/error.rs index d1cf42e..d3bad06 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,14 +7,11 @@ use axum::{ use serde_json::json; use tonic::metadata::errors::InvalidMetadataValue; use tonic::{Code, Status}; -use tracing::error; #[derive(thiserror::Error, Debug)] pub enum ApiError { #[error("Unauthorized")] Unauthorized, - #[error("Session cookie not found")] - CookieNotFound, #[error("Unexpected error: {0}")] Unexpected(String), #[error(transparent)] diff --git a/src/grpc.rs b/src/grpc.rs new file mode 100644 index 0000000..ba35fa7 --- /dev/null +++ b/src/grpc.rs @@ -0,0 +1,127 @@ +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, +}; + +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; +use tonic::{Request, Response, Status, Streaming}; + +use crate::{ + error::ApiError, + proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo}, +}; + +// connected clients +type ClientMap = HashMap>>; + +#[derive(Debug)] +pub(crate) struct ProxyServer { + current_id: Arc, + clients: Arc>, + results: Arc>>>, +} + +impl ProxyServer { + #[must_use] + /// Create new `ProxyServer`. + pub fn new() -> Self { + Self { + current_id: Arc::new(AtomicU64::new(1)), + clients: Arc::new(Mutex::new(HashMap::new())), + results: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Sends message to the other side of RPC, with given `payload` and optional 'device_info`. + /// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply. + #[instrument(name = "send_grpc_message", level = "debug", skip(self))] + pub fn send( + &self, + payload: Option, + device_info: Option, + ) -> Result, ApiError> { + if let Some(client_tx) = self.clients.lock().unwrap().values().next() { + let id = self.current_id.fetch_add(1, Ordering::Relaxed); + let res = CoreRequest { + id, + device_info, + payload, + }; + if let Err(err) = client_tx.send(Ok(res)) { + error!("Failed to send CoreRequest: {err}"); + return Err(ApiError::Unexpected("Failed to send CoreRequest".into())); + }; + let (tx, rx) = oneshot::channel(); + let mut results = self.results.lock().unwrap(); + results.insert(id, tx); + Ok(rx) + } else { + error!("Defguard core is disconnected"); + Err(ApiError::Unexpected("Defguard core is disconnected".into())) + } + } +} + +impl Clone for ProxyServer { + fn clone(&self) -> Self { + Self { + current_id: Arc::clone(&self.current_id), + clients: Arc::clone(&self.clients), + results: Arc::clone(&self.results), + } + } +} + +#[tonic::async_trait] +impl proxy_server::Proxy for ProxyServer { + type BidiStream = UnboundedReceiverStream>; + + /// Handle bidirectional communication with Defguard core. + #[instrument(name = "bidirectional_communication", level = "debug", skip(self))] + async fn bidi( + &self, + request: Request>, + ) -> Result, Status> { + let Some(address) = request.remote_addr() else { + error!("Failed to determine client address for request: {request:?}"); + return Err(Status::internal("Failed to determine client address")); + }; + info!("Defguard core RPC client connected from: {address}"); + + let (tx, rx) = mpsc::unbounded_channel(); + self.clients.lock().unwrap().insert(address, tx); + + let clients = Arc::clone(&self.clients); + let results = Arc::clone(&self.results); + let mut in_stream = request.into_inner(); + tokio::spawn(async move { + while let Some(result) = in_stream.next().await { + match result { + Ok(response) => { + debug!("Received message from Defguard core: {response:?}"); + // Discard empty payloads. + if let Some(payload) = response.payload { + if let Some(rx) = results.lock().unwrap().remove(&response.id) { + if let Err(err) = rx.send(payload) { + error!("Failed to send message to rx: {err:?}"); + } + } else { + error!("Missing receiver for response #{}", response.id); + } + } + } + Err(err) => error!("RPC client error: {err}"), + } + } + info!("Defguard core client disconnected: {address}"); + clients.lock().unwrap().remove(&address); + }); + + Ok(Response::new(UnboundedReceiverStream::new(rx))) + } +} diff --git a/src/handlers/desktop_client_mfa.rs b/src/handlers/desktop_client_mfa.rs index 224951e..a64ab69 100644 --- a/src/handlers/desktop_client_mfa.rs +++ b/src/handlers/desktop_client_mfa.rs @@ -1,14 +1,13 @@ use crate::{ error::ApiError, handlers::get_core_response, + http::AppState, proto::{ core_request, core_response, ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo, }, - server::AppState, }; use axum::{extract::State, routing::post, Json, Router}; -use tracing::{error, info}; pub fn router() -> Router { Router::new() @@ -16,6 +15,7 @@ pub fn router() -> Router { .route("/finish", post(finish_client_mfa)) } +#[instrument(level = "debug", skip(state))] async fn start_client_mfa( State(state): State, device_info: Option, @@ -30,12 +30,13 @@ async fn start_client_mfa( match payload { core_response::Payload::ClientMfaStart(response) => Ok(Json(response)), _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } } +#[instrument(level = "debug", skip(state))] async fn finish_client_mfa( State(state): State, device_info: Option, @@ -50,7 +51,7 @@ async fn finish_client_mfa( match payload { core_response::Payload::ClientMfaFinish(response) => Ok(Json(response)), _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } diff --git a/src/handlers/enrollment.rs b/src/handlers/enrollment.rs index fefae89..30aad81 100644 --- a/src/handlers/enrollment.rs +++ b/src/handlers/enrollment.rs @@ -1,16 +1,15 @@ use axum::{extract::State, routing::post, Json, Router}; use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use time::OffsetDateTime; -use tracing::{debug, error, info}; use crate::{ error::ApiError, handlers::get_core_response, + http::{AppState, ENROLLMENT_COOKIE_NAME}, proto::{ core_request, core_response, ActivateUserRequest, DeviceConfigResponse, DeviceInfo, EnrollmentStartRequest, EnrollmentStartResponse, ExistingDevice, NewDevice, }, - server::{AppState, ENROLLMENT_COOKIE_NAME}, }; pub fn router() -> Router { @@ -21,6 +20,7 @@ pub fn router() -> Router { .route("/network_info", post(get_network_info)) } +#[instrument(level = "debug", skip(state))] pub async fn start_enrollment_process( State(state): State, mut private_cookies: PrivateCookieJar, @@ -49,12 +49,13 @@ pub async fn start_enrollment_process( Ok((private_cookies.add(cookie), Json(response))) } _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } } +#[instrument(level = "debug", skip(state))] pub async fn activate_user( State(state): State, device_info: Option, @@ -82,12 +83,13 @@ pub async fn activate_user( Ok(private_cookies) } _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } } +#[instrument(level = "debug", skip(state))] pub async fn create_device( State(state): State, device_info: Option, @@ -108,12 +110,13 @@ pub async fn create_device( match payload { core_response::Payload::DeviceConfig(response) => Ok(Json(response)), _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } } +#[instrument(level = "debug", skip(state))] pub async fn get_network_info( State(state): State, private_cookies: PrivateCookieJar, @@ -133,7 +136,7 @@ pub async fn get_network_info( match payload { core_response::Payload::DeviceConfig(response) => Ok(Json(response)), _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index caaad10..b88bb5f 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -8,7 +8,6 @@ use axum_client_ip::{InsecureClientIp, LeftmostXForwardedFor}; use axum_extra::{headers::UserAgent, TypedHeader}; use std::time::Duration; use tokio::{sync::oneshot::Receiver, time::timeout}; -use tracing::error; use super::proto::DeviceInfo; @@ -49,6 +48,7 @@ where async fn get_core_response(rx: Receiver) -> Result { match timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await { Ok(core_response) => { + debug!("Got gRPC response from Defguard core: {core_response:?}"); if let Ok(Payload::CoreError(core_error)) = core_response { return Err(core_error.into()); }; diff --git a/src/handlers/password_reset.rs b/src/handlers/password_reset.rs index e66f6cd..88e9e44 100644 --- a/src/handlers/password_reset.rs +++ b/src/handlers/password_reset.rs @@ -1,16 +1,15 @@ use axum::{extract::State, routing::post, Json, Router}; use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; use time::OffsetDateTime; -use tracing::{debug, error, info}; use crate::{ error::ApiError, handlers::get_core_response, + http::{AppState, PASSWORD_RESET_COOKIE_NAME}, proto::{ core_request, core_response, DeviceInfo, PasswordResetInitializeRequest, PasswordResetRequest, PasswordResetStartRequest, PasswordResetStartResponse, }, - server::{AppState, PASSWORD_RESET_COOKIE_NAME}, }; pub fn router() -> Router { @@ -20,6 +19,7 @@ pub fn router() -> Router { .route("/reset", post(reset_password)) } +#[instrument(level = "debug", skip(state))] pub async fn request_password_reset( State(state): State, device_info: Option, @@ -35,12 +35,13 @@ pub async fn request_password_reset( match payload { core_response::Payload::Empty(_) => Ok(()), _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } } +#[instrument(level = "debug", skip(state))] pub async fn start_password_reset( State(state): State, device_info: Option, @@ -71,12 +72,13 @@ pub async fn start_password_reset( Ok((private_cookies.add(cookie), Json(response))) } _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } } +#[instrument(level = "debug", skip(state))] pub async fn reset_password( State(state): State, device_info: Option, @@ -103,7 +105,7 @@ pub async fn reset_password( Ok(private_cookies) } _ => { - error!("Received invalid gRPC response type: {payload:#?}"); + error!("Received invalid gRPC response type: {payload:?}"); Err(ApiError::InvalidResponseType) } } diff --git a/src/server.rs b/src/http.rs similarity index 80% rename from src/server.rs rename to src/http.rs index 256ccec..28dc364 100644 --- a/src/server.rs +++ b/src/http.rs @@ -5,7 +5,7 @@ use std::{ use anyhow::Context; use axum::{ - extract::FromRef, + extract::{ConnectInfo, FromRef}, handler::HandlerWithoutStateExt, http::{Request, StatusCode}, routing::get, @@ -20,14 +20,14 @@ use tower_http::{ services::{ServeDir, ServeFile}, trace::{self, TraceLayer}, }; -use tracing::{debug, info, info_span, Level}; +use tracing::{info_span, Level}; use crate::{ config::Config, error::ApiError, + grpc::ProxyServer, handlers::{desktop_client_mfa, enrollment, password_reset}, proto::proxy_server, - ProxyServer, }; pub(crate) static ENROLLMENT_COOKIE_NAME: &str = "defguard_proxy"; @@ -65,6 +65,7 @@ async fn healthcheck() -> &'static str { pub async fn run_server(config: Config) -> anyhow::Result<()> { info!("Starting Defguard proxy server"); + debug!("Using config: {config:?}"); let mut tasks = JoinSet::new(); @@ -80,6 +81,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { }; // read gRPC TLS cert and key + debug!("Configuring grpc certificates"); let grpc_cert = config .grpc_cert .as_ref() @@ -88,8 +90,10 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .grpc_key .as_ref() .and_then(|path| read_to_string(path).ok()); + debug!("Configured grpc certificates, cert: {grpc_cert:?}, key: {grpc_key:?}"); // Start gRPC server. + debug!("Spawning gRPC server"); tasks.spawn(async move { let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), config.grpc_port); info!("gRPC server is listening on {addr}"); @@ -103,10 +107,11 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .add_service(proxy_server::ProxyServer::new(grpc_server)) .serve(addr) .await - .context("Error running RPC server") + .context("Error running gRPC server") }); // Serve static frontend files. + debug!("Configuring API server routing"); let serve_web_dir = ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html")); let serve_images = ServeDir::new("web/src/shared/images/svg").not_found_service(handle_404.into_service()); @@ -126,20 +131,31 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .layer( TraceLayer::new_for_http() .make_span_with(|request: &Request<_>| { + // extract client address + let addr = request + .extensions() + .get::>() + .map(|addr| addr.0.to_string()) + .unwrap_or_else(|| "unknown".to_string()); info_span!( "http_request", method = ?request.method(), path = ?request.uri(), + // TODO: headers only in debug logs + // headers = ?request.headers(), + client_addr = addr, ) }) .on_response(trace::DefaultOnResponse::new().level(Level::DEBUG)), ); + debug!("Configured API server routing: {app:?}"); // Start web server. + debug!("Spawning API web server"); tasks.spawn(async move { let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), config.http_port); let listener = TcpListener::bind(&addr).await?; - info!("Web server is listening on {addr}"); + info!("API web server is listening on {addr}"); serve( listener, app.into_make_service_with_connect_info::(), @@ -148,6 +164,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .context("Error running HTTP server") }); + info!("Defguard proxy server initialization complete"); while let Some(Ok(result)) = tasks.join_next().await { result?; } diff --git a/src/lib.rs b/src/lib.rs index 3d03c99..9f35c30 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,142 +1,13 @@ pub mod config; -pub mod error; +mod error; +mod grpc; mod handlers; -pub mod server; +pub mod http; +pub mod tracing; pub(crate) mod proto { tonic::include_proto!("defguard.proxy"); } -use std::{ - collections::HashMap, - net::SocketAddr, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, Mutex, - }, -}; - -use tokio::sync::{mpsc, oneshot}; -use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; -use tonic::{Request, Response, Status, Streaming}; -use tracing::{debug, error, info}; - -use crate::error::ApiError; -use proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo}; - -// connected clients -type ClientMap = HashMap>>; - -#[derive(Debug)] -pub(crate) struct ProxyServer { - current_id: Arc, - clients: Arc>, - results: Arc>>>, -} - -impl ProxyServer { - #[must_use] - /// Create new `ProxyServer`. - pub fn new() -> Self { - Self { - current_id: Arc::new(AtomicU64::new(1)), - clients: Arc::new(Mutex::new(HashMap::new())), - results: Arc::new(Mutex::new(HashMap::new())), - } - } - - /// Sends message to the other side of RPC, with given `payload` and optional 'device_info`. - /// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply. - pub fn send( - &self, - payload: Option, - device_info: Option, - ) -> Result, ApiError> { - if let Some(client_tx) = self.clients.lock().unwrap().values().next() { - let id = self.current_id.fetch_add(1, Ordering::Relaxed); - let res = CoreRequest { - id, - device_info, - payload, - }; - if client_tx.send(Ok(res)).is_ok() { - let (tx, rx) = oneshot::channel(); - if let Ok(mut results) = self.results.lock() { - results.insert(id, tx); - return Ok(rx); - } - } - - debug!("Failed to send CoreRequest"); - } - - Err(ApiError::Unexpected( - "Failed to communicate with Defguard core".into(), - )) - } -} - -impl Clone for ProxyServer { - fn clone(&self) -> Self { - Self { - current_id: Arc::clone(&self.current_id), - clients: Arc::clone(&self.clients), - results: Arc::clone(&self.results), - } - } -} - -#[tonic::async_trait] -impl proxy_server::Proxy for ProxyServer { - type BidiStream = UnboundedReceiverStream>; - - /// Handle bidirectional communication with Defguard core. - async fn bidi( - &self, - request: Request>, - ) -> Result, Status> { - let Some(address) = request.remote_addr() else { - return Err(Status::internal("failed to determine client address")); - }; - info!("RPC client connected from: {address}"); - - let (tx, rx) = mpsc::unbounded_channel(); - self.clients.lock().unwrap().insert(address, tx); - - let clients = Arc::clone(&self.clients); - let results = Arc::clone(&self.results); - let mut in_stream = request.into_inner(); - tokio::spawn(async move { - while let Some(result) = in_stream.next().await { - match result { - Ok(response) => { - debug!("RPC message received {response:?}"); - // Discard empty payloads. - if let Some(payload) = response.payload { - if let Ok(mut results) = results.lock() { - if let Some(rx) = results.remove(&response.id) { - if rx.send(payload).is_err() { - debug!("failed to send to rx"); - } - } else { - debug!("missing receiver for response #{}", response.id); - } - } else { - error!("failed to obtain mutex on results"); - } - } - } - Err(err) => error!("RPC client error: {err}"), - } - } - debug!("client disconnected {address}"); - if let Ok(mut clients) = clients.lock() { - clients.remove(&address); - } else { - error!("failed to obtain mutex on clients"); - } - }); - - Ok(Response::new(UnboundedReceiverStream::new(rx))) - } -} +#[macro_use] +extern crate tracing as rust_tracing; diff --git a/src/main.rs b/src/main.rs index c6b2ae9..d2ac2ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,26 +1,16 @@ use clap::Parser; -use defguard_proxy::{config::Config, server::run_server}; -use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; +use defguard_proxy::{config::Config, http::run_server, tracing::init_tracing}; #[tokio::main] async fn main() -> anyhow::Result<()> { - // initialize tracing - tracing_subscriber::registry() - .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| { - "defguard_proxy=debug,tower_http=debug,axum::rejection=trace".into() - })) - .with(fmt::layer()) - .init(); - - // load .env + // configuration if dotenvy::from_filename(".env.local").is_err() { dotenvy::dotenv().ok(); } - - // read config from env let config = Config::parse(); + init_tracing(&config.log_level); - // run API server + // run API web server run_server(config).await?; Ok(()) diff --git a/src/tracing.rs b/src/tracing.rs new file mode 100644 index 0000000..4e7f5c8 --- /dev/null +++ b/src/tracing.rs @@ -0,0 +1,17 @@ +use tracing::log::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +// Initializes tracing with the specified log level. +// Allows fine-grained filtering with `EnvFilter` directives. +// The directives are read from `DEFGUARD_PROXY_LOG_FILTER` env variable. +// For more info check: +pub fn init_tracing(level: &LevelFilter) { + tracing_subscriber::registry() + .with( + EnvFilter::try_from_env("DEFGUARD_PROXY_LOG_FILTER") + .unwrap_or_else(|_| level.to_string().into()), + ) + .with(fmt::layer()) + .init(); + info!("Tracing initialized"); +}