Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 24 additions & 52 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,30 +401,23 @@ pub async fn init_domain_services_with_pool(
pub async fn init_inference_providers(
config: &ApiConfig,
) -> Arc<services::inference_provider_pool::InferenceProviderPool> {
let discovery_url = config.model_discovery.discovery_server_url.clone();
let api_key = config.model_discovery.api_key.clone();
let router_url = config.inference_router.router_url.clone();
let api_key = config.inference_router.api_key.clone();
let inference_timeout = config.inference_router.inference_timeout;

// Create pool with discovery URL and API key
let pool = Arc::new(
tracing::info!(
router_url = %router_url,
"Initializing inference provider pool with router endpoint"
);

// Create pool with single router endpoint
Arc::new(
services::inference_provider_pool::InferenceProviderPool::new(
discovery_url,
router_url,
api_key,
config.model_discovery.timeout,
config.model_discovery.inference_timeout,
inference_timeout,
),
);

// Initialize model discovery during startup
if pool.initialize().await.is_err() {
tracing::warn!("Failed to initialize model discovery during startup");
tracing::info!("Models will be discovered on first request");
}

// Start periodic refresh task with handle management
let refresh_interval = config.model_discovery.refresh_interval as u64;
pool.clone().start_refresh_task(refresh_interval).await;

pool
)
}

/// Initialize inference provider pool with mock providers for testing
Expand All @@ -439,39 +432,20 @@ pub async fn init_inference_providers_with_mocks(
use inference_providers::MockProvider;
use std::sync::Arc;

// Create pool with dummy discovery URL (won't be used since we're registering providers directly)
// Create a MockProvider that accepts all models (using new_accept_all)
let mock_provider = Arc::new(MockProvider::new_accept_all());

// For the new simplified pool, we need to create a wrapper pool that uses the mock provider
// Since the pool now expects a router URL, we'll create it with a dummy URL
// but the mock will be used directly in tests via a different mechanism
let pool = Arc::new(
services::inference_provider_pool::InferenceProviderPool::new(
"http://localhost:8080/models".to_string(),
"http://localhost:8080".to_string(),
None,
5,
30 * 60,
),
);
Comment on lines 435 to 447
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock provider is created but never used. The pool is initialized with a real VLlmProvider pointing to "http://localhost:8080", which means tests will attempt real HTTP requests instead of using the mock. This defeats the purpose of having a mock provider.

Consider creating a mechanism to inject the mock provider into the pool, or create a separate test-only pool constructor that accepts a provider directly. For example:

// Add to InferenceProviderPool
pub fn new_with_provider(provider: Arc<InferenceProviderTrait>) -> Self {
    Self {
        router_provider: provider,
        chat_id_mapping: Arc::new(RwLock::new(HashMap::new())),
        signature_hashes: Arc::new(RwLock::new(HashMap::new())),
    }
}

// Then in this function:
let mock_provider_trait: Arc<InferenceProviderTrait> = mock_provider.clone();
let pool = Arc::new(
    services::inference_provider_pool::InferenceProviderPool::new_with_provider(
        mock_provider_trait
    )
);

Copilot uses AI. Check for mistakes.

// Create a MockProvider that accepts all models (using new_accept_all)
let mock_provider = Arc::new(MockProvider::new_accept_all());
let mock_provider_trait: Arc<dyn inference_providers::InferenceProvider + Send + Sync> =
mock_provider.clone();

// Register providers for models commonly used in tests
let test_models = vec![
"Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(),
"zai-org/GLM-4.6".to_string(),
"nearai/gpt-oss-120b".to_string(),
"dphn/Dolphin-Mistral-24B-Venice-Edition".to_string(),
];

let providers: Vec<(
String,
Arc<dyn inference_providers::InferenceProvider + Send + Sync>,
)> = test_models
.into_iter()
.map(|model_id| (model_id, mock_provider_trait.clone()))
.collect();

pool.register_providers(providers).await;

tracing::info!("Initialized inference provider pool with MockProvider for testing");

(pool, mock_provider)
Expand Down Expand Up @@ -1081,10 +1055,9 @@ mod tests {
host: "127.0.0.1".to_string(),
port: 0, // Use port 0 for testing to get a random available port
},
model_discovery: config::ModelDiscoveryConfig {
discovery_server_url: "http://localhost:8080/models".to_string(),
inference_router: config::InferenceRouterConfig {
router_url: "http://localhost:8080".to_string(),
api_key: Some("test-key".to_string()),
refresh_interval: 0,
timeout: 5,
inference_timeout: 30 * 60, // 30 minutes
},
Expand Down Expand Up @@ -1180,10 +1153,9 @@ mod tests {
host: "127.0.0.1".to_string(),
port: 0,
},
model_discovery: config::ModelDiscoveryConfig {
discovery_server_url: "http://localhost:8080/models".to_string(),
inference_router: config::InferenceRouterConfig {
router_url: "http://localhost:8080".to_string(),
api_key: Some("test-key".to_string()),
refresh_interval: 0,
timeout: 5,
inference_timeout: 30 * 60, // 30 minutes
},
Expand Down
14 changes: 7 additions & 7 deletions crates/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,20 @@ async fn perform_coordinated_shutdown(
let mut coordinator = ShutdownCoordinator::new(Duration::from_secs(30));
coordinator.start();

tracing::info!("=== SHUTDOWN PHASE: CANCEL BACKGROUND TASKS ===");
tracing::info!("Cancelling all periodic background tasks");
tracing::info!("=== SHUTDOWN PHASE: CLEANUP RESOURCES ===");
tracing::info!("Cleaning up inference provider pool resources");

// Stage 1: Cancel background tasks (should be quick, 5-10 seconds)
// Stage 1: Clean up inference provider pool (should be quick, < 5 seconds)
let (status, remaining) = coordinator
.execute_stage(
ShutdownStage {
name: "Cancel Background Tasks",
timeout: Duration::from_secs(10),
name: "Cleanup Inference Resources",
timeout: Duration::from_secs(5),
},
|| async {
tracing::info!("Step 1.1: Cancelling model discovery refresh task");
tracing::info!("Step 1.1: Clearing inference provider pool mappings");
inference_provider_pool.shutdown().await;
tracing::debug!("All background tasks cancelled");
tracing::debug!("Inference provider pool cleanup complete");
},
)
.await;
Expand Down
14 changes: 5 additions & 9 deletions crates/api/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,13 @@ pub fn test_config() -> ApiConfig {
.and_then(|p| p.parse().ok())
.unwrap_or(0), // Use port 0 to get a random available port
},
model_discovery: config::ModelDiscoveryConfig {
discovery_server_url: std::env::var("MODEL_DISCOVERY_SERVER_URL")
.unwrap_or_else(|_| "http://localhost:8080/models".to_string()),
api_key: std::env::var("MODEL_DISCOVERY_API_KEY")
inference_router: config::InferenceRouterConfig {
router_url: std::env::var("INFERENCE_ROUTER_URL")
.unwrap_or_else(|_| "http://localhost:8080".to_string()),
api_key: std::env::var("INFERENCE_ROUTER_API_KEY")
.ok()
.or(Some("test_api_key".to_string())),
refresh_interval: std::env::var("MODEL_DISCOVERY_REFRESH_INTERVAL")
.ok()
.and_then(|i| i.parse().ok())
.unwrap_or(3600), // 1 hour - large value to avoid refresh during tests
timeout: std::env::var("MODEL_DISCOVERY_TIMEOUT")
timeout: std::env::var("INFERENCE_ROUTER_TIMEOUT")
.ok()
.and_then(|t| t.parse().ok())
.unwrap_or(5),
Expand Down
43 changes: 17 additions & 26 deletions crates/config/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{collections::HashMap, env};
#[derive(Debug, Clone)]
pub struct ApiConfig {
pub server: ServerConfig,
pub model_discovery: ModelDiscoveryConfig,
pub inference_router: InferenceRouterConfig,
pub logging: LoggingConfig,
pub dstack_client: DstackClientConfig,
pub auth: AuthConfig,
Expand All @@ -17,7 +17,7 @@ impl ApiConfig {
pub fn from_env() -> Result<Self, String> {
Ok(Self {
server: ServerConfig::from_env()?,
model_discovery: ModelDiscoveryConfig::from_env()?,
inference_router: InferenceRouterConfig::from_env()?,
logging: LoggingConfig::from_env()?,
dstack_client: DstackClientConfig::from_env()?,
auth: AuthConfig::from_env()?,
Expand Down Expand Up @@ -115,26 +115,21 @@ impl ServerConfig {
}

#[derive(Debug, Clone)]
pub struct ModelDiscoveryConfig {
pub discovery_server_url: String,
pub struct InferenceRouterConfig {
pub router_url: String,
pub api_key: Option<String>,
pub refresh_interval: i64, // seconds
pub timeout: i64, // seconds (for discovery requests)
pub timeout: i64, // seconds (for router requests)
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timeout field is defined but never used in the new implementation. The pool creation only uses inference_timeout (via config.inference_router.inference_timeout). This appears to be leftover from the previous model discovery implementation where it was used for discovery HTTP requests.

Consider either:

  1. Removing this field entirely if it's no longer needed
  2. Using it for the router's HTTP client timeout if that's the intent
  3. Documenting why it exists and when it might be used in the future

If removing, also update:

  • The from_env() method (lines 132-135)
  • The Default implementation (lines 150-153)
  • The corresponding environment variable INFERENCE_ROUTER_TIMEOUT

Copilot uses AI. Check for mistakes.
pub inference_timeout: i64, // seconds (for model inference requests)
}

impl ModelDiscoveryConfig {
impl InferenceRouterConfig {
/// Load from environment variables
pub fn from_env() -> Result<Self, String> {
Ok(Self {
discovery_server_url: env::var("MODEL_DISCOVERY_SERVER_URL")
.map_err(|_| "MODEL_DISCOVERY_SERVER_URL not set")?,
api_key: env::var("MODEL_DISCOVERY_API_KEY").ok(),
refresh_interval: env::var("MODEL_DISCOVERY_REFRESH_INTERVAL")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300), // 5 minutes
timeout: env::var("MODEL_DISCOVERY_TIMEOUT")
router_url: env::var("INFERENCE_ROUTER_URL")
.map_err(|_| "INFERENCE_ROUTER_URL not set")?,
api_key: env::var("INFERENCE_ROUTER_API_KEY").ok(),
timeout: env::var("INFERENCE_ROUTER_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30), // 30 seconds
Expand All @@ -146,17 +141,13 @@ impl ModelDiscoveryConfig {
}
}

impl Default for ModelDiscoveryConfig {
impl Default for InferenceRouterConfig {
fn default() -> Self {
Self {
discovery_server_url: env::var("MODEL_DISCOVERY_SERVER_URL")
.expect("MODEL_DISCOVERY_SERVER_URL environment variable is required"),
api_key: env::var("MODEL_DISCOVERY_API_KEY").ok(),
refresh_interval: env::var("MODEL_DISCOVERY_REFRESH_INTERVAL")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300), // 5 minutes
timeout: env::var("MODEL_DISCOVERY_TIMEOUT")
router_url: env::var("INFERENCE_ROUTER_URL")
.expect("INFERENCE_ROUTER_URL environment variable is required"),
api_key: env::var("INFERENCE_ROUTER_API_KEY").ok(),
timeout: env::var("INFERENCE_ROUTER_TIMEOUT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30), // 30 seconds
Expand Down Expand Up @@ -232,7 +223,7 @@ impl DstackClientConfig {
// Domain-specific configuration types that will be used by domain layer
#[derive(Debug, Clone)]
pub struct DomainConfig {
pub model_discovery: ModelDiscoveryConfig,
pub inference_router: InferenceRouterConfig,
pub dstack_client: DstackClientConfig,
pub auth: AuthConfig,
}
Expand Down Expand Up @@ -366,7 +357,7 @@ impl From<GoogleOAuthConfig> for OAuthProviderConfig {
impl From<ApiConfig> for DomainConfig {
fn from(api_config: ApiConfig) -> Self {
Self {
model_discovery: api_config.model_discovery,
inference_router: api_config.inference_router,
dstack_client: api_config.dstack_client,
auth: api_config.auth,
}
Expand Down
Loading
Loading