diff --git a/.gitignore b/.gitignore index c3620dd..063311f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,26 @@ model.onnx tokenizer.json # Config file (use eidos.toml.example as template) -eidos.toml \ No newline at end of file +eidos.toml + +# IDE files +.idea/ +.vscode/ +*.swp +*.swo +*~ +.DS_Store +Thumbs.db + +# Sensitive files +.env +.env.local +credentials.json +*_secret* + +# Logs +*.log +logs/ + +# Generated docs +/target/doc/ \ No newline at end of file diff --git a/lib_chat/src/api.rs b/lib_chat/src/api.rs index 22aef4c..78ee919 100644 --- a/lib_chat/src/api.rs +++ b/lib_chat/src/api.rs @@ -6,6 +6,10 @@ use serde::{Deserialize, Serialize}; use std::env; use std::time::Duration; +// Default timeouts (can be overridden via environment variables) +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30; +const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10; + #[derive(Debug, Clone)] pub enum ApiProvider { OpenAI { @@ -108,20 +112,31 @@ pub struct ApiClient { } impl ApiClient { - pub fn new(provider: ApiProvider) -> Self { - // Create HTTP client with timeout to prevent hanging requests + pub fn new(provider: ApiProvider) -> Result { + // Get timeout values from environment variables or use defaults + let request_timeout = env::var("HTTP_REQUEST_TIMEOUT_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS); + + let connect_timeout = env::var("HTTP_CONNECT_TIMEOUT_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_CONNECT_TIMEOUT_SECS); + + // Create HTTP client with configurable timeouts to prevent hanging requests let client = Client::builder() - .timeout(Duration::from_secs(30)) // 30 second timeout - .connect_timeout(Duration::from_secs(10)) // 10 second connection timeout + .timeout(Duration::from_secs(request_timeout)) + .connect_timeout(Duration::from_secs(connect_timeout)) .build() - .expect("Failed to build HTTP client"); + .map_err(|e| ChatError::ApiError(format!("Failed to build HTTP client: {}", e)))?; - Self { provider, client } + Ok(Self { provider, client }) } pub fn from_env() -> Result { let provider = ApiProvider::from_env()?; - Ok(Self::new(provider)) + Self::new(provider) } pub async fn send_message( diff --git a/lib_chat/src/error.rs b/lib_chat/src/error.rs index 4efcbbb..f777d0d 100644 --- a/lib_chat/src/error.rs +++ b/lib_chat/src/error.rs @@ -26,6 +26,9 @@ pub enum ChatError { #[error("Environment variable not set: {0}")] EnvError(String), + + #[error("Invalid input: {0}")] + InvalidInput(String), } pub type Result = std::result::Result; diff --git a/lib_chat/src/history.rs b/lib_chat/src/history.rs index 2f7c324..ed1f24c 100644 --- a/lib_chat/src/history.rs +++ b/lib_chat/src/history.rs @@ -42,36 +42,77 @@ impl Message { pub struct ConversationHistory { messages: Vec, max_messages: usize, + max_bytes_total: usize, // Max total memory for all messages + max_bytes_per_message: usize, // Max size for a single message } impl ConversationHistory { pub fn new(max_messages: usize) -> Self { + Self::new_with_limits( + max_messages, + 10 * 1024 * 1024, // 10MB total by default + 1 * 1024 * 1024, // 1MB per message by default + ) + } + + pub fn new_with_limits( + max_messages: usize, + max_bytes_total: usize, + max_bytes_per_message: usize, + ) -> Self { Self { messages: Vec::new(), max_messages, + max_bytes_total, + max_bytes_per_message, } } - pub fn add_message(&mut self, message: Message) { + /// Calculate total byte size of all messages + fn total_bytes(&self) -> usize { + self.messages + .iter() + .map(|m| m.content.len()) + .sum() + } + + pub fn add_message(&mut self, message: Message) -> Result<(), String> { + // Check individual message size + let message_bytes = message.content.len(); + if message_bytes > self.max_bytes_per_message { + return Err(format!( + "Message too large: {} bytes (max {} bytes)", + message_bytes, self.max_bytes_per_message + )); + } + self.messages.push(message); - // Keep only the most recent messages + // Keep only the most recent messages by count if self.messages.len() > self.max_messages { let start = self.messages.len() - self.max_messages; self.messages.drain(0..start); } + + // Keep only the most recent messages by total size + while self.total_bytes() > self.max_bytes_total && self.messages.len() > 1 { + // Remove oldest message + self.messages.remove(0); + } + + Ok(()) } - pub fn add_user_message(&mut self, content: impl Into) { - self.add_message(Message::user(content)); + pub fn add_user_message(&mut self, content: impl Into) -> Result<(), String> { + self.add_message(Message::user(content)) } - pub fn add_assistant_message(&mut self, content: impl Into) { - self.add_message(Message::assistant(content)); + pub fn add_assistant_message(&mut self, content: impl Into) -> Result<(), String> { + self.add_message(Message::assistant(content)) } - pub fn add_system_message(&mut self, content: impl Into) { - self.add_message(Message::system(content)); + pub fn add_system_message(&mut self, content: impl Into) -> Result<(), String> { + self.add_message(Message::system(content)) } pub fn messages(&self) -> &[Message] { @@ -116,14 +157,14 @@ mod tests { fn test_conversation_history() { let mut history = ConversationHistory::new(3); - history.add_user_message("Message 1"); - history.add_assistant_message("Response 1"); - history.add_user_message("Message 2"); + history.add_user_message("Message 1").unwrap(); + history.add_assistant_message("Response 1").unwrap(); + history.add_user_message("Message 2").unwrap(); assert_eq!(history.len(), 3); // Adding more messages should drop oldest - history.add_assistant_message("Response 2"); + history.add_assistant_message("Response 2").unwrap(); assert_eq!(history.len(), 3); assert_eq!(history.messages()[0].content, "Response 1"); } @@ -131,10 +172,36 @@ mod tests { #[test] fn test_clear_history() { let mut history = ConversationHistory::new(10); - history.add_user_message("Test"); + history.add_user_message("Test").unwrap(); assert!(!history.is_empty()); history.clear(); assert!(history.is_empty()); } + + #[test] + fn test_message_size_limit() { + let mut history = ConversationHistory::new_with_limits(10, 1000, 100); + + // Message within limit should succeed + assert!(history.add_user_message("x".repeat(50)).is_ok()); + + // Message exceeding limit should fail + let result = history.add_user_message("x".repeat(150)); + assert!(result.is_err()); + } + + #[test] + fn test_total_size_limit() { + let mut history = ConversationHistory::new_with_limits(10, 200, 100); + + // Add messages that together exceed total limit + history.add_user_message("x".repeat(80)).unwrap(); + history.add_user_message("x".repeat(80)).unwrap(); + history.add_user_message("x".repeat(80)).unwrap(); + + // Should have dropped old messages to stay under total limit + assert!(history.total_bytes() <= 200); + assert!(history.len() < 3); + } } diff --git a/lib_chat/src/lib.rs b/lib_chat/src/lib.rs index aee2673..9ebf33e 100644 --- a/lib_chat/src/lib.rs +++ b/lib_chat/src/lib.rs @@ -12,8 +12,17 @@ use tokio::runtime::Runtime; /// /// Creating a new Runtime on every request is expensive (~10-50ms overhead). /// This static runtime is created once and reused for all chat operations. -static RUNTIME: Lazy = - Lazy::new(|| Runtime::new().expect("Failed to create tokio runtime")); +/// +/// # Panics +/// Will panic if the tokio runtime cannot be created. This is a critical failure +/// that indicates system resource exhaustion or misconfiguration. +static RUNTIME: Lazy = Lazy::new(|| { + Runtime::new().expect( + "FATAL: Failed to create tokio runtime. \ + This likely indicates system resource exhaustion. \ + Check available memory and file descriptors.", + ) +}); pub struct Chat { client: Option, @@ -34,11 +43,11 @@ impl Chat { } /// Create a Chat instance with a specific provider - pub fn with_provider(provider: ApiProvider) -> Self { - Self { - client: Some(ApiClient::new(provider)), + pub fn with_provider(provider: ApiProvider) -> Result { + Ok(Self { + client: Some(ApiClient::new(provider)?), history: ConversationHistory::default(), - } + }) } /// Send a message and get a response (async) @@ -49,7 +58,9 @@ impl Chat { .ok_or_else(|| error::ChatError::NoProviderError)?; // Add user message to history - self.history.add_user_message(message); + self.history + .add_user_message(message) + .map_err(|e| error::ChatError::InvalidInput(e))?; // Send to API with full conversation history let response = client @@ -57,7 +68,9 @@ impl Chat { .await?; // Add assistant response to history - self.history.add_assistant_message(&response); + self.history + .add_assistant_message(&response) + .map_err(|e| error::ChatError::InvalidInput(e))?; Ok(response) } @@ -73,8 +86,10 @@ impl Chat { } /// Add a system message to guide the conversation - pub fn set_system_prompt(&mut self, prompt: &str) { - self.history.add_system_message(prompt); + pub fn set_system_prompt(&mut self, prompt: &str) -> Result<()> { + self.history + .add_system_message(prompt) + .map_err(|e| error::ChatError::InvalidInput(e)) } /// Clear conversation history diff --git a/lib_core/src/tract_llm.rs b/lib_core/src/tract_llm.rs index e6f6fe1..3ba5792 100644 --- a/lib_core/src/tract_llm.rs +++ b/lib_core/src/tract_llm.rs @@ -46,13 +46,51 @@ impl Core { pub fn is_safe_command(&self, command: &str) -> bool { is_safe_command(command) } + + /// Generates an explanation for what a command does + /// + /// This helps users understand generated commands before executing them. + /// The explanation describes the command's purpose, flags used, and potential side effects. + /// + /// # Example + /// ```ignore + /// let explanation = core.explain_command("ls -la")?; + /// // Returns: "Lists all files in long format, including hidden files" + /// ``` + pub fn explain_command(&self, command: &str) -> TractResult { + let prompt = format!("Explain what this command does: {}", command); + + let encoding = self.tokenizer.encode(prompt.as_str(), true).map_err(|e| anyhow!(e))?; + let input_ids: Vec = encoding.get_ids().iter().map(|&id| id as i64).collect(); + let input_tensor = arr1(&input_ids).into_dyn().into_tensor(); + + let result = self.model.run(tvec!(input_tensor.into()))?; + + let output_tensor = result[0].to_array_view::()?; + let output_ids: Vec = output_tensor.iter().map(|&id| id as u32).collect(); + + let explanation = self + .tokenizer + .decode(&output_ids, true) + .map_err(|e| anyhow!(e))?; + + Ok(explanation) + } } impl Default for Core { + /// Create Core with default paths + /// + /// # Panics + /// Panics if the default model files ("model.onnx", "tokenizer.json") cannot be loaded. + /// This is intentional for Default trait - use Core::new() directly for error handling. fn default() -> Self { let model_path = "model.onnx"; let tokenizer_path = "tokenizer.json"; - Core::new(model_path, tokenizer_path).expect("Failed to create Core instance") + Core::new(model_path, tokenizer_path).expect( + "FATAL: Failed to load default Core model. \ + Ensure 'model.onnx' and 'tokenizer.json' exist in the current directory.", + ) } } diff --git a/lib_core/src/validation.rs b/lib_core/src/validation.rs index eb9f3ab..97bfd1c 100644 --- a/lib_core/src/validation.rs +++ b/lib_core/src/validation.rs @@ -37,11 +37,11 @@ /// - `tests/` for comprehensive security test suite pub fn is_safe_command(command: &str) -> bool { // Whitelist of safe base commands that are read-only and don't modify system state. - // DO NOT add write commands. See SAFETY.md for rationale. + // DO NOT add write commands (including touch/mkdir). See SAFETY.md for rationale. + // Even "safe" write operations are excluded to maintain strict read-only policy. let allowed_commands = [ "ls", "pwd", "echo", "cat", "head", "tail", "grep", "find", "wc", "date", "whoami", "hostname", "uname", "df", "du", "free", "top", "ps", "which", "whereis", "file", "stat", - "touch", "mkdir", ]; // Dangerous patterns that should never be allowed @@ -116,8 +116,8 @@ pub fn is_safe_command(command: &str) -> bool { return false; } - // Check if command starts with an allowed command - let first_word = cmd_trimmed.split_whitespace().next().unwrap_or(""); + // Check if command starts with an allowed command (case-insensitive) + let first_word = cmd_lower.split_whitespace().next().unwrap_or(""); if !allowed_commands.contains(&first_word) { return false; } diff --git a/lib_translate/src/error.rs b/lib_translate/src/error.rs index 81f71ff..dc7a2c7 100644 --- a/lib_translate/src/error.rs +++ b/lib_translate/src/error.rs @@ -23,6 +23,9 @@ pub enum TranslateError { #[error("No translator configured")] NoTranslatorError, + + #[error("Configuration error: {0}")] + ConfigError(String), } pub type Result = std::result::Result; diff --git a/lib_translate/src/lib.rs b/lib_translate/src/lib.rs index 0b561fd..dc70a90 100644 --- a/lib_translate/src/lib.rs +++ b/lib_translate/src/lib.rs @@ -12,8 +12,17 @@ use tokio::runtime::Runtime; /// /// Creating a new Runtime on every request is expensive (~10-50ms overhead). /// This static runtime is created once and reused for all translation operations. -static RUNTIME: Lazy = - Lazy::new(|| Runtime::new().expect("Failed to create tokio runtime")); +/// +/// # Panics +/// Will panic if the tokio runtime cannot be created. This is a critical failure +/// that indicates system resource exhaustion or misconfiguration. +static RUNTIME: Lazy = Lazy::new(|| { + Runtime::new().expect( + "FATAL: Failed to create tokio runtime. \ + This likely indicates system resource exhaustion. \ + Check available memory and file descriptors.", + ) +}); pub struct Translate { translator: Option, @@ -29,17 +38,17 @@ impl Translate { ); // Use mock translator as fallback return Self { - translator: Some(Translator::new(TranslatorProvider::Mock)), + translator: Translator::new(TranslatorProvider::Mock).ok(), }; } Self { translator } } /// Create a Translate instance with a specific provider - pub fn with_provider(provider: TranslatorProvider) -> Self { - Self { - translator: Some(Translator::new(provider)), - } + pub fn with_provider(provider: TranslatorProvider) -> Result { + Ok(Self { + translator: Some(Translator::new(provider)?), + }) } /// Detect language and translate if needed diff --git a/lib_translate/src/translator.rs b/lib_translate/src/translator.rs index d9d5f28..509962b 100644 --- a/lib_translate/src/translator.rs +++ b/lib_translate/src/translator.rs @@ -5,6 +5,10 @@ use serde::{Deserialize, Serialize}; use std::env; use std::time::Duration; +// Default timeouts (can be overridden via environment variables) +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30; +const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10; + #[derive(Debug, Clone)] pub enum TranslatorProvider { LibreTranslate { @@ -17,17 +21,20 @@ pub enum TranslatorProvider { impl TranslatorProvider { /// Load translator from environment variables pub fn from_env() -> Result { - // Check for LibreTranslate configuration - if let Ok(url) = env::var("LIBRETRANSLATE_URL") { - let api_key = env::var("LIBRETRANSLATE_API_KEY").ok(); - return Ok(TranslatorProvider::LibreTranslate { url, api_key }); - } - - // Default to public LibreTranslate instance (with limitations) - Ok(TranslatorProvider::LibreTranslate { - url: "https://libretranslate.com".to_string(), - api_key: None, - }) + // Require explicit LibreTranslate configuration for security + let url = env::var("LIBRETRANSLATE_URL").map_err(|_| { + TranslateError::ConfigError( + "Translation service not configured. Set LIBRETRANSLATE_URL environment variable.\n\ + Options:\n\ + 1. Self-hosted: export LIBRETRANSLATE_URL=http://localhost:5000\n\ + 2. Public API: export LIBRETRANSLATE_URL=https://libretranslate.com\n\ + (Note: Public API has rate limits and may require an API key)\n\ + 3. With API key: export LIBRETRANSLATE_API_KEY=your_api_key".to_string(), + ) + })?; + + let api_key = env::var("LIBRETRANSLATE_API_KEY").ok(); + Ok(TranslatorProvider::LibreTranslate { url, api_key }) } } @@ -59,20 +66,31 @@ pub struct Translator { } impl Translator { - pub fn new(provider: TranslatorProvider) -> Self { - // Create HTTP client with timeout to prevent hanging requests + pub fn new(provider: TranslatorProvider) -> Result { + // Get timeout values from environment variables or use defaults + let request_timeout = env::var("HTTP_REQUEST_TIMEOUT_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS); + + let connect_timeout = env::var("HTTP_CONNECT_TIMEOUT_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_CONNECT_TIMEOUT_SECS); + + // Create HTTP client with configurable timeouts to prevent hanging requests let client = Client::builder() - .timeout(Duration::from_secs(30)) // 30 second timeout - .connect_timeout(Duration::from_secs(10)) // 10 second connection timeout + .timeout(Duration::from_secs(request_timeout)) + .connect_timeout(Duration::from_secs(connect_timeout)) .build() - .expect("Failed to build HTTP client"); + .map_err(|e| TranslateError::ApiError(format!("Failed to build HTTP client: {}", e)))?; - Self { provider, client } + Ok(Self { provider, client }) } pub fn from_env() -> Result { let provider = TranslatorProvider::from_env()?; - Ok(Self::new(provider)) + Self::new(provider) } pub async fn translate( @@ -170,7 +188,7 @@ mod tests { #[tokio::test] async fn test_mock_translator() { - let translator = Translator::new(TranslatorProvider::Mock); + let translator = Translator::new(TranslatorProvider::Mock).unwrap(); let result = translator.translate("Hello", "en", "es").await.unwrap(); assert!(result.contains("Hello")); assert!(result.contains("en")); @@ -179,7 +197,7 @@ mod tests { #[tokio::test] async fn test_translate_to_english_same_language() { - let translator = Translator::new(TranslatorProvider::Mock); + let translator = Translator::new(TranslatorProvider::Mock).unwrap(); let result = translator .translate_to_english("Hello", "en") .await diff --git a/src/config.rs b/src/config.rs index 53372e6..1068879 100644 --- a/src/config.rs +++ b/src/config.rs @@ -69,22 +69,100 @@ impl Config { }) } - /// Validate that the configured paths exist + /// Validate that the configured paths exist and are safe to use pub fn validate(&self) -> Result<(), String> { - if !self.model_path.exists() { + // Validate model path + Self::validate_file_path(&self.model_path, "Model", 2 * 1024 * 1024 * 1024)?; // 2GB max + + // Validate tokenizer path + Self::validate_file_path(&self.tokenizer_path, "Tokenizer", 100 * 1024 * 1024)?; // 100MB max + + Ok(()) + } + + /// Validate a file path for security and safety + fn validate_file_path(path: &PathBuf, file_type: &str, max_size: u64) -> Result<(), String> { + // Check if file exists + if !path.exists() { + return Err(format!("{} file not found: {}", file_type, path.display())); + } + + // Canonicalize path to resolve symlinks and check for path traversal + let canonical_path = path.canonicalize().map_err(|e| { + format!( + "Failed to resolve {} path {}: {}", + file_type, + path.display(), + e + ) + })?; + + // Check if path contains suspicious patterns (after canonicalization) + let path_str = canonical_path.to_string_lossy(); + if path_str.contains("..") { return Err(format!( - "Model file not found: {}", - self.model_path.display() + "{} path contains suspicious patterns: {}", + file_type, + path.display() )); } - if !self.tokenizer_path.exists() { + // Get file metadata + let metadata = fs::metadata(&canonical_path).map_err(|e| { + format!( + "Failed to read {} file metadata: {}", + file_type, e + ) + })?; + + // Check if it's a regular file (not directory or other special file) + if !metadata.is_file() { return Err(format!( - "Tokenizer file not found: {}", - self.tokenizer_path.display() + "{} path is not a regular file: {}", + file_type, + path.display() )); } + // Check file size is reasonable + let file_size = metadata.len(); + if file_size > max_size { + return Err(format!( + "{} file too large: {} bytes (max {} bytes)", + file_type, file_size, max_size + )); + } + + if file_size == 0 { + return Err(format!("{} file is empty: {}", file_type, path.display())); + } + + // Check file permissions (must be readable) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let permissions = metadata.permissions(); + let mode = permissions.mode(); + + // Check if file is readable by user (owner) + if mode & 0o400 == 0 { + return Err(format!( + "{} file is not readable: {}", + file_type, + path.display() + )); + } + + // Warn if file is world-readable with write permissions + if mode & 0o002 != 0 { + eprintln!( + "⚠️ Warning: {} file is world-writable: {}", + file_type, + path.display() + ); + } + } + Ok(()) } } diff --git a/src/constants.rs b/src/constants.rs index 53ce69d..3323d69 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,33 +1,7 @@ // Global constants for Eidos CLI // Centralizes magic numbers and configuration values for easier maintenance -/// Input validation limits +/// Input validation limits (actively used) pub const MAX_CHAT_INPUT_LENGTH: usize = 10_000; pub const MAX_CORE_PROMPT_LENGTH: usize = 1_000; pub const MAX_TRANSLATE_INPUT_LENGTH: usize = 5_000; - -/// HTTP client timeouts -pub const API_REQUEST_TIMEOUT_SECS: u64 = 30; -pub const API_CONNECT_TIMEOUT_SECS: u64 = 10; - -/// Chat history configuration -pub const DEFAULT_MAX_CONVERSATION_MESSAGES: usize = 50; - -/// Language detection configuration -pub const LANGUAGE_DETECTION_CONFIDENCE_THRESHOLD: f64 = 0.25; - -/// Model inference configuration -pub const SEED_FOR_REPRODUCIBILITY: u64 = 299792458; // Speed of light in m/s - -/// Application metadata -pub const APP_VERSION: &str = "0.2.0-beta"; -pub const APP_NAME: &str = "Eidos"; -pub const APP_DESCRIPTION: &str = "AI-powered CLI for Linux - Natural language to shell commands"; - -/// Cache configuration (for future use) -pub const DEFAULT_CACHE_SIZE: usize = 1000; -pub const DEFAULT_CACHE_TTL_HOURS: u64 = 24; - -/// Performance tuning -pub const VALIDATION_PATTERNS_CAPACITY: usize = 64; -pub const HISTORY_BUFFER_CAPACITY: usize = 100; diff --git a/src/main.rs b/src/main.rs index 1ee2391..b76278d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ mod config; mod constants; mod error; -mod output; use crate::config::Config; use crate::constants::*; @@ -104,9 +103,6 @@ struct Cli { #[clap(short, long, global = true, help = "Enable debug logging")] debug: bool, - - #[clap(short = 'o', long, global = true, value_name = "FORMAT", help = "Output format: text (default) or json")] - output_format: Option, } #[derive(Subcommand, Debug)] @@ -134,6 +130,23 @@ enum Commands { }, } +/// Sanitize sensitive text for logging by truncating and masking +/// +/// This prevents sensitive information from being exposed in debug logs. +/// Only logs first 50 characters and masks the rest. +fn sanitize_for_logging(text: &str, max_chars: usize) -> String { + let char_count = text.chars().count(); + if char_count <= max_chars { + format!("{}... ({} chars)", text.chars().take(max_chars).collect::(), char_count) + } else { + format!( + "{}... [TRUNCATED] ({} chars total)", + text.chars().take(max_chars).collect::(), + char_count + ) + } +} + /// Validate input text for safety and sanity fn validate_input(text: &str, max_length: usize) -> std::result::Result<(), String> { // Check for empty input @@ -189,7 +202,7 @@ fn setup_bridge() -> Bridge { Request::Chat, Box::new(|text: &str| { info!("Processing chat request"); - debug!("Chat input: {}", text); + debug!("Chat input: {}", sanitize_for_logging(text, 50)); let mut chat = Chat::new(); match chat.run(text) { @@ -217,7 +230,7 @@ fn setup_bridge() -> Bridge { Request::Core, Box::new(|prompt: &str| { info!("Processing core command generation request"); - debug!("Prompt: {}", prompt); + debug!("Prompt: {}", sanitize_for_logging(prompt, 50)); // Load configuration debug!("Loading configuration"); @@ -301,7 +314,7 @@ fn setup_bridge() -> Bridge { Request::Translate, Box::new(|text: &str| { info!("Processing translation request"); - debug!("Translation input: {}", text); + debug!("Translation input: {}", sanitize_for_logging(text, 50)); let translate = Translate::new(); match translate.run(text) { @@ -361,7 +374,11 @@ fn main() -> Result<()> { crate::error::AppError::InvalidInput(e) }) } - Commands::Core { ref prompt, alternatives: _, explain: _ } => { + Commands::Core { + ref prompt, + alternatives, + explain, + } => { // Validate input (max 1000 chars for prompts) if let Err(e) = validate_input(prompt, MAX_CORE_PROMPT_LENGTH) { error!("Input validation failed: {}", e); @@ -369,11 +386,137 @@ fn main() -> Result<()> { return Err(crate::error::AppError::InvalidInput(e)); } - debug!("Routing to core handler"); - bridge.route(Request::Core, prompt).map_err(|e| { - error!("Core routing failed: {}", e); + // Handle Core command generation with alternatives and explain support + info!("Processing core command generation request"); + debug!("Prompt: {}", sanitize_for_logging(prompt, 50)); + debug!("Alternatives: {}, Explain: {}", alternatives, explain); + + // Load configuration + debug!("Loading configuration"); + let config = Config::load().map_err(|e| { + error!("Configuration loading failed: {}", e); + crate::error::AppError::InvalidInput(format!("Config error: {}", e)) + })?; + + // Validate configuration + config.validate().map_err(|e| { + error!("Configuration validation failed: {}", e); + eprintln!("❌ Configuration Error: {}", e); + eprintln!(); + eprintln!("To configure Eidos, choose one of:"); + eprintln!(" 1. Environment variables:"); + eprintln!(" export EIDOS_MODEL_PATH=/path/to/model.onnx"); + eprintln!(" export EIDOS_TOKENIZER_PATH=/path/to/tokenizer.json"); + eprintln!(); + eprintln!(" 2. Config file (./eidos.toml or ~/.config/eidos/eidos.toml):"); + eprintln!(" model_path = \"/path/to/model.onnx\""); + eprintln!(" tokenizer_path = \"/path/to/tokenizer.json\""); + eprintln!(); + eprintln!(" 3. See docs/MODEL_GUIDE.md for training your own model"); + crate::error::AppError::InvalidInput(e.to_string()) + })?; + + debug!("Configuration valid, loading model"); + + // Get Core instance from cache (or load if not cached) + let model_path_str = config + .model_path + .to_str() + .ok_or_else(|| { + crate::error::AppError::InvalidInput( + "Invalid model path encoding".to_string(), + ) + })?; + let tokenizer_path_str = config + .tokenizer_path + .to_str() + .ok_or_else(|| { + crate::error::AppError::InvalidInput( + "Invalid tokenizer path encoding".to_string(), + ) + })?; + + let core = get_or_load_model(model_path_str, tokenizer_path_str).map_err(|e| { + error!("Model loading failed: {}", e); crate::error::AppError::InvalidInput(e) - }) + })?; + + // Generate alternatives if requested + if alternatives > 1 { + info!("Generating {} alternative commands", alternatives); + match core.generate_alternatives(prompt, alternatives) { + Ok(commands) => { + println!("Generated {} alternatives:", commands.len()); + for (i, cmd) in commands.iter().enumerate() { + if core.is_safe_command(cmd) { + println!(" {}. {}", i + 1, cmd); + if explain { + if let Ok(explanation) = core.explain_command(cmd) { + println!(" → {}", explanation); + } + } + } else { + warn!("Alternative {} failed safety check: {}", i + 1, cmd); + } + } + info!("Alternatives generated successfully"); + Ok(()) + } + Err(e) => { + error!("Alternative generation failed: {}", e); + eprintln!("❌ Error: {}", e); + Err(crate::error::AppError::InvalidInput(e.to_string())) + } + } + } else { + // Generate single command + match core.generate_command(prompt) { + Ok(command) => { + // Validate that generated command is safe + if core.is_safe_command(&command) { + info!("Command generated and validated successfully"); + debug!("Generated command: {}", command); + println!("{}", command); + + // Add explanation if requested + if explain { + match core.explain_command(&command) { + Ok(explanation) => { + println!("\nExplanation: {}", explanation); + } + Err(e) => { + warn!("Failed to generate explanation: {}", e); + } + } + } + + Ok(()) + } else { + error!("Generated command failed safety validation"); + eprintln!("❌ Safety Error: Generated command is not safe to execute"); + eprintln!("Generated: {}", command); + eprintln!(); + eprintln!( + "The model generated a command that contains dangerous patterns." + ); + eprintln!("This is a safety feature to prevent harmful commands."); + Err(crate::error::AppError::InvalidInput( + "Generated command failed safety validation".to_string(), + )) + } + } + Err(e) => { + error!("Inference failed: {}", e); + eprintln!("❌ Error: {}", e); + eprintln!(); + eprintln!("This could be due to:"); + eprintln!(" - Invalid or corrupted model file"); + eprintln!(" - Incompatible model format"); + eprintln!(" - Prompt too long or malformed"); + Err(crate::error::AppError::InvalidInput(e.to_string())) + } + } + } } Commands::Translate { ref text } => { // Validate input (max 5000 chars for translation) diff --git a/src/output.rs b/src/output.rs deleted file mode 100644 index e9b755c..0000000 --- a/src/output.rs +++ /dev/null @@ -1,136 +0,0 @@ -// Output formatting module -use serde::Serialize; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OutputFormat { - Text, - Json, -} - -impl OutputFormat { - pub fn from_str(s: &str) -> Option { - match s.to_lowercase().as_str() { - "text" | "plain" => Some(Self::Text), - "json" => Some(Self::Json), - _ => None, - } - } -} - -#[derive(Debug, Serialize)] -pub struct CommandResult { - pub prompt: String, - pub command: String, - pub safety_level: String, - pub is_safe: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub explanation: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub alternatives: Option>, -} - -impl CommandResult { - pub fn new(prompt: impl Into, command: impl Into, is_safe: bool) -> Self { - let is_safe = is_safe; - Self { - prompt: prompt.into(), - command: command.into(), - safety_level: if is_safe { "SAFE".to_string() } else { "UNSAFE".to_string() }, - is_safe, - explanation: None, - alternatives: None, - } - } - - pub fn with_explanation(mut self, explanation: impl Into) -> Self { - self.explanation = Some(explanation.into()); - self - } - - pub fn with_alternatives(mut self, alternatives: Vec) -> Self { - self.alternatives = Some(alternatives); - self - } - - pub fn to_json(&self) -> Result { - serde_json::to_string_pretty(self) - } - - pub fn to_text(&self) -> String { - let mut output = String::new(); - - if self.is_safe { - output.push_str(&format!("✅ {}\n", self.command)); - } else { - output.push_str(&format!("❌ {} (UNSAFE)\n", self.command)); - } - - if let Some(ref explanation) = self.explanation { - output.push_str(&format!("\nExplanation: {}\n", explanation)); - } - - if let Some(ref alternatives) = self.alternatives { - if !alternatives.is_empty() { - output.push_str("\nAlternatives:\n"); - for (i, alt) in alternatives.iter().enumerate() { - output.push_str(&format!(" {}. {}\n", i + 1, alt)); - } - } - } - - output - } -} - -#[derive(Debug, Serialize)] -pub struct ChatResult { - pub user_message: String, - pub assistant_message: String, -} - -impl ChatResult { - pub fn new(user_message: impl Into, assistant_message: impl Into) -> Self { - Self { - user_message: user_message.into(), - assistant_message: assistant_message.into(), - } - } - - pub fn to_json(&self) -> Result { - serde_json::to_string_pretty(self) - } - - pub fn to_text(&self) -> String { - format!("Assistant: {}", self.assistant_message) - } -} - -#[derive(Debug, Serialize)] -pub struct TranslationResultOutput { - pub detected_language: String, - pub target_language: String, - pub original_text: String, - pub translated_text: String, - pub was_translated: bool, -} - -impl TranslationResultOutput { - pub fn to_json(&self) -> Result { - serde_json::to_string_pretty(self) - } - - pub fn to_text(&self) -> String { - let mut output = String::new(); - output.push_str(&format!("Detected language: {}\n", self.detected_language)); - - if self.was_translated { - output.push_str(&format!("Original ({}): {}\n", self.detected_language, self.original_text)); - output.push_str(&format!("Translated ({}): {}\n", self.target_language, self.translated_text)); - } else { - output.push_str(&format!("Text is already in {}\n", self.target_language)); - output.push_str(&format!("Text: {}\n", self.original_text)); - } - - output - } -}