Skip to content

Commit

Permalink
Merge pull request #54 from DefGuard/tracing
Browse files Browse the repository at this point in the history
feat: tracing setup
  • Loading branch information
j-chmielewski authored Mar 13, 2024
2 parents 0726e53 + 4119986 commit 0bf34ae
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 174 deletions.
6 changes: 5 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use clap::Parser;
use tracing::log::LevelFilter;

#[derive(Parser)]
#[derive(Parser, Debug)]
#[command(version)]
pub struct Config {
// port the API server will listen on
Expand All @@ -21,4 +22,7 @@ pub struct Config {

#[arg(long, env = "DEFGUARD_PROXY_GRPC_KEY")]
pub grpc_key: Option<String>,

#[arg(long, env = "DEFGUARD_PROXY_LOG_LEVEL", default_value_t = LevelFilter::Info)]
pub log_level: LevelFilter,
}
3 changes: 0 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ use axum::{
use serde_json::json;
use tonic::metadata::errors::InvalidMetadataValue;
use tonic::{Code, Status};
use tracing::error;

#[derive(thiserror::Error, Debug)]
pub enum ApiError {
#[error("Unauthorized")]
Unauthorized,
#[error("Session cookie not found")]
CookieNotFound,
#[error("Unexpected error: {0}")]
Unexpected(String),
#[error(transparent)]
Expand Down
127 changes: 127 additions & 0 deletions src/grpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};

use tokio::sync::{mpsc, oneshot};
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use tonic::{Request, Response, Status, Streaming};

use crate::{
error::ApiError,
proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo},
};

// connected clients
type ClientMap = HashMap<SocketAddr, mpsc::UnboundedSender<Result<CoreRequest, Status>>>;

#[derive(Debug)]
pub(crate) struct ProxyServer {
current_id: Arc<AtomicU64>,
clients: Arc<Mutex<ClientMap>>,
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
}

impl ProxyServer {
#[must_use]
/// Create new `ProxyServer`.
pub fn new() -> Self {
Self {
current_id: Arc::new(AtomicU64::new(1)),
clients: Arc::new(Mutex::new(HashMap::new())),
results: Arc::new(Mutex::new(HashMap::new())),
}
}

/// Sends message to the other side of RPC, with given `payload` and optional 'device_info`.
/// Returns `tokio::sync::oneshot::Reveicer` to let the caller await reply.
#[instrument(name = "send_grpc_message", level = "debug", skip(self))]
pub fn send(
&self,
payload: Option<core_request::Payload>,
device_info: Option<DeviceInfo>,
) -> Result<oneshot::Receiver<core_response::Payload>, ApiError> {
if let Some(client_tx) = self.clients.lock().unwrap().values().next() {
let id = self.current_id.fetch_add(1, Ordering::Relaxed);
let res = CoreRequest {
id,
device_info,
payload,
};
if let Err(err) = client_tx.send(Ok(res)) {
error!("Failed to send CoreRequest: {err}");
return Err(ApiError::Unexpected("Failed to send CoreRequest".into()));
};
let (tx, rx) = oneshot::channel();
let mut results = self.results.lock().unwrap();
results.insert(id, tx);
Ok(rx)
} else {
error!("Defguard core is disconnected");
Err(ApiError::Unexpected("Defguard core is disconnected".into()))
}
}
}

impl Clone for ProxyServer {
fn clone(&self) -> Self {
Self {
current_id: Arc::clone(&self.current_id),
clients: Arc::clone(&self.clients),
results: Arc::clone(&self.results),
}
}
}

#[tonic::async_trait]
impl proxy_server::Proxy for ProxyServer {
type BidiStream = UnboundedReceiverStream<Result<CoreRequest, Status>>;

/// Handle bidirectional communication with Defguard core.
#[instrument(name = "bidirectional_communication", level = "debug", skip(self))]
async fn bidi(
&self,
request: Request<Streaming<CoreResponse>>,
) -> Result<Response<Self::BidiStream>, Status> {
let Some(address) = request.remote_addr() else {
error!("Failed to determine client address for request: {request:?}");
return Err(Status::internal("Failed to determine client address"));
};
info!("Defguard core RPC client connected from: {address}");

let (tx, rx) = mpsc::unbounded_channel();
self.clients.lock().unwrap().insert(address, tx);

let clients = Arc::clone(&self.clients);
let results = Arc::clone(&self.results);
let mut in_stream = request.into_inner();
tokio::spawn(async move {
while let Some(result) = in_stream.next().await {
match result {
Ok(response) => {
debug!("Received message from Defguard core: {response:?}");
// Discard empty payloads.
if let Some(payload) = response.payload {
if let Some(rx) = results.lock().unwrap().remove(&response.id) {
if let Err(err) = rx.send(payload) {
error!("Failed to send message to rx: {err:?}");
}
} else {
error!("Missing receiver for response #{}", response.id);
}
}
}
Err(err) => error!("RPC client error: {err}"),
}
}
info!("Defguard core client disconnected: {address}");
clients.lock().unwrap().remove(&address);
});

Ok(Response::new(UnboundedReceiverStream::new(rx)))
}
}
9 changes: 5 additions & 4 deletions src/handlers/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use crate::{
error::ApiError,
handlers::get_core_response,
http::AppState,
proto::{
core_request, core_response, ClientMfaFinishRequest, ClientMfaFinishResponse,
ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo,
},
server::AppState,
};
use axum::{extract::State, routing::post, Json, Router};
use tracing::{error, info};

pub fn router() -> Router<AppState> {
Router::new()
.route("/start", post(start_client_mfa))
.route("/finish", post(finish_client_mfa))
}

#[instrument(level = "debug", skip(state))]
async fn start_client_mfa(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand All @@ -30,12 +30,13 @@ async fn start_client_mfa(
match payload {
core_response::Payload::ClientMfaStart(response) => Ok(Json(response)),
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
}

#[instrument(level = "debug", skip(state))]
async fn finish_client_mfa(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand All @@ -50,7 +51,7 @@ async fn finish_client_mfa(
match payload {
core_response::Payload::ClientMfaFinish(response) => Ok(Json(response)),
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
Expand Down
15 changes: 9 additions & 6 deletions src/handlers/enrollment.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use axum::{extract::State, routing::post, Json, Router};
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use time::OffsetDateTime;
use tracing::{debug, error, info};

use crate::{
error::ApiError,
handlers::get_core_response,
http::{AppState, ENROLLMENT_COOKIE_NAME},
proto::{
core_request, core_response, ActivateUserRequest, DeviceConfigResponse, DeviceInfo,
EnrollmentStartRequest, EnrollmentStartResponse, ExistingDevice, NewDevice,
},
server::{AppState, ENROLLMENT_COOKIE_NAME},
};

pub fn router() -> Router<AppState> {
Expand All @@ -21,6 +20,7 @@ pub fn router() -> Router<AppState> {
.route("/network_info", post(get_network_info))
}

#[instrument(level = "debug", skip(state))]
pub async fn start_enrollment_process(
State(state): State<AppState>,
mut private_cookies: PrivateCookieJar,
Expand Down Expand Up @@ -49,12 +49,13 @@ pub async fn start_enrollment_process(
Ok((private_cookies.add(cookie), Json(response)))
}
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
}

#[instrument(level = "debug", skip(state))]
pub async fn activate_user(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand Down Expand Up @@ -82,12 +83,13 @@ pub async fn activate_user(
Ok(private_cookies)
}
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
}

#[instrument(level = "debug", skip(state))]
pub async fn create_device(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand All @@ -108,12 +110,13 @@ pub async fn create_device(
match payload {
core_response::Payload::DeviceConfig(response) => Ok(Json(response)),
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
}

#[instrument(level = "debug", skip(state))]
pub async fn get_network_info(
State(state): State<AppState>,
private_cookies: PrivateCookieJar,
Expand All @@ -133,7 +136,7 @@ pub async fn get_network_info(
match payload {
core_response::Payload::DeviceConfig(response) => Ok(Json(response)),
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use axum_client_ip::{InsecureClientIp, LeftmostXForwardedFor};
use axum_extra::{headers::UserAgent, TypedHeader};
use std::time::Duration;
use tokio::{sync::oneshot::Receiver, time::timeout};
use tracing::error;

use super::proto::DeviceInfo;

Expand Down Expand Up @@ -49,6 +48,7 @@ where
async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
match timeout(Duration::from_secs(CORE_RESPONSE_TIMEOUT), rx).await {
Ok(core_response) => {
debug!("Got gRPC response from Defguard core: {core_response:?}");
if let Ok(Payload::CoreError(core_error)) = core_response {
return Err(core_error.into());
};
Expand Down
12 changes: 7 additions & 5 deletions src/handlers/password_reset.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use axum::{extract::State, routing::post, Json, Router};
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use time::OffsetDateTime;
use tracing::{debug, error, info};

use crate::{
error::ApiError,
handlers::get_core_response,
http::{AppState, PASSWORD_RESET_COOKIE_NAME},
proto::{
core_request, core_response, DeviceInfo, PasswordResetInitializeRequest,
PasswordResetRequest, PasswordResetStartRequest, PasswordResetStartResponse,
},
server::{AppState, PASSWORD_RESET_COOKIE_NAME},
};

pub fn router() -> Router<AppState> {
Expand All @@ -20,6 +19,7 @@ pub fn router() -> Router<AppState> {
.route("/reset", post(reset_password))
}

#[instrument(level = "debug", skip(state))]
pub async fn request_password_reset(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand All @@ -35,12 +35,13 @@ pub async fn request_password_reset(
match payload {
core_response::Payload::Empty(_) => Ok(()),
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
}

#[instrument(level = "debug", skip(state))]
pub async fn start_password_reset(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand Down Expand Up @@ -71,12 +72,13 @@ pub async fn start_password_reset(
Ok((private_cookies.add(cookie), Json(response)))
}
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
}

#[instrument(level = "debug", skip(state))]
pub async fn reset_password(
State(state): State<AppState>,
device_info: Option<DeviceInfo>,
Expand All @@ -103,7 +105,7 @@ pub async fn reset_password(
Ok(private_cookies)
}
_ => {
error!("Received invalid gRPC response type: {payload:#?}");
error!("Received invalid gRPC response type: {payload:?}");
Err(ApiError::InvalidResponseType)
}
}
Expand Down
Loading

0 comments on commit 0bf34ae

Please sign in to comment.