Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion secator/ai/interactivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,14 @@ def build_pending_prompt(self, question, choices, session_id, prompt_type="follo
# (e.g. a worker that died mid-poll). Expire them BEFORE this doc is
# persisted so only the current prompt stays live (M10).
self._expire_stale_pending(session_id)
# The conversation id rides on `_context.session_id` (auto-stamped from the
# runner context on persist) — the poll + restore + secator-api all key on
# that, so this pending doc needs no top-level session_id field.
return Ai(
content=question,
ai_type=prompt_type,
status="pending",
choices=choices,
session_id=session_id,
extra_data=extra_data,
_timestamp=time.time(),
)
Expand Down
6 changes: 5 additions & 1 deletion secator/output_types/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ class Ai(OutputType):
status: str = field(default='', compare=False)
answer: str = field(default='', compare=False)
choices: list = field(default_factory=list, compare=False)
session_id: str = field(default='', compare=False)
# NOTE: no top-level `session_id` field — the conversation id is carried by
# `_context.session_id`, auto-stamped on every persisted item from the runner
# context (see ai._init_options). restore_history_from_db, the remote answer
# poll, and secator-api all correlate on `_context.session_id`. Don't re-add a
# redundant top-level field.
_source: str = field(default='', repr=True, compare=False)
_type: str = field(default='ai', repr=True)
_timestamp: int = field(default_factory=lambda: time.time(), compare=False)
Expand Down
16 changes: 14 additions & 2 deletions secator/tasks/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _maybe_resume_remote(self):
# Append the new user message that respawned the conversation
if self.prompt:
self.history.add_user(maybe_encrypt(self.prompt, self.encryptor))
yield Ai(content=self.prompt, ai_type="prompt", session_id=self.session_id)
yield Ai(content=self.prompt, ai_type="prompt")

yield Info(message=f"Resumed session from DB ({len(self.history.messages)} messages), model: {self.model}, mode: {self.mode}") # noqa: E501
yield from self._run_loop()
Expand Down Expand Up @@ -391,7 +391,6 @@ def _mark_turn_completed(self):
content="",
ai_type="turn_completed",
status="completed",
session_id=self.session_id,
extra_data={"turn_uuid": turn_uuid},
), print=False)

Expand Down Expand Up @@ -696,6 +695,19 @@ def _init_options(self):
or self.session_name
or str(self.id)
)
# Write the resolved session_id back onto the runner context so it is the
# single source of truth for the conversation id. Every persisted item
# copies `self.context` into its `_context` (Runner._process_item), so this
# stamps `_context.session_id` on ALL `_type:"ai"` docs — including the
# `prompt`/`response` turns yielded directly here, which otherwise carry no
# session_id (they don't go through `_get_result_context` like tool docs do).
# restore_history_from_db + the remote poll both key on `_context.session_id`,
# so without this a locally-resolved session_id (str(self.id)/session_name)
# leaves the transcript turns unqueryable and a resume restores nothing.
# On the platform the dispatcher already supplies session_id in the context,
# so self.session_id equals it and this is an idempotent write.
if self.context is not None:
self.context["session_id"] = self.session_id
self.backend = create_backend(self.interactive, timeout=CONFIG.addons.ai.user_response_timeout)

# Auto-approve workspace targets
Expand Down
71 changes: 67 additions & 4 deletions tests/unit/test_ai_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for secator.ai.session restore_history_from_db + remote resume branch."""
import contextlib
import tempfile
import unittest
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -26,8 +27,10 @@ def test_rebuilds_order_roles_and_system(self):
history = restore_history_from_db(
"session1", engine, model="gpt-4o", system_prompt="SYSTEM PROMPT")

# Query was scoped to the session
engine.search.assert_called_once_with({"_type": "ai", "session_id": "session1"})
# Query was scoped to the session by the auto-stamped `_context.session_id`
# (the same key the remote poll + resume branch use — NOT the top-level field,
# which prompt/response docs don't carry).
engine.search.assert_called_once_with({"_type": "ai", "_context.session_id": "session1"})

# System prompt set, conversation turns in timestamp order, non-turn docs skipped
self.assertEqual(history.messages, [
Expand Down Expand Up @@ -115,7 +118,8 @@ def _make_task(self, prior_docs, backend_name="mongodb"):
engine.backend.name = backend_name

def _search(query, limit=0):
if query.get("_type") == "ai" and "session_id" in query:
# The resume branch scopes by `_context.session_id` (not the top-level field).
if query.get("_type") == "ai" and any("session_id" in k for k in query):
return prior_docs
return []
engine.search.side_effect = _search
Expand Down Expand Up @@ -283,7 +287,9 @@ def test_mark_turn_completed_persists_marker(self):
self.assertIsInstance(marker, Ai)
self.assertEqual(marker.ai_type, "turn_completed")
self.assertEqual(marker.extra_data.get("turn_uuid"), "turn-abc")
self.assertEqual(marker.session_id, "sess-123")
# The marker carries no top-level session_id; its conversation id is stamped
# onto `_context.session_id` by the runner persist pipeline (from self.context),
# which _turn_completed_marker queries by. That stamping is out of scope here.

# Local channel: no marker persisted (idempotency is a remote concern).
persisted.clear()
Expand Down Expand Up @@ -369,5 +375,62 @@ def test_force_redetects_over_explicit_mode(self):
mock_llm.assert_not_called()


class TestSessionIdStampedOnContext(unittest.TestCase):
"""_init_options writes the resolved session_id back onto self.context.

Every persisted item copies self.context into its `_context` (Runner._process_item),
so this is what makes `prompt`/`response` docs queryable by `_context.session_id`
(restore_history_from_db + the remote poll both key on it). Without the stamp a
locally-resolved session_id (str(self.id)/session_name) leaves the transcript turns
unqueryable and a remote resume restores an empty history.
"""

def _drive_init(self, context, run_opts=None):
from secator.tasks.ai import ai
task = ai.__new__(ai)
task.context = context
task.run_opts = run_opts or {}
task.results = []
task.inputs = []
task._reports_folder = None
task.sync = True
opt_values = {
"resume": False, "subagent": False, "model": "m", "intent_model": "im",
"api_base": None, "api_key": "k", "sensitive": False, "mode": "chat",
"max_tokens_total": 100000, "max_workers": 1, "max_iterations": 10,
"temperature": 0.7, "context_warnings": True, "async_tasks": False,
"dangerous": False, "interactive": "remote",
}
task.get_opt_value = lambda key: opt_values.get(key)
with contextlib.ExitStack() as stack:
stack.enter_context(patch('secator.tasks.ai.PermissionEngine'))
stack.enter_context(patch('secator.tasks.ai.create_backend'))
stack.enter_context(patch('secator.tasks.ai.SensitiveDataEncryptor'))
stack.enter_context(patch.object(ai, '_auto_approve_workspace_targets'))
stack.enter_context(patch.object(type(task), 'reports_folder', property(lambda self: None)))
stack.enter_context(patch.object(type(task), 'id', 'runner-id-42', create=True))
task._init_options()
return task

def test_stamped_when_locally_derived(self):
"""No session_id anywhere -> falls back to str(self.id) AND is written to context."""
task = self._drive_init(context={"workspace_id": "ws1"})
self.assertEqual(task.session_id, "runner-id-42")
self.assertEqual(task.context["session_id"], "runner-id-42")

def test_platform_supplied_session_id_preserved(self):
"""A dispatcher-supplied context session_id is kept and remains the stamped value."""
task = self._drive_init(context={"workspace_id": "ws1", "session_id": "ui-sess-abc"})
self.assertEqual(task.session_id, "ui-sess-abc")
self.assertEqual(task.context["session_id"], "ui-sess-abc")

def test_stamp_matches_restore_query_key(self):
"""The stamped context key is exactly what restore/poll query (`_context.session_id`)."""
task = self._drive_init(context={})
# Simulate the generic per-item context copy (Runner._process_item does self.context.copy()).
item_context = dict(task.context)
self.assertEqual(item_context.get("session_id"), task.session_id)


if __name__ == "__main__":
unittest.main()
Loading