Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions memoria/crates/memoria-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@ async fn call_log_mw(
// which is necessary because JSON-RPC errors return HTTP 200.
if !is_dashboard && !path.starts_with("/v1/mcp") {
if let Some(reporter) = &state.stats_reporter {
reporter.report(
memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
user_id: uid.clone(),
path: path.clone(),
is_mcp: false,
is_success: status_code < 400,
},
);
reporter.report(memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
user_id: uid.clone(),
path: path.clone(),
is_mcp: false,
is_success: status_code < 400,
});
}
}
if let Some(mask) = should_mark_metrics_dirty(&method, &path, status_code) {
Expand Down Expand Up @@ -147,6 +145,7 @@ fn should_mark_metrics_dirty(
} else if path.starts_with("/v1/snapshots/") && path.ends_with("/rollback")
|| path.starts_with("/v1/branches/") && path.ends_with("/checkout")
|| path.starts_with("/v1/branches/") && path.ends_with("/merge")
|| path.starts_with("/v1/branches/") && path.ends_with("/pick")
{
Some(DirtyMask::FULL)
} else if path.starts_with("/v1/sessions/") && path.ends_with("/summary") {
Expand Down Expand Up @@ -283,6 +282,10 @@ pub fn build_router(state: AppState) -> Router {
"/v1/branches/:name/diff",
get(routes::snapshots::diff_branch),
)
.route(
"/v1/branches/:name/pick",
post(routes::snapshots::pick_branch),
)
.route(
"/v1/branches/:name",
delete(routes::snapshots::delete_branch),
Expand Down
78 changes: 78 additions & 0 deletions memoria/crates/memoria-api/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,84 @@ fn default_strategy() -> String {
"accept".to_string()
}

#[derive(Debug, Deserialize, Serialize)]
pub struct PickRequest {
/// Target branch for the selected changes. Defaults to main.
#[serde(default = "default_pick_target")]
pub target: String,
/// Conflict strategy for selected changes: fail | skip | accept. Defaults to fail.
#[serde(default = "default_pick_strategy")]
pub strategy: String,
/// Selector that decides which branch changes are eligible to apply.
pub selector: PickSelector,
/// Optional dry-run preview settings. When present, returns a preview instead of mutating state.
pub dry_run: Option<PickDryRunOptions>,
}

fn default_pick_target() -> String {
"main".to_string()
}

fn default_pick_strategy() -> String {
"fail".to_string()
}

fn default_pick_top_k() -> i64 {
5
}

fn default_pick_preview_limit() -> i64 {
10
}

fn default_pick_include_content_preview() -> bool {
true
}

fn default_pick_include_scores() -> bool {
true
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct PickDryRunOptions {
/// Maximum number of preview candidates to return.
#[serde(default = "default_pick_preview_limit")]
pub limit: i64,
/// Preview pagination offset.
#[serde(default)]
pub offset: i64,
/// Include a short content_preview for each candidate. Embeddings are never returned.
#[serde(default = "default_pick_include_content_preview")]
pub include_content_preview: bool,
/// Include retrieve scores when available. Non-retrieve selectors omit scores.
#[serde(default = "default_pick_include_scores")]
pub include_scores: bool,
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PickSelector {
KeyList {
/// Explicit memory_ids from the source branch to apply.
keys: Vec<String>,
},
SnapshotRange {
/// Start snapshot name in the source branch history.
from_snapshot: String,
/// End snapshot name in the source branch history.
to_snapshot: String,
},
Retrieve {
/// Natural-language query used to rank changed source rows.
query: String,
/// Maximum number of ranked rows eligible for application.
#[serde(default = "default_pick_top_k")]
top_k: i64,
/// Optional retrieve threshold. Rows below this score are excluded.
min_score: Option<f64>,
},
}

// ── Helpers ───────────────────────────────────────────────────────────────────

pub fn parse_memory_type(s: &str) -> Result<MemoryType, String> {
Expand Down
16 changes: 7 additions & 9 deletions memoria/crates/memoria-api/src/routes/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn mcp_tool_dirty_mask(tool: &str) -> Option<crate::metrics_summary::DirtyMask>
"memory_snapshot" | "memory_snapshot_delete" => Some(DirtyMask::SNAPSHOT),
"memory_rollback" => Some(DirtyMask::FULL),
"memory_branch" | "memory_branch_delete" => Some(DirtyMask::BRANCH),
"memory_checkout" | "memory_merge" => Some(DirtyMask::FULL),
"memory_checkout" | "memory_merge" | "memory_pick" => Some(DirtyMask::FULL),
_ => None,
}
}
Expand Down Expand Up @@ -133,14 +133,12 @@ pub async fn mcp_handler(
RpcMeta::err($code),
);
if let Some(reporter) = &state.stats_reporter {
reporter.report(
memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
user_id: auth.user_id.clone(),
path: $path.to_string(),
is_mcp: true,
is_success: false,
},
);
reporter.report(memoria_service::stats_reporter::StatsEvent::ApiCallLogged {
user_id: auth.user_id.clone(),
path: $path.to_string(),
is_mcp: true,
is_success: false,
});
}
return Json($body).into_response();
}};
Expand Down
58 changes: 57 additions & 1 deletion memoria/crates/memoria-api/src/routes/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
routes::memory::{api_err, api_err_typed},
state::AppState,
};
use memoria_core::TrustTier;
use memoria_core::{MemoriaError, TrustTier};
use memoria_git::GitForDataService;
use std::sync::Arc;

Expand Down Expand Up @@ -66,6 +66,27 @@ async fn git_call(
Ok(json!({ "result": text }))
}

async fn git_call_pick(
state: &AppState,
user_id: &str,
args: serde_json::Value,
) -> Result<serde_json::Value, (StatusCode, String)> {
let result =
memoria_mcp::git_tools::call("memory_pick", args, &state.git, &state.service, user_id)
.await
.map_err(|e| match e {
MemoriaError::Validation(msg) if msg.starts_with("Conflict:") => {
(StatusCode::CONFLICT, msg)
}
other => api_err_typed(other),
})?;
let text = result["content"][0]["text"]
.as_str()
.unwrap_or("")
.to_string();
Ok(serde_json::from_str(&text).unwrap_or_else(|_| json!({ "result": text })))
}

async fn user_snapshot_store(
state: &AppState,
user_id: &str,
Expand Down Expand Up @@ -687,6 +708,41 @@ pub async fn diff_branch(
Ok(Json(r))
}

/// POST /v1/branches/:name/pick
///
/// Request body:
/// - target: optional target branch, defaults to main
/// - strategy: optional conflict strategy, defaults to fail
/// - selector:
/// - key_list { keys[] }
/// - snapshot_range { from_snapshot, to_snapshot }
/// - retrieve { query, top_k, min_score? }
/// - dry_run: optional preview output shaping { limit, offset, include_content_preview, include_scores }
///
/// Response:
/// - normal execution: { "result": "Picked ..." }
/// - dry_run preview: structured JSON summary + paginated candidates (never includes embeddings)
pub async fn pick_branch(
State(state): State<AppState>,
AuthUser { user_id, .. }: AuthUser,
Path(name): Path<String>,
Json(req): Json<PickRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let r = git_call_pick(
&state,
&user_id,
json!({
"source": name,
"target": req.target,
"strategy": req.strategy,
"selector": req.selector,
"dry_run": req.dry_run,
}),
)
.await?;
Ok(Json(r))
}

pub async fn delete_branch(
State(state): State<AppState>,
AuthUser { user_id, .. }: AuthUser,
Expand Down
Loading
Loading