Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 93 additions & 86 deletions bot/eval/locomo/import_to_ov.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import json
import sys
import time
from datetime import datetime
import traceback
from datetime import datetime, timedelta
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional

import openviking as ov
from openviking.message.part import TextPart


def _get_session_number(session_key: str) -> int:
Expand Down Expand Up @@ -58,32 +58,6 @@ def parse_test_file(path: str) -> List[Dict[str, Any]]:
return sessions


def format_locomo_message(msg: Dict[str, Any]) -> str:
"""Format a single LoCoMo message into a natural chat-style string.

Output format:
Speaker: text here
image_url: caption
"""
speaker = msg.get("speaker", "unknown")
text = msg.get("text", "")
line = f"{speaker}: {text}"

img_urls = msg.get("img_url", [])
if isinstance(img_urls, str):
img_urls = [img_urls]
blip = msg.get("blip_caption", "")

if img_urls:
for url in img_urls:
caption = f": {blip}" if blip else ""
line += f"\n{url}{caption}"
elif blip:
line += f"\n({blip})"

return line


def load_locomo_data(
path: str,
sample_index: Optional[int] = None,
Expand All @@ -103,9 +77,10 @@ def build_session_messages(
item: Dict[str, Any],
session_range: Optional[Tuple[int, int]] = None,
) -> List[Dict[str, Any]]:
"""Build bundled session messages for one LoCoMo sample.
"""Build session messages for one LoCoMo sample.

Returns list of dicts with keys: message, meta.
Returns list of dicts with keys: messages, meta.
Each dict represents a session with multiple messages (user/assistant role).
"""
conv = item["conversation"]
speakers = f"{conv['speaker_a']} & {conv['speaker_b']}"
Expand All @@ -126,13 +101,20 @@ def build_session_messages(
dt_key = f"{sk}_date_time"
date_time = conv.get(dt_key, "")

parts = [f"[group chat conversation: {date_time}]"]
for msg in conv[sk]:
parts.append(format_locomo_message(msg))
combined = "\n\n".join(parts)
# Extract messages with all as user role, including speaker in content
messages = []
for idx, msg in enumerate(conv[sk]):
speaker = msg.get("speaker", "unknown")
text = msg.get("text", "")
messages.append({
"role": "user",
"text": f"[{speaker}]: {text}",
"speaker": speaker,
"index": idx
})

sessions.append({
"message": combined,
"messages": messages,
"meta": {
"sample_id": item["sample_id"],
"session_key": sk,
Expand All @@ -148,7 +130,7 @@ def build_session_messages(
# Ingest record helpers (avoid duplicate ingestion)
# ---------------------------------------------------------------------------

def load_success_csv(csv_path: str = "import_success.csv") -> set:
def load_success_csv(csv_path: str = "./result/import_success.csv") -> set:
"""加载成功导入的CSV记录,返回已成功的键集合"""
success_keys = set()
if Path(csv_path).exists():
Expand All @@ -160,7 +142,7 @@ def load_success_csv(csv_path: str = "import_success.csv") -> set:
return success_keys


def write_success_record(record: Dict[str, Any], csv_path: str = "import_success.csv") -> None:
def write_success_record(record: Dict[str, Any], csv_path: str = "./result/import_success.csv") -> None:
"""写入成功记录到CSV文件"""
file_exists = Path(csv_path).exists()
fieldnames = ["timestamp", "sample_id", "session", "date_time", "speakers",
Expand All @@ -186,7 +168,7 @@ def write_success_record(record: Dict[str, Any], csv_path: str = "import_success
})


def write_error_record(record: Dict[str, Any], error_path: str = "import_errors.log") -> None:
def write_error_record(record: Dict[str, Any], error_path: str = "./result/import_errors.log") -> None:
"""写入错误记录到日志文件"""
with open(error_path, "a", encoding="utf-8") as f:
timestamp = record["timestamp"]
Expand Down Expand Up @@ -242,22 +224,42 @@ def mark_ingested(
# ---------------------------------------------------------------------------
# OpenViking import
# ---------------------------------------------------------------------------
def _parse_token_usage(token_data: Dict[str, Any]) -> Dict[str, int]:
"""解析Token使用数据(仅支持新版token_usage格式)"""
usage = token_data["token_usage"]
def _parse_token_usage(commit_result: Dict[str, Any]) -> Dict[str, int]:
"""解析Token使用数据(从commit返回的telemetry中提取)"""
telemetry = commit_result.get("telemetry", {}).get("summary", {})
tokens = telemetry.get("tokens", {})
return {
"embedding": usage["embedding"]["total_tokens"],
"vlm": usage["llm"]["total_tokens"],
"llm_input": usage["llm"]["prompt_tokens"],
"llm_output": usage["llm"]["completion_tokens"],
"total": usage["total"]["total_tokens"]
"embedding": tokens.get("embedding", {}).get("total", 0),
"vlm": tokens.get("llm", {}).get("total", 0),
"llm_input": tokens.get("llm", {}).get("input", 0),
"llm_output": tokens.get("llm", {}).get("output", 0),
"total": tokens.get("total", 0)
}


async def viking_ingest(msg: str, openviking_url: str, semaphore: asyncio.Semaphore) -> Dict[str, int]:
"""Save a message to OpenViking via OpenViking SDK client.
async def viking_ingest(
messages: List[Dict[str, Any]],
openviking_url: str,
semaphore: asyncio.Semaphore,
session_time: Optional[str] = None
) -> Dict[str, int]:
"""Save messages to OpenViking via OpenViking SDK client.
Returns token usage dict with embedding and vlm token counts.

Args:
messages: List of message dicts with role and text
openviking_url: OpenViking service URL
semaphore: Async semaphore for concurrency control
session_time: Session time string (e.g., "9:36 am on 2 April, 2023")
"""
# 解析 session_time - 为每条消息计算递增的时间戳
base_datetime = None
if session_time:
try:
base_datetime = datetime.strptime(session_time, "%I:%M %p on %d %B, %Y")
except ValueError:
print(f"Warning: Failed to parse session_time: {session_time}", file=sys.stderr)

# 使用信号量控制并发
async with semaphore:
# Create client
Expand All @@ -268,49 +270,41 @@ async def viking_ingest(msg: str, openviking_url: str, semaphore: asyncio.Semaph
# Create session
create_res = await client.create_session()
session_id = create_res["session_id"]
session = client.session(session_id)

# Add message
await session.add_message(
role="user",
parts=[TextPart(text=msg)]
)
# Add messages one by one with created_at
for idx, msg in enumerate(messages):
msg_created_at = None
if base_datetime:
# 每条消息递增1秒,确保时间顺序
msg_dt = base_datetime + timedelta(seconds=idx)
msg_created_at = msg_dt.isoformat()

await client.add_message(
session_id=session_id,
role=msg["role"],
parts=[{"type": "text", "text": msg["text"]}],
created_at=msg_created_at
)

# Commit
result = await session.commit(telemetry=True)
result = await client.commit_session(session_id, telemetry=True)

if not (result.get("status") == "accepted" and result.get("task_id")):
if result.get("status") != "committed":
raise RuntimeError(f"Commit failed: {result}")

# 轮询等待异步任务完成
task_id = result["task_id"]
max_wait = 1200 # 最多等待20分钟
waited = 0

while waited < max_wait:
task = await client.get_task(task_id)
if task["status"] == "completed":
token_usage = _parse_token_usage(task["result"])
break
elif task["status"] == "failed":
raise RuntimeError(f"Commit failed: {task.get('error', 'Unknown error')}")

# 指数退避策略,避免频繁请求
await asyncio.sleep(min(1 << (waited // 10), 60))
waited += 1
else:
raise RuntimeError(f"Commit timed out after {max_wait} seconds")
# 直接从commit结果中提取token使用情况
token_usage = _parse_token_usage(result)

return token_usage

finally:
await client.close()


def sync_viking_ingest(msg: str, openviking_url: str) -> Dict[str, int]:
def sync_viking_ingest(messages: List[Dict[str, Any]], openviking_url: str, session_time: Optional[str] = None) -> Dict[str, int]:
"""Synchronous wrapper for viking_ingest to maintain existing API."""
semaphore = asyncio.Semaphore(1) # 同步调用时使用信号量为1
return asyncio.run(viking_ingest(msg, openviking_url, semaphore))
return asyncio.run(viking_ingest(messages, openviking_url, semaphore, session_time))


# ---------------------------------------------------------------------------
Expand All @@ -327,7 +321,7 @@ def parse_session_range(s: str) -> Tuple[int, int]:


async def process_single_session(
msg: str,
messages: List[Dict[str, Any]],
sample_id: str | int,
session_key: str,
meta: Dict[str, Any],
Expand All @@ -338,7 +332,7 @@ async def process_single_session(
) -> Dict[str, Any]:
"""处理单个会话的导入任务"""
try:
token_usage = await viking_ingest(msg, args.openviking_url, semaphore)
token_usage = await viking_ingest(messages, args.openviking_url, semaphore, meta.get("date_time"))
print(f" -> [SUCCESS] [{sample_id}/{session_key}] imported to OpenViking", file=sys.stderr)

# Extract token counts
Expand Down Expand Up @@ -369,6 +363,7 @@ async def process_single_session(

except Exception as e:
print(f" -> [ERROR] [{sample_id}/{session_key}] {e}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)

# Write error record
result = {
Expand Down Expand Up @@ -428,7 +423,7 @@ async def run_import(args: argparse.Namespace) -> None:

for sess in sessions:
meta = sess["meta"]
msg = sess["message"]
messages = sess["messages"]
session_key = meta["session_key"]
label = f"{session_key} ({meta['date_time']})"

Expand All @@ -438,13 +433,14 @@ async def run_import(args: argparse.Namespace) -> None:
skipped_count += 1
continue

preview = msg.replace("\n", " | ")[:80]
print(f" [{label}] {preview}...", file=sys.stderr)
# Preview messages
preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]])
print(f" [{label}] {preview}", file=sys.stderr)

# 创建异步任务
task = asyncio.create_task(
process_single_session(
msg=msg,
messages=messages,
sample_id=sample_id,
session_key=session_key,
meta=meta,
Expand All @@ -471,14 +467,23 @@ async def run_import(args: argparse.Namespace) -> None:
skipped_count += 1
continue

combined_msg = "\n\n".join(session["messages"])
preview = combined_msg.replace("\n", " | ")[:80]
print(f" {preview}...", file=sys.stderr)
# For plain text, all messages as user role
messages = []
for i, text in enumerate(session["messages"]):
messages.append({
"role": "user",
"text": text.strip(),
"speaker": "user",
"index": i
})

preview = " | ".join([f"{msg['role']}: {msg['text'][:30]}..." for msg in messages[:3]])
print(f" {preview}", file=sys.stderr)

# 创建异步任务
task = asyncio.create_task(
process_single_session(
msg=combined_msg,
messages=messages,
sample_id="txt",
session_key=session_key,
meta={"session_index": idx},
Expand All @@ -499,6 +504,8 @@ async def run_import(args: argparse.Namespace) -> None:
if isinstance(result, Exception):
error_count += 1
print(f"[UNEXPECTED ERROR] Task failed with exception: {result}", file=sys.stderr)
if hasattr(result, '__traceback__'):
traceback.print_exception(type(result), result, result.__traceback__, file=sys.stderr)
continue

if result["status"] == "success":
Expand Down
Loading