diff --git a/src/attestation/dcap.rs b/src/attestation/dcap.rs index 5ed2b33..156339e 100644 --- a/src/attestation/dcap.rs +++ b/src/attestation/dcap.rs @@ -79,7 +79,7 @@ fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { /// Create a quote #[cfg(not(test))] fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { - configfs_tsm::create_quote(input) + configfs_tsm::create_tdx_quote(input) } /// Given a [Report] get the input data regardless of report type diff --git a/src/health_check.rs b/src/health_check.rs new file mode 100644 index 0000000..73063f8 --- /dev/null +++ b/src/health_check.rs @@ -0,0 +1,52 @@ +use axum::{routing::get, Json, Router}; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +/// Version information +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct VersionDetails { + pub cargo_package_version: String, +} + +impl VersionDetails { + fn new() -> Self { + Self { + cargo_package_version: env!("CARGO_PKG_VERSION").to_string(), + } + } +} + +async fn health_handler() -> Json { + Json(VersionDetails::new()) +} + +/// Start a HTTP health check server which returns the cargo package version number +pub async fn server(listen_addr: SocketAddr) -> anyhow::Result { + let app = Router::new().fallback(get(health_handler)); + + let listener = TcpListener::bind(listen_addr).await?; + let listen_addr = listener.local_addr()?; + tracing::info!("Starting health check server at {}", listen_addr); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + Ok(listen_addr) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_health_check() { + let addr = server("127.0.0.1:0".parse().unwrap()).await.unwrap(); + + let response = reqwest::get(format!("http://{addr}")).await.unwrap(); + assert_eq!(response.status(), reqwest::StatusCode::OK); + let body = response.text().await.unwrap(); + assert_eq!(body, serde_json::to_string(&VersionDetails::new()).unwrap()) + } +} diff --git a/src/lib.rs b/src/lib.rs index d0a703d..ea9c950 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod attestation; pub mod attested_get; pub mod file_server; +pub mod health_check; pub use attestation::AttestationGenerator; use attestation::{measurements::MultiMeasurements, AttestationError, AttestationType}; diff --git a/src/main.rs b/src/main.rs index d20facc..234be65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use attested_tls_proxy::{ attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier}, attested_get::attested_get, file_server::attested_file_server, - get_tls_cert, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, + get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, }; #[derive(Parser, Debug, Clone)] @@ -63,6 +63,9 @@ enum CliCommand { /// dummy #[arg(long)] dev_dummy_dcap: Option, + // Address to listen on for health checks + #[arg(long)] + listen_addr_healthcheck: Option, }, /// Run a proxy server Server { @@ -89,11 +92,9 @@ enum CliCommand { /// dummy #[arg(long)] dev_dummy_dcap: Option, - // TODO missing: - // Name: "listen-addr-healthcheck", - // EnvVars: []string{"LISTEN_ADDR_HEALTHCHECK"}, - // Value: "", - // Usage: "address to listen on for health checks", + // Address to listen on for health checks + #[arg(long)] + listen_addr_healthcheck: Option, }, /// Retrieve the attested TLS certificate from a proxy server GetTlsCert { @@ -193,12 +194,17 @@ async fn main() -> anyhow::Result<()> { tls_certificate_path, tls_ca_certificate, dev_dummy_dcap, + listen_addr_healthcheck, } => { let target_addr = target_addr .strip_prefix("https://") .unwrap_or(&target_addr) .to_string(); + if let Some(listen_addr_healthcheck) = listen_addr_healthcheck { + health_check::server(listen_addr_healthcheck).await?; + } + let tls_cert_and_chain = if let Some(private_key) = tls_private_key_path { Some(load_tls_cert_and_key( tls_certificate_path @@ -254,7 +260,12 @@ async fn main() -> anyhow::Result<()> { client_auth, server_attestation_type, dev_dummy_dcap, + listen_addr_healthcheck, } => { + if let Some(listen_addr_healthcheck) = listen_addr_healthcheck { + health_check::server(listen_addr_healthcheck).await?; + } + let tls_cert_and_chain = load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?;