Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion memoria/crates/memoria-api/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ fn default_strategy() -> String {
// ── Helpers ───────────────────────────────────────────────────────────────────

pub fn parse_memory_type(s: &str) -> Result<MemoryType, String> {
MemoryType::from_str(s).map_err(|e| e.to_string())
Ok(s.parse::<MemoryType>().unwrap())
}

pub fn parse_trust_tier(s: &str) -> Result<TrustTier, String> {
Expand Down
5 changes: 1 addition & 4 deletions memoria/crates/memoria-api/src/routes/governance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ pub async fn reflect(
continue;
}
let mt_str = item["type"].as_str().unwrap_or("semantic");
let mt = memoria_core::MemoryType::from_str(mt_str)
.unwrap_or(memoria_core::MemoryType::Semantic);
let mt = mt_str.parse::<memoria_core::MemoryType>().unwrap();
let _ = state
.service
.store_memory(
Expand Down Expand Up @@ -370,5 +369,3 @@ pub async fn get_entities(
"entities": entities.iter().map(|(n, t)| json!({"name": n, "entity_type": t})).collect::<Vec<_>>()
})))
}

use std::str::FromStr;
89 changes: 62 additions & 27 deletions memoria/crates/memoria-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,45 @@ use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Memory type — must have exactly 6 variants matching Python implementation.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
/// Memory type classification.
///
/// The built-in variants cover standard agent memory categories.
/// `Custom(String)` allows downstream applications to define their own
/// domain-specific types (e.g. `brand_theme`, `layout_catalog`) without
/// requiring changes to Memoria itself.
///
/// Custom types are stored as-is in the database `memory_type VARCHAR(64)`
/// column and participate in retrieval filtering just like built-in types.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MemoryType {
Semantic,
Working,
Episodic,
Profile,
ToolResult,
Procedural,
/// Application-defined memory type. The inner string is stored verbatim.
Custom(String),
}

impl MemoryType {
/// Returns `true` for the six built-in variants, `false` for `Custom`.
pub fn is_builtin(&self) -> bool {
!matches!(self, MemoryType::Custom(_))
}
}

impl Serialize for MemoryType {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.to_string())
}
}

impl<'de> Deserialize<'de> for MemoryType {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(s.parse().expect("MemoryType::from_str is infallible"))
}
}

impl std::fmt::Display for MemoryType {
Expand All @@ -23,23 +52,24 @@ impl std::fmt::Display for MemoryType {
MemoryType::Profile => "profile",
MemoryType::ToolResult => "tool_result",
MemoryType::Procedural => "procedural",
MemoryType::Custom(name) => name.as_str(),
};
write!(f, "{s}")
}
}

impl std::str::FromStr for MemoryType {
type Err = crate::MemoriaError;
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"semantic" => Ok(MemoryType::Semantic),
"working" => Ok(MemoryType::Working),
"episodic" => Ok(MemoryType::Episodic),
"profile" => Ok(MemoryType::Profile),
"tool_result" => Ok(MemoryType::ToolResult),
"procedural" => Ok(MemoryType::Procedural),
other => Err(crate::MemoriaError::InvalidMemoryType(other.to_string())),
}
Ok(match s {
"semantic" => MemoryType::Semantic,
"working" => MemoryType::Working,
"episodic" => MemoryType::Episodic,
"profile" => MemoryType::Profile,
"tool_result" => MemoryType::ToolResult,
"procedural" => MemoryType::Procedural,
other => MemoryType::Custom(other.to_string()),
})
}
}

Expand Down Expand Up @@ -142,20 +172,7 @@ mod tests {
use super::*;

#[test]
fn test_memory_type_has_six_variants() {
let types = [
MemoryType::Semantic,
MemoryType::Working,
MemoryType::Episodic,
MemoryType::Profile,
MemoryType::ToolResult,
MemoryType::Procedural,
];
assert_eq!(types.len(), 6);
}

#[test]
fn test_memory_type_roundtrip() {
fn test_builtin_types_roundtrip() {
for (s, expected) in [
("semantic", MemoryType::Semantic),
("working", MemoryType::Working),
Expand All @@ -167,9 +184,27 @@ mod tests {
let parsed: MemoryType = s.parse().unwrap();
assert_eq!(parsed, expected);
assert_eq!(parsed.to_string(), s);
assert!(parsed.is_builtin());
}
}

#[test]
fn test_custom_type_roundtrip() {
let parsed: MemoryType = "brand_theme".parse().unwrap();
assert_eq!(parsed, MemoryType::Custom("brand_theme".to_string()));
assert_eq!(parsed.to_string(), "brand_theme");
assert!(!parsed.is_builtin());
}

#[test]
fn test_custom_type_serde_roundtrip() {
let mt = MemoryType::Custom("layout_catalog".to_string());
let json = serde_json::to_string(&mt).unwrap();
assert_eq!(json, "\"layout_catalog\"");
let back: MemoryType = serde_json::from_str(&json).unwrap();
assert_eq!(back, mt);
}

#[test]
fn test_trust_tier_roundtrip() {
for (s, expected) in [
Expand Down
4 changes: 2 additions & 2 deletions memoria/crates/memoria-mcp/src/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ pub async fn call(
.transpose()
.ok()
.flatten();
let mt = MemoryType::from_str(memory_type).unwrap_or(MemoryType::Semantic);
let mt = memory_type.parse::<MemoryType>().unwrap();
let m = match service
.store_memory(
user_id,
Expand Down Expand Up @@ -641,7 +641,7 @@ pub async fn call(
continue;
}
let mt_str = item["type"].as_str().unwrap_or("semantic");
let mt = MemoryType::from_str(mt_str).unwrap_or(MemoryType::Semantic);
let mt = mt_str.parse::<MemoryType>().unwrap();
let confidence = item["confidence"].as_f64().unwrap_or(0.5) as f32;
// Store as T4 (unverified insight from reflection)
let _ = service
Expand Down
4 changes: 2 additions & 2 deletions memoria/crates/memoria-storage/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl SqlMemoryStore {
r#"CREATE TABLE IF NOT EXISTS mem_memories (
memory_id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
memory_type VARCHAR(20) NOT NULL,
memory_type VARCHAR(64) NOT NULL,
content TEXT NOT NULL,
embedding vecf32({dim}),
session_id VARCHAR(64),
Expand Down Expand Up @@ -3011,7 +3011,7 @@ fn row_to_memory(row: &sqlx::mysql::MySqlRow) -> Result<Memory, MemoriaError> {
Ok(Memory {
memory_id: row.try_get("memory_id").map_err(db_err)?,
user_id: row.try_get("user_id").map_err(db_err)?,
memory_type: MemoryType::from_str(&memory_type_str)?,
memory_type: memory_type_str.parse::<MemoryType>().unwrap(),
content: row.try_get("content").map_err(db_err)?,
initial_confidence: row
.try_get::<f32, _>("initial_confidence")
Expand Down