-
Notifications
You must be signed in to change notification settings - Fork 0
Migrate cloud-api to using a single router endpoint. #240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -401,30 +401,23 @@ pub async fn init_domain_services_with_pool( | |
| pub async fn init_inference_providers( | ||
| config: &ApiConfig, | ||
| ) -> Arc<services::inference_provider_pool::InferenceProviderPool> { | ||
| let discovery_url = config.model_discovery.discovery_server_url.clone(); | ||
| let api_key = config.model_discovery.api_key.clone(); | ||
| let router_url = config.inference_router.router_url.clone(); | ||
| let api_key = config.inference_router.api_key.clone(); | ||
| let inference_timeout = config.inference_router.inference_timeout; | ||
|
|
||
| // Create pool with discovery URL and API key | ||
| let pool = Arc::new( | ||
| tracing::info!( | ||
| router_url = %router_url, | ||
| "Initializing inference provider pool with router endpoint" | ||
| ); | ||
|
|
||
| // Create pool with single router endpoint | ||
| Arc::new( | ||
| services::inference_provider_pool::InferenceProviderPool::new( | ||
| discovery_url, | ||
| router_url, | ||
| api_key, | ||
| config.model_discovery.timeout, | ||
| config.model_discovery.inference_timeout, | ||
| inference_timeout, | ||
| ), | ||
| ); | ||
|
|
||
| // Initialize model discovery during startup | ||
| if pool.initialize().await.is_err() { | ||
| tracing::warn!("Failed to initialize model discovery during startup"); | ||
| tracing::info!("Models will be discovered on first request"); | ||
| } | ||
|
|
||
| // Start periodic refresh task with handle management | ||
| let refresh_interval = config.model_discovery.refresh_interval as u64; | ||
| pool.clone().start_refresh_task(refresh_interval).await; | ||
|
|
||
| pool | ||
| ) | ||
| } | ||
|
|
||
| /// Initialize inference provider pool with mock providers for testing | ||
|
|
@@ -439,39 +432,20 @@ pub async fn init_inference_providers_with_mocks( | |
| use inference_providers::MockProvider; | ||
| use std::sync::Arc; | ||
|
|
||
| // Create pool with dummy discovery URL (won't be used since we're registering providers directly) | ||
| // Create a MockProvider that accepts all models (using new_accept_all) | ||
| let mock_provider = Arc::new(MockProvider::new_accept_all()); | ||
|
|
||
| // For the new simplified pool, we need to create a wrapper pool that uses the mock provider | ||
| // Since the pool now expects a router URL, we'll create it with a dummy URL | ||
| // but the mock will be used directly in tests via a different mechanism | ||
| let pool = Arc::new( | ||
| services::inference_provider_pool::InferenceProviderPool::new( | ||
| "http://localhost:8080/models".to_string(), | ||
| "http://localhost:8080".to_string(), | ||
cursor[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| None, | ||
| 5, | ||
| 30 * 60, | ||
| ), | ||
| ); | ||
|
Comment on lines
435
to
447
|
||
|
|
||
| // Create a MockProvider that accepts all models (using new_accept_all) | ||
| let mock_provider = Arc::new(MockProvider::new_accept_all()); | ||
| let mock_provider_trait: Arc<dyn inference_providers::InferenceProvider + Send + Sync> = | ||
| mock_provider.clone(); | ||
|
|
||
| // Register providers for models commonly used in tests | ||
| let test_models = vec![ | ||
| "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(), | ||
| "zai-org/GLM-4.6".to_string(), | ||
| "nearai/gpt-oss-120b".to_string(), | ||
| "dphn/Dolphin-Mistral-24B-Venice-Edition".to_string(), | ||
| ]; | ||
|
|
||
| let providers: Vec<( | ||
| String, | ||
| Arc<dyn inference_providers::InferenceProvider + Send + Sync>, | ||
| )> = test_models | ||
| .into_iter() | ||
| .map(|model_id| (model_id, mock_provider_trait.clone())) | ||
| .collect(); | ||
|
|
||
| pool.register_providers(providers).await; | ||
|
|
||
| tracing::info!("Initialized inference provider pool with MockProvider for testing"); | ||
|
|
||
| (pool, mock_provider) | ||
|
|
@@ -1081,10 +1055,9 @@ mod tests { | |
| host: "127.0.0.1".to_string(), | ||
| port: 0, // Use port 0 for testing to get a random available port | ||
| }, | ||
| model_discovery: config::ModelDiscoveryConfig { | ||
| discovery_server_url: "http://localhost:8080/models".to_string(), | ||
| inference_router: config::InferenceRouterConfig { | ||
| router_url: "http://localhost:8080".to_string(), | ||
| api_key: Some("test-key".to_string()), | ||
| refresh_interval: 0, | ||
| timeout: 5, | ||
| inference_timeout: 30 * 60, // 30 minutes | ||
| }, | ||
|
|
@@ -1180,10 +1153,9 @@ mod tests { | |
| host: "127.0.0.1".to_string(), | ||
| port: 0, | ||
| }, | ||
| model_discovery: config::ModelDiscoveryConfig { | ||
| discovery_server_url: "http://localhost:8080/models".to_string(), | ||
| inference_router: config::InferenceRouterConfig { | ||
| router_url: "http://localhost:8080".to_string(), | ||
| api_key: Some("test-key".to_string()), | ||
| refresh_interval: 0, | ||
| timeout: 5, | ||
| inference_timeout: 30 * 60, // 30 minutes | ||
| }, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ use std::{collections::HashMap, env}; | |
| #[derive(Debug, Clone)] | ||
| pub struct ApiConfig { | ||
| pub server: ServerConfig, | ||
| pub model_discovery: ModelDiscoveryConfig, | ||
| pub inference_router: InferenceRouterConfig, | ||
| pub logging: LoggingConfig, | ||
| pub dstack_client: DstackClientConfig, | ||
| pub auth: AuthConfig, | ||
|
|
@@ -17,7 +17,7 @@ impl ApiConfig { | |
| pub fn from_env() -> Result<Self, String> { | ||
| Ok(Self { | ||
| server: ServerConfig::from_env()?, | ||
| model_discovery: ModelDiscoveryConfig::from_env()?, | ||
| inference_router: InferenceRouterConfig::from_env()?, | ||
| logging: LoggingConfig::from_env()?, | ||
| dstack_client: DstackClientConfig::from_env()?, | ||
| auth: AuthConfig::from_env()?, | ||
|
|
@@ -115,26 +115,21 @@ impl ServerConfig { | |
| } | ||
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct ModelDiscoveryConfig { | ||
| pub discovery_server_url: String, | ||
| pub struct InferenceRouterConfig { | ||
| pub router_url: String, | ||
| pub api_key: Option<String>, | ||
| pub refresh_interval: i64, // seconds | ||
| pub timeout: i64, // seconds (for discovery requests) | ||
| pub timeout: i64, // seconds (for router requests) | ||
|
||
| pub inference_timeout: i64, // seconds (for model inference requests) | ||
| } | ||
|
|
||
| impl ModelDiscoveryConfig { | ||
| impl InferenceRouterConfig { | ||
| /// Load from environment variables | ||
| pub fn from_env() -> Result<Self, String> { | ||
| Ok(Self { | ||
| discovery_server_url: env::var("MODEL_DISCOVERY_SERVER_URL") | ||
| .map_err(|_| "MODEL_DISCOVERY_SERVER_URL not set")?, | ||
| api_key: env::var("MODEL_DISCOVERY_API_KEY").ok(), | ||
| refresh_interval: env::var("MODEL_DISCOVERY_REFRESH_INTERVAL") | ||
| .ok() | ||
| .and_then(|s| s.parse().ok()) | ||
| .unwrap_or(300), // 5 minutes | ||
| timeout: env::var("MODEL_DISCOVERY_TIMEOUT") | ||
| router_url: env::var("INFERENCE_ROUTER_URL") | ||
| .map_err(|_| "INFERENCE_ROUTER_URL not set")?, | ||
| api_key: env::var("INFERENCE_ROUTER_API_KEY").ok(), | ||
| timeout: env::var("INFERENCE_ROUTER_TIMEOUT") | ||
| .ok() | ||
| .and_then(|s| s.parse().ok()) | ||
| .unwrap_or(30), // 30 seconds | ||
|
|
@@ -146,17 +141,13 @@ impl ModelDiscoveryConfig { | |
| } | ||
| } | ||
|
|
||
| impl Default for ModelDiscoveryConfig { | ||
| impl Default for InferenceRouterConfig { | ||
| fn default() -> Self { | ||
| Self { | ||
| discovery_server_url: env::var("MODEL_DISCOVERY_SERVER_URL") | ||
| .expect("MODEL_DISCOVERY_SERVER_URL environment variable is required"), | ||
| api_key: env::var("MODEL_DISCOVERY_API_KEY").ok(), | ||
| refresh_interval: env::var("MODEL_DISCOVERY_REFRESH_INTERVAL") | ||
| .ok() | ||
| .and_then(|s| s.parse().ok()) | ||
| .unwrap_or(300), // 5 minutes | ||
| timeout: env::var("MODEL_DISCOVERY_TIMEOUT") | ||
| router_url: env::var("INFERENCE_ROUTER_URL") | ||
| .expect("INFERENCE_ROUTER_URL environment variable is required"), | ||
| api_key: env::var("INFERENCE_ROUTER_API_KEY").ok(), | ||
| timeout: env::var("INFERENCE_ROUTER_TIMEOUT") | ||
| .ok() | ||
| .and_then(|s| s.parse().ok()) | ||
| .unwrap_or(30), // 30 seconds | ||
|
|
@@ -232,7 +223,7 @@ impl DstackClientConfig { | |
| // Domain-specific configuration types that will be used by domain layer | ||
| #[derive(Debug, Clone)] | ||
| pub struct DomainConfig { | ||
| pub model_discovery: ModelDiscoveryConfig, | ||
| pub inference_router: InferenceRouterConfig, | ||
| pub dstack_client: DstackClientConfig, | ||
| pub auth: AuthConfig, | ||
| } | ||
|
|
@@ -366,7 +357,7 @@ impl From<GoogleOAuthConfig> for OAuthProviderConfig { | |
| impl From<ApiConfig> for DomainConfig { | ||
| fn from(api_config: ApiConfig) -> Self { | ||
| Self { | ||
| model_discovery: api_config.model_discovery, | ||
| inference_router: api_config.inference_router, | ||
| dstack_client: api_config.dstack_client, | ||
| auth: api_config.auth, | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.