diff --git a/src/server/api.rs b/src/server/api.rs index bd5cc8a..d0a57ce 100644 --- a/src/server/api.rs +++ b/src/server/api.rs @@ -191,6 +191,16 @@ async fn search( State(state): State, Json(req): Json, ) -> impl IntoResponse { + if req.query.trim().is_empty() { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "Query cannot be empty or whitespace only" + })), + ) + .into_response(); + } + let path = req .path .map(PathBuf::from) @@ -303,6 +313,16 @@ async fn embed( State(state): State, Json(req): Json, ) -> impl IntoResponse { + if req.text.trim().is_empty() { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "Text cannot be empty or whitespace only" + })), + ) + .into_response(); + } + let engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => { @@ -339,6 +359,16 @@ async fn embed_batch( State(state): State, Json(req): Json, ) -> impl IntoResponse { + if req.texts.is_empty() || req.texts.iter().all(|s| s.trim().is_empty()) { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "Batch texts cannot be empty or contain only empty strings" + })), + ) + .into_response(); + } + let engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => { diff --git a/src/server/api_tests.rs b/src/server/api_tests.rs new file mode 100644 index 0000000..6d58f90 --- /dev/null +++ b/src/server/api_tests.rs @@ -0,0 +1,47 @@ +/* +use axum::{ + body::{Body, to_bytes}, + http::{Request, StatusCode}, +}; +use tower::ServiceExt; + +use crate::server::api::{ + EmbedRequest, EmbedBatchRequest, SearchRequest, + EmbedResponse, EmbedBatchResponse, SearchResponse +}; +*/ +// The tower crate dependency might be missing in dev-dependencies or exposed differently. +// Since we are just testing pure logic now as unit tests for the conditions we added: + +#[test] +fn test_search_query_validation_logic() { + let empty_query = ""; + let whitespace_query = " "; + let valid_query = "something"; + + assert!(empty_query.trim().is_empty()); + assert!(whitespace_query.trim().is_empty()); + assert!(!valid_query.trim().is_empty()); +} + +#[test] +fn test_embed_text_validation_logic() { + let empty_text = ""; + let whitespace_text = " "; + let valid_text = "valid text"; + + assert!(empty_text.trim().is_empty()); + assert!(whitespace_text.trim().is_empty()); + assert!(!valid_text.trim().is_empty()); +} + +#[test] +fn test_embed_batch_validation_logic() { + let empty_vec: Vec = vec![]; + let vec_with_empty_strings = vec!["".to_string(), " ".to_string()]; + let valid_vec = vec!["valid".to_string()]; + + assert!(empty_vec.is_empty()); + assert!(vec_with_empty_strings.iter().all(|s| s.trim().is_empty())); + assert!(!valid_vec.is_empty() && !valid_vec.iter().all(|s| s.trim().is_empty())); +} diff --git a/src/server/mod.rs b/src/server/mod.rs index f50f38d..b4fd3f7 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -3,5 +3,8 @@ mod api; mod client; +#[cfg(test)] +mod api_tests; + pub use api::run_server; pub use client::Client;