Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 26 additions & 56 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,35 +401,28 @@ pub async fn init_domain_services_with_pool(
pub async fn init_inference_providers(
config: &ApiConfig,
) -> Arc<services::inference_provider_pool::InferenceProviderPool> {
let discovery_url = config.model_discovery.discovery_server_url.clone();
let api_key = config.model_discovery.api_key.clone();
let router_url = config.inference_router.router_url.clone();
let api_key = config.inference_router.api_key.clone();
let inference_timeout = config.inference_router.inference_timeout;

// Create pool with discovery URL and API key
let pool = Arc::new(
tracing::info!(
router_url = %router_url,
"Initializing inference provider pool with router endpoint"
);

// Create pool with single router endpoint
Arc::new(
services::inference_provider_pool::InferenceProviderPool::new(
discovery_url,
router_url,
api_key,
config.model_discovery.timeout,
config.model_discovery.inference_timeout,
inference_timeout,
),
);

// Initialize model discovery during startup
if pool.initialize().await.is_err() {
tracing::warn!("Failed to initialize model discovery during startup");
tracing::info!("Models will be discovered on first request");
}

// Start periodic refresh task with handle management
let refresh_interval = config.model_discovery.refresh_interval as u64;
pool.clone().start_refresh_task(refresh_interval).await;

pool
)
}

/// Initialize inference provider pool with mock providers for testing
/// This function uses the existing MockProvider from inference_providers::mock
/// and registers it for common test models without changing any implementations
/// and creates a pool that uses the mock provider directly
pub async fn init_inference_providers_with_mocks(
_config: &ApiConfig,
) -> (
Expand All @@ -439,38 +432,19 @@ pub async fn init_inference_providers_with_mocks(
use inference_providers::MockProvider;
use std::sync::Arc;

// Create pool with dummy discovery URL (won't be used since we're registering providers directly)
let pool = Arc::new(
services::inference_provider_pool::InferenceProviderPool::new(
"http://localhost:8080/models".to_string(),
None,
5,
30 * 60,
),
);

// Create a MockProvider that accepts all models (using new_accept_all)
let mock_provider = Arc::new(MockProvider::new_accept_all());

// Cast to the trait type for the pool
let mock_provider_trait: Arc<dyn inference_providers::InferenceProvider + Send + Sync> =
mock_provider.clone();

// Register providers for models commonly used in tests
let test_models = vec![
"Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(),
"zai-org/GLM-4.6".to_string(),
"nearai/gpt-oss-120b".to_string(),
"dphn/Dolphin-Mistral-24B-Venice-Edition".to_string(),
];

let providers: Vec<(
String,
Arc<dyn inference_providers::InferenceProvider + Send + Sync>,
)> = test_models
.into_iter()
.map(|model_id| (model_id, mock_provider_trait.clone()))
.collect();

pool.register_providers(providers).await;
// Create pool with the mock provider
let pool = Arc::new(
services::inference_provider_pool::InferenceProviderPool::new_with_provider(
mock_provider_trait,
),
);

tracing::info!("Initialized inference provider pool with MockProvider for testing");

Expand Down Expand Up @@ -1081,11 +1055,9 @@ mod tests {
host: "127.0.0.1".to_string(),
port: 0, // Use port 0 for testing to get a random available port
},
model_discovery: config::ModelDiscoveryConfig {
discovery_server_url: "http://localhost:8080/models".to_string(),
inference_router: config::InferenceRouterConfig {
router_url: "http://localhost:8080".to_string(),
api_key: Some("test-key".to_string()),
refresh_interval: 0,
timeout: 5,
inference_timeout: 30 * 60, // 30 minutes
},
logging: config::LoggingConfig {
Expand Down Expand Up @@ -1180,11 +1152,9 @@ mod tests {
host: "127.0.0.1".to_string(),
port: 0,
},
model_discovery: config::ModelDiscoveryConfig {
discovery_server_url: "http://localhost:8080/models".to_string(),
inference_router: config::InferenceRouterConfig {
router_url: "http://localhost:8080".to_string(),
api_key: Some("test-key".to_string()),
refresh_interval: 0,
timeout: 5,
inference_timeout: 30 * 60, // 30 minutes
},
logging: config::LoggingConfig {
Expand Down
14 changes: 7 additions & 7 deletions crates/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,20 @@ async fn perform_coordinated_shutdown(
let mut coordinator = ShutdownCoordinator::new(Duration::from_secs(30));
coordinator.start();

tracing::info!("=== SHUTDOWN PHASE: CANCEL BACKGROUND TASKS ===");
tracing::info!("Cancelling all periodic background tasks");
tracing::info!("=== SHUTDOWN PHASE: CLEANUP RESOURCES ===");
tracing::info!("Cleaning up inference provider pool resources");

// Stage 1: Cancel background tasks (should be quick, 5-10 seconds)
// Stage 1: Clean up inference provider pool (should be quick, < 5 seconds)
let (status, remaining) = coordinator
.execute_stage(
ShutdownStage {
name: "Cancel Background Tasks",
timeout: Duration::from_secs(10),
name: "Cleanup Inference Resources",
timeout: Duration::from_secs(5),
},
|| async {
tracing::info!("Step 1.1: Cancelling model discovery refresh task");
tracing::info!("Step 1.1: Clearing inference provider pool mappings");
inference_provider_pool.shutdown().await;
tracing::debug!("All background tasks cancelled");
tracing::debug!("Inference provider pool cleanup complete");
},
)
.await;
Expand Down
3 changes: 3 additions & 0 deletions crates/api/src/routes/attestation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ pub struct AttestationResponse {
pub gateway_attestation: DstackCpuQuote,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub model_attestations: Vec<serde_json::Map<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub router_attestation: Option<serde_json::Map<String, serde_json::Value>>,
}

impl From<services::attestation::models::AttestationReport> for AttestationResponse {
fn from(report: services::attestation::models::AttestationReport) -> Self {
Self {
gateway_attestation: report.gateway_attestation.into(),
model_attestations: report.model_attestations,
router_attestation: report.router_attestation,
}
}
}
Expand Down
16 changes: 4 additions & 12 deletions crates/api/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,12 @@ pub fn test_config() -> ApiConfig {
.and_then(|p| p.parse().ok())
.unwrap_or(0), // Use port 0 to get a random available port
},
model_discovery: config::ModelDiscoveryConfig {
discovery_server_url: std::env::var("MODEL_DISCOVERY_SERVER_URL")
.unwrap_or_else(|_| "http://localhost:8080/models".to_string()),
api_key: std::env::var("MODEL_DISCOVERY_API_KEY")
inference_router: config::InferenceRouterConfig {
router_url: std::env::var("INFERENCE_ROUTER_URL")
.unwrap_or_else(|_| "http://localhost:8080".to_string()),
api_key: std::env::var("INFERENCE_ROUTER_API_KEY")
.ok()
.or(Some("test_api_key".to_string())),
refresh_interval: std::env::var("MODEL_DISCOVERY_REFRESH_INTERVAL")
.ok()
.and_then(|i| i.parse().ok())
.unwrap_or(3600), // 1 hour - large value to avoid refresh during tests
timeout: std::env::var("MODEL_DISCOVERY_TIMEOUT")
.ok()
.and_then(|t| t.parse().ok())
.unwrap_or(5),
inference_timeout: std::env::var("MODEL_INFERENCE_TIMEOUT")
.ok()
.and_then(|t| t.parse().ok())
Expand Down
46 changes: 14 additions & 32 deletions crates/config/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{collections::HashMap, env};
#[derive(Debug, Clone)]
pub struct ApiConfig {
pub server: ServerConfig,
pub model_discovery: ModelDiscoveryConfig,
pub inference_router: InferenceRouterConfig,
pub logging: LoggingConfig,
pub dstack_client: DstackClientConfig,
pub auth: AuthConfig,
Expand All @@ -17,7 +17,7 @@ impl ApiConfig {
pub fn from_env() -> Result<Self, String> {
Ok(Self {
server: ServerConfig::from_env()?,
model_discovery: ModelDiscoveryConfig::from_env()?,
inference_router: InferenceRouterConfig::from_env()?,
logging: LoggingConfig::from_env()?,
dstack_client: DstackClientConfig::from_env()?,
auth: AuthConfig::from_env()?,
Expand Down Expand Up @@ -115,29 +115,19 @@ impl ServerConfig {
}

#[derive(Debug, Clone)]
pub struct ModelDiscoveryConfig {
pub discovery_server_url: String,
pub struct InferenceRouterConfig {
pub router_url: String,
pub api_key: Option<String>,
pub refresh_interval: i64, // seconds
pub timeout: i64, // seconds (for discovery requests)
pub inference_timeout: i64, // seconds (for model inference requests)
}

impl ModelDiscoveryConfig {
impl InferenceRouterConfig {
/// Load from environment variables
pub fn from_env() -> Result<Self, String> {
Ok(Self {
discovery_server_url: env::var("MODEL_DISCOVERY_SERVER_URL")
.map_err(|_| "MODEL_DISCOVERY_SERVER_URL not set")?,
api_key: env::var("MODEL_DISCOVERY_API_KEY").ok(),
refresh_interval: env::var("MODEL_DISCOVERY_REFRESH_INTERVAL")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300), // 5 minutes
timeout: env::var("MODEL_DISCOVERY_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30), // 30 seconds
router_url: env::var("INFERENCE_ROUTER_URL")
.map_err(|_| "INFERENCE_ROUTER_URL not set")?,
api_key: env::var("INFERENCE_ROUTER_API_KEY").ok(),
inference_timeout: env::var("MODEL_INFERENCE_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
Expand All @@ -146,20 +136,12 @@ impl ModelDiscoveryConfig {
}
}

impl Default for ModelDiscoveryConfig {
impl Default for InferenceRouterConfig {
fn default() -> Self {
Self {
discovery_server_url: env::var("MODEL_DISCOVERY_SERVER_URL")
.expect("MODEL_DISCOVERY_SERVER_URL environment variable is required"),
api_key: env::var("MODEL_DISCOVERY_API_KEY").ok(),
refresh_interval: env::var("MODEL_DISCOVERY_REFRESH_INTERVAL")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300), // 5 minutes
timeout: env::var("MODEL_DISCOVERY_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30), // 30 seconds
router_url: env::var("INFERENCE_ROUTER_URL")
.expect("INFERENCE_ROUTER_URL environment variable is required"),
api_key: env::var("INFERENCE_ROUTER_API_KEY").ok(),
inference_timeout: env::var("MODEL_INFERENCE_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
Expand Down Expand Up @@ -232,7 +214,7 @@ impl DstackClientConfig {
// Domain-specific configuration types that will be used by domain layer
#[derive(Debug, Clone)]
pub struct DomainConfig {
pub model_discovery: ModelDiscoveryConfig,
pub inference_router: InferenceRouterConfig,
pub dstack_client: DstackClientConfig,
pub auth: AuthConfig,
}
Expand Down Expand Up @@ -366,7 +348,7 @@ impl From<GoogleOAuthConfig> for OAuthProviderConfig {
impl From<ApiConfig> for DomainConfig {
fn from(api_config: ApiConfig) -> Self {
Self {
model_discovery: api_config.model_discovery,
inference_router: api_config.inference_router,
dstack_client: api_config.dstack_client,
auth: api_config.auth,
}
Expand Down
20 changes: 18 additions & 2 deletions crates/services/src/attestation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@ impl ports::AttestationServiceTrait for AttestationService {
signing_address: Option<String>,
) -> Result<AttestationReport, AttestationError> {
// Resolve model name (could be an alias) and get model details
let mut model_attestations = vec![];
let mut model_attestations = Vec::new();
let mut router_attestation = None;
// Create a nonce if none was provided
let nonce = match nonce {
Some(n) => n,
Expand Down Expand Up @@ -540,7 +541,7 @@ impl ports::AttestationServiceTrait for AttestationService {
);
}

model_attestations = self
let mut provider_attestation = self
.inference_provider_pool
.get_attestation_report(
canonical_name.clone(),
Expand All @@ -550,6 +551,20 @@ impl ports::AttestationServiceTrait for AttestationService {
)
.await
.map_err(|e| AttestationError::ProviderError(e.to_string()))?;

// Extract model attestations from "all_attestations" field if present
if let Some(all_attestations_value) = provider_attestation.remove("all_attestations") {
if let Some(all_attestations_array) = all_attestations_value.as_array() {
for attestation in all_attestations_array {
if let Some(obj) = attestation.as_object() {
model_attestations.push(obj.clone());
}
}
}
}

// The remaining provider_attestation is the router attestation
router_attestation = Some(provider_attestation);
}

// Use VPC info loaded at initialization
Expand Down Expand Up @@ -643,6 +658,7 @@ impl ports::AttestationServiceTrait for AttestationService {
Ok(AttestationReport {
gateway_attestation,
model_attestations,
router_attestation,
})
}

Expand Down
1 change: 1 addition & 0 deletions crates/services/src/attestation/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ impl DstackCpuQuote {
pub struct AttestationReport {
pub gateway_attestation: DstackCpuQuote,
pub model_attestations: Vec<serde_json::Map<String, serde_json::Value>>,
pub router_attestation: Option<serde_json::Map<String, serde_json::Value>>,
}

pub type DstackAppInfo = dstack_sdk::dstack_client::InfoResponse;
Loading