diff --git a/Cargo.lock b/Cargo.lock index c8104495..b6cb3fa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1398,6 +1398,7 @@ dependencies = [ "fx-improve", "fx-kernel", "fx-llm", + "fx-marketplace", "fx-memory", "fx-ripcord", "fx-session", @@ -1544,6 +1545,7 @@ dependencies = [ "fx-marketplace", "fx-memory", "fx-propose", + "fx-python", "fx-ripcord", "fx-scratchpad", "fx-security", @@ -1580,6 +1582,20 @@ dependencies = [ "zeroize", ] +[[package]] +name = "fx-cloud-gpu" +version = "0.1.0" +dependencies = [ + "async-trait", + "fx-kernel", + "fx-llm", + "fx-loadable", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "fx-config" version = "0.1.0" @@ -1808,6 +1824,7 @@ dependencies = [ "fx-session", "fx-skills", "fx-transactions", + "libc", "notify", "serde", "serde_json", @@ -1828,6 +1845,7 @@ dependencies = [ "serde", "serde_json", "tempfile", + "tracing", "ureq", ] @@ -1875,6 +1893,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "fx-python" +version = "0.1.0" +dependencies = [ + "async-trait", + "fx-kernel", + "fx-llm", + "fx-loadable", + "libc", + "serde", + "serde_json", + "tempfile", + "tokio", +] + [[package]] name = "fx-ripcord" version = "0.1.0" @@ -2035,6 +2068,7 @@ dependencies = [ "fx-loadable", "fx-memory", "fx-propose", + "fx-ripcord", "fx-session", "fx-storage", "fx-subagent", diff --git a/Cargo.toml b/Cargo.toml index 0292770b..cd72b701 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,8 @@ members = [ "engine/crates/fx-kernel", "engine/crates/fx-auth", "engine/crates/fx-loadable", + "engine/crates/fx-cloud-gpu", + "engine/crates/fx-python", "engine/crates/fx-agent", "engine/crates/fx-llm", "engine/crates/fx-security", @@ -121,6 +123,7 @@ fx-auth = { path = "engine/crates/fx-auth" } fx-cli = { path = "engine/crates/fx-cli" } fx-conversation = { path = "engine/crates/fx-conversation" } fx-loadable = { path = "engine/crates/fx-loadable" } +fx-python = { path = "engine/crates/fx-python" } fx-agent = { path = "engine/crates/fx-agent" } fx-security = { path = "engine/crates/fx-security" } fx-skills = { path = "engine/crates/fx-skills" } diff --git a/engine/crates/fx-api/Cargo.toml b/engine/crates/fx-api/Cargo.toml index d3065ab1..a4ddda97 100644 --- a/engine/crates/fx-api/Cargo.toml +++ b/engine/crates/fx-api/Cargo.toml @@ -25,6 +25,7 @@ fx-kernel.workspace = true fx-improve.workspace = true fx-llm.workspace = true fx-memory.workspace = true +fx-marketplace = { path = "../fx-marketplace" } fx-ripcord = { path = "../fx-ripcord" } fx-session.workspace = true fx-storage.workspace = true diff --git a/engine/crates/fx-api/src/bundle.rs b/engine/crates/fx-api/src/bundle.rs index ad6979d0..2459a9e9 100644 --- a/engine/crates/fx-api/src/bundle.rs +++ b/engine/crates/fx-api/src/bundle.rs @@ -60,9 +60,9 @@ mod tests { #[test] fn find_bundle_root_finds_nested_app() { - let path = Path::new("/Users/joe/Desktop/Fawx.app/Contents/MacOS/fawx-server"); + let path = Path::new("/Applications/Fawx.app/Contents/MacOS/fawx-server"); let root = find_bundle_root(path); - assert_eq!(root, Some(PathBuf::from("/Users/joe/Desktop/Fawx.app"))); + assert_eq!(root, Some(PathBuf::from("/Applications/Fawx.app"))); } #[test] diff --git a/engine/crates/fx-api/src/devices.rs b/engine/crates/fx-api/src/devices.rs index 9c178710..2d9d960c 100644 --- a/engine/crates/fx-api/src/devices.rs +++ b/engine/crates/fx-api/src/devices.rs @@ -220,7 +220,7 @@ mod tests { #[test] fn create_device_returns_hashed_token() { let mut store = DeviceStore::new(); - let (raw_token, device) = store.create_device("My MacBook"); + let (raw_token, device) = store.create_device("Example MacBook"); assert!(raw_token.starts_with(DEVICE_TOKEN_PREFIX)); assert_eq!( @@ -235,18 +235,18 @@ mod tests { #[test] fn list_device_info_excludes_token_hash() { let mut store = DeviceStore::new(); - let _ = store.create_device("My MacBook"); + let _ = store.create_device("Example MacBook"); let json = serde_json::to_value(store.list_device_info()).expect("serialize device info"); assert!(json[0].get("token_hash").is_none()); - assert_eq!(json[0]["device_name"], "My MacBook"); + assert_eq!(json[0]["device_name"], "Example MacBook"); } #[test] fn authenticate_works() { let mut store = DeviceStore::new(); - let (raw_token, device) = store.create_device("My MacBook"); + let (raw_token, device) = store.create_device("Example MacBook"); store.list_devices_mut()[0].last_used_at = 0; assert_eq!(store.authenticate(&raw_token), Some(device.id)); @@ -257,7 +257,7 @@ mod tests { #[test] fn revoke_invalidates_device() { let mut store = DeviceStore::new(); - let (raw_token, device) = store.create_device("My MacBook"); + let (raw_token, device) = store.create_device("Example MacBook"); assert_eq!(store.revoke(&device.id), Some(device.clone())); assert!(store.revoke(&device.id).is_none()); @@ -269,7 +269,7 @@ mod tests { let temp = tempdir().expect("tempdir"); let path = temp.path().join("devices.json"); let mut store = DeviceStore::new(); - let (raw_token, _) = store.create_device("My MacBook"); + let (raw_token, _) = store.create_device("Example MacBook"); store.save(&path).expect("save device store"); let mut loaded = DeviceStore::load(&path); @@ -286,7 +286,7 @@ mod tests { let temp = tempdir().expect("tempdir"); let path = temp.path().join("devices.json"); let mut store = DeviceStore::new(); - let _ = store.create_device("My MacBook"); + let _ = store.create_device("Example MacBook"); store.save(&path).expect("save device store"); let mode = fs::metadata(&path).expect("metadata").permissions().mode() & 0o777; @@ -302,7 +302,7 @@ mod tests { devices: vec![DeviceToken { id: "dev-123".to_string(), token_hash: "hash".to_string(), - device_name: "My MacBook".to_string(), + device_name: "Example MacBook".to_string(), created_at: 1_700_000_000_000, last_used_at: 1_700_000_005_000, }], diff --git a/engine/crates/fx-api/src/handlers/fleet.rs b/engine/crates/fx-api/src/handlers/fleet.rs index 2b56c6e9..e0a8bfed 100644 --- a/engine/crates/fx-api/src/handlers/fleet.rs +++ b/engine/crates/fx-api/src/handlers/fleet.rs @@ -157,7 +157,7 @@ mod tests { let temp_dir = tempfile::TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); TestFleet { _temp_dir: temp_dir, @@ -168,7 +168,7 @@ mod tests { fn registration_request(token: &str) -> FleetRegistrationRequest { FleetRegistrationRequest { - node_name: "node-alpha".to_string(), + node_name: "node-a".to_string(), bearer_token: token.to_string(), capabilities: vec!["agentic_loop".to_string(), "macos-aarch64".to_string()], rust_version: Some("1.85.0".to_string()), diff --git a/engine/crates/fx-api/src/handlers/fleet_dashboard.rs b/engine/crates/fx-api/src/handlers/fleet_dashboard.rs index c2d17297..10252c53 100644 --- a/engine/crates/fx-api/src/handlers/fleet_dashboard.rs +++ b/engine/crates/fx-api/src/handlers/fleet_dashboard.rs @@ -360,7 +360,7 @@ mod tests { fn node_dto_serializes() { let response = FleetNodeDto { id: "node-1".to_string(), - name: "Node Alpha".to_string(), + name: "Worker Node A".to_string(), status: "healthy".to_string(), last_seen_at: 1_742_000_100, active_tasks: 0, @@ -373,7 +373,7 @@ mod tests { json, json!({ "id": "node-1", - "name": "Node Alpha", + "name": "Worker Node A", "status": "healthy", "last_seen_at": 1_742_000_100, "active_tasks": 0, @@ -433,7 +433,7 @@ mod tests { fn effective_status_marks_old_busy_nodes_degraded() { let node = NodeInfo { node_id: "node-1".to_string(), - name: "Node Alpha".to_string(), + name: "Worker Node A".to_string(), endpoint: "https://127.0.0.1:8400".to_string(), auth_token: None, capabilities: vec![NodeCapability::AgenticLoop], diff --git a/engine/crates/fx-api/src/handlers/git.rs b/engine/crates/fx-api/src/handlers/git.rs index 0d538436..5b4da36b 100644 --- a/engine/crates/fx-api/src/handlers/git.rs +++ b/engine/crates/fx-api/src/handlers/git.rs @@ -725,7 +725,7 @@ mod tests { hash: "abcdef123456".to_string(), short_hash: "abcdef1".to_string(), message: "feat: add git api".to_string(), - author: "Alice".to_string(), + author: "Example Author".to_string(), timestamp: "2026-03-15T20:00:00Z".to_string(), }], }; @@ -733,7 +733,7 @@ mod tests { let json = serde_json::to_value(response).unwrap(); assert_eq!(json["commits"][0]["hash"], "abcdef123456"); - assert_eq!(json["commits"][0]["author"], "Alice"); + assert_eq!(json["commits"][0]["author"], "Example Author"); } #[test] @@ -779,14 +779,14 @@ mod tests { #[test] fn parse_log_line() { let commit = super::parse_log_line( - "abcdef123456|abcdef1|feat: support pipes | in messages|Alice|2026-03-15T20:00:00Z", + "abcdef123456|abcdef1|feat: support pipes | in messages|Example Author|2026-03-15T20:00:00Z", ) .unwrap(); assert_eq!(commit.hash, "abcdef123456"); assert_eq!(commit.short_hash, "abcdef1"); assert_eq!(commit.message, "feat: support pipes | in messages"); - assert_eq!(commit.author, "Alice"); + assert_eq!(commit.author, "Example Author"); } #[test] diff --git a/engine/crates/fx-api/src/handlers/marketplace.rs b/engine/crates/fx-api/src/handlers/marketplace.rs index e9a66746..fd1632ff 100644 --- a/engine/crates/fx-api/src/handlers/marketplace.rs +++ b/engine/crates/fx-api/src/handlers/marketplace.rs @@ -1,13 +1,15 @@ +use std::path::{Path, PathBuf}; + use crate::skill_manifests::{update_skill_capabilities, SkillManifestError}; use crate::state::HttpState; use crate::types::ErrorBody; -use axum::extract::{Json, Path, Query, State}; +use axum::extract::{Json, Path as AxumPath, Query, State}; use axum::http::StatusCode; +use fx_marketplace::{InstallResult, MarketplaceError, SkillEntry}; use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; const MARKETPLACE_NOT_CONNECTED_MESSAGE: &str = "Marketplace not yet connected"; -const MARKETPLACE_UNAVAILABLE_MESSAGE: &str = - "Marketplace not yet available. Install skills via CLI: fawx skills install "; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct MarketplaceSkillSummary { @@ -35,10 +37,16 @@ pub struct SearchQuery { #[derive(Debug, Deserialize)] pub struct InstallSkillRequest { - /// Skill name to install. Currently unused (stub), but validated by - /// deserialization to ensure the request shape is correct. - #[allow(dead_code)] + /// Skill name to install. + pub name: String, +} + +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct InstallSkillResponse { pub name: String, + pub version: String, + pub size_bytes: u64, + pub installed: bool, } #[derive(Debug, Deserialize)] @@ -55,29 +63,38 @@ pub struct UpdateSkillPermissionsResponse { pub restart_required: bool, } -pub async fn handle_search_skills(Query(params): Query) -> Json { - Json(SkillSearchResponse { - query: params.q, - skills: vec![], - total: 0, - marketplace_available: false, - message: MARKETPLACE_NOT_CONNECTED_MESSAGE.to_string(), - }) +pub async fn handle_search_skills( + State(state): State, + Query(params): Query, +) -> Json { + Json(search_skills_response(state.data_dir.clone(), params.q, search_marketplace).await) } pub async fn handle_install_skill( - Json(_request): Json, -) -> (StatusCode, Json) { - marketplace_unavailable() + State(state): State, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + install_skill_response( + state.data_dir.clone(), + request.name, + install_marketplace_skill, + ) + .await + .map(Json) } -pub async fn handle_remove_skill(Path(name): Path) -> (StatusCode, Json) { - skill_not_found(name) +pub async fn handle_remove_skill( + State(state): State, + AxumPath(name): AxumPath, +) -> Result, (StatusCode, Json)> { + remove_skill_response(state.data_dir.clone(), name) + .await + .map(Json) } pub async fn handle_update_skill_permissions( State(state): State, - Path(name): Path, + AxumPath(name): AxumPath, Json(request): Json, ) -> Result, (StatusCode, Json)> { let capabilities = @@ -92,15 +109,253 @@ pub async fn handle_update_skill_permissions( })) } -fn marketplace_unavailable() -> (StatusCode, Json) { +async fn search_skills_response( + data_dir: PathBuf, + query: String, + search_fn: F, +) -> SkillSearchResponse +where + F: FnOnce(&Path, &str) -> Result, MarketplaceError> + Send + 'static, +{ + let query_for_error = query.clone(); + match tokio::task::spawn_blocking(move || { + let entries = search_fn(&data_dir, &query)?; + Ok::(build_search_response(query, entries)) + }) + .await + { + Ok(Ok(response)) => response, + Ok(Err(error)) => { + tracing::error!(error = %error, "Marketplace search failed"); + unavailable_search_response(query_for_error, error.to_string()) + } + Err(error) => { + tracing::error!(error = %error, "Marketplace search task failed"); + unavailable_search_response(query_for_error, error.to_string()) + } + } +} + +fn search_marketplace(data_dir: &Path, query: &str) -> Result, MarketplaceError> { + let config = fx_marketplace::default_config(data_dir)?; + fx_marketplace::search(&config, query) +} + +fn build_search_response(query: String, entries: Vec) -> SkillSearchResponse { + let skills: Vec<_> = entries.into_iter().map(map_skill_entry).collect(); + SkillSearchResponse { + query, + total: skills.len(), + skills, + marketplace_available: true, + message: String::new(), + } +} + +fn map_skill_entry(entry: SkillEntry) -> MarketplaceSkillSummary { + MarketplaceSkillSummary { + title: title_case_skill_name(&entry.name), + name: entry.name, + description: entry.description, + publisher: entry.author, + signed: true, + } +} + +fn title_case_skill_name(name: &str) -> String { + let mut chars = name.chars(); + let Some(first) = chars.next() else { + return String::new(); + }; + + let mut title = String::new(); + title.extend(first.to_uppercase()); + title.push_str(&chars.as_str().to_lowercase()); + title +} + +fn unavailable_search_response(query: String, message: String) -> SkillSearchResponse { + let message = if message.is_empty() { + MARKETPLACE_NOT_CONNECTED_MESSAGE.to_string() + } else { + message + }; + + SkillSearchResponse { + query, + skills: vec![], + total: 0, + marketplace_available: false, + message, + } +} + +async fn install_skill_response( + data_dir: PathBuf, + name: String, + install_fn: F, +) -> Result)> +where + F: FnOnce(&Path, &str) -> Result + Send + 'static, +{ + match tokio::task::spawn_blocking(move || { + let result = install_fn(&data_dir, &name)?; + Ok::(InstallSkillResponse::from(result)) + }) + .await + { + Ok(Ok(response)) => Ok(response), + Ok(Err(error)) => { + tracing::error!(error = %error, "Marketplace install failed"); + Err(marketplace_error(error)) + } + Err(error) => { + tracing::error!(error = %error, "Marketplace install task failed"); + Err(internal_error(error.to_string())) + } + } +} + +fn install_marketplace_skill( + data_dir: &Path, + name: &str, +) -> Result { + let config = fx_marketplace::default_config(data_dir)?; + fx_marketplace::install(&config, name) +} + +impl From for InstallSkillResponse { + fn from(result: InstallResult) -> Self { + Self { + name: result.name, + version: result.version, + size_bytes: result.size_bytes, + installed: true, + } + } +} + +async fn remove_skill_response( + data_dir: PathBuf, + name: String, +) -> Result)> { + match tokio::task::spawn_blocking(move || remove_skill_directory(&data_dir, &name)).await { + Ok(Ok(response)) => Ok(response), + Ok(Err(error)) => { + let Json(body) = &error.1; + tracing::error!( + status = %error.0, + error = %body.error, + "Marketplace remove failed" + ); + Err(error) + } + Err(error) => { + tracing::error!(error = %error, "Marketplace remove task failed"); + Err(internal_error(error.to_string())) + } + } +} + +fn remove_skill_directory( + data_dir: &Path, + name: &str, +) -> Result)> { + fx_marketplace::validate_skill_name(name) + .map_err(|error| validation_error(error.to_string()))?; + + let skills_dir = data_dir.join("skills"); + let skill_dir = skills_dir.join(name); + ensure_skill_exists(&skill_dir, name)?; + ensure_skill_dir_within_skills_dir(&skills_dir, &skill_dir)?; + std::fs::remove_dir_all(&skill_dir) + .map_err(|error| internal_error(format!("failed to remove skill '{name}': {error}")))?; + + Ok(json!({ "removed": true, "name": name })) +} + +fn ensure_skill_exists(skill_dir: &Path, name: &str) -> Result<(), (StatusCode, Json)> { + if skill_dir + .try_exists() + .map_err(|error| internal_error(format!("failed to access skill '{name}': {error}")))? + { + return Ok(()); + } + Err(skill_not_found(name.to_string())) +} + +fn ensure_skill_dir_within_skills_dir( + skills_dir: &Path, + skill_dir: &Path, +) -> Result<(), (StatusCode, Json)> { + let canonical_skill_dir = std::fs::canonicalize(skill_dir).map_err(|error| { + tracing::error!( + error = %error, + skill_dir = %skill_dir.display(), + "Failed to resolve skill directory" + ); + invalid_skill_directory() + })?; + let canonical_skills_dir = std::fs::canonicalize(skills_dir).map_err(|error| { + tracing::error!( + error = %error, + skills_dir = %skills_dir.display(), + "Failed to resolve skills directory" + ); + invalid_skill_directory() + })?; + if canonical_skill_dir.starts_with(&canonical_skills_dir) { + return Ok(()); + } + + tracing::error!( + skill_dir = %skill_dir.display(), + skills_dir = %skills_dir.display(), + canonical_skill_dir = %canonical_skill_dir.display(), + canonical_skills_dir = %canonical_skills_dir.display(), + "Skill directory outside allowed path" + ); + Err(skill_directory_outside_allowed_path()) +} + +fn marketplace_error(error: MarketplaceError) -> (StatusCode, Json) { + let status = match &error { + MarketplaceError::SkillNotFound(_) => StatusCode::NOT_FOUND, + MarketplaceError::SignatureInvalid(_) | MarketplaceError::ManifestInvalid(_) => { + StatusCode::UNPROCESSABLE_ENTITY + } + MarketplaceError::InvalidIndex(_) | MarketplaceError::NetworkError(_) => { + StatusCode::BAD_GATEWAY + } + MarketplaceError::InstallError(_) | MarketplaceError::InsecureRegistry(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } + }; + ( - StatusCode::SERVICE_UNAVAILABLE, + status, Json(ErrorBody { - error: MARKETPLACE_UNAVAILABLE_MESSAGE.to_string(), + error: error.to_string(), }), ) } +fn validation_error(error: String) -> (StatusCode, Json) { + (StatusCode::BAD_REQUEST, Json(ErrorBody { error })) +} + +fn internal_error(error: String) -> (StatusCode, Json) { + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorBody { error })) +} + +fn invalid_skill_directory() -> (StatusCode, Json) { + internal_error("Invalid skill directory".to_string()) +} + +fn skill_directory_outside_allowed_path() -> (StatusCode, Json) { + internal_error("Skill directory outside allowed path".to_string()) +} + fn skill_not_found(name: String) -> (StatusCode, Json) { ( StatusCode::NOT_FOUND, @@ -128,6 +383,35 @@ fn skill_manifest_error(error: SkillManifestError) -> (StatusCode, Json SkillEntry { + SkillEntry { + name: name.to_string(), + version: "1.0.0".to_string(), + description: format!("{name} description"), + author: "Fawx Team".to_string(), + capabilities: vec!["network".to_string()], + size_bytes: Some(1024), + } + } + + fn sample_install_result() -> InstallResult { + InstallResult { + name: "weather".to_string(), + version: "1.2.3".to_string(), + size_bytes: 4096, + install_path: PathBuf::from("/tmp/fawx/skills/weather"), + } + } #[test] fn search_response_serializes() { @@ -170,35 +454,190 @@ mod tests { assert_eq!(request.name, "portfolio-tracker"); } + #[test] + fn install_response_serializes_without_install_path() { + let json = + serde_json::to_value(InstallSkillResponse::from(sample_install_result())).unwrap(); + + assert_eq!(json["name"], "weather"); + assert_eq!(json["installed"], true); + assert_eq!(json.get("install_path"), None); + } + #[tokio::test] - async fn search_handler_returns_empty_results() { - let params = SearchQuery { - q: "portfolio".into(), - }; - let response = handle_search_skills(Query(params)).await; + async fn search_response_maps_marketplace_entries() { + let response = search_skills_response(PathBuf::new(), String::new(), |_, query| { + assert!(query.is_empty()); + Ok(vec![ + sample_skill_entry("weather"), + sample_skill_entry("web-fetch"), + ]) + }) + .await; + + assert!(response.marketplace_available); + assert_eq!(response.total, 2); + assert_eq!(response.skills[0].title, "Weather"); + assert_eq!(response.skills[1].title, "Web-fetch"); + assert!(response.skills.iter().all(|skill| skill.signed)); + assert!(response.message.is_empty()); + } - assert_eq!(response.0.query, "portfolio"); - assert!(response.0.skills.is_empty()); - assert!(!response.0.marketplace_available); + #[tokio::test] + async fn search_response_returns_error_message_when_marketplace_fails() { + let response = search_skills_response(PathBuf::new(), "weather".into(), |_, _| { + Err(MarketplaceError::NetworkError( + "registry unavailable".into(), + )) + }) + .await; + + assert_eq!(response.query, "weather"); + assert!(response.skills.is_empty()); + assert!(!response.marketplace_available); + assert_eq!(response.message, "network error: registry unavailable"); } #[tokio::test] - async fn install_handler_returns_service_unavailable() { - let request = InstallSkillRequest { - name: "portfolio-tracker".into(), - }; - let (status, body) = handle_install_skill(Json(request)).await; + async fn search_response_handles_blocking_task_panics() { + let response = search_skills_response(PathBuf::new(), "weather".into(), |_, _| { + panic!("boom from search") + }) + .await; + + assert_eq!(response.query, "weather"); + assert!(!response.marketplace_available); + assert!(response.message.contains("panic") || !response.message.is_empty()); + } - assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); - assert_eq!(body.0.error, MARKETPLACE_UNAVAILABLE_MESSAGE); + #[tokio::test] + async fn install_response_maps_success() { + let response = install_skill_response(PathBuf::new(), "weather".into(), |_, name| { + assert_eq!(name, "weather"); + Ok(sample_install_result()) + }) + .await + .expect("install should succeed"); + + assert_eq!(response.name, "weather"); + assert_eq!(response.version, "1.2.3"); + assert_eq!(response.size_bytes, 4096); + assert!(response.installed); } #[tokio::test] - async fn remove_handler_returns_not_found() { - let (status, body) = handle_remove_skill(Path(String::from("portfolio-tracker"))).await; + async fn install_response_returns_status_when_marketplace_fails() { + let error = install_skill_response(PathBuf::new(), "weather".into(), |_, _| { + Err(MarketplaceError::InstallError("disk full".into())) + }) + .await + .expect_err("install should fail"); + + let Json(body) = &error.1; - assert_eq!(status, StatusCode::NOT_FOUND); - assert_eq!(body.0.error, "Skill 'portfolio-tracker' not found"); + assert_eq!(error.0, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(body.error, "install error: disk full"); + } + + #[test] + fn marketplace_error_maps_all_variants() { + let cases = vec![ + ( + MarketplaceError::SkillNotFound("missing".into()), + StatusCode::NOT_FOUND, + ), + ( + MarketplaceError::SignatureInvalid("bad sig".into()), + StatusCode::UNPROCESSABLE_ENTITY, + ), + ( + MarketplaceError::ManifestInvalid("bad manifest".into()), + StatusCode::UNPROCESSABLE_ENTITY, + ), + ( + MarketplaceError::InvalidIndex("bad index".into()), + StatusCode::BAD_GATEWAY, + ), + ( + MarketplaceError::NetworkError("offline".into()), + StatusCode::BAD_GATEWAY, + ), + ( + MarketplaceError::InstallError("disk full".into()), + StatusCode::INTERNAL_SERVER_ERROR, + ), + ( + MarketplaceError::InsecureRegistry("http://example.com".into()), + StatusCode::INTERNAL_SERVER_ERROR, + ), + ]; + + for (error, expected_status) in cases { + let (status, body) = marketplace_error(error); + assert_eq!(status, expected_status); + assert!(!body.0.error.is_empty()); + } + } + + #[tokio::test] + async fn remove_response_rejects_invalid_names() { + let temp = TempDir::new().expect("tempdir"); + let error = remove_skill_response(temp.path().to_path_buf(), "../escape".into()) + .await + .expect_err("invalid name should fail"); + + assert_eq!(error.0, StatusCode::BAD_REQUEST); + assert!(error.1 .0.error.contains("forbidden characters")); + } + + #[tokio::test] + async fn remove_response_returns_not_found_before_canonicalize() { + let temp = TempDir::new().expect("tempdir"); + let error = remove_skill_response(temp.path().to_path_buf(), "weather".into()) + .await + .expect_err("missing skill should fail"); + + assert_eq!(error.0, StatusCode::NOT_FOUND); + assert_eq!(error.1 .0.error, "Skill 'weather' not found"); + } + + #[cfg(unix)] + #[tokio::test] + async fn remove_response_rejects_symlink_escape_without_leaking_paths() { + let temp = TempDir::new().expect("tempdir"); + let outside_dir = temp.path().join("outside-weather"); + let skills_dir = temp.path().join("skills"); + fs::create_dir_all(&outside_dir).expect("mkdir outside"); + fs::create_dir_all(&skills_dir).expect("mkdir skills"); + create_dir_symlink(&outside_dir, &skills_dir.join("weather")); + + let error = remove_skill_response(temp.path().to_path_buf(), "weather".into()) + .await + .expect_err("symlink escape should fail"); + let Json(body) = &error.1; + let client_error = body.error.as_str(); + + assert_eq!(error.0, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(client_error, "Skill directory outside allowed path"); + assert!(!client_error.contains(outside_dir.to_string_lossy().as_ref())); + assert!(!client_error.contains(skills_dir.to_string_lossy().as_ref())); + assert!(outside_dir.exists()); + } + + #[tokio::test] + async fn remove_response_deletes_existing_skill_directory() { + let temp = TempDir::new().expect("tempdir"); + let skill_dir = temp.path().join("skills").join("weather"); + fs::create_dir_all(&skill_dir).expect("mkdir skill"); + fs::write(skill_dir.join("manifest.toml"), "name = \"weather\"").expect("manifest"); + + let response = remove_skill_response(temp.path().to_path_buf(), "weather".into()) + .await + .expect("remove should succeed"); + + assert_eq!(response["removed"], true); + assert_eq!(response["name"], "weather"); + assert!(!skill_dir.exists()); } #[test] diff --git a/engine/crates/fx-api/src/handlers/pairing.rs b/engine/crates/fx-api/src/handlers/pairing.rs index 250beb83..b7e21bca 100644 --- a/engine/crates/fx-api/src/handlers/pairing.rs +++ b/engine/crates/fx-api/src/handlers/pairing.rs @@ -396,16 +396,14 @@ mod phase4_tests { let response = qr_pairing_response( &test_runtime(false), &QrTailscaleStatus { - hostname: Some("joes-mac.tail1234.ts.net".to_string()), + hostname: Some("node.example.ts.net".to_string()), cert_ready: true, }, ); - assert_eq!(response.display_host, "joes-mac.tail1234.ts.net"); + assert_eq!(response.display_host, "node.example.ts.net"); assert_eq!(response.transport, "tailscale_https"); assert!(!response.same_network_only); - assert!(response - .scheme_url - .contains("host=joes-mac.tail1234.ts.net")); + assert!(response.scheme_url.contains("host=node.example.ts.net")); } #[test] diff --git a/engine/crates/fx-api/src/handlers/phase4.rs b/engine/crates/fx-api/src/handlers/phase4.rs index 937652c9..87dcfb1d 100644 --- a/engine/crates/fx-api/src/handlers/phase4.rs +++ b/engine/crates/fx-api/src/handlers/phase4.rs @@ -256,7 +256,7 @@ mod tests { installed: true, running: true, logged_in: true, - hostname: Some("joes-mac.tail1234.ts.net".to_string()), + hostname: Some("node.example.ts.net".to_string()), cert_ready: true, }, }; @@ -268,7 +268,7 @@ mod tests { assert_eq!(json["launchagent"]["loaded"], true); assert_eq!(json["local_server"]["port"], 8400); assert_eq!(json["auth"]["providers_configured"][0], "anthropic"); - assert_eq!(json["tailscale"]["hostname"], "joes-mac.tail1234.ts.net"); + assert_eq!(json["tailscale"]["hostname"], "node.example.ts.net"); } #[test] diff --git a/engine/crates/fx-api/src/handlers/sessions.rs b/engine/crates/fx-api/src/handlers/sessions.rs index 324ef3c1..470df757 100644 --- a/engine/crates/fx-api/src/handlers/sessions.rs +++ b/engine/crates/fx-api/src/handlers/sessions.rs @@ -27,6 +27,7 @@ use fx_session::{ SessionMessage, SessionRegistry, SessionStatus, }; use serde::{Deserialize, Serialize}; +use std::borrow::Cow; use std::collections::HashSet; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -34,8 +35,16 @@ use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; use uuid::Uuid; -const SESSION_MEMORY_MAX_ITEMS: usize = 20; -const SESSION_MEMORY_MAX_TOKENS: usize = 2_000; +// Coarse API validation gate. Session-level dynamic caps enforce the real limit. +const SESSION_MEMORY_MAX_ITEMS: usize = 80; +const SESSION_MEMORY_MAX_TOKENS: usize = 8_000; + +struct TurnInput<'a> { + message: Cow<'a, str>, + images: Cow<'a, [EncodedImage]>, + documents: Cow<'a, [EncodedDocument]>, + context: Vec, +} #[derive(Debug, Deserialize)] pub struct CreateSessionRequest { @@ -87,10 +96,7 @@ struct StreamingSessionMessageTask { state: HttpState, registry: SessionRegistry, key: SessionKey, - message: String, - images: Vec, - documents: Vec, - context: Vec, + input: TurnInput<'static>, sender: mpsc::Sender, disconnected: Arc, } @@ -313,10 +319,12 @@ pub(crate) async fn handle_send_message_for_session( state, registry, key, - request.message, - images, - documents, - context, + TurnInput { + message: Cow::Owned(request.message), + images: Cow::Owned(images), + documents: Cow::Owned(documents), + context, + }, ) .await); } @@ -325,10 +333,12 @@ pub(crate) async fn handle_send_message_for_session( &state, ®istry, &key, - &request.message, - &images, - &documents, - context, + TurnInput { + message: Cow::Borrowed(request.message.as_str()), + images: Cow::Borrowed(&images), + documents: Cow::Borrowed(&documents), + context, + }, ) .await .map_err(internal_error)?; @@ -397,10 +407,7 @@ async fn stream_session_message_response( state: HttpState, registry: SessionRegistry, key: SessionKey, - message: String, - images: Vec, - documents: Vec, - context: Vec, + input: TurnInput<'static>, ) -> Response { let (sender, receiver) = mpsc::channel(SSE_CHANNEL_CAPACITY); let disconnected = Arc::new(AtomicBool::new(false)); @@ -409,10 +416,7 @@ async fn stream_session_message_response( state, registry, key, - message, - images, - documents, - context, + input, sender, disconnected, }, @@ -421,35 +425,33 @@ async fn stream_session_message_response( } async fn run_streaming_session_message_task(task: StreamingSessionMessageTask) { - let callback = stream_callback(task.sender.clone(), Arc::clone(&task.disconnected)); - let result = execute_session_turn( - &task.state, - &task.registry, - &task.key, - &task.message, - &task.images, - &task.documents, - task.context, - Some(callback), - ) - .await; + let StreamingSessionMessageTask { + state, + registry, + key, + input, + sender, + disconnected, + } = task; + let callback = stream_callback(sender.clone(), Arc::clone(&disconnected)); + let result = execute_session_turn(&state, ®istry, &key, input, Some(callback)).await; match result { Ok((_result, session_messages, session_memory)) => { if let Err(error) = - persist_session_turn(&task.registry, &task.key, session_messages, session_memory) + persist_session_turn(®istry, &key, session_messages, session_memory) { let _ = send_sse_frame( - &task.sender, - &task.disconnected, + &sender, + &disconnected, error_stream_frame(&error.to_string()), ); } } Err(error) => { let _ = send_sse_frame( - &task.sender, - &task.disconnected, + &sender, + &disconnected, error_stream_frame(&error.to_string()), ); } @@ -460,15 +462,10 @@ async fn process_and_route_session_message( state: &HttpState, registry: &SessionRegistry, key: &SessionKey, - message: &str, - images: &[EncodedImage], - documents: &[EncodedDocument], - context: Vec, + input: TurnInput<'_>, ) -> Result<(CycleResult, String, Vec, SessionMemory), anyhow::Error> { - let (result, session_messages, session_memory) = execute_session_turn( - state, registry, key, message, images, documents, context, None, - ) - .await?; + let (result, session_messages, session_memory) = + execute_session_turn(state, registry, key, input, None).await?; state .channels @@ -487,15 +484,11 @@ async fn process_and_route_session_message( Ok((result, response, session_messages, session_memory)) } -#[allow(clippy::too_many_arguments)] async fn execute_session_turn( state: &HttpState, registry: &SessionRegistry, key: &SessionKey, - message: &str, - images: &[EncodedImage], - documents: &[EncodedDocument], - context: Vec, + input: TurnInput<'_>, callback: Option, ) -> Result<(CycleResult, Vec, SessionMemory), anyhow::Error> { let loaded_memory = registry.memory(key).map_err(anyhow::Error::new)?; @@ -503,10 +496,10 @@ async fn execute_session_turn( let previous_memory = app.replace_session_memory(loaded_memory); let outcome = app .process_message_with_context( - message, - encoded_images_to_attachments(images), - encoded_documents_to_attachments(documents), - context, + input.message.as_ref(), + encoded_images_to_attachments(input.images.as_ref()), + encoded_documents_to_attachments(input.documents.as_ref()), + input.context, InputSource::Http, callback, ) @@ -754,14 +747,38 @@ mod tests { })); } + #[test] + fn validate_session_memory_accepts_maximum_dynamic_item_cap() { + let mut memory = SessionMemory::default(); + memory.active_files = (0..SESSION_MEMORY_MAX_ITEMS) + .map(|index| format!("file-{index}.rs")) + .collect(); + + let validated = validate_session_memory(memory).expect("validation should pass"); + + assert_eq!(validated.active_files.len(), SESSION_MEMORY_MAX_ITEMS); + } + + #[test] + fn validate_session_memory_accepts_maximum_dynamic_token_cap() { + let mut memory = SessionMemory::default(); + memory.project = Some("a ".repeat(7_900).trim_end().to_string()); + + let estimated_tokens = memory.estimated_tokens(); + assert!(estimated_tokens > 4_000); + assert!(estimated_tokens <= SESSION_MEMORY_MAX_TOKENS); + + let validated = validate_session_memory(memory).expect("validation should pass"); + + assert_eq!(validated.estimated_tokens(), estimated_tokens); + } + #[test] fn validate_session_memory_rejects_too_many_active_files() { - let memory = SessionMemory { - active_files: (0..=SESSION_MEMORY_MAX_ITEMS) - .map(|index| format!("file-{index}.rs")) - .collect(), - ..SessionMemory::default() - }; + let mut memory = SessionMemory::default(); + memory.active_files = (0..=SESSION_MEMORY_MAX_ITEMS) + .map(|index| format!("file-{index}.rs")) + .collect(); let error = validate_session_memory(memory).expect_err("validation should fail"); diff --git a/engine/crates/fx-api/src/sse.rs b/engine/crates/fx-api/src/sse.rs index f822e936..1d24e21f 100644 --- a/engine/crates/fx-api/src/sse.rs +++ b/engine/crates/fx-api/src/sse.rs @@ -67,6 +67,13 @@ pub fn serialize_stream_event(event: StreamEvent) -> Option { "is_error": is_error, }), ), + StreamEvent::ToolError { tool_name, error } => sse_frame( + "tool_error", + serde_json::json!({ + "tool_name": tool_name, + "error": error, + }), + ), StreamEvent::PermissionPrompt(prompt) => sse_frame( "permission_prompt", serde_json::json!({ @@ -288,6 +295,20 @@ mod tests { ); } + #[test] + fn tool_error_event_serializes() { + let frame = serialize_stream_event(StreamEvent::ToolError { + tool_name: "read_file".to_string(), + error: "permission denied".to_string(), + }) + .expect("tool error frame"); + + assert_eq!( + frame, + "event: tool_error\ndata: {\"error\":\"permission denied\",\"tool_name\":\"read_file\"}\n\n" + ); + } + #[test] fn context_compacted_event_serializes() { let frame = serialize_stream_event(StreamEvent::ContextCompacted { diff --git a/engine/crates/fx-api/src/tailscale.rs b/engine/crates/fx-api/src/tailscale.rs index f254a75a..dd1f143c 100644 --- a/engine/crates/fx-api/src/tailscale.rs +++ b/engine/crates/fx-api/src/tailscale.rs @@ -108,11 +108,11 @@ mod tests { #[test] fn parse_tailscale_cli_output_returns_cgnat_ip() { - let stdout = b"100.100.1.1\n"; + let stdout = b"100.64.0.42\n"; assert_eq!( parse_tailscale_cli_output(stdout), - Some(IpAddr::V4(Ipv4Addr::new(100, 100, 1, 1))) + Some(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 42))) ); } @@ -132,17 +132,17 @@ mod tests { #[test] fn parse_macos_ifconfig_line_extracts_cgnat_ip() { - let line = "inet 100.100.2.1 --> 100.100.2.1 netmask 0xffffffff"; + let line = "inet 100.64.0.43 --> 100.64.0.43 netmask 0xffffffff"; assert_eq!( extract_ip_from_line(line), - Some(IpAddr::V4(Ipv4Addr::new(100, 100, 2, 1))) + Some(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 43))) ); } #[test] fn parse_macos_ifconfig_line_without_inet_prefix_returns_none() { - let line = "10.0.0.5 --> 10.0.0.5 netmask 0xffffffff"; + let line = "100.64.0.43 --> 100.64.0.43 netmask 0xffffffff"; assert_eq!(extract_ip_from_line(line), None); } @@ -157,11 +157,11 @@ mod tests { #[test] fn linux_ip_output_still_parsed_correctly() { - let text = "7: tailscale0 inet 100.100.1.1/32 brd 100.100.1.1 scope global tailscale0"; + let text = "7: tailscale0 inet 100.64.0.42/32 brd 100.64.0.42 scope global tailscale0"; assert_eq!( find_cgnat_ip(text), - Some(IpAddr::V4(Ipv4Addr::new(100, 100, 1, 1))) + Some(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 42))) ); } } diff --git a/engine/crates/fx-api/src/tests.rs b/engine/crates/fx-api/src/tests.rs index c75c060a..81233894 100644 --- a/engine/crates/fx-api/src/tests.rs +++ b/engine/crates/fx-api/src/tests.rs @@ -380,7 +380,7 @@ async fn mock_status() -> Json { model: "test-model".to_string(), skills: vec!["skill-a".to_string()], memory_entries: 10, - tailscale_ip: Some("100.64.0.1".to_string()), + tailscale_ip: Some("100.64.0.30".to_string()), config: None, }) } @@ -425,7 +425,7 @@ fn tailscale_ip_accepts_valid_range() { Ipv4Addr::new(100, 127, 255, 255) ))); assert!(crate::tailscale::is_tailscale_ip(&IpAddr::V4( - Ipv4Addr::new(100, 100, 1, 1) + Ipv4Addr::new(100, 64, 0, 42) ))); } @@ -450,14 +450,14 @@ fn tailscale_ip_rejects_ipv6() { #[test] fn listen_targets_bind_localhost_and_tailscale() { - let plan = listen_targets(8400, Some(IpAddr::V4(Ipv4Addr::new(100, 100, 1, 1)))); + let plan = listen_targets(8400, Some(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 42)))); let tailscale = plan.tailscale.expect("tailscale target"); assert_eq!(plan.local.addr, SocketAddr::from(([127, 0, 0, 1], 8400))); assert_eq!(plan.local.label, "local"); assert_eq!( tailscale.addr, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(100, 100, 1, 1)), 8400) + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 42)), 8400) ); assert_eq!(tailscale.label, "Tailscale"); } @@ -515,7 +515,7 @@ fn startup_target_lines_use_https_for_tailscale_when_enabled() { label: "local", }, Some(ListenTarget { - addr: SocketAddr::from(([100, 100, 1, 1], 8400)), + addr: SocketAddr::from(([100, 64, 0, 42], 8400)), label: "Tailscale", }), true, @@ -523,7 +523,7 @@ fn startup_target_lines_use_https_for_tailscale_when_enabled() { assert_eq!(lines[0], "Fawx API listening on:"); assert_eq!(lines[1], " http://127.0.0.1:8400 (local)"); - assert_eq!(lines[2], " https://100.100.1.1:8400 (Tailscale)"); + assert_eq!(lines[2], " https://100.64.0.42:8400 (Tailscale)"); } #[test] @@ -534,14 +534,14 @@ fn startup_target_lines_use_http_for_tailscale_when_tls_disabled() { label: "local", }, Some(ListenTarget { - addr: SocketAddr::from(([100, 100, 1, 1], 8400)), + addr: SocketAddr::from(([100, 64, 0, 42], 8400)), label: "Tailscale", }), false, ); assert_eq!(lines[0], "Fawx HTTP API listening on:"); - assert_eq!(lines[2], " http://100.100.1.1:8400 (Tailscale)"); + assert_eq!(lines[2], " http://100.64.0.42:8400 (Tailscale)"); } #[tokio::test] @@ -555,7 +555,7 @@ async fn tailscale_bind_failure_falls_back_to_localhost_server() { .expect("bind localhost"); let local_addr = local_listener.local_addr().expect("local addr"); let tailscale_target = ListenTarget { - addr: SocketAddr::from(([100, 100, 1, 1], 8400)), + addr: SocketAddr::from(([100, 64, 0, 42], 8400)), label: "Tailscale", }; let listeners = BoundListeners { @@ -615,9 +615,9 @@ async fn wait_for_server_pair_shuts_down_peer_when_one_server_exits() { #[test] fn extract_ip_parses_ip_addr_output() { - let line = "4: tailscale0 inet 100.100.1.1/32 scope global tailscale0"; + let line = "4: tailscale0 inet 100.64.0.42/32 scope global tailscale0"; let ip = crate::tailscale::extract_ip_from_line(line); - assert_eq!(ip, Some(IpAddr::V4(Ipv4Addr::new(100, 100, 1, 1)))); + assert_eq!(ip, Some(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 42)))); } #[test] @@ -715,13 +715,13 @@ fn status_response_has_expected_fields() { model: "claude-3".to_string(), skills: vec!["read_file".to_string()], memory_entries: 42, - tailscale_ip: Some("100.100.1.1".to_string()), + tailscale_ip: Some("100.64.0.20".to_string()), config: None, }; let json: serde_json::Value = serde_json::from_str(&serde_json::to_string(&response).expect("serialize")).expect("parse"); assert_eq!(json["status"], "ok"); - assert_eq!(json["tailscale_ip"], "100.100.1.1"); + assert_eq!(json["tailscale_ip"], "100.64.0.20"); assert_eq!(json["memory_entries"], 42); assert!(json["skills"].is_array()); } @@ -805,7 +805,7 @@ async fn status_endpoint_returns_ok() { let body = resp.into_body().collect().await.expect("body").to_bytes(); let json: serde_json::Value = serde_json::from_slice(&body).expect("json"); assert_eq!(json["status"], "ok"); - assert_eq!(json["tailscale_ip"], "100.64.0.1"); + assert_eq!(json["tailscale_ip"], "100.64.0.30"); assert!(json["skills"].is_array()); } @@ -1052,6 +1052,19 @@ fn serialize_stream_event_serializes_error_event_payload() { assert!(frame.contains("\"recoverable\":true")); } +#[test] +fn serialize_stream_event_serializes_tool_error_payload() { + let frame = serialize_stream_event(StreamEvent::ToolError { + tool_name: "read_file".to_string(), + error: "permission denied".to_string(), + }) + .expect("tool error frame"); + + assert!(frame.contains("event: tool_error")); + assert!(frame.contains("\"tool_name\":\"read_file\"")); + assert!(frame.contains("\"error\":\"permission denied\"")); +} + #[test] fn send_sse_frame_stops_when_receiver_is_closed() { let (sender, receiver) = mpsc::channel(1); @@ -1963,8 +1976,8 @@ mod routing_and_status { #[tokio::test] async fn get_devices_returns_device_list() { let mut devices = DeviceStore::new(); - let (_, first) = devices.create_device("My MacBook"); - let (_, second) = devices.create_device("My iPhone"); + let (_, first) = devices.create_device("Example MacBook"); + let (_, second) = devices.create_device("Example iPhone"); let app = build_router(test_state_with_devices(devices), None); let response = app @@ -1983,7 +1996,7 @@ mod routing_and_status { #[tokio::test] async fn get_devices_excludes_token_hash() { let mut devices = DeviceStore::new(); - let _ = devices.create_device("My MacBook"); + let _ = devices.create_device("Example MacBook"); let app = build_router(test_state_with_devices(devices), None); let response = app @@ -1998,7 +2011,7 @@ mod routing_and_status { #[tokio::test] async fn delete_device_revokes_token() { let mut devices = DeviceStore::new(); - let (raw_token, device) = devices.create_device("My MacBook"); + let (raw_token, device) = devices.create_device("Example MacBook"); let app = build_router(test_state_with_devices(devices), None); let before_delete = Request::builder() @@ -2430,7 +2443,7 @@ allowed_chat_ids = [123] let temp = TempDir::new().expect("tempdir"); let mut manager = FleetManager::init(temp.path()).expect("fleet init"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); let app = build_router( test_state(None, Vec::new()), @@ -2442,7 +2455,7 @@ allowed_chat_ids = [123] .header("content-type", "application/json") .body(Body::from( serde_json::to_vec(&fx_fleet::FleetRegistrationRequest { - node_name: "node-alpha".to_string(), + node_name: "node-a".to_string(), bearer_token: token.secret, capabilities: vec!["agentic_loop".to_string()], rust_version: None, @@ -3033,14 +3046,13 @@ allowed_chat_ids = [123] async fn get_session_memory_returns_stored_memory() { let registry = make_session_registry(); let key = seed_session(®istry, "sess-memory-get"); - let memory = SessionMemory { - project: Some("Phase 6".to_string()), - current_state: Some("Reviewing compaction UX".to_string()), - key_decisions: vec!["Use a subtle banner".to_string()], - active_files: vec!["app/Fawx/ViewModels/ChatViewModel.swift".to_string()], - custom_context: vec!["Keep session memory user-editable".to_string()], - last_updated: 1_742_000_000, - }; + let mut memory = SessionMemory::default(); + memory.project = Some("Phase 6".to_string()); + memory.current_state = Some("Reviewing compaction UX".to_string()); + memory.key_decisions = vec!["Use a subtle banner".to_string()]; + memory.active_files = vec!["app/Fawx/ViewModels/ChatViewModel.swift".to_string()]; + memory.custom_context = vec!["Keep session memory user-editable".to_string()]; + memory.last_updated = 1_742_000_000; registry .record_turn(&key, Vec::new(), memory.clone()) .expect("seed memory"); @@ -3084,10 +3096,8 @@ allowed_chat_ids = [123] async fn put_session_memory_persists_and_updates_loaded_session_memory() { let registry = make_session_registry(); let key = seed_session(®istry, "sess-memory-put"); - let initial_loaded_memory = SessionMemory { - project: Some("Old loaded memory".to_string()), - ..SessionMemory::default() - }; + let mut initial_loaded_memory = SessionMemory::default(); + initial_loaded_memory.project = Some("Old loaded memory".to_string()); let (app, app_state) = session_memory_test_router(registry.clone(), initial_loaded_memory, Some(key.clone())); @@ -3141,16 +3151,14 @@ allowed_chat_ids = [123] async fn put_session_memory_rejects_payloads_that_exceed_token_cap() { let registry = make_session_registry(); let key = seed_session(®istry, "sess-memory-too-large"); - let seeded_memory = SessionMemory { - project: Some("Existing memory".to_string()), - ..SessionMemory::default() - }; + let mut seeded_memory = SessionMemory::default(); + seeded_memory.project = Some("Existing memory".to_string()); registry .record_turn(&key, Vec::new(), seeded_memory.clone()) .expect("seed memory"); let app = build_router(test_state_with_sessions(registry.clone()), None); - let oversized_project = "memory ".repeat(2_200); + let oversized_project = "a ".repeat(8_100); let request_body = serde_json::json!({ "project": oversized_project, "last_updated": 0 @@ -3179,14 +3187,10 @@ allowed_chat_ids = [123] async fn session_message_persists_updated_session_memory() { let registry = make_session_registry(); let key = seed_session(®istry, "sess-memory-persist"); - let seeded_memory = SessionMemory { - project: Some("persistent project".to_string()), - ..SessionMemory::default() - }; - let restored_memory = SessionMemory { - project: Some("shared app memory".to_string()), - ..SessionMemory::default() - }; + let mut seeded_memory = SessionMemory::default(); + seeded_memory.project = Some("persistent project".to_string()); + let mut restored_memory = SessionMemory::default(); + restored_memory.project = Some("shared app memory".to_string()); registry .record_turn(&key, Vec::new(), seeded_memory.clone()) .expect("seed memory"); @@ -3223,14 +3227,10 @@ allowed_chat_ids = [123] async fn session_message_stream_persists_updated_session_memory() { let registry = make_session_registry(); let key = seed_session(®istry, "sess-memory-stream-persist"); - let seeded_memory = SessionMemory { - project: Some("persistent project".to_string()), - ..SessionMemory::default() - }; - let restored_memory = SessionMemory { - project: Some("shared app memory".to_string()), - ..SessionMemory::default() - }; + let mut seeded_memory = SessionMemory::default(); + seeded_memory.project = Some("persistent project".to_string()); + let mut restored_memory = SessionMemory::default(); + restored_memory.project = Some("shared app memory".to_string()); registry .record_turn(&key, Vec::new(), seeded_memory.clone()) .expect("seed memory"); diff --git a/engine/crates/fx-channel-telegram/src/lib.rs b/engine/crates/fx-channel-telegram/src/lib.rs index 4d6b5ac9..33e39261 100644 --- a/engine/crates/fx-channel-telegram/src/lib.rs +++ b/engine/crates/fx-channel-telegram/src/lib.rs @@ -1043,7 +1043,7 @@ mod tests { "message": {{ "message_id": 42, "chat": {{ "id": {chat_id} }}, - "from": {{ "first_name": "Alice" }}, + "from": {{ "first_name": "Example" }}, "text": "{text}" }} }}"# @@ -1065,7 +1065,7 @@ mod tests { assert_eq!(result.chat_id, 12345); assert_eq!(result.text, "hello bot"); assert_eq!(result.message_id, 42); - assert_eq!(result.from_name.as_deref(), Some("Alice")); + assert_eq!(result.from_name.as_deref(), Some("Example")); } #[test] @@ -1164,7 +1164,7 @@ mod tests { "message": { "message_id": 44, "chat": { "id": 12345 }, - "from": { "first_name": "Alice" }, + "from": { "first_name": "Example" }, "photo": [ {"file_id": "thumb", "width": 90, "height": 90}, {"file_id": "medium", "width": 320, "height": 240}, diff --git a/engine/crates/fx-cli/Cargo.toml b/engine/crates/fx-cli/Cargo.toml index f999f446..800b0209 100644 --- a/engine/crates/fx-cli/Cargo.toml +++ b/engine/crates/fx-cli/Cargo.toml @@ -34,6 +34,7 @@ fx-auth.workspace = true fx-conversation.workspace = true fx-llm = { path = "../fx-llm" } fx-loadable = { path = "../fx-loadable" } +fx-python.workspace = true fx-memory.workspace = true fx-embeddings.workspace = true fx-analysis.workspace = true @@ -93,6 +94,7 @@ syn = { version = "2", features = ["full"] } fx-config = { workspace = true, features = ["test-support"] } fx-consensus = { workspace = true, features = ["test-support"] } fx-embeddings = { workspace = true, features = ["test-support"] } +fx-loadable = { path = "../fx-loadable", features = ["test-support"] } fx-subagent = { workspace = true, features = ["test-support"] } tempfile = "3.14" tower = { version = "0.5", features = ["util"] } diff --git a/engine/crates/fx-cli/src/commands/devices.rs b/engine/crates/fx-cli/src/commands/devices.rs index 01455b47..f236adb2 100644 --- a/engine/crates/fx-cli/src/commands/devices.rs +++ b/engine/crates/fx-cli/src/commands/devices.rs @@ -277,7 +277,7 @@ mod tests { let response = DevicesResponse { devices: vec![DeviceInfo { id: "dev-a1b2c3".to_string(), - device_name: "My MacBook".to_string(), + device_name: "Example MacBook".to_string(), created_at: 1_773_400_000, last_used_at: 1_773_435_000, }], @@ -289,7 +289,7 @@ mod tests { .expect("device JSON should parse"); assert_eq!(json["devices"][0]["id"], "dev-a1b2c3"); - assert_eq!(json["devices"][0]["device_name"], "My MacBook"); + assert_eq!(json["devices"][0]["device_name"], "Example MacBook"); assert_eq!(json["devices"][0]["created_at"], 1_773_400_000); assert_eq!(json["devices"][0]["last_used_at"], 1_773_435_000); } @@ -299,7 +299,7 @@ mod tests { let response = DevicesResponse { devices: vec![DeviceInfo { id: "dev-a1b2c3".to_string(), - device_name: "My MacBook".to_string(), + device_name: "Example MacBook".to_string(), created_at: 1_700_000_000, last_used_at: 1_700_000_300, }], diff --git a/engine/crates/fx-cli/src/commands/fleet.rs b/engine/crates/fx-cli/src/commands/fleet.rs index deb25d05..6a6b70df 100644 --- a/engine/crates/fx-cli/src/commands/fleet.rs +++ b/engine/crates/fx-cli/src/commands/fleet.rs @@ -24,7 +24,7 @@ pub enum FleetCommands { Init, /// Add a worker node to the fleet Add { - /// Node name (e.g., "node-alpha") + /// Node name (e.g., "node-a") name: String, /// Tailscale IP address #[arg(long)] @@ -35,7 +35,7 @@ pub enum FleetCommands { }, /// Join a fleet as a worker node Join { - /// Primary node endpoint (e.g., 10.0.0.1:8400) + /// Primary node endpoint (e.g., 203.0.113.20:8400) primary: String, /// Bearer token from `fawx fleet add` #[arg(long)] @@ -505,7 +505,7 @@ mod tests { #[test] fn parsed_hostname_trims_trailing_newline() { - assert_eq!(parsed_hostname(b"macmini\n"), Some("macmini".to_string())); + assert_eq!(parsed_hostname(b"node-a\n"), Some("node-a".to_string())); } #[test] @@ -549,8 +549,8 @@ mod tests { let mut output = Vec::new(); execute_fleet_command( &FleetCommands::Add { - name: "node-alpha".to_string(), - ip: "10.0.0.2".to_string(), + name: "node-a".to_string(), + ip: "203.0.113.10".to_string(), port: 8400, }, &fleet_dir, @@ -563,11 +563,11 @@ mod tests { let tokens = read_tokens(&fleet_dir); let token = tokens.first().expect("token should exist"); - assert!(output.contains("✓ Node \"node-alpha\" registered")); + assert!(output.contains("✓ Node \"node-a\" registered")); assert!(output.contains("✓ Token generated")); assert!(output.contains("Join command (run on the worker):")); assert!(output.contains(&format!( - "fawx fleet join 10.0.0.2:8400 --token {}", + "fawx fleet join 203.0.113.10:8400 --token {}", token.secret ))); } @@ -583,8 +583,8 @@ mod tests { let mut first_output = Vec::new(); execute_fleet_command( &FleetCommands::Add { - name: "node-alpha".to_string(), - ip: "10.0.0.2".to_string(), + name: "node-a".to_string(), + ip: "203.0.113.10".to_string(), port: 8400, }, &fleet_dir, @@ -595,8 +595,8 @@ mod tests { let result = execute_fleet_command( &FleetCommands::Add { - name: "node-alpha".to_string(), - ip: "10.0.0.3".to_string(), + name: "node-a".to_string(), + ip: "203.0.113.11".to_string(), port: 8400, }, &fleet_dir, @@ -612,7 +612,7 @@ mod tests { let mut server = TestRegisterServer::spawn(TestRegisterResponse { status: StatusCode::OK, body: FleetRegistrationResponse { - node_id: "macmini-a1b2c3".to_string(), + node_id: "node-a-a1b2c3".to_string(), accepted: true, message: "registered".to_string(), }, @@ -650,12 +650,12 @@ mod tests { .json .capabilities .contains(&"agentic_loop".to_string())); - assert_eq!(identity.node_id, "macmini-a1b2c3"); + assert_eq!(identity.node_id, "node-a-a1b2c3"); assert_eq!(identity.primary_endpoint, server.base_url); assert_eq!(identity.bearer_token, token); assert!(identity.registered_at_ms > 0); assert!(output.contains("✓ Connected to primary at")); - assert!(output.contains("✓ Registered as node \"macmini-a1b2c3\"")); + assert!(output.contains("✓ Registered as node \"node-a-a1b2c3\"")); assert!(output.contains("✓ Identity saved to")); } @@ -665,13 +665,13 @@ mod tests { let fleet_dir = temp_dir.path().join("fleet"); let mut manager = FleetManager::init(&fleet_dir).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); let mut output = Vec::new(); execute_fleet_command( &FleetCommands::Remove { - name: "node-alpha".to_string(), + name: "node-a".to_string(), }, &fleet_dir, &mut output, @@ -682,7 +682,7 @@ mod tests { let reloaded_manager = FleetManager::load(&fleet_dir).expect("fleet should load"); let output = String::from_utf8(output).expect("utf8"); - assert!(output.contains("✓ Node \"node-alpha\" removed and token revoked")); + assert!(output.contains("✓ Node \"node-a\" removed and token revoked")); assert_eq!(reloaded_manager.verify_bearer(&token.secret), None); assert!(reloaded_manager.list_nodes().is_empty()); } @@ -738,16 +738,16 @@ mod tests { let mut manager = FleetManager::load(&fleet_dir).expect("fleet should load"); manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("first node should add"); manager - .add_node("node-beta", "10.0.0.3", 8400) + .add_node("node-b", "203.0.113.11", 8400) .expect("second node should add"); let now_ms = current_time_ms(); let mut nodes = read_nodes(&fleet_dir); for node in &mut nodes { - if node.name == "node-beta" { + if node.name == "node-b" { node.status = NodeStatus::Online; node.last_heartbeat_ms = now_ms.saturating_sub(65_000); } @@ -761,10 +761,10 @@ mod tests { let output = String::from_utf8(output).expect("utf8"); assert!(output.contains("Fleet Nodes:")); - assert!(output.contains("node-beta")); - assert!(output.contains("node-alpha")); - assert!(output.contains("10.0.0.2:8400")); - assert!(output.contains("10.0.0.3:8400")); + assert!(output.contains("node-b")); + assert!(output.contains("node-a")); + assert!(output.contains("203.0.113.10:8400")); + assert!(output.contains("203.0.113.11:8400")); assert!(output.contains("online")); assert!(output.contains("offline")); assert!(output.contains("1m ago")); @@ -777,13 +777,13 @@ mod tests { let fleet_dir = temp_dir.path().join("fleet"); let mut manager = FleetManager::init(&fleet_dir).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); let nodes = manager.list_nodes(); let output = render_list_output(&nodes, current_time_ms()); - assert!(output.contains("node-alpha")); + assert!(output.contains("node-a")); assert!(!output.contains(&token.secret)); } diff --git a/engine/crates/fx-cli/src/commands/marketplace.rs b/engine/crates/fx-cli/src/commands/marketplace.rs index 4165184f..285b04bd 100644 --- a/engine/crates/fx-cli/src/commands/marketplace.rs +++ b/engine/crates/fx-cli/src/commands/marketplace.rs @@ -1,111 +1,93 @@ //! CLI commands for the skill marketplace (search, install, list). -use std::path::{Path, PathBuf}; +use std::path::PathBuf; -use fx_marketplace::{InstalledSkill, RegistryConfig, SkillEntry}; +use crate::startup; +use fx_marketplace::{InstalledSkill, SkillEntry}; -/// Default registry URL (raw GitHub content). -const DEFAULT_REGISTRY: &str = "https://raw.githubusercontent.com/fawxai/registry/main"; - -/// Official fawxai publisher Ed25519 public key (32 bytes). -const FAWXAI_PUBLIC_KEY: [u8; 32] = [ - 62, 38, 70, 230, 12, 59, 226, 179, 11, 150, 52, 48, 238, 181, 159, 188, 106, 55, 109, 208, 1, - 191, 157, 233, 161, 111, 154, 212, 209, 133, 28, 68, -]; - -/// Resolve the Fawx data directory (`~/.fawx`). -fn data_dir() -> anyhow::Result { - let home = dirs::home_dir().ok_or_else(|| anyhow::anyhow!("cannot determine home dir"))?; - Ok(home.join(".fawx")) -} - -/// Load trusted keys from `~/.fawx/trusted_keys/`. -fn load_trusted_keys(data: &Path) -> anyhow::Result>> { - let mut keys = vec![FAWXAI_PUBLIC_KEY.to_vec()]; - let keys_dir = data.join("trusted_keys"); - if keys_dir.exists() { - for entry in std::fs::read_dir(&keys_dir)? { - let path = entry?.path(); - if path.is_file() { - keys.push(std::fs::read(&path)?); - } - } - } - Ok(keys) +/// Resolve the Fawx data directory. +fn data_dir() -> PathBuf { + startup::fawx_data_dir() } /// Build a `RegistryConfig` from defaults. -fn build_config() -> anyhow::Result { - let data = data_dir()?; - let trusted_keys = load_trusted_keys(&data)?; - Ok(RegistryConfig { - registry_url: DEFAULT_REGISTRY.to_string(), - data_dir: data, - trusted_keys, - }) +fn build_config() -> anyhow::Result { + let data = data_dir(); + Ok(fx_marketplace::default_config(&data)?) } -/// Print a list of skill entries from search results. -fn print_search_results(entries: &[SkillEntry]) { +/// Render a list of skill entries from search results. +fn render_search_results(entries: &[SkillEntry]) -> String { if entries.is_empty() { - println!("No skills found."); - return; + return "No skills found.".to_string(); } - for e in entries { - let size = e + + let mut lines = Vec::new(); + for entry in entries { + let size = entry .size_bytes - .map(|b| format!("{} KB", b / 1024)) + .map(|bytes| format!("{} KB", bytes / 1024)) .unwrap_or_else(|| "unknown".to_string()); - let caps = e.capabilities.join(", "); - println!(" {} v{} — {}", e.name, e.version, e.description); - println!(" by {} | capabilities: {} | {}", e.author, caps, size); + let capabilities = entry.capabilities.join(", "); + lines.push(format!( + " {} v{}: {}", + entry.name, entry.version, entry.description + )); + lines.push(format!( + " by {} | capabilities: {} | {}", + entry.author, capabilities, size + )); } - let n = entries.len(); - let noun = if n == 1 { "skill" } else { "skills" }; - println!("\n{n} {noun} found."); + + let count = entries.len(); + let noun = if count == 1 { "skill" } else { "skills" }; + lines.push(String::new()); + lines.push(format!("{count} {noun} found.")); + lines.join("\n") } -/// Print a list of installed skills. -fn print_installed(skills: &[InstalledSkill]) { +/// Render a list of installed skills. +fn render_installed(skills: &[InstalledSkill]) -> String { if skills.is_empty() { - println!("No installed skills."); - return; + return "No installed skills.".to_string(); } - println!("Installed skills:"); - for s in skills { - let caps = if s.capabilities.is_empty() { + + let mut lines = vec!["Installed skills:".to_string()]; + for skill in skills { + let capabilities = if skill.capabilities.is_empty() { String::new() } else { - format!(" ({})", s.capabilities.join(", ")) + format!(" ({})", skill.capabilities.join(", ")) }; - println!(" {:16} v{}{}", s.name, s.version, caps); + lines.push(format!( + " {:16} v{}{}", + skill.name, skill.version, capabilities + )); } + lines.join("\n") } -/// `fawx search ` -pub fn search_cmd(query: &str) -> anyhow::Result<()> { +pub fn search_output(query: &str) -> anyhow::Result { let config = build_config()?; - println!("Registry: fawxai/fawx-skills\n"); let results = fx_marketplace::search(&config, query)?; - print_search_results(&results); - Ok(()) + Ok(format!( + "Registry: fawxai/registry\n\n{}", + render_search_results(&results) + )) } -/// `fawx install ` -pub fn install_cmd(name: &str) -> anyhow::Result<()> { +pub fn install_output(name: &str) -> anyhow::Result { let config = build_config()?; - println!("Installing {name}..."); let result = fx_marketplace::install(&config, name)?; - println!(" Downloaded: {} KB", result.size_bytes / 1024); - println!(" Signature: verified ✓"); - println!(" Installed to: {}", result.install_path.display()); - Ok(()) + Ok(format!( + "Installing {name}...\n Downloaded: {} KB\n Signature: verified ✓\n Installed to: {}", + result.size_bytes / 1024, + result.install_path.display() + )) } -/// `fawx list` -pub fn list_cmd() -> anyhow::Result<()> { - let data = data_dir()?; +pub fn list_output() -> anyhow::Result { + let data = data_dir(); let skills = fx_marketplace::list_installed(&data)?; - print_installed(&skills); - Ok(()) + Ok(render_installed(&skills)) } diff --git a/engine/crates/fx-cli/src/commands/reset.rs b/engine/crates/fx-cli/src/commands/reset.rs index 24198fc6..48e89b51 100644 --- a/engine/crates/fx-cli/src/commands/reset.rs +++ b/engine/crates/fx-cli/src/commands/reset.rs @@ -641,7 +641,7 @@ mod tests { #[test] fn all_reset_preserves_credentials_while_resetting_the_rest() { let fixture = ResetFixture::new( - "[http]\nbearer_token = \"keep-me\"\n\n[telegram]\nbot_token = \"keep-bot\"\n\n[[fleet.nodes]]\nid = \"node-1\"\nname = \"Node One\"\nendpoint = \"https://node.example\"\nauth_token = \"keep-token\"\nssh_key = \"~/.ssh/node-1\"\ncapabilities = [\"agentic_loop\"]\naddress = \"100.64.0.1\"\nuser = \"deploy\"\n", + "[http]\nbearer_token = \"keep-me\"\n\n[telegram]\nbot_token = \"keep-bot\"\n\n[[fleet.nodes]]\nid = \"node-1\"\nname = \"Node One\"\nendpoint = \"https://node.example\"\nauth_token = \"keep-token\"\nssh_key = \"~/.ssh/node-1\"\ncapabilities = [\"agentic_loop\"]\naddress = \"203.0.113.30\"\nuser = \"deploy\"\n", ); write_dir_file(&fixture.layout.data_dir.join("memory"), "memory.json"); write_dir_file(&fixture.layout.embedding_model_dir, "index.bin"); diff --git a/engine/crates/fx-cli/src/commands/setup.rs b/engine/crates/fx-cli/src/commands/setup.rs index 32859771..f26c2aac 100644 --- a/engine/crates/fx-cli/src/commands/setup.rs +++ b/engine/crates/fx-cli/src/commands/setup.rs @@ -36,7 +36,7 @@ pub async fn run(force: bool) -> anyhow::Result { println!("🦊 Welcome to Fawx setup!\n"); - let mut wizard = SetupWizard::new(force)?; + let mut wizard = SetupWizard::new(force, None)?; wizard.print_system_check(); wizard.run_tailscale_phase(); if !wizard.confirm_existing_config()? { @@ -182,8 +182,8 @@ static SETUP_SKILLS: [SetupSkill; 7] = [ ]; impl SetupWizard { - fn new(force: bool) -> anyhow::Result { - let data_dir = fawx_data_dir(); + fn new(force: bool, data_dir_override: Option) -> anyhow::Result { + let data_dir = data_dir_override.unwrap_or_else(fawx_data_dir); fs::create_dir_all(&data_dir) .with_context(|| format!("failed to create {}", data_dir.display()))?; let config_path = data_dir.join("config.toml"); @@ -1508,36 +1508,6 @@ mod tests { use super::*; use tempfile::TempDir; - static TEST_HOME_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); - - struct HomeGuard { - original_home: Option, - } - - impl HomeGuard { - fn set(temp_home: &TempDir) -> Self { - let original_home = std::env::var("HOME").ok(); - unsafe { - std::env::set_var("HOME", temp_home.path()); - } - Self { original_home } - } - } - - impl Drop for HomeGuard { - fn drop(&mut self) { - if let Some(home) = &self.original_home { - unsafe { - std::env::set_var("HOME", home); - } - } else { - unsafe { - std::env::remove_var("HOME"); - } - } - } - } - #[test] fn parse_chat_ids_accepts_blank_input() { assert!(parse_chat_ids(" ").expect("blank input").is_empty()); @@ -1738,21 +1708,14 @@ mod tests { #[test] fn setup_wizard_stores_skill_credentials_without_reopening_database() { - let _home_lock = TEST_HOME_LOCK - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let temp_home = TempDir::new().expect("temp home"); - let _home = HomeGuard::set(&temp_home); - let mut wizard = SetupWizard::new(false).expect("setup wizard"); - let data_dir = temp_home.path().join(".fawx"); + let temp_dir = TempDir::new().expect("temp dir"); + let mut wizard = + SetupWizard::new(false, Some(temp_dir.path().to_path_buf())).expect("setup wizard"); - { - // Verify the wizard did not eagerly open credentials.db during construction. - // The handle is scoped so the redb file lock is released before the wizard - // lazily opens the same database below. - EncryptedFileCredentialStore::open(&data_dir) - .expect("wizard should not lock credential store during construction"); - } + assert!( + wizard.skill_credential_store.is_none(), + "wizard should not eagerly open the skill credential store" + ); wizard .store_skill_credential("brave_api_key", "brv-test") diff --git a/engine/crates/fx-cli/src/commands/skills.rs b/engine/crates/fx-cli/src/commands/skills.rs index 309ebdcf..6587fa28 100644 --- a/engine/crates/fx-cli/src/commands/skills.rs +++ b/engine/crates/fx-cli/src/commands/skills.rs @@ -678,7 +678,7 @@ mod tests { assert_eq!( error.to_string(), - "unknown capability 'flying', valid: network, storage, notifications, sensors, phone_actions" + "unknown capability 'flying', valid: network, storage, shell, filesystem, notifications, sensors, phone_actions" ); } diff --git a/engine/crates/fx-cli/src/commands/slash.rs b/engine/crates/fx-cli/src/commands/slash.rs index 8076f84f..12450265 100644 --- a/engine/crates/fx-cli/src/commands/slash.rs +++ b/engine/crates/fx-cli/src/commands/slash.rs @@ -69,6 +69,15 @@ pub trait CommandHost { fn handle_sign(&self, _target: Option<&str>, _has_extra_args: bool) -> Result { Ok("WASM signing is not available in this mode.".to_string()) } + fn list_skills(&self) -> Result { + Ok("Skill listing is not available in this mode.".to_string()) + } + fn install_skill(&self, _name: &str) -> Result { + Ok("Skill installation is not available in this mode.".to_string()) + } + fn search_skills(&self, _query: &str) -> Result { + Ok("Skill search is not available in this mode.".to_string()) + } } pub struct CommandContext<'a, H: CommandHost> { @@ -98,6 +107,13 @@ pub enum ParsedCommand { target: Option, has_extra_args: bool, }, + Skills, + Install { + name: Option, + }, + Search { + query: Option, + }, Budget, Loop, Status, @@ -158,6 +174,20 @@ pub fn parse_command(value: &str) -> ParsedCommand { target: parts.next().map(ToString::to_string), has_extra_args: parts.next().is_some(), }, + "skills" => ParsedCommand::Skills, + "install" => ParsedCommand::Install { + name: parts.next().map(ToString::to_string), + }, + "search" => ParsedCommand::Search { + query: { + let rest: Vec<&str> = parts.collect(); + if rest.is_empty() { + None + } else { + Some(rest.join(" ")) + } + }, + }, "budget" => ParsedCommand::Budget, "loop" => ParsedCommand::Loop, "status" => ParsedCommand::Status, @@ -224,6 +254,16 @@ pub fn execute_command( app.handle_sign(target.as_deref(), *has_extra_args) .map(response) }), + ParsedCommand::Skills => { + execute_embedded_only(ctx.app, |app| app.list_skills().map(response)) + } + ParsedCommand::Install { name } => { + execute_embedded_only(ctx.app, |app| execute_install(app, name.as_deref())) + } + ParsedCommand::Search { query } => execute_embedded_only(ctx.app, |app| { + app.search_skills(query.as_deref().unwrap_or("")) + .map(response) + }), ParsedCommand::Budget => Some(Ok(response(ctx.app.show_budget_status()))), ParsedCommand::Loop => { execute_embedded_only(ctx.app, |app| app.show_loop_status().map(response)) @@ -436,9 +476,12 @@ pub fn help_text() -> &'static str { " /keys list List trusted public keys\n", " /keys trust \n", " /keys revoke \n", - " /sign Sign one WASM skill\n", - " /sign --all Sign all installed WASM skills\n", - " /status Show model, tokens, budget summary\n", + " /sign Sign one WASM skill\n", + " /sign --all Sign all installed WASM skills\n", + " /skills List installed skills\n", + " /install Install a skill from the marketplace\n", + " /search [query] Search the skill marketplace\n", + " /status Show model, tokens, budget summary\n", " /budget Show detailed budget usage\n", " /loop Show loop iteration details\n", " /signals Show condensed signal summary for last turn\n", @@ -569,6 +612,13 @@ fn execute_keys( .map(response) } +fn execute_install(app: &mut H, name: Option<&str>) -> Result { + match name { + Some(name) => app.install_skill(name).map(response), + None => Ok(response("Usage: /install ".to_string())), + } +} + fn execute_config(app: &mut H, action: Option<&str>) -> Result { match action { None => app.show_config().map(response), @@ -658,6 +708,7 @@ mod tests { use fx_core::signals::{LoopStep, Signal, SignalKind}; use fx_kernel::budget::BudgetRemaining; use fx_kernel::loop_engine::LoopStatus; + use std::cell::RefCell; use std::sync::{Arc, Mutex}; use tempfile::tempdir; @@ -672,7 +723,12 @@ mod tests { thinking: String, init: String, reload: String, + skills: String, + installed_skill: String, + search_results: String, last_model: Option, + last_installed_skill: RefCell>, + last_search_query: RefCell>, thinking_level: Option, } @@ -736,6 +792,20 @@ mod tests { self.thinking_level = level.map(ToString::to_string); Ok(self.thinking.clone()) } + + fn list_skills(&self) -> Result { + Ok(self.skills.clone()) + } + + fn install_skill(&self, name: &str) -> Result { + self.last_installed_skill.replace(Some(name.to_string())); + Ok(self.installed_skill.clone()) + } + + fn search_skills(&self, query: &str) -> Result { + self.last_search_query.replace(Some(query.to_string())); + Ok(self.search_results.clone()) + } } #[test] @@ -847,6 +917,57 @@ mod tests { ); } + #[test] + fn parse_skills_command() { + assert_eq!(parse_command("/skills"), ParsedCommand::Skills); + } + + #[test] + fn parse_install_command_with_name() { + assert_eq!( + parse_command("/install weather"), + ParsedCommand::Install { + name: Some("weather".to_string()), + } + ); + } + + #[test] + fn parse_install_command_without_name() { + assert_eq!( + parse_command("/install"), + ParsedCommand::Install { name: None } + ); + } + + #[test] + fn parse_search_command_without_query() { + assert_eq!( + parse_command("/search"), + ParsedCommand::Search { query: None } + ); + } + + #[test] + fn parse_search_command_with_query() { + assert_eq!( + parse_command("/search weather"), + ParsedCommand::Search { + query: Some("weather".to_string()), + } + ); + } + + #[test] + fn parse_search_command_with_multi_word_query() { + assert_eq!( + parse_command("/search weather api"), + ParsedCommand::Search { + query: Some("weather api".to_string()), + } + ); + } + #[test] fn parse_proposals_accepts_optional_id() { assert_eq!( @@ -933,6 +1054,83 @@ mod tests { ); } + #[test] + fn execute_skills_command_lists_installed() { + let mut host = StubHost { + skills: "Installed skills".to_string(), + ..StubHost::default() + }; + let mut context = CommandContext { app: &mut host }; + + let result = execute_command(&mut context, &ParsedCommand::Skills) + .expect("server-side") + .expect("ok"); + + assert_eq!(result.response, "Installed skills"); + } + + #[test] + fn execute_install_command_with_name() { + let mut host = StubHost { + installed_skill: "Installed github".to_string(), + ..StubHost::default() + }; + let result = { + let mut context = CommandContext { app: &mut host }; + execute_command( + &mut context, + &ParsedCommand::Install { + name: Some("github".to_string()), + }, + ) + .expect("server-side") + .expect("ok") + }; + + assert_eq!(result.response, "Installed github"); + assert_eq!( + host.last_installed_skill.borrow().as_deref(), + Some("github") + ); + } + + #[test] + fn execute_install_command_without_name_shows_usage() { + let mut host = StubHost::default(); + + let result = { + let mut context = CommandContext { app: &mut host }; + execute_command(&mut context, &ParsedCommand::Install { name: None }) + .expect("server-side") + .expect("ok") + }; + + assert_eq!(result.response, "Usage: /install "); + assert!(host.last_installed_skill.borrow().is_none()); + } + + #[test] + fn execute_search_command_routes_query() { + let mut host = StubHost { + search_results: "Search results".to_string(), + ..StubHost::default() + }; + let result = { + let mut context = CommandContext { app: &mut host }; + execute_command( + &mut context, + &ParsedCommand::Search { + query: Some("weather".to_string()), + }, + ) + .expect("server-side") + .expect("ok") + }; + + assert_eq!(result.response, "Search results"); + assert_eq!(host.last_search_query.borrow().as_deref(), Some("weather")); + } + #[test] fn execute_command_formats_model_switch_response() { let mut host = StubHost::default(); diff --git a/engine/crates/fx-cli/src/commands/tailscale.rs b/engine/crates/fx-cli/src/commands/tailscale.rs index 7832908a..2961d874 100644 --- a/engine/crates/fx-cli/src/commands/tailscale.rs +++ b/engine/crates/fx-cli/src/commands/tailscale.rs @@ -106,10 +106,10 @@ mod tests { #[test] fn parse_dns_name_trims_trailing_dot() { - let hostname = parse_dns_name(br#"{"Self":{"DNSName":"fawx.tail123.ts.net."}}"#) + let hostname = parse_dns_name(br#"{"Self":{"DNSName":"node.example.ts.net."}}"#) .expect("hostname should parse"); - assert_eq!(hostname, "fawx.tail123.ts.net"); + assert_eq!(hostname, "node.example.ts.net"); } #[test] diff --git a/engine/crates/fx-cli/src/headless.rs b/engine/crates/fx-cli/src/headless.rs index d835b2bf..1989ef4f 100644 --- a/engine/crates/fx-cli/src/headless.rs +++ b/engine/crates/fx-cli/src/headless.rs @@ -536,7 +536,8 @@ impl SessionTurnCollector { StreamEvent::Done { .. } | StreamEvent::Error { .. } => { self.flush_pending_tool_results(); } - StreamEvent::TextDelta { .. } + StreamEvent::ToolError { .. } + | StreamEvent::TextDelta { .. } | StreamEvent::Notification { .. } | StreamEvent::PermissionPrompt(_) | StreamEvent::PhaseChange { .. } @@ -1205,6 +1206,7 @@ impl HeadlessApp { .await } + #[cfg(test)] #[allow(dead_code)] pub async fn process_message_with_images( &mut self, @@ -1796,7 +1798,7 @@ impl HeadlessApp { } } - #[allow(dead_code)] + #[cfg(test)] fn build_perception_snapshot(&self, input: &str, source: &InputSource) -> PerceptionSnapshot { self.build_perception_snapshot_with_attachments(input, source, &[], &[]) } @@ -2204,6 +2206,18 @@ impl CommandHost for HeadlessApp { fn handle_sign(&self, _target: Option<&str>, _has_extra_args: bool) -> anyhow::Result { Ok("Use `fawx sign ` CLI to sign WASM packages.".to_string()) } + + fn list_skills(&self) -> anyhow::Result { + crate::commands::marketplace::list_output() + } + + fn install_skill(&self, name: &str) -> anyhow::Result { + crate::commands::marketplace::install_output(name) + } + + fn search_skills(&self, query: &str) -> anyhow::Result { + crate::commands::marketplace::search_output(query) + } } fn preferred_supported_budget(levels: &[String]) -> ThinkingBudget { diff --git a/engine/crates/fx-cli/src/lib.rs b/engine/crates/fx-cli/src/lib.rs index 47e2b1b9..6c02f639 100644 --- a/engine/crates/fx-cli/src/lib.rs +++ b/engine/crates/fx-cli/src/lib.rs @@ -34,6 +34,8 @@ mod fleet_command { include!("commands/fleet.rs"); } } +#[path = "commands/marketplace.rs"] +pub(crate) mod marketplace_commands; #[cfg(test)] #[allow(dead_code)] mod repo_root; @@ -48,6 +50,7 @@ mod start_stop_command { include!("commands/start_stop.rs"); } mod commands { + pub(crate) use super::marketplace_commands as marketplace; pub(crate) use super::slash_commands as slash; } mod config_bridge; diff --git a/engine/crates/fx-cli/src/main.rs b/engine/crates/fx-cli/src/main.rs index c644baec..93032a2e 100644 --- a/engine/crates/fx-cli/src/main.rs +++ b/engine/crates/fx-cli/src/main.rs @@ -301,10 +301,16 @@ enum SkillCommands { /// List installed skills List, + /// Search the skill registry + Search { + /// Search query. Leave empty to show all available skills. + query: Option, + }, + /// Install a skill Install { - /// Path to skill WASM file - path: String, + /// Skill name or path to WASM file + name_or_path: String, }, /// Remove a skill @@ -905,16 +911,33 @@ async fn dispatch_audit(command: AuditCommands) -> anyhow::Result { } } +fn looks_like_local_skill_path(name_or_path: &str) -> bool { + name_or_path.contains('/') || name_or_path.contains('\\') || name_or_path.ends_with(".wasm") +} + +async fn dispatch_skill_install(name_or_path: &str) -> anyhow::Result { + if looks_like_local_skill_path(name_or_path) { + commands::skills::install(name_or_path).await?; + } else { + println!("{}", commands::marketplace::install_output(name_or_path)?); + } + Ok(0) +} + async fn dispatch_skill(command: SkillCommands) -> anyhow::Result { match command { SkillCommands::List => { - commands::skills::list().await?; + println!("{}", commands::marketplace::list_output()?); Ok(0) } - SkillCommands::Install { path } => { - commands::skills::install(&path).await?; + SkillCommands::Search { query } => { + println!( + "{}", + commands::marketplace::search_output(query.as_deref().unwrap_or(""))? + ); Ok(0) } + SkillCommands::Install { name_or_path } => dispatch_skill_install(&name_or_path).await, SkillCommands::Remove { name } => { commands::skills::remove(&name).await?; Ok(0) @@ -1055,15 +1078,15 @@ async fn dispatch_command(command: Commands) -> anyhow::Result { Commands::Audit { command } => dispatch_audit(command).await, Commands::Skill { command } => dispatch_skill(command).await, Commands::Search { query } => { - commands::marketplace::search_cmd(&query)?; + println!("{}", commands::marketplace::search_output(&query)?); Ok(0) } Commands::Install { name } => { - commands::marketplace::install_cmd(&name)?; + println!("{}", commands::marketplace::install_output(&name)?); Ok(0) } Commands::List => { - commands::marketplace::list_cmd()?; + println!("{}", commands::marketplace::list_output()?); Ok(0) } #[cfg(not(feature = "oauth-bridge"))] @@ -1186,9 +1209,9 @@ mod tests { use super::{build_telegram_channel, telegram_webhook_secret_from_credential_store}; use super::{ cleanup_stale_pid_file_at, dispatch_command, ensure_headless_chat_model_available, - fawx_tui_binary_name, find_fawx_tui_binary_from, resolve_ripcord_path_with, - ripcord_binary_name, Cli, Commands, SessionsCommands, SkillCommands, - FAWX_TUI_NOT_FOUND_MESSAGE, + fawx_tui_binary_name, find_fawx_tui_binary_from, looks_like_local_skill_path, + resolve_ripcord_path_with, ripcord_binary_name, Cli, Commands, SessionsCommands, + SkillCommands, FAWX_TUI_NOT_FOUND_MESSAGE, }; use crate::auth_store::AuthStore; use crate::restart; @@ -1365,6 +1388,50 @@ mod tests { )); } + #[test] + fn cli_parses_skill_install_command() { + let cli = Cli::parse_from(["fawx", "skill", "install", "github"]); + assert!(matches!( + cli.command, + Some(Commands::Skill { + command: SkillCommands::Install { name_or_path } + }) if name_or_path == "github" + )); + } + + #[test] + fn cli_parses_skill_search_without_query() { + let cli = Cli::parse_from(["fawx", "skill", "search"]); + assert!(matches!( + cli.command, + Some(Commands::Skill { + command: SkillCommands::Search { query: None } + }) + )); + } + + #[test] + fn cli_parses_skill_search_with_query() { + let cli = Cli::parse_from(["fawx", "skill", "search", "weather"]); + assert!(matches!( + cli.command, + Some(Commands::Skill { + command: SkillCommands::Search { query: Some(query) } + }) if query == "weather" + )); + } + + #[test] + fn looks_like_local_skill_path_detects_marketplace_names_and_paths() { + assert!(!looks_like_local_skill_path("github")); + assert!(!looks_like_local_skill_path("weather")); + assert!(!looks_like_local_skill_path("web.fetch")); + assert!(looks_like_local_skill_path("weather.wasm")); + assert!(looks_like_local_skill_path("./weather")); + assert!(looks_like_local_skill_path("/tmp/skill.wasm")); + assert!(looks_like_local_skill_path("skills\\weather")); + } + #[test] fn cli_parses_completions_command() { let cli = Cli::parse_from(["fawx", "completions", "bash"]); diff --git a/engine/crates/fx-cli/src/startup.rs b/engine/crates/fx-cli/src/startup.rs index 6ed139be..e7b71a59 100644 --- a/engine/crates/fx-cli/src/startup.rs +++ b/engine/crates/fx-cli/src/startup.rs @@ -36,12 +36,14 @@ use fx_kernel::{ use fx_llm::{ AnthropicProvider, CompletionRequest, ModelRouter, OpenAiProvider, OpenAiResponsesProvider, }; +use fx_loadable::watcher::{ReloadEvent, SkillWatcher}; use fx_loadable::{ NotificationSender, NotifySkill, SessionMemorySkill, SignaturePolicy, SkillRegistry, TransactionSkill, }; use fx_memory::embedding_index::EmbeddingIndex; use fx_memory::{JsonFileMemory, JsonMemoryConfig, SignalStore}; +use fx_python::PythonSkill; use fx_ripcord::{resolve_tripwires, RipcordJournal, TripwireEvaluator}; use fx_scratchpad::skill::ScratchpadSkill; use fx_scratchpad::Scratchpad; @@ -58,6 +60,7 @@ use std::io::IsTerminal; use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex, RwLock}; use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::mpsc; use tracing::Dispatch; use tracing_appender::non_blocking::WorkerGuard; use tracing_appender::rolling::{RollingFileAppender, Rotation}; @@ -934,7 +937,8 @@ fn build_skill_registry( })); ProcessRegistry::spawn_cleanup_task(&process_registry); let mut startup_warnings = Vec::new(); - let executor = build_tool_executor(&options, tool_config, process_registry); + let executor = build_tool_executor(&options, tool_config, process_registry) + .with_protected_branches(config.git.protected_branches.clone()); let (mut executor, memory, embedding_index_persistence, snapshot_text, memory_enabled) = attach_memory_if_enabled( executor, @@ -1005,6 +1009,8 @@ fn build_skill_registry( let session_memory = Arc::new(Mutex::new(fx_session::SessionMemory::default())); let session_memory_skill = SessionMemorySkill::new(Arc::clone(&session_memory)); registry.register(Arc::new(session_memory_skill)); + let python_skill = PythonSkill::new(data_dir); + registry.register(Arc::new(python_skill)); if let Some(session_registry) = options.session_registry.clone() { let session_skill = SessionToolsSkill::new(session_registry); @@ -1053,7 +1059,8 @@ fn build_skill_registry( std::sync::Arc::new(move || cp.get_credential("github_token")) as std::sync::Arc Option> + Send + Sync> }); - let git_skill = GitSkill::new(options.working_dir.clone(), sm, github_token_fn); + let git_skill = GitSkill::new(options.working_dir.clone(), sm, github_token_fn) + .with_protected_branches(config.git.protected_branches.clone()); registry.register(Arc::new(git_skill)); // Load WASM skills from ~/.fawx/skills/ @@ -1102,6 +1109,14 @@ fn build_skill_registry( }; apply_skill_summaries(&runtime_info, registry.as_ref()); + let skills_dir = data_dir.join("skills"); + start_skill_watcher( + skills_dir, + Arc::clone(®istry), + Arc::clone(&runtime_info), + credential_provider.clone(), + signature_policy.clone(), + ); SkillRegistryBundle { registry, @@ -1121,6 +1136,80 @@ fn build_skill_registry( } } +fn start_skill_watcher( + skills_dir: PathBuf, + registry: Arc, + runtime_info: Arc>, + credential_provider: Option>, + signature_policy: SignaturePolicy, +) { + if let Err(error) = fs::create_dir_all(&skills_dir) { + tracing::warn!(path = %skills_dir.display(), error = %error, "failed to create skills directory for watcher"); + return; + } + + let Ok(handle) = tokio::runtime::Handle::try_current() else { + tracing::debug!(path = %skills_dir.display(), "skipping skill watcher startup without active tokio runtime"); + return; + }; + + let (reload_event_tx, reload_event_rx) = mpsc::channel(32); + let mut skill_watcher = SkillWatcher::new( + skills_dir, + Arc::clone(®istry), + reload_event_tx, + credential_provider, + signature_policy, + ); + skill_watcher.initialize_hashes(); + handle.spawn(handle_skill_reload_events( + reload_event_rx, + runtime_info, + registry, + )); + handle.spawn(async move { + if let Err(error) = skill_watcher.run().await { + tracing::error!(error = %error, "skill watcher exited with error"); + } + }); +} + +async fn handle_skill_reload_events( + mut reload_event_rx: mpsc::Receiver, + runtime_info: Arc>, + registry: Arc, +) { + while let Some(event) = reload_event_rx.recv().await { + log_skill_reload_event(&event); + apply_skill_summaries(&runtime_info, registry.as_ref()); + } +} + +fn log_skill_reload_event(event: &ReloadEvent) { + match event { + ReloadEvent::Loaded { + skill_name, + version, + } => tracing::info!(skill = %skill_name, version = %version, "skill hot-loaded"), + ReloadEvent::Updated { + skill_name, + old_version, + new_version, + } => tracing::info!( + skill = %skill_name, + old_version = %old_version, + new_version = %new_version, + "skill hot-reloaded" + ), + ReloadEvent::Removed { skill_name } => { + tracing::info!(skill = %skill_name, "skill removed") + } + ReloadEvent::Error { skill_name, error } => { + tracing::warn!(skill = %skill_name, error = %error, "skill reload failed") + } + } +} + fn build_tool_executor( options: &SkillRegistryBuildOptions, tool_config: ToolConfig, @@ -1910,6 +1999,7 @@ mod tests { use fx_config::manager::ConfigManager; use fx_core::memory::MemoryProvider; use fx_embeddings::test_support::create_test_model_dir; + use fx_loadable::test_support::write_test_skill; use fx_subagent::test_support::StubSubagentControl; use std::cell::Cell; use std::io; @@ -1917,6 +2007,7 @@ mod tests { use std::path::Path; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; + use std::time::Duration; use tracing::Level; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::fmt::writer::MakeWriter; @@ -1930,15 +2021,46 @@ mod tests { (config, temp_dir) } + fn registry_has_skill(bundle: &LoopEngineBundle, name: &str) -> bool { + bundle + .skill_registry + .skill_summaries() + .iter() + .any(|(skill_name, _, _, _)| skill_name == name) + } + + fn runtime_info_has_skill(bundle: &LoopEngineBundle, name: &str) -> bool { + bundle + .runtime_info + .read() + .expect("runtime info") + .skills + .iter() + .any(|skill| skill.name == name) + } + + async fn wait_for_skill_registration(bundle: &LoopEngineBundle, name: &str) { + tokio::time::timeout(Duration::from_secs(10), async { + loop { + if registry_has_skill(bundle, name) && runtime_info_has_skill(bundle, name) { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) + .await + .expect("skill watcher should register the new skill"); + } + fn test_fleet_node_config() -> fx_config::NodeConfig { fx_config::NodeConfig { id: "mac-mini".to_string(), - name: "Node Alpha".to_string(), + name: "Worker Node A".to_string(), endpoint: Some("https://10.0.0.5:8400".to_string()), auth_token: Some("token".to_string()), capabilities: vec!["agentic_loop".to_string(), "test".to_string()], address: Some("10.0.0.5".to_string()), - user: Some("admin".to_string()), + user: Some("builder".to_string()), ssh_key: Some("~/.ssh/id_ed25519".to_string()), } } @@ -2557,6 +2679,29 @@ mod tests { assert!(!names.contains(&"node_run".to_string())); } + #[tokio::test] + async fn headless_bundle_starts_skill_watcher_for_runtime_installs() { + let (config, _temp_dir) = test_config_with_temp_dir(); + let skills_dir = config + .general + .data_dir + .clone() + .expect("data dir") + .join("skills"); + let bundle = + build_headless_loop_engine_bundle(&config, None, HeadlessLoopBuildOptions::default()) + .expect("bundle should build"); + + assert!( + skills_dir.exists(), + "startup should create the skills directory" + ); + + tokio::time::sleep(Duration::from_millis(500)).await; + write_test_skill(&skills_dir, "runtimeinstallwatcher").expect("write test skill"); + wait_for_skill_registration(&bundle, "runtimeinstallwatcher").await; + } + #[cfg(feature = "http")] #[test] fn build_tool_executor_attaches_experiment_registrar_when_registry_supplied() { diff --git a/engine/crates/fx-cloud-gpu/Cargo.toml b/engine/crates/fx-cloud-gpu/Cargo.toml new file mode 100644 index 00000000..af65e57a --- /dev/null +++ b/engine/crates/fx-cloud-gpu/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "fx-cloud-gpu" +version = "0.1.0" +edition = "2021" + +authors.workspace = true +license.workspace = true + +description = "Cloud GPU provider trait and skill adapters for Fawx." + +[dependencies] +async-trait = "0.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2" +fx-kernel = { path = "../fx-kernel" } +fx-loadable = { path = "../fx-loadable" } +fx-llm = { path = "../fx-llm" } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/engine/crates/fx-cloud-gpu/src/lib.rs b/engine/crates/fx-cloud-gpu/src/lib.rs new file mode 100644 index 00000000..fe8dbd0d --- /dev/null +++ b/engine/crates/fx-cloud-gpu/src/lib.rs @@ -0,0 +1,111 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +mod provider; +mod skill; + +pub use provider::CloudGpuProvider; +pub use skill::CloudGpuSkill; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PodConfig { + pub name: String, + pub gpu: GpuType, + pub gpu_count: u32, + pub image: String, + pub disk_gb: u32, + #[serde(default)] + pub env: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GpuType { + Rtx3090, + Rtx4090, + #[serde(rename = "A100_80gb")] + A100_80Gb, + #[serde(rename = "H100_80gb")] + H100_80Gb, + Custom(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Pod { + pub id: String, + pub status: PodStatus, + pub ssh_host: String, + pub ssh_port: u16, + pub gpu: GpuType, + pub cost_per_hour: f64, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum PodStatus { + Creating, + Running, + Stopped, + Terminated, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExecResult { + pub stdout: String, + pub stderr: String, + pub exit_code: i32, + pub duration_ms: u64, +} + +#[derive(Debug, thiserror::Error)] +pub enum GpuError { + #[error("provider error: {0}")] + Provider(String), + #[error("authentication failed: {0}")] + Authentication(String), + #[error("pod not found: {0}")] + PodNotFound(String), + #[error("rate limited: retry after {retry_after_seconds}s")] + RateLimited { retry_after_seconds: u32 }, + #[error("timeout after {0}s")] + Timeout(u32), + #[error("ssh error: {0}")] + Ssh(String), + #[error("transfer error: {0}")] + Transfer(String), +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn pod_config_defaults_env_when_omitted() { + let config: PodConfig = serde_json::from_value(json!({ + "name": "trainer", + "gpu": "Rtx4090", + "gpu_count": 1, + "image": "nvidia/cuda:12.0.0-runtime-ubuntu22.04", + "disk_gb": 200, + })) + .expect("pod config without env should deserialize"); + + assert!(config.env.is_empty()); + } + + #[test] + fn gpu_type_uses_legacy_wire_names_for_80gb_variants() { + let a100 = serde_json::to_string(&GpuType::A100_80Gb).expect("serialize A100"); + let h100 = serde_json::to_string(&GpuType::H100_80Gb).expect("serialize H100"); + + assert_eq!(a100, "\"A100_80gb\""); + assert_eq!(h100, "\"H100_80gb\""); + assert!(matches!( + serde_json::from_str::("\"A100_80gb\""), + Ok(GpuType::A100_80Gb) + )); + assert!(matches!( + serde_json::from_str::("\"H100_80gb\""), + Ok(GpuType::H100_80Gb) + )); + } +} diff --git a/engine/crates/fx-cloud-gpu/src/provider.rs b/engine/crates/fx-cloud-gpu/src/provider.rs new file mode 100644 index 00000000..edabeff7 --- /dev/null +++ b/engine/crates/fx-cloud-gpu/src/provider.rs @@ -0,0 +1,39 @@ +use crate::{ExecResult, GpuError, Pod, PodConfig}; +use async_trait::async_trait; +use fx_kernel::cancellation::CancellationToken; +use std::path::Path; + +#[async_trait] +pub trait CloudGpuProvider: Send + Sync + std::fmt::Debug { + fn provider_name(&self) -> &str; + + async fn create_pod(&self, config: PodConfig) -> Result; + async fn list_pods(&self) -> Result, GpuError>; + async fn pod_status(&self, pod_id: &str) -> Result; + async fn stop_pod(&self, pod_id: &str) -> Result<(), GpuError>; + async fn destroy_pod(&self, pod_id: &str) -> Result<(), GpuError>; + + async fn exec( + &self, + pod_id: &str, + command: &str, + timeout_seconds: u32, + cancel: Option<&CancellationToken>, + ) -> Result; + + async fn upload( + &self, + pod_id: &str, + local_path: &Path, + remote_path: &str, + cancel: Option<&CancellationToken>, + ) -> Result<(), GpuError>; + + async fn download( + &self, + pod_id: &str, + remote_path: &str, + local_path: &Path, + cancel: Option<&CancellationToken>, + ) -> Result<(), GpuError>; +} diff --git a/engine/crates/fx-cloud-gpu/src/skill.rs b/engine/crates/fx-cloud-gpu/src/skill.rs new file mode 100644 index 00000000..37da5e3b --- /dev/null +++ b/engine/crates/fx-cloud-gpu/src/skill.rs @@ -0,0 +1,879 @@ +use crate::{CloudGpuProvider, ExecResult, GpuError, Pod, PodConfig}; +use async_trait::async_trait; +use fx_kernel::act::ToolCacheability; +use fx_kernel::cancellation::CancellationToken; +use fx_llm::ToolDefinition; +use fx_loadable::{Skill, SkillError}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::path::{Path, PathBuf}; + +const GPU_CREATE_TOOL: &str = "gpu_create"; +const GPU_LIST_TOOL: &str = "gpu_list"; +const GPU_STATUS_TOOL: &str = "gpu_status"; +const GPU_STOP_TOOL: &str = "gpu_stop"; +const GPU_DESTROY_TOOL: &str = "gpu_destroy"; +const GPU_EXEC_TOOL: &str = "gpu_exec"; +const GPU_UPLOAD_TOOL: &str = "gpu_upload"; +const GPU_DOWNLOAD_TOOL: &str = "gpu_download"; + +#[derive(Debug)] +pub struct CloudGpuSkill { + provider: Box, +} + +impl CloudGpuSkill { + #[must_use] + pub fn new(provider: Box) -> Self { + Self { provider } + } + + fn handles_tool(tool_name: &str) -> bool { + matches!( + tool_name, + GPU_CREATE_TOOL + | GPU_LIST_TOOL + | GPU_STATUS_TOOL + | GPU_STOP_TOOL + | GPU_DESTROY_TOOL + | GPU_EXEC_TOOL + | GPU_UPLOAD_TOOL + | GPU_DOWNLOAD_TOOL + ) + } + + async fn execute_tool( + &self, + tool_name: &str, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Result { + match tool_name { + GPU_CREATE_TOOL => self.handle_create(arguments).await, + GPU_LIST_TOOL => self.handle_list(arguments).await, + GPU_STATUS_TOOL => self.handle_status(arguments).await, + GPU_STOP_TOOL => self.handle_stop(arguments).await, + GPU_DESTROY_TOOL => self.handle_destroy(arguments).await, + GPU_EXEC_TOOL => self.handle_exec(arguments, cancel).await, + GPU_UPLOAD_TOOL => self.handle_upload(arguments, cancel).await, + GPU_DOWNLOAD_TOOL => self.handle_download(arguments, cancel).await, + _ => Err(format!("unknown cloud gpu tool: {tool_name}")), + } + } + + async fn handle_create(&self, arguments: &str) -> Result { + let request: GpuCreateRequest = parse_request(arguments)?; + let pod = self + .provider + .create_pod(request.config) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuCreateResponse { pod }) + } + + async fn handle_list(&self, arguments: &str) -> Result { + let _: GpuListRequest = parse_request(arguments)?; + let pods = self + .provider + .list_pods() + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuListResponse { pods }) + } + + async fn handle_status(&self, arguments: &str) -> Result { + let request: GpuStatusRequest = parse_request(arguments)?; + let pod = self + .provider + .pod_status(&request.pod_id) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuStatusResponse { pod }) + } + + async fn handle_stop(&self, arguments: &str) -> Result { + let request: GpuStopRequest = parse_request(arguments)?; + self.provider + .stop_pod(&request.pod_id) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuStopResponse { + pod_id: request.pod_id, + stopped: true, + }) + } + + async fn handle_destroy(&self, arguments: &str) -> Result { + let request: GpuDestroyRequest = parse_request(arguments)?; + self.provider + .destroy_pod(&request.pod_id) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuDestroyResponse { + pod_id: request.pod_id, + destroyed: true, + }) + } + + async fn handle_exec( + &self, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Result { + let request: GpuExecRequest = parse_request(arguments)?; + let result = self + .provider + .exec( + &request.pod_id, + &request.command, + request.timeout_seconds, + cancel, + ) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuExecResponse { result }) + } + + async fn handle_upload( + &self, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Result { + let request: GpuUploadRequest = parse_request(arguments)?; + self.provider + .upload( + &request.pod_id, + &request.local_path, + &request.remote_path, + cancel, + ) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuUploadResponse { + pod_id: request.pod_id, + local_path: path_to_string(&request.local_path), + remote_path: request.remote_path, + uploaded: true, + }) + } + + async fn handle_download( + &self, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Result { + let request: GpuDownloadRequest = parse_request(arguments)?; + self.provider + .download( + &request.pod_id, + &request.remote_path, + &request.local_path, + cancel, + ) + .await + .map_err(serialize_gpu_error)?; + serialize_response(&GpuDownloadResponse { + pod_id: request.pod_id, + remote_path: request.remote_path, + local_path: path_to_string(&request.local_path), + downloaded: true, + }) + } +} + +#[async_trait] +impl Skill for CloudGpuSkill { + fn name(&self) -> &str { + "cloud_gpu" + } + + fn description(&self) -> &str { + "Manage cloud GPU pods through a configured provider." + } + + fn tool_definitions(&self) -> Vec { + cloud_gpu_tool_definitions() + } + + fn cacheability(&self, tool_name: &str) -> ToolCacheability { + match tool_name { + GPU_CREATE_TOOL | GPU_STOP_TOOL | GPU_DESTROY_TOOL | GPU_EXEC_TOOL + | GPU_UPLOAD_TOOL | GPU_DOWNLOAD_TOOL => ToolCacheability::SideEffect, + GPU_LIST_TOOL | GPU_STATUS_TOOL => ToolCacheability::NeverCache, + _ => ToolCacheability::NeverCache, + } + } + + async fn execute( + &self, + tool_name: &str, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Option> { + if !Self::handles_tool(tool_name) { + return None; + } + Some(self.execute_tool(tool_name, arguments, cancel).await) + } +} + +#[derive(Debug, Deserialize)] +struct GpuCreateRequest { + config: PodConfig, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuCreateResponse { + pod: Pod, +} + +#[derive(Debug, Default, Deserialize)] +struct GpuListRequest {} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuListResponse { + pods: Vec, +} + +#[derive(Debug, Deserialize)] +struct GpuStatusRequest { + pod_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuStatusResponse { + pod: Pod, +} + +#[derive(Debug, Deserialize)] +struct GpuStopRequest { + pod_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuStopResponse { + pod_id: String, + stopped: bool, +} + +#[derive(Debug, Deserialize)] +struct GpuDestroyRequest { + pod_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuDestroyResponse { + pod_id: String, + destroyed: bool, +} + +#[derive(Debug, Deserialize)] +struct GpuExecRequest { + pod_id: String, + command: String, + timeout_seconds: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuExecResponse { + result: ExecResult, +} + +#[derive(Debug, Deserialize)] +struct GpuUploadRequest { + pod_id: String, + local_path: PathBuf, + remote_path: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuUploadResponse { + pod_id: String, + local_path: String, + remote_path: String, + uploaded: bool, +} + +#[derive(Debug, Deserialize)] +struct GpuDownloadRequest { + pod_id: String, + remote_path: String, + local_path: PathBuf, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GpuDownloadResponse { + pod_id: String, + remote_path: String, + local_path: String, + downloaded: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ErrorResponse { + error: String, +} + +fn parse_request(arguments: &str) -> Result +where + T: DeserializeOwned, +{ + serde_json::from_str(arguments).map_err(|error| format!("invalid arguments: {error}")) +} + +fn serialize_response(response: &T) -> Result +where + T: Serialize, +{ + serde_json::to_string(response).map_err(|error| format!("serialization failed: {error}")) +} + +fn serialize_gpu_error(error: GpuError) -> SkillError { + serialize_error_message(error.to_string()) +} + +fn serialize_error_message(message: String) -> SkillError { + let response = ErrorResponse { error: message }; + match serde_json::to_string(&response) { + Ok(json) => json, + Err(error) => format!( + "serialization failed: {error}; original error: {}", + response.error + ), + } +} + +fn path_to_string(path: &Path) -> String { + path.display().to_string() +} + +fn cloud_gpu_tool_definitions() -> Vec { + vec![ + gpu_create_definition(), + gpu_list_definition(), + gpu_status_definition(), + gpu_stop_definition(), + gpu_destroy_definition(), + gpu_exec_definition(), + gpu_upload_definition(), + gpu_download_definition(), + ] +} + +fn gpu_create_definition() -> ToolDefinition { + tool_definition( + GPU_CREATE_TOOL, + "Create a new cloud GPU pod from a pod configuration.", + json!({ + "type": "object", + "properties": { + "config": pod_config_schema() + }, + "required": ["config"] + }), + ) +} + +fn gpu_list_definition() -> ToolDefinition { + tool_definition( + GPU_LIST_TOOL, + "List all cloud GPU pods for the configured provider.", + empty_object_schema(), + ) +} + +fn gpu_status_definition() -> ToolDefinition { + tool_definition( + GPU_STATUS_TOOL, + "Get the current status and connection details for a pod.", + pod_id_schema(), + ) +} + +fn gpu_stop_definition() -> ToolDefinition { + tool_definition( + GPU_STOP_TOOL, + "Stop a running cloud GPU pod.", + pod_id_schema(), + ) +} + +fn gpu_destroy_definition() -> ToolDefinition { + tool_definition( + GPU_DESTROY_TOOL, + "Destroy a cloud GPU pod permanently.", + pod_id_schema(), + ) +} + +fn gpu_exec_definition() -> ToolDefinition { + tool_definition( + GPU_EXEC_TOOL, + "Execute a command inside a cloud GPU pod.", + json!({ + "type": "object", + "properties": { + "pod_id": { "type": "string", "description": "Pod identifier" }, + "command": { "type": "string", "description": "Command to execute" }, + "timeout_seconds": { + "type": "integer", + "description": "Command timeout in seconds" + } + }, + "required": ["pod_id", "command", "timeout_seconds"] + }), + ) +} + +fn gpu_upload_definition() -> ToolDefinition { + tool_definition( + GPU_UPLOAD_TOOL, + "Upload a local file to a cloud GPU pod.", + json!({ + "type": "object", + "properties": { + "pod_id": { "type": "string", "description": "Pod identifier" }, + "local_path": { "type": "string", "description": "Source file path" }, + "remote_path": { "type": "string", "description": "Destination path on the pod" } + }, + "required": ["pod_id", "local_path", "remote_path"] + }), + ) +} + +fn gpu_download_definition() -> ToolDefinition { + tool_definition( + GPU_DOWNLOAD_TOOL, + "Download a file from a cloud GPU pod to the local machine.", + json!({ + "type": "object", + "properties": { + "pod_id": { "type": "string", "description": "Pod identifier" }, + "remote_path": { "type": "string", "description": "Source path on the pod" }, + "local_path": { "type": "string", "description": "Destination file path" } + }, + "required": ["pod_id", "remote_path", "local_path"] + }), + ) +} + +fn tool_definition(name: &str, description: &str, parameters: serde_json::Value) -> ToolDefinition { + ToolDefinition { + name: name.to_string(), + description: description.to_string(), + parameters, + } +} + +fn empty_object_schema() -> serde_json::Value { + json!({ + "type": "object", + "properties": {}, + "required": [] + }) +} + +fn pod_id_schema() -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "pod_id": { "type": "string", "description": "Pod identifier" } + }, + "required": ["pod_id"] + }) +} + +fn pod_config_schema() -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "gpu": gpu_type_schema(), + "gpu_count": { "type": "integer" }, + "image": { "type": "string" }, + "disk_gb": { "type": "integer" }, + "env": { + "type": "object", + "additionalProperties": { "type": "string" } + } + }, + "required": ["name", "gpu", "gpu_count", "image", "disk_gb"] + }) +} + +fn gpu_type_schema() -> serde_json::Value { + json!({ + "oneOf": [ + { + "type": "string", + "enum": ["Rtx3090", "Rtx4090", "A100_80gb", "H100_80gb"] + }, + { + "type": "object", + "properties": { + "Custom": { "type": "string" } + }, + "required": ["Custom"] + } + ] + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{GpuType, PodStatus}; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + #[derive(Debug)] + struct MockProvider { + calls: Arc>>, + destroy_missing: bool, + } + + impl MockProvider { + fn new(destroy_missing: bool) -> (Self, Arc>>) { + let calls = Arc::new(Mutex::new(Vec::new())); + ( + Self { + calls: Arc::clone(&calls), + destroy_missing, + }, + calls, + ) + } + + fn record_call(&self, call: impl Into) { + self.calls.lock().unwrap().push(call.into()); + } + + fn cancel_suffix(cancel: Option<&CancellationToken>) -> String { + cancel.map_or_else(String::new, |token| { + format!(":cancelled:{}", token.is_cancelled()) + }) + } + } + + #[async_trait] + impl CloudGpuProvider for MockProvider { + fn provider_name(&self) -> &str { + "mock" + } + + async fn create_pod(&self, config: PodConfig) -> Result { + self.record_call(format!("create:{}", config.name)); + Ok(sample_pod()) + } + + async fn list_pods(&self) -> Result, GpuError> { + self.record_call("list"); + Ok(vec![sample_pod()]) + } + + async fn pod_status(&self, pod_id: &str) -> Result { + self.record_call(format!("status:{pod_id}")); + Ok(sample_pod()) + } + + async fn stop_pod(&self, pod_id: &str) -> Result<(), GpuError> { + self.record_call(format!("stop:{pod_id}")); + Ok(()) + } + + async fn destroy_pod(&self, pod_id: &str) -> Result<(), GpuError> { + self.record_call(format!("destroy:{pod_id}")); + if self.destroy_missing { + Err(GpuError::PodNotFound(pod_id.to_string())) + } else { + Ok(()) + } + } + + async fn exec( + &self, + pod_id: &str, + command: &str, + timeout_seconds: u32, + cancel: Option<&CancellationToken>, + ) -> Result { + self.record_call(format!( + "exec:{pod_id}:{command}:{timeout_seconds}{}", + Self::cancel_suffix(cancel) + )); + Ok(sample_exec_result()) + } + + async fn upload( + &self, + pod_id: &str, + local_path: &std::path::Path, + remote_path: &str, + cancel: Option<&CancellationToken>, + ) -> Result<(), GpuError> { + self.record_call(format!( + "upload:{pod_id}:{}:{remote_path}{}", + local_path.display(), + Self::cancel_suffix(cancel) + )); + Ok(()) + } + + async fn download( + &self, + pod_id: &str, + remote_path: &str, + local_path: &std::path::Path, + cancel: Option<&CancellationToken>, + ) -> Result<(), GpuError> { + self.record_call(format!( + "download:{pod_id}:{remote_path}:{}{}", + local_path.display(), + Self::cancel_suffix(cancel) + )); + Ok(()) + } + } + + fn sample_config() -> PodConfig { + let mut env = HashMap::new(); + env.insert("TOKEN".to_string(), "abc123".to_string()); + PodConfig { + name: "trainer".to_string(), + gpu: GpuType::Rtx4090, + gpu_count: 1, + image: "nvidia/cuda:12.0.0-runtime-ubuntu22.04".to_string(), + disk_gb: 200, + env, + } + } + + fn sample_pod() -> Pod { + Pod { + id: "pod-1".to_string(), + status: PodStatus::Running, + ssh_host: "gpu.example.com".to_string(), + ssh_port: 22, + gpu: GpuType::Rtx4090, + cost_per_hour: 1.5, + } + } + + fn sample_exec_result() -> ExecResult { + ExecResult { + stdout: "GPU ready".to_string(), + stderr: String::new(), + exit_code: 0, + duration_ms: 125, + } + } + + fn test_skill(destroy_missing: bool) -> (CloudGpuSkill, Arc>>) { + let (provider, calls) = MockProvider::new(destroy_missing); + (CloudGpuSkill::new(Box::new(provider)), calls) + } + + async fn execute_tool( + skill: &CloudGpuSkill, + tool_name: &str, + arguments: serde_json::Value, + ) -> Result { + execute_tool_with_cancel(skill, tool_name, arguments, None).await + } + + async fn execute_tool_with_cancel( + skill: &CloudGpuSkill, + tool_name: &str, + arguments: serde_json::Value, + cancel: Option<&CancellationToken>, + ) -> Result { + skill + .execute(tool_name, &arguments.to_string(), cancel) + .await + .expect("tool should be handled") + } + + #[tokio::test] + async fn mock_provider_returns_canned_responses() { + let provider = MockProvider::new(false).0; + let pod = provider + .create_pod(sample_config()) + .await + .expect("create pod"); + let pods = provider.list_pods().await.expect("list pods"); + let exec = provider + .exec("pod-1", "nvidia-smi", 30, None) + .await + .expect("exec command"); + + assert_eq!(pod.id, "pod-1"); + assert_eq!(pods.len(), 1); + assert_eq!(exec.stdout, "GPU ready"); + } + + async fn exercise_lifecycle_tools(skill: &CloudGpuSkill) { + let _ = execute_tool(skill, GPU_CREATE_TOOL, json!({ "config": sample_config() })).await; + let _ = execute_tool(skill, GPU_LIST_TOOL, json!({})).await; + let _ = execute_tool(skill, GPU_STATUS_TOOL, json!({ "pod_id": "pod-1" })).await; + let _ = execute_tool(skill, GPU_STOP_TOOL, json!({ "pod_id": "pod-1" })).await; + let _ = execute_tool(skill, GPU_DESTROY_TOOL, json!({ "pod_id": "pod-1" })).await; + let _ = execute_tool( + skill, + GPU_EXEC_TOOL, + json!({ "pod_id": "pod-1", "command": "nvidia-smi", "timeout_seconds": 30 }), + ) + .await; + } + + async fn exercise_transfer_tools(skill: &CloudGpuSkill) { + let _ = execute_tool( + skill, + GPU_UPLOAD_TOOL, + json!({ + "pod_id": "pod-1", + "local_path": "/tmp/input.txt", + "remote_path": "/workspace/input.txt" + }), + ) + .await; + let _ = execute_tool( + skill, + GPU_DOWNLOAD_TOOL, + json!({ + "pod_id": "pod-1", + "remote_path": "/workspace/output.txt", + "local_path": "/tmp/output.txt" + }), + ) + .await; + } + + fn expected_calls() -> Vec<&'static str> { + vec![ + "create:trainer", + "list", + "status:pod-1", + "stop:pod-1", + "destroy:pod-1", + "exec:pod-1:nvidia-smi:30", + "upload:pod-1:/tmp/input.txt:/workspace/input.txt", + "download:pod-1:/workspace/output.txt:/tmp/output.txt", + ] + } + + #[tokio::test] + async fn cloud_gpu_skill_routes_to_correct_provider_method() { + let (skill, calls) = test_skill(false); + + exercise_lifecycle_tools(&skill).await; + exercise_transfer_tools(&skill).await; + + assert_eq!(calls.lock().unwrap().clone(), expected_calls()); + } + + #[tokio::test] + async fn exec_upload_and_download_forward_cancellation_token() { + let (skill, calls) = test_skill(false); + let cancel = CancellationToken::new(); + cancel.cancel(); + + let _ = execute_tool_with_cancel( + &skill, + GPU_EXEC_TOOL, + json!({ "pod_id": "pod-1", "command": "nvidia-smi", "timeout_seconds": 30 }), + Some(&cancel), + ) + .await; + let _ = execute_tool_with_cancel( + &skill, + GPU_UPLOAD_TOOL, + json!({ + "pod_id": "pod-1", + "local_path": "/tmp/input.txt", + "remote_path": "/workspace/input.txt" + }), + Some(&cancel), + ) + .await; + let _ = execute_tool_with_cancel( + &skill, + GPU_DOWNLOAD_TOOL, + json!({ + "pod_id": "pod-1", + "remote_path": "/workspace/output.txt", + "local_path": "/tmp/output.txt" + }), + Some(&cancel), + ) + .await; + + assert_eq!( + calls.lock().unwrap().clone(), + vec![ + "exec:pod-1:nvidia-smi:30:cancelled:true", + "upload:pod-1:/tmp/input.txt:/workspace/input.txt:cancelled:true", + "download:pod-1:/workspace/output.txt:/tmp/output.txt:cancelled:true", + ] + ); + } + + #[tokio::test] + async fn unknown_tool_returns_none() { + let (skill, _) = test_skill(false); + let result = skill.execute("gpu_unknown", "{}", None).await; + + assert!(result.is_none()); + } + + #[tokio::test] + async fn error_serialization_returns_json_error_payload() { + let (skill, _) = test_skill(true); + let result = execute_tool(&skill, GPU_DESTROY_TOOL, json!({ "pod_id": "pod-404" })) + .await + .expect_err("destroy should fail"); + let payload: ErrorResponse = serde_json::from_str(&result).expect("error json"); + + assert_eq!(payload.error, "pod not found: pod-404"); + } + + #[test] + fn gpu_create_schema_allows_omitting_env() { + let schema = gpu_create_definition().parameters; + let required = schema["properties"]["config"]["required"] + .as_array() + .expect("config schema should list required fields"); + let required_fields: Vec<&str> = required + .iter() + .map(|field| field.as_str().expect("required field should be a string")) + .collect(); + + assert!(!required_fields.contains(&"env")); + } + + #[test] + fn tool_definitions_match_expected_count_and_names() { + let (skill, _) = test_skill(false); + let definitions = skill.tool_definitions(); + let names: Vec<&str> = definitions + .iter() + .map(|definition| definition.name.as_str()) + .collect(); + + assert_eq!(definitions.len(), 8); + assert_eq!(skill.name(), "cloud_gpu"); + assert_eq!( + names, + vec![ + GPU_CREATE_TOOL, + GPU_LIST_TOOL, + GPU_STATUS_TOOL, + GPU_STOP_TOOL, + GPU_DESTROY_TOOL, + GPU_EXEC_TOOL, + GPU_UPLOAD_TOOL, + GPU_DOWNLOAD_TOOL, + ] + ); + } +} diff --git a/engine/crates/fx-config/src/lib.rs b/engine/crates/fx-config/src/lib.rs index 4b5a3030..ba81d696 100644 --- a/engine/crates/fx-config/src/lib.rs +++ b/engine/crates/fx-config/src/lib.rs @@ -67,6 +67,9 @@ pub const DEFAULT_CONFIG_TEMPLATE: &str = r#"# Fawx Configuration # search_exclude = ["vendor", "dist"] # max_read_size = 1048576 +[git] +# protected_branches = ["main", "staging"] + [memory] # max_entries = 1000 # max_value_size = 10240 @@ -134,6 +137,8 @@ pub struct FawxConfig { pub model: ModelConfig, pub logging: LoggingConfig, pub tools: ToolsConfig, + #[serde(default)] + pub git: GitConfig, pub memory: MemoryConfig, pub security: SecurityConfig, pub self_modify: SelfModifyCliConfig, @@ -197,6 +202,14 @@ pub struct WorkspaceConfig { pub root: Option, } +/// Git policy configuration for protected branch enforcement. +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct GitConfig { + #[serde(default)] + pub protected_branches: Vec, +} + /// Permission presets that define default agent autonomy levels. #[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] @@ -1392,6 +1405,8 @@ log_dir = "~/.fawx/custom-logs" assert!(DEFAULT_CONFIG_TEMPLATE.contains("# verbosity = \"normal\"")); assert!(DEFAULT_CONFIG_TEMPLATE.contains("[workspace]")); assert!(DEFAULT_CONFIG_TEMPLATE.contains("# root = \".\"")); + assert!(DEFAULT_CONFIG_TEMPLATE.contains("[git]")); + assert!(DEFAULT_CONFIG_TEMPLATE.contains("# protected_branches = [\"main\", \"staging\"]")); assert!(DEFAULT_CONFIG_TEMPLATE.contains("[permissions]")); assert!(DEFAULT_CONFIG_TEMPLATE.contains("# preset = \"power\"")); assert!(DEFAULT_CONFIG_TEMPLATE.contains("[budget]")); @@ -1420,6 +1435,7 @@ log_dir = "~/.fawx/custom-logs" assert_eq!(defaults.agent.behavior.verbosity, "normal"); assert_eq!(defaults.logging, LoggingConfig::default()); assert_eq!(defaults.tools.max_read_size, 1024 * 1024); + assert!(defaults.git.protected_branches.is_empty()); assert_eq!(defaults.memory.max_entries, 1000); assert_eq!(defaults.memory.max_value_size, 10240); assert_eq!(defaults.memory.max_snapshot_chars, 2000); @@ -1427,6 +1443,42 @@ log_dir = "~/.fawx/custom-logs" assert!(defaults.memory.embeddings_enabled); } + #[test] + fn config_parses_git_protected_branches() { + let config: FawxConfig = toml::from_str( + r#"[git] +protected_branches = ["main", "staging"] +"#, + ) + .expect("deserialize git config"); + + assert_eq!(config.git.protected_branches, vec!["main", "staging"]); + } + + #[test] + fn config_parses_empty_git_protected_branches() { + let config: FawxConfig = toml::from_str( + r#"[git] +protected_branches = [] +"#, + ) + .expect("deserialize empty git config"); + + assert!(config.git.protected_branches.is_empty()); + } + + #[test] + fn git_config_serde_round_trip() { + let original = GitConfig { + protected_branches: vec!["main".to_string(), "staging".to_string()], + }; + + let encoded = toml::to_string(&original).expect("serialize git config"); + let decoded: GitConfig = toml::from_str(&encoded).expect("deserialize git config"); + + assert_eq!(decoded, original); + } + #[test] fn config_fields_roundtrip() { let original = FawxConfig { @@ -1462,6 +1514,9 @@ log_dir = "~/.fawx/custom-logs" search_exclude: vec!["vendor".to_string()], max_read_size: 2048, }, + git: GitConfig { + protected_branches: vec!["main".to_string(), "staging".to_string()], + }, memory: MemoryConfig { max_entries: 4, max_value_size: 5, @@ -2002,9 +2057,9 @@ max_iterations = 10 #[test] fn tilde_expansion_does_not_expand_tilde_user() { - let path = PathBuf::from("~joe/.config"); + let path = PathBuf::from("~user/.config"); let expanded = expand_tilde(&path); - assert_eq!(expanded, PathBuf::from("~joe/.config")); + assert_eq!(expanded, PathBuf::from("~user/.config")); } #[test] @@ -2327,6 +2382,7 @@ working_dir = "/tmp/work" let config: FawxConfig = toml::from_str("[general]\nmax_iterations = 12\n").expect("deserialize old config"); assert_eq!(config.workspace, WorkspaceConfig::default()); + assert_eq!(config.git, GitConfig::default()); assert_eq!(config.budget, BudgetConfig::default()); assert_eq!(config.sandbox, SandboxConfig::default()); assert_eq!(config.proposals, ProposalConfig::default()); diff --git a/engine/crates/fx-consensus/src/remote_workspace.rs b/engine/crates/fx-consensus/src/remote_workspace.rs index 3bea25b3..47e42322 100644 --- a/engine/crates/fx-consensus/src/remote_workspace.rs +++ b/engine/crates/fx-consensus/src/remote_workspace.rs @@ -407,9 +407,9 @@ mod tests { #[test] fn remote_eval_target_parses_user_host_and_path() { - let target: RemoteEvalTarget = "user@example.com:/srv/fawx".parse().expect("target"); + let target: RemoteEvalTarget = "builder@example.com:/srv/fawx".parse().expect("target"); - assert_eq!(target.ssh_user, "user"); + assert_eq!(target.ssh_user, "builder"); assert_eq!(target.ssh_host, "example.com"); assert_eq!(target.remote_project_dir, "/srv/fawx"); } @@ -425,7 +425,7 @@ mod tests { #[test] fn ssh_command_format_builds_expected_args() { - let spec = ssh_command_spec("user", "10.0.0.1", "cd '/srv/fawx' && cargo test"); + let spec = ssh_command_spec("builder", "203.0.113.20", "cd '/srv/fawx' && cargo test"); assert_eq!(spec.program, "ssh"); assert_eq!( @@ -441,7 +441,7 @@ mod tests { "ServerAliveInterval=15", "-o", "ServerAliveCountMax=3", - "user@10.0.0.1", + "builder@203.0.113.20", "cd '/srv/fawx' && cargo test", ] ); @@ -450,8 +450,8 @@ mod tests { #[test] fn patch_application_builds_scp_and_git_apply_commands() { let scp = scp_command_spec( - "user", - "10.0.0.1", + "builder", + "203.0.113.20", Path::new("/tmp/local.patch"), "/tmp/remote.patch", ); @@ -468,7 +468,7 @@ mod tests { "-o", "ConnectTimeout=30", "/tmp/local.patch", - "user@10.0.0.1:/tmp/remote.patch", + "builder@203.0.113.20:/tmp/remote.patch", ] ); assert_eq!( diff --git a/engine/crates/fx-fleet/src/http.rs b/engine/crates/fx-fleet/src/http.rs index 425d39bb..953c8f56 100644 --- a/engine/crates/fx-fleet/src/http.rs +++ b/engine/crates/fx-fleet/src/http.rs @@ -527,7 +527,7 @@ mod tests { fn sample_registration_request() -> FleetRegistrationRequest { FleetRegistrationRequest { - node_name: "macmini-01".to_string(), + node_name: "node-a-01".to_string(), bearer_token: "node-secret".to_string(), capabilities: vec!["generate".to_string(), "evaluate".to_string()], rust_version: Some("1.86.0".to_string()), @@ -539,7 +539,7 @@ mod tests { fn sample_heartbeat() -> FleetHeartbeat { FleetHeartbeat { - node_id: "macmini-01".to_string(), + node_id: "node-a-01".to_string(), status: WorkerState::Idle, current_task: None, } @@ -547,7 +547,7 @@ mod tests { fn sample_worker_status() -> FleetWorkerStatus { FleetWorkerStatus { - node_id: "macmini-01".to_string(), + node_id: "node-a-01".to_string(), status: WorkerState::Busy, current_task: Some("exp-001".to_string()), uptime_seconds: 42, @@ -670,7 +670,7 @@ mod tests { #[test] fn registration_debug_redacts_bearer_token() { let request = FleetRegistrationRequest { - node_name: "macmini-01".to_string(), + node_name: "node-a-01".to_string(), bearer_token: "node-secret".to_string(), capabilities: vec!["generate".to_string()], rust_version: None, diff --git a/engine/crates/fx-fleet/src/identity.rs b/engine/crates/fx-fleet/src/identity.rs index 3ca64aac..3c912854 100644 --- a/engine/crates/fx-fleet/src/identity.rs +++ b/engine/crates/fx-fleet/src/identity.rs @@ -45,8 +45,8 @@ mod tests { fn sample_identity() -> FleetIdentity { FleetIdentity { - node_id: "macmini-a1b2c3".to_string(), - primary_endpoint: "http://10.0.0.1:8400".to_string(), + node_id: "node-a-a1b2c3".to_string(), + primary_endpoint: "http://203.0.113.20:8400".to_string(), bearer_token: "tok_secret_123".to_string(), registered_at_ms: 12345, } diff --git a/engine/crates/fx-fleet/src/lib.rs b/engine/crates/fx-fleet/src/lib.rs index 8f0cf5fb..856dec22 100644 --- a/engine/crates/fx-fleet/src/lib.rs +++ b/engine/crates/fx-fleet/src/lib.rs @@ -29,7 +29,7 @@ pub struct NodeInfo { pub node_id: String, /// Human-readable name. pub name: String, - /// HTTP API endpoint (e.g., "https://100.64.1.5:8400"). + /// HTTP API endpoint (e.g., "https://203.0.113.5:8400"). pub endpoint: String, /// Bearer token for authenticating with this node. pub auth_token: Option, @@ -436,19 +436,19 @@ mod tests { fn node_info_from_config_maps_fleet_fields() { let config = NodeConfig { id: "mac-mini".to_string(), - name: "Node Alpha".to_string(), + name: "Worker Node A".to_string(), endpoint: Some("https://10.0.0.5:8400".to_string()), auth_token: Some("token".to_string()), capabilities: vec!["agentic_loop".to_string(), "test".to_string()], address: Some("10.0.0.5".to_string()), - user: Some("admin".to_string()), + user: Some("builder".to_string()), ssh_key: Some("~/.ssh/id_ed25519".to_string()), }; let node = NodeInfo::from(&config); assert_eq!(node.node_id, "mac-mini"); - assert_eq!(node.name, "Node Alpha"); + assert_eq!(node.name, "Worker Node A"); assert_eq!(node.endpoint, "https://10.0.0.5:8400"); assert_eq!(node.auth_token.as_deref(), Some("token")); assert_eq!( @@ -462,7 +462,7 @@ mod tests { assert_eq!(node.last_heartbeat_ms, 0); assert!(node.registered_at_ms > 0); assert_eq!(node.address.as_deref(), Some("10.0.0.5")); - assert_eq!(node.ssh_user.as_deref(), Some("admin")); + assert_eq!(node.ssh_user.as_deref(), Some("builder")); assert_eq!(node.ssh_key.as_deref(), Some("~/.ssh/id_ed25519")); } diff --git a/engine/crates/fx-fleet/src/manager.rs b/engine/crates/fx-fleet/src/manager.rs index 2b7ca9cc..5e2c2360 100644 --- a/engine/crates/fx-fleet/src/manager.rs +++ b/engine/crates/fx-fleet/src/manager.rs @@ -448,7 +448,7 @@ mod tests { let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("Node Alpha", "10.0.0.2", 8400) + .add_node("Worker Node A", "203.0.113.10", 8400) .expect("node should add"); let node = manager .list_nodes() @@ -457,10 +457,10 @@ mod tests { assert_eq!(token.node_id, node.node_id); assert_ne!(token.node_id, node.name); - assert!(token.node_id.starts_with("node-alpha-")); - assert_eq!(node.name, "Node Alpha"); - assert_eq!(node.endpoint, "https://10.0.0.2:8400"); - assert_eq!(node.address.as_deref(), Some("10.0.0.2")); + assert!(token.node_id.starts_with("worker-node-a-")); + assert_eq!(node.name, "Worker Node A"); + assert_eq!(node.endpoint, "https://203.0.113.10:8400"); + assert_eq!(node.address.as_deref(), Some("203.0.113.10")); assert_eq!(node.status, NodeStatus::Offline); } @@ -470,9 +470,9 @@ mod tests { let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("first node should add"); - let result = manager.add_node("node-alpha", "10.0.0.3", 8400); + let result = manager.add_node("node-a", "203.0.113.11", 8400); assert!(matches!(result, Err(FleetError::DuplicateNode))); } @@ -482,11 +482,11 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); manager - .remove_node("node-alpha") + .remove_node("node-a") .expect("node should remove cleanly"); assert!(manager.list_nodes().is_empty()); @@ -510,7 +510,7 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("Node Alpha", "10.0.0.2", 8400) + .add_node("Worker Node A", "203.0.113.10", 8400) .expect("node should add"); let verified = manager.verify_bearer(&token.secret); @@ -523,10 +523,10 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); manager - .remove_node("node-alpha") + .remove_node("node-a") .expect("node should remove cleanly"); let verified = manager.verify_bearer(&token.secret); @@ -547,7 +547,7 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); let node = manager @@ -571,7 +571,7 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); manager @@ -592,7 +592,7 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); let token = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("node should add"); manager .record_worker_heartbeat(&token.node_id, NodeStatus::Busy, 100) @@ -617,19 +617,19 @@ mod tests { let fleet_dir = temp_dir.path().join("fleet"); let mut manager = FleetManager::init(&fleet_dir).expect("fleet should initialize"); let active = manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("first node should add"); let revoked = manager - .add_node("node-beta", "10.0.0.3", 8401) + .add_node("node-b", "203.0.113.11", 8401) .expect("second node should add"); manager - .remove_node("node-beta") + .remove_node("node-b") .expect("node should remove cleanly"); let loaded = FleetManager::load(&fleet_dir).expect("fleet should load"); let node_names = sorted_node_names(loaded.list_nodes()); - assert_eq!(node_names, vec!["node-alpha".to_string()]); + assert_eq!(node_names, vec!["node-a".to_string()]); assert_eq!( loaded.verify_bearer(&active.secret).as_deref(), Some(active.node_id.as_str()) @@ -648,7 +648,7 @@ mod tests { let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); manager - .add_node("Node Alpha", "10.0.0.2", 8400) + .add_node("Worker Node A", "203.0.113.10", 8400) .expect("node should add"); assert_private_permissions(&nodes_path(temp_dir.path())); @@ -682,18 +682,15 @@ mod tests { let temp_dir = TempDir::new().expect("tempdir should create"); let mut manager = FleetManager::init(temp_dir.path()).expect("fleet should initialize"); manager - .add_node("node-alpha", "10.0.0.2", 8400) + .add_node("node-a", "203.0.113.10", 8400) .expect("first node should add"); manager - .add_node("node-beta", "10.0.0.3", 8401) + .add_node("node-b", "203.0.113.11", 8401) .expect("second node should add"); let names = sorted_node_names(manager.list_nodes()); - assert_eq!( - names, - vec!["node-alpha".to_string(), "node-beta".to_string()] - ); + assert_eq!(names, vec!["node-a".to_string(), "node-b".to_string()]); } fn sorted_node_names(nodes: Vec<&NodeInfo>) -> Vec { diff --git a/engine/crates/fx-kernel/src/budget.rs b/engine/crates/fx-kernel/src/budget.rs index 59b94cb6..351f044f 100644 --- a/engine/crates/fx-kernel/src/budget.rs +++ b/engine/crates/fx-kernel/src/budget.rs @@ -152,13 +152,13 @@ pub struct BudgetConfig { #[serde(default = "default_max_tool_retries")] pub max_tool_retries: u8, /// Controls graceful termination behavior when budget limits fire and - /// how tool-only turn runs are handled. + /// how tool-turn runs are handled. #[serde(default)] pub termination: TerminationConfig, } -/// Controls how the loop exits when a budget limit fires and how consecutive -/// tool-only turns are managed (nudge → strip → synthesize). +/// Controls how the loop exits when a budget limit fires and how tool-use +/// runs are managed across cycles and within continuation rounds. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct TerminationConfig { /// When true, make one final LLM call with tools stripped to force a @@ -166,17 +166,27 @@ pub struct TerminationConfig { #[serde(default = "default_synthesize_on_exhaustion")] pub synthesize_on_exhaustion: bool, - /// Consecutive tool-only turns before injecting a nudge message telling - /// the agent to respond to the user. 0 disables the nudge. + /// Consecutive tool turns before injecting a nudge message telling the + /// agent to respond to the user. 0 disables the nudge. #[serde(default = "default_nudge_after_tool_turns")] pub nudge_after_tool_turns: u16, - /// Additional consecutive tool-only turns *after the nudge fires* before - /// tools are stripped entirely, forcing a text response. 0 means strip + /// Additional consecutive tool turns *after the nudge fires* before tools + /// are stripped entirely, forcing a text response. 0 means strip /// immediately when the nudge threshold is reached. Set to `u16::MAX` /// to disable stripping while keeping the nudge. #[serde(default = "default_strip_tools_after_nudge")] pub strip_tools_after_nudge: u16, + + /// Tool continuation rounds before injecting a progress nudge. 0 disables + /// both the nudge and the follow-up strip enforcement. + #[serde(default = "default_tool_round_nudge_after")] + pub tool_round_nudge_after: u16, + + /// Additional continuation rounds after the nudge before tools are + /// stripped, forcing a text response. + #[serde(default = "default_tool_round_strip_after_nudge")] + pub tool_round_strip_after_nudge: u16, } fn default_synthesize_on_exhaustion() -> bool { @@ -188,13 +198,21 @@ fn default_nudge_after_tool_turns() -> u16 { fn default_strip_tools_after_nudge() -> u16 { 3 } +fn default_tool_round_nudge_after() -> u16 { + 4 +} +fn default_tool_round_strip_after_nudge() -> u16 { + 2 +} impl Default for TerminationConfig { fn default() -> Self { Self { - synthesize_on_exhaustion: true, - nudge_after_tool_turns: 6, - strip_tools_after_nudge: 3, + synthesize_on_exhaustion: default_synthesize_on_exhaustion(), + nudge_after_tool_turns: default_nudge_after_tool_turns(), + strip_tools_after_nudge: default_strip_tools_after_nudge(), + tool_round_nudge_after: default_tool_round_nudge_after(), + tool_round_strip_after_nudge: default_tool_round_strip_after_nudge(), } } } diff --git a/engine/crates/fx-kernel/src/context_manager.rs b/engine/crates/fx-kernel/src/context_manager.rs index 8d21f5bc..211803df 100644 --- a/engine/crates/fx-kernel/src/context_manager.rs +++ b/engine/crates/fx-kernel/src/context_manager.rs @@ -405,7 +405,7 @@ mod tests { version: 1, }], identity_context: IdentityContext { - user_name: Some("Alice".to_owned()), + user_name: Some("Example User".to_owned()), preferences, personality_traits: vec!["focused".to_owned(), "concise".to_owned()], }, diff --git a/engine/crates/fx-kernel/src/conversation_compactor.rs b/engine/crates/fx-kernel/src/conversation_compactor.rs index 53260819..71762b72 100644 --- a/engine/crates/fx-kernel/src/conversation_compactor.rs +++ b/engine/crates/fx-kernel/src/conversation_compactor.rs @@ -4,7 +4,6 @@ use fx_llm::{ContentBlock, Message, MessageRole}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::error::Error; -use std::sync::Arc; const COMPACTION_MARKER_PREFIX: &str = "[context compacted:"; const SUMMARY_MARKER_PREFIX: &str = "[context summary]"; @@ -56,7 +55,9 @@ fn text_blocks(message: &Message) -> impl Iterator { } fn message_contains_marker(message: &Message) -> bool { - text_blocks(message).any(|text| text.starts_with(COMPACTION_MARKER_PREFIX)) + text_blocks(message).any(|text| { + text.starts_with(COMPACTION_MARKER_PREFIX) || text.starts_with(SUMMARY_MARKER_PREFIX) + }) } fn message_is_system_like(message: &Message) -> bool { @@ -69,7 +70,7 @@ fn message_is_system_like(message: &Message) -> bool { /// This can create adjacent assistant-role messages in the compacted window, /// which is acceptable in the current integration because message ordering is /// preserved and no role alternation invariant is enforced by providers. -fn summary_message(summary: &str) -> Message { +pub(crate) fn summary_message(summary: &str) -> Message { Message::assistant(format!("{SUMMARY_MARKER_PREFIX}\n{summary}")) } @@ -469,6 +470,7 @@ fn sliding_compaction_result( compacted_count: 0, estimated_tokens: before_tokens, used_summarization: false, + summary: None, evicted_indices: Vec::new(), }; debug_assert_tool_pair_integrity(&result.messages); @@ -498,6 +500,7 @@ fn sliding_compaction_result( messages: compacted_messages, compacted_count, used_summarization: false, + summary: None, evicted_indices, }; debug_assert_tool_pair_integrity(&result.messages); @@ -517,6 +520,7 @@ pub fn emergency_compact(messages: &[Message], preserve_recent_turns: usize) -> compacted_count: 0, estimated_tokens: ConversationBudget::estimate_tokens(messages), used_summarization: false, + summary: None, evicted_indices, }; } @@ -529,6 +533,7 @@ pub fn emergency_compact(messages: &[Message], preserve_recent_turns: usize) -> messages: result_messages, compacted_count, used_summarization: false, + summary: None, evicted_indices, } } @@ -749,22 +754,6 @@ impl ConversationBudget { pub fn compaction_target(&self) -> usize { self.conversation_budget() / 2 } - - /// Target token count for summarizing compaction (Tier 3). - /// Returns 40% of conversation budget for headroom below summarize_threshold (80%). - pub fn summarize_target(&self) -> usize { - self.conversation_budget().saturating_mul(2) / 5 - } -} - -/// Strategy for compacting an oversized conversation history. -#[async_trait] -pub trait CompactionStrategy: Send + Sync + std::fmt::Debug { - async fn compact( - &self, - messages: &[Message], - target_tokens: usize, - ) -> Result; } /// Persists evicted message content before compaction drops them. @@ -802,6 +791,7 @@ pub struct CompactionResult { pub(crate) compacted_count: usize, pub(crate) estimated_tokens: usize, pub(crate) used_summarization: bool, + pub(crate) summary: Option, /// Indices into the original message slice for messages evicted by compaction. pub(crate) evicted_indices: Vec, } @@ -818,11 +808,8 @@ impl SlidingWindowCompactor { preserve_recent_turns, } } -} -#[async_trait] -impl CompactionStrategy for SlidingWindowCompactor { - async fn compact( + pub async fn compact( &self, messages: &[Message], target_tokens: usize, @@ -831,87 +818,124 @@ impl CompactionStrategy for SlidingWindowCompactor { } } -/// Summarizes older turns into structured context using an LLM call. -#[derive(Debug)] -pub struct SummarizingCompactor { - llm: Arc, - preserve_recent_turns: usize, - max_summary_tokens: usize, -} +pub const DEFAULT_MAX_SUMMARY_TOKENS: usize = 1_024; -impl SummarizingCompactor { - pub const DEFAULT_MAX_SUMMARY_TOKENS: usize = 1_024; +#[derive(Debug, Clone)] +pub(crate) struct SlideSummarizationPlan { + bounds: ZoneBounds, + protected_middle: HashSet, + pub(crate) evicted_indices: Vec, + pub(crate) evicted_messages: Vec, +} - pub fn new(llm: Arc, preserve_recent_turns: usize) -> Self { - Self::with_max_summary_tokens(llm, preserve_recent_turns, Self::DEFAULT_MAX_SUMMARY_TOKENS) +fn summarizable_indices( + bounds: &ZoneBounds, + protected_middle: &HashSet, +) -> Result, CompactionError> { + let indices = summarizable_middle_indices(bounds, protected_middle); + if indices.is_empty() { + return Err(CompactionError::AllMessagesProtected); } + Ok(indices) +} - pub fn with_max_summary_tokens( - llm: Arc, - preserve_recent_turns: usize, - max_summary_tokens: usize, - ) -> Self { - Self { - llm, - preserve_recent_turns, - max_summary_tokens, - } - } +pub(crate) fn slide_summarization_plan( + messages: &[Message], + preserve_recent_turns: usize, +) -> Result { + let bounds = zone_bounds(messages, preserve_recent_turns); + let protected_middle = protected_middle_indices(messages, &bounds); + let evicted_indices = summarizable_indices(&bounds, &protected_middle)?; + let evicted_messages = cloned_messages_at_indices(messages, &evicted_indices); + Ok(SlideSummarizationPlan { + bounds, + protected_middle, + evicted_indices, + evicted_messages, + }) +} - fn summarizable_indices( - &self, - bounds: &ZoneBounds, - protected_middle: &HashSet, - ) -> Result, CompactionError> { - let indices = summarizable_middle_indices(bounds, protected_middle); - if indices.is_empty() { - return Err(CompactionError::AllMessagesProtected); - } - Ok(indices) +pub async fn summarized_compaction_result( + llm: &dyn LlmProvider, + messages: &[Message], + preserve_recent_turns: usize, + max_summary_tokens: usize, + target_tokens: usize, +) -> Result { + let before_tokens = ConversationBudget::estimate_tokens(messages); + if before_tokens <= target_tokens { + let result = CompactionResult { + messages: messages.to_vec(), + compacted_count: 0, + estimated_tokens: before_tokens, + used_summarization: false, + summary: None, + evicted_indices: Vec::new(), + }; + debug_assert_tool_pair_integrity(&result.messages); + return Ok(result); } - async fn generate_summary( - &self, - summarizable_messages: &[Message], - ) -> Result { - let prompt = Self::summary_prompt(summarizable_messages); - self.llm - .generate(&prompt, self.max_summary_tokens as u32) - .await - .map_err(|source| CompactionError::SummarizationFailed { - source: Box::new(source), - }) + let plan = slide_summarization_plan(messages, preserve_recent_turns)?; + let summary = generate_summary(llm, &plan.evicted_messages, max_summary_tokens).await?; + let compacted_messages = assemble_summarized_messages(messages, &plan, &summary); + let estimated_tokens = ConversationBudget::estimate_tokens(&compacted_messages); + if estimated_tokens > target_tokens { + return Err(CompactionError::SummaryExceededTarget); } - fn assemble_summarized_messages( - &self, - messages: &[Message], - bounds: &ZoneBounds, - protected_middle: &HashSet, - summary: &str, - ) -> Vec { - let mut compacted_messages = Vec::new(); - compacted_messages.extend_from_slice(&messages[..bounds.prefix_end]); - append_protected_middle_messages( - &mut compacted_messages, - messages, - bounds, - protected_middle, - ); - compacted_messages.push(summary_message(summary)); - compacted_messages.extend_from_slice(&messages[bounds.tail_start..]); - compacted_messages - } + let result = CompactionResult { + messages: compacted_messages, + compacted_count: plan.evicted_messages.len(), + estimated_tokens, + used_summarization: true, + summary: Some(summary), + evicted_indices: plan.evicted_indices, + }; + debug_assert_tool_pair_integrity(&result.messages); + Ok(result) +} - fn summary_prompt(messages: &[Message]) -> String { - let conversation = messages - .iter() - .map(message_to_summary_line) - .collect::>() - .join("\n"); +pub(crate) async fn generate_summary( + llm: &dyn LlmProvider, + summarizable_messages: &[Message], + max_summary_tokens: usize, +) -> Result { + let prompt = summary_prompt(summarizable_messages); + llm.generate(&prompt, max_summary_tokens as u32) + .await + .map_err(|source| CompactionError::SummarizationFailed { + source: Box::new(source), + }) +} + +pub(crate) fn assemble_summarized_messages( + messages: &[Message], + plan: &SlideSummarizationPlan, + summary: &str, +) -> Vec { + let mut compacted_messages = Vec::new(); + compacted_messages.extend_from_slice(&messages[..plan.bounds.prefix_end]); + append_protected_middle_messages( + &mut compacted_messages, + messages, + &plan.bounds, + &plan.protected_middle, + ); + compacted_messages.push(summary_message(summary)); + compacted_messages.extend_from_slice(&messages[plan.bounds.tail_start..]); + compacted_messages +} + +pub(crate) fn summary_prompt(messages: &[Message]) -> String { + let conversation = messages + .iter() + .map(message_to_summary_line) + .collect::>() + .join("\n"); - format!( - "Summarize the following conversation history.\n\ + format!( + "Summarize the following conversation history.\n\ Keep the summary factual and grounded in provided content only.\n\ \nSections (required):\n\ 1. Decisions\n\ @@ -919,8 +943,7 @@ Keep the summary factual and grounded in provided content only.\n\ 3. Task state\n\ 4. Key context\n\ \nConversation:\n{conversation}" - ) - } + ) } fn message_to_summary_line(message: &Message) -> String { @@ -961,62 +984,16 @@ fn message_to_summary_line(message: &Message) -> String { format!("- {role}: {text}") } -#[async_trait] -impl CompactionStrategy for SummarizingCompactor { - async fn compact( - &self, - messages: &[Message], - target_tokens: usize, - ) -> Result { - let before_tokens = ConversationBudget::estimate_tokens(messages); - if before_tokens <= target_tokens { - let result = CompactionResult { - messages: messages.to_vec(), - compacted_count: 0, - estimated_tokens: before_tokens, - used_summarization: false, - evicted_indices: Vec::new(), - }; - debug_assert_tool_pair_integrity(&result.messages); - return Ok(result); - } - - let bounds = zone_bounds(messages, self.preserve_recent_turns); - let protected_middle = protected_middle_indices(messages, &bounds); - let summarizable_indices = self.summarizable_indices(&bounds, &protected_middle)?; - let summarizable_messages = cloned_messages_at_indices(messages, &summarizable_indices); - let summary = self.generate_summary(&summarizable_messages).await?; - let compacted_messages = - self.assemble_summarized_messages(messages, &bounds, &protected_middle, &summary); - - let estimated_tokens = ConversationBudget::estimate_tokens(&compacted_messages); - if estimated_tokens > target_tokens { - return Err(CompactionError::SummaryExceededTarget); - } - - let result = CompactionResult { - messages: compacted_messages, - compacted_count: summarizable_messages.len(), - estimated_tokens, - used_summarization: true, - evicted_indices: summarizable_indices, - }; - debug_assert_tool_pair_integrity(&result.messages); - Ok(result) - } -} - #[derive(Debug, Clone, thiserror::Error)] pub enum CompactionConfigError { #[error("threshold must be in (0.0, 1.0], got {0}")] InvalidThreshold(f32), #[error( - "thresholds must be strictly increasing: prune ({prune}) < slide ({slide}) < summarize ({summarize}) < emergency ({emergency})" + "thresholds must be strictly increasing: prune ({prune}) < slide ({slide}) < emergency ({emergency})" )] ThresholdsNotMonotonic { prune: f32, slide: f32, - summarize: f32, emergency: f32, }, #[error("model_context_limit must be > 0")] @@ -1046,7 +1023,9 @@ pub struct CompactionConfig { #[serde(alias = "compaction_threshold")] pub(crate) slide_threshold: f32, pub(crate) prune_threshold: f32, - pub(crate) summarize_threshold: f32, + /// Legacy field retained only for backward-compatible deserialization. + #[serde(alias = "summarize_threshold", default, skip_serializing)] + pub(crate) _legacy_summarize_threshold: f32, pub(crate) emergency_threshold: f32, pub(crate) preserve_recent_turns: usize, pub(crate) model_context_limit: usize, @@ -1078,15 +1057,13 @@ impl CompactionConfig { for threshold in [ self.prune_threshold, self.slide_threshold, - self.summarize_threshold, self.emergency_threshold, ] { validate_threshold(threshold)?; } if self.prune_threshold < self.slide_threshold - && self.slide_threshold < self.summarize_threshold - && self.summarize_threshold < self.emergency_threshold + && self.slide_threshold < self.emergency_threshold { return Ok(()); } @@ -1094,7 +1071,6 @@ impl CompactionConfig { Err(CompactionConfigError::ThresholdsNotMonotonic { prune: self.prune_threshold, slide: self.slide_threshold, - summarize: self.summarize_threshold, emergency: self.emergency_threshold, }) } @@ -1135,24 +1111,6 @@ impl CompactionConfig { } Ok(()) } - - pub fn build_strategy(&self, llm: Option>) -> Box { - if self.use_summarization { - if let Some(provider) = llm { - return Box::new(SummarizingCompactor::with_max_summary_tokens( - provider, - self.preserve_recent_turns, - self.max_summary_tokens, - )); - } - - tracing::info!( - "use_summarization=true but no llm provider available; falling back to SlidingWindowCompactor" - ); - } - - Box::new(SlidingWindowCompactor::new(self.preserve_recent_turns)) - } } impl Default for CompactionConfig { @@ -1160,14 +1118,14 @@ impl Default for CompactionConfig { Self { slide_threshold: 0.60, prune_threshold: 0.40, - summarize_threshold: 0.80, + _legacy_summarize_threshold: 0.80, emergency_threshold: 0.95, - preserve_recent_turns: 6, + preserve_recent_turns: 12, model_context_limit: 128_000, reserved_system_tokens: 2_000, recompact_cooldown_turns: 2, use_summarization: true, - max_summary_tokens: SummarizingCompactor::DEFAULT_MAX_SUMMARY_TOKENS, + max_summary_tokens: DEFAULT_MAX_SUMMARY_TOKENS, prune_tool_blocks: true, tool_block_summary_max_chars: 100, } @@ -1181,7 +1139,7 @@ mod tests { use fx_core::error::LlmError as CoreLlmError; use fx_llm::{CompletionRequest, CompletionResponse, ProviderError, ToolCall}; use std::collections::VecDeque; - use std::sync::Mutex; + use std::sync::{Arc, Mutex}; fn words(count: usize) -> String { std::iter::repeat_n("a", count) @@ -1326,9 +1284,8 @@ mod tests { let config = CompactionConfig::default(); assert_eq!(config.prune_threshold, 0.40); assert_eq!(config.slide_threshold, 0.60); - assert_eq!(config.summarize_threshold, 0.80); assert_eq!(config.emergency_threshold, 0.95); - assert_eq!(config.preserve_recent_turns, 6); + assert_eq!(config.preserve_recent_turns, 12); assert_eq!(config.model_context_limit, 128_000); assert_eq!(config.reserved_system_tokens, 2_000); assert_eq!(config.recompact_cooldown_turns, 2); @@ -1357,15 +1314,6 @@ mod tests { assert!(budget.at_tier(&[user(453)], 0.5)); } - #[test] - fn summarize_target_returns_two_fifths_of_budget() { - let budget = ConversationBudget::new(16_384, 0.8, 2_000); - assert_eq!( - budget.summarize_target(), - budget.conversation_budget() * 2 / 5 - ); - } - #[test] fn needs_compaction_returns_false_below_threshold() { let budget = ConversationBudget::new(5_000, 0.50, 0); @@ -1817,22 +1765,86 @@ mod tests { assert!(has_tool_result(&result.messages, "keep")); } - // 5.3 SummarizingCompactor tests + // 5.3 Summarization function tests + + #[test] + fn summary_prompt_requests_required_sections() { + let prompt = summary_prompt(&[ + user(20), + assistant(20), + tool_use("call-1"), + tool_result("call-1", 20), + ]); + + assert!(prompt.contains("Sections (required):")); + assert!(prompt.contains("1. Decisions")); + assert!(prompt.contains("2. Files modified")); + assert!(prompt.contains("3. Task state")); + assert!(prompt.contains("4. Key context")); + assert!(prompt.contains("- assistant: [tool_use:read]")); + assert!(prompt.contains("- tool: [tool_result:call-1]")); + } #[tokio::test] - async fn summarize_produces_structured_output() { + async fn generate_summary_uses_llm_and_records_prompt() { let llm = Arc::new(MockSummaryLlm::new(vec![Ok( "Decisions:\n- keep\nFiles modified:\n- src/lib.rs\nTask state:\n- in progress\nKey context:\n- tests failing" .to_string(), )])); - let compactor = SummarizingCompactor::new(llm, 2); - let messages = vec![user(40), assistant(40), user(30), assistant(30), user(20)]; + let messages = vec![user(40), assistant(40), user(30)]; - let result = compactor.compact(&messages, 120).await.expect("compact"); - assert!(result.used_summarization); + let summary = generate_summary(llm.as_ref(), &messages, 256) + .await + .expect("summary"); + + assert!(summary.contains("Decisions:")); + let prompts = llm.prompts(); + assert_eq!(prompts.len(), 1); + assert!(prompts[0].contains("Sections (required):")); + } + + #[tokio::test] + async fn generate_summary_returns_summarization_failed_on_llm_error() { + let llm = Arc::new(MockSummaryLlm::new(vec![Err(CoreLlmError::Inference( + "boom".to_string(), + ))])); + let messages = vec![user(40), assistant(40), user(30)]; + + let error = generate_summary(llm.as_ref(), &messages, 256) + .await + .expect_err("error"); + assert!(matches!(error, CompactionError::SummarizationFailed { .. })); + } + + #[tokio::test] + async fn generate_summary_returns_summarization_failed_on_timeout() { + let llm = Arc::new(MockSummaryLlm::new(vec![Err(CoreLlmError::ApiRequest( + "timeout".to_string(), + ))])); + let messages = vec![user(40), assistant(40), user(30)]; + + let error = generate_summary(llm.as_ref(), &messages, 256) + .await + .expect_err("error"); + assert!(matches!(error, CompactionError::SummarizationFailed { .. })); + } + + #[test] + fn assemble_summarized_messages_inserts_single_summary_marker() { + let messages = vec![ + Message::system("system"), + user(40), + assistant(40), + user(30), + assistant(30), + user(20), + ]; + let plan = slide_summarization_plan(&messages, 2).expect("summary plan"); + let compacted = assemble_summarized_messages(&messages, &plan, "Decisions:\n- keep"); + + assert_eq!(compacted.first(), Some(&messages[0])); assert_eq!( - result - .messages + compacted .iter() .filter(|message| { text_blocks(message).any(|text| text.starts_with(SUMMARY_MARKER_PREFIX)) @@ -1840,15 +1852,22 @@ mod tests { .count(), 1 ); + assert_eq!(&compacted[compacted.len() - 2..], &messages[4..]); + } + + #[test] + fn summary_markers_are_treated_as_system_like() { + assert!(message_is_system_like(&summary_message( + "Decisions:\n- keep" + ))); } #[tokio::test] - async fn evicted_indices_populated_for_summarizing() { + async fn summarized_compaction_populates_evicted_indices() { let llm = Arc::new(MockSummaryLlm::new(vec![Ok( "Decisions:\n- keep\nFiles modified:\n- src/lib.rs\nTask state:\n- in progress\nKey context:\n- tests failing" .to_string(), )])); - let compactor = SummarizingCompactor::new(llm, 2); let messages = vec![ Message::system("system"), user(40), @@ -1858,90 +1877,39 @@ mod tests { user(20), ]; - let result = compactor.compact(&messages, 120).await.expect("compact"); + let result = summarized_compaction_result(llm.as_ref(), &messages, 2, 256, 120) + .await + .expect("compact"); + assert!(result.used_summarization); assert_eq!(result.evicted_indices, vec![1, 2, 3]); } #[tokio::test] - async fn summarize_returns_summarization_failed_on_llm_error() { - let llm = Arc::new(MockSummaryLlm::new(vec![Err(CoreLlmError::Inference( - "boom".to_string(), - ))])); - let compactor = SummarizingCompactor::new(llm, 2); - let messages = vec![user(40), assistant(40), user(30), assistant(30), user(20)]; - - let error = compactor.compact(&messages, 120).await.expect_err("error"); - assert!(matches!(error, CompactionError::SummarizationFailed { .. })); - } - - #[tokio::test] - async fn summarize_returns_summarization_failed_on_timeout() { - let llm = Arc::new(MockSummaryLlm::new(vec![Err(CoreLlmError::ApiRequest( - "timeout".to_string(), - ))])); - let compactor = SummarizingCompactor::new(llm, 2); - let messages = vec![user(40), assistant(40), user(30), assistant(30), user(20)]; - - let error = compactor.compact(&messages, 120).await.expect_err("error"); - assert!(matches!(error, CompactionError::SummarizationFailed { .. })); - } - - #[tokio::test] - async fn summarize_returns_summary_exceeded_target_when_summary_too_large() { + async fn summarized_compaction_returns_summary_exceeded_target_when_summary_too_large() { let llm = Arc::new(MockSummaryLlm::new(vec![Ok(words(500))])); - let compactor = SummarizingCompactor::new(llm, 2); let messages = vec![user(40), assistant(40), user(30), assistant(30), user(20)]; - let error = compactor.compact(&messages, 120).await.expect_err("error"); + let error = summarized_compaction_result(llm.as_ref(), &messages, 2, 256, 120) + .await + .expect_err("error"); assert!(matches!(error, CompactionError::SummaryExceededTarget)); } #[tokio::test] - async fn summarize_respects_target_budget() { + async fn summarized_compaction_respects_target_budget() { let llm = Arc::new(MockSummaryLlm::new(vec![Ok( "Decisions:\n- x\nFiles modified:\n- y\nTask state:\n- z\nKey context:\n- q" .to_string(), )])); - let compactor = SummarizingCompactor::new(llm, 2); let messages = vec![user(30), assistant(30), user(30), assistant(30), user(20)]; - let result = compactor.compact(&messages, 110).await.expect("compact"); + let result = summarized_compaction_result(llm.as_ref(), &messages, 2, 256, 110) + .await + .expect("compact"); assert!(result.estimated_tokens <= 110); } - #[tokio::test] - async fn summary_preserves_key_context_categories() { - let llm = Arc::new(MockSummaryLlm::new(vec![Ok( - "Decisions:\n- keep\nFiles modified:\n- src/main.rs\nTask state:\n- done\nKey context:\n- regression fixed" - .to_string(), - )])); - let provider: Arc = llm.clone(); - let compactor = SummarizingCompactor::new(provider, 2); - let messages = vec![user(35), assistant(35), user(30), assistant(30), user(20)]; - - let result = compactor.compact(&messages, 120).await.expect("compact"); - let summary_text = text_blocks( - result - .messages - .iter() - .find(|message| { - text_blocks(message).any(|text| text.starts_with(SUMMARY_MARKER_PREFIX)) - }) - .expect("summary"), - ) - .collect::>() - .join("\n"); - - assert!(summary_text.contains("Decisions:")); - assert!(summary_text.contains("Files modified:")); - assert!(summary_text.contains("Task state:")); - assert!(summary_text.contains("Key context:")); - - let prompts = llm.prompts(); - assert!(prompts[0].contains("Sections (required):")); - } - // 5.6 CompactionConfig validation tests #[test] @@ -1977,13 +1945,23 @@ mod tests { #[test] fn config_rejects_non_monotonic_thresholds() { let mut config = CompactionConfig::default(); - config.summarize_threshold = config.slide_threshold; + config.emergency_threshold = config.slide_threshold; assert!(matches!( config.validate(), Err(CompactionConfigError::ThresholdsNotMonotonic { .. }) )); } + #[test] + fn legacy_summarize_threshold_is_ignored_during_validation() { + let mut config = CompactionConfig::default(); + config._legacy_summarize_threshold = 0.01; + + config + .validate() + .expect("legacy summarize threshold should be ignored"); + } + #[test] fn config_accepts_valid_thresholds() { CompactionConfig::default() @@ -2003,16 +1981,20 @@ mod tests { config.prune_threshold, CompactionConfig::default().prune_threshold ); - assert_eq!( - config.summarize_threshold, - CompactionConfig::default().summarize_threshold - ); assert_eq!( config.emergency_threshold, CompactionConfig::default().emergency_threshold ); } + #[test] + fn legacy_summarize_threshold_is_not_serialized() { + let serialized = + serde_json::to_value(CompactionConfig::default()).expect("config should serialize"); + + assert!(serialized.get("summarize_threshold").is_none()); + } + #[test] fn config_rejects_zero_context_limit() { let mut config = CompactionConfig::default(); @@ -2406,7 +2388,7 @@ mod tests { let config = CompactionConfig { slide_threshold: 0.80, prune_threshold: 0.40, - summarize_threshold: 0.90, + _legacy_summarize_threshold: 0.90, emergency_threshold: 0.95, preserve_recent_turns: 2, model_context_limit: 16_000, @@ -2422,7 +2404,7 @@ mod tests { config.slide_threshold, config.reserved_system_tokens, ); - let strategy = config.build_strategy(None); + let compactor = SlidingWindowCompactor::new(config.preserve_recent_turns); assert_eq!(budget.compaction_target(), budget.conversation_budget() / 2); // Build messages: a massive tool result in old window pushes tokens @@ -2455,9 +2437,8 @@ mod tests { "pruned messages should be below slide threshold (tokens: {after_tokens})" ); - // Verify the compaction strategy is never invoked (we'd get the pruned - // messages back without the compaction marker). - let result = strategy.compact(&pruned, budget.compaction_target()).await; + // Verify sliding compaction is a no-op once pruning has already reduced usage. + let result = compactor.compact(&pruned, budget.compaction_target()).await; match result { Ok(r) => assert_eq!(r.compacted_count, 0, "compaction should be a no-op"), Err(_) => panic!("compact should succeed on already-below-threshold messages"), diff --git a/engine/crates/fx-kernel/src/loop_engine.rs b/engine/crates/fx-kernel/src/loop_engine.rs index 25d0ec38..7b5b731c 100644 --- a/engine/crates/fx-kernel/src/loop_engine.rs +++ b/engine/crates/fx-kernel/src/loop_engine.rs @@ -12,9 +12,11 @@ use crate::channels::ChannelRegistry; use crate::context_manager::ContextCompactor; use crate::conversation_compactor::{ - debug_assert_tool_pair_integrity, emergency_compact, estimate_text_tokens, has_prunable_blocks, - prune_tool_blocks, CompactionConfig, CompactionError, CompactionMemoryFlush, CompactionResult, - CompactionStrategy, ConversationBudget, SlidingWindowCompactor, + assemble_summarized_messages, debug_assert_tool_pair_integrity, emergency_compact, + estimate_text_tokens, generate_summary, has_prunable_blocks, prune_tool_blocks, + slide_summarization_plan, summary_message, CompactionConfig, CompactionError, + CompactionMemoryFlush, CompactionResult, ConversationBudget, SlideSummarizationPlan, + SlidingWindowCompactor, }; use crate::decide::Decision; use crate::input::{LoopCommand, LoopInputChannel}; @@ -218,6 +220,13 @@ impl<'a> CycleStream<'a> { }); } + fn tool_error(self, tool_name: &str, error: &str) { + self.emit(StreamEvent::ToolError { + tool_name: tool_name.to_string(), + error: error.to_string(), + }); + } + fn notification(self, title: impl Into, body: impl Into) { self.emit(StreamEvent::Notification { title: title.into(), @@ -401,7 +410,6 @@ impl std::fmt::Display for CompactionScope { enum CompactionTier { Prune, Slide, - Summarize, Emergency, } @@ -410,7 +418,6 @@ impl CompactionTier { match self { Self::Prune => "prune", Self::Slide => "slide", - Self::Summarize => "summarize", Self::Emergency => "emergency", } } @@ -418,10 +425,9 @@ impl CompactionTier { /// Core orchestrator for the 7-step agentic loop. /// -/// Note: `LoopEngine` previously derived `Clone`, but Phases 1-3 -/// (context window compaction) introduced two non-`Clone` fields: -/// `conversation_compactor: Box` and -/// `compaction_last_iteration: Mutex>`. +/// Note: `LoopEngine` previously derived `Clone`, but context compaction +/// introduced a non-`Clone` cooldown tracker +/// (`compaction_last_iteration: Mutex>`). /// `LoopInputChannel` also contains an `mpsc::Receiver`, which remains /// non-`Clone`. No existing code clones `LoopEngine`, so this is a safe change. pub struct LoopEngine { @@ -442,7 +448,6 @@ pub struct LoopEngine { event_bus: Option, compaction_config: CompactionConfig, conversation_budget: ConversationBudget, - conversation_compactor: Box, /// LLM for compaction-time memory extraction. compaction_llm: Option>, memory_flush: Option>, @@ -450,10 +455,10 @@ pub struct LoopEngine { /// Guards performance signal to fire only on the Normal→Low transition, /// not on every `perceive()` call while the budget stays Low. budget_low_signaled: bool, - /// Consecutive iterations that used tools without producing user-facing text. + /// Consecutive iterations that included tool calls. /// Stored on `LoopEngine` because `perceive()` only has `&mut self`. /// Cycle-scoped; `prepare_cycle()` resets it, so child cycles start fresh. - consecutive_tool_only_turns: u16, + consecutive_tool_turns: u16, /// Latest reasoning input messages for graceful budget-exhausted synthesis. /// Stored on `LoopEngine` because `perceive()` only has `&mut self`. last_reasoning_messages: Vec, @@ -486,10 +491,7 @@ impl std::fmt::Debug for LoopEngine { .field("scratchpad_context", &self.scratchpad_context) .field("compaction_config", &self.compaction_config) .field("budget_low_signaled", &self.budget_low_signaled) - .field( - "consecutive_tool_only_turns", - &self.consecutive_tool_only_turns, - ) + .field("consecutive_tool_turns", &self.consecutive_tool_turns) .field("tool_retry_tracker", &self.tool_retry_tracker) .field("notify_called_this_cycle", &self.notify_called_this_cycle) .field( @@ -858,8 +860,12 @@ impl LoopEngineBuilder { let synthesis_instruction = required_builder_field(self.synthesis_instruction, "synthesis_instruction")?; let compaction_llm_for_extraction = self.compaction_llm.as_ref().map(Arc::clone); - let (compaction_config, conversation_budget, conversation_compactor) = - build_compaction_components(self.compaction_config, self.compaction_llm)?; + let (compaction_config, conversation_budget) = + build_compaction_components(self.compaction_config)?; + let session_memory = self + .session_memory + .unwrap_or_else(|| default_session_memory(compaction_config.model_context_limit)); + configure_session_memory(&session_memory, compaction_config.model_context_limit); Ok(LoopEngine { budget, @@ -869,7 +875,7 @@ impl LoopEngineBuilder { iteration_count: 0, synthesis_instruction, memory_context: self.memory_context, - session_memory: self.session_memory.unwrap_or_else(default_session_memory), + session_memory, scratchpad_context: self.scratchpad_context, signals: SignalCollector::default(), cancel_token: self.cancel_token, @@ -879,12 +885,11 @@ impl LoopEngineBuilder { event_bus: self.event_bus, compaction_config, conversation_budget, - conversation_compactor, compaction_llm: compaction_llm_for_extraction, memory_flush: self.memory_flush, compaction_last_iteration: Mutex::new(HashMap::new()), budget_low_signaled: false, - consecutive_tool_only_turns: 0, + consecutive_tool_turns: 0, last_reasoning_messages: Vec::new(), tool_retry_tracker: ToolRetryTracker::default(), notify_called_this_cycle: false, @@ -901,15 +906,7 @@ impl LoopEngineBuilder { fn build_compaction_components( config: Option, - llm: Option>, -) -> Result< - ( - CompactionConfig, - ConversationBudget, - Box, - ), - LoopError, -> { +) -> Result<(CompactionConfig, ConversationBudget), LoopError> { let compaction_config = config.unwrap_or_default(); compaction_config.validate().map_err(|error| { loop_error( @@ -924,8 +921,7 @@ fn build_compaction_components( compaction_config.slide_threshold, compaction_config.reserved_system_tokens, ); - let strategy = compaction_config.build_strategy(llm); - Ok((compaction_config, conversation_budget, strategy)) + Ok((compaction_config, conversation_budget)) } fn build_extraction_prompt(messages: &[Message]) -> String { @@ -1023,6 +1019,122 @@ fn parse_extraction_response(response: &str) -> Option { None } +#[derive(Clone, Copy)] +enum SummarySection { + Decisions, + FilesModified, + TaskState, + KeyContext, +} + +#[derive(Default)] +struct ParsedSummarySections { + decisions: Vec, + files_modified: Vec, + task_state: Vec, + key_context: Vec, +} + +fn parse_summary_memory_update(summary: &str) -> Option { + let sections = parse_summary_sections(summary); + let update = SessionMemoryUpdate { + project: None, + current_state: joined_summary_section(§ions.task_state), + key_decisions: optional_summary_items(sections.decisions), + active_files: optional_summary_items(sections.files_modified), + custom_context: optional_summary_items(sections.key_context), + }; + has_memory_update_fields(&update).then_some(update) +} + +fn parse_summary_sections(summary: &str) -> ParsedSummarySections { + let mut sections = ParsedSummarySections::default(); + let mut current = None; + for line in summary + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + { + if let Some((section, inline)) = summary_section_header(line) { + current = Some(section); + if let Some(text) = inline { + push_summary_section_line(&mut sections, section, text); + } + continue; + } + if let Some(section) = current { + push_summary_section_line(&mut sections, section, line); + } + } + sections +} + +fn summary_section_header(line: &str) -> Option<(SummarySection, Option<&str>)> { + let (heading, remainder) = line.split_once(':')?; + let section = match strip_summary_section_numbering(heading) { + text if text.eq_ignore_ascii_case("Decisions") => SummarySection::Decisions, + text if text.eq_ignore_ascii_case("Files modified") => SummarySection::FilesModified, + text if text.eq_ignore_ascii_case("Task state") => SummarySection::TaskState, + text if text.eq_ignore_ascii_case("Key context") => SummarySection::KeyContext, + _ => return None, + }; + let inline = (!remainder.trim().is_empty()).then_some(remainder.trim()); + Some((section, inline)) +} + +fn strip_summary_section_numbering(heading: &str) -> &str { + let trimmed = heading.trim(); + let digits_len = trimmed + .as_bytes() + .iter() + .take_while(|byte| byte.is_ascii_digit()) + .count(); + if digits_len == 0 { + return trimmed; + } + trimmed[digits_len..] + .strip_prefix('.') + .map_or(trimmed, |remainder| remainder.trim_start()) +} + +fn push_summary_section_line( + sections: &mut ParsedSummarySections, + section: SummarySection, + line: &str, +) { + let trimmed = line.trim(); + let item = trimmed + .strip_prefix("- ") + .or_else(|| trimmed.strip_prefix("* ")) + .unwrap_or(trimmed) + .trim(); + if item.is_empty() { + return; + } + match section { + SummarySection::Decisions => sections.decisions.push(item.to_string()), + SummarySection::FilesModified => sections.files_modified.push(item.to_string()), + SummarySection::TaskState => sections.task_state.push(item.to_string()), + SummarySection::KeyContext => sections.key_context.push(item.to_string()), + } +} + +fn joined_summary_section(items: &[String]) -> Option { + (!items.is_empty()).then(|| items.join("; ")) +} + +fn optional_summary_items(items: Vec) -> Option> { + (!items.is_empty()).then_some(items) +} + +fn has_memory_update_fields(update: &SessionMemoryUpdate) -> bool { + update.project.is_some() + || update.current_state.is_some() + || update.key_decisions.is_some() + || update.active_files.is_some() + || update.custom_context.is_some() +} + fn extract_json_object(text: &str) -> Option<&str> { let start = text.find('{')?; let end = text.rfind('}')?; @@ -1041,8 +1153,15 @@ fn normalize_memory_context(memory_context: String) -> Option { } } -fn default_session_memory() -> Arc> { - Arc::new(Mutex::new(SessionMemory::default())) +fn default_session_memory(context_limit: usize) -> Arc> { + Arc::new(Mutex::new(SessionMemory::with_context_limit(context_limit))) +} + +fn configure_session_memory(memory: &Arc>, context_limit: usize) { + let mut memory = memory + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + memory.set_context_limit(context_limit); } #[derive(Debug, Default, Clone)] @@ -1260,7 +1379,19 @@ Do not call any tools. Do not decompose. \ Summarize what you have accomplished and what remains undone. Be concise."; const BUDGET_EXHAUSTED_SYNTHESIS_DIRECTIVE: &str = "\n\nYour tool budget is exhausted. Provide a final response summarizing what you've found and accomplished."; const BUDGET_EXHAUSTED_FALLBACK_RESPONSE: &str = "I reached my iteration limit."; -const TOOL_ONLY_TURN_NUDGE: &str = "You've been working for several steps without responding. Share your progress with the user before continuing."; +const TOOL_TURN_NUDGE: &str = "You've been working for several steps without responding. Share your progress with the user before continuing."; +const TOOL_ROUND_PROGRESS_NUDGE: &str = "You've been calling tools for several rounds without providing a response. Share your progress with the user now. If you have enough information to answer, do so immediately instead of calling more tools."; +const TOOL_ERROR_RELAY_PREFIX: &str = "The following tools failed. Report these errors to the user before continuing with additional tool calls:"; + +fn tool_error_relay_directive(failed_tools: &[(&str, &str)]) -> String { + let details: Vec = failed_tools + .iter() + .map(|(name, error)| format!("- Tool '{}' failed with: {}", name, error)) + .collect(); + format!("{}\n{}", TOOL_ERROR_RELAY_PREFIX, details.join("\n")) +} +/// Maximum time to wait for a best-effort summary during emergency compaction. +const EMERGENCY_SUMMARY_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(500); impl LoopEngine { /// Create a loop engine builder. @@ -1303,11 +1434,13 @@ impl LoopEngine { } pub fn replace_session_memory(&self, memory: SessionMemory) -> SessionMemory { + let mut replacement = memory; + replacement.set_context_limit(self.compaction_config.model_context_limit); let mut stored = match self.session_memory.lock() { Ok(guard) => guard, Err(poisoned) => poisoned.into_inner(), }; - std::mem::replace(&mut *stored, memory) + std::mem::replace(&mut *stored, replacement) } pub fn session_memory_snapshot(&self) -> SessionMemory { @@ -1353,6 +1486,7 @@ impl LoopEngine { self.compaction_config.slide_threshold, self.compaction_config.reserved_system_tokens, ); + configure_session_memory(&self.session_memory, new_limit); } /// Synchronise the shared iteration counter and refresh scratchpad context. @@ -1537,7 +1671,7 @@ impl LoopEngine { } state.tokens.accumulate(action.tokens_used); - self.update_tool_only_turns(&action); + self.update_tool_turns(&action); if let Some(result) = self.check_cancellation(Some(action.response_text.clone())) { return Ok(self.finish_streaming_result(result, stream)); @@ -1703,7 +1837,7 @@ impl LoopEngine { self.user_stop_requested = false; self.pending_steer = None; self.budget_low_signaled = false; - self.consecutive_tool_only_turns = 0; + self.consecutive_tool_turns = 0; self.last_reasoning_messages.clear(); self.tool_retry_tracker.clear(); self.notify_called_this_cycle = false; @@ -1714,11 +1848,39 @@ impl LoopEngine { self.tool_executor.clear_cache(); } - fn update_tool_only_turns(&mut self, action: &ActionResult) { - if !action.tool_results.is_empty() && action.response_text.trim().is_empty() { - self.consecutive_tool_only_turns = self.consecutive_tool_only_turns.saturating_add(1); - } else if !action.response_text.trim().is_empty() { - self.consecutive_tool_only_turns = 0; + fn update_tool_turns(&mut self, action: &ActionResult) { + if !action.tool_results.is_empty() { + self.consecutive_tool_turns = self.consecutive_tool_turns.saturating_add(1); + } else { + self.consecutive_tool_turns = 0; + } + } + + /// Apply nudge/strip policy for the current tool continuation round. + /// + /// Mutates `continuation_messages` by appending a progress nudge at the + /// nudge threshold round. Returns the tool definitions to use: either the + /// full set (normal) or an empty vec (tools stripped at strip threshold). + fn apply_tool_round_progress_policy( + &self, + round: u32, + continuation_messages: &mut Vec, + ) -> Vec { + let tc = &self.budget.config().termination; + let nudge_threshold = u32::from(tc.tool_round_nudge_after); + let strip_threshold = + nudge_threshold.saturating_add(u32::from(tc.tool_round_strip_after_nudge)); + + // Fire nudge exactly once (at the threshold round) to avoid stacking + // duplicate nudge messages in continuation_messages across rounds. + if nudge_threshold > 0 && round == nudge_threshold { + continuation_messages.push(Message::system(TOOL_ROUND_PROGRESS_NUDGE.to_string())); + } + + if nudge_threshold > 0 && round >= strip_threshold { + Vec::new() + } else { + self.tool_executor.tool_definitions() } } @@ -1885,8 +2047,8 @@ impl LoopEngine { } let nudge_at = self.budget.config().termination.nudge_after_tool_turns; - if nudge_at > 0 && self.consecutive_tool_only_turns >= nudge_at { - context_window.push(Message::system(TOOL_ONLY_TURN_NUDGE.to_string())); + if nudge_at > 0 && self.consecutive_tool_turns >= nudge_at { + context_window.push(Message::system(TOOL_TURN_NUDGE.to_string())); } let processed = ProcessedPerception { @@ -1920,13 +2082,13 @@ impl LoopEngine { ) -> Result { let tc = &self.budget.config().termination; let should_strip_tools = tc.nudge_after_tool_turns > 0 - && self.consecutive_tool_only_turns + && self.consecutive_tool_turns >= tc .nudge_after_tool_turns .saturating_add(tc.strip_tools_after_nudge); let tools = if should_strip_tools { tracing::info!( - turns = self.consecutive_tool_only_turns, + turns = self.consecutive_tool_turns, "stripping tools: agent exceeded nudge + grace threshold" ); vec![] @@ -2991,6 +3153,15 @@ impl LoopEngine { } } + fn emit_tool_errors(&self, results: &[ToolResult], stream: CycleStream<'_>) -> bool { + let mut has_errors = false; + for result in results.iter().filter(|result| !result.success) { + has_errors = true; + stream.tool_error(&result.tool_name, &result.output); + } + has_errors + } + fn publish_tool_result(&mut self, result: &ToolResult) { if result.success && result.tool_name == NOTIFY_TOOL_NAME { self.notify_called_this_cycle = true; @@ -3078,13 +3249,6 @@ impl LoopEngine { { return Some(CompactionTier::Emergency); } - if self.compaction_config.use_summarization - && self - .conversation_budget - .at_tier(messages, self.compaction_config.summarize_threshold) - { - return Some(CompactionTier::Summarize); - } if self .conversation_budget .at_tier(messages, self.compaction_config.slide_threshold) @@ -3147,6 +3311,30 @@ impl LoopEngine { ); } + fn collect_evicted_messages( + &self, + messages: &[Message], + evicted_indices: &[usize], + ) -> Vec { + evicted_indices + .iter() + .filter_map(|&index| messages.get(index).cloned()) + .collect() + } + + fn apply_session_memory_update(&self, update: SessionMemoryUpdate) { + let mut memory = self + .session_memory + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + if let Err(err) = memory.apply_update(update) { + tracing::warn!( + error = %err, + "auto-extracted memory update rejected (token cap)" + ); + } + } + async fn flush_evicted( &self, messages: &[Message], @@ -3157,17 +3345,19 @@ impl LoopEngine { return; } - let evicted: Vec = result - .evicted_indices - .iter() - .filter_map(|&index| messages.get(index).cloned()) - .collect(); - if evicted.is_empty() { - return; - } - + let evicted = self.collect_evicted_messages(messages, &result.evicted_indices); if let Some(flush) = &self.memory_flush { - if let Err(err) = flush.flush(&evicted, scope.as_str()).await { + let flush_result = if let Some(summary) = result.summary.as_deref() { + let summary = summary_message(summary); + flush + .flush(std::slice::from_ref(&summary), scope.as_str()) + .await + } else if evicted.is_empty() { + Ok(()) + } else { + flush.flush(&evicted, scope.as_str()).await + }; + if let Err(err) = flush_result { tracing::warn!( scope = scope.as_str(), error = %err, @@ -3182,10 +3372,21 @@ impl LoopEngine { } } - self.extract_memory_from_evicted(&evicted).await; + self.extract_memory_from_evicted(&evicted, result.summary.as_deref()) + .await; + } + + async fn extract_memory_from_evicted(&self, evicted: &[Message], summary: Option<&str>) { + if let Some(summary) = summary { + if let Some(update) = parse_summary_memory_update(summary) { + self.apply_session_memory_update(update); + return; + } + } + self.extract_memory_with_llm(evicted).await; } - async fn extract_memory_from_evicted(&self, evicted: &[Message]) { + async fn extract_memory_with_llm(&self, evicted: &[Message]) { let Some(llm) = &self.compaction_llm else { return; }; @@ -3197,16 +3398,7 @@ impl LoopEngine { match llm.generate(&prompt, 512).await { Ok(response) => { if let Some(update) = parse_extraction_response(&response) { - let mut memory = self - .session_memory - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - if let Err(err) = memory.apply_update(update) { - tracing::warn!( - error = %err, - "auto-extracted memory update rejected (token cap)" - ); - } + self.apply_session_memory_update(update); } } Err(err) => { @@ -3218,6 +3410,82 @@ impl LoopEngine { } } + async fn generate_eviction_summary( + &self, + messages: &[Message], + ) -> Result { + let llm = + self.compaction_llm + .as_ref() + .ok_or_else(|| CompactionError::SummarizationFailed { + source: Box::new(std::io::Error::other("no compaction LLM")), + })?; + generate_summary( + llm.as_ref(), + messages, + self.compaction_config.max_summary_tokens, + ) + .await + } + + fn summarized_compaction_result( + &self, + messages: &[Message], + plan: &SlideSummarizationPlan, + summary: String, + ) -> CompactionResult { + let compacted_messages = assemble_summarized_messages(messages, plan, &summary); + CompactionResult { + estimated_tokens: ConversationBudget::estimate_tokens(&compacted_messages), + messages: compacted_messages, + compacted_count: plan.evicted_messages.len(), + used_summarization: true, + summary: Some(summary), + evicted_indices: plan.evicted_indices.clone(), + } + } + + async fn apply_follow_up_slide( + &self, + result: CompactionResult, + target_tokens: usize, + scope: CompactionScope, + ) -> CompactionResult { + if result.estimated_tokens <= target_tokens { + return result; + } + + match self + .run_sliding_compaction(&result.messages, scope, target_tokens) + .await + { + Ok(follow_up) => Self::merge_summarized_follow_up(result, follow_up), + Err(error) => { + tracing::warn!( + scope = scope.as_str(), + tier = CompactionTier::Slide.as_str(), + error = ?error, + "follow-up slide after summarization failed; keeping summary result" + ); + result + } + } + } + + fn merge_summarized_follow_up( + base: CompactionResult, + follow_up: CompactionResult, + ) -> CompactionResult { + CompactionResult { + messages: follow_up.messages, + compacted_count: base.compacted_count + follow_up.compacted_count, + estimated_tokens: follow_up.estimated_tokens, + used_summarization: true, + summary: base.summary, + evicted_indices: base.evicted_indices, + } + } + async fn finish_tier<'a>( &self, tier: CompactionTier, @@ -3264,53 +3532,90 @@ impl LoopEngine { current } - async fn apply_slide_tier<'a>( + fn can_summarize_eviction(&self) -> bool { + self.compaction_config.use_summarization && self.compaction_llm.is_some() + } + + async fn summarize_before_slide( &self, - current: Cow<'a, [Message]>, + messages: &[Message], + target_tokens: usize, scope: CompactionScope, - iteration: u32, - ) -> Result, LoopError> { - let target_tokens = self.conversation_budget.compaction_target(); - match self - .run_sliding_compaction(current.as_ref(), scope, target_tokens) - .await - { - Ok(result) => Ok(self - .finish_tier( - CompactionTier::Slide, - current, - result, - scope, - Some(iteration), - target_tokens, - ) - .await), + ) -> Result { + let plan = slide_summarization_plan(messages, self.compaction_config.preserve_recent_turns) + .map_err(|error| compaction_failed_error(scope, error))?; + match self.generate_eviction_summary(&plan.evicted_messages).await { + Ok(summary) => { + let result = self.summarized_compaction_result(messages, &plan, summary); + Ok(self + .apply_follow_up_slide(result, target_tokens, scope) + .await) + } Err(error) => { tracing::warn!( scope = scope.as_str(), tier = CompactionTier::Slide.as_str(), - error = ?error, - "conversation compaction tier failed; continuing" + error = %error, + "pre-slide summarization failed; falling back to lossy slide" ); - Ok(current) + self.run_sliding_compaction(messages, scope, target_tokens) + .await + } + } + } + + async fn best_effort_emergency_summary( + &self, + messages: &[Message], + scope: CompactionScope, + ) -> Option { + let plan = slide_summarization_plan(messages, self.compaction_config.preserve_recent_turns) + .ok()?; + match tokio::time::timeout( + EMERGENCY_SUMMARY_TIMEOUT, + self.generate_eviction_summary(&plan.evicted_messages), + ) + .await + { + Ok(Ok(summary)) => Some(self.summarized_compaction_result(messages, &plan, summary)), + Ok(Err(error)) => { + tracing::warn!( + scope = scope.as_str(), + tier = CompactionTier::Emergency.as_str(), + error = %error, + "emergency summarization failed; falling back to mechanical emergency compaction" + ); + None + } + Err(_) => { + tracing::warn!( + scope = scope.as_str(), + tier = CompactionTier::Emergency.as_str(), + "emergency summarization timed out; falling back to mechanical emergency compaction" + ); + None } } } - async fn apply_summarize_tier<'a>( + async fn apply_slide_tier<'a>( &self, current: Cow<'a, [Message]>, scope: CompactionScope, iteration: u32, ) -> Result, LoopError> { - let target_tokens = self.conversation_budget.summarize_target(); - match self - .run_compaction_strategy(scope, current.as_ref(), target_tokens) - .await - { + let target_tokens = self.conversation_budget.compaction_target(); + let result = if self.can_summarize_eviction() { + self.summarize_before_slide(current.as_ref(), target_tokens, scope) + .await + } else { + self.run_sliding_compaction(current.as_ref(), scope, target_tokens) + .await + }; + match result { Ok(result) => Ok(self .finish_tier( - CompactionTier::Summarize, + CompactionTier::Slide, current, result, scope, @@ -3321,7 +3626,7 @@ impl LoopEngine { Err(error) => { tracing::warn!( scope = scope.as_str(), - tier = CompactionTier::Summarize.as_str(), + tier = CompactionTier::Slide.as_str(), error = ?error, "conversation compaction tier failed; continuing" ); @@ -3335,10 +3640,21 @@ impl LoopEngine { current: Cow<'a, [Message]>, scope: CompactionScope, ) -> Result, LoopError> { - let result = emergency_compact( - current.as_ref(), - self.compaction_config.preserve_recent_turns, - ); + let result = if self.can_summarize_eviction() { + self.best_effort_emergency_summary(current.as_ref(), scope) + .await + .unwrap_or_else(|| { + emergency_compact( + current.as_ref(), + self.compaction_config.preserve_recent_turns, + ) + }) + } else { + emergency_compact( + current.as_ref(), + self.compaction_config.preserve_recent_turns, + ) + }; Ok(self .finish_tier(CompactionTier::Emergency, current, result, scope, None, 0) .await) @@ -3354,11 +3670,9 @@ impl LoopEngine { let current = self.apply_prune_tier(current, scope); let current = match self.highest_compaction_tier(current.as_ref()) { Some(CompactionTier::Emergency) => self.apply_emergency_tier(current, scope).await?, - Some(tier @ (CompactionTier::Summarize | CompactionTier::Slide)) => { + Some(tier @ CompactionTier::Slide) => { if self.should_skip_compaction(scope, iteration, tier) { current - } else if matches!(tier, CompactionTier::Summarize) { - self.apply_summarize_tier(current, scope, iteration).await? } else { self.apply_slide_tier(current, scope, iteration).await? } @@ -3427,39 +3741,6 @@ impl LoopEngine { .map_err(|error| compaction_failed_error(scope, error)) } - async fn run_compaction_strategy( - &self, - scope: CompactionScope, - messages: &[Message], - target_tokens: usize, - ) -> Result { - match self - .conversation_compactor - .compact(messages, target_tokens) - .await - { - Ok(result) => Ok(result), - Err(CompactionError::SummarizationFailed { source }) => { - tracing::warn!( - error = %source, - scope = scope.as_str(), - "summarization compaction failed; trying sliding fallback" - ); - self.run_sliding_compaction(messages, scope, target_tokens) - .await - } - Err(CompactionError::SummaryExceededTarget) => { - tracing::warn!( - scope = scope.as_str(), - "summary exceeded compaction target; trying sliding fallback" - ); - self.run_sliding_compaction(messages, scope, target_tokens) - .await - } - Err(error) => Err(compaction_failed_error(scope, error)), - } - } - fn ensure_within_hard_limit( &self, scope: CompactionScope, @@ -3581,8 +3862,11 @@ impl LoopEngine { break; } + let continuation_tools = + self.apply_tool_round_progress_policy(round, &mut state.continuation_messages); + match self - .execute_tool_round(round + 1, llm, &mut state, stream) + .execute_tool_round(round + 1, llm, &mut state, continuation_tools, stream) .await? { ToolRoundOutcome::Cancelled => { @@ -3747,6 +4031,7 @@ impl LoopEngine { round: u32, llm: &dyn LlmProvider, state: &mut ToolRoundState, + continuation_tools: Vec, stream: CycleStream<'_>, ) -> Result { let round_started = current_time_ms(); @@ -3755,6 +4040,7 @@ impl LoopEngine { .execute_tool_calls_with_stream(&state.current_calls, stream) .await?; self.publish_tool_results(&results, stream); + let has_tool_errors = self.emit_tool_errors(&results, stream); self.record_tool_execution_cost(results.len()); let round_result_bytes: usize = results.iter().map(|r| r.output.len()).sum(); @@ -3766,6 +4052,16 @@ impl LoopEngine { &self.tool_call_provider_ids, &results, )?; + if has_tool_errors { + let failed: Vec<(&str, &str)> = results + .iter() + .filter(|result| !result.success) + .map(|result| (result.tool_name.as_str(), result.output.as_str())) + .collect(); + state + .continuation_messages + .push(Message::system(tool_error_relay_directive(&failed))); + } state.all_tool_results.extend(results); self.compact_tool_continuation(round, &mut state.continuation_messages) @@ -3786,6 +4082,7 @@ impl LoopEngine { .request_tool_continuation( llm, &state.continuation_messages, + continuation_tools, &mut state.tokens_used, stream, ) @@ -3913,13 +4210,14 @@ impl LoopEngine { &mut self, llm: &dyn LlmProvider, context_messages: &[Message], + continuation_tools: Vec, tokens_used: &mut TokenUsage, stream: CycleStream<'_>, ) -> Result { let request = build_continuation_request_with_notify_guidance( context_messages, llm.model_name(), - self.tool_executor.tool_definitions(), + continuation_tools, self.memory_context.as_deref(), self.scratchpad_context.as_deref(), self.thinking_config.clone(), @@ -4767,6 +5065,14 @@ fn unmatched_tool_call_id_error(result: &ToolResult) -> LoopError { ) } +fn completion_request_tools(tool_definitions: Vec) -> Vec { + if tool_definitions.is_empty() { + Vec::new() + } else { + tool_definitions_with_decompose(tool_definitions) + } +} + fn tool_definitions_with_decompose( mut tool_definitions: Vec, ) -> Vec { @@ -4841,7 +5147,7 @@ fn build_continuation_request_with_notify_guidance( thinking: Option, notify_tool_guidance_enabled: bool, ) -> CompletionRequest { - let tools = tool_definitions_with_decompose(tool_definitions); + let tools = completion_request_tools(tool_definitions); let system_prompt = build_tool_continuation_system_prompt_with_notify_guidance( memory_context, scratchpad_context, @@ -4892,7 +5198,7 @@ fn build_truncation_continuation_request_with_notify_guidance( thinking: Option, notify_tool_guidance_enabled: bool, ) -> CompletionRequest { - let tools = tool_definitions_with_decompose(tool_definitions); + let tools = completion_request_tools(tool_definitions); // Intentional: truncation continuations resume a cut-off response after context // overflow. They are not the post-tool-result path, so they keep the plain // reasoning prompt instead of the tool continuation directive. @@ -5423,13 +5729,7 @@ fn build_reasoning_request_with_notify_guidance( thinking: Option, notify_tool_guidance_enabled: bool, ) -> CompletionRequest { - // When tools are explicitly empty (e.g., stripped for tool-only turn enforcement), - // skip adding the decompose tool — the intent is to force a text-only response. - let tools = if tool_definitions.is_empty() { - vec![] - } else { - tool_definitions_with_decompose(tool_definitions) - }; + let tools = completion_request_tools(tool_definitions); let system_prompt = build_reasoning_system_prompt_with_notify_guidance( memory_context, scratchpad_context, @@ -7456,6 +7756,7 @@ mod phase2_tests { #[cfg(test)] mod phase4_tests { use super::*; + use crate::budget::{BudgetConfig, BudgetTracker, TerminationConfig}; use crate::cancellation::CancellationToken; use crate::input::{loop_input_channel, LoopCommand}; use async_trait::async_trait; @@ -7652,32 +7953,52 @@ mod phase4_tests { } fn p4_engine() -> LoopEngine { + p4_engine_with_config(BudgetConfig::default(), 3) + } + + fn p4_engine_with_config(config: BudgetConfig, max_iterations: u32) -> LoopEngine { LoopEngine::builder() - .budget(BudgetTracker::new( - crate::budget::BudgetConfig::default(), - 0, - 0, - )) + .budget(BudgetTracker::new(config, 0, 0)) .context(ContextCompactor::new(2048, 256)) - .max_iterations(3) + .max_iterations(max_iterations) .tool_executor(Arc::new(Phase4StubToolExecutor)) .synthesis_instruction("Summarize tool output".to_string()) .build() .expect("test engine build") } - fn p4_snapshot(text: &str) -> PerceptionSnapshot { - PerceptionSnapshot { - timestamp_ms: 1, - screen: ScreenState { - current_app: "terminal".to_string(), - elements: Vec::new(), - text_content: text.to_string(), - }, - notifications: Vec::new(), - active_app: "terminal".to_string(), - user_input: Some(UserInput { - text: text.to_string(), + fn has_tool_round_progress_nudge(messages: &[Message]) -> bool { + messages.iter().any(|message| { + message.content.iter().any(|block| match block { + ContentBlock::Text { text } => text.contains(TOOL_ROUND_PROGRESS_NUDGE), + _ => false, + }) + }) + } + + fn tool_round_budget_config(nudge_after: u16, strip_after_nudge: u16) -> BudgetConfig { + BudgetConfig { + termination: TerminationConfig { + tool_round_nudge_after: nudge_after, + tool_round_strip_after_nudge: strip_after_nudge, + ..TerminationConfig::default() + }, + ..BudgetConfig::default() + } + } + + fn p4_snapshot(text: &str) -> PerceptionSnapshot { + PerceptionSnapshot { + timestamp_ms: 1, + screen: ScreenState { + current_app: "terminal".to_string(), + elements: Vec::new(), + text_content: text.to_string(), + }, + notifications: Vec::new(), + active_app: "terminal".to_string(), + user_input: Some(UserInput { + text: text.to_string(), source: InputSource::Text, timestamp: 1, context_id: None, @@ -7874,6 +8195,196 @@ mod phase4_tests { ); } + #[tokio::test] + async fn act_with_tools_nudges_after_threshold() { + let config = tool_round_budget_config(1, 10); + let mut engine = p4_engine_with_config(config, 3); + let decision = Decision::UseTools(vec![read_file_call("call-1", "a.txt")]); + let llm = Phase4MockLlm::new(vec![ + tool_use_response(vec![read_file_call("call-2", "b.txt")]), + text_response("done after nudge"), + ]); + let context_messages = vec![Message::user("read files")]; + + let _action = engine + .act_with_tools( + &decision, + calls_from_decision(&decision), + &llm, + &context_messages, + CycleStream::disabled(), + ) + .await + .expect("act_with_tools"); + + let requests = llm.requests(); + assert_eq!(requests.len(), 2); + assert!(!has_tool_round_progress_nudge(&requests[0].messages)); + assert!(has_tool_round_progress_nudge(&requests[1].messages)); + } + + #[tokio::test] + async fn act_with_tools_strips_tools_after_threshold() { + let config = tool_round_budget_config(1, 1); + let mut engine = p4_engine_with_config(config, 4); + let decision = Decision::UseTools(vec![read_file_call("call-1", "a.txt")]); + let llm = Phase4MockLlm::new(vec![ + tool_use_response(vec![read_file_call("call-2", "b.txt")]), + tool_use_response(vec![read_file_call("call-3", "c.txt")]), + text_response("done after strip"), + ]); + let context_messages = vec![Message::user("read files")]; + + let _action = engine + .act_with_tools( + &decision, + calls_from_decision(&decision), + &llm, + &context_messages, + CycleStream::disabled(), + ) + .await + .expect("act_with_tools"); + + let requests = llm.requests(); + assert_eq!(requests.len(), 3); + assert!(!requests[1].tools.is_empty()); + assert!(requests[2].tools.is_empty()); + } + + #[tokio::test] + async fn act_with_tools_no_nudge_when_disabled() { + let config = tool_round_budget_config(0, 2); + let mut engine = p4_engine_with_config(config, 4); + let decision = Decision::UseTools(vec![read_file_call("call-1", "a.txt")]); + let llm = Phase4MockLlm::new(vec![ + tool_use_response(vec![read_file_call("call-2", "b.txt")]), + tool_use_response(vec![read_file_call("call-3", "c.txt")]), + text_response("done without nudge"), + ]); + let context_messages = vec![Message::user("read files")]; + + let _action = engine + .act_with_tools( + &decision, + calls_from_decision(&decision), + &llm, + &context_messages, + CycleStream::disabled(), + ) + .await + .expect("act_with_tools"); + + let requests = llm.requests(); + assert!(requests.iter().all(|request| { + !has_tool_round_progress_nudge(&request.messages) && !request.tools.is_empty() + })); + } + + #[tokio::test] + async fn act_with_tools_aggressive_config() { + let config = tool_round_budget_config(1, 0); + let mut engine = p4_engine_with_config(config, 3); + let decision = Decision::UseTools(vec![read_file_call("call-1", "a.txt")]); + let llm = Phase4MockLlm::new(vec![ + tool_use_response(vec![read_file_call("call-2", "b.txt")]), + text_response("done after aggressive strip"), + ]); + let context_messages = vec![Message::user("read files")]; + + let _action = engine + .act_with_tools( + &decision, + calls_from_decision(&decision), + &llm, + &context_messages, + CycleStream::disabled(), + ) + .await + .expect("act_with_tools"); + + let requests = llm.requests(); + assert_eq!(requests.len(), 2); + assert!(has_tool_round_progress_nudge(&requests[1].messages)); + assert!(requests[1].tools.is_empty()); + } + + #[tokio::test] + async fn act_with_tools_no_nudge_before_threshold() { + let config = tool_round_budget_config(2, 2); + let mut engine = p4_engine_with_config(config, 3); + let decision = Decision::UseTools(vec![read_file_call("call-1", "a.txt")]); + let llm = Phase4MockLlm::new(vec![ + tool_use_response(vec![read_file_call("call-2", "b.txt")]), + text_response("done before threshold"), + ]); + let context_messages = vec![Message::user("read files")]; + + let _action = engine + .act_with_tools( + &decision, + calls_from_decision(&decision), + &llm, + &context_messages, + CycleStream::disabled(), + ) + .await + .expect("act_with_tools"); + + let requests = llm.requests(); + assert_eq!(requests.len(), 2); + assert!(!has_tool_round_progress_nudge(&requests[1].messages)); + } + + #[tokio::test] + async fn act_with_tools_nudge_fires_exactly_once() { + // With nudge_after=1 and strip_after=3, the model runs 3 rounds past + // the nudge threshold. Verify the nudge message appears exactly once + // (not stacked on every round). + let config = tool_round_budget_config(1, 3); + let mut engine = p4_engine_with_config(config, 5); + let decision = Decision::UseTools(vec![read_file_call("call-1", "a.txt")]); + let llm = Phase4MockLlm::new(vec![ + tool_use_response(vec![read_file_call("call-2", "b.txt")]), + tool_use_response(vec![read_file_call("call-3", "c.txt")]), + tool_use_response(vec![read_file_call("call-4", "d.txt")]), + text_response("done after strip"), + ]); + let context_messages = vec![Message::user("read files")]; + + let _action = engine + .act_with_tools( + &decision, + calls_from_decision(&decision), + &llm, + &context_messages, + CycleStream::disabled(), + ) + .await + .expect("act_with_tools"); + + let requests = llm.requests(); + // The last request has the full continuation_messages history. + // Count nudge messages in it — should be exactly 1 (not stacked). + let last_request = requests.last().expect("should have requests"); + let nudge_count = last_request + .messages + .iter() + .filter(|m| { + m.content.iter().any(|block| { + matches!( + block, + ContentBlock::Text { text } if text.contains(TOOL_ROUND_PROGRESS_NUDGE) + ) + }) + }) + .count(); + assert_eq!( + nudge_count, 1, + "nudge should appear exactly once, not stack" + ); + } + #[tokio::test] async fn act_with_tools_falls_back_to_synthesis_on_max_iterations() { let mut engine = LoopEngine::builder() @@ -11559,7 +12070,7 @@ mod context_compaction_tests { CompactionConfig { slide_threshold: 0.60, prune_threshold: 0.40, - summarize_threshold: 0.80, + _legacy_summarize_threshold: 0.80, emergency_threshold: 0.95, preserve_recent_turns: 2, model_context_limit: 5_096, @@ -11735,7 +12246,7 @@ mod context_compaction_tests { CompactionConfig { slide_threshold: 0.2, prune_threshold: 0.1, - summarize_threshold: 0.8, + _legacy_summarize_threshold: 0.8, emergency_threshold: 0.95, preserve_recent_turns: 2, model_context_limit: 5_000, @@ -11858,10 +12369,6 @@ mod context_compaction_tests { engine.compaction_config.prune_threshold, defaults.prune_threshold ); - assert_eq!( - engine.compaction_config.summarize_threshold, - defaults.summarize_threshold - ); assert_eq!( engine.compaction_config.emergency_threshold, defaults.emergency_threshold @@ -11897,13 +12404,47 @@ mod context_compaction_tests { assert!(engine.session_memory_snapshot().is_empty()); } + #[test] + fn builder_applies_context_scaled_session_memory_caps() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let config = CompactionConfig { + model_context_limit: 200_000, + ..CompactionConfig::default() + }; + let memory = Arc::new(Mutex::new(SessionMemory::default())); + let engine = LoopEngine::builder() + .budget(BudgetTracker::new( + crate::budget::BudgetConfig::default(), + current_time_ms(), + 0, + )) + .context(ContextCompactor::new(2_048, 256)) + .max_iterations(4) + .tool_executor(executor) + .synthesis_instruction("synthesize".to_string()) + .compaction_config(config.clone()) + .session_memory(Arc::clone(&memory)) + .build() + .expect("test engine build"); + + let stored = engine.session_memory_snapshot(); + assert_eq!( + stored.token_cap(), + fx_session::max_memory_tokens(config.model_context_limit) + ); + assert_eq!( + stored.item_cap(), + fx_session::max_memory_items(config.model_context_limit) + ); + } + #[test] fn builder_full_config() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); let config = CompactionConfig { slide_threshold: 0.3, prune_threshold: 0.2, - summarize_threshold: 0.4, + _legacy_summarize_threshold: 0.4, emergency_threshold: 0.9, preserve_recent_turns: 3, model_context_limit: 5_200, @@ -11976,14 +12517,12 @@ mod context_compaction_tests { } #[test] - fn build_compaction_components_default_to_valid_budget_and_strategy() { - let (config, budget, _strategy) = - build_compaction_components(None, None).expect("components should build"); + fn build_compaction_components_default_to_valid_budget() { + let (config, budget) = build_compaction_components(None).expect("components should build"); let defaults = CompactionConfig::default(); assert_eq!(config.slide_threshold, defaults.slide_threshold); assert_eq!(config.prune_threshold, defaults.prune_threshold); - assert_eq!(config.summarize_threshold, defaults.summarize_threshold); assert_eq!(config.emergency_threshold, defaults.emergency_threshold); assert_eq!(config.preserve_recent_turns, defaults.preserve_recent_turns); assert_eq!( @@ -11999,8 +12538,7 @@ mod context_compaction_tests { let mut config = CompactionConfig::default(); config.recompact_cooldown_turns = 0; - let error = - build_compaction_components(Some(config), None).expect_err("invalid config rejected"); + let error = build_compaction_components(Some(config)).expect_err("invalid config rejected"); assert_eq!(error.stage, "init"); assert!(error.reason.contains("invalid_compaction_config")); } @@ -12012,13 +12550,22 @@ mod context_compaction_tests { struct ExtractionLlm { responses: Mutex>>, prompts: Mutex>, + delay: Option, } impl ExtractionLlm { fn new(responses: Vec>) -> Self { + Self::with_delay(responses, None) + } + + fn with_delay( + responses: Vec>, + delay: Option, + ) -> Self { Self { responses: Mutex::new(VecDeque::from(responses)), prompts: Mutex::new(Vec::new()), + delay, } } @@ -12034,6 +12581,9 @@ mod context_compaction_tests { .lock() .expect("prompts lock") .push(prompt.to_string()); + if let Some(delay) = self.delay { + tokio::time::sleep(delay).await; + } self.responses .lock() .expect("responses lock") @@ -12147,6 +12697,36 @@ mod context_compaction_tests { } } + #[derive(Debug, Default)] + struct FailingToolRoundExecutor; + + #[async_trait] + impl ToolExecutor for FailingToolRoundExecutor { + async fn execute_tools( + &self, + calls: &[ToolCall], + _cancel: Option<&CancellationToken>, + ) -> Result, crate::act::ToolExecutorError> { + Ok(calls + .iter() + .map(|call| ToolResult { + tool_call_id: call.id.clone(), + tool_name: call.name.clone(), + success: false, + output: "permission denied".to_string(), + }) + .collect()) + } + + fn tool_definitions(&self) -> Vec { + vec![ToolDefinition { + name: "read_file".to_string(), + description: "read file".to_string(), + parameters: serde_json::json!({"type":"object"}), + }] + } + } + #[tokio::test] async fn long_conversation_triggers_compaction_in_perceive() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); @@ -12175,8 +12755,9 @@ mod context_compaction_tests { let calls = vec![read_call("call-1")]; let mut state = ToolRoundState::new(&calls, &large_history(12, 70)); + let tools = engine.tool_executor.tool_definitions(); let _ = engine - .execute_tool_round(1, &llm, &mut state, CycleStream::disabled()) + .execute_tool_round(1, &llm, &mut state, tools, CycleStream::disabled()) .await .expect("tool round"); @@ -12195,8 +12776,9 @@ mod context_compaction_tests { let calls = vec![read_call("call-1")]; let mut state = ToolRoundState::new(&calls, &large_history(12, 70)); + let tools = engine.tool_executor.tool_definitions(); engine - .execute_tool_round(1, &llm, &mut state, CycleStream::disabled()) + .execute_tool_round(1, &llm, &mut state, tools, CycleStream::disabled()) .await .expect("tool round"); @@ -12204,28 +12786,129 @@ mod context_compaction_tests { assert_eq!(engine.last_reasoning_messages, state.continuation_messages); } + fn stream_recorder() -> (StreamCallback, Arc>>) { + let events: Arc>> = Arc::new(Mutex::new(Vec::new())); + let captured = Arc::clone(&events); + let callback: StreamCallback = Arc::new(move |event| { + captured.lock().expect("lock").push(event); + }); + (callback, events) + } + #[tokio::test] - async fn decompose_child_receives_compacted_context() { - let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); - let engine = engine_with( + async fn tool_error_event_emitted_on_failure() { + let executor: Arc = Arc::new(FailingToolRoundExecutor); + let mut engine = engine_with( ContextCompactor::new(2_048, 256), executor, compaction_config(), ); - let llm = RecordingLlm::new(vec![Ok(text_response("child done"))]); - let goal = SubGoal { - description: "child task".to_string(), - required_tools: Vec::new(), - expected_output: None, - complexity_hint: None, - }; - let child_budget = BudgetConfig::default(); + let llm = RecordingLlm::ok(vec![text_response("done")]); + let calls = vec![read_call("call-1")]; + let mut state = ToolRoundState::new(&calls, &[Message::user("read file")]); + let (callback, events) = stream_recorder(); - let _execution = engine - .run_sub_goal(&goal, child_budget, &llm, &large_history(10, 60)) - .await; + engine + .execute_tool_round( + 1, + &llm, + &mut state, + Vec::new(), + CycleStream::enabled(&callback), + ) + .await + .expect("tool round"); - let requests = llm.requests(); + let events = events.lock().expect("lock").clone(); + assert!(events.contains(&StreamEvent::ToolError { + tool_name: "read_file".to_string(), + error: "permission denied".to_string(), + })); + } + + #[tokio::test] + async fn tool_error_directive_injected_on_failure() { + let executor: Arc = Arc::new(FailingToolRoundExecutor); + let mut engine = engine_with( + ContextCompactor::new(2_048, 256), + executor, + compaction_config(), + ); + let llm = RecordingLlm::ok(vec![text_response("done")]); + let calls = vec![read_call("call-1")]; + let mut state = ToolRoundState::new(&calls, &[Message::user("read file")]); + + engine + .execute_tool_round(1, &llm, &mut state, Vec::new(), CycleStream::disabled()) + .await + .expect("tool round"); + + let relay_message = state + .continuation_messages + .iter() + .map(message_to_text) + .find(|text| text.contains(TOOL_ERROR_RELAY_PREFIX)) + .expect("tool error relay message"); + assert!(relay_message.contains("- Tool 'read_file' failed with: permission denied")); + } + + #[tokio::test] + async fn no_tool_error_on_success() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 5 }); + let mut engine = engine_with( + ContextCompactor::new(2_048, 256), + executor, + compaction_config(), + ); + let llm = RecordingLlm::ok(vec![text_response("done")]); + let calls = vec![read_call("call-1")]; + let mut state = ToolRoundState::new(&calls, &[Message::user("read file")]); + let (callback, events) = stream_recorder(); + + engine + .execute_tool_round( + 1, + &llm, + &mut state, + Vec::new(), + CycleStream::enabled(&callback), + ) + .await + .expect("tool round"); + + let events = events.lock().expect("lock").clone(); + assert!(!events + .iter() + .any(|event| matches!(event, StreamEvent::ToolError { .. }))); + assert!(!state + .continuation_messages + .iter() + .map(message_to_text) + .any(|text| text.contains(TOOL_ERROR_RELAY_PREFIX))); + } + + #[tokio::test] + async fn decompose_child_receives_compacted_context() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let engine = engine_with( + ContextCompactor::new(2_048, 256), + executor, + compaction_config(), + ); + let llm = RecordingLlm::new(vec![Ok(text_response("child done"))]); + let goal = SubGoal { + description: "child task".to_string(), + required_tools: Vec::new(), + expected_output: None, + complexity_hint: None, + }; + let child_budget = BudgetConfig::default(); + + let _execution = engine + .run_sub_goal(&goal, child_budget, &llm, &large_history(10, 60)) + .await; + + let requests = llm.requests(); assert!(!requests.is_empty()); assert!(has_compaction_marker(&requests[0].messages)); } @@ -12285,11 +12968,10 @@ mod context_compaction_tests { #[tokio::test] async fn session_memory_injected_in_context() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); - let memory = Arc::new(Mutex::new(SessionMemory { - project: Some("Phase 3".to_string()), - current_state: Some("testing injection".to_string()), - ..SessionMemory::default() - })); + let mut stored_memory = SessionMemory::default(); + stored_memory.project = Some("Phase 3".to_string()); + stored_memory.current_state = Some("testing injection".to_string()); + let memory = Arc::new(Mutex::new(stored_memory)); let mut engine = LoopEngine::builder() .budget(BudgetTracker::new( crate::budget::BudgetConfig::default(), @@ -12539,7 +13221,7 @@ mod context_compaction_tests { Message::assistant("LoopEngine needs automatic extraction."), ]; - engine.extract_memory_from_evicted(&evicted).await; + engine.extract_memory_from_evicted(&evicted, None).await; let memory = engine.session_memory_snapshot(); assert_eq!(memory.project.as_deref(), Some("Phase 5")); @@ -12566,7 +13248,7 @@ mod context_compaction_tests { ); engine - .extract_memory_from_evicted(&[Message::user("remember this")]) + .extract_memory_from_evicted(&[Message::user("remember this")], None) .await; assert!(engine.session_memory_snapshot().is_empty()); @@ -12586,7 +13268,7 @@ mod context_compaction_tests { ); engine - .extract_memory_from_evicted(&[Message::user("remember this")]) + .extract_memory_from_evicted(&[Message::user("remember this")], None) .await; assert!(engine.session_memory_snapshot().is_empty()); @@ -12604,7 +13286,7 @@ mod context_compaction_tests { ); engine - .extract_memory_from_evicted(&[Message::user("remember this")]) + .extract_memory_from_evicted(&[Message::user("remember this")], None) .await; assert!(engine.session_memory_snapshot().is_empty()); @@ -12624,12 +13306,83 @@ mod context_compaction_tests { ); engine - .extract_memory_from_evicted(&[Message::user("remember this")]) + .extract_memory_from_evicted(&[Message::user("remember this")], None) .await; assert!(engine.session_memory_snapshot().is_empty()); } + #[tokio::test] + async fn extract_memory_from_summary_falls_back_to_llm_when_parsing_fails() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let llm = Arc::new(ExtractionLlm::new(vec![Ok(serde_json::json!({ + "project": "Phase 2", + "current_state": "LLM fallback after malformed summary" + }) + .to_string())])); + let engine = engine_with_compaction_llm( + ContextCompactor::new(2_048, 256), + executor, + compaction_config(), + Arc::clone(&llm) as Arc, + ); + + engine + .extract_memory_from_evicted( + &[Message::user("remember this")], + Some("freeform summary without section headers"), + ) + .await; + + let memory = engine.session_memory_snapshot(); + assert_eq!(memory.project.as_deref(), Some("Phase 2")); + assert_eq!( + memory.current_state.as_deref(), + Some("LLM fallback after malformed summary") + ); + assert_eq!(llm.prompts().len(), 1); + assert!(llm.prompts()[0].contains("Conversation:")); + } + + #[tokio::test] + async fn extract_memory_from_numbered_summary_skips_llm_fallback() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let llm = Arc::new(ExtractionLlm::new(vec![Ok("{}".to_string())])); + let engine = engine_with_compaction_llm( + ContextCompactor::new(2_048, 256), + executor, + compaction_config(), + Arc::clone(&llm) as Arc, + ); + let summary = concat!( + "1. Decisions:\n", + "- summarize before slide\n", + "2. Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "3. Task state:\n", + "- preserving summary context\n", + "4. Key context:\n", + "- no second LLM call needed" + ); + + engine + .extract_memory_from_evicted(&[Message::user("remember this")], Some(summary)) + .await; + + let memory = engine.session_memory_snapshot(); + assert_eq!( + memory.current_state.as_deref(), + Some("preserving summary context") + ); + assert_eq!(memory.key_decisions, vec!["summarize before slide"]); + assert_eq!( + memory.active_files, + vec!["engine/crates/fx-kernel/src/loop_engine.rs"] + ); + assert_eq!(memory.custom_context, vec!["no second LLM call needed"]); + assert!(llm.prompts().is_empty()); + } + #[test] fn build_extraction_prompt_formats_messages() { let prompt = build_extraction_prompt(&[ @@ -12669,6 +13422,82 @@ mod context_compaction_tests { assert!(parse_extraction_response("definitely not json").is_none()); } + #[test] + fn parse_summary_memory_update_extracts_sections() { + let summary = concat!( + "Decisions:\n", + "- Use summarize-before-slide\n", + "Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "Task state:\n", + "- Implementing Phase 2\n", + "Key context:\n", + "- Preserve summary markers during follow-up slide" + ); + + let update = parse_summary_memory_update(summary).expect("summary parse"); + + assert_eq!(update.project, None); + assert_eq!( + update.current_state.as_deref(), + Some("Implementing Phase 2") + ); + assert_eq!( + update.key_decisions, + Some(vec!["Use summarize-before-slide".to_string()]) + ); + assert_eq!( + update.active_files, + Some(vec![ + "engine/crates/fx-kernel/src/loop_engine.rs".to_string() + ]) + ); + assert_eq!( + update.custom_context, + Some(vec![ + "Preserve summary markers during follow-up slide".to_string() + ]) + ); + } + + #[test] + fn parse_summary_memory_update_extracts_numbered_sections() { + let summary = concat!( + "1. Decisions:\n", + "- Use summarize-before-slide\n", + "2. Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "3. Task state:\n", + "- Implementing Phase 2\n", + "4. Key context:\n", + "- Preserve summary markers during follow-up slide" + ); + + let update = parse_summary_memory_update(summary).expect("summary parse"); + + assert_eq!(update.project, None); + assert_eq!( + update.current_state.as_deref(), + Some("Implementing Phase 2") + ); + assert_eq!( + update.key_decisions, + Some(vec!["Use summarize-before-slide".to_string()]) + ); + assert_eq!( + update.active_files, + Some(vec![ + "engine/crates/fx-kernel/src/loop_engine.rs".to_string() + ]) + ); + assert_eq!( + update.custom_context, + Some(vec![ + "Preserve summary markers during follow-up slide".to_string() + ]) + ); + } + #[tokio::test] async fn flush_evicted_triggers_extraction() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); @@ -12713,6 +13542,74 @@ mod context_compaction_tests { assert_eq!(llm.prompts().len(), 1); } + #[tokio::test] + async fn flush_evicted_uses_summary_for_flush_and_memory_extraction() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let flush = Arc::new(RecordingMemoryFlush::default()); + let summary = concat!( + "Decisions:\n", + "- summarize before slide\n", + "Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "Task state:\n", + "- preserving old context\n", + "Key context:\n", + "- summary markers stay protected" + ); + let llm = Arc::new(ExtractionLlm::new(vec![Ok(summary.to_string())])); + let mut config = tiered_compaction_config(true); + config.prune_tool_blocks = false; + let engine = LoopEngine::builder() + .budget(BudgetTracker::new( + crate::budget::BudgetConfig::default(), + current_time_ms(), + 0, + )) + .context(ContextCompactor::new(2_048, 256)) + .max_iterations(4) + .tool_executor(executor) + .synthesis_instruction("synthesize".to_string()) + .compaction_config(config) + .compaction_llm(Arc::clone(&llm) as Arc) + .memory_flush(Arc::clone(&flush) as Arc) + .build() + .expect("test engine build"); + let messages = vec![ + Message::user(format!("older decision {}", words(199))), + Message::assistant(format!("older file change {}", words(199))), + Message::user(format!("recent state {}", words(124))), + Message::assistant(format!("recent context {}", words(124))), + ]; + + let compacted = engine + .compact_if_needed(&messages, CompactionScope::Perceive, 1) + .await + .expect("compaction should succeed"); + + assert!(has_conversation_summary_marker(compacted.as_ref())); + assert!(!has_compaction_marker(compacted.as_ref())); + let calls = flush.calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].scope, "perceive"); + assert_eq!(calls[0].evicted.len(), 1); + assert!(message_to_text(&calls[0].evicted[0]).contains("[context summary]")); + let memory = engine.session_memory_snapshot(); + assert_eq!( + memory.current_state.as_deref(), + Some("preserving old context") + ); + assert_eq!(memory.key_decisions, vec!["summarize before slide"]); + assert_eq!( + memory.active_files, + vec!["engine/crates/fx-kernel/src/loop_engine.rs"] + ); + assert_eq!( + memory.custom_context, + vec!["summary markers stay protected"] + ); + assert_eq!(llm.prompts().len(), 1); + } + #[tokio::test] async fn tiered_compaction_prune_only() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); @@ -12749,7 +13646,7 @@ mod context_compaction_tests { let messages = vec![user(200), assistant(200), user(125), assistant(125)]; let usage = budget.usage_ratio(&messages); - assert!(usage > 0.60 && usage < 0.80, "usage ratio was {usage}"); + assert!(usage > 0.60 && usage < 0.95, "usage ratio was {usage}"); let compacted = engine .compact_if_needed(&messages, CompactionScope::Perceive, 10) @@ -12762,15 +13659,61 @@ mod context_compaction_tests { } #[tokio::test] - async fn tiered_compaction_summarize_when_slide_insufficient() { + async fn slide_tier_summarizes_before_eviction_when_llm_available() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); - let config = tiered_compaction_config(true); + let summary = concat!( + "Decisions:\n", + "- preserve older context\n", + "Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "Task state:\n", + "- summary inserted before slide\n", + "Key context:\n", + "- older messages remain recoverable" + ); + let llm = Arc::new(ExtractionLlm::new(vec![Ok(summary.to_string())])); + let mut config = tiered_compaction_config(true); + config.prune_tool_blocks = false; + let budget = tiered_budget(&config); + let engine = engine_with_compaction_llm( + ContextCompactor::new(2_048, 256), + executor, + config, + Arc::clone(&llm) as Arc, + ); + let messages = vec![ + Message::user(format!("older plan {}", words(199))), + Message::assistant(format!("older file {}", words(199))), + Message::user(format!("recent state {}", words(124))), + Message::assistant(format!("recent context {}", words(124))), + ]; + + let usage = budget.usage_ratio(&messages); + assert!(usage > 0.60 && usage < 0.95, "usage ratio was {usage}"); + + let compacted = engine + .compact_if_needed(&messages, CompactionScope::Perceive, 10) + .await + .expect("slide compaction"); + + assert!(has_conversation_summary_marker(compacted.as_ref())); + assert!(!has_compaction_marker(compacted.as_ref())); + let prompts = llm.prompts(); + assert_eq!(prompts.len(), 1); + assert!(prompts[0].contains("older plan")); + assert!(prompts[0].contains("older file")); + } + + #[tokio::test] + async fn slide_tier_falls_back_to_lossy_slide_when_summary_fails() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let llm = Arc::new(ExtractionLlm::new(vec![ + Err(CoreLlmError::ApiRequest("boom".to_string())), + Err(CoreLlmError::ApiRequest("boom".to_string())), + ])); + let mut config = tiered_compaction_config(true); + config.prune_tool_blocks = false; let budget = tiered_budget(&config); - let llm: Arc = Arc::new(RecordingLlm::with_generated_summary( - Vec::new(), - "Decisions:\n- keep\nFiles modified:\n- none\nTask state:\n- active\nKey context:\n- summarized" - .to_string(), - )); let engine = engine_with_compaction_llm(ContextCompactor::new(2_048, 256), executor, config, llm); let messages = vec![user(250), assistant(250), user(175), assistant(175)]; @@ -12781,10 +13724,31 @@ mod context_compaction_tests { let compacted = engine .compact_if_needed(&messages, CompactionScope::Perceive, 10) .await - .expect("summarizing compaction"); + .expect("slide compaction"); - assert!(has_conversation_summary_marker(compacted.as_ref())); - assert!(!has_compaction_marker(compacted.as_ref())); + assert!(has_compaction_marker(compacted.as_ref())); + assert!(!has_conversation_summary_marker(compacted.as_ref())); + assert!(!has_emergency_compaction_marker(compacted.as_ref())); + } + + #[tokio::test] + async fn slide_tier_falls_back_to_lossy_slide_without_compaction_llm() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let config = tiered_compaction_config(true); + let budget = tiered_budget(&config); + let engine = engine_with(ContextCompactor::new(2_048, 256), executor, config); + let messages = vec![user(250), assistant(250), user(175), assistant(175)]; + + let usage = budget.usage_ratio(&messages); + assert!(usage > 0.80 && usage < 0.95, "usage ratio was {usage}"); + + let compacted = engine + .compact_if_needed(&messages, CompactionScope::Perceive, 10) + .await + .expect("slide compaction"); + + assert!(has_compaction_marker(compacted.as_ref())); + assert!(!has_conversation_summary_marker(compacted.as_ref())); assert!(!has_emergency_compaction_marker(compacted.as_ref())); } @@ -12808,6 +13772,87 @@ mod context_compaction_tests { assert!(!has_conversation_summary_marker(compacted.as_ref())); } + #[tokio::test] + async fn emergency_tier_uses_summary_when_llm_is_fast_enough() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let summary = concat!( + "Decisions:\n", + "- capture emergency context\n", + "Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "Task state:\n", + "- emergency summary completed\n", + "Key context:\n", + "- fallback count marker avoided" + ); + let llm = Arc::new(ExtractionLlm::new(vec![Ok(summary.to_string())])); + let mut config = tiered_compaction_config(true); + config.prune_tool_blocks = false; + let budget = tiered_budget(&config); + let engine = engine_with_compaction_llm( + ContextCompactor::new(2_048, 256), + executor, + config, + Arc::clone(&llm) as Arc, + ); + let messages = vec![user(250), assistant(250), user(230), assistant(230)]; + + let usage = budget.usage_ratio(&messages); + assert!(usage > 0.95, "usage ratio was {usage}"); + + let compacted = engine + .compact_if_needed(&messages, CompactionScope::Perceive, 10) + .await + .expect("emergency compaction"); + + assert!(has_conversation_summary_marker(compacted.as_ref())); + assert!(!has_emergency_compaction_marker(compacted.as_ref())); + assert_eq!(llm.prompts().len(), 1); + } + + #[tokio::test] + async fn emergency_tier_attempts_best_effort_summary_before_fallback() { + let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); + let summary = concat!( + "Decisions:\n", + "- capture emergency context\n", + "Files modified:\n", + "- engine/crates/fx-kernel/src/loop_engine.rs\n", + "Task state:\n", + "- timeout fallback\n", + "Key context:\n", + "- summary was too slow" + ); + let llm = Arc::new(ExtractionLlm::with_delay( + vec![Ok(summary.to_string()), Ok("{}".to_string())], + Some(EMERGENCY_SUMMARY_TIMEOUT + std::time::Duration::from_millis(10)), + )); + let mut config = tiered_compaction_config(true); + config.prune_tool_blocks = false; + let budget = tiered_budget(&config); + let engine = engine_with_compaction_llm( + ContextCompactor::new(2_048, 256), + executor, + config, + Arc::clone(&llm) as Arc, + ); + let messages = vec![user(250), assistant(250), user(230), assistant(230)]; + + let usage = budget.usage_ratio(&messages); + assert!(usage > 0.95, "usage ratio was {usage}"); + + let compacted = engine + .compact_if_needed(&messages, CompactionScope::Perceive, 10) + .await + .expect("emergency compaction"); + + assert!(has_emergency_compaction_marker(compacted.as_ref())); + assert!(!has_conversation_summary_marker(compacted.as_ref())); + let prompts = llm.prompts(); + assert!(!prompts.is_empty()); + assert!(prompts[0].contains("Sections (required):")); + } + #[tokio::test] async fn compact_if_needed_emergency_tier_preserves_tool_pairs() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); @@ -12849,16 +13894,10 @@ mod context_compaction_tests { } #[tokio::test] - async fn cooldown_skips_slide_and_summarize() { + async fn cooldown_skips_slide_but_allows_emergency() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); let config = tiered_compaction_config(true); - let llm: Arc = Arc::new(RecordingLlm::with_generated_summary( - Vec::new(), - "Decisions:\n- keep\nFiles modified:\n- none\nTask state:\n- active\nKey context:\n- summarized" - .to_string(), - )); - let engine = - engine_with_compaction_llm(ContextCompactor::new(2_048, 256), executor, config, llm); + let engine = engine_with(ContextCompactor::new(2_048, 256), executor, config); let slide_input = vec![user(200), assistant(200), user(125), assistant(125)]; let first = engine @@ -12871,11 +13910,6 @@ mod context_compaction_tests { 11, CompactionTier::Slide )); - assert!(engine.should_skip_compaction( - CompactionScope::Perceive, - 11, - CompactionTier::Summarize - )); let emergency_input = vec![user(250), assistant(250), user(230), assistant(230)]; let second = engine @@ -12960,40 +13994,24 @@ mod context_compaction_tests { } #[tokio::test] - async fn summary_exceeded_target_falls_back_to_sliding_compactor() { + async fn legacy_summarize_threshold_does_not_trigger_compaction_below_slide_threshold() { let executor: Arc = Arc::new(SizedToolExecutor { output_words: 20 }); - let mut config = compaction_config(); - config.use_summarization = true; - config.summarize_threshold = 0.3; - let llm: Arc = Arc::new(RecordingLlm::with_generated_summary( - Vec::new(), - words(2_000), - )); + let mut config = tiered_compaction_config(true); + config.slide_threshold = 0.80; + config._legacy_summarize_threshold = 0.30; + let budget = tiered_budget(&config); + let engine = engine_with(ContextCompactor::new(2_048, 256), executor, config); + let messages = vec![user(125), assistant(125), user(125), assistant(125)]; - let engine = LoopEngine::builder() - .budget(BudgetTracker::new( - crate::budget::BudgetConfig::default(), - current_time_ms(), - 0, - )) - .context(ContextCompactor::new(2_048, 256)) - .max_iterations(4) - .tool_executor(executor) - .synthesis_instruction("synthesize".to_string()) - .compaction_config(config) - .compaction_llm(llm) - .build() - .expect("test engine build"); + let usage = budget.usage_ratio(&messages); + assert!(usage > 0.30 && usage < 0.80, "usage ratio was {usage}"); - let history = large_history(10, 70); let compacted = engine - .compact_if_needed(&history, CompactionScope::Perceive, 1) + .compact_if_needed(&messages, CompactionScope::Perceive, 1) .await - .expect("compaction should fall back to sliding window"); + .expect("legacy summarize threshold should be ignored"); - assert!(has_compaction_marker(compacted.as_ref())); - assert!(!has_conversation_summary_marker(compacted.as_ref())); - assert!(!has_emergency_compaction_marker(compacted.as_ref())); + assert_eq!(compacted.as_ref(), messages.as_slice()); } #[tokio::test] @@ -14293,7 +15311,7 @@ mod loop_resilience_tests { #[tokio::test] async fn tool_only_turn_nudge_injected_at_threshold() { let mut engine = high_budget_engine(); - engine.consecutive_tool_only_turns = 6; + engine.consecutive_tool_turns = 6; let processed = engine .perceive(&test_snapshot("hello")) @@ -14312,7 +15330,7 @@ mod loop_resilience_tests { #[tokio::test] async fn tool_only_turn_nudge_not_injected_below_threshold() { let mut engine = high_budget_engine(); - engine.consecutive_tool_only_turns = 6 - 1; + engine.consecutive_tool_turns = 6 - 1; let processed = engine .perceive(&test_snapshot("hello")) @@ -14345,7 +15363,7 @@ mod loop_resilience_tests { .synthesis_instruction("Summarize".to_string()) .build() .expect("build"); - engine.consecutive_tool_only_turns = 4; + engine.consecutive_tool_turns = 4; let processed = engine .perceive(&test_snapshot("hello")) @@ -14378,7 +15396,7 @@ mod loop_resilience_tests { .synthesis_instruction("Summarize".to_string()) .build() .expect("build"); - engine.consecutive_tool_only_turns = 100; + engine.consecutive_tool_turns = 100; let processed = engine .perceive(&test_snapshot("hello")) @@ -14405,7 +15423,7 @@ mod loop_resilience_tests { ..BudgetConfig::default() }; let mut engine = engine_with_budget(config); - engine.consecutive_tool_only_turns = 3; + engine.consecutive_tool_turns = 3; let llm = RecordingLlm::ok(vec![CompletionResponse { content: vec![ContentBlock::Text { text: "Here is my summary.".to_string(), @@ -14446,7 +15464,7 @@ mod loop_resilience_tests { .build() .expect("build"); // At turn 5 (3 nudge + 2 grace), tools should be stripped - engine.consecutive_tool_only_turns = 5; + engine.consecutive_tool_turns = 5; let llm = RecordingLlm::ok(vec![CompletionResponse { content: vec![ContentBlock::Text { @@ -14495,7 +15513,7 @@ mod loop_resilience_tests { .build() .expect("build"); // At turn 4 (below 3+2=5), tools should NOT be stripped - engine.consecutive_tool_only_turns = 4; + engine.consecutive_tool_turns = 4; let llm = RecordingLlm::ok(vec![CompletionResponse { content: vec![ContentBlock::Text { @@ -14556,18 +15574,8 @@ mod loop_resilience_tests { assert!(llm.requests().is_empty()); } - #[test] - fn default_termination_config_matches_current_behavior() { - let config = TerminationConfig::default(); - assert!(config.synthesize_on_exhaustion); - assert_eq!(config.nudge_after_tool_turns, 6); - assert_eq!(config.strip_tools_after_nudge, 3); - } - - #[test] - fn tool_only_turn_counter_tracks_and_resets() { - let mut engine = high_budget_engine(); - let tool_action = ActionResult { + fn tool_action(response_text: &str) -> ActionResult { + ActionResult { decision: Decision::UseTools(Vec::new()), tool_results: vec![ToolResult { tool_call_id: "call-1".to_string(), @@ -14575,20 +15583,66 @@ mod loop_resilience_tests { success: true, output: "ok".to_string(), }], - response_text: String::new(), + response_text: response_text.to_string(), tokens_used: TokenUsage::default(), - }; - engine.update_tool_only_turns(&tool_action); - assert_eq!(engine.consecutive_tool_only_turns, 1); + } + } - let text_action = ActionResult { - decision: Decision::Respond("done".to_string()), + fn text_only_action(response_text: &str) -> ActionResult { + ActionResult { + decision: Decision::Respond(response_text.to_string()), tool_results: Vec::new(), - response_text: "done".to_string(), + response_text: response_text.to_string(), tokens_used: TokenUsage::default(), - }; - engine.update_tool_only_turns(&text_action); - assert_eq!(engine.consecutive_tool_only_turns, 0); + } + } + + #[test] + fn default_termination_config_matches_current_behavior() { + let config = TerminationConfig::default(); + assert!(config.synthesize_on_exhaustion); + assert_eq!(config.nudge_after_tool_turns, 6); + assert_eq!(config.strip_tools_after_nudge, 3); + assert_eq!(config.tool_round_nudge_after, 4); + assert_eq!(config.tool_round_strip_after_nudge, 2); + } + + #[test] + fn update_tool_turns_increments_on_tools_with_text() { + let mut engine = high_budget_engine(); + + engine.update_tool_turns(&tool_action("still working")); + + assert_eq!(engine.consecutive_tool_turns, 1); + } + + #[test] + fn update_tool_turns_resets_on_text_only() { + let mut engine = high_budget_engine(); + engine.consecutive_tool_turns = 2; + + engine.update_tool_turns(&text_only_action("done")); + + assert_eq!(engine.consecutive_tool_turns, 0); + } + + #[test] + fn update_tool_turns_increments_on_tools_only() { + let mut engine = high_budget_engine(); + + engine.update_tool_turns(&tool_action("")); + + assert_eq!(engine.consecutive_tool_turns, 1); + } + + #[test] + fn update_tool_turns_saturating_add() { + let mut engine = high_budget_engine(); + engine.consecutive_tool_turns = u16::MAX; + + engine.update_tool_turns(&tool_action("still working")); + + assert_eq!(engine.consecutive_tool_turns, u16::MAX); } // --- Test 9: 3 tool calls with cap=4 → all 3 execute --- diff --git a/engine/crates/fx-kernel/src/perceive.rs b/engine/crates/fx-kernel/src/perceive.rs index f36d9798..1c0b2d58 100644 --- a/engine/crates/fx-kernel/src/perceive.rs +++ b/engine/crates/fx-kernel/src/perceive.rs @@ -597,7 +597,7 @@ mod tests { preferences.insert("tone".to_owned(), "concise".to_owned()); IdentityContext { - user_name: Some("Alice".to_owned()), + user_name: Some("Example User".to_owned()), preferences, personality_traits: vec!["helpful".to_owned()], } diff --git a/engine/crates/fx-kernel/src/reason.rs b/engine/crates/fx-kernel/src/reason.rs index 285406fc..84c40f1f 100644 --- a/engine/crates/fx-kernel/src/reason.rs +++ b/engine/crates/fx-kernel/src/reason.rs @@ -468,7 +468,7 @@ mod tests { version: 2, }], identity_context: IdentityContext { - user_name: Some("Alice".to_owned()), + user_name: Some("Example User".to_owned()), preferences, personality_traits: vec!["helpful".to_owned()], }, @@ -497,7 +497,9 @@ mod tests { .contains("Goal: Draft and send a reply")); assert!(prompt.messages[0].content.contains("last_contact = Alex")); assert!(prompt.messages[0].content.contains("Identity context:")); - assert!(prompt.messages[0].content.contains("User name: Alice")); + assert!(prompt.messages[0] + .content + .contains("User name: Example User")); assert!(prompt.messages[0].content.contains("tone: direct")); assert!(prompt.messages[0] .content diff --git a/engine/crates/fx-kernel/src/streaming.rs b/engine/crates/fx-kernel/src/streaming.rs index ddec9f07..af062985 100644 --- a/engine/crates/fx-kernel/src/streaming.rs +++ b/engine/crates/fx-kernel/src/streaming.rs @@ -68,6 +68,10 @@ pub enum StreamEvent { output: String, is_error: bool, }, + ToolError { + tool_name: String, + error: String, + }, PermissionPrompt(crate::permission_prompt::PermissionPrompt), PhaseChange { phase: Phase, @@ -130,6 +134,18 @@ mod tests { assert_eq!(event, deserialized); } + #[test] + fn tool_error_event_serializes_correctly() { + let event = StreamEvent::ToolError { + tool_name: "read_file".to_string(), + error: "permission denied".to_string(), + }; + + let json = serde_json::to_string(&event).unwrap(); + let deserialized: StreamEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(event, deserialized); + } + #[test] fn context_compacted_event_serializes_correctly() { let event = StreamEvent::ContextCompacted { diff --git a/engine/crates/fx-kernel/src/system_prompt.rs b/engine/crates/fx-kernel/src/system_prompt.rs index a7d3326a..b0402017 100644 --- a/engine/crates/fx-kernel/src/system_prompt.rs +++ b/engine/crates/fx-kernel/src/system_prompt.rs @@ -423,7 +423,7 @@ mod tests { restricted: vec!["kernel_modify".to_string()], working_dir: "/workspace".to_string(), }) - .user_context("Alice prefers short answers.") + .user_context("Prefers short answers.") .surface(Surface::HeadlessApi) .session(SessionContext { is_new: false, @@ -437,7 +437,7 @@ mod tests { "Behavioral:\nKeep answers grounded in evidence.", "Capabilities:\n- web_fetch: Fetch a web page", "Security:\n- Mode: capability\n- Restricted: kernel_modify\n- Working directory: /workspace", - "User context:\nAlice prefers short answers.", + "User context:\nPrefers short answers.", "Surface: Headless API. Return plain content without UI-specific references.", "Session:\n- State: continuing\n- Message count: 3\n- Recent summary: Reviewed deployment notes.", "Directives:\n- Return machine-readable content when asked.", diff --git a/engine/crates/fx-kernel/src/types.rs b/engine/crates/fx-kernel/src/types.rs index b0e43aad..040b08e4 100644 --- a/engine/crates/fx-kernel/src/types.rs +++ b/engine/crates/fx-kernel/src/types.rs @@ -429,7 +429,7 @@ mod tests { relevant_semantic: vec![], active_procedures: vec![], identity_context: IdentityContext { - user_name: Some("Alice".to_owned()), + user_name: Some("Example User".to_owned()), preferences: parent_preferences, personality_traits: vec!["concise".to_owned()], }, @@ -465,7 +465,7 @@ mod tests { relevant_semantic: vec![], active_procedures: vec![], identity_context: IdentityContext { - user_name: Some("Alice".to_owned()), + user_name: Some("Example User".to_owned()), preferences: child_preferences, personality_traits: vec!["focused".to_owned()], }, @@ -603,7 +603,7 @@ mod tests { preferences.insert("lang".to_owned(), "en".to_owned()); let identity = IdentityContext { - user_name: Some("Alice".to_owned()), + user_name: Some("Example User".to_owned()), preferences, personality_traits: vec!["friendly".to_owned()], }; diff --git a/engine/crates/fx-loadable/Cargo.toml b/engine/crates/fx-loadable/Cargo.toml index cabee045..7bd5e245 100644 --- a/engine/crates/fx-loadable/Cargo.toml +++ b/engine/crates/fx-loadable/Cargo.toml @@ -6,6 +6,9 @@ authors.workspace = true license.workspace = true description = "Fawx loadable layer — skill registry, plugin loading, built-in tool adapters. Hot-swappable at runtime." +[features] +test-support = [] + [dependencies] fx-core = { workspace = true } fx-kernel = { workspace = true } @@ -23,6 +26,7 @@ tokio = { workspace = true, features = ["sync", "time"] } tracing.workspace = true dirs = "6.0" futures.workspace = true +libc = "0.2" notify = { version = "7", features = ["macos_fsevent"] } sha2 = "0.10" diff --git a/engine/crates/fx-loadable/src/lib.rs b/engine/crates/fx-loadable/src/lib.rs index 503fca55..bc9ca51b 100644 --- a/engine/crates/fx-loadable/src/lib.rs +++ b/engine/crates/fx-loadable/src/lib.rs @@ -25,6 +25,8 @@ pub mod notify_skill; pub mod registry; pub mod session_memory_skill; pub mod skill; +#[cfg(any(test, feature = "test-support"))] +pub mod test_support; pub mod transaction_skill; pub mod wasm_host; pub mod wasm_skill; diff --git a/engine/crates/fx-loadable/src/session_memory_skill.rs b/engine/crates/fx-loadable/src/session_memory_skill.rs index 5bc4beda..48a8baa6 100644 --- a/engine/crates/fx-loadable/src/session_memory_skill.rs +++ b/engine/crates/fx-loadable/src/session_memory_skill.rs @@ -83,17 +83,17 @@ fn tool_definition() -> ToolDefinition { "key_decisions": { "type": "array", "items": { "type": "string" }, - "description": "Key decisions to remember (appended, max 20)" + "description": "Key decisions to remember (appended, capped by session memory budget)" }, "active_files": { "type": "array", "items": { "type": "string" }, - "description": "Files actively being worked on (replaces list)" + "description": "Files actively being worked on (replaces list, capped by session memory budget)" }, "custom_context": { "type": "array", "items": { "type": "string" }, - "description": "Any other context to remember (appended, max 20)" + "description": "Any other context to remember (appended, capped by session memory budget)" } } }), diff --git a/engine/crates/fx-loadable/src/test_support.rs b/engine/crates/fx-loadable/src/test_support.rs new file mode 100644 index 00000000..d1256e26 --- /dev/null +++ b/engine/crates/fx-loadable/src/test_support.rs @@ -0,0 +1,56 @@ +use std::fs; +use std::io; +use std::path::Path; + +pub fn test_manifest_toml(name: &str) -> String { + versioned_manifest_toml(name, "1.0.0") +} + +pub fn versioned_manifest_toml(name: &str, version: &str) -> String { + format!( + r#"name = "{name}" +version = "{version}" +description = "{name} skill" +author = "Test" +api_version = "host_api_v1" +entry_point = "run" +"# + ) +} + +pub fn invocable_wasm_bytes() -> Vec { + let wat = r#" + (module + (import "host_api_v1" "log" (func $log (param i32 i32 i32))) + (import "host_api_v1" "kv_get" (func $kv_get (param i32 i32) (result i32))) + (import "host_api_v1" "kv_set" (func $kv_set (param i32 i32 i32 i32))) + (import "host_api_v1" "get_input" (func $get_input (result i32))) + (import "host_api_v1" "set_output" (func $set_output (param i32 i32))) + (memory (export "memory") 1) + (func (export "run") + (i32.store8 (i32.const 0) (i32.const 111)) + (i32.store8 (i32.const 1) (i32.const 107)) + (call $set_output (i32.const 0) (i32.const 2)) + ) + ) + "#; + wat.as_bytes().to_vec() +} + +pub fn write_test_skill(skills_dir: &Path, name: &str) -> io::Result<()> { + write_versioned_test_skill(skills_dir, name, "1.0.0") +} + +pub fn write_versioned_test_skill(skills_dir: &Path, name: &str, version: &str) -> io::Result<()> { + let skill_dir = skills_dir.join(name); + fs::create_dir_all(&skill_dir)?; + fs::write( + skill_dir.join("manifest.toml"), + versioned_manifest_toml(name, version), + )?; + fs::write( + skill_dir.join(format!("{name}.wasm")), + invocable_wasm_bytes(), + )?; + Ok(()) +} diff --git a/engine/crates/fx-loadable/src/wasm_host.rs b/engine/crates/fx-loadable/src/wasm_host.rs index 18a3fac2..471a8a4d 100644 --- a/engine/crates/fx-loadable/src/wasm_host.rs +++ b/engine/crates/fx-loadable/src/wasm_host.rs @@ -8,10 +8,29 @@ use fx_skills::host_api::HostApi; use fx_skills::live_host_api::{execute_http_request, CredentialProvider}; use fx_skills::manifest::Capability; use fx_skills::storage::SkillStorage; +use serde::Serialize; +use std::io::Read; +#[cfg(unix)] +use std::os::unix::process::CommandExt; +use std::path::Path; +use std::process::{Child, ChildStderr, ChildStdout, Command, Stdio}; use std::sync::{Arc, Mutex}; +use std::thread::{self, JoinHandle}; +use std::time::{Duration, Instant}; /// Default storage quota per skill: 64 KiB. const DEFAULT_STORAGE_QUOTA: usize = 64 * 1024; +const COMMAND_OUTPUT_LIMIT_BYTES: usize = 512 * 1024; +const COMMAND_TIMEOUT_EXIT_CODE: i32 = -1; +const COMMAND_FAILURE_EXIT_CODE: i32 = -2; +const COMMAND_POLL_INTERVAL_MS: u64 = 10; + +#[derive(Serialize)] +struct ShellCommandResult { + stdout: String, + stderr: String, + exit_code: i32, +} /// Live host API backed by real runtime services. /// @@ -78,6 +97,10 @@ impl LiveHostApi { .unwrap_or_else(|poisoned| poisoned.into_inner()), ) } + + fn has_capability(&self, capability: Capability) -> bool { + self.capabilities.contains(&capability) + } } impl HostApi for LiveHostApi { @@ -132,6 +155,30 @@ impl HostApi for LiveHostApi { execute_http_request(method, url, headers, body) } + fn exec_command(&self, command: &str, timeout_ms: u32) -> Option { + if !self.has_capability(Capability::Shell) { + tracing::error!("exec_command denied: Shell capability not declared"); + return None; + } + execute_shell_command(command, timeout_ms) + } + + fn read_file(&self, path: &str) -> Option { + if !self.has_capability(Capability::Filesystem) { + tracing::error!("read_file denied: Filesystem capability not declared"); + return None; + } + read_utf8_file(path) + } + + fn write_file(&self, path: &str, content: &str) -> bool { + if !self.has_capability(Capability::Filesystem) { + tracing::error!("write_file denied: Filesystem capability not declared"); + return false; + } + write_utf8_file(path, content) + } + fn get_output(&self) -> String { self.output .lock() @@ -144,6 +191,197 @@ impl HostApi for LiveHostApi { } } +fn execute_shell_command(command: &str, timeout_ms: u32) -> Option { + let mut child = spawn_shell(command) + .map_err(|error| { + tracing::error!("exec_command failed to spawn '{command}': {error}"); + error + }) + .ok()?; + let (stdout, stderr) = take_command_pipes(&mut child)?; + let stdout_handle = spawn_output_reader(stdout); + let stderr_handle = spawn_output_reader(stderr); + let exit_code = wait_for_command(&mut child, timeout_ms); + let stdout = join_output_reader(stdout_handle, "stdout")?; + let stderr = add_timeout_message( + join_output_reader(stderr_handle, "stderr")?, + exit_code, + timeout_ms, + ); + serialize_command_result(stdout, stderr, exit_code) +} + +fn spawn_shell(command: &str) -> std::io::Result { + let mut cmd = Command::new("sh"); + cmd.args(["-c", command]) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + configure_shell_process_group(&mut cmd); + cmd.spawn() +} + +#[cfg(unix)] +fn configure_shell_process_group(cmd: &mut Command) { + unsafe { + cmd.pre_exec(create_process_group); + } +} + +#[cfg(not(unix))] +fn configure_shell_process_group(_cmd: &mut Command) {} + +#[cfg(unix)] +fn create_process_group() -> std::io::Result<()> { + if unsafe { libc::setpgid(0, 0) } == -1 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) +} + +fn take_command_pipes(child: &mut Child) -> Option<(ChildStdout, ChildStderr)> { + let stdout = child.stdout.take().or_else(|| { + tracing::error!("exec_command: stdout pipe missing"); + None + })?; + let stderr = child.stderr.take().or_else(|| { + tracing::error!("exec_command: stderr pipe missing"); + None + })?; + Some((stdout, stderr)) +} + +fn spawn_output_reader(reader: R) -> JoinHandle> +where + R: Read + Send + 'static, +{ + thread::spawn(move || read_capped_output(reader)) +} + +fn read_capped_output(reader: R) -> Vec +where + R: Read, +{ + let mut output = Vec::new(); + let mut limited_reader = reader.take(COMMAND_OUTPUT_LIMIT_BYTES as u64); + if let Err(error) = limited_reader.read_to_end(&mut output) { + tracing::error!("exec_command: failed to read process output: {error}"); + } + output +} + +fn wait_for_command(child: &mut Child, timeout_ms: u32) -> i32 { + let deadline = Instant::now() + Duration::from_millis(u64::from(timeout_ms)); + loop { + match child.try_wait() { + Ok(Some(status)) => return status.code().unwrap_or(COMMAND_FAILURE_EXIT_CODE), + Ok(None) if Instant::now() < deadline => { + thread::sleep(Duration::from_millis(COMMAND_POLL_INTERVAL_MS)); + } + Ok(None) => return kill_timed_out_child(child), + Err(error) => { + tracing::error!("exec_command: failed to wait on child: {error}"); + return COMMAND_FAILURE_EXIT_CODE; + } + } + } +} + +fn kill_timed_out_child(child: &mut Child) -> i32 { + if let Err(error) = terminate_timed_out_child(child) { + tracing::error!("exec_command: failed to kill timed out child: {error}"); + return COMMAND_FAILURE_EXIT_CODE; + } + if let Err(error) = child.wait() { + tracing::error!("exec_command: failed to reap timed out child: {error}"); + return COMMAND_FAILURE_EXIT_CODE; + } + COMMAND_TIMEOUT_EXIT_CODE +} + +#[cfg(unix)] +fn terminate_timed_out_child(child: &mut Child) -> std::io::Result<()> { + if unsafe { libc::killpg(child.id() as i32, libc::SIGKILL) } == 0 { + return Ok(()); + } + let error = std::io::Error::last_os_error(); + if error.raw_os_error() == Some(libc::ESRCH) { + return Ok(()); + } + Err(error) +} + +#[cfg(not(unix))] +fn terminate_timed_out_child(child: &mut Child) -> std::io::Result<()> { + child.kill() +} + +fn join_output_reader(handle: JoinHandle>, stream_name: &str) -> Option { + match handle.join() { + Ok(bytes) => Some(String::from_utf8_lossy(&bytes).into_owned()), + Err(_) => { + tracing::error!("exec_command: {stream_name} reader thread panicked"); + None + } + } +} + +fn add_timeout_message(mut stderr: String, exit_code: i32, timeout_ms: u32) -> String { + if exit_code == COMMAND_TIMEOUT_EXIT_CODE { + if !stderr.is_empty() { + stderr.push('\n'); + } + stderr.push_str(&format!("command timed out after {timeout_ms}ms")); + } + stderr +} + +fn serialize_command_result(stdout: String, stderr: String, exit_code: i32) -> Option { + serde_json::to_string(&ShellCommandResult { + stdout, + stderr, + exit_code, + }) + .map_err(|error| { + tracing::error!("exec_command: failed to serialize result: {error}"); + error + }) + .ok() +} + +fn read_utf8_file(path: &str) -> Option { + std::fs::read_to_string(path) + .map_err(|error| { + tracing::error!("read_file failed for '{}': {}", path, error); + error + }) + .ok() +} + +fn write_utf8_file(path: &str, content: &str) -> bool { + let path = Path::new(path); + if let Err(error) = ensure_parent_directory(path) { + tracing::error!( + "write_file failed to create parent for '{}': {}", + path.display(), + error + ); + return false; + } + if let Err(error) = std::fs::write(path, content) { + tracing::error!("write_file failed for '{}': {}", path.display(), error); + return false; + } + true +} + +fn ensure_parent_directory(path: &Path) -> std::io::Result<()> { + match path.parent() { + Some(parent) if !parent.as_os_str().is_empty() => std::fs::create_dir_all(parent), + _ => Ok(()), + } +} + fn is_network_allowed(url: &str, capabilities: &[Capability]) -> bool { for cap in capabilities { match cap { @@ -188,7 +426,11 @@ fn extract_host(url: &str) -> Option<&str> { #[cfg(test)] mod tests { use super::*; + use serde_json::Value; use std::collections::HashMap; + use std::env; + use std::io::Write; + use tempfile::TempDir; use zeroize::Zeroizing; fn make_config(input: &str) -> LiveHostApiConfig<'_> { @@ -201,10 +443,34 @@ mod tests { } } + const STDIN_HELPER_ENV: &str = "FX_LOADABLE_STDIN_HELPER"; + fn make_api(input: &str) -> LiveHostApi { LiveHostApi::new(make_config(input)) } + fn make_shell_api() -> LiveHostApi { + LiveHostApi::new(LiveHostApiConfig { + skill_name: "test", + input: String::new(), + storage_quota: None, + capabilities: vec![Capability::Shell], + credential_provider: None, + }) + } + + fn parse_command_result(json: &str) -> Value { + serde_json::from_str(json).expect("parse command result") + } + + fn run_shell_command(command: &str, timeout_ms: u32) -> Value { + let api = make_shell_api(); + let json = api + .exec_command(command, timeout_ms) + .expect("shell command result"); + parse_command_result(&json) + } + /// Mock credential provider for testing. struct MockCredentialProvider { credentials: HashMap, @@ -229,6 +495,57 @@ mod tests { } } + struct FailAfterLimitReader { + bytes_remaining: usize, + chunk_size: usize, + } + + impl Read for FailAfterLimitReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if self.bytes_remaining == 0 { + panic!("reader was polled after hitting the output cap"); + } + let bytes_to_read = self.bytes_remaining.min(self.chunk_size).min(buf.len()); + buf[..bytes_to_read].fill(b'x'); + self.bytes_remaining -= bytes_to_read; + Ok(bytes_to_read) + } + } + + #[cfg(unix)] + fn read_background_pid(pid_file: &Path) -> i32 { + let deadline = Instant::now() + Duration::from_millis(500); + while Instant::now() < deadline { + if let Ok(pid) = std::fs::read_to_string(pid_file) { + return pid.trim().parse().expect("background pid"); + } + thread::sleep(Duration::from_millis(10)); + } + panic!( + "timed out waiting for background pid at {}", + pid_file.display() + ); + } + + #[cfg(unix)] + fn process_exists(pid: i32) -> bool { + if unsafe { libc::kill(pid, 0) } == 0 { + return true; + } + std::io::Error::last_os_error().raw_os_error() == Some(libc::EPERM) + } + + #[cfg(unix)] + fn wait_for_process_exit(pid: i32) { + let deadline = Instant::now() + Duration::from_millis(500); + while Instant::now() < deadline { + if !process_exists(pid) { + return; + } + thread::sleep(Duration::from_millis(10)); + } + } + #[test] fn input_output_round_trip() { let mut api = make_api("hello world"); @@ -339,6 +656,135 @@ mod tests { ); } + #[test] + fn exec_command_denied_without_shell_capability() { + let api = make_api(""); + assert_eq!(api.exec_command("printf hello", 1_000), None); + } + + #[test] + fn exec_command_allowed_with_shell_capability() { + let result = run_shell_command("printf hello", 1_000); + assert_eq!(result["stdout"], "hello"); + assert_eq!(result["stderr"], ""); + assert_eq!(result["exit_code"], 0); + } + + #[test] + fn read_capped_output_stops_after_limit() { + let reader = FailAfterLimitReader { + bytes_remaining: COMMAND_OUTPUT_LIMIT_BYTES, + chunk_size: 8192, + }; + + let output = read_capped_output(reader); + + assert_eq!(output.len(), COMMAND_OUTPUT_LIMIT_BYTES); + } + + #[test] + #[cfg(unix)] + fn exec_command_stdin_helper_reads_eof() { + if env::var_os(STDIN_HELPER_ENV).is_none() { + return; + } + let result = run_shell_command( + r#"if read value; then printf '%s' "$value"; else printf eof; fi"#, + 1_000, + ); + assert_eq!(result["stdout"], "eof"); + assert_eq!(result["stderr"], ""); + assert_eq!(result["exit_code"], 0); + } + + #[test] + #[cfg(unix)] + fn exec_command_does_not_inherit_parent_stdin() { + let exe = env::current_exe().expect("test binary path"); + let mut child = Command::new(exe) + .arg("--exact") + .arg("exec_command_stdin_helper_reads_eof") + .env(STDIN_HELPER_ENV, "1") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn stdin helper"); + child + .stdin + .take() + .expect("helper stdin") + .write_all(b"from-parent\n") + .expect("write helper stdin"); + let output = child.wait_with_output().expect("helper output"); + assert!( + output.status.success(), + "stdin helper failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + #[test] + #[cfg(unix)] + fn exec_command_timeout_kills_background_process_group() { + let temp_dir = TempDir::new().expect("temp dir"); + let pid_file = temp_dir.path().join("background.pid"); + let command = format!("sleep 1 & echo $! > '{}' && wait", pid_file.display()); + let started_at = Instant::now(); + + let result = run_shell_command(&command, 50); + let elapsed = started_at.elapsed(); + let background_pid = read_background_pid(&pid_file); + wait_for_process_exit(background_pid); + if process_exists(background_pid) { + let _ = unsafe { libc::kill(background_pid, libc::SIGKILL) }; + } + + assert_eq!(result["exit_code"], COMMAND_TIMEOUT_EXIT_CODE); + assert!(elapsed < Duration::from_millis(500)); + assert!(!process_exists(background_pid)); + } + + #[test] + fn read_file_denied_without_filesystem_capability() { + let temp_dir = TempDir::new().expect("temp dir"); + let path = temp_dir.path().join("data.txt"); + std::fs::write(&path, "hello").expect("write fixture"); + let api = make_api(""); + + assert_eq!(api.read_file(path.to_str().expect("utf-8 path")), None); + } + + #[test] + fn write_file_denied_without_filesystem_capability() { + let temp_dir = TempDir::new().expect("temp dir"); + let path = temp_dir.path().join("data.txt"); + let api = make_api(""); + + assert!(!api.write_file(path.to_str().expect("utf-8 path"), "hello")); + assert!(!path.exists()); + } + + #[test] + fn read_write_file_round_trip() { + let temp_dir = TempDir::new().expect("temp dir"); + let path = temp_dir.path().join("nested").join("data.txt"); + let api = LiveHostApi::new(LiveHostApiConfig { + skill_name: "test", + input: String::new(), + storage_quota: None, + capabilities: vec![Capability::Filesystem], + credential_provider: None, + }); + + assert!(api.write_file(path.to_str().expect("utf-8 path"), "hello")); + assert_eq!( + api.read_file(path.to_str().expect("utf-8 path")), + Some("hello".to_string()) + ); + } + #[test] fn network_allowed_unrestricted() { assert!(is_network_allowed( diff --git a/engine/crates/fx-loadable/src/wasm_skill.rs b/engine/crates/fx-loadable/src/wasm_skill.rs index 0428477b..13cfeab2 100644 --- a/engine/crates/fx-loadable/src/wasm_skill.rs +++ b/engine/crates/fx-loadable/src/wasm_skill.rs @@ -445,6 +445,7 @@ fn read_skill_directories(skills_dir: &Path) -> Result, S #[cfg(test)] mod tests { use super::*; + use crate::test_support::invocable_wasm_bytes; use fx_skills::loader::SkillLoader; use fx_skills::manifest::SkillManifest; @@ -460,25 +461,6 @@ mod tests { } } - fn invocable_wasm_bytes() -> Vec { - let wat = r#" - (module - (import "host_api_v1" "log" (func $log (param i32 i32 i32))) - (import "host_api_v1" "kv_get" (func $kv_get (param i32 i32) (result i32))) - (import "host_api_v1" "kv_set" (func $kv_set (param i32 i32 i32 i32))) - (import "host_api_v1" "get_input" (func $get_input (result i32))) - (import "host_api_v1" "set_output" (func $set_output (param i32 i32))) - (memory (export "memory") 1) - (func (export "run") - (i32.store8 (i32.const 0) (i32.const 111)) - (i32.store8 (i32.const 1) (i32.const 107)) - (call $set_output (i32.const 0) (i32.const 2)) - ) - ) - "#; - wat.as_bytes().to_vec() - } - fn load_test_skill(name: &str) -> LoadedSkill { let loader = SkillLoader::new(vec![]); let manifest = test_manifest(name); diff --git a/engine/crates/fx-loadable/src/watcher.rs b/engine/crates/fx-loadable/src/watcher.rs index fbee99ab..5c417488 100644 --- a/engine/crates/fx-loadable/src/watcher.rs +++ b/engine/crates/fx-loadable/src/watcher.rs @@ -399,79 +399,14 @@ fn collect_expired(pending: &mut HashMap) -> Vec { #[cfg(test)] mod tests { use super::*; + use crate::test_support::{ + invocable_wasm_bytes, test_manifest_toml, versioned_manifest_toml, write_test_skill, + write_versioned_test_skill, + }; use crate::wasm_skill::compute_wasm_hash; use std::fs; use tempfile::TempDir; - fn test_manifest_toml(name: &str) -> String { - format!( - r#"name = "{name}" -version = "1.0.0" -description = "{name} skill" -author = "Test" -api_version = "host_api_v1" -entry_point = "run" -"# - ) - } - - fn versioned_manifest_toml(name: &str, version: &str) -> String { - format!( - r#"name = "{name}" -version = "{version}" -description = "{name} skill" -author = "Test" -api_version = "host_api_v1" -entry_point = "run" -"# - ) - } - - fn invocable_wasm_bytes() -> Vec { - let wat = r#" - (module - (import "host_api_v1" "log" (func $log (param i32 i32 i32))) - (import "host_api_v1" "kv_get" (func $kv_get (param i32 i32) (result i32))) - (import "host_api_v1" "kv_set" (func $kv_set (param i32 i32 i32 i32))) - (import "host_api_v1" "get_input" (func $get_input (result i32))) - (import "host_api_v1" "set_output" (func $set_output (param i32 i32))) - (memory (export "memory") 1) - (func (export "run") - (i32.store8 (i32.const 0) (i32.const 111)) - (i32.store8 (i32.const 1) (i32.const 107)) - (call $set_output (i32.const 0) (i32.const 2)) - ) - ) - "#; - wat.as_bytes().to_vec() - } - - fn setup_skill_dir(skills_dir: &Path, name: &str) { - let skill_dir = skills_dir.join(name); - fs::create_dir_all(&skill_dir).unwrap(); - fs::write(skill_dir.join("manifest.toml"), test_manifest_toml(name)).unwrap(); - fs::write( - skill_dir.join(format!("{name}.wasm")), - invocable_wasm_bytes(), - ) - .unwrap(); - } - - fn setup_versioned_skill_dir(skills_dir: &Path, name: &str, version: &str) { - let skill_dir = skills_dir.join(name); - fs::create_dir_all(&skill_dir).unwrap(); - fs::write( - skill_dir.join("manifest.toml"), - versioned_manifest_toml(name, version), - ) - .unwrap(); - fs::write( - skill_dir.join(format!("{name}.wasm")), - invocable_wasm_bytes(), - ) - .unwrap(); - } - #[test] fn reload_event_is_debug_and_clone() { let event = ReloadEvent::Loaded { @@ -511,7 +446,7 @@ entry_point = "run" fn skill_dir_is_valid_with_all_files() { let tmp = TempDir::new().unwrap(); let name = "test_skill"; - setup_skill_dir(tmp.path(), name); + write_test_skill(tmp.path(), name).unwrap(); assert!(skill_dir_is_valid(&tmp.path().join(name), name)); } @@ -548,8 +483,8 @@ entry_point = "run" #[test] fn initialize_hashes_populates_from_existing_skills() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "alpha"); - setup_skill_dir(tmp.path(), "beta"); + write_test_skill(tmp.path(), "alpha").unwrap(); + write_test_skill(tmp.path(), "beta").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, _rx) = mpsc::channel(16); @@ -570,7 +505,7 @@ entry_point = "run" #[test] fn initialize_hashes_correct_hash_value() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "test_hash"); + write_test_skill(tmp.path(), "test_hash").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, _rx) = mpsc::channel(16); @@ -591,7 +526,7 @@ entry_point = "run" #[test] fn initialize_hashes_stores_version() { let tmp = TempDir::new().unwrap(); - setup_versioned_skill_dir(tmp.path(), "versioned", "2.5.0"); + write_versioned_test_skill(tmp.path(), "versioned", "2.5.0").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, _rx) = mpsc::channel(16); @@ -645,7 +580,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_loads_new_skill() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "newskill"); + write_test_skill(tmp.path(), "newskill").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -675,7 +610,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_loads_with_correct_version() { let tmp = TempDir::new().unwrap(); - setup_versioned_skill_dir(tmp.path(), "verskill", "3.1.0"); + write_versioned_test_skill(tmp.path(), "verskill", "3.1.0").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -720,7 +655,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_updates_existing_skill() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "updskill"); + write_test_skill(tmp.path(), "updskill").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -750,7 +685,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_update_reports_old_version() { let tmp = TempDir::new().unwrap(); - setup_versioned_skill_dir(tmp.path(), "upver", "1.0.0"); + write_versioned_test_skill(tmp.path(), "upver", "1.0.0").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -794,7 +729,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_same_hash_no_reload() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "sameskill"); + write_test_skill(tmp.path(), "sameskill").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -820,7 +755,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_removal() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "rmskill"); + write_test_skill(tmp.path(), "rmskill").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -851,7 +786,7 @@ entry_point = "run" #[tokio::test] async fn process_skill_change_error_keeps_existing() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "errskill"); + write_test_skill(tmp.path(), "errskill").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -917,7 +852,7 @@ entry_point = "run" #[tokio::test] async fn debounce_multiple_events_single_reload() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "debounce"); + write_test_skill(tmp.path(), "debounce").unwrap(); let registry = Arc::new(SkillRegistry::new()); let (tx, mut rx) = mpsc::channel(16); @@ -981,7 +916,7 @@ entry_point = "run" #[tokio::test] async fn handle_removal_uses_try_send() { let tmp = TempDir::new().unwrap(); - setup_skill_dir(tmp.path(), "trysend"); + write_test_skill(tmp.path(), "trysend").unwrap(); let registry = Arc::new(SkillRegistry::new()); // Channel with capacity 1 — fill it to verify try_send doesn't block diff --git a/engine/crates/fx-marketplace/Cargo.toml b/engine/crates/fx-marketplace/Cargo.toml index 535fc06f..dee78d08 100644 --- a/engine/crates/fx-marketplace/Cargo.toml +++ b/engine/crates/fx-marketplace/Cargo.toml @@ -9,6 +9,7 @@ license.workspace = true fx-skills = { path = "../fx-skills" } serde = { workspace = true } serde_json = { workspace = true } +tracing.workspace = true ureq = { workspace = true } [dev-dependencies] diff --git a/engine/crates/fx-marketplace/src/lib.rs b/engine/crates/fx-marketplace/src/lib.rs index 83e99458..9047cadd 100644 --- a/engine/crates/fx-marketplace/src/lib.rs +++ b/engine/crates/fx-marketplace/src/lib.rs @@ -11,6 +11,15 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; +/// Official fawxai publisher Ed25519 public key (32 bytes). +pub const FAWXAI_PUBLIC_KEY: [u8; 32] = [ + 62, 38, 70, 230, 12, 59, 226, 179, 11, 150, 52, 48, 238, 181, 159, 188, 106, 55, 109, 208, 1, + 191, 157, 233, 161, 111, 154, 212, 209, 133, 28, 68, +]; + +/// Default registry URL (raw GitHub content). +pub const DEFAULT_REGISTRY_URL: &str = "https://raw.githubusercontent.com/fawxai/registry/main"; + // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- @@ -119,6 +128,62 @@ pub fn parse_index(json: &str) -> Result, MarketplaceError> { Ok(index.skills) } +/// Load trusted keys: builtin fawxai key + any user-added keys from +/// `{data_dir}/trusted_keys/`. +pub fn load_trusted_keys(data_dir: &Path) -> Result>, MarketplaceError> { + let mut keys = vec![FAWXAI_PUBLIC_KEY.to_vec()]; + let keys_dir = data_dir.join("trusted_keys"); + if !keys_dir.exists() { + return Ok(keys); + } + + for path in trusted_key_paths(&keys_dir)? { + if let Some(key) = read_trusted_key(&path)? { + keys.push(key); + } + } + Ok(keys) +} + +/// Build a default `RegistryConfig` for the given data directory. +pub fn default_config(data_dir: &Path) -> Result { + Ok(RegistryConfig { + registry_url: DEFAULT_REGISTRY_URL.to_string(), + data_dir: data_dir.to_path_buf(), + trusted_keys: load_trusted_keys(data_dir)?, + }) +} + +fn trusted_key_paths(keys_dir: &Path) -> Result, MarketplaceError> { + let mut paths = Vec::new(); + let entries = fs::read_dir(keys_dir) + .map_err(|e| MarketplaceError::InstallError(format!("read trusted_keys: {e}")))?; + for entry in entries { + let path = entry + .map_err(|e| MarketplaceError::InstallError(format!("read entry: {e}")))? + .path(); + if path.is_file() { + paths.push(path); + } + } + paths.sort(); + Ok(paths) +} + +fn read_trusted_key(path: &Path) -> Result>, MarketplaceError> { + let key_bytes = + fs::read(path).map_err(|e| MarketplaceError::InstallError(format!("read key: {e}")))?; + if key_bytes.len() != 32 { + tracing::warn!( + path = %path.display(), + size = key_bytes.len(), + "Skipping invalid trusted key file" + ); + return Ok(None); + } + Ok(Some(key_bytes)) +} + /// Validate that a skill name contains only safe characters. /// /// Rejects names that contain path separators, `..`, or any characters @@ -443,6 +508,37 @@ capabilities = ["network"] fs::write(dir.join("manifest.toml"), manifest).expect("write manifest"); } + #[test] + fn load_trusted_keys_returns_builtin_and_valid_user_keys() { + let tmp = TempDir::new().expect("tempdir"); + let keys_dir = tmp.path().join("trusted_keys"); + fs::create_dir_all(&keys_dir).expect("mkdir trusted_keys"); + fs::write(keys_dir.join("a.key"), vec![7_u8; 32]).expect("write valid key"); + fs::write(keys_dir.join("b.key"), vec![9_u8; 31]).expect("write invalid key"); + + let keys = load_trusted_keys(tmp.path()).expect("load trusted keys"); + + assert_eq!(keys.len(), 2); + assert_eq!(keys[0], FAWXAI_PUBLIC_KEY.to_vec()); + assert_eq!(keys[1], vec![7_u8; 32]); + } + + #[test] + fn default_config_uses_default_registry_and_loaded_keys() { + let tmp = TempDir::new().expect("tempdir"); + let keys_dir = tmp.path().join("trusted_keys"); + fs::create_dir_all(&keys_dir).expect("mkdir trusted_keys"); + fs::write(keys_dir.join("publisher.key"), vec![5_u8; 32]).expect("write key"); + + let config = default_config(tmp.path()).expect("default config"); + + assert_eq!(config.registry_url, DEFAULT_REGISTRY_URL); + assert_eq!(config.data_dir, tmp.path()); + assert_eq!(config.trusted_keys.len(), 2); + assert_eq!(config.trusted_keys[0], FAWXAI_PUBLIC_KEY.to_vec()); + assert_eq!(config.trusted_keys[1], vec![5_u8; 32]); + } + // 1. search_filters_by_name #[test] fn search_filters_by_name() { diff --git a/engine/crates/fx-python/Cargo.toml b/engine/crates/fx-python/Cargo.toml new file mode 100644 index 00000000..adaa3eca --- /dev/null +++ b/engine/crates/fx-python/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "fx-python" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", features = ["process", "time", "fs", "io-util", "rt"] } +libc = "0.2" +fx-kernel = { path = "../fx-kernel" } +fx-loadable = { path = "../fx-loadable" } +fx-llm = { path = "../fx-llm" } + +[dev-dependencies] +tempfile = "3" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/engine/crates/fx-python/src/installer.rs b/engine/crates/fx-python/src/installer.rs new file mode 100644 index 00000000..9e136304 --- /dev/null +++ b/engine/crates/fx-python/src/installer.rs @@ -0,0 +1,488 @@ +use crate::process::{ + elapsed_millis, format_process_detail, run_command, CapturedProcess, ProcessStatus, + MAX_TIMEOUT_SECONDS, +}; +use crate::venv::{PackageInfo, VenvManager}; +use fx_kernel::cancellation::CancellationToken; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::time::{Duration, Instant}; +use tokio::fs; +use tokio::process::Command; + +const DEFAULT_INSTALL_TIMEOUT_SECONDS: u64 = 600; + +#[derive(Debug, Clone)] +pub struct PythonInstaller { + manager: VenvManager, + experiments_root: PathBuf, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct PythonInstallArgs { + #[serde(default)] + pub packages: Vec, + pub venv: String, + #[serde(default)] + pub requirements_file: Option, + #[serde(default = "default_install_timeout_seconds")] + pub timeout_seconds: u64, +} + +#[derive(Debug, Serialize)] +pub(crate) struct InstallResult { + pub installed: Vec, + pub duration_ms: u64, +} + +struct InstallPlan { + requirements_file: Option, + timeout_seconds: u64, +} + +impl PythonInstaller { + #[must_use] + pub fn new(manager: VenvManager, experiments_root: PathBuf) -> Self { + Self { + manager, + experiments_root, + } + } + + pub async fn install( + &self, + args: PythonInstallArgs, + cancel: Option<&CancellationToken>, + ) -> Result { + validate_install_request(&args)?; + self.manager.ensure_venv(&args.venv).await?; + let plan = self.plan_install(&args).await?; + + let started = Instant::now(); + let output = self.run_pip_install(&args, &plan, cancel).await?; + let installed = self + .resolve_installed_packages(&args, &output.stdout) + .await?; + + Ok(InstallResult { + installed, + duration_ms: elapsed_millis(started.elapsed()), + }) + } + + async fn plan_install(&self, args: &PythonInstallArgs) -> Result { + Ok(InstallPlan { + requirements_file: self.resolve_requirements_file(args).await?, + timeout_seconds: clamp_timeout_seconds(args.timeout_seconds), + }) + } + + async fn run_pip_install( + &self, + args: &PythonInstallArgs, + plan: &InstallPlan, + cancel: Option<&CancellationToken>, + ) -> Result { + let mut command = Command::new(self.manager.pip_path(&args.venv)); + configure_install_command(&mut command, args, plan); + command.stdout(Stdio::piped()).stderr(Stdio::piped()); + + let output = run_command( + command, + "pip install", + Duration::from_secs(plan.timeout_seconds), + cancel, + ) + .await?; + require_success(output, plan.timeout_seconds) + } + + async fn resolve_installed_packages( + &self, + args: &PythonInstallArgs, + stdout: &str, + ) -> Result, String> { + let parsed = parse_pip_output(stdout); + if !parsed.is_empty() || args.requirements_file.is_some() { + return Ok(parsed); + } + + let installed = self.manager.info(&args.venv).await?; + Ok(match_requested_packages(&installed, &args.packages)) + } + + async fn resolve_requirements_file( + &self, + args: &PythonInstallArgs, + ) -> Result, String> { + let Some(requirements_file) = &args.requirements_file else { + return Ok(None); + }; + + let experiment_dir = self.experiments_root.join(&args.venv); + let base = canonicalize_dir(&experiment_dir).await?; + let candidate = requirements_path(&base, requirements_file); + let path = canonicalize_file(&candidate).await?; + ensure_path_within(&base, &path)?; + Ok(Some(path)) + } +} + +fn validate_install_request(args: &PythonInstallArgs) -> Result<(), String> { + if args.packages.is_empty() && args.requirements_file.is_none() { + return Err("python_install requires 'packages' or 'requirements_file'".to_string()); + } + + Ok(()) +} + +fn configure_install_command(command: &mut Command, args: &PythonInstallArgs, plan: &InstallPlan) { + command.arg("install").arg("--no-cache-dir"); + if let Some(requirements_file) = &plan.requirements_file { + command.arg("-r").arg(requirements_file); + return; + } + + command.args(args.packages.iter().map(String::as_str)); +} + +fn require_success( + output: CapturedProcess, + timeout_seconds: u64, +) -> Result { + match output.status { + ProcessStatus::Exited(0) => Ok(output), + ProcessStatus::Exited(exit_code) => Err(format!( + "pip install failed: {}", + format_process_detail(Some(exit_code), &output.stdout, &output.stderr) + )), + ProcessStatus::TimedOut => Err(timeout_message(&output, timeout_seconds)), + ProcessStatus::Cancelled => Err("pip install cancelled".to_string()), + } +} + +fn timeout_message(output: &CapturedProcess, timeout_seconds: u64) -> String { + let detail = format_process_detail(None, &output.stdout, &output.stderr); + format!("pip install timed out after {timeout_seconds} seconds: {detail}") +} + +async fn canonicalize_dir(path: &Path) -> Result { + fs::canonicalize(path).await.map_err(|error| { + format!( + "experiment dir '{}' is unavailable: {error}", + path.display() + ) + }) +} + +async fn canonicalize_file(path: &Path) -> Result { + let canonical = fs::canonicalize(path).await.map_err(|error| { + format!( + "requirements file '{}' is unavailable: {error}", + path.display() + ) + })?; + let metadata = fs::metadata(&canonical) + .await + .map_err(|error| format!("failed to inspect '{}': {error}", canonical.display()))?; + if metadata.is_file() { + return Ok(canonical); + } + + Err(format!( + "requirements file '{}' must be a file", + canonical.display() + )) +} + +fn requirements_path(base: &Path, requirements_file: &str) -> PathBuf { + let path = Path::new(requirements_file); + if path.is_absolute() { + return path.to_path_buf(); + } + base.join(path) +} + +fn ensure_path_within(base: &Path, path: &Path) -> Result<(), String> { + if path.starts_with(base) { + return Ok(()); + } + + Err(format!( + "requirements file '{}' must stay inside experiment dir '{}'", + path.display(), + base.display() + )) +} + +pub(crate) fn parse_pip_output(stdout: &str) -> Vec { + let mut installed = Vec::new(); + for line in stdout.lines() { + if let Some(packages) = line.trim().strip_prefix("Successfully installed ") { + for token in packages.split_whitespace() { + if let Some(package) = parse_installed_token(token) { + installed.push(package); + } + } + } + } + installed +} + +fn parse_installed_token(token: &str) -> Option { + let cleaned = token.trim_matches(|ch: char| ch == ',' || ch == ';'); + let (name, version) = cleaned.rsplit_once('-')?; + if name.is_empty() || version.is_empty() { + return None; + } + + Some(format!("{name}=={version}")) +} + +fn match_requested_packages(installed: &[PackageInfo], requested: &[String]) -> Vec { + let mut matches = Vec::new(); + for spec in requested { + if let Some(name) = requested_name(spec) { + if let Some(package) = find_package(installed, &name) { + matches.push(format!("{}=={}", package.name, package.version)); + } + } + } + matches +} + +fn find_package<'a>(installed: &'a [PackageInfo], name: &str) -> Option<&'a PackageInfo> { + installed + .iter() + .find(|package| normalize_name(&package.name) == name) +} + +fn requested_name(spec: &str) -> Option { + let name: String = spec + .chars() + .take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '-' || *ch == '_') + .collect(); + if name.is_empty() { + return None; + } + + Some(normalize_name(&name)) +} + +fn normalize_name(name: &str) -> String { + name.to_ascii_lowercase().replace('_', "-") +} + +fn clamp_timeout_seconds(timeout_seconds: u64) -> u64 { + timeout_seconds.min(MAX_TIMEOUT_SECONDS) +} + +fn default_install_timeout_seconds() -> u64 { + DEFAULT_INSTALL_TIMEOUT_SECONDS +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::os::unix::fs::PermissionsExt; + use tempfile::TempDir; + + #[test] + fn parse_pip_output_extracts_installed_packages() { + let stdout = "Collecting requests\nSuccessfully installed requests-2.32.3 urllib3-2.2.2 my-package-1.0.0\n"; + + let installed = parse_pip_output(stdout); + + assert_eq!( + installed, + vec![ + "requests==2.32.3".to_string(), + "urllib3==2.2.2".to_string(), + "my-package==1.0.0".to_string(), + ] + ); + } + + #[tokio::test] + async fn install_rejects_requirements_file_outside_experiment_dir() { + let temp_dir = TempDir::new().expect("tempdir"); + let installer = installer_with_fake_pip(&temp_dir, success_pip_script("unused")); + let experiment_dir = temp_dir.path().join("experiments/test"); + fs::create_dir_all(&experiment_dir).expect("experiment dir"); + let outside = temp_dir.path().join("outside.txt"); + fs::write(&outside, "requests==2.32.3\n").expect("outside requirements"); + + let error = installer + .install( + PythonInstallArgs { + packages: Vec::new(), + venv: "test".to_string(), + requirements_file: Some(outside.to_string_lossy().into_owned()), + timeout_seconds: 600, + }, + None, + ) + .await + .expect_err("outside requirements should fail"); + + assert!(error.contains("must stay inside experiment dir")); + } + + #[tokio::test] + async fn install_adds_no_cache_dir_and_sandboxed_requirements_file() { + let temp_dir = TempDir::new().expect("tempdir"); + let args_file = temp_dir.path().join("pip-args.txt"); + let installer = installer_with_fake_pip( + &temp_dir, + success_pip_script(args_file.to_string_lossy().as_ref()), + ); + let experiment_dir = temp_dir.path().join("experiments/test"); + fs::create_dir_all(&experiment_dir).expect("experiment dir"); + fs::write( + experiment_dir.join("requirements.txt"), + "requests==2.32.3\n", + ) + .expect("requirements file"); + + installer + .install( + PythonInstallArgs { + packages: Vec::new(), + venv: "test".to_string(), + requirements_file: Some("requirements.txt".to_string()), + timeout_seconds: 600, + }, + None, + ) + .await + .expect("install should succeed"); + + let args = fs::read_to_string(&args_file).expect("pip args"); + let lines: Vec<_> = args.lines().collect(); + let canonical_experiment_dir = + std::fs::canonicalize(&experiment_dir).expect("canonicalize experiment dir"); + assert_eq!(lines[0], "install"); + assert_eq!(lines[1], "--no-cache-dir"); + assert_eq!(lines[2], "-r"); + assert!(lines[3].starts_with(canonical_experiment_dir.to_string_lossy().as_ref())); + } + + #[tokio::test] + async fn install_times_out() { + let temp_dir = TempDir::new().expect("tempdir"); + let installer = installer_with_fake_pip(&temp_dir, sleeping_pip_script()); + + let error = installer + .install( + PythonInstallArgs { + packages: vec!["demo".to_string()], + venv: "test".to_string(), + requirements_file: None, + timeout_seconds: 1, + }, + None, + ) + .await + .expect_err("sleeping pip should time out"); + + assert!(error.contains("timed out")); + } + + #[tokio::test] + async fn cancellation_stops_pip_install() { + let temp_dir = TempDir::new().expect("tempdir"); + let installer = installer_with_fake_pip(&temp_dir, sleeping_pip_script()); + + let token = CancellationToken::new(); + let cancel = token.clone(); + + // Cancel after 200ms — well before the sleeping pip would finish + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + cancel.cancel(); + }); + + let error = installer + .install( + PythonInstallArgs { + packages: vec!["demo".to_string()], + venv: "test".to_string(), + requirements_file: None, + timeout_seconds: 600, + }, + Some(&token), + ) + .await + .expect_err("cancelled pip should fail"); + + assert!( + error.contains("cancelled"), + "expected cancelled, got: {error}" + ); + } + + #[tokio::test] + async fn install_adds_no_cache_dir_for_package_installs() { + let temp_dir = TempDir::new().expect("tempdir"); + let args_file = temp_dir.path().join("package-args.txt"); + let installer = installer_with_fake_pip( + &temp_dir, + success_pip_script(args_file.to_string_lossy().as_ref()), + ); + + installer + .install( + PythonInstallArgs { + packages: vec!["demo".to_string()], + venv: "test".to_string(), + requirements_file: None, + timeout_seconds: 600, + }, + None, + ) + .await + .expect("install should succeed"); + + let args = fs::read_to_string(&args_file).expect("pip args"); + let lines: Vec<_> = args.lines().collect(); + assert_eq!(lines[0], "install"); + assert_eq!(lines[1], "--no-cache-dir"); + assert_eq!(lines[2], "demo"); + } + + #[test] + fn timeout_seconds_are_clamped() { + assert_eq!( + clamp_timeout_seconds(MAX_TIMEOUT_SECONDS + 1), + MAX_TIMEOUT_SECONDS + ); + } + + fn installer_with_fake_pip(temp_dir: &TempDir, pip_script: String) -> PythonInstaller { + let venv_root = temp_dir.path().join("venvs"); + let manager = VenvManager::new(&venv_root); + let bin_dir = manager.venv_path("test").join("bin"); + fs::create_dir_all(&bin_dir).expect("bin dir"); + fs::write(bin_dir.join("python"), "#!/usr/bin/env sh\nexit 0\n").expect("python shim"); + write_executable(&bin_dir.join("pip"), &pip_script); + PythonInstaller::new(manager, temp_dir.path().join("experiments")) + } + + fn write_executable(path: &Path, content: &str) { + fs::write(path, content).expect("write executable"); + let permissions = fs::Permissions::from_mode(0o755); + fs::set_permissions(path, permissions).expect("chmod executable"); + } + + fn success_pip_script(args_file: &str) -> String { + format!( + "#!/usr/bin/env sh\nprintf '%s\\n' \"$@\" > '{args_file}'\nprintf 'Successfully installed demo-1.0.0\\n'\n" + ) + } + + fn sleeping_pip_script() -> String { + "#!/usr/bin/env sh\nsleep 2\nprintf 'Successfully installed demo-1.0.0\\n'\n".to_string() + } +} diff --git a/engine/crates/fx-python/src/lib.rs b/engine/crates/fx-python/src/lib.rs new file mode 100644 index 00000000..855c2b7e --- /dev/null +++ b/engine/crates/fx-python/src/lib.rs @@ -0,0 +1,306 @@ +mod installer; +mod process; +mod runner; +mod venv; + +use async_trait::async_trait; +use fx_kernel::act::ToolCacheability; +use fx_kernel::cancellation::CancellationToken; +use fx_llm::ToolDefinition; +use fx_loadable::skill::{Skill, SkillError}; +use installer::{PythonInstallArgs, PythonInstaller}; +use runner::{PythonRunArgs, PythonRunner}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::path::Path; +use venv::{PackageInfo, VenvManager}; + +#[derive(Debug, Deserialize)] +struct PythonVenvsArgs { + action: String, + #[serde(default)] + name: Option, +} + +#[derive(Debug, Serialize)] +struct VenvListResponse { + venvs: Vec, +} + +#[derive(Debug, Serialize)] +struct VenvDeleteResponse { + deleted: String, +} + +#[derive(Debug, Serialize)] +struct VenvInfoResponse { + name: String, + packages: Vec, +} + +#[derive(Debug, Clone)] +pub struct PythonSkill { + venv_manager: VenvManager, + runner: PythonRunner, + installer: PythonInstaller, +} + +impl PythonSkill { + #[must_use] + pub fn new(data_dir: &Path) -> Self { + let venv_root = data_dir.join("venvs"); + let experiments_root = data_dir.join("experiments"); + let venv_manager = VenvManager::new(&venv_root); + let runner = PythonRunner::new(venv_manager.clone(), experiments_root.clone()); + let installer = PythonInstaller::new(venv_manager.clone(), experiments_root); + + Self { + venv_manager, + runner, + installer, + } + } + + async fn handle_run( + &self, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Result { + let args: PythonRunArgs = parse_arguments(arguments)?; + let result = self.runner.run(args, cancel).await?; + serialize_response(&result) + } + + async fn handle_install( + &self, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Result { + let args: PythonInstallArgs = parse_arguments(arguments)?; + let result = self.installer.install(args, cancel).await?; + serialize_response(&result) + } + + async fn handle_venvs(&self, arguments: &str) -> Result { + let args: PythonVenvsArgs = parse_arguments(arguments)?; + match args.action.as_str() { + "list" => self.handle_list_venvs().await, + "delete" => self.handle_delete_venv(&args).await, + "info" => self.handle_info_venv(&args).await, + _ => Err(format!("unknown python_venvs action: {}", args.action)), + } + } + + async fn handle_list_venvs(&self) -> Result { + let response = VenvListResponse { + venvs: self.venv_manager.list_venvs().await?, + }; + serialize_response(&response) + } + + async fn handle_delete_venv(&self, args: &PythonVenvsArgs) -> Result { + let name = required_venv_name(args)?; + self.venv_manager.delete_venv(name).await?; + serialize_response(&VenvDeleteResponse { + deleted: name.to_string(), + }) + } + + async fn handle_info_venv(&self, args: &PythonVenvsArgs) -> Result { + let name = required_venv_name(args)?; + let packages = self.venv_manager.info(name).await?; + serialize_response(&VenvInfoResponse { + name: name.to_string(), + packages, + }) + } +} + +#[async_trait] +impl Skill for PythonSkill { + fn name(&self) -> &str { + "python" + } + + fn description(&self) -> &str { + "Execute Python code and manage Python virtual environments." + } + + fn tool_definitions(&self) -> Vec { + vec![ + python_run_definition(), + python_install_definition(), + python_venvs_definition(), + ] + } + + fn cacheability(&self, _tool_name: &str) -> ToolCacheability { + ToolCacheability::SideEffect + } + + async fn execute( + &self, + tool_name: &str, + arguments: &str, + cancel: Option<&CancellationToken>, + ) -> Option> { + match tool_name { + "python_run" => Some(self.handle_run(arguments, cancel).await), + "python_install" => Some(self.handle_install(arguments, cancel).await), + "python_venvs" => Some(self.handle_venvs(arguments).await), + _ => None, + } + } +} + +fn python_run_definition() -> ToolDefinition { + ToolDefinition { + name: "python_run".to_string(), + description: "Run Python code inside a named virtual environment and report output, exit code, and generated artifacts.".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python source code to execute" + }, + "venv": { + "type": "string", + "description": "Virtual environment name" + }, + "timeout_seconds": { + "type": "integer", + "description": "Execution timeout in seconds", + "default": 300 + } + }, + "required": ["code", "venv"] + }), + } +} + +fn python_install_definition() -> ToolDefinition { + ToolDefinition { + name: "python_install".to_string(), + description: "Install Python packages into a named virtual environment using pip." + .to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "packages": { + "type": "array", + "items": { "type": "string" }, + "description": "Packages to install" + }, + "venv": { + "type": "string", + "description": "Virtual environment name" + }, + "requirements_file": { + "type": ["string", "null"], + "description": "Optional requirements file path inside the experiment directory" + }, + "timeout_seconds": { + "type": "integer", + "description": "Pip install timeout in seconds", + "default": 600 + } + }, + "required": ["venv"] + }), + } +} + +fn python_venvs_definition() -> ToolDefinition { + ToolDefinition { + name: "python_venvs".to_string(), + description: "List, inspect, or delete managed Python virtual environments.".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["list", "delete", "info"], + "description": "Virtual environment action to perform" + }, + "name": { + "type": "string", + "description": "Virtual environment name for delete or info" + } + }, + "required": ["action"] + }), + } +} + +fn required_venv_name(args: &PythonVenvsArgs) -> Result<&str, SkillError> { + args.name + .as_deref() + .ok_or_else(|| "python_venvs action requires 'name'".to_string()) +} + +fn parse_arguments(arguments: &str) -> Result +where + T: DeserializeOwned, +{ + serde_json::from_str(arguments).map_err(|error| format!("invalid arguments: {error}")) +} + +fn serialize_response(value: &T) -> Result +where + T: Serialize, +{ + serde_json::to_string(value).map_err(|error| format!("serialize response: {error}")) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn skill_in_tempdir(temp_dir: &TempDir) -> PythonSkill { + PythonSkill::new(temp_dir.path()) + } + + #[test] + fn tool_definitions_count() { + let temp_dir = TempDir::new().expect("tempdir"); + let skill = skill_in_tempdir(&temp_dir); + + assert_eq!(skill.tool_definitions().len(), 3); + } + + #[tokio::test] + async fn unknown_tool_returns_none() { + let temp_dir = TempDir::new().expect("tempdir"); + let skill = skill_in_tempdir(&temp_dir); + + let result = skill.execute("unknown_tool", "{}", None).await; + + assert!(result.is_none()); + } + + #[tokio::test] + async fn python_run_uses_cancellation_token() { + let temp_dir = TempDir::new().expect("tempdir"); + let skill = skill_in_tempdir(&temp_dir); + let token = CancellationToken::new(); + let cancel = token.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + cancel.cancel(); + }); + + let result = skill + .execute( + "python_run", + r#"{"code":"import time\ntime.sleep(5)\n","venv":"test"}"#, + Some(&token), + ) + .await + .expect("known tool") + .expect_err("run should be cancelled"); + + assert!(result.contains("cancelled")); + } +} diff --git a/engine/crates/fx-python/src/process.rs b/engine/crates/fx-python/src/process.rs new file mode 100644 index 00000000..a51053aa --- /dev/null +++ b/engine/crates/fx-python/src/process.rs @@ -0,0 +1,233 @@ +use fx_kernel::cancellation::CancellationToken; +use std::process::Output; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::process::{Child, Command}; +use tokio::task::JoinHandle; + +pub(crate) const MAX_TIMEOUT_SECONDS: u64 = 3_600; +const MAX_CAPTURE_BYTES: u64 = 512 * 1024; + +pub(crate) struct CapturedProcess { + pub stdout: String, + pub stderr: String, + pub status: ProcessStatus, +} + +pub(crate) enum ProcessStatus { + Exited(i32), + TimedOut, + Cancelled, +} + +struct CaptureTasks { + stdout: JoinHandle>, + stderr: JoinHandle>, +} + +enum StopReason { + Timeout, + Cancellation, +} + +impl StopReason { + fn as_str(&self) -> &'static str { + match self { + Self::Timeout => "timeout", + Self::Cancellation => "cancellation", + } + } + + fn status(&self) -> ProcessStatus { + match self { + Self::Timeout => ProcessStatus::TimedOut, + Self::Cancellation => ProcessStatus::Cancelled, + } + } +} + +pub(crate) async fn run_command( + mut command: Command, + action: &str, + timeout: Duration, + cancel: Option<&CancellationToken>, +) -> Result { + configure_process_group(&mut command); + let mut child = command + .spawn() + .map_err(|error| format!("failed to start {action}: {error}"))?; + let captures = spawn_capture_tasks(&mut child)?; + let status = wait_for_child(&mut child, action, timeout, cancel).await?; + let stdout = join_capture(captures.stdout, action, "stdout").await?; + let stderr = join_capture(captures.stderr, action, "stderr").await?; + + Ok(CapturedProcess { + stdout, + stderr, + status, + }) +} + +pub(crate) fn elapsed_millis(duration: Duration) -> u64 { + duration.as_millis().min(u128::from(u64::MAX)) as u64 +} + +pub(crate) fn format_process_output(output: &Output) -> String { + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + format_process_detail(output.status.code(), &stdout, &stderr) +} + +pub(crate) fn format_process_detail( + status_code: Option, + stdout: &str, + stderr: &str, +) -> String { + let detail = stderr.trim(); + let detail = if detail.is_empty() { + stdout.trim() + } else { + detail + }; + format!("status {status_code:?}; {detail}") +} + +async fn wait_for_child( + child: &mut Child, + action: &str, + timeout: Duration, + cancel: Option<&CancellationToken>, +) -> Result { + if let Some(token) = cancel { + tokio::select! { + result = tokio::time::timeout(timeout, child.wait()) => { + handle_wait_result(child, action, result).await + } + _ = token.cancelled() => stop_child(child, action, StopReason::Cancellation).await, + } + } else { + let result = tokio::time::timeout(timeout, child.wait()).await; + handle_wait_result(child, action, result).await + } +} + +async fn handle_wait_result( + child: &mut Child, + action: &str, + result: Result, tokio::time::error::Elapsed>, +) -> Result { + match result { + Ok(wait) => { + let status = wait.map_err(|error| format!("failed to wait for {action}: {error}"))?; + Ok(ProcessStatus::Exited(status.code().unwrap_or(-1))) + } + Err(_) => stop_child(child, action, StopReason::Timeout).await, + } +} + +async fn stop_child( + child: &mut Child, + action: &str, + reason: StopReason, +) -> Result { + signal_child(child, action, &reason).await?; + child + .wait() + .await + .map_err(|error| format!("failed to wait after stopping {action}: {error}"))?; + Ok(reason.status()) +} + +#[cfg(unix)] +async fn signal_child(child: &mut Child, action: &str, reason: &StopReason) -> Result<(), String> { + if let Some(pid) = child.id() { + // SAFETY: killpg targets the child process group created in configure_process_group. + let result = unsafe { libc::killpg(pid as i32, libc::SIGKILL) }; + if result == 0 || std::io::Error::last_os_error().raw_os_error() == Some(libc::ESRCH) { + return Ok(()); + } + let error = std::io::Error::last_os_error(); + return Err(format!( + "failed to kill {action} process group after {}: {error}", + reason.as_str() + )); + } + + child + .kill() + .await + .map_err(|error| format!("failed to kill {action} after {}: {error}", reason.as_str())) +} + +#[cfg(not(unix))] +async fn signal_child(child: &mut Child, action: &str, reason: &StopReason) -> Result<(), String> { + child + .kill() + .await + .map_err(|error| format!("failed to kill {action} after {}: {error}", reason.as_str())) +} + +#[cfg(unix)] +fn configure_process_group(command: &mut Command) { + // SAFETY: pre_exec runs in the child just before exec; the closure only + // calls async-signal-safe libc::setpgid to move that child into its own + // process group so timeout/cancel can terminate the full tree. + unsafe { + command.pre_exec(|| { + // SAFETY: setpgid(0, 0) only affects the current child process. + if libc::setpgid(0, 0) != 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) + }); + } +} + +#[cfg(not(unix))] +fn configure_process_group(_command: &mut Command) {} + +fn spawn_capture_tasks(child: &mut Child) -> Result { + let stdout = child + .stdout + .take() + .ok_or_else(|| "stdout pipe unavailable".to_string())?; + let stderr = child + .stderr + .take() + .ok_or_else(|| "stderr pipe unavailable".to_string())?; + + Ok(CaptureTasks { + stdout: spawn_capture_task(stdout), + stderr: spawn_capture_task(stderr), + }) +} + +fn spawn_capture_task(stream: R) -> JoinHandle> +where + R: AsyncRead + Unpin + Send + 'static, +{ + tokio::spawn(async move { capture_stream(stream).await }) +} + +async fn capture_stream(stream: R) -> Result +where + R: AsyncRead + Unpin, +{ + let mut reader = stream.take(MAX_CAPTURE_BYTES); + let mut bytes = Vec::new(); + reader + .read_to_end(&mut bytes) + .await + .map_err(|error| format!("failed to read process output: {error}"))?; + Ok(String::from_utf8_lossy(&bytes).to_string()) +} + +async fn join_capture( + handle: JoinHandle>, + action: &str, + stream_name: &str, +) -> Result { + handle + .await + .map_err(|error| format!("failed to join {action} {stream_name} capture: {error}"))? +} diff --git a/engine/crates/fx-python/src/runner.rs b/engine/crates/fx-python/src/runner.rs new file mode 100644 index 00000000..b611ca27 --- /dev/null +++ b/engine/crates/fx-python/src/runner.rs @@ -0,0 +1,421 @@ +use crate::process::{ + elapsed_millis, run_command, CapturedProcess, ProcessStatus, MAX_TIMEOUT_SECONDS, +}; +use crate::venv::VenvManager; +use fx_kernel::cancellation::CancellationToken; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use tokio::fs; +use tokio::process::Command; + +const TIMEOUT_EXIT_CODE: i32 = -1; +const DEFAULT_TIMEOUT_SECONDS: u64 = 300; +static SCRIPT_FILE_COUNTER: AtomicU64 = AtomicU64::new(0); + +type FileSnapshot = BTreeMap; + +#[derive(Debug, Clone)] +pub struct PythonRunner { + venv_manager: VenvManager, + experiments_root: PathBuf, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct PythonRunArgs { + pub code: String, + pub venv: String, + #[serde(default = "default_timeout_seconds")] + pub timeout_seconds: u64, +} + +#[derive(Debug, Serialize)] +pub(crate) struct RunResult { + pub stdout: String, + pub stderr: String, + pub exit_code: i32, + pub artifacts: Vec, + pub duration_ms: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct FileState { + modified: SystemTime, + size: u64, +} + +struct ExecutionRequest { + python_path: PathBuf, + script_path: PathBuf, + work_dir: PathBuf, + timeout_seconds: u64, +} + +struct CommandOutput { + stdout: String, + stderr: String, + exit_code: i32, + timed_out: bool, +} + +impl PythonRunner { + #[must_use] + pub fn new(venv_manager: VenvManager, experiments_root: PathBuf) -> Self { + Self { + venv_manager, + experiments_root, + } + } + + pub async fn run( + &self, + args: PythonRunArgs, + cancel: Option<&CancellationToken>, + ) -> Result { + self.venv_manager.ensure_venv(&args.venv).await?; + let work_dir = self.experiments_root.join(&args.venv); + fs::create_dir_all(&work_dir).await.map_err(|error| { + format!( + "failed to create experiment dir '{}': {error}", + work_dir.display() + ) + })?; + + let timeout_seconds = clamp_timeout_seconds(args.timeout_seconds); + let script_path = work_dir.join(script_file_name()); + fs::write(&script_path, args.code) + .await + .map_err(|error| format!("failed to write '{}': {error}", script_path.display()))?; + + let before = snapshot_files(&work_dir).await?; + let started = Instant::now(); + let output = execute_script( + ExecutionRequest { + python_path: self.venv_manager.python_path(&args.venv), + script_path: script_path.clone(), + work_dir: work_dir.clone(), + timeout_seconds, + }, + cancel, + ) + .await?; + cleanup_script_after_success(&script_path, output.exit_code).await?; + let after = snapshot_files(&work_dir).await?; + + Ok(RunResult { + stdout: output.stdout, + stderr: finalize_stderr(output.stderr, timeout_seconds, output.timed_out), + exit_code: output.exit_code, + artifacts: detect_artifacts(&before, &after), + duration_ms: elapsed_millis(started.elapsed()), + }) + } +} + +async fn execute_script( + request: ExecutionRequest, + cancel: Option<&CancellationToken>, +) -> Result { + let mut command = Command::new(&request.python_path); + command + .arg(&request.script_path) + .current_dir(&request.work_dir) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let output = run_command( + command, + "python", + Duration::from_secs(request.timeout_seconds), + cancel, + ) + .await?; + to_command_output(output) +} + +fn to_command_output(output: CapturedProcess) -> Result { + match output.status { + ProcessStatus::Exited(exit_code) => Ok(CommandOutput { + stdout: output.stdout, + stderr: output.stderr, + exit_code, + timed_out: false, + }), + ProcessStatus::TimedOut => Ok(CommandOutput { + stdout: output.stdout, + stderr: output.stderr, + exit_code: TIMEOUT_EXIT_CODE, + timed_out: true, + }), + ProcessStatus::Cancelled => Err("python execution cancelled".to_string()), + } +} + +fn finalize_stderr(stderr: String, timeout_seconds: u64, timed_out: bool) -> String { + if !timed_out { + return stderr; + } + + let suffix = format!("Execution timed out after {timeout_seconds} seconds"); + if stderr.trim().is_empty() { + return suffix; + } + + format!("{stderr}\n{suffix}") +} + +async fn snapshot_files(dir: &Path) -> Result { + let exists = fs::try_exists(dir) + .await + .map_err(|error| format!("failed to inspect '{}': {error}", dir.display()))?; + if !exists { + return Ok(BTreeMap::new()); + } + + let mut snapshot = BTreeMap::new(); + collect_snapshot(dir, &mut snapshot).await?; + Ok(snapshot) +} + +async fn collect_snapshot(base: &Path, snapshot: &mut FileSnapshot) -> Result<(), String> { + let mut pending = vec![base.to_path_buf()]; + while let Some(dir) = pending.pop() { + let mut entries = fs::read_dir(&dir) + .await + .map_err(|error| format!("failed to read '{}': {error}", dir.display()))?; + while let Some(entry) = entries + .next_entry() + .await + .map_err(|error| format!("failed to inspect directory entry: {error}"))? + { + let path = entry.path(); + let metadata = entry.metadata().await.map_err(|error| { + format!("failed to read metadata for '{}': {error}", path.display()) + })?; + if metadata.is_dir() { + pending.push(path); + continue; + } + if metadata.is_file() { + insert_snapshot_entry(base, &path, metadata, snapshot)?; + } + } + } + Ok(()) +} + +fn insert_snapshot_entry( + base: &Path, + path: &Path, + metadata: std::fs::Metadata, + snapshot: &mut FileSnapshot, +) -> Result<(), String> { + let relative = path + .strip_prefix(base) + .map_err(|error| format!("failed to strip base path: {error}"))?; + snapshot.insert( + relative.to_path_buf(), + FileState { + modified: metadata.modified().map_err(|error| { + format!("failed to read mtime for '{}': {error}", path.display()) + })?, + size: metadata.len(), + }, + ); + Ok(()) +} + +fn detect_artifacts(before: &FileSnapshot, after: &FileSnapshot) -> Vec { + after + .iter() + .filter_map(|(path, state)| match before.get(path) { + Some(previous) if previous == state => None, + _ => Some(path.to_string_lossy().into_owned()), + }) + .collect() +} + +async fn cleanup_script_after_success(script_path: &Path, exit_code: i32) -> Result<(), String> { + if exit_code != 0 { + return Ok(()); + } + + match fs::remove_file(script_path).await { + Ok(()) => Ok(()), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()), + Err(error) => Err(format!( + "failed to clean up script '{}': {error}", + script_path.display() + )), + } +} + +fn script_file_name() -> String { + let timestamp = unix_timestamp_millis(); + let counter = SCRIPT_FILE_COUNTER.fetch_add(1, Ordering::Relaxed); + format!("run_{timestamp}_{counter}.py") +} + +fn unix_timestamp_millis() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_millis()) + .unwrap_or_default() +} + +fn clamp_timeout_seconds(timeout_seconds: u64) -> u64 { + timeout_seconds.min(MAX_TIMEOUT_SECONDS) +} + +fn default_timeout_seconds() -> u64 { + DEFAULT_TIMEOUT_SECONDS +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::BTreeSet; + use tempfile::TempDir; + + fn runner_in_tempdir(temp_dir: &TempDir) -> PythonRunner { + let venv_root = temp_dir.path().join("venvs"); + let experiments_root = temp_dir.path().join("experiments"); + let manager = VenvManager::new(&venv_root); + PythonRunner::new(manager, experiments_root) + } + + fn run_args(code: &str, timeout_seconds: u64) -> PythonRunArgs { + PythonRunArgs { + code: code.to_string(), + venv: "test".to_string(), + timeout_seconds, + } + } + + #[tokio::test] + async fn run_simple_code() { + let temp_dir = TempDir::new().expect("tempdir"); + let runner = runner_in_tempdir(&temp_dir); + + let result = runner + .run(run_args("print(1 + 1)\n", 300), None) + .await + .expect("python ran"); + + assert_eq!(result.stdout.trim(), "2"); + assert_eq!(result.exit_code, 0); + } + + #[tokio::test] + async fn timeout_kills() { + let temp_dir = TempDir::new().expect("tempdir"); + let runner = runner_in_tempdir(&temp_dir); + let code = "import time\ntime.sleep(2)\n"; + + let result = runner + .run(run_args(code, 1), None) + .await + .expect("python ran"); + + assert_eq!(result.exit_code, TIMEOUT_EXIT_CODE); + assert!(result.stderr.contains("timed out")); + } + + #[tokio::test] + async fn artifact_detection() { + let temp_dir = TempDir::new().expect("tempdir"); + let runner = runner_in_tempdir(&temp_dir); + let code = "from pathlib import Path\nPath('artifact.txt').write_text('hello')\n"; + + let result = runner + .run(run_args(code, 300), None) + .await + .expect("python ran"); + + assert!(result.artifacts.contains(&"artifact.txt".to_string())); + } + + #[tokio::test] + async fn cancellation_stops_python_run() { + let temp_dir = TempDir::new().expect("tempdir"); + let runner = runner_in_tempdir(&temp_dir); + let token = CancellationToken::new(); + let cancel = token.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(100)).await; + cancel.cancel(); + }); + + let error = runner + .run(run_args("import time\ntime.sleep(5)\n", 300), Some(&token)) + .await + .expect_err("run should be cancelled"); + + assert!(error.contains("cancelled")); + } + + #[tokio::test] + async fn successful_run_cleans_up_script_file() { + let temp_dir = TempDir::new().expect("tempdir"); + let runner = runner_in_tempdir(&temp_dir); + + runner + .run(run_args("print('ok')\n", 300), None) + .await + .expect("python ran"); + + let experiment_dir = temp_dir.path().join("experiments/test"); + let mut entries = fs::read_dir(&experiment_dir) + .await + .expect("read experiment dir"); + while let Some(entry) = entries.next_entry().await.expect("read entry") { + let file_name = entry.file_name().to_string_lossy().into_owned(); + assert!( + !file_name.starts_with("run_"), + "unexpected script file: {file_name}" + ); + } + } + + #[cfg(unix)] + #[tokio::test] + async fn timeout_kills_process_group_children() { + let temp_dir = TempDir::new().expect("tempdir"); + let runner = runner_in_tempdir(&temp_dir); + let code = concat!( + "import subprocess, sys, time\n", + "subprocess.Popen([\n", + " sys.executable,\n", + " '-c',\n", + " \"import time; from pathlib import Path; time.sleep(2); Path('orphan.txt').write_text('still-running')\",\n", + "])\n", + "time.sleep(10)\n" + ); + + let result = runner + .run(run_args(code, 1), None) + .await + .expect("python ran"); + + assert_eq!(result.exit_code, TIMEOUT_EXIT_CODE); + tokio::time::sleep(Duration::from_secs(3)).await; + assert!(!temp_dir.path().join("experiments/test/orphan.txt").exists()); + } + + #[test] + fn timeout_seconds_are_clamped() { + assert_eq!( + clamp_timeout_seconds(MAX_TIMEOUT_SECONDS + 99), + MAX_TIMEOUT_SECONDS + ); + } + + #[test] + fn script_file_names_are_unique() { + let names: BTreeSet<_> = (0..64).map(|_| script_file_name()).collect(); + assert_eq!(names.len(), 64); + } +} diff --git a/engine/crates/fx-python/src/venv.rs b/engine/crates/fx-python/src/venv.rs new file mode 100644 index 00000000..8c3cf13c --- /dev/null +++ b/engine/crates/fx-python/src/venv.rs @@ -0,0 +1,327 @@ +use crate::process::format_process_output; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use std::process::Output; +use tokio::fs; +use tokio::process::Command; + +#[derive(Debug, Clone)] +pub struct VenvManager { + root: PathBuf, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct PackageInfo { + pub name: String, + pub version: String, +} + +impl VenvManager { + #[must_use] + pub fn new(root: &Path) -> Self { + Self { + root: root.to_path_buf(), + } + } + + #[must_use] + pub fn venv_path(&self, name: &str) -> PathBuf { + self.root.join(name) + } + + #[must_use] + pub fn python_path(&self, name: &str) -> PathBuf { + self.venv_path(name).join("bin").join("python") + } + + #[must_use] + pub fn pip_path(&self, name: &str) -> PathBuf { + self.venv_path(name).join("bin").join("pip") + } + + pub async fn ensure_venv(&self, name: &str) -> Result { + validate_venv_name(name)?; + self.ensure_root().await?; + + let venv_path = self.venv_path(name); + if path_exists(&self.python_path(name)).await? { + self.ensure_pip_entrypoint(name).await?; + return Ok(venv_path); + } + + let python = detect_python_binary().await?; + create_venv(&python, &venv_path, name).await?; + self.ensure_pip_entrypoint(name).await?; + Ok(venv_path) + } + + pub async fn list_venvs(&self) -> Result, String> { + if !path_exists(&self.root).await? { + return Ok(Vec::new()); + } + + let mut entries = fs::read_dir(&self.root) + .await + .map_err(|error| format!("failed to read '{}': {error}", self.root.display()))?; + let mut names = Vec::new(); + + while let Some(entry) = entries + .next_entry() + .await + .map_err(|error| format!("failed to read venv entry: {error}"))? + { + if entry + .file_type() + .await + .map_err(|error| format!("failed to inspect venv entry: {error}"))? + .is_dir() + { + names.push(entry.file_name().to_string_lossy().into_owned()); + } + } + + names.sort(); + Ok(names) + } + + pub async fn delete_venv(&self, name: &str) -> Result<(), String> { + validate_venv_name(name)?; + let path = self.venv_path(name); + if !path_exists(&path).await? { + return Ok(()); + } + + fs::remove_dir_all(&path) + .await + .map_err(|error| format!("failed to delete venv '{name}': {error}")) + } + + pub async fn info(&self, name: &str) -> Result, String> { + validate_venv_name(name)?; + ensure_existing_venv(self, name).await?; + + let output = Command::new(self.pip_path(name)) + .args(["list", "--format=json"]) + .output() + .await + .map_err(|error| format!("failed to inspect venv '{name}': {error}"))?; + let output = require_success(output, &format!("inspect venv '{name}'"))?; + parse_package_list(&output.stdout) + } + + async fn ensure_root(&self) -> Result<(), String> { + fs::create_dir_all(&self.root).await.map_err(|error| { + format!( + "failed to create venv root '{}': {error}", + self.root.display() + ) + }) + } + + async fn ensure_pip_entrypoint(&self, name: &str) -> Result<(), String> { + let pip_path = self.pip_path(name); + if path_exists(&pip_path).await? { + return Ok(()); + } + + write_pip_shim(&self.python_path(name), &pip_path).await + } +} + +async fn ensure_existing_venv(manager: &VenvManager, name: &str) -> Result<(), String> { + if path_exists(&manager.python_path(name)).await? { + return Ok(()); + } + + Err(format!("venv '{name}' does not exist")) +} + +fn validate_venv_name(name: &str) -> Result<(), String> { + let valid = !name.is_empty() + && name + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_'); + if valid { + return Ok(()); + } + + Err("venv names must use only letters, numbers, '-' or '_'".to_string()) +} + +async fn detect_python_binary() -> Result { + for candidate in ["python3", "python"] { + if supports_python_three(candidate).await? { + return Ok(candidate.to_string()); + } + } + + Err("python 3 interpreter not found; tried python3 and python".to_string()) +} + +async fn supports_python_three(candidate: &str) -> Result { + match Command::new(candidate).arg("--version").output().await { + Ok(output) => Ok(matches!(parse_python_major_version(&output), Some(major) if major >= 3)), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(false), + Err(error) => Err(format!("failed to inspect {candidate}: {error}")), + } +} + +fn parse_python_major_version(output: &Output) -> Option { + let version = format!( + "{}{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let line = version.lines().next()?.trim(); + let version = line.strip_prefix("Python ")?; + version.split('.').next()?.parse().ok() +} + +async fn create_venv(python: &str, venv_path: &Path, name: &str) -> Result<(), String> { + let output = run_venv_command(python, venv_path, false).await?; + if output.status.success() { + return Ok(()); + } + if ensurepip_missing(&output) { + let fallback = run_venv_command(python, venv_path, true).await?; + let _ = require_success(fallback, &format!("create venv '{name}'"))?; + return Ok(()); + } + + Err(format!( + "create venv '{name}' failed: {}", + format_process_output(&output) + )) +} + +async fn run_venv_command( + python: &str, + venv_path: &Path, + without_pip: bool, +) -> Result { + let mut command = Command::new(python); + command.args(["-m", "venv"]); + if without_pip { + command.arg("--without-pip"); + } + command.arg(venv_path); + command + .output() + .await + .map_err(|error| format!("failed to create venv '{}': {error}", venv_path.display())) +} + +fn ensurepip_missing(output: &Output) -> bool { + let stderr = String::from_utf8_lossy(&output.stderr).to_ascii_lowercase(); + let stdout = String::from_utf8_lossy(&output.stdout).to_ascii_lowercase(); + stderr.contains("ensurepip") || stdout.contains("ensurepip") +} + +async fn write_pip_shim(python_path: &Path, pip_path: &Path) -> Result<(), String> { + let content = format!( + "#!/usr/bin/env sh\n\"{}\" -m pip \"$@\"\n", + python_path.display() + ); + fs::write(pip_path, content) + .await + .map_err(|error| format!("failed to write pip shim '{}': {error}", pip_path.display()))?; + set_executable_permissions(pip_path).await +} + +#[cfg(unix)] +async fn set_executable_permissions(path: &Path) -> Result<(), String> { + use std::os::unix::fs::PermissionsExt; + + fs::set_permissions(path, std::fs::Permissions::from_mode(0o755)) + .await + .map_err(|error| format!("failed to chmod '{}': {error}", path.display())) +} + +#[cfg(not(unix))] +async fn set_executable_permissions(_path: &Path) -> Result<(), String> { + Ok(()) +} + +fn require_success(output: Output, action: &str) -> Result { + if output.status.success() { + return Ok(output); + } + + Err(format!( + "{action} failed: {}", + format_process_output(&output) + )) +} + +fn parse_package_list(stdout: &[u8]) -> Result, String> { + serde_json::from_slice(stdout) + .map_err(|error| format!("failed to parse pip list output: {error}")) +} + +async fn path_exists(path: &Path) -> Result { + fs::try_exists(path) + .await + .map_err(|error| format!("failed to inspect '{}': {error}", path.display())) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn create_venv() { + let temp_dir = TempDir::new().expect("tempdir"); + let manager = VenvManager::new(temp_dir.path()); + + let path = manager.ensure_venv("alpha").await.expect("venv created"); + + assert!(path.exists()); + assert!(manager.python_path("alpha").exists()); + } + + #[tokio::test] + async fn list_venvs() { + let temp_dir = TempDir::new().expect("tempdir"); + let manager = VenvManager::new(temp_dir.path()); + manager.ensure_venv("alpha").await.expect("alpha created"); + manager.ensure_venv("beta").await.expect("beta created"); + + let venvs = manager.list_venvs().await.expect("venvs listed"); + + assert_eq!(venvs, vec!["alpha".to_string(), "beta".to_string()]); + } + + #[tokio::test] + async fn delete_venv() { + let temp_dir = TempDir::new().expect("tempdir"); + let manager = VenvManager::new(temp_dir.path()); + let path = manager.ensure_venv("alpha").await.expect("venv created"); + + manager.delete_venv("alpha").await.expect("venv deleted"); + + assert!(!path.exists()); + } + + #[test] + fn parse_python_major_version_accepts_python_three() { + let output = Output { + status: std::process::ExitStatus::default(), + stdout: b"Python 3.12.2\n".to_vec(), + stderr: Vec::new(), + }; + + assert_eq!(parse_python_major_version(&output), Some(3)); + } + + #[test] + fn parse_python_major_version_rejects_python_two() { + let output = Output { + status: std::process::ExitStatus::default(), + stdout: Vec::new(), + stderr: b"Python 2.7.18\n".to_vec(), + }; + + assert_eq!(parse_python_major_version(&output), Some(2)); + } +} diff --git a/engine/crates/fx-ripcord/src/git_guard.rs b/engine/crates/fx-ripcord/src/git_guard.rs new file mode 100644 index 00000000..3d4086b2 --- /dev/null +++ b/engine/crates/fx-ripcord/src/git_guard.rs @@ -0,0 +1,338 @@ +const ALL_BRANCHES_TARGET: &str = "*"; + +/// Check whether any push targets are protected branches. +/// Returns Err with a user-facing message listing blocked branches. +#[must_use = "the push guard result must be checked before running git push"] +pub fn check_push_allowed(targets: &[String], protected_branches: &[String]) -> Result<(), String> { + let blocked = blocked_branches(targets, protected_branches); + if blocked.is_empty() { + return Ok(()); + } + Err(format_blocked_push(&blocked)) +} + +/// Extract target branches from a shell command string. +/// Returns empty vec if the command is not a git push or targets can't be determined. +pub fn extract_push_targets(command: &str) -> Vec { + let tokens: Vec<&str> = command.split_whitespace().collect(); + match push_refspecs(&tokens) { + Some(PushTargets::Refspecs(refspecs)) => { + refspecs.into_iter().filter_map(normalize_target).collect() + } + Some(PushTargets::AllBranches) => vec![ALL_BRANCHES_TARGET.to_string()], + Some(PushTargets::NoBranchTargets) | None => Vec::new(), + } +} + +fn blocked_branches(targets: &[String], protected_branches: &[String]) -> Vec { + if targets.is_empty() || protected_branches.is_empty() { + return Vec::new(); + } + if targets.iter().any(|target| target == ALL_BRANCHES_TARGET) { + return unique_branches(protected_branches); + } + + let mut blocked = Vec::new(); + for target in targets { + if protected_branches.iter().any(|branch| branch == target) + && !blocked.iter().any(|branch| branch == target) + { + blocked.push(target.clone()); + } + } + blocked +} + +fn unique_branches(branches: &[String]) -> Vec { + let mut unique = Vec::new(); + for branch in branches { + if !unique.iter().any(|existing| existing == branch) { + unique.push(branch.clone()); + } + } + unique +} + +fn format_blocked_push(blocked: &[String]) -> String { + let branches = blocked + .iter() + .map(|branch| format!("'{branch}'")) + .collect::>() + .join(", "); + format!( + "Blocked: push to protected branch(es) {branches}. Protected branches can only be updated through pull requests." + ) +} + +enum PushTargets<'a> { + Refspecs(Vec<&'a str>), + AllBranches, + NoBranchTargets, +} + +#[derive(Clone, Copy)] +struct SkipFlag { + tokens: usize, + repository_from_flag: bool, +} + +impl SkipFlag { + const fn new(tokens: usize, repository_from_flag: bool) -> Self { + Self { + tokens, + repository_from_flag, + } + } +} + +fn push_refspecs<'a>(tokens: &'a [&'a str]) -> Option> { + if !is_git_push(tokens) { + return None; + } + + let mut positionals = Vec::new(); + let mut saw_tags = false; + let mut repository_from_flag = false; + let mut index = 2; + + while index < tokens.len() { + let token = tokens[index]; + if token == "--delete" { + return None; + } + if token == "--all" || token == "--mirror" { + return Some(PushTargets::AllBranches); + } + if token == "--tags" { + saw_tags = true; + index += 1; + continue; + } + if let Some(skip_flag) = should_skip_flag(token) { + repository_from_flag |= skip_flag.repository_from_flag; + index = skip_tokens(tokens, index, skip_flag.tokens)?; + continue; + } + if token.starts_with('-') { + return None; + } + positionals.push(token); + index += 1; + } + + classify_push_targets(positionals, saw_tags, repository_from_flag) +} + +fn is_git_push(tokens: &[&str]) -> bool { + tokens.len() >= 2 && tokens[0] == "git" && tokens[1] == "push" +} + +fn skip_tokens(tokens: &[&str], index: usize, count: usize) -> Option { + let next = index + count; + (next <= tokens.len()).then_some(next) +} + +fn classify_push_targets<'a>( + positionals: Vec<&'a str>, + saw_tags: bool, + repository_from_flag: bool, +) -> Option> { + if repository_from_flag { + if !positionals.is_empty() { + return Some(PushTargets::Refspecs(positionals)); + } + } else if positionals.len() >= 2 { + return Some(PushTargets::Refspecs( + positionals.into_iter().skip(1).collect(), + )); + } + + if saw_tags { + return Some(PushTargets::NoBranchTargets); + } + None +} + +fn should_skip_flag(flag: &str) -> Option { + if matches!( + flag, + "-f" | "--force" + | "--no-verify" + | "-u" + | "--set-upstream" + | "--force-with-lease" + | "--quiet" + | "-q" + | "--verbose" + | "-v" + | "--dry-run" + | "-n" + ) || flag.starts_with("--force-with-lease=") + || flag.starts_with("--push-option=") + || flag.starts_with("--repo=") + || flag.starts_with("--receive-pack=") + { + let repository_from_flag = flag.starts_with("--repo="); + return Some(SkipFlag::new(1, repository_from_flag)); + } + + match flag { + "-o" | "--push-option" | "--receive-pack" => Some(SkipFlag::new(2, false)), + "--repo" => Some(SkipFlag::new(2, true)), + _ => None, + } +} + +fn normalize_target(refspec: &str) -> Option { + let target = refspec_target(refspec)?; + let target = target.strip_prefix("refs/heads/").unwrap_or(target); + if target.is_empty() || target == "HEAD" || target.starts_with("refs/") { + return None; + } + Some(target.to_string()) +} + +fn refspec_target(refspec: &str) -> Option<&str> { + let cleaned = refspec.strip_prefix('+').unwrap_or(refspec); + match cleaned.split_once(':') { + Some((source, destination)) if source.is_empty() || destination.is_empty() => None, + Some((_, destination)) => Some(destination), + None => Some(cleaned), + } +} + +#[cfg(test)] +mod tests { + use super::{check_push_allowed, extract_push_targets}; + + #[test] + fn check_push_allowed_blocks_protected_branches() { + let targets = vec!["main".to_string(), "dev".to_string(), "staging".to_string()]; + let protected = vec!["main".to_string(), "staging".to_string()]; + + let error = check_push_allowed(&targets, &protected).expect_err("push should be blocked"); + + assert!(error.contains("'main'")); + assert!(error.contains("'staging'")); + assert!(error.contains("pull requests")); + } + + #[test] + fn check_push_allowed_allows_unprotected_branches() { + let targets = vec!["dev".to_string(), "feature/ripcord".to_string()]; + let protected = vec!["main".to_string(), "staging".to_string()]; + + assert!(check_push_allowed(&targets, &protected).is_ok()); + } + + #[test] + fn check_push_allowed_allows_empty_protected_branches() { + let targets = vec!["main".to_string()]; + let protected = Vec::new(); + + assert!(check_push_allowed(&targets, &protected).is_ok()); + } + + #[test] + fn check_push_allowed_allows_empty_targets() { + let targets = Vec::new(); + let protected = vec!["main".to_string()]; + + assert!(check_push_allowed(&targets, &protected).is_ok()); + } + + #[test] + fn extract_push_targets_handles_supported_push_forms() { + let cases = [ + ("git push origin main", vec!["main"]), + ("git push origin main staging", vec!["main", "staging"]), + ("git push origin HEAD:main", vec!["main"]), + ("git push origin +main", vec!["main"]), + ("git push origin refs/heads/main", vec!["main"]), + ("git push -f origin main", vec!["main"]), + ("git push --force origin main", vec!["main"]), + ("git push --no-verify origin main", vec!["main"]), + ( + "git push origin +HEAD:refs/heads/main refs/heads/staging", + vec!["main", "staging"], + ), + ("git push -o ci.skip origin main", vec!["main"]), + ("git push --push-option ci.skip origin main", vec!["main"]), + ("git push --push-option=ci.skip origin main", vec!["main"]), + ( + "git push --receive-pack git-receive-pack origin main", + vec!["main"], + ), + ( + "git push --receive-pack=git-receive-pack origin main", + vec!["main"], + ), + ( + "git push --repo ssh://example.com/repo.git main", + vec!["main"], + ), + ( + "git push --repo=ssh://example.com/repo.git main", + vec!["main"], + ), + ]; + + for (command, expected) in cases { + let actual = extract_push_targets(command); + assert_eq!(actual, expected, "command: {command}"); + } + } + + #[test] + fn extract_push_targets_blocks_all_branches_push_when_protected_branches_exist() { + let targets = extract_push_targets("git push --all origin"); + let protected = vec!["main".to_string(), "staging".to_string()]; + + let error = check_push_allowed(&targets, &protected).expect_err("push should be blocked"); + + assert!(error.contains("'main'")); + assert!(error.contains("'staging'")); + } + + #[test] + fn extract_push_targets_blocks_mirror_push_when_protected_branches_exist() { + let targets = extract_push_targets("git push --mirror origin"); + let protected = vec!["main".to_string(), "staging".to_string()]; + + let error = check_push_allowed(&targets, &protected).expect_err("push should be blocked"); + + assert!(error.contains("'main'")); + assert!(error.contains("'staging'")); + } + + #[test] + fn extract_push_targets_allows_all_branches_push_when_no_branches_are_protected() { + let targets = extract_push_targets("git push --all origin"); + let protected = Vec::new(); + + assert!(check_push_allowed(&targets, &protected).is_ok()); + } + + #[test] + fn extract_push_targets_ignores_tag_only_pushes() { + assert!(extract_push_targets("git push --tags origin").is_empty()); + } + + #[test] + fn extract_push_targets_returns_empty_when_target_cannot_be_determined() { + let cases = [ + "git push", + "git status", + "git push --delete origin main", + "git push origin :main", + "git push --unknown origin main", + ]; + + for command in cases { + assert!( + extract_push_targets(command).is_empty(), + "command: {command}" + ); + } + } +} diff --git a/engine/crates/fx-ripcord/src/lib.rs b/engine/crates/fx-ripcord/src/lib.rs index f19598be..b14d056e 100644 --- a/engine/crates/fx-ripcord/src/lib.rs +++ b/engine/crates/fx-ripcord/src/lib.rs @@ -6,6 +6,7 @@ pub mod config; pub mod evaluator; +pub mod git_guard; pub mod journal; pub mod revert; pub mod snapshot; diff --git a/engine/crates/fx-security/src/policy/tests.rs b/engine/crates/fx-security/src/policy/tests.rs index 35220de2..36b2700b 100644 --- a/engine/crates/fx-security/src/policy/tests.rs +++ b/engine/crates/fx-security/src/policy/tests.rs @@ -519,10 +519,10 @@ decision = "deny" #[test] fn test_condition_contact_target_partial_eq() { let cond1 = Condition::ContactTarget { - contact: "joe".to_string(), + contact: "owner".to_string(), }; let cond2 = Condition::ContactTarget { - contact: "joe".to_string(), + contact: "owner".to_string(), }; let cond3 = Condition::ContactTarget { contact: "alice".to_string(), diff --git a/engine/crates/fx-session/src/lib.rs b/engine/crates/fx-session/src/lib.rs index d73deba8..efcdcf3b 100644 --- a/engine/crates/fx-session/src/lib.rs +++ b/engine/crates/fx-session/src/lib.rs @@ -11,8 +11,9 @@ pub mod types; pub use registry::{SessionError, SessionRegistry}; pub use session::{ - render_content_blocks, render_content_blocks_with_options, ContentRenderOptions, Session, - SessionContentBlock, SessionMemory, SessionMemoryUpdate, SessionMessage, + max_memory_items, max_memory_tokens, render_content_blocks, render_content_blocks_with_options, + ContentRenderOptions, Session, SessionContentBlock, SessionMemory, SessionMemoryUpdate, + SessionMessage, }; pub use store::SessionStore; pub use types::{ diff --git a/engine/crates/fx-session/src/registry.rs b/engine/crates/fx-session/src/registry.rs index 5bd5f898..128a57d0 100644 --- a/engine/crates/fx-session/src/registry.rs +++ b/engine/crates/fx-session/src/registry.rs @@ -478,11 +478,9 @@ mod tests { .expect("create"); let messages = vec![SessionMessage::text(MessageRole::Assistant, "saved", 7)]; - let memory = SessionMemory { - project: Some("session memory".to_string()), - current_state: Some("testing".to_string()), - ..SessionMemory::default() - }; + let mut memory = SessionMemory::default(); + memory.project = Some("session memory".to_string()); + memory.current_state = Some("testing".to_string()); reg.record_turn(&key, messages, memory.clone()) .expect("record turn"); diff --git a/engine/crates/fx-session/src/session.rs b/engine/crates/fx-session/src/session.rs index 7f06159a..026b6dc0 100644 --- a/engine/crates/fx-session/src/session.rs +++ b/engine/crates/fx-session/src/session.rs @@ -228,12 +228,32 @@ pub struct ContentRenderOptions { pub include_tool_use_id: bool, } -const SESSION_MEMORY_MAX_ITEMS: usize = 20; -const SESSION_MEMORY_MAX_TOKENS: usize = 2_000; +const DEFAULT_SESSION_MEMORY_MAX_ITEMS: usize = 40; +const DEFAULT_SESSION_MEMORY_MAX_TOKENS: usize = 4_000; + +/// Compute the session memory token cap for a given model context window. +#[must_use] +pub fn max_memory_tokens(context_limit: usize) -> usize { + (context_limit / 50).clamp(2_000, 8_000) +} + +/// Compute the per-list session memory item cap for a given model context window. +#[must_use] +pub fn max_memory_items(context_limit: usize) -> usize { + (context_limit / 5_000).clamp(20, 80) +} + +fn default_session_memory_max_items() -> usize { + DEFAULT_SESSION_MEMORY_MAX_ITEMS +} + +fn default_session_memory_max_tokens() -> usize { + DEFAULT_SESSION_MEMORY_MAX_TOKENS +} /// Persistent session memory that survives conversation compaction. /// Contains key facts the agent extracted about the session's purpose and state. -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct SessionMemory { /// What this session is about. #[serde(default, skip_serializing_if = "Option::is_none")] @@ -241,21 +261,82 @@ pub struct SessionMemory { /// Current state of work. #[serde(default, skip_serializing_if = "Option::is_none")] pub current_state: Option, - /// Key decisions made during this session (max 20). + /// Key decisions made during this session. #[serde(default)] pub key_decisions: Vec, /// Files actively being worked on. #[serde(default)] pub active_files: Vec, - /// Custom context the agent wants to remember (max 20). + /// Custom context the agent wants to remember. #[serde(default)] pub custom_context: Vec, /// Unix epoch seconds of last update. #[serde(default)] pub last_updated: u64, + /// Runtime-only list cap applied to tracked memory collections. + #[serde(skip, default = "default_session_memory_max_items")] + max_items: usize, + /// Runtime-only token cap applied to rendered session memory. + #[serde(skip, default = "default_session_memory_max_tokens")] + max_tokens: usize, } +impl Default for SessionMemory { + fn default() -> Self { + Self { + project: None, + current_state: None, + key_decisions: Vec::new(), + active_files: Vec::new(), + custom_context: Vec::new(), + last_updated: 0, + max_items: DEFAULT_SESSION_MEMORY_MAX_ITEMS, + max_tokens: DEFAULT_SESSION_MEMORY_MAX_TOKENS, + } + } +} + +impl PartialEq for SessionMemory { + fn eq(&self, other: &Self) -> bool { + self.project == other.project + && self.current_state == other.current_state + && self.key_decisions == other.key_decisions + && self.active_files == other.active_files + && self.custom_context == other.custom_context + && self.last_updated == other.last_updated + } +} + +impl Eq for SessionMemory {} + impl SessionMemory { + /// Create empty session memory configured for a specific model context window. + #[must_use] + pub fn with_context_limit(context_limit: usize) -> Self { + let mut memory = Self::default(); + memory.set_context_limit(context_limit); + memory + } + + /// Recompute runtime caps for a new model context window. + pub fn set_context_limit(&mut self, context_limit: usize) { + self.max_items = max_memory_items(context_limit); + self.max_tokens = max_memory_tokens(context_limit); + self.trim_to_item_cap(); + } + + /// Return the current per-list item cap. + #[must_use] + pub fn item_cap(&self) -> usize { + self.max_items + } + + /// Return the current token cap. + #[must_use] + pub fn token_cap(&self) -> usize { + self.max_tokens + } + #[must_use] pub fn is_empty(&self) -> bool { self.project.is_none() @@ -291,19 +372,20 @@ impl SessionMemory { candidate.current_state = Some(state); } if let Some(decisions) = update.key_decisions { - append_capped_items(&mut candidate.key_decisions, decisions); + append_capped_items(&mut candidate.key_decisions, decisions, candidate.max_items); } if let Some(files) = update.active_files { candidate.active_files = files; + trim_oldest_items(&mut candidate.active_files, candidate.max_items); } if let Some(context) = update.custom_context { - append_capped_items(&mut candidate.custom_context, context); + append_capped_items(&mut candidate.custom_context, context, candidate.max_items); } let estimated_tokens = candidate.estimated_tokens(); - if estimated_tokens > SESSION_MEMORY_MAX_TOKENS { + if estimated_tokens > candidate.max_tokens { return Err(format!( "Session memory would exceed {} token cap ({} estimated). Be more concise.", - SESSION_MEMORY_MAX_TOKENS, estimated_tokens + candidate.max_tokens, estimated_tokens )); } candidate.last_updated = current_epoch_secs(); @@ -328,6 +410,12 @@ impl SessionMemory { push_session_memory_items(&mut lines, "Context:", &self.custom_context); lines.join("\n") } + + fn trim_to_item_cap(&mut self) { + trim_oldest_items(&mut self.key_decisions, self.max_items); + trim_oldest_items(&mut self.active_files, self.max_items); + trim_oldest_items(&mut self.custom_context, self.max_items); + } } /// Partial update to session memory from the agent's tool call. @@ -340,13 +428,13 @@ pub struct SessionMemoryUpdate { pub custom_context: Option>, } -fn append_capped_items(items: &mut Vec, incoming: Vec) { +fn append_capped_items(items: &mut Vec, incoming: Vec, max_items: usize) { items.extend(incoming); - trim_oldest_items(items); + trim_oldest_items(items, max_items); } -fn trim_oldest_items(items: &mut Vec) { - let excess = items.len().saturating_sub(SESSION_MEMORY_MAX_ITEMS); +fn trim_oldest_items(items: &mut Vec, max_items: usize) { + let excess = items.len().saturating_sub(max_items); if excess > 0 { items.drain(..excess); } @@ -648,8 +736,33 @@ mod tests { } #[test] - fn session_memory_default_is_empty() { - assert!(SessionMemory::default().is_empty()); + fn session_memory_default_is_empty_and_uses_default_caps() { + let memory = SessionMemory::default(); + + assert!(memory.is_empty()); + assert_eq!(memory.token_cap(), DEFAULT_SESSION_MEMORY_MAX_TOKENS); + assert_eq!(memory.item_cap(), DEFAULT_SESSION_MEMORY_MAX_ITEMS); + } + + #[test] + fn max_memory_tokens_scales_with_context_limit() { + assert_eq!(max_memory_tokens(32_000), 2_000); + assert_eq!(max_memory_tokens(200_000), 4_000); + assert_eq!(max_memory_tokens(300_000), 6_000); + } + + #[test] + fn max_memory_items_scales_with_context_limit() { + assert_eq!(max_memory_items(32_000), 20); + assert_eq!(max_memory_items(200_000), 40); + assert_eq!(max_memory_items(300_000), 60); + assert_eq!(max_memory_items(500_000), 80); + } + + #[test] + fn max_memory_tokens_clamps_at_boundaries() { + assert_eq!(max_memory_tokens(16_000), 2_000); + assert_eq!(max_memory_tokens(500_000), 8_000); } #[test] @@ -661,6 +774,7 @@ mod tests { active_files: vec!["engine/crates/fx-session/src/session.rs".to_string()], custom_context: vec!["keep it concise".to_string()], last_updated: 123, + ..SessionMemory::default() }; let json = serde_json::to_string(&memory).expect("serialize memory"); @@ -700,6 +814,11 @@ mod tests { let restored: Session = serde_json::from_value(value).expect("deserialize session"); assert!(restored.memory.is_empty()); + assert_eq!( + restored.memory.token_cap(), + DEFAULT_SESSION_MEMORY_MAX_TOKENS + ); + assert_eq!(restored.memory.item_cap(), DEFAULT_SESSION_MEMORY_MAX_ITEMS); } #[test] @@ -759,19 +878,25 @@ mod tests { } #[test] - fn apply_update_caps_lists_at_twenty_items() { - let mut memory = SessionMemory::default(); + fn apply_update_caps_lists_for_context_limit() { + let mut memory = SessionMemory::with_context_limit(32_000); let mut update = memory_update(); update.key_decisions = Some((0..25).map(|i| format!("decision-{i}")).collect()); + update.active_files = Some((0..22).map(|i| format!("file-{i}.rs")).collect()); update.custom_context = Some((0..22).map(|i| format!("context-{i}")).collect()); memory.apply_update(update).expect("capped update"); - assert_eq!(memory.key_decisions.len(), SESSION_MEMORY_MAX_ITEMS); - assert_eq!(memory.custom_context.len(), SESSION_MEMORY_MAX_ITEMS); + assert_eq!(memory.key_decisions.len(), 20); + assert_eq!(memory.active_files.len(), 20); + assert_eq!(memory.custom_context.len(), 20); assert_eq!( memory.key_decisions.first().map(String::as_str), Some("decision-5") ); + assert_eq!( + memory.active_files.first().map(String::as_str), + Some("file-2.rs") + ); assert_eq!( memory.custom_context.first().map(String::as_str), Some("context-2") @@ -792,7 +917,7 @@ mod tests { fn session_memory_rejects_oversized_updates() { let mut memory = SessionMemory::default(); let mut update = memory_update(); - update.project = Some("x".repeat(SESSION_MEMORY_MAX_TOKENS * 8)); + update.project = Some("x".repeat(DEFAULT_SESSION_MEMORY_MAX_TOKENS * 8)); let error = memory .apply_update(update) diff --git a/engine/crates/fx-skills/src/host_api.rs b/engine/crates/fx-skills/src/host_api.rs index 04fba5d7..5021add8 100644 --- a/engine/crates/fx-skills/src/host_api.rs +++ b/engine/crates/fx-skills/src/host_api.rs @@ -33,6 +33,25 @@ pub trait HostApi: Send + Sync { /// - `body`: Request body (empty string for no body) fn http_request(&self, method: &str, url: &str, headers: &str, body: &str) -> Option; + /// Execute a shell command. Returns JSON: {"stdout": "...", "stderr": "...", "exit_code": N} + /// Returns None if Shell capability is not granted. + fn exec_command(&self, command: &str, timeout_ms: u32) -> Option { + let _ = (command, timeout_ms); + None + } + + /// Read a file's contents as UTF-8. Returns None if Filesystem capability is not granted. + fn read_file(&self, path: &str) -> Option { + let _ = path; + None + } + + /// Write content to a file. Returns true on success. + fn write_file(&self, path: &str, content: &str) -> bool { + let _ = (path, content); + false + } + /// Get the output that was set by the skill. fn get_output(&self) -> String; diff --git a/engine/crates/fx-skills/src/manifest.rs b/engine/crates/fx-skills/src/manifest.rs index b6283d12..d6f3e7de 100644 --- a/engine/crates/fx-skills/src/manifest.rs +++ b/engine/crates/fx-skills/src/manifest.rs @@ -3,6 +3,7 @@ use fx_core::error::SkillError; use semver::Version; use serde::{Deserialize, Serialize}; +use std::str::FromStr; /// Capability a skill can request. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -14,6 +15,10 @@ pub enum Capability { NetworkRestricted { allowed_domains: Vec }, /// Persistent key-value storage Storage, + /// Execute shell commands + Shell, + /// Read and write local files + Filesystem, /// Send notifications Notifications, /// Read sensor data @@ -22,9 +27,11 @@ pub enum Capability { PhoneActions, } -pub const ALL_CAPABILITIES: [Capability; 5] = [ +pub const ALL_CAPABILITIES: [Capability; 7] = [ Capability::Network, Capability::Storage, + Capability::Shell, + Capability::Filesystem, Capability::Notifications, Capability::Sensors, Capability::PhoneActions, @@ -36,6 +43,8 @@ impl Capability { Capability::Network => "network", Capability::NetworkRestricted { .. } => "network_restricted", Capability::Storage => "storage", + Capability::Shell => "shell", + Capability::Filesystem => "filesystem", Capability::Notifications => "notifications", Capability::Sensors => "sensors", Capability::PhoneActions => "phone_actions", @@ -43,9 +52,18 @@ impl Capability { } pub fn parse(value: &str) -> Option { + value.parse().ok() + } +} + +impl FromStr for Capability { + type Err = SkillError; + + fn from_str(value: &str) -> Result { ALL_CAPABILITIES .into_iter() .find(|capability| capability.as_str() == value) + .ok_or_else(|| SkillError::InvalidManifest(format!("Unknown capability '{}'", value))) } } @@ -256,13 +274,15 @@ version = "1.0.0" description = "Test skill" author = "Fawx" api_version = "host_api_v1" -capabilities = ["network", "storage", "notifications", "sensors", "phone_actions"] +capabilities = ["network", "storage", "shell", "filesystem", "notifications", "sensors", "phone_actions"] "#; let manifest = parse_manifest(toml).expect("Should parse"); - assert_eq!(manifest.capabilities.len(), 5); + assert_eq!(manifest.capabilities.len(), 7); assert!(manifest.capabilities.contains(&Capability::Network)); assert!(manifest.capabilities.contains(&Capability::Storage)); + assert!(manifest.capabilities.contains(&Capability::Shell)); + assert!(manifest.capabilities.contains(&Capability::Filesystem)); assert!(manifest.capabilities.contains(&Capability::Notifications)); assert!(manifest.capabilities.contains(&Capability::Sensors)); assert!(manifest.capabilities.contains(&Capability::PhoneActions)); @@ -325,11 +345,29 @@ capabilities = ["network", "storage", "notifications", "sensors", "phone_actions "network_restricted" ); assert_eq!(format!("{}", Capability::Storage), "storage"); + assert_eq!(format!("{}", Capability::Shell), "shell"); + assert_eq!(format!("{}", Capability::Filesystem), "filesystem"); assert_eq!(format!("{}", Capability::Notifications), "notifications"); assert_eq!(format!("{}", Capability::Sensors), "sensors"); assert_eq!(format!("{}", Capability::PhoneActions), "phone_actions"); } + #[test] + fn shell_capability_serializes() { + let json = serde_json::to_string(&Capability::Shell).expect("serialize shell"); + assert_eq!(json, "\"shell\""); + let parsed: Capability = serde_json::from_str(&json).expect("deserialize shell"); + assert_eq!(parsed, Capability::Shell); + } + + #[test] + fn filesystem_capability_serializes() { + let json = serde_json::to_string(&Capability::Filesystem).expect("serialize filesystem"); + assert_eq!(json, "\"filesystem\""); + let parsed: Capability = serde_json::from_str(&json).expect("deserialize filesystem"); + assert_eq!(parsed, Capability::Filesystem); + } + #[test] fn test_capability_display_in_error_message() { let cap = Capability::Network; diff --git a/engine/crates/fx-skills/src/runtime.rs b/engine/crates/fx-skills/src/runtime.rs index f4135121..f383bdd1 100644 --- a/engine/crates/fx-skills/src/runtime.rs +++ b/engine/crates/fx-skills/src/runtime.rs @@ -326,6 +326,9 @@ impl SkillRuntime { .map_err(|e| SkillError::Execution(format!("Failed to link set_output: {}", e)))?; Self::link_http_request(linker)?; + Self::link_exec_command(linker)?; + Self::link_read_file(linker)?; + Self::link_write_file(linker)?; Self::link_v2_host_functions(linker)?; Ok(()) @@ -352,87 +355,178 @@ impl SkillRuntime { body_ptr: u32, body_len: u32| -> u32 { - // Check Network capability if !caller.data().has_capability(&Capability::Network) { tracing::warn!("http_request denied: skill lacks Network capability"); return 0; } - - // Read all string parameters from WASM memory - let method = match caller.data().read_string(&caller, method_ptr, method_len) { - Ok(s) => s, - Err(e) => { - tracing::error!("http_request: failed to read method: {}", e); - return 0; - } + let Some(method) = Self::read_host_string( + &caller, + method_ptr, + method_len, + "http_request method", + ) else { + return 0; }; - - let url = match caller.data().read_string(&caller, url_ptr, url_len) { - Ok(s) => s, - Err(e) => { - tracing::error!("http_request: failed to read url: {}", e); - return 0; - } + let Some(url) = + Self::read_host_string(&caller, url_ptr, url_len, "http_request url") + else { + return 0; }; - - let headers = match caller.data().read_string(&caller, headers_ptr, headers_len) - { - Ok(s) => s, - Err(e) => { - tracing::error!("http_request: failed to read headers: {}", e); - return 0; - } + let Some(headers) = Self::read_host_string( + &caller, + headers_ptr, + headers_len, + "http_request headers", + ) else { + return 0; }; - - let body = match caller.data().read_string(&caller, body_ptr, body_len) { - Ok(s) => s, - Err(e) => { - tracing::error!("http_request: failed to read body: {}", e); - return 0; - } + let Some(body) = + Self::read_host_string(&caller, body_ptr, body_len, "http_request body") + else { + return 0; }; - - // Call the host API - let response = match caller + let Some(response) = caller .data() .api .http_request(&method, &url, &headers, &body) - { - Some(r) => r, - None => return 0, + else { + return 0; }; + Self::write_host_string(&mut caller, &response, "http_request") + }, + ) + .map_err(|e| SkillError::Execution(format!("Failed to link http_request: {}", e)))?; - // Write response to WASM memory - let memory = match caller.data().memory { - Some(m) => m, - None => { - tracing::error!("Memory not initialized for http_request"); - return 0; - } + Ok(()) + } + + fn link_exec_command(linker: &mut Linker) -> Result<(), SkillError> { + linker + .func_wrap( + "host_api_v1", + "exec_command", + |mut caller: Caller<'_, HostState>, + command_ptr: u32, + command_len: u32, + timeout_ms: u32| + -> u32 { + if !caller.data().has_capability(&Capability::Shell) { + tracing::warn!("exec_command denied: skill lacks Shell capability"); + return 0; + } + let Some(command) = + Self::read_host_string(&caller, command_ptr, command_len, "exec_command") + else { + return 0; }; - let mut alloc_offset = caller.data().alloc_offset; - match HostState::write_to_memory( - memory, - &mut caller, - &response, - &mut alloc_offset, - ) { - Ok(ptr) => { - caller.data_mut().alloc_offset = alloc_offset; - ptr - } - Err(e) => { - tracing::error!("Failed to write http_request response: {}", e); - 0 - } + let Some(response) = caller.data().api.exec_command(&command, timeout_ms) + else { + return 0; + }; + Self::write_host_string(&mut caller, &response, "exec_command") + }, + ) + .map_err(|e| SkillError::Execution(format!("Failed to link exec_command: {}", e)))?; + Ok(()) + } + + fn link_read_file(linker: &mut Linker) -> Result<(), SkillError> { + linker + .func_wrap( + "host_api_v1", + "read_file", + |mut caller: Caller<'_, HostState>, path_ptr: u32, path_len: u32| -> u32 { + if !caller.data().has_capability(&Capability::Filesystem) { + tracing::warn!("read_file denied: skill lacks Filesystem capability"); + return 0; } + let Some(path) = + Self::read_host_string(&caller, path_ptr, path_len, "read_file") + else { + return 0; + }; + let Some(contents) = caller.data().api.read_file(&path) else { + return 0; + }; + Self::write_host_string(&mut caller, &contents, "read_file") }, ) - .map_err(|e| SkillError::Execution(format!("Failed to link http_request: {}", e)))?; + .map_err(|e| SkillError::Execution(format!("Failed to link read_file: {}", e)))?; + Ok(()) + } + fn link_write_file(linker: &mut Linker) -> Result<(), SkillError> { + linker + .func_wrap( + "host_api_v1", + "write_file", + |caller: Caller<'_, HostState>, + path_ptr: u32, + path_len: u32, + content_ptr: u32, + content_len: u32| + -> i32 { + if !caller.data().has_capability(&Capability::Filesystem) { + tracing::warn!("write_file denied: skill lacks Filesystem capability"); + return 0; + } + let Some(path) = + Self::read_host_string(&caller, path_ptr, path_len, "write_file path") + else { + return 0; + }; + let Some(content) = Self::read_host_string( + &caller, + content_ptr, + content_len, + "write_file content", + ) else { + return 0; + }; + i32::from(caller.data().api.write_file(&path, &content)) + }, + ) + .map_err(|e| SkillError::Execution(format!("Failed to link write_file: {}", e)))?; Ok(()) } + fn read_host_string( + caller: &Caller<'_, HostState>, + ptr: u32, + len: u32, + context: &str, + ) -> Option { + caller + .data() + .read_string(caller, ptr, len) + .map_err(|e| { + tracing::error!("{context}: failed to read string: {e}"); + e + }) + .ok() + } + + fn write_host_string(caller: &mut Caller<'_, HostState>, value: &str, context: &str) -> u32 { + let memory = match caller.data().memory { + Some(memory) => memory, + None => { + tracing::error!("Memory not initialized for {context}"); + return 0; + } + }; + let mut alloc_offset = caller.data().alloc_offset; + match HostState::write_to_memory(memory, caller, value, &mut alloc_offset) { + Ok(ptr) => { + caller.data_mut().alloc_offset = alloc_offset; + ptr + } + Err(e) => { + tracing::error!("Failed to write {context} response: {e}"); + 0 + } + } + } + /// Link host_api_v2 functions to the WASM linker. fn link_v2_host_functions(linker: &mut Linker) -> Result<(), SkillError> { Self::link_v2_get_context(linker)?; @@ -769,6 +863,70 @@ mod tests { wat.as_bytes().to_vec() } + #[derive(Debug)] + struct CapabilityRuntimeHostApi { + base: crate::host_api::HostApiBase, + } + + impl CapabilityRuntimeHostApi { + fn new() -> Self { + Self { + base: crate::host_api::HostApiBase::new("input"), + } + } + } + + impl HostApi for CapabilityRuntimeHostApi { + fn log(&self, _level: u32, _message: &str) {} + + fn kv_get(&self, key: &str) -> Option { + self.base.kv_get(key) + } + + fn kv_set(&mut self, key: &str, value: &str) -> Result<(), SkillError> { + self.base.kv_set(key, value); + Ok(()) + } + + fn get_input(&self) -> String { + self.base.get_input() + } + + fn set_output(&mut self, text: &str) { + self.base.set_output(text); + } + + fn http_request( + &self, + _method: &str, + _url: &str, + _headers: &str, + _body: &str, + ) -> Option { + None + } + + fn exec_command(&self, _command: &str, _timeout_ms: u32) -> Option { + Some("exec_result".to_string()) + } + + fn read_file(&self, _path: &str) -> Option { + Some("file_result".to_string()) + } + + fn write_file(&self, path: &str, content: &str) -> bool { + path == "out.txt" && content == "hello" + } + + fn get_output(&self) -> String { + self.base.get_output() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + } + #[test] fn test_register_and_list_skills() { let mut runtime = SkillRuntime::new().expect("Should create runtime"); @@ -1078,6 +1236,76 @@ mod tests { wat.as_bytes().to_vec() } + fn create_exec_command_wasm() -> Vec { + let wat = r#" + (module + (import "host_api_v1" "set_output" (func $set_output (param i32 i32))) + (import "host_api_v1" "exec_command" + (func $exec_command (param i32 i32 i32) (result i32))) + (memory (export "memory") 1) + (data (i32.const 0) "echo ok") + (data (i32.const 100) "no_response") + + (func (export "run") + (local $resp_ptr i32) + (local.set $resp_ptr + (call $exec_command (i32.const 0) (i32.const 7) (i32.const 1000))) + (if (i32.eqz (local.get $resp_ptr)) + (then (call $set_output (i32.const 100) (i32.const 11))) + (else (call $set_output (local.get $resp_ptr) (i32.const 11)))) + ) + ) + "#; + wat.as_bytes().to_vec() + } + + fn create_read_file_wasm() -> Vec { + let wat = r#" + (module + (import "host_api_v1" "set_output" (func $set_output (param i32 i32))) + (import "host_api_v1" "read_file" + (func $read_file (param i32 i32) (result i32))) + (memory (export "memory") 1) + (data (i32.const 0) "/tmp/input.txt") + (data (i32.const 100) "no_response") + + (func (export "run") + (local $resp_ptr i32) + (local.set $resp_ptr (call $read_file (i32.const 0) (i32.const 14))) + (if (i32.eqz (local.get $resp_ptr)) + (then (call $set_output (i32.const 100) (i32.const 11))) + (else (call $set_output (local.get $resp_ptr) (i32.const 11)))) + ) + ) + "#; + wat.as_bytes().to_vec() + } + + fn create_write_file_wasm() -> Vec { + let wat = r#" + (module + (import "host_api_v1" "set_output" (func $set_output (param i32 i32))) + (import "host_api_v1" "write_file" + (func $write_file (param i32 i32 i32 i32) (result i32))) + (memory (export "memory") 1) + (data (i32.const 0) "out.txt") + (data (i32.const 32) "hello") + (data (i32.const 100) "written") + (data (i32.const 120) "failed") + + (func (export "run") + (if (i32.eq (call $write_file + (i32.const 0) (i32.const 7) + (i32.const 32) (i32.const 5)) + (i32.const 1)) + (then (call $set_output (i32.const 100) (i32.const 7))) + (else (call $set_output (i32.const 120) (i32.const 6)))) + ) + ) + "#; + wat.as_bytes().to_vec() + } + fn create_http_manifest(name: &str, with_network: bool) -> SkillManifest { SkillManifest { name: name.to_string(), @@ -1154,6 +1382,60 @@ mod tests { assert_eq!(output, "no_response"); } + #[test] + fn test_wasm_exec_command_with_shell_capability() { + let mut runtime = SkillRuntime::new().expect("Should create runtime"); + let loader = SkillLoader::with_engine(runtime.engine().clone(), vec![]); + + let mut manifest = create_test_manifest("exec_command"); + manifest.capabilities = vec![Capability::Shell]; + let skill = loader + .load(&create_exec_command_wasm(), &manifest, None) + .expect("Should load"); + runtime.register_skill(skill).expect("Should register"); + + let output = runtime + .invoke_with_api("exec_command", Box::new(CapabilityRuntimeHostApi::new())) + .expect("Should invoke"); + assert_eq!(output, "exec_result"); + } + + #[test] + fn test_wasm_read_file_with_filesystem_capability() { + let mut runtime = SkillRuntime::new().expect("Should create runtime"); + let loader = SkillLoader::with_engine(runtime.engine().clone(), vec![]); + + let mut manifest = create_test_manifest("read_file"); + manifest.capabilities = vec![Capability::Filesystem]; + let skill = loader + .load(&create_read_file_wasm(), &manifest, None) + .expect("Should load"); + runtime.register_skill(skill).expect("Should register"); + + let output = runtime + .invoke_with_api("read_file", Box::new(CapabilityRuntimeHostApi::new())) + .expect("Should invoke"); + assert_eq!(output, "file_result"); + } + + #[test] + fn test_wasm_write_file_with_filesystem_capability() { + let mut runtime = SkillRuntime::new().expect("Should create runtime"); + let loader = SkillLoader::with_engine(runtime.engine().clone(), vec![]); + + let mut manifest = create_test_manifest("write_file"); + manifest.capabilities = vec![Capability::Filesystem]; + let skill = loader + .load(&create_write_file_wasm(), &manifest, None) + .expect("Should load"); + runtime.register_skill(skill).expect("Should register"); + + let output = runtime + .invoke_with_api("write_file", Box::new(CapabilityRuntimeHostApi::new())) + .expect("Should invoke"); + assert_eq!(output, "written"); + } + #[test] fn v1_skill_loads_with_v2_host() { let mut runtime = SkillRuntime::new().expect("create runtime"); diff --git a/engine/crates/fx-tools/Cargo.toml b/engine/crates/fx-tools/Cargo.toml index 50b9cec6..ba6bbbcd 100644 --- a/engine/crates/fx-tools/Cargo.toml +++ b/engine/crates/fx-tools/Cargo.toml @@ -20,6 +20,7 @@ fx-llm = { path = "../fx-llm" } fx-loadable.workspace = true fx-config.workspace = true fx-auth.workspace = true +fx-ripcord = { path = "../fx-ripcord" } fx-consensus.workspace = true fx-memory.workspace = true fx-subagent = { path = "../fx-subagent" } diff --git a/engine/crates/fx-tools/src/git_skill.rs b/engine/crates/fx-tools/src/git_skill.rs index 30713c9b..d12f065c 100644 --- a/engine/crates/fx-tools/src/git_skill.rs +++ b/engine/crates/fx-tools/src/git_skill.rs @@ -3,6 +3,7 @@ use fx_core::self_modify::{classify_path, format_tier_violation, SelfModifyConfi use fx_kernel::cancellation::CancellationToken; use fx_llm::ToolDefinition; use fx_loadable::{Skill, SkillError}; +use fx_ripcord::git_guard::check_push_allowed; use serde::Deserialize; use std::path::{Path, PathBuf}; use std::process::{Output, Stdio}; @@ -31,6 +32,7 @@ pub struct GitSkill { working_dir: PathBuf, self_modify: Option, github_token: Option, + protected_branches: Vec, } impl std::fmt::Debug for GitSkill { @@ -39,6 +41,7 @@ impl std::fmt::Debug for GitSkill { .field("working_dir", &self.working_dir) .field("self_modify", &self.self_modify) .field("github_token", &self.github_token.is_some()) + .field("protected_branches", &self.protected_branches) .finish() } } @@ -107,9 +110,16 @@ impl GitSkill { working_dir, self_modify, github_token, + protected_branches: Vec::new(), } } + #[must_use] + pub fn with_protected_branches(mut self, protected_branches: Vec) -> Self { + self.protected_branches = protected_branches; + self + } + async fn run_git(&self, args: &[&str]) -> Result { self.run_git_with_timeout(args, STATUS_TIMEOUT).await } @@ -278,6 +288,7 @@ impl GitSkill { }; validate_remote_name(remote)?; validate_branch_name(&branch)?; + self.ensure_push_allowed(&branch)?; let token = self.require_github_token()?; self.run_git_with_token_auth(&["push", remote, &branch], &token, PUSH_TIMEOUT) .await @@ -321,6 +332,11 @@ impl GitSkill { Ok(branch) } + fn ensure_push_allowed(&self, branch: &str) -> Result<(), String> { + let target = branch.to_string(); + check_push_allowed(std::slice::from_ref(&target), &self.protected_branches) + } + fn require_github_token(&self) -> Result, String> { let provider = self.github_token.as_ref().ok_or_else(|| { "GitHub token not configured. Set up GitHub auth via `fawx setup` or configure a PAT." @@ -824,6 +840,22 @@ mod tests { ); } + fn fake_token_provider() -> Option { + Some(Arc::new(|| Some(Zeroizing::new("ghp_fake".to_string())))) + } + + fn init_push_remote(repo: &TempDir) -> TempDir { + let remote = TempDir::new().expect("remote tempdir"); + let remote_path = remote.path().to_str().expect("utf8 remote path"); + let output = StdCommand::new("git") + .args(["init", "--bare", remote_path]) + .output() + .expect("init bare remote"); + assert!(output.status.success(), "bare remote init should succeed"); + run_git_ok(repo, &["remote", "add", "origin", remote_path]); + remote + } + #[test] fn git_skill_provides_ten_tool_definitions() { let skill = GitSkill::new(PathBuf::from("."), None, None); @@ -1533,6 +1565,55 @@ mod tests { assert!(validate_remote_name("origin").is_ok()); } + #[tokio::test] + async fn git_push_blocks_protected_branch() { + let repo = init_test_repo(); + seed_initial_commit(&repo, "f.txt", "data\n"); + let skill = GitSkill::new(repo.path().to_path_buf(), None, None) + .with_protected_branches(vec!["main".to_string()]); + + let error = run_tool( + &skill, + "git_push", + serde_json::json!({"remote": "origin", "branch": "main"}), + ) + .await + .expect_err("push to protected branch should fail"); + + assert!(error.contains("protected branch(es) 'main'")); + } + + #[tokio::test] + async fn git_push_allows_unprotected_branch() { + let repo = init_test_repo(); + seed_initial_commit(&repo, "f.txt", "data\n"); + run_git_ok(&repo, &["checkout", "-b", "dev"]); + let remote = init_push_remote(&repo); + let skill = GitSkill::new(repo.path().to_path_buf(), None, fake_token_provider()) + .with_protected_branches(vec!["main".to_string()]); + + run_tool( + &skill, + "git_push", + serde_json::json!({"remote": "origin", "branch": "dev"}), + ) + .await + .expect("push to unprotected branch should succeed"); + + let remote_path = remote.path().to_str().expect("utf8 remote path"); + let output = StdCommand::new("git") + .args([ + "--git-dir", + remote_path, + "show-ref", + "--verify", + "refs/heads/dev", + ]) + .output() + .expect("verify remote ref"); + assert!(output.status.success(), "remote dev branch should exist"); + } + #[tokio::test] async fn git_push_requires_github_token() { let repo = init_test_repo(); diff --git a/engine/crates/fx-tools/src/node_run.rs b/engine/crates/fx-tools/src/node_run.rs index 457678cb..a14f0461 100644 --- a/engine/crates/fx-tools/src/node_run.rs +++ b/engine/crates/fx-tools/src/node_run.rs @@ -264,11 +264,11 @@ mod tests { #[tokio::test] async fn resolves_node_by_name() { let transport = Arc::new(MockTransport::succeeding("ok\n")); - let state = make_state(vec![make_node("n1", "Node Alpha")], transport.clone()); + let state = make_state(vec![make_node("n1", "Worker Node A")], transport.clone()); let result = handle_node_run( &state, - &serde_json::json!({"node": "Node Alpha", "command": "ls"}), + &serde_json::json!({"node": "Worker Node A", "command": "ls"}), ) .await .expect("should resolve by name"); @@ -281,11 +281,11 @@ mod tests { #[tokio::test] async fn resolves_node_name_case_insensitive() { let transport = Arc::new(MockTransport::succeeding("ok\n")); - let state = make_state(vec![make_node("n1", "MacBook Pro")], transport.clone()); + let state = make_state(vec![make_node("n1", "Worker Node B")], transport.clone()); let result = handle_node_run( &state, - &serde_json::json!({"node": "macbook pro", "command": "ls"}), + &serde_json::json!({"node": "worker node b", "command": "ls"}), ) .await; diff --git a/engine/crates/fx-tools/src/tools.rs b/engine/crates/fx-tools/src/tools.rs index 20b5f913..fb91067e 100644 --- a/engine/crates/fx-tools/src/tools.rs +++ b/engine/crates/fx-tools/src/tools.rs @@ -19,6 +19,7 @@ use fx_kernel::{ListEntry, ProcessConfig, ProcessRegistry, SpawnResult, StatusRe use fx_llm::{ToolCall, ToolDefinition}; use fx_memory::embedding_index::EmbeddingIndex; use fx_propose::{build_proposal_content, current_file_hash, Proposal, ProposalWriter}; +use fx_ripcord::git_guard::{check_push_allowed, extract_push_targets}; use fx_subagent::{ SpawnConfig, SpawnMode, SubagentControl, SubagentHandle, SubagentId, SubagentStatus, }; @@ -82,6 +83,7 @@ pub struct FawxToolExecutor { self_modify: Option, concurrency_policy: ConcurrencyPolicy, config_manager: Option>>, + protected_branches: Vec, kernel_budget: KernelBudgetConfig, start_time: std::time::Instant, subagent_control: Option>, @@ -133,6 +135,7 @@ impl FawxToolExecutor { self_modify: None, concurrency_policy: ConcurrencyPolicy::default(), config_manager: None, + protected_branches: Vec::new(), kernel_budget: KernelBudgetConfig::default(), start_time: std::time::Instant::now(), subagent_control: None, @@ -183,6 +186,12 @@ impl FawxToolExecutor { self } + #[must_use] + pub fn with_protected_branches(mut self, protected_branches: Vec) -> Self { + self.protected_branches = protected_branches; + self + } + /// Attach the active kernel budget configuration. pub fn with_kernel_budget(mut self, budget: KernelBudgetConfig) -> Self { self.kernel_budget = budget; @@ -530,6 +539,7 @@ impl FawxToolExecutor { return Err("command cannot be empty".to_string()); } let working_dir = self.resolve_command_dir(parsed.working_dir.as_deref())?; + self.guard_push_command(command)?; let child = build_command(command, parsed.shell.unwrap_or(false), &working_dir)? .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -542,6 +552,7 @@ impl FawxToolExecutor { fn handle_exec_background(&self, args: &serde_json::Value) -> Result { let parsed: ExecBackgroundArgs = parse_args(args)?; let working_dir = self.resolve_command_dir(parsed.working_dir.as_deref())?; + self.guard_push_command(&parsed.command)?; let result = self .process_registry .spawn(parsed.command, working_dir, parsed.label)?; @@ -570,6 +581,14 @@ impl FawxToolExecutor { })) } + fn guard_push_command(&self, command: &str) -> Result<(), String> { + let targets = extract_push_targets(command); + if targets.is_empty() { + return Ok(()); + } + check_push_allowed(&targets, &self.protected_branches) + } + fn resolve_command_dir(&self, requested: Option<&str>) -> Result { let desired = requested.unwrap_or_else(|| self.working_dir.to_str().unwrap_or(".")); if !self.config.jail_to_working_dir { @@ -2626,6 +2645,45 @@ mod tests { serde_json::from_str(output).expect("valid json output") } + fn executor_with_protected_branches(root: &Path, branches: &[&str]) -> FawxToolExecutor { + test_executor(root).with_protected_branches( + branches + .iter() + .map(|branch| (*branch).to_string()) + .collect(), + ) + } + + fn run_git_ok(repo: &Path, args: &[&str]) { + let output = std::process::Command::new("git") + .args(args) + .current_dir(repo) + .output() + .expect("git command should run in tests"); + assert!( + output.status.success(), + "git {:?} failed: {}", + args, + String::from_utf8_lossy(&output.stderr) + ); + } + + fn init_push_repo() -> (TempDir, TempDir) { + let repo = TempDir::new().expect("repo tempdir"); + let remote = TempDir::new().expect("remote tempdir"); + let remote_path = remote.path().to_str().expect("utf8 remote path"); + run_git_ok(repo.path(), &["init"]); + run_git_ok(repo.path(), &["config", "user.email", "test@test.com"]); + run_git_ok(repo.path(), &["config", "user.name", "Test"]); + run_git_ok(remote.path(), &["init", "--bare"]); + run_git_ok(repo.path(), &["remote", "add", "origin", remote_path]); + fs::write(repo.path().join("file.txt"), "data\n").expect("write seed file"); + run_git_ok(repo.path(), &["add", "file.txt"]); + run_git_ok(repo.path(), &["commit", "-m", "initial"]); + run_git_ok(repo.path(), &["checkout", "-b", "dev"]); + (repo, remote) + } + fn test_executor_with_subagents(root: &Path) -> FawxToolExecutor { test_executor_with_control(root, Arc::new(StubSubagentControl::new())) } @@ -3415,6 +3473,47 @@ three assert!(output.is_err()); } + #[tokio::test] + async fn run_command_blocks_push_to_protected_branch() { + let temp = TempDir::new().expect("temp"); + let executor = executor_with_protected_branches(temp.path(), &["main"]); + + let error = executor + .handle_run_command(&serde_json::json!({"command": "git push origin main"})) + .await + .expect_err("protected push should be blocked"); + + assert!(error.contains("protected branch(es) 'main'")); + } + + #[tokio::test] + async fn run_command_allows_push_to_unprotected_branch() { + let (repo, remote) = init_push_repo(); + let executor = executor_with_protected_branches(repo.path(), &["main"]); + + let output = executor + .handle_run_command(&serde_json::json!({"command": "git push origin dev"})) + .await + .expect("unprotected push should execute"); + + assert!( + output.contains("exit_code: 0"), + "unexpected output: {output}" + ); + let remote_path = remote.path().to_str().expect("utf8 remote path"); + let status = std::process::Command::new("git") + .args([ + "--git-dir", + remote_path, + "show-ref", + "--verify", + "refs/heads/dev", + ]) + .output() + .expect("verify remote ref"); + assert!(status.status.success(), "remote dev branch should exist"); + } + #[tokio::test] async fn exec_background_returns_session_id_and_status() { let temp = TempDir::new().expect("temp"); @@ -3504,6 +3603,18 @@ three assert!(!error.is_empty()); } + #[test] + fn exec_background_blocks_push_to_protected_branch() { + let temp = TempDir::new().expect("temp"); + let executor = executor_with_protected_branches(temp.path(), &["main"]); + + let error = executor + .handle_exec_background(&serde_json::json!({"command": "git push origin main"})) + .expect_err("protected push should be blocked"); + + assert!(error.contains("protected branch(es) 'main'")); + } + #[test] fn search_text_finds_pattern_with_file_and_line() { let temp = TempDir::new().expect("temp"); diff --git a/tui/src/embedded_backend.rs b/tui/src/embedded_backend.rs index 0010e187..5758b731 100644 --- a/tui/src/embedded_backend.rs +++ b/tui/src/embedded_backend.rs @@ -151,7 +151,13 @@ fn handle_stream_event( is_error, } => { complete_experiment_tool(active_experiments, experiment_panel, &id); - send_tool_result(tx, output, is_error); + if !is_error { + send_tool_result(tx, None, output, true); + } + } + StreamEvent::ToolError { tool_name, error } => { + tracing::warn!(tool = %tool_name, "tool error in embedded mode: {error}"); + send_tool_result(tx, Some(tool_name), error, false); } StreamEvent::Error { message, .. } => { tracing::warn!("stream error in embedded mode: {message}"); @@ -238,12 +244,17 @@ fn send_tool_call_complete(tx: &UnboundedSender, name: String, arg ); } -fn send_tool_result(tx: &UnboundedSender, output: String, is_error: bool) { +fn send_tool_result( + tx: &UnboundedSender, + name: Option, + output: String, + success: bool, +) { try_send( tx, BackendEvent::ToolResult { - name: None, - success: !is_error, + name, + success, content: output, }, ); @@ -615,7 +626,7 @@ mod tests { } #[tokio::test] - async fn handle_stream_event_maps_tool_result_to_backend_tool_result() { + async fn handle_stream_event_maps_successful_tool_result_to_backend_tool_result() { let (tx, mut rx) = unbounded_channel(); let saw_text_delta = AtomicBool::new(false); let experiment_panel = test_experiment_panel(); @@ -628,8 +639,8 @@ mod tests { &active_experiments, StreamEvent::ToolResult { id: "call-1".to_string(), - output: "denied".to_string(), - is_error: true, + output: "file contents".to_string(), + is_error: false, }, ); @@ -640,6 +651,38 @@ mod tests { content, } => { assert!(name.is_none()); + assert!(success); + assert_eq!(content, "file contents"); + } + other => panic!("unexpected event: {other:?}"), + } + } + + #[tokio::test] + async fn handle_stream_event_maps_tool_error_to_backend_tool_result() { + let (tx, mut rx) = unbounded_channel(); + let saw_text_delta = AtomicBool::new(false); + let experiment_panel = test_experiment_panel(); + let active_experiments = test_active_experiments(); + + handle_stream_event( + &tx, + &saw_text_delta, + &experiment_panel, + &active_experiments, + StreamEvent::ToolError { + tool_name: "read".to_string(), + error: "denied".to_string(), + }, + ); + + match recv_event(&mut rx).await { + BackendEvent::ToolResult { + name, + success, + content, + } => { + assert_eq!(name.as_deref(), Some("read")); assert!(!success); assert_eq!(content, "denied"); } diff --git a/tui/src/fawx_backend.rs b/tui/src/fawx_backend.rs index 0d5ffd3a..7db3857b 100644 --- a/tui/src/fawx_backend.rs +++ b/tui/src/fawx_backend.rs @@ -113,6 +113,13 @@ struct ToolResultData { is_error: bool, } +/// Data payload for `tool_error` events. +#[derive(Deserialize)] +struct ToolErrorData { + tool_name: String, + error: String, +} + /// Data payload for `done` events. #[derive(Deserialize)] struct DoneData { @@ -534,13 +541,29 @@ fn handle_tool_call_complete(data: &str, tx: &UnboundedSender) -> Ok(()) } +fn handle_tool_error(data: &str, tx: &UnboundedSender) -> anyhow::Result<()> { + let d: ToolErrorData = serde_json::from_str(data).context("decode tool_error")?; + try_send( + tx, + BackendEvent::ToolResult { + name: Some(d.tool_name), + success: false, + content: d.error, + }, + ); + Ok(()) +} + fn handle_tool_result(data: &str, tx: &UnboundedSender) -> anyhow::Result<()> { let d: ToolResultData = serde_json::from_str(data).context("decode tool_result")?; + if d.is_error { + return Ok(()); + } try_send( tx, BackendEvent::ToolResult { name: d.id, - success: !d.is_error, + success: true, content: d.output.unwrap_or_default(), }, ); @@ -600,6 +623,7 @@ fn dispatch_sse_frame( "tool_call_start" => handle_tool_call_start(&sse.data, tx)?, "tool_call_complete" => handle_tool_call_complete(&sse.data, tx)?, "tool_result" => handle_tool_result(&sse.data, tx)?, + "tool_error" => handle_tool_error(&sse.data, tx)?, "done" => handle_done(&sse.data, tx, *saw_text_delta)?, "phase" => { /* Phase changes are informational; TUI doesn't need them yet. */ } "error" => handle_error(&sse.data, tx)?, @@ -860,7 +884,7 @@ model = "gpt-4" } #[test] - fn dispatch_tool_result_error_maps_is_error_to_success_false() { + fn dispatch_tool_result_error_is_ignored_in_favor_of_tool_error_event() { let (tx, mut rx) = unbounded_channel(); let mut saw = false; dispatch_sse_frame( @@ -869,10 +893,10 @@ model = "gpt-4" &mut saw, ) .expect("should decode"); - match rx.try_recv().expect("event") { - BackendEvent::ToolResult { success, .. } => assert!(!success), - other => panic!("unexpected: {other:?}"), - } + assert!( + rx.try_recv().is_err(), + "error tool_result should be skipped" + ); } #[test]