Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/attestation/dcap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn generate_quote(input: [u8; 64]) -> Result<Vec<u8>, QuoteGenerationError> {
/// Create a quote
#[cfg(not(test))]
fn generate_quote(input: [u8; 64]) -> Result<Vec<u8>, QuoteGenerationError> {
configfs_tsm::create_quote(input)
configfs_tsm::create_tdx_quote(input)
}

/// Given a [Report] get the input data regardless of report type
Expand Down
52 changes: 52 additions & 0 deletions src/health_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use axum::{routing::get, Json, Router};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use tokio::net::TcpListener;

/// Version information
#[derive(Serialize, Deserialize, PartialEq, Debug)]
pub struct VersionDetails {
pub cargo_package_version: String,
}

impl VersionDetails {
fn new() -> Self {
Self {
cargo_package_version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}

async fn health_handler() -> Json<VersionDetails> {
Json(VersionDetails::new())
}

/// Start a HTTP health check server which returns the cargo package version number
pub async fn server(listen_addr: SocketAddr) -> anyhow::Result<SocketAddr> {
let app = Router::new().fallback(get(health_handler));

let listener = TcpListener::bind(listen_addr).await?;
let listen_addr = listener.local_addr()?;
tracing::info!("Starting health check server at {}", listen_addr);

tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});

Ok(listen_addr)
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_health_check() {
let addr = server("127.0.0.1:0".parse().unwrap()).await.unwrap();

let response = reqwest::get(format!("http://{addr}")).await.unwrap();
assert_eq!(response.status(), reqwest::StatusCode::OK);
let body = response.text().await.unwrap();
assert_eq!(body, serde_json::to_string(&VersionDetails::new()).unwrap())
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod attestation;
pub mod attested_get;
pub mod file_server;
pub mod health_check;

pub use attestation::AttestationGenerator;
use attestation::{measurements::MultiMeasurements, AttestationError, AttestationType};
Expand Down
23 changes: 17 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use attested_tls_proxy::{
attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier},
attested_get::attested_get,
file_server::attested_file_server,
get_tls_cert, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey,
get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey,
};

#[derive(Parser, Debug, Clone)]
Expand Down Expand Up @@ -63,6 +63,9 @@ enum CliCommand {
/// dummy
#[arg(long)]
dev_dummy_dcap: Option<String>,
// Address to listen on for health checks
#[arg(long)]
listen_addr_healthcheck: Option<SocketAddr>,
},
/// Run a proxy server
Server {
Expand All @@ -89,11 +92,9 @@ enum CliCommand {
/// dummy
#[arg(long)]
dev_dummy_dcap: Option<String>,
// TODO missing:
// Name: "listen-addr-healthcheck",
// EnvVars: []string{"LISTEN_ADDR_HEALTHCHECK"},
// Value: "",
// Usage: "address to listen on for health checks",
// Address to listen on for health checks
#[arg(long)]
listen_addr_healthcheck: Option<SocketAddr>,
},
/// Retrieve the attested TLS certificate from a proxy server
GetTlsCert {
Expand Down Expand Up @@ -193,12 +194,17 @@ async fn main() -> anyhow::Result<()> {
tls_certificate_path,
tls_ca_certificate,
dev_dummy_dcap,
listen_addr_healthcheck,
} => {
let target_addr = target_addr
.strip_prefix("https://")
.unwrap_or(&target_addr)
.to_string();

if let Some(listen_addr_healthcheck) = listen_addr_healthcheck {
health_check::server(listen_addr_healthcheck).await?;
}

let tls_cert_and_chain = if let Some(private_key) = tls_private_key_path {
Some(load_tls_cert_and_key(
tls_certificate_path
Expand Down Expand Up @@ -254,7 +260,12 @@ async fn main() -> anyhow::Result<()> {
client_auth,
server_attestation_type,
dev_dummy_dcap,
listen_addr_healthcheck,
} => {
if let Some(listen_addr_healthcheck) = listen_addr_healthcheck {
health_check::server(listen_addr_healthcheck).await?;
}

let tls_cert_and_chain =
load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?;

Expand Down