Skip to content

Commit fbe7a04

Browse files
committed
Add health check server
1 parent 4939bd4 commit fbe7a04

File tree

4 files changed

+71
-7
lines changed

4 files changed

+71
-7
lines changed

src/attestation/dcap.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ fn generate_quote(input: [u8; 64]) -> Result<Vec<u8>, QuoteGenerationError> {
7979
/// Create a quote
8080
#[cfg(not(test))]
8181
fn generate_quote(input: [u8; 64]) -> Result<Vec<u8>, QuoteGenerationError> {
82-
configfs_tsm::create_quote(input)
82+
configfs_tsm::create_tdx_quote(input)
8383
}
8484

8585
/// Given a [Report] get the input data regardless of report type

src/health_check.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use axum::{routing::get, Json, Router};
2+
use serde::{Deserialize, Serialize};
3+
use std::net::SocketAddr;
4+
use tokio::net::TcpListener;
5+
6+
/// Version information
7+
#[derive(Serialize, Deserialize, PartialEq, Debug)]
8+
pub struct VersionDetails {
9+
pub cargo_package_version: String,
10+
}
11+
12+
impl VersionDetails {
13+
fn new() -> Self {
14+
Self {
15+
cargo_package_version: env!("CARGO_PKG_VERSION").to_string(),
16+
}
17+
}
18+
}
19+
20+
async fn health_handler() -> Json<VersionDetails> {
21+
Json(VersionDetails::new())
22+
}
23+
24+
/// Start a HTTP health check server which returns the cargo package version number
25+
pub async fn server(listen_addr: SocketAddr) -> anyhow::Result<SocketAddr> {
26+
let app = Router::new().fallback(get(health_handler));
27+
28+
let listener = TcpListener::bind(listen_addr).await?;
29+
let listen_addr = listener.local_addr()?;
30+
tracing::info!("Starting health check server at {}", listen_addr);
31+
32+
tokio::spawn(async move {
33+
axum::serve(listener, app).await.unwrap();
34+
});
35+
36+
Ok(listen_addr)
37+
}
38+
39+
#[cfg(test)]
40+
mod tests {
41+
use super::*;
42+
43+
#[tokio::test]
44+
async fn test_health_check() {
45+
let addr = server("127.0.0.1:0".parse().unwrap()).await.unwrap();
46+
47+
let response = reqwest::get(format!("http://{addr}")).await.unwrap();
48+
assert_eq!(response.status(), reqwest::StatusCode::OK);
49+
let body = response.text().await.unwrap();
50+
assert_eq!(body, serde_json::to_string(&VersionDetails::new()).unwrap())
51+
}
52+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub mod attestation;
22
pub mod attested_get;
33
pub mod file_server;
4+
pub mod health_check;
45

56
pub use attestation::AttestationGenerator;
67
use attestation::{measurements::MultiMeasurements, AttestationError, AttestationType};

src/main.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use attested_tls_proxy::{
99
attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier},
1010
attested_get::attested_get,
1111
file_server::attested_file_server,
12-
get_tls_cert, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey,
12+
get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey,
1313
};
1414

1515
#[derive(Parser, Debug, Clone)]
@@ -63,6 +63,9 @@ enum CliCommand {
6363
/// dummy
6464
#[arg(long)]
6565
dev_dummy_dcap: Option<String>,
66+
// Address to listen on for health checks
67+
#[arg(long)]
68+
listen_addr_healthcheck: Option<SocketAddr>,
6669
},
6770
/// Run a proxy server
6871
Server {
@@ -89,11 +92,9 @@ enum CliCommand {
8992
/// dummy
9093
#[arg(long)]
9194
dev_dummy_dcap: Option<String>,
92-
// TODO missing:
93-
// Name: "listen-addr-healthcheck",
94-
// EnvVars: []string{"LISTEN_ADDR_HEALTHCHECK"},
95-
// Value: "",
96-
// Usage: "address to listen on for health checks",
95+
// Address to listen on for health checks
96+
#[arg(long)]
97+
listen_addr_healthcheck: Option<SocketAddr>,
9798
},
9899
/// Retrieve the attested TLS certificate from a proxy server
99100
GetTlsCert {
@@ -193,12 +194,17 @@ async fn main() -> anyhow::Result<()> {
193194
tls_certificate_path,
194195
tls_ca_certificate,
195196
dev_dummy_dcap,
197+
listen_addr_healthcheck,
196198
} => {
197199
let target_addr = target_addr
198200
.strip_prefix("https://")
199201
.unwrap_or(&target_addr)
200202
.to_string();
201203

204+
if let Some(listen_addr_healthcheck) = listen_addr_healthcheck {
205+
health_check::server(listen_addr_healthcheck).await?;
206+
}
207+
202208
let tls_cert_and_chain = if let Some(private_key) = tls_private_key_path {
203209
Some(load_tls_cert_and_key(
204210
tls_certificate_path
@@ -254,7 +260,12 @@ async fn main() -> anyhow::Result<()> {
254260
client_auth,
255261
server_attestation_type,
256262
dev_dummy_dcap,
263+
listen_addr_healthcheck,
257264
} => {
265+
if let Some(listen_addr_healthcheck) = listen_addr_healthcheck {
266+
health_check::server(listen_addr_healthcheck).await?;
267+
}
268+
258269
let tls_cert_and_chain =
259270
load_tls_cert_and_key(tls_certificate_path, tls_private_key_path)?;
260271

0 commit comments

Comments
 (0)