Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion refact-agent/engine/src/at_commands/at_knowledge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl AtCommand for AtLoadKnowledge {
let search_key = args.iter().map(|x| x.text.clone()).join(" ");
let gcx = ccx.lock().await.global_context.clone();

let memories = memories_search(gcx, &search_key, 5).await?;
let memories = memories_search(gcx, &search_key, 5, 0).await?;
let mut seen_memids = HashSet::new();
let unique_memories: Vec<_> = memories.into_iter()
.filter(|m| seen_memids.insert(m.memid.clone()))
Expand Down
1 change: 1 addition & 0 deletions refact-agent/engine/src/background_tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub async fn start_background_tasks(gcx: Arc<ARwLock<GlobalContext>>, _config_di
tokio::spawn(crate::integrations::sessions::remove_expired_sessions_background_task(gcx.clone())),
tokio::spawn(crate::git::cleanup::git_shadow_cleanup_background_task(gcx.clone())),
tokio::spawn(crate::knowledge_graph::knowledge_cleanup_background_task(gcx.clone())),
tokio::spawn(crate::trajectory_memos::trajectory_memos_background_task(gcx.clone())),
]);
let ast = gcx.clone().read().await.ast_service.clone();
if let Some(ast_service) = ast {
Expand Down
4 changes: 2 additions & 2 deletions refact-agent/engine/src/file_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use std::path::PathBuf;
const LARGE_FILE_SIZE_THRESHOLD: u64 = 4096*1024; // 4Mb files
const SMALL_FILE_SIZE_THRESHOLD: u64 = 5; // 5 Bytes

pub const KNOWLEDGE_FOLDER_NAME: &str = ".refact_knowledge";
pub const KNOWLEDGE_FOLDER_NAME: &str = ".refact/knowledge";

const ALLOWED_HIDDEN_FOLDERS: &[&str] = &[KNOWLEDGE_FOLDER_NAME];
const ALLOWED_HIDDEN_FOLDERS: &[&str] = &[".refact"];

pub const SOURCE_FILE_EXTENSIONS: &[&str] = &[
"c", "cpp", "cc", "h", "hpp", "cs", "java", "py", "rb", "go", "rs", "swift",
Expand Down
1 change: 1 addition & 0 deletions refact-agent/engine/src/http/routers/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ mod v1_integrations;
pub mod vecdb;
mod workspace;
mod knowledge_graph;
mod knowledge_enrichment;
pub mod trajectories;

pub fn make_v1_router() -> Router {
Expand Down
29 changes: 10 additions & 19 deletions refact-agent/engine/src/http/routers/v1/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use axum::Extension;
use axum::response::Result;
use hyper::{Body, Response, StatusCode};

use crate::call_validation::{ChatContent, ChatMessage, ChatPost};
use crate::call_validation::{ChatContent, ChatMessage, ChatPost, ChatMode};
use crate::caps::resolve_chat_model;
use crate::custom_error::ScratchError;
use crate::at_commands::at_commands::AtCommandsContext;
Expand All @@ -17,6 +17,8 @@ use crate::integrations::docker::docker_container_manager::docker_container_chec
use crate::tools::tools_description::ToolDesc;
use crate::tools::tools_list::get_available_tools_by_chat_mode;

use super::knowledge_enrichment::enrich_messages_with_knowledge;

pub const CHAT_TOP_N: usize = 12;

pub async fn handle_v1_chat_completions(
Expand Down Expand Up @@ -198,10 +200,11 @@ async fn _chat(
}
}

// SYSTEM PROMPT WAS HERE


// chat_post.stream = Some(false); // for debugging 400 errors that are hard to debug with streaming (because "data: " is not present and the error message is ignored by the library)
let mut pre_stream_messages: Option<Vec<serde_json::Value>> = None;
let last_is_user = messages.last().map(|m| m.role == "user").unwrap_or(false);
if chat_post.meta.chat_mode == ChatMode::AGENT && last_is_user {
pre_stream_messages = enrich_messages_with_knowledge(gcx.clone(), &mut messages).await;
}
let mut scratchpad = crate::scratchpads::create_chat_scratchpad(
gcx.clone(),
&mut chat_post,
Expand All @@ -213,19 +216,6 @@ async fn _chat(
).await.map_err(|e|
ScratchError::new(StatusCode::BAD_REQUEST, e)
)?;
// if !chat_post.chat_id.is_empty() {
// let cache_dir = {
// let gcx_locked = gcx.read().await;
// gcx_locked.cache_dir.clone()
// };
// let notes_dir_path = cache_dir.join("chats");
// let _ = std::fs::create_dir_all(&notes_dir_path);
// let notes_path = notes_dir_path.join(format!("chat{}_{}.json",
// chrono::Local::now().format("%Y%m%d"),
// chat_post.chat_id,
// ));
// let _ = std::fs::write(&notes_path, serde_json::to_string_pretty(&chat_post.messages).unwrap());
// }
let mut ccx = AtCommandsContext::new(
gcx.clone(),
effective_n_ctx,
Expand Down Expand Up @@ -258,7 +248,8 @@ async fn _chat(
model_rec.base.clone(),
chat_post.parameters.clone(),
chat_post.only_deterministic_messages,
meta
meta,
pre_stream_messages,
).await
}
}
2 changes: 1 addition & 1 deletion refact-agent/engine/src/http/routers/v1/code_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub async fn handle_v1_code_completion(
if !code_completion_post.stream {
crate::restream::scratchpad_interaction_not_stream(ccx.clone(), &mut scratchpad, "completion".to_string(), &model_rec.base, &mut code_completion_post.parameters, false, None).await
} else {
crate::restream::scratchpad_interaction_stream(ccx.clone(), scratchpad, "completion-stream".to_string(), model_rec.base.clone(), code_completion_post.parameters.clone(), false, None).await
crate::restream::scratchpad_interaction_stream(ccx.clone(), scratchpad, "completion-stream".to_string(), model_rec.base.clone(), code_completion_post.parameters.clone(), false, None, None).await
}
}

Expand Down
266 changes: 266 additions & 0 deletions refact-agent/engine/src/http/routers/v1/knowledge_enrichment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock as ARwLock;
use regex::Regex;

use crate::call_validation::{ChatContent, ChatMessage, ContextFile};
use crate::global_context::GlobalContext;
use crate::memories::memories_search;

const KNOWLEDGE_TOP_N: usize = 3;
const TRAJECTORY_TOP_N: usize = 2;
const KNOWLEDGE_SCORE_THRESHOLD: f32 = 0.75;
const KNOWLEDGE_ENRICHMENT_MARKER: &str = "knowledge_enrichment";
const MAX_QUERY_LENGTH: usize = 2000;

pub async fn enrich_messages_with_knowledge(
gcx: Arc<ARwLock<GlobalContext>>,
messages: &mut Vec<ChatMessage>,
) -> Option<Vec<serde_json::Value>> {
let last_user_idx = messages.iter().rposition(|m| m.role == "user")?;
let query_raw = messages[last_user_idx].content.content_text_only();

if has_knowledge_enrichment_near(messages, last_user_idx) {
return None;
}

let query_normalized = normalize_query(&query_raw);

if !should_enrich(messages, &query_raw, &query_normalized) {
return None;
}

let existing_paths = get_existing_context_file_paths(messages);

if let Some((knowledge_context, ui_context)) = create_knowledge_context(gcx, &query_normalized, &existing_paths).await {
messages.insert(last_user_idx, knowledge_context);
tracing::info!("Injected knowledge context before user message at position {}", last_user_idx);
return Some(vec![ui_context]);
}

None
}

fn normalize_query(query: &str) -> String {
let code_fence_re = Regex::new(r"```[\s\S]*?```").unwrap();
let normalized = code_fence_re.replace_all(query, " [code] ").to_string();
let normalized = normalized.trim();
if normalized.len() > MAX_QUERY_LENGTH {
normalized.chars().take(MAX_QUERY_LENGTH).collect()
} else {
normalized.to_string()
}
}

fn should_enrich(messages: &[ChatMessage], query_raw: &str, query_normalized: &str) -> bool {
let trimmed = query_raw.trim();

// Guardrail: empty query
if trimmed.is_empty() {
return false;
}

// Guardrail: command-like messages
if trimmed.starts_with('@') || trimmed.starts_with('/') {
return false;
}

// Rule 1: Always enrich first user message
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
if user_message_count == 1 {
tracing::info!("Knowledge enrichment: first user message");
return true;
}

// Rule 2: Signal-based for subsequent messages
let strong = count_strong_signals(query_raw);
let weak = count_weak_signals(query_raw, query_normalized);

if strong >= 1 {
tracing::info!("Knowledge enrichment: {} strong signal(s)", strong);
return true;
}

if weak >= 2 && query_normalized.len() >= 20 {
tracing::info!("Knowledge enrichment: {} weak signal(s)", weak);
return true;
}

false
}

fn count_strong_signals(query: &str) -> usize {
let query_lower = query.to_lowercase();
let mut count = 0;

// Error/debug keywords
let error_keywords = [
"error", "panic", "exception", "traceback", "stack trace",
"segfault", "failed", "unable to", "cannot", "doesn't work",
"does not work", "broken", "bug", "crash"
];
if error_keywords.iter().any(|kw| query_lower.contains(kw)) {
count += 1;
}

// File references
let file_extensions = [".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".java", ".cpp", ".c", ".h"];
let config_files = ["cargo.toml", "package.json", "tsconfig", "pyproject", ".yaml", ".yml", ".toml"];
if file_extensions.iter().any(|ext| query_lower.contains(ext))
|| config_files.iter().any(|f| query_lower.contains(f)) {
count += 1;
}

// Path-like pattern
let path_re = Regex::new(r"\b[\w-]+/[\w-]+(?:/[\w.-]+)*\b").unwrap();
if path_re.is_match(query) {
count += 1;
}

// Code symbols
if query.contains("::") || query.contains("->") || query.contains("`") {
count += 1;
}

// Explicit retrieval intent
let retrieval_phrases = [
"search", "find", "where is", "which file", "look up",
"in this repo", "in the codebase", "in the project"
];
if retrieval_phrases.iter().any(|p| query_lower.contains(p)) {
count += 1;
}

count
}

fn count_weak_signals(query_raw: &str, query_normalized: &str) -> usize {
let mut count = 0;

// Has question mark
if query_raw.contains('?') {
count += 1;
}

// Starts with question word
let query_lower = query_raw.trim().to_lowercase();
let question_starters = ["how", "why", "what", "where", "when", "can", "should", "could", "would", "is there", "are there"];
if question_starters.iter().any(|s| query_lower.starts_with(s)) {
count += 1;
}

// Long enough natural language (after stripping code)
if query_normalized.len() >= 80 {
count += 1;
}

count
}

async fn create_knowledge_context(
gcx: Arc<ARwLock<GlobalContext>>,
query_text: &str,
existing_paths: &HashSet<String>,
) -> Option<(ChatMessage, serde_json::Value)> {

let memories = memories_search(gcx.clone(), &query_text, KNOWLEDGE_TOP_N, TRAJECTORY_TOP_N).await.ok()?;

let high_score_memories: Vec<_> = memories
.into_iter()
.filter(|m| m.score.unwrap_or(0.0) >= KNOWLEDGE_SCORE_THRESHOLD)
.filter(|m| {
if let Some(path) = &m.file_path {
!existing_paths.contains(&path.to_string_lossy().to_string())
} else {
true
}
})
.collect();

if high_score_memories.is_empty() {
return None;
}

tracing::info!("Knowledge enrichment: {} memories passed threshold {}", high_score_memories.len(), KNOWLEDGE_SCORE_THRESHOLD);

let context_files_for_llm: Vec<ContextFile> = high_score_memories
.iter()
.filter_map(|memo| {
let file_path = memo.file_path.as_ref()?;
let (line1, line2) = memo.line_range.unwrap_or((1, 50));
Some(ContextFile {
file_name: file_path.to_string_lossy().to_string(),
file_content: String::new(),
line1: line1 as usize,
line2: line2 as usize,
symbols: vec![],
gradient_type: -1,
usefulness: 80.0 + (memo.score.unwrap_or(0.75) * 20.0),
skip_pp: false,
})
})
.collect();

if context_files_for_llm.is_empty() {
return None;
}

let context_files_for_ui: Vec<serde_json::Value> = high_score_memories
.iter()
.filter_map(|memo| {
let file_path = memo.file_path.as_ref()?;
let (line1, line2) = memo.line_range.unwrap_or((1, 50));
Some(serde_json::json!({
"file_name": file_path.to_string_lossy().to_string(),
"file_content": memo.content.clone(),
"line1": line1,
"line2": line2,
}))
})
.collect();

let content = serde_json::to_string(&context_files_for_llm).ok()?;
let chat_message = ChatMessage {
role: "context_file".to_string(),
content: ChatContent::SimpleText(content),
tool_call_id: KNOWLEDGE_ENRICHMENT_MARKER.to_string(),
..Default::default()
};

let ui_content_str = serde_json::to_string(&context_files_for_ui).unwrap_or_default();
let ui_message = serde_json::json!({
"role": "context_file",
"content": ui_content_str,
"tool_call_id": KNOWLEDGE_ENRICHMENT_MARKER,
});

Some((chat_message, ui_message))
}

fn has_knowledge_enrichment_near(messages: &[ChatMessage], user_idx: usize) -> bool {
let search_start = user_idx.saturating_sub(2);
let search_end = (user_idx + 2).min(messages.len());

for i in search_start..search_end {
if messages[i].role == "context_file" && messages[i].tool_call_id == KNOWLEDGE_ENRICHMENT_MARKER {
tracing::info!("Skipping enrichment - already enriched at position {}", i);
return true;
}
}
false
}

fn get_existing_context_file_paths(messages: &[ChatMessage]) -> HashSet<String> {
let mut paths = HashSet::new();
for msg in messages {
if msg.role == "context_file" {
let content = msg.content.content_text_only();
if let Ok(files) = serde_json::from_str::<Vec<ContextFile>>(&content) {
for file in files {
paths.insert(file.file_name.clone());
}
}
}
}
paths
}
2 changes: 1 addition & 1 deletion refact-agent/engine/src/knowledge_graph/kg_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub async fn build_knowledge_graph(gcx: Arc<ARwLock<GlobalContext>>) -> Knowledg
.collect();

if knowledge_dirs.is_empty() {
info!("knowledge_graph: no .refact_knowledge directories found");
info!("knowledge_graph: no .refact/knowledge directories found");
return graph;
}

Expand Down
1 change: 1 addition & 0 deletions refact-agent/engine/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ mod agentic;
mod memories;
mod files_correction_cache;
mod knowledge_graph;
mod trajectory_memos;
pub mod constants;

#[tokio::main]
Expand Down
Loading
Loading