diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 1fbe4d46..27e82f31 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -11,7 +11,7 @@ import logging import math import time -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple from openviking.models.embedder.base import EmbedResult @@ -36,6 +36,18 @@ logger = get_logger(__name__) +def _is_expired(expires_at: datetime) -> bool: + """Return True if ``expires_at`` is in the past.""" + dt = ( + parse_iso_datetime(expires_at) + if isinstance(expires_at, str) + else expires_at.replace(tzinfo=timezone.utc) + if isinstance(expires_at, datetime) and expires_at.tzinfo is None + else expires_at + ) + return isinstance(dt, datetime) and dt <= datetime.now(timezone.utc) + + class RetrieverMode(str): THINKING = "thinking" QUICK = "quick" @@ -205,6 +217,9 @@ async def retrieve( matched = await self._convert_to_matched_contexts(candidates, ctx=ctx) final = matched[:limit] + final = [ + m for m in final if not getattr(m, "expires_at", None) or not _is_expired(m.expires_at) + ] # Record retrieval stats for the observer. elapsed_ms = (time.monotonic() - t0) * 1000 @@ -272,7 +287,7 @@ def _rerank_scores( return fallback_scores normalized_scores: List[float] = [] - for score, fallback in zip(scores, fallback_scores): + for score, fallback in zip(scores, fallback_scores, strict=False): if isinstance(score, (int, float)): normalized_scores.append(float(score)) else: @@ -343,7 +358,7 @@ def _prepare_initial_candidates( else: query_scores = default_scores - for candidate, score in zip(initial_candidates, query_scores): + for candidate, score in zip(initial_candidates, query_scores, strict=False): candidate["_score"] = score return initial_candidates @@ -444,7 +459,7 @@ def passes_threshold(score: float) -> bool: documents = [str(r.get("abstract", "")) for r in results] query_scores = self._rerank_scores(query, documents, query_scores) - for r, score in zip(results, query_scores): + for r, score in zip(results, query_scores, strict=False): uri = r.get("uri", "") final_score = ( alpha * score + (1 - alpha) * current_score if current_score else score @@ -553,19 +568,22 @@ async def _convert_to_matched_contexts( level = c.get("level", 2) display_uri = self._append_level_suffix(c.get("uri", ""), level) - results.append( - MatchedContext( - uri=display_uri, - context_type=ContextType(c["context_type"]) - if c.get("context_type") - else ContextType.RESOURCE, - level=level, - abstract=c.get("abstract", ""), - category=c.get("category", ""), - score=final_score, - relations=relations, - ) + matched = MatchedContext( + uri=display_uri, + context_type=ContextType(c["context_type"]) + if c.get("context_type") + else ContextType.RESOURCE, + level=level, + abstract=c.get("abstract", ""), + category=c.get("category", ""), + score=final_score, + relations=relations, ) + expires_at = c.get("expires_at") + if expires_at is None and isinstance(c.get("meta"), dict): + expires_at = c["meta"].get("expires_at") + matched.expires_at = expires_at # type: ignore[attr-defined] + results.append(matched) # Re-sort by blended score so hotness boost can change ranking results.sort(key=lambda x: x.score, reverse=True) diff --git a/openviking/session/memory/dataclass.py b/openviking/session/memory/dataclass.py index f4738069..8e187a6a 100644 --- a/openviking/session/memory/dataclass.py +++ b/openviking/session/memory/dataclass.py @@ -5,7 +5,7 @@ """ import json -from datetime import datetime +from datetime import datetime, timezone from typing import ( Any, Dict, @@ -73,6 +73,9 @@ class MemoryData(BaseModel): tags: List[str] = Field(default_factory=list, description="Tags") created_at: Optional[datetime] = Field(None, description="Created time") updated_at: Optional[datetime] = Field(None, description="Updated time") + expires_at: Optional[datetime] = Field( + None, description="Expiration time (None = never expires)" + ) def get_field(self, field_name: str) -> Any: """Get field value.""" @@ -82,6 +85,17 @@ def set_field(self, field_name: str, value: Any) -> None: """Set field value.""" self.fields[field_name] = value + def is_expired(self) -> bool: + """Return True if memory has expired.""" + if self.expires_at is None: + return False + expires_at = ( + self.expires_at.replace(tzinfo=timezone.utc) + if self.expires_at.tzinfo is None + else self.expires_at + ) + return expires_at <= datetime.now(timezone.utc) + # ============================================================================ # Fault Tolerant Base Model (参考 vikingdb BaseModelCompat) @@ -114,7 +128,7 @@ def get_origin_type(cls, annotation) -> type: origin = get_origin(annotation) if origin is Union: args = get_args(annotation) - if len(args) == 2 and args[1] == type(None): + if len(args) == 2 and args[1] is type(None): return cls.get_origin_type(args[0]) elif origin is list: return list @@ -126,7 +140,7 @@ def get_arg_type(cls, annotation) -> type: origin = get_origin(annotation) if origin is Union: args = get_args(annotation) - if len(args) == 2 and args[1] == type(None): + if len(args) == 2 and args[1] is type(None): return cls.get_arg_type(args[0]) elif origin is list: args = get_args(annotation) diff --git a/openviking/session/memory_extractor.py b/openviking/session/memory_extractor.py index 82011163..705a4056 100644 --- a/openviking/session/memory_extractor.py +++ b/openviking/session/memory_extractor.py @@ -10,6 +10,7 @@ import re from dataclasses import dataclass +from datetime import datetime, timezone from enum import Enum from typing import List, Optional from uuid import uuid4 @@ -23,6 +24,7 @@ from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger from openviking_cli.utils.config import get_openviking_config +from openviking_cli.utils.config.memory_config import MemoryConfig logger = get_logger(__name__) @@ -453,6 +455,23 @@ async def create_memory( owner_space = self._get_owner_space(candidate.category, ctx) + memory_cfg = get_openviking_config().memory + candidate_metadata = getattr(candidate, "metadata", None) + override_ttl = ( + candidate_metadata.get("ttl") if isinstance(candidate_metadata, dict) else None + ) + ttl_str = override_ttl or memory_cfg.ttl_by_type.get(candidate.category.value) + if ttl_str is None: + ttl_str = memory_cfg.default_ttl + try: + ttl_delta = MemoryConfig.parse_ttl(ttl_str) + except ValueError: + logger.warning( + "Invalid memory TTL %r for category=%s", ttl_str, candidate.category.value + ) + ttl_delta = None + expires_at = datetime.now(timezone.utc) + ttl_delta if ttl_delta else None + # Special handling for profile: append to profile.md if candidate.category == MemoryCategory.PROFILE: payload = await self._append_to_profile(candidate, viking_fs, ctx=ctx) @@ -472,6 +491,8 @@ async def create_memory( account_id=ctx.account_id, owner_space=owner_space, ) + if expires_at is not None: + memory.meta["expires_at"] = expires_at logger.info(f"uri {memory_uri} abstract: {payload.abstract} content: {payload.content}") memory.set_vectorize(Vectorize(text=payload.content)) return memory @@ -512,6 +533,8 @@ async def create_memory( account_id=ctx.account_id, owner_space=owner_space, ) + if expires_at is not None: + memory.meta["expires_at"] = expires_at logger.info(f"uri {memory_uri} abstract: {candidate.abstract} content: {candidate.content}") memory.set_vectorize(Vectorize(text=candidate.content)) return memory diff --git a/openviking_cli/utils/config/memory_config.py b/openviking_cli/utils/config/memory_config.py index b6889684..adf73914 100644 --- a/openviking_cli/utils/config/memory_config.py +++ b/openviking_cli/utils/config/memory_config.py @@ -1,6 +1,8 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: AGPL-3.0 -from typing import Any, Dict +import re +from datetime import timedelta +from typing import Any, Dict, Optional from pydantic import BaseModel, Field, field_validator @@ -24,6 +26,13 @@ class MemoryConfig(BaseModel): default="", description="Custom memory templates directory. If set, templates from this directory will be loaded in addition to built-in templates", ) + default_ttl: Optional[str] = Field( + default=None, + description="Default TTL for new memories (e.g. '7d', '24h'). None = no expiration", + ) + ttl_by_type: Dict[str, Optional[str]] = Field( + default_factory=dict, description="Per memory-type TTL overrides" + ) model_config = {"extra": "forbid"} @@ -34,6 +43,26 @@ def validate_agent_scope_mode(cls, value: str) -> str: raise ValueError("memory.agent_scope_mode must be 'user+agent' or 'agent'") return value + @staticmethod + def parse_ttl(ttl_str: Optional[str]) -> Optional[timedelta]: + """Parse TTL values like '7d', '24h', '30m' into timedelta.""" + if ttl_str is None: + return None + match = re.fullmatch(r"\s*(\d+)\s*([smhdw])\s*", ttl_str.lower()) + if not match: + raise ValueError(f"Invalid TTL format: {ttl_str}") + value = int(match.group(1)) + unit = match.group(2) + if unit == "s": + return timedelta(seconds=value) + if unit == "m": + return timedelta(minutes=value) + if unit == "h": + return timedelta(hours=value) + if unit == "d": + return timedelta(days=value) + return timedelta(weeks=value) + @classmethod def from_dict(cls, config: Dict[str, Any]) -> "MemoryConfig": """Create configuration from dictionary.""" diff --git a/tests/unit/session/test_memory_ttl.py b/tests/unit/session/test_memory_ttl.py new file mode 100644 index 00000000..d33e6ce2 --- /dev/null +++ b/tests/unit/session/test_memory_ttl.py @@ -0,0 +1,51 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 + +from datetime import datetime, timedelta, timezone + +import pytest + +from openviking.session.memory.dataclass import MemoryData +from openviking_cli.utils.config.memory_config import MemoryConfig + + +def test_memory_data_is_expired_with_no_expires_at() -> None: + memory = MemoryData(memory_type="events", fields={"k": "v"}) + assert memory.is_expired() is False + + +def test_memory_data_is_expired_with_future_expires_at() -> None: + memory = MemoryData( + memory_type="events", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + ) + assert memory.is_expired() is False + + +def test_memory_data_is_expired_with_past_expires_at() -> None: + memory = MemoryData( + memory_type="events", + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + assert memory.is_expired() is True + + +def test_memory_config_parse_ttl() -> None: + assert MemoryConfig.parse_ttl(None) is None + assert MemoryConfig.parse_ttl("7d") == timedelta(days=7) + assert MemoryConfig.parse_ttl("24h") == timedelta(hours=24) + assert MemoryConfig.parse_ttl("30m") == timedelta(minutes=30) + + +@pytest.mark.parametrize("invalid_ttl", ["", "abc", "10x"]) +def test_memory_config_parse_ttl_invalid(invalid_ttl: str) -> None: + with pytest.raises(ValueError): + MemoryConfig.parse_ttl(invalid_ttl) + + +def test_memory_data_backward_compatibility_without_expires_at() -> None: + memory = MemoryData(memory_type="preferences", abstract="pref") + dumped = memory.model_dump() + assert "expires_at" in dumped + assert dumped["expires_at"] is None + assert memory.abstract == "pref"