Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,7 @@ jobs:
DEV: "true"

- name: Run integration tests
run: |
for test in crates/api/tests/e2e_*.rs; do
name=$(basename "$test" .rs)
cargo test --test "$name" -p api
done
run: cargo test -p api --tests
env:
POSTGRES_PRIMARY_APP_ID: ${{ secrets.POSTGRES_PRIMARY_APP_ID }}
DATABASE_HOST: localhost
Expand All @@ -94,6 +90,7 @@ jobs:
AUTH_ENCODING_KEY: ${{ secrets.AUTH_ENCODING_KEY }}
AUTH_ADMIN_DOMAINS: ${{ secrets.AUTH_ADMIN_DOMAINS }}
BRAVE_SEARCH_PRO_API_KEY: ${{ secrets.BRAVE_SEARCH_PRO_API_KEY }}
TEST_DATABASE_NAME: platform_api_test
RUST_LOG: debug
DEV: "true"

Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ dotenvy = "0.15.7"
k256 = { version = "0.13", features = ["ecdsa", "arithmetic"] }
sha3 = "0.10"
hmac = "0.12"
tokio-postgres = "0.7"
ed25519-dalek = { version = "2.1", features = ["rand_core"] }
103 changes: 103 additions & 0 deletions crates/api/tests/common/db_setup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::env;
use tokio_postgres::NoTls;
use tracing::{error, info, warn};

/// Get test database name from environment or default
pub fn get_test_db_name() -> String {
env::var("TEST_DATABASE_NAME").unwrap_or_else(|_| "platform_api_test".to_string())
}

/// Get admin database name - try 'postgres' first, fallback to 'template1' if not available
async fn get_admin_db_name(
host: &str,
port: u16,
username: &str,
password: &str,
) -> Result<String, String> {
// Try 'postgres' first (most common)
if can_connect_to_db(host, port, username, password, "postgres").await {
return Ok("postgres".to_string());
}

// Fallback to 'template1' (always exists in PostgreSQL)
if can_connect_to_db(host, port, username, password, "template1").await {
warn!("'postgres' database not found, using 'template1' as admin database");
return Ok("template1".to_string());
}

Err("Neither 'postgres' nor 'template1' database found".to_string())
}

async fn can_connect_to_db(
host: &str,
port: u16,
username: &str,
password: &str,
dbname: &str,
) -> bool {
let conn_string =
format!("host={host} port={port} user={username} password={password} dbname={dbname}");
tokio_postgres::connect(&conn_string, NoTls).await.is_ok()
}

pub async fn reset_test_database(config: &config::DatabaseConfig) -> Result<(), String> {
let test_db_name = get_test_db_name();

// Safety check - only allow resetting test database
if !test_db_name.contains("test") {
panic!("Safety: Can only reset databases with 'test' in the name. Got: {test_db_name}");
}

let host = config
.host
.clone()
.unwrap_or_else(|| "localhost".to_string());
let port = config.port;
let username = config.username.clone();
let password = config.password.clone();

// Find available admin database
let admin_db = get_admin_db_name(&host, port, &username, &password).await?;

let conn_string =
format!("host={host} port={port} user={username} password={password} dbname={admin_db}");

let (client, connection) = tokio_postgres::connect(&conn_string, NoTls)
.await
.map_err(|e| format!("Failed to connect to admin database: {e}"))?;

tokio::spawn(async move {
if let Err(e) = connection.await {
error!("Database connection error: {}", e);
}
});

// Terminate existing connections to allow DROP
let _ = client
.execute(
&format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity
WHERE datname = '{test_db_name}' AND pid <> pg_backend_pid()"
),
&[],
)
.await;

// Drop database if exists
let drop_result = client
.execute(&format!("DROP DATABASE IF EXISTS {test_db_name}"), &[])
.await;

if let Err(e) = drop_result {
warn!("Failed to drop test database (may not exist): {}", e);
}

// Create fresh database
client
.execute(&format!("CREATE DATABASE {test_db_name}"), &[])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: SQL injection vulnerability via unsanitized database name

The test_db_name value from the DATABASE_TEST_NAME environment variable is directly interpolated into SQL statements (DROP DATABASE IF EXISTS {test_db_name} and CREATE DATABASE {test_db_name}) without proper identifier quoting. The safety check test_db_name.contains("test") can be bypassed by crafting a malicious value like test"; DROP DATABASE production; --. Database identifiers in PostgreSQL need to be quoted with double quotes to prevent injection when used in DDL statements.

Additional Locations (1)

Fix in Cursor Fix in Web

.await
.map_err(|e| format!("Failed to create test database: {e}"))?;
Comment on lines +76 to +99
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQL injection vulnerability: The database name test_db_name is directly interpolated into SQL queries without proper sanitization or quoting. While there's a safety check that the name contains "test", a malicious database name like test'; DROP TABLE users; -- could bypass this check and execute arbitrary SQL.

Use parameterized queries or properly quote/escape the database name. PostgreSQL identifiers should be wrapped in double quotes and escaped. Consider using the format!() macro with proper identifier escaping or a library function to safely handle database identifiers.

Copilot uses AI. Check for mistakes.

info!("Test database '{}' reset successfully", test_db_name);
Ok(())
}
28 changes: 24 additions & 4 deletions crates/api/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(dead_code)]

mod db_setup;

use api::{build_app_with_config, init_auth_services, models::BatchUpdateModelApiRequest};
use base64::Engine;
use chrono::Utc;
Expand All @@ -21,6 +23,9 @@ use sha3::Keccak256;
// Global once cell to ensure migrations only run once across all tests
static MIGRATIONS_INITIALIZED: OnceCell<()> = OnceCell::const_new();

// Global once cell to ensure database reset happens only once per test run
static RESET_DONE: OnceCell<()> = OnceCell::const_new();

// Constants for mock test data
pub const MOCK_USER_ID: &str = "11111111-1111-1111-1111-111111111111";

Expand Down Expand Up @@ -95,7 +100,7 @@ fn db_config_for_tests() -> config::DatabaseConfig {
gateway_subdomain: "cvm1.near.ai".to_string(),
port: 5432,
host: None,
database: "platform_api".to_string(),
database: db_setup::get_test_db_name(),
username: std::env::var("DATABASE_USERNAME").unwrap_or("postgres".to_string()),
password: std::env::var("DATABASE_PASSWORD").unwrap_or("postgres".to_string()),
max_connections: 2,
Expand Down Expand Up @@ -134,6 +139,15 @@ pub async fn get_access_token_from_refresh_token(

/// Initialize database with migrations running only once
pub async fn init_test_database(config: &config::DatabaseConfig) -> Arc<Database> {
// Reset database once per test run
RESET_DONE
.get_or_init(|| async {
db_setup::reset_test_database(config)
.await
.expect("Failed to reset test database");
})
.await;

let database = Arc::new(
Database::from_config(config)
.await
Expand Down Expand Up @@ -162,11 +176,12 @@ pub async fn setup_test_server() -> axum_test::TestServer {
}

/// Setup a complete test server with all components initialized
/// Returns a tuple of (TestServer, InferenceProviderPool, MockProvider) for advanced testing
/// Returns a tuple of (TestServer, InferenceProviderPool, MockProvider, Database) for advanced testing
pub async fn setup_test_server_with_pool() -> (
axum_test::TestServer,
std::sync::Arc<services::inference_provider_pool::InferenceProviderPool>,
std::sync::Arc<inference_providers::mock::MockProvider>,
Arc<Database>,
) {
let _ = tracing_subscriber::fmt()
.with_test_writer()
Expand Down Expand Up @@ -195,10 +210,15 @@ pub async fn setup_test_server_with_pool() -> (
)
.await;

let app = build_app_with_config(database, auth_components, domain_services, Arc::new(config));
let app = build_app_with_config(
database.clone(),
auth_components,
domain_services,
Arc::new(config),
);
let server = axum_test::TestServer::new(app).unwrap();

(server, inference_provider_pool, mock_provider)
(server, inference_provider_pool, mock_provider, database)
}

/// Create the mock user in the database to satisfy foreign key constraints
Expand Down
8 changes: 7 additions & 1 deletion crates/api/tests/e2e_conversations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ async fn create_response_stream(
#[tokio::test]
async fn test_responses_api() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add setup_qwen_model() for the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test calls /v1/models API and asserts that at least one model is available (assert!(!models.data.is_empty())). Without setup_qwen_model(), no models are registered in the database, causing the endpoint to return an empty list and the test to fail immediately.

let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -249,6 +250,7 @@ async fn test_responses_api() {
#[tokio::test]
async fn test_streaming_responses_api() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -1121,6 +1123,7 @@ async fn test_create_conversation_items_different_roles() {
#[tokio::test]
async fn test_conversation_items_pagination() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await;
let api_key = get_api_key_for_org(&server, org.id).await;
let models = list_models(&server, api_key.clone()).await;
Expand Down Expand Up @@ -1412,6 +1415,7 @@ async fn test_response_previous_next_relationships() {
#[tokio::test]
async fn test_response_previous_next_relationships_streaming() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -1943,6 +1947,7 @@ async fn test_clone_conversation() {
#[tokio::test]
async fn test_clone_conversation_with_responses_and_items() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -2698,6 +2703,7 @@ async fn test_conversation_items_include_model() {
#[tokio::test]
async fn test_conversation_items_model_with_streaming() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -3035,7 +3041,7 @@ async fn test_batch_get_conversations() {
async fn test_conversation_title_strips_thinking_tags() {
use inference_providers::mock::ResponseTemplate;

let (server, _pool, mock_provider) = setup_test_server_with_pool().await;
let (server, _pool, mock_provider, _db) = setup_test_server_with_pool().await;
let org = setup_org_with_credits(&server, 10000000000i64).await;
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down
5 changes: 4 additions & 1 deletion crates/api/tests/e2e_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ async fn test_complete_file_lifecycle() {

#[tokio::test]
async fn test_file_in_response_api() {
let (server, _pool, mock) = setup_test_server_with_pool().await;
let (server, _pool, mock, _database) = setup_test_server_with_pool().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -1028,6 +1029,7 @@ async fn test_file_in_response_api() {
#[tokio::test]
async fn test_file_not_found_in_response_api() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down Expand Up @@ -1086,6 +1088,7 @@ async fn test_file_not_found_in_response_api() {
#[tokio::test]
async fn test_multiple_files_in_response_api() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
let api_key = get_api_key_for_org(&server, org.id).await;

Expand Down
1 change: 1 addition & 0 deletions crates/api/tests/e2e_org_system_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async fn test_system_prompt_integration_with_responses() {
let access_token = get_access_token_from_refresh_token(&server, get_session_id()).await;

setup_glm_model(&server).await;
setup_qwen_model(&server).await;

// Set system prompt
server
Expand Down
44 changes: 11 additions & 33 deletions crates/api/tests/e2e_repositories.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,13 @@
mod common;

use chrono::{Duration, Utc};
use common::*;
use database::{Database, OAuthStateRepository};

// Helper to create database pool for repository testing
async fn create_test_pool() -> database::pool::DbPool {
let config = config::DatabaseConfig {
primary_app_id: "postgres-test".to_string(),
gateway_subdomain: "cvm1.near.ai".to_string(),
port: 5432,
host: None,
database: "platform_api".to_string(),
username: std::env::var("DATABASE_USERNAME").unwrap_or("postgres".to_string()),
password: std::env::var("DATABASE_PASSWORD").unwrap_or("postgres".to_string()),
max_connections: 2,
tls_enabled: false,
tls_ca_cert_path: None,
refresh_interval: 30,
mock: false,
};

Database::from_config(&config).await.unwrap().pool().clone()
use database::OAuthStateRepository;

// Helper to get database pool for repository testing
async fn get_test_pool() -> database::pool::DbPool {
let (_server, _inference_provider_pool, _mock_provider, database) =
common::setup_test_server_with_pool().await;
database.pool().clone()
}

// ============================================
Expand All @@ -32,9 +18,7 @@ async fn create_test_pool() -> database::pool::DbPool {

#[tokio::test]
async fn test_create_and_get_oauth_state() {
let _ = setup_test_server().await; // Initialize DB once via OnceCell

let pool = create_test_pool().await;
let pool = get_test_pool().await;
let repo = OAuthStateRepository::new(pool.clone());

let state = format!("test-state-{}", uuid::Uuid::new_v4());
Expand Down Expand Up @@ -63,9 +47,7 @@ async fn test_create_and_get_oauth_state() {

#[tokio::test]
async fn test_expired_state_not_returned() {
let _ = setup_test_server().await; // Initialize DB once via OnceCell

let pool = create_test_pool().await;
let pool = get_test_pool().await;
let repo = OAuthStateRepository::new(pool.clone());

let state = format!("test-state-{}", uuid::Uuid::new_v4());
Expand All @@ -91,9 +73,7 @@ async fn test_expired_state_not_returned() {

#[tokio::test]
async fn test_google_with_pkce_verifier() {
let _ = setup_test_server().await; // Initialize DB once via OnceCell

let pool = create_test_pool().await;
let pool = get_test_pool().await;
let repo = OAuthStateRepository::new(pool.clone());

let state = format!("test-state-{}", uuid::Uuid::new_v4());
Expand All @@ -114,9 +94,7 @@ async fn test_google_with_pkce_verifier() {

#[tokio::test]
async fn test_state_replay_protection() {
let _ = setup_test_server().await; // Initialize DB once via OnceCell

let pool = create_test_pool().await;
let pool = get_test_pool().await;
let repo = OAuthStateRepository::new(pool.clone());

let state = format!("test-state-{}", uuid::Uuid::new_v4());
Expand Down
1 change: 1 addition & 0 deletions crates/api/tests/e2e_response_signature_verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use common::*;
#[tokio::test]
async fn test_streaming_response_signature_verification() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
println!("Created organization: {}", org.id);

Expand Down
1 change: 1 addition & 0 deletions crates/api/tests/e2e_signature_verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use inference_providers::StreamChunk;
#[tokio::test]
async fn test_streaming_chat_completion_signature_verification() {
let server = setup_test_server().await;
setup_qwen_model(&server).await;
let org = setup_org_with_credits(&server, 10000000000i64).await; // $10.00 USD
println!("Created organization: {}", org.id);

Expand Down
Loading