Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions openviking/retrieve/hierarchical_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import math
import time
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple

from openviking.models.embedder.base import EmbedResult
Expand All @@ -36,6 +36,18 @@
logger = get_logger(__name__)


def _is_expired(expires_at: datetime) -> bool:
"""Return True if ``expires_at`` is in the past."""
dt = (
parse_iso_datetime(expires_at)
if isinstance(expires_at, str)
else expires_at.replace(tzinfo=timezone.utc)
if isinstance(expires_at, datetime) and expires_at.tzinfo is None
else expires_at
)
return isinstance(dt, datetime) and dt <= datetime.now(timezone.utc)


class RetrieverMode(str):
THINKING = "thinking"
QUICK = "quick"
Expand Down Expand Up @@ -205,6 +217,9 @@ async def retrieve(
matched = await self._convert_to_matched_contexts(candidates, ctx=ctx)

final = matched[:limit]
final = [
m for m in final if not getattr(m, "expires_at", None) or not _is_expired(m.expires_at)
]

# Record retrieval stats for the observer.
elapsed_ms = (time.monotonic() - t0) * 1000
Expand Down Expand Up @@ -272,7 +287,7 @@ def _rerank_scores(
return fallback_scores

normalized_scores: List[float] = []
for score, fallback in zip(scores, fallback_scores):
for score, fallback in zip(scores, fallback_scores, strict=False):
if isinstance(score, (int, float)):
normalized_scores.append(float(score))
else:
Expand Down Expand Up @@ -343,7 +358,7 @@ def _prepare_initial_candidates(
else:
query_scores = default_scores

for candidate, score in zip(initial_candidates, query_scores):
for candidate, score in zip(initial_candidates, query_scores, strict=False):
candidate["_score"] = score

return initial_candidates
Expand Down Expand Up @@ -444,7 +459,7 @@ def passes_threshold(score: float) -> bool:
documents = [str(r.get("abstract", "")) for r in results]
query_scores = self._rerank_scores(query, documents, query_scores)

for r, score in zip(results, query_scores):
for r, score in zip(results, query_scores, strict=False):
uri = r.get("uri", "")
final_score = (
alpha * score + (1 - alpha) * current_score if current_score else score
Expand Down Expand Up @@ -553,19 +568,22 @@ async def _convert_to_matched_contexts(
level = c.get("level", 2)
display_uri = self._append_level_suffix(c.get("uri", ""), level)

results.append(
MatchedContext(
uri=display_uri,
context_type=ContextType(c["context_type"])
if c.get("context_type")
else ContextType.RESOURCE,
level=level,
abstract=c.get("abstract", ""),
category=c.get("category", ""),
score=final_score,
relations=relations,
)
matched = MatchedContext(
uri=display_uri,
context_type=ContextType(c["context_type"])
if c.get("context_type")
else ContextType.RESOURCE,
level=level,
abstract=c.get("abstract", ""),
category=c.get("category", ""),
score=final_score,
relations=relations,
)
expires_at = c.get("expires_at")
if expires_at is None and isinstance(c.get("meta"), dict):
expires_at = c["meta"].get("expires_at")
matched.expires_at = expires_at # type: ignore[attr-defined]
results.append(matched)

# Re-sort by blended score so hotness boost can change ranking
results.sort(key=lambda x: x.score, reverse=True)
Expand Down
20 changes: 17 additions & 3 deletions openviking/session/memory/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import json
from datetime import datetime
from datetime import datetime, timezone
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -73,6 +73,9 @@ class MemoryData(BaseModel):
tags: List[str] = Field(default_factory=list, description="Tags")
created_at: Optional[datetime] = Field(None, description="Created time")
updated_at: Optional[datetime] = Field(None, description="Updated time")
expires_at: Optional[datetime] = Field(
None, description="Expiration time (None = never expires)"
)

def get_field(self, field_name: str) -> Any:
"""Get field value."""
Expand All @@ -82,6 +85,17 @@ def set_field(self, field_name: str, value: Any) -> None:
"""Set field value."""
self.fields[field_name] = value

def is_expired(self) -> bool:
"""Return True if memory has expired."""
if self.expires_at is None:
return False
expires_at = (
self.expires_at.replace(tzinfo=timezone.utc)
if self.expires_at.tzinfo is None
else self.expires_at
)
return expires_at <= datetime.now(timezone.utc)


# ============================================================================
# Fault Tolerant Base Model (参考 vikingdb BaseModelCompat)
Expand Down Expand Up @@ -114,7 +128,7 @@ def get_origin_type(cls, annotation) -> type:
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
if len(args) == 2 and args[1] == type(None):
if len(args) == 2 and args[1] is type(None):
return cls.get_origin_type(args[0])
elif origin is list:
return list
Expand All @@ -126,7 +140,7 @@ def get_arg_type(cls, annotation) -> type:
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
if len(args) == 2 and args[1] == type(None):
if len(args) == 2 and args[1] is type(None):
return cls.get_arg_type(args[0])
elif origin is list:
args = get_args(annotation)
Expand Down
23 changes: 23 additions & 0 deletions openviking/session/memory_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import re
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import List, Optional
from uuid import uuid4
Expand All @@ -23,6 +24,7 @@
from openviking_cli.session.user_id import UserIdentifier
from openviking_cli.utils import get_logger
from openviking_cli.utils.config import get_openviking_config
from openviking_cli.utils.config.memory_config import MemoryConfig

logger = get_logger(__name__)

Expand Down Expand Up @@ -453,6 +455,23 @@ async def create_memory(

owner_space = self._get_owner_space(candidate.category, ctx)

memory_cfg = get_openviking_config().memory
candidate_metadata = getattr(candidate, "metadata", None)
override_ttl = (
candidate_metadata.get("ttl") if isinstance(candidate_metadata, dict) else None
)
ttl_str = override_ttl or memory_cfg.ttl_by_type.get(candidate.category.value)
if ttl_str is None:
ttl_str = memory_cfg.default_ttl
try:
ttl_delta = MemoryConfig.parse_ttl(ttl_str)
except ValueError:
logger.warning(
"Invalid memory TTL %r for category=%s", ttl_str, candidate.category.value
)
ttl_delta = None
expires_at = datetime.now(timezone.utc) + ttl_delta if ttl_delta else None

# Special handling for profile: append to profile.md
if candidate.category == MemoryCategory.PROFILE:
payload = await self._append_to_profile(candidate, viking_fs, ctx=ctx)
Expand All @@ -472,6 +491,8 @@ async def create_memory(
account_id=ctx.account_id,
owner_space=owner_space,
)
if expires_at is not None:
memory.meta["expires_at"] = expires_at
logger.info(f"uri {memory_uri} abstract: {payload.abstract} content: {payload.content}")
memory.set_vectorize(Vectorize(text=payload.content))
return memory
Expand Down Expand Up @@ -512,6 +533,8 @@ async def create_memory(
account_id=ctx.account_id,
owner_space=owner_space,
)
if expires_at is not None:
memory.meta["expires_at"] = expires_at
logger.info(f"uri {memory_uri} abstract: {candidate.abstract} content: {candidate.content}")
memory.set_vectorize(Vectorize(text=candidate.content))
return memory
Expand Down
31 changes: 30 additions & 1 deletion openviking_cli/utils/config/memory_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: AGPL-3.0
from typing import Any, Dict
import re
from datetime import timedelta
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field, field_validator

Expand All @@ -24,6 +26,13 @@ class MemoryConfig(BaseModel):
default="",
description="Custom memory templates directory. If set, templates from this directory will be loaded in addition to built-in templates",
)
default_ttl: Optional[str] = Field(
default=None,
description="Default TTL for new memories (e.g. '7d', '24h'). None = no expiration",
)
ttl_by_type: Dict[str, Optional[str]] = Field(
default_factory=dict, description="Per memory-type TTL overrides"
)

model_config = {"extra": "forbid"}

Expand All @@ -34,6 +43,26 @@ def validate_agent_scope_mode(cls, value: str) -> str:
raise ValueError("memory.agent_scope_mode must be 'user+agent' or 'agent'")
return value

@staticmethod
def parse_ttl(ttl_str: Optional[str]) -> Optional[timedelta]:
"""Parse TTL values like '7d', '24h', '30m' into timedelta."""
if ttl_str is None:
return None
match = re.fullmatch(r"\s*(\d+)\s*([smhdw])\s*", ttl_str.lower())
if not match:
raise ValueError(f"Invalid TTL format: {ttl_str}")
value = int(match.group(1))
unit = match.group(2)
if unit == "s":
return timedelta(seconds=value)
if unit == "m":
return timedelta(minutes=value)
if unit == "h":
return timedelta(hours=value)
if unit == "d":
return timedelta(days=value)
return timedelta(weeks=value)

@classmethod
def from_dict(cls, config: Dict[str, Any]) -> "MemoryConfig":
"""Create configuration from dictionary."""
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/session/test_memory_ttl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: AGPL-3.0

from datetime import datetime, timedelta, timezone

import pytest

from openviking.session.memory.dataclass import MemoryData
from openviking_cli.utils.config.memory_config import MemoryConfig


def test_memory_data_is_expired_with_no_expires_at() -> None:
memory = MemoryData(memory_type="events", fields={"k": "v"})
assert memory.is_expired() is False


def test_memory_data_is_expired_with_future_expires_at() -> None:
memory = MemoryData(
memory_type="events",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
assert memory.is_expired() is False


def test_memory_data_is_expired_with_past_expires_at() -> None:
memory = MemoryData(
memory_type="events",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
assert memory.is_expired() is True


def test_memory_config_parse_ttl() -> None:
assert MemoryConfig.parse_ttl(None) is None
assert MemoryConfig.parse_ttl("7d") == timedelta(days=7)
assert MemoryConfig.parse_ttl("24h") == timedelta(hours=24)
assert MemoryConfig.parse_ttl("30m") == timedelta(minutes=30)


@pytest.mark.parametrize("invalid_ttl", ["", "abc", "10x"])
def test_memory_config_parse_ttl_invalid(invalid_ttl: str) -> None:
with pytest.raises(ValueError):
MemoryConfig.parse_ttl(invalid_ttl)


def test_memory_data_backward_compatibility_without_expires_at() -> None:
memory = MemoryData(memory_type="preferences", abstract="pref")
dumped = memory.model_dump()
assert "expires_at" in dumped
assert dumped["expires_at"] is None
assert memory.abstract == "pref"
Loading