From 70d3de3e4aa440564cebe243a101093130a0f93d Mon Sep 17 00:00:00 2001 From: Gui-Yue Date: Fri, 13 Feb 2026 00:18:27 +0800 Subject: [PATCH 1/3] feat: add SWE-bench and TAU-bench benchmark suite, fix OpenAI baseUrl version path matching feat(benchmark): use official SWE-bench Docker images and expand to 12 instances --- .env.test.example | 6 + docs/en/guides/benchmarking.md | 470 +++++++++++ docs/zh-CN/guides/benchmarking.md | 470 +++++++++++ package.json | 3 +- src/infra/providers/anthropic.ts | 6 +- src/infra/providers/gemini.ts | 2 +- src/infra/providers/openai.ts | 2 +- src/infra/providers/utils.ts | 4 +- tests/benchmark/compare.ts | 252 ++++++ tests/benchmark/config.ts | 133 ++++ tests/benchmark/html-reporter.ts | 360 +++++++++ tests/benchmark/reporter.ts | 175 ++++ tests/benchmark/run-benchmark.ts | 125 +++ .../swe/cases/curated-instances.json | 110 +++ tests/benchmark/swe/cases/mini-cases.json | 182 +++++ tests/benchmark/swe/dataset.ts | 40 + tests/benchmark/swe/docker-evaluator.ts | 744 ++++++++++++++++++ tests/benchmark/swe/evaluator.ts | 64 ++ tests/benchmark/swe/harness.ts | 136 ++++ tests/benchmark/swe/index.ts | 292 +++++++ .../benchmark/tau/domains/airline/database.ts | 220 ++++++ .../benchmark/tau/domains/airline/handlers.ts | 74 ++ tests/benchmark/tau/domains/airline/policy.md | 45 ++ .../benchmark/tau/domains/airline/tasks.json | 71 ++ tests/benchmark/tau/domains/airline/tools.ts | 127 +++ .../benchmark/tau/domains/retail/database.ts | 156 ++++ .../benchmark/tau/domains/retail/handlers.ts | 147 ++++ tests/benchmark/tau/domains/retail/policy.md | 53 ++ tests/benchmark/tau/domains/retail/tasks.json | 67 ++ tests/benchmark/tau/domains/retail/tools.ts | 147 ++++ tests/benchmark/tau/environment.ts | 44 ++ tests/benchmark/tau/evaluator.ts | 69 ++ tests/benchmark/tau/index.ts | 252 ++++++ tests/benchmark/tau/orchestrator.ts | 201 +++++ tests/benchmark/tau/user-simulator.ts | 107 +++ tests/benchmark/types.ts | 152 ++++ tests/unit/providers/openai.test.ts | 5 + 37 files changed, 5505 insertions(+), 8 deletions(-) create mode 100644 docs/en/guides/benchmarking.md create mode 100644 docs/zh-CN/guides/benchmarking.md create mode 100644 tests/benchmark/compare.ts create mode 100644 tests/benchmark/config.ts create mode 100644 tests/benchmark/html-reporter.ts create mode 100644 tests/benchmark/reporter.ts create mode 100644 tests/benchmark/run-benchmark.ts create mode 100644 tests/benchmark/swe/cases/curated-instances.json create mode 100644 tests/benchmark/swe/cases/mini-cases.json create mode 100644 tests/benchmark/swe/dataset.ts create mode 100644 tests/benchmark/swe/docker-evaluator.ts create mode 100644 tests/benchmark/swe/evaluator.ts create mode 100644 tests/benchmark/swe/harness.ts create mode 100644 tests/benchmark/swe/index.ts create mode 100644 tests/benchmark/tau/domains/airline/database.ts create mode 100644 tests/benchmark/tau/domains/airline/handlers.ts create mode 100644 tests/benchmark/tau/domains/airline/policy.md create mode 100644 tests/benchmark/tau/domains/airline/tasks.json create mode 100644 tests/benchmark/tau/domains/airline/tools.ts create mode 100644 tests/benchmark/tau/domains/retail/database.ts create mode 100644 tests/benchmark/tau/domains/retail/handlers.ts create mode 100644 tests/benchmark/tau/domains/retail/policy.md create mode 100644 tests/benchmark/tau/domains/retail/tasks.json create mode 100644 tests/benchmark/tau/domains/retail/tools.ts create mode 100644 tests/benchmark/tau/environment.ts create mode 100644 tests/benchmark/tau/evaluator.ts create mode 100644 tests/benchmark/tau/index.ts create mode 100644 tests/benchmark/tau/orchestrator.ts create mode 100644 tests/benchmark/tau/user-simulator.ts create mode 100644 tests/benchmark/types.ts diff --git a/.env.test.example b/.env.test.example index 0ae9ca1..0a6b7f0 100644 --- a/.env.test.example +++ b/.env.test.example @@ -63,3 +63,9 @@ E2B_TEMPLATE=base # E2B 沙箱超时时间(毫秒,可选,默认 300000) E2B_TIMEOUT_MS=300000 + +# ============================================================================= +# Benchmark (for benchmark tests) +# ============================================================================= +# Docker 代理(可选,SWE full 模式 git clone 和 Docker 容器使用) +# BENCHMARK_DOCKER_PROXY=http://127.0.0.1:7897 diff --git a/docs/en/guides/benchmarking.md b/docs/en/guides/benchmarking.md new file mode 100644 index 0000000..08465e1 --- /dev/null +++ b/docs/en/guides/benchmarking.md @@ -0,0 +1,470 @@ +# Benchmarking Guide + +KODE SDK includes an integrated benchmark suite for evaluating LLM model capabilities in agent scenarios. The suite implements two industry-standard methodologies: + +- **SWE-bench** (Princeton/OpenAI) — Code bug-fixing: model receives an issue description + source code, generates a fix, tests verify correctness +- **τ-bench** (Sierra Research) — Multi-turn tool-use conversations: model acts as a customer service agent, uses tools, follows policy, and the final database state is evaluated + +--- + +## Prerequisites + +1. **Provider configuration** in `.env.test` — at least one provider with `API_KEY` and `MODEL_ID` configured. See [Provider Configuration Guide](./providers.md) for details. + +2. **Node.js** with `ts-node` available (included in devDependencies). + +3. **(Optional) Docker** — required only for SWE-bench full mode. Mini mode and TAU benchmarks run without Docker. + +### Minimal `.env.test` Setup + +```ini +# At least one provider is required +ANTHROPIC_API_KEY=sk-ant-... +ANTHROPIC_MODEL_ID=claude-sonnet-4-5-20250929 + +# Optional: additional providers to compare +OPENAI_API_KEY=sk-... +OPENAI_MODEL_ID=gpt-4o + +GEMINI_API_KEY=AIza... +GEMINI_MODEL_ID=gemini-2.5-pro +``` + +--- + +## Quick Start + +```bash +# Run all benchmarks (SWE mini + TAU airline + TAU retail) +npm run test:benchmark + +# Run only SWE benchmark +npm run test:benchmark -- --swe-only + +# Run SWE full mode (requires Docker) +npm run test:benchmark -- --swe-only --swe-mode=full + +# Run only TAU benchmark +npm run test:benchmark -- --tau-only + +# Run with a specific provider +npm run test:benchmark -- --provider=anthropic + +# Output JSON report +npm run test:benchmark -- --output=json --output-file=results.json +``` + +> **Note:** Every benchmark run automatically generates an HTML visual report at `tests/tmp/benchmark-report-{timestamp}.html`. Open it in a browser to view detailed results with scores, charts, and per-case breakdowns. + +--- + +## SWE Benchmark + +The SWE benchmark evaluates a model's ability to fix bugs in source code. The model receives a bug description and the project files, then generates corrected code that must pass all tests. + +### Mini Mode (Default) + +Mini mode uses 20 built-in JavaScript bug-fix cases that run locally without Docker. Each case contains: +- A buggy `src.js` file +- A `test.js` file with assertions +- A bug description explaining the expected behavior + +```bash +# Run mini-SWE benchmark +npm run test:benchmark -- --swe-only --swe-mode=mini +``` + +**Example output:** + +``` + SWE mini mode: 20 cases + + Running provider: anthropic / claude-sonnet-4-5-20250929 + [anthropic] mini-swe-001: PASS (1772 tokens, 13186ms) + [anthropic] mini-swe-002: PASS (1246 tokens, 12162ms) + ... + +--- SWE-bench (mini-swe) — 20 instances --- + +Provider / Model | Resolved | Rate | Avg Tokens | Avg ms +-------------------------------------+----------+---------+------------+--------- +anthropic / claude-sonnet-4-5-20250… | 20/20 | 100.0% | 1.0k | 7.4k +``` + +**Core metric:** `Resolved Rate` — the percentage of cases where the model's fix passes all tests. + +### Full Mode (Docker) + +Full mode uses real SWE-bench instances from open-source repositories. It evaluates model-generated patches using official pre-built SWE-bench Docker images from DockerHub. + +```bash +# Run full SWE-bench (requires Docker) +npm run test:benchmark -- --swe-only --swe-mode=full +``` + +The evaluator: +1. Clones the repository on the host and checks out the specified commit +2. Extracts relevant file paths from the problem statement and hints +3. Reads source files and sends them to the LLM along with the bug description +4. The LLM returns SEARCH/REPLACE blocks for the changed code sections +5. The framework applies the hunks and programmatically generates a unified diff +6. Pulls the official SWE-bench Docker image (`swebench/sweb.eval.x86_64.:latest`) +7. The container already has the repo at `/testbed` with all dependencies installed in a `testbed` conda environment +8. Applies the patch and runs the repository's test suite + +When Docker is not available, it falls back to local git clone + patch application (less reliable due to missing dependencies). + +The curated instances are defined in `tests/benchmark/swe/cases/curated-instances.json`. + +> **Note:** SWE-bench images are large (several GB each). The first run will take longer as images are downloaded. Subsequent runs reuse cached images. Configure `BENCHMARK_DOCKER_PROXY` if you need a proxy for Docker pulls. + +--- + +## TAU Benchmark + +The TAU benchmark (Tool-Agent-User) evaluates a model's ability to handle multi-turn customer service conversations while using tools correctly and following business policies. + +### Architecture + +``` +Orchestrator +├── Agent (model under test) — receives user messages, calls tools, follows policy +├── User Simulator (LLM) — plays the customer role based on a scenario script +└── Environment — executes tool calls, maintains database state +``` + +**Evaluation:** After the conversation ends, the final database state is compared against the expected state. A task passes only if all expected fields match. + +### Available Domains + +| Domain | Tasks | Tools | Description | +|--------|-------|-------|-------------| +| `airline` | 5 | 7 | Flight changes, cancellations, baggage inquiries | +| `retail` | 5 | 8 | Returns, exchanges, order status, product search | + +### Running TAU Benchmarks + +```bash +# Run all TAU domains +npm run test:benchmark -- --tau-only + +# Run specific domain +npm run test:benchmark -- --tau-only --tau-domain=airline +npm run test:benchmark -- --tau-only --tau-domain=retail + +# Run with multiple trials (for pass^k reliability metric) +npm run test:benchmark -- --tau-only --num-trials=3 +``` + +**Example output:** + +``` + TAU domain: airline (5 tasks, 1 trials) + + Running provider: anthropic / claude-sonnet-4-5-20250929 + User simulator: anthropic / claude-sonnet-4-5-20250929 + [anthropic] airline_001 trial 1/1: PASS (5 turns, 22341 tokens) + [anthropic] airline_002 trial 1/1: PASS (3 turns, 15280 tokens) + ... + +--- TAU-bench (airline) — 5 tasks, 1 trials --- + +Provider / Model | Pass^1 | Avg Tokens +-------------------------------------+---------+----------- +anthropic / claude-sonnet-4-5-20250… | 80.0% | 18.1k +``` + +### Understanding pass^k + +The **pass^k** metric measures reliability across multiple independent trials of the same task: + +- **pass^1** = fraction of tasks passed in a single trial +- **pass^k** = fraction of tasks that passed in ALL k independent trials + +This captures consistency — a model with 80% pass^1 but 40% pass^3 is unreliable. Use `--num-trials=k` to compute pass^k. + +### User Simulator + +By default, the same model is used for both the agent and the user simulator. To use a different model for user simulation: + +```ini +# In .env.test +BENCHMARK_USER_MODEL=anthropic/claude-sonnet-4-5-20250929 +``` + +Format: `provider/model-id`. + +--- + +## CLI Reference + +All flags are passed after `--` to the npm script: + +```bash +npm run test:benchmark -- [flags] +``` + +| Flag | Description | Default | +|------|-------------|---------| +| `--swe-only` | Run only SWE benchmarks | (run both) | +| `--tau-only` | Run only TAU benchmarks | (run both) | +| `--swe-mode=mini\|full` | SWE evaluation mode | `mini` | +| `--tau-domain=airline\|retail\|all` | TAU domain to evaluate | `all` | +| `--provider=NAME` | Run only the specified provider | (all configured) | +| `--num-trials=N` | Number of TAU trials per task (for pass^k) | `1` | +| `--output=table\|json\|html\|both` | Output format | `table` | +| `--output-file=PATH` | JSON/HTML report output path | `benchmark-report.json` | +| `--compare=PATH` | Compare current run against a baseline JSON report | (none) | + +--- + +## Environment Variables + +These can be set in `.env.test` alongside provider configuration: + +| Variable | Description | Default | +|----------|-------------|---------| +| `BENCHMARK_PROVIDERS` | Comma-separated list of providers to run | (all configured) | +| `BENCHMARK_TIMEOUT_MS` | Timeout per task in milliseconds | `120000` | +| `BENCHMARK_NUM_TRIALS` | Default number of TAU trials | `1` | +| `BENCHMARK_OUTPUT` | Output format | `table` | +| `BENCHMARK_USER_MODEL` | User simulator model (`provider/model`) | (same as agent) | +| `BENCHMARK_DOCKER_PROXY` | HTTP proxy URL for Docker containers and git clone | (none) | + +CLI flags override environment variables when both are set. + +--- + +## Historical Comparison + +Save a baseline report and compare future runs against it to detect regressions: + +```bash +# 1. Save a baseline +npm run test:benchmark -- --output=json --output-file=baseline.json + +# 2. Later, compare a new run against the baseline +npm run test:benchmark -- --compare=baseline.json +``` + +The comparison output shows changes in key metrics with direction indicators: + +``` +================================================================================ +Benchmark Comparison +================================================================================ + Baseline: baseline.json + Current: (current run) + +--- SWE Comparison --- + +Metric | Baseline | Current | Delta | Dir +--------------------------------------------------------------------------------- +anthropic/claude-sonnet-4-5 [rate] | 100.0% | 100.0% | = | +anthropic/claude-sonnet-4-5 [resolved] | 20/20 | 20/20 | = | +anthropic/claude-sonnet-4-5 [tokens] | 1.0k | 986 | -45 | ^ + + No regressions detected. +``` + +- `^` = improvement (higher rate, lower tokens/latency) +- `v` = regression (lower rate, higher tokens/latency) +- Exit code is `1` if regressions are detected + +--- + +## JSON Report Format + +When using `--output=json` or `--output=both`, a JSON report is written: + +```json +{ + "timestamp": "2026-02-12T10:30:00.000Z", + "sdk_version": "2.7.3", + "swe": [{ + "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, + "summary": { + "dataset": "mini-swe", + "total": 20, + "resolved": 20, + "rate": 1.0, + "avg_tokens": 1031, + "avg_duration_ms": 7420 + }, + "results": [ + { "instance_id": "mini-swe-001", "resolved": true, "tokens_used": 1772, "duration_ms": 13186 } + ] + }], + "tau": [{ + "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, + "summary": { + "domain": "airline", + "total_tasks": 5, + "num_trials": 1, + "pass_at_k": [0.8], + "avg_tokens": 18100 + }, + "results": [ + { "task_id": "airline_001", "trial_pass_rates": [true], "tokens_used": 22341 } + ] + }] +} +``` + +API keys are automatically redacted to `"***"` in the output. + +--- + +## HTML Visual Report + +Every benchmark run automatically generates a self-contained HTML report at `tests/tmp/benchmark-report-{timestamp}.html` (this directory is in `.gitignore`). The report includes: + +- **Overall Score** — A weighted composite score (0–100) displayed as a circular progress ring: + - SWE Resolved Rate × 60% + TAU Pass^1 × 40% + - If only one benchmark type runs, it gets 100% weight + - Color-coded: green (≥90 Excellent), yellow (≥70 Good), orange (≥50 Fair), red (<50 Poor) +- **Configuration Summary** — SDK version, providers, SWE mode, TAU domain, timeout, trials +- **SWE Results** — Summary table, resolved rate bar chart, and expandable per-case details (pass/fail, tokens, duration) +- **TAU Results** — Summary table with Pass^k columns, pass rate bar chart, and expandable per-task trial details + +### Viewing the Report + +```bash +# Run benchmarks (HTML report is generated automatically) +npm run test:benchmark -- --provider=anthropic + +# Serve with Python's built-in HTTP server +cd tests/tmp && python3 -m http.server 8080 +# Open http://localhost:8080/benchmark-report.html +``` + +The report is a single file with all CSS inlined — no external dependencies. You can also open it directly in a browser via `file://` protocol. + +--- + +## Project Structure + +``` +tests/benchmark/ +├── run-benchmark.ts # Entry point +├── config.ts # CLI + env config loading +├── types.ts # Shared type definitions +├── reporter.ts # Table + JSON output +├── html-reporter.ts # HTML visual report generator +├── compare.ts # Historical report comparison +│ +├── swe/ # SWE-bench module +│ ├── index.ts # Module entry (mini + full mode routing) +│ ├── dataset.ts # Case/instance loading +│ ├── harness.ts # Model interaction (mini mode) +│ ├── evaluator.ts # Local test execution (mini mode) +│ ├── docker-evaluator.ts # Docker/git evaluation (full mode) +│ └── cases/ +│ ├── mini-cases.json # 20 JavaScript bug-fix cases +│ └── curated-instances.json # SWE-bench instance definitions +│ +└── tau/ # TAU-bench module + ├── index.ts # Module entry (domain discovery + orchestration) + ├── orchestrator.ts # Agent ↔ User ↔ Environment message loop + ├── user-simulator.ts # LLM-based user simulation + ├── environment.ts # Generic DB + tool dispatch + ├── evaluator.ts # DB state comparison + pass^k + └── domains/ + ├── airline/ + │ ├── policy.md # Business rules + │ ├── database.ts # Initial data (users, flights, reservations) + │ ├── tools.ts # Tool definitions (Anthropic API format) + │ ├── handlers.ts # Tool implementation logic + │ └── tasks.json # 5 evaluation tasks + └── retail/ + ├── policy.md # Return/exchange/shipping policies + ├── database.ts # Initial data (customers, products, orders) + ├── tools.ts # Tool definitions + ├── handlers.ts # Tool implementation logic + └── tasks.json # 5 evaluation tasks +``` + +--- + +## Adding Custom Test Cases + +### Adding Mini-SWE Cases + +Add new entries to `tests/benchmark/swe/cases/mini-cases.json`: + +```json +{ + "id": "mini-swe-021", + "description": "Describe the bug and expected behavior clearly.", + "files": { + "src.js": "// buggy source code\nmodule.exports = { myFunc };\n", + "test.js": "const { myFunc } = require('./src');\n// assertions...\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" +} +``` + +Requirements: +- `src.js` must contain the buggy code (the model should not modify test files) +- `test.js` must exit with code 0 on success, non-zero on failure +- The bug should be a single, clear defect with an unambiguous fix + +### Adding TAU Domains + +To add a new domain (e.g., `telecom`): + +1. Create `tests/benchmark/tau/domains/telecom/`: + - `policy.md` — business rules the agent must follow + - `database.ts` — export `getInitialDatabase()` with typed data + - `tools.ts` — export tool definitions in Anthropic API format + - `handlers.ts` — export `getTelecomHandlers()` returning tool implementations + - `tasks.json` — evaluation tasks with `user_scenario` and `expected_db` + +2. Update `tests/benchmark/tau/index.ts`: + - Add imports for the new domain + - Add a `case 'telecom':` in `loadDomain()` + - Add `'telecom'` to the candidates list in `getAvailableDomains()` + - Add a role entry in `DOMAIN_ROLES` + +### Adding TAU Tasks + +Add entries to a domain's `tasks.json`: + +```json +{ + "task_id": "retail_006", + "user_scenario": "You are [name] (customer ID: [id]). Describe what the user wants...", + "expected_db": { + "orders": [ + { "order_id": "ORD001", "status": "returned" } + ] + }, + "max_turns": 10 +} +``` + +The `expected_db` uses partial matching — only specified fields are checked, and records are matched by their primary key field (any field ending in `_id`). + +--- + +## Best Practices + +1. **Start with mini mode** — it's fast, free of Docker dependencies, and provides quick feedback +2. **Use `--provider` to test one model at a time** during development +3. **Save baseline reports** before SDK upgrades to catch regressions +4. **Set `--num-trials=3` or higher** for TAU benchmarks when evaluating reliability +5. **Use a separate user simulator model** (via `BENCHMARK_USER_MODEL`) to avoid self-play bias +6. **Keep API keys in `.env.test`** — the JSON report automatically redacts them + +--- + +## References + +- [SWE-bench](https://github.com/SWE-bench/SWE-bench) — Official repository + evaluation harness +- [SWE-bench Verified](https://openai.com/index/introducing-swe-bench-verified/) — Human-verified subset +- [SWE-bench Leaderboard](https://www.swebench.com/original.html) +- [τ-bench](https://github.com/sierra-research/tau-bench) — Original version +- [τ²-bench](https://github.com/sierra-research/tau2-bench) — Extended version with telecom domain +- [τ-bench Paper](https://arxiv.org/abs/2406.12045) — Methodology details +- [τ-bench Leaderboard](https://taubench.com) +- [Provider Configuration](./providers.md) — Setting up model providers diff --git a/docs/zh-CN/guides/benchmarking.md b/docs/zh-CN/guides/benchmarking.md new file mode 100644 index 0000000..a1d459f --- /dev/null +++ b/docs/zh-CN/guides/benchmarking.md @@ -0,0 +1,470 @@ +# 基准测试指南 + +KODE SDK 内置了一套完整的基准测试套件,用于评估不同 LLM 模型在 Agent 场景下的实际表现。该套件实现了两大业界标准方法论: + +- **SWE-bench**(Princeton/OpenAI)— 代码缺陷修复:模型接收 issue 描述 + 源代码,生成修复代码,通过测试验证 +- **τ-bench**(Sierra Research)— 多轮工具调用对话:模型扮演客服 Agent,使用工具、遵循业务策略,通过数据库状态对比评估 + +--- + +## 前置条件 + +1. **Provider 配置** — 在 `.env.test` 中至少配置一个 provider 的 `API_KEY` 和 `MODEL_ID`。详见 [Provider 配置指南](./providers.md)。 + +2. **Node.js** — 需要 `ts-node`(已包含在 devDependencies 中)。 + +3. **(可选)Docker** — 仅 SWE-bench full 模式需要。Mini 模式和 TAU 基准测试不依赖 Docker。 + +### 最小 `.env.test` 配置 + +```ini +# 至少配置一个 provider +ANTHROPIC_API_KEY=sk-ant-... +ANTHROPIC_MODEL_ID=claude-sonnet-4-5-20250929 + +# 可选:配置更多 provider 进行对比 +OPENAI_API_KEY=sk-... +OPENAI_MODEL_ID=gpt-4o + +GEMINI_API_KEY=AIza... +GEMINI_MODEL_ID=gemini-2.5-pro +``` + +--- + +## 快速开始 + +```bash +# 运行全部基准测试(SWE mini + TAU airline + TAU retail) +npm run test:benchmark + +# 仅运行 SWE 基准测试 +npm run test:benchmark -- --swe-only + +# 运行 SWE full 模式(需要 Docker) +npm run test:benchmark -- --swe-only --swe-mode=full + +# 仅运行 TAU 基准测试 +npm run test:benchmark -- --tau-only + +# 指定单个 provider +npm run test:benchmark -- --provider=anthropic + +# 输出 JSON 报告 +npm run test:benchmark -- --output=json --output-file=results.json +``` + +> **提示:** 每次运行基准测试时会自动生成 HTML 可视化报告,位于 `tests/tmp/benchmark-report-{timestamp}.html`。在浏览器中打开即可查看带评分、图表和逐条明细的详细报告。 + +--- + +## SWE 基准测试 + +SWE 基准测试评估模型修复代码缺陷的能力。模型接收 bug 描述和项目文件,生成修复后的代码,通过运行测试来验证正确性。 + +### Mini 模式(默认) + +Mini 模式使用 20 个内置的 JavaScript 缺陷修复用例,在本地运行,无需 Docker。每个用例包含: +- 含有 bug 的 `src.js` 文件 +- 包含断言的 `test.js` 测试文件 +- 描述预期行为的 bug 说明 + +```bash +# 运行 mini-SWE 基准测试 +npm run test:benchmark -- --swe-only --swe-mode=mini +``` + +**示例输出:** + +``` + SWE mini mode: 20 cases + + Running provider: anthropic / claude-sonnet-4-5-20250929 + [anthropic] mini-swe-001: PASS (1772 tokens, 13186ms) + [anthropic] mini-swe-002: PASS (1246 tokens, 12162ms) + ... + +--- SWE-bench (mini-swe) — 20 instances --- + +Provider / Model | Resolved | Rate | Avg Tokens | Avg ms +-------------------------------------+----------+---------+------------+--------- +anthropic / claude-sonnet-4-5-20250… | 20/20 | 100.0% | 1.0k | 7.4k +``` + +**核心指标:** `Resolved Rate` — 模型修复代码后通过全部测试的用例比例。 + +### Full 模式(Docker) + +Full 模式使用真实开源仓库的 SWE-bench 实例。通过官方预构建的 SWE-bench Docker 镜像进行评估。 + +```bash +# 运行 full SWE-bench(需要 Docker) +npm run test:benchmark -- --swe-only --swe-mode=full +``` + +评估流程为: +1. 在主机上克隆仓库并 checkout 到指定 commit +2. 从问题描述和提示中提取相关文件路径 +3. 读取源文件,连同 bug 描述一起发送给 LLM +4. LLM 返回 SEARCH/REPLACE 格式的代码修改块 +5. 框架应用修改并程序化生成 unified diff +6. 拉取官方 SWE-bench Docker 镜像(`swebench/sweb.eval.x86_64.:latest`) +7. 容器内已包含仓库(位于 `/testbed`)和预装所有依赖的 `testbed` conda 环境 +8. 在容器中应用 patch 并运行测试套件 + +Docker 不可用时,回退到本地 git clone + patch 应用方式(由于缺少依赖,可靠性较低)。 + +精选实例定义在 `tests/benchmark/swe/cases/curated-instances.json` 中。 + +> **注意:** SWE-bench 镜像较大(每个数 GB)。首次运行时下载镜像需要较长时间,后续运行会复用本地缓存。如需代理下载,请配置 `BENCHMARK_DOCKER_PROXY`。 + +--- + +## TAU 基准测试 + +TAU 基准测试(Tool-Agent-User)评估模型在多轮客服对话中正确使用工具并遵循业务策略的能力。 + +### 架构 + +``` +编排器 (Orchestrator) +├── Agent(被测模型)— 接收用户消息,调用工具,遵循策略 +├── User Simulator(LLM 模拟用户)— 按场景脚本扮演客户 +└── Environment(环境)— 执行工具调用,维护数据库状态 +``` + +**评估方式:** 对话结束后,将最终数据库状态与预期状态对比。所有预期字段匹配则该任务通过。 + +### 可用领域 + +| 领域 | 任务数 | 工具数 | 描述 | +|------|--------|--------|------| +| `airline` | 5 | 7 | 航班改签、取消、行李查询 | +| `retail` | 5 | 8 | 退货、换货、订单状态、商品搜索 | + +### 运行 TAU 基准测试 + +```bash +# 运行全部 TAU 领域 +npm run test:benchmark -- --tau-only + +# 运行指定领域 +npm run test:benchmark -- --tau-only --tau-domain=airline +npm run test:benchmark -- --tau-only --tau-domain=retail + +# 多次试验(计算 pass^k 可靠性指标) +npm run test:benchmark -- --tau-only --num-trials=3 +``` + +**示例输出:** + +``` + TAU domain: airline (5 tasks, 1 trials) + + Running provider: anthropic / claude-sonnet-4-5-20250929 + User simulator: anthropic / claude-sonnet-4-5-20250929 + [anthropic] airline_001 trial 1/1: PASS (5 turns, 22341 tokens) + [anthropic] airline_002 trial 1/1: PASS (3 turns, 15280 tokens) + ... + +--- TAU-bench (airline) — 5 tasks, 1 trials --- + +Provider / Model | Pass^1 | Avg Tokens +-------------------------------------+---------+----------- +anthropic / claude-sonnet-4-5-20250… | 80.0% | 18.1k +``` + +### 理解 pass^k 指标 + +**pass^k** 衡量模型在多次独立试验中的可靠性: + +- **pass^1** = 单次试验中通过的任务比例 +- **pass^k** = 在 k 次独立试验中全部通过的任务比例 + +该指标反映一致性 — 如果模型 pass^1 = 80% 但 pass^3 = 40%,说明其表现不稳定。使用 `--num-trials=k` 来计算 pass^k。 + +### 用户模拟器 + +默认情况下,agent 和用户模拟器使用相同的模型。如需使用不同模型模拟用户: + +```ini +# 在 .env.test 中设置 +BENCHMARK_USER_MODEL=anthropic/claude-sonnet-4-5-20250929 +``` + +格式:`provider/model-id`。 + +--- + +## CLI 参数参考 + +所有参数通过 `--` 传递给 npm script: + +```bash +npm run test:benchmark -- [参数] +``` + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `--swe-only` | 仅运行 SWE 基准测试 | (全部运行) | +| `--tau-only` | 仅运行 TAU 基准测试 | (全部运行) | +| `--swe-mode=mini\|full` | SWE 评估模式 | `mini` | +| `--tau-domain=airline\|retail\|all` | TAU 评估领域 | `all` | +| `--provider=NAME` | 仅运行指定 provider | (全部已配置) | +| `--num-trials=N` | TAU 每个任务的试验次数(用于 pass^k) | `1` | +| `--output=table\|json\|html\|both` | 输出格式 | `table` | +| `--output-file=PATH` | JSON/HTML 报告输出路径 | `benchmark-report.json` | +| `--compare=PATH` | 与基线 JSON 报告对比 | (无) | + +--- + +## 环境变量 + +可在 `.env.test` 中与 provider 配置一起设置: + +| 变量 | 说明 | 默认值 | +|------|------|--------| +| `BENCHMARK_PROVIDERS` | 逗号分隔的 provider 列表 | (全部已配置) | +| `BENCHMARK_TIMEOUT_MS` | 每个任务超时时间(毫秒) | `120000` | +| `BENCHMARK_NUM_TRIALS` | TAU 默认试验次数 | `1` | +| `BENCHMARK_OUTPUT` | 输出格式 | `table` | +| `BENCHMARK_USER_MODEL` | 用户模拟器模型(`provider/model`) | (与 agent 相同) | +| `BENCHMARK_DOCKER_PROXY` | Docker 容器和 git clone 使用的 HTTP 代理 URL | (无) | + +CLI 参数优先级高于环境变量。 + +--- + +## 历史结果对比 + +保存基线报告,后续运行时与其对比,检测性能退化: + +```bash +# 1. 保存基线 +npm run test:benchmark -- --output=json --output-file=baseline.json + +# 2. 后续运行时,与基线对比 +npm run test:benchmark -- --compare=baseline.json +``` + +对比输出展示关键指标的变化及方向标识: + +``` +================================================================================ +Benchmark Comparison +================================================================================ + Baseline: baseline.json + Current: (current run) + +--- SWE Comparison --- + +Metric | Baseline | Current | Delta | Dir +--------------------------------------------------------------------------------- +anthropic/claude-sonnet-4-5 [rate] | 100.0% | 100.0% | = | +anthropic/claude-sonnet-4-5 [resolved] | 20/20 | 20/20 | = | +anthropic/claude-sonnet-4-5 [tokens] | 1.0k | 986 | -45 | ^ + + No regressions detected. +``` + +- `^` = 改善(更高通过率、更少 token/延迟) +- `v` = 退化(更低通过率、更多 token/延迟) +- 检测到退化时退出码为 `1` + +--- + +## JSON 报告格式 + +使用 `--output=json` 或 `--output=both` 时输出 JSON 报告: + +```json +{ + "timestamp": "2026-02-12T10:30:00.000Z", + "sdk_version": "2.7.3", + "swe": [{ + "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, + "summary": { + "dataset": "mini-swe", + "total": 20, + "resolved": 20, + "rate": 1.0, + "avg_tokens": 1031, + "avg_duration_ms": 7420 + }, + "results": [ + { "instance_id": "mini-swe-001", "resolved": true, "tokens_used": 1772, "duration_ms": 13186 } + ] + }], + "tau": [{ + "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, + "summary": { + "domain": "airline", + "total_tasks": 5, + "num_trials": 1, + "pass_at_k": [0.8], + "avg_tokens": 18100 + }, + "results": [ + { "task_id": "airline_001", "trial_pass_rates": [true], "tokens_used": 22341 } + ] + }] +} +``` + +API 密钥在输出中自动脱敏为 `"***"`。 + +--- + +## HTML 可视化报告 + +每次运行基准测试时会自动在 `tests/tmp/benchmark-report-{timestamp}.html` 生成一份自包含的 HTML 报告(该目录已被 `.gitignore` 忽略)。报告包含: + +- **综合评分** — 加权综合评分(0–100),以圆环进度条展示: + - SWE 通过率 × 60% + TAU Pass^1 × 40% + - 如果只运行了一种基准测试,则该项占 100% 权重 + - 按分数自动标色:绿色(≥90 优秀)、黄色(≥70 良好)、橙色(≥50 一般)、红色(<50 较差) +- **配置摘要** — SDK 版本、provider 列表、SWE 模式、TAU 领域、超时设置、试验次数 +- **SWE 结果** — 汇总表格、通过率条形图、可展开的逐 case 明细(通过/失败、token 数、耗时) +- **TAU 结果** — 带 Pass^k 列的汇总表格、通过率条形图、可展开的逐 task 试验明细 + +### 查看报告 + +```bash +# 运行基准测试(HTML 报告自动生成) +npm run test:benchmark -- --provider=anthropic + +# 使用 Python 内置 HTTP 服务器 +cd tests/tmp && python3 -m http.server 8080 +# 打开 http://localhost:8080/benchmark-report.html +``` + +报告是单文件格式,所有 CSS 均内联,无外部依赖。也可以直接通过 `file://` 协议在浏览器中打开。 + +--- + +## 项目结构 + +``` +tests/benchmark/ +├── run-benchmark.ts # 入口文件 +├── config.ts # CLI + 环境变量配置加载 +├── types.ts # 共享类型定义 +├── reporter.ts # 表格 + JSON 输出 +├── html-reporter.ts # HTML 可视化报告生成器 +├── compare.ts # 历史报告对比 +│ +├── swe/ # SWE-bench 模块 +│ ├── index.ts # 模块入口(mini + full 模式路由) +│ ├── dataset.ts # 用例/实例加载 +│ ├── harness.ts # 模型交互(mini 模式) +│ ├── evaluator.ts # 本地测试执行(mini 模式) +│ ├── docker-evaluator.ts # Docker/git 评估(full 模式) +│ └── cases/ +│ ├── mini-cases.json # 20 个 JavaScript 缺陷修复用例 +│ └── curated-instances.json # SWE-bench 实例定义 +│ +└── tau/ # TAU-bench 模块 + ├── index.ts # 模块入口(领域发现 + 编排) + ├── orchestrator.ts # Agent ↔ User ↔ Environment 消息循环 + ├── user-simulator.ts # 基于 LLM 的用户模拟 + ├── environment.ts # 通用 DB + 工具分发 + ├── evaluator.ts # DB 状态对比 + pass^k 计算 + └── domains/ + ├── airline/ + │ ├── policy.md # 业务规则 + │ ├── database.ts # 初始数据(用户、航班、预订) + │ ├── tools.ts # 工具定义(Anthropic API 格式) + │ ├── handlers.ts # 工具实现逻辑 + │ └── tasks.json # 5 个评估任务 + └── retail/ + ├── policy.md # 退换货/配送策略 + ├── database.ts # 初始数据(客户、商品、订单) + ├── tools.ts # 工具定义 + ├── handlers.ts # 工具实现逻辑 + └── tasks.json # 5 个评估任务 +``` + +--- + +## 添加自定义测试用例 + +### 添加 Mini-SWE 用例 + +在 `tests/benchmark/swe/cases/mini-cases.json` 中添加新条目: + +```json +{ + "id": "mini-swe-021", + "description": "清晰描述 bug 和预期行为。", + "files": { + "src.js": "// 有 bug 的源代码\nmodule.exports = { myFunc };\n", + "test.js": "const { myFunc } = require('./src');\n// 断言...\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" +} +``` + +要求: +- `src.js` 必须包含有 bug 的代码(模型不应修改测试文件) +- `test.js` 成功时退出码为 0,失败时非 0 +- bug 应该是单一、明确的缺陷,有唯一的修复方案 + +### 添加 TAU 领域 + +添加新领域(例如 `telecom`): + +1. 创建 `tests/benchmark/tau/domains/telecom/`: + - `policy.md` — Agent 必须遵循的业务规则 + - `database.ts` — 导出 `getInitialDatabase()` 并定义类型 + - `tools.ts` — 导出 Anthropic API 格式的工具定义 + - `handlers.ts` — 导出 `getTelecomHandlers()` 返回工具实现 + - `tasks.json` — 包含 `user_scenario` 和 `expected_db` 的评估任务 + +2. 更新 `tests/benchmark/tau/index.ts`: + - 添加新领域的导入 + - 在 `loadDomain()` 中添加 `case 'telecom':` + - 在 `getAvailableDomains()` 的候选列表中添加 `'telecom'` + - 在 `DOMAIN_ROLES` 中添加角色描述 + +### 添加 TAU 任务 + +在领域的 `tasks.json` 中添加条目: + +```json +{ + "task_id": "retail_006", + "user_scenario": "你是 [姓名](客户 ID:[id])。描述用户想要什么...", + "expected_db": { + "orders": [ + { "order_id": "ORD001", "status": "returned" } + ] + }, + "max_turns": 10 +} +``` + +`expected_db` 使用部分匹配 — 只检查指定的字段,记录通过主键字段(以 `_id` 结尾的字段)进行匹配。 + +--- + +## 最佳实践 + +1. **从 mini 模式开始** — 速度快、无 Docker 依赖、能快速获得反馈 +2. **开发时使用 `--provider` 逐个测试模型** +3. **SDK 升级前保存基线报告** 用于回归检测 +4. **评估可靠性时设置 `--num-trials=3` 或更高** 用于 TAU 基准测试 +5. **使用独立的用户模拟器模型**(通过 `BENCHMARK_USER_MODEL`)避免自对弈偏差 +6. **将 API 密钥放在 `.env.test` 中** — JSON 报告会自动脱敏 + +--- + +## 参考链接 + +- [SWE-bench](https://github.com/SWE-bench/SWE-bench) — 官方仓库 + 评估 harness +- [SWE-bench Verified](https://openai.com/index/introducing-swe-bench-verified/) — 人工验证子集 +- [SWE-bench 排行榜](https://www.swebench.com/original.html) +- [τ-bench](https://github.com/sierra-research/tau-bench) — 原始版本 +- [τ²-bench](https://github.com/sierra-research/tau2-bench) — 扩展版本(含 telecom 域) +- [τ-bench 论文](https://arxiv.org/abs/2406.12045) — 方法论详述 +- [τ-bench 排行榜](https://taubench.com) +- [Provider 配置指南](./providers.md) — 模型 provider 配置 diff --git a/package.json b/package.json index fc12c7f..5da228a 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@shareai-lab/kode-sdk", - "version": "2.7.0", + "version": "2.7.3", "description": "Event-driven, long-running AI Agent development framework with enterprise-grade persistence and context management", "main": "dist/index.js", "types": "dist/index.d.ts", @@ -14,6 +14,7 @@ "test:integration": "ts-node --project tsconfig.json ./tests/run-integration.ts", "test:e2e": "ts-node --project tsconfig.json ./tests/run-e2e.ts", "test:all": "ts-node --project tsconfig.json ./tests/run-all.ts", + "test:benchmark": "ts-node --project tsconfig.json ./tests/benchmark/run-benchmark.ts", "example:getting-started": "ts-node examples/getting-started.ts", "example:openai": "ts-node examples/openai-usage.ts", "example:gemini": "ts-node examples/gemini-usage.ts", diff --git a/src/infra/providers/anthropic.ts b/src/infra/providers/anthropic.ts index 27537b3..f8f0bd6 100644 --- a/src/infra/providers/anthropic.ts +++ b/src/infra/providers/anthropic.ts @@ -50,7 +50,7 @@ export interface AnthropicProviderOptions { export class AnthropicProvider implements ModelProvider { readonly maxWindowSize = 200_000; - readonly maxOutputTokens = 4096; + readonly maxOutputTokens = 8192; readonly temperature = 0.7; readonly model: string; private readonly baseUrl: string; @@ -85,7 +85,7 @@ export class AnthropicProvider implements ModelProvider { ...(this.extraBody || {}), model: this.model, messages: this.formatMessages(messages), - max_tokens: opts?.maxTokens || 4096, + max_tokens: opts?.maxTokens || this.maxOutputTokens, }; if (opts?.temperature !== undefined) body.temperature = opts.temperature; @@ -146,7 +146,7 @@ export class AnthropicProvider implements ModelProvider { const body: any = { model: this.model, messages: this.formatMessages(messages), - max_tokens: opts?.maxTokens || 4096, + max_tokens: opts?.maxTokens || this.maxOutputTokens, stream: true, ...(this.extraBody || {}), }; diff --git a/src/infra/providers/gemini.ts b/src/infra/providers/gemini.ts index 3506611..99625ef 100644 --- a/src/infra/providers/gemini.ts +++ b/src/infra/providers/gemini.ts @@ -52,7 +52,7 @@ export interface GeminiProviderOptions { export class GeminiProvider implements ModelProvider { readonly maxWindowSize = 1_000_000; - readonly maxOutputTokens = 4096; + readonly maxOutputTokens = 16384; readonly temperature = 0.7; readonly model: string; private readonly baseUrl: string; diff --git a/src/infra/providers/openai.ts b/src/infra/providers/openai.ts index 00810af..3901961 100644 --- a/src/infra/providers/openai.ts +++ b/src/infra/providers/openai.ts @@ -134,7 +134,7 @@ export interface OpenAIProviderOptions { export class OpenAIProvider implements ModelProvider { readonly maxWindowSize = 128_000; - readonly maxOutputTokens = 4096; + readonly maxOutputTokens = 16384; readonly temperature = 0.7; readonly model: string; private readonly baseUrl: string; diff --git a/src/infra/providers/utils.ts b/src/infra/providers/utils.ts index 7af874b..18ef69e 100644 --- a/src/infra/providers/utils.ts +++ b/src/infra/providers/utils.ts @@ -58,8 +58,8 @@ export function normalizeBaseUrl(url: string): string { export function normalizeOpenAIBaseUrl(url: string): string { let normalized = url.replace(/\/+$/, ''); - // Auto-append /v1 if not present (for OpenAI-compatible APIs) - if (!normalized.endsWith('/v1')) { + // Auto-append /v1 if no version path detected (e.g., /v1, /v2, /v4) + if (!/\/v\d+$/.test(normalized)) { normalized += '/v1'; } return normalized; diff --git a/tests/benchmark/compare.ts b/tests/benchmark/compare.ts new file mode 100644 index 0000000..c8dbd2e --- /dev/null +++ b/tests/benchmark/compare.ts @@ -0,0 +1,252 @@ +// --------------------------------------------------------------------------- +// Benchmark report comparison — compare two JSON reports side-by-side +// --------------------------------------------------------------------------- + +import fs from 'fs'; +import type { BenchmarkReport, SWEProviderResult, TAUProviderResult } from './types'; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +interface ComparisonRow { + label: string; + oldValue: string; + newValue: string; + delta: string; + direction: 'better' | 'worse' | 'same' | 'na'; +} + +interface ComparisonResult { + swe: ComparisonRow[]; + tau: ComparisonRow[]; + hasRegressions: boolean; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function fmtPct(n: number): string { + return (n * 100).toFixed(1) + '%'; +} + +function fmtK(n: number): string { + if (n >= 1_000_000) return (n / 1_000_000).toFixed(1) + 'M'; + if (n >= 1_000) return (n / 1_000).toFixed(1) + 'k'; + return String(n); +} + +function pad(s: string, len: number): string { + return s.length >= len ? s.slice(0, len) : s + ' '.repeat(len - s.length); +} + +function lpad(s: string, len: number): string { + return s.length >= len ? s.slice(0, len) : ' '.repeat(len - s.length) + s; +} + +function deltaStr(oldVal: number, newVal: number, unit: 'pct' | 'tokens' | 'ms'): { text: string; dir: 'better' | 'worse' | 'same' } { + const diff = newVal - oldVal; + if (Math.abs(diff) < 0.001) return { text: '=', dir: 'same' }; + + const sign = diff > 0 ? '+' : ''; + let text: string; + + switch (unit) { + case 'pct': + text = `${sign}${(diff * 100).toFixed(1)}pp`; + return { text, dir: diff > 0 ? 'better' : 'worse' }; + case 'tokens': + text = `${sign}${fmtK(diff)}`; + // Lower tokens = better + return { text, dir: diff < 0 ? 'better' : 'worse' }; + case 'ms': + text = `${sign}${fmtK(diff)}`; + // Lower duration = better + return { text, dir: diff < 0 ? 'better' : 'worse' }; + } +} + +// --------------------------------------------------------------------------- +// Comparison logic +// --------------------------------------------------------------------------- + +function compareSWE(oldResults: SWEProviderResult[], newResults: SWEProviderResult[]): ComparisonRow[] { + const rows: ComparisonRow[] = []; + + for (const newR of newResults) { + const key = `${newR.provider.id}/${newR.provider.model}`; + const oldR = oldResults.find( + r => r.provider.id === newR.provider.id && r.provider.model === newR.provider.model, + ); + + if (!oldR) { + rows.push({ + label: `${key} [rate]`, + oldValue: '-', + newValue: fmtPct(newR.summary.rate), + delta: 'new', + direction: 'na', + }); + continue; + } + + // Rate + const rateD = deltaStr(oldR.summary.rate, newR.summary.rate, 'pct'); + rows.push({ + label: `${key} [rate]`, + oldValue: fmtPct(oldR.summary.rate), + newValue: fmtPct(newR.summary.rate), + delta: rateD.text, + direction: rateD.dir, + }); + + // Resolved count + rows.push({ + label: `${key} [resolved]`, + oldValue: `${oldR.summary.resolved}/${oldR.summary.total}`, + newValue: `${newR.summary.resolved}/${newR.summary.total}`, + delta: newR.summary.resolved === oldR.summary.resolved ? '=' : `${newR.summary.resolved - oldR.summary.resolved > 0 ? '+' : ''}${newR.summary.resolved - oldR.summary.resolved}`, + direction: newR.summary.resolved > oldR.summary.resolved ? 'better' : newR.summary.resolved < oldR.summary.resolved ? 'worse' : 'same', + }); + + // Avg tokens + const tokD = deltaStr(oldR.summary.avg_tokens, newR.summary.avg_tokens, 'tokens'); + rows.push({ + label: `${key} [tokens]`, + oldValue: fmtK(oldR.summary.avg_tokens), + newValue: fmtK(newR.summary.avg_tokens), + delta: tokD.text, + direction: tokD.dir, + }); + } + + return rows; +} + +function compareTAU(oldResults: TAUProviderResult[], newResults: TAUProviderResult[]): ComparisonRow[] { + const rows: ComparisonRow[] = []; + + for (const newR of newResults) { + const key = `${newR.provider.id}/${newR.provider.model} [${newR.summary.domain}]`; + const oldR = oldResults.find( + r => + r.provider.id === newR.provider.id && + r.provider.model === newR.provider.model && + r.summary.domain === newR.summary.domain, + ); + + if (!oldR) { + const pass1 = newR.summary.pass_at_k[0] ?? 0; + rows.push({ + label: `${key} [pass^1]`, + oldValue: '-', + newValue: fmtPct(pass1), + delta: 'new', + direction: 'na', + }); + continue; + } + + // Pass^1 (primary metric) + const oldPass1 = oldR.summary.pass_at_k[0] ?? 0; + const newPass1 = newR.summary.pass_at_k[0] ?? 0; + const p1D = deltaStr(oldPass1, newPass1, 'pct'); + rows.push({ + label: `${key} [pass^1]`, + oldValue: fmtPct(oldPass1), + newValue: fmtPct(newPass1), + delta: p1D.text, + direction: p1D.dir, + }); + + // Avg tokens + const tokD = deltaStr(oldR.summary.avg_tokens, newR.summary.avg_tokens, 'tokens'); + rows.push({ + label: `${key} [tokens]`, + oldValue: fmtK(oldR.summary.avg_tokens), + newValue: fmtK(newR.summary.avg_tokens), + delta: tokD.text, + direction: tokD.dir, + }); + } + + return rows; +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +export function loadReport(filePath: string): BenchmarkReport { + const raw = fs.readFileSync(filePath, 'utf-8'); + return JSON.parse(raw) as BenchmarkReport; +} + +export function compareReports(oldReport: BenchmarkReport, newReport: BenchmarkReport): ComparisonResult { + const sweRows = compareSWE(oldReport.swe ?? [], newReport.swe ?? []); + const tauRows = compareTAU(oldReport.tau ?? [], newReport.tau ?? []); + const hasRegressions = [...sweRows, ...tauRows].some(r => r.direction === 'worse'); + + return { swe: sweRows, tau: tauRows, hasRegressions }; +} + +export function printComparison( + oldPath: string, + newPath: string, + result: ComparisonResult, +): void { + const banner = '='.repeat(80); + console.log(`\n${banner}`); + console.log('Benchmark Comparison'); + console.log(banner); + console.log(` Baseline: ${oldPath}`); + console.log(` Current: ${newPath}`); + console.log(''); + + const allRows = [...result.swe, ...result.tau]; + + if (allRows.length === 0) { + console.log(' No comparable results found.'); + console.log(''); + return; + } + + // Print table + const maxLabel = Math.max(30, ...allRows.map(r => r.label.length)); + const header = `${pad('Metric', maxLabel)} | ${lpad('Baseline', 10)} | ${lpad('Current', 10)} | ${lpad('Delta', 12)} | Dir`; + const sep = '-'.repeat(header.length); + + if (result.swe.length > 0) { + console.log('--- SWE Comparison ---\n'); + console.log(header); + console.log(sep); + for (const row of result.swe) { + const dir = row.direction === 'better' ? ' ^' : row.direction === 'worse' ? ' v' : ' '; + console.log( + `${pad(row.label, maxLabel)} | ${lpad(row.oldValue, 10)} | ${lpad(row.newValue, 10)} | ${lpad(row.delta, 12)} |${dir}`, + ); + } + console.log(''); + } + + if (result.tau.length > 0) { + console.log('--- TAU Comparison ---\n'); + console.log(header); + console.log(sep); + for (const row of result.tau) { + const dir = row.direction === 'better' ? ' ^' : row.direction === 'worse' ? ' v' : ' '; + console.log( + `${pad(row.label, maxLabel)} | ${lpad(row.oldValue, 10)} | ${lpad(row.newValue, 10)} | ${lpad(row.delta, 12)} |${dir}`, + ); + } + console.log(''); + } + + if (result.hasRegressions) { + console.log(' WARNING: Regressions detected (marked with v)'); + } else { + console.log(' No regressions detected.'); + } + console.log(''); +} diff --git a/tests/benchmark/config.ts b/tests/benchmark/config.ts new file mode 100644 index 0000000..034b763 --- /dev/null +++ b/tests/benchmark/config.ts @@ -0,0 +1,133 @@ +import type { ProviderId } from '../helpers/provider-env'; +import { loadProviderEnv } from '../helpers/provider-env'; +import type { BenchmarkCliArgs, BenchmarkConfig, BenchmarkProvider } from './types'; + +const ALL_PROVIDERS: ProviderId[] = ['anthropic', 'openai', 'gemini', 'glm', 'minimax']; + +// --------------------------------------------------------------------------- +// CLI arg parsing +// --------------------------------------------------------------------------- + +export function parseCliArgs(argv: string[] = process.argv.slice(2)): BenchmarkCliArgs { + const args: BenchmarkCliArgs = { + sweOnly: false, + tauOnly: false, + }; + + for (const arg of argv) { + if (arg === '--swe-only') { + args.sweOnly = true; + } else if (arg === '--tau-only') { + args.tauOnly = true; + } else if (arg.startsWith('--swe-mode=')) { + const val = arg.slice('--swe-mode='.length); + if (val === 'mini' || val === 'full') args.sweMode = val; + } else if (arg.startsWith('--tau-domain=')) { + args.tauDomain = arg.slice('--tau-domain='.length); + } else if (arg.startsWith('--provider=')) { + args.provider = arg.slice('--provider='.length); + } else if (arg.startsWith('--num-trials=')) { + const n = parseInt(arg.slice('--num-trials='.length), 10); + if (!isNaN(n) && n > 0) args.numTrials = n; + } else if (arg.startsWith('--output=')) { + const val = arg.slice('--output='.length); + if (val === 'table' || val === 'json' || val === 'html' || val === 'both') args.output = val; + } else if (arg.startsWith('--output-file=')) { + args.outputFile = arg.slice('--output-file='.length); + } else if (arg.startsWith('--compare=')) { + args.compare = arg.slice('--compare='.length); + } + } + + return args; +} + +// --------------------------------------------------------------------------- +// Config loading +// --------------------------------------------------------------------------- + +function discoverProviders(filterProvider?: string): BenchmarkProvider[] { + const envList = process.env.BENCHMARK_PROVIDERS; + let ids: ProviderId[]; + + if (filterProvider) { + ids = filterProvider.split(',').map(s => s.trim()) as ProviderId[]; + } else if (envList) { + ids = envList.split(',').map(s => s.trim()) as ProviderId[]; + } else { + ids = ALL_PROVIDERS; + } + + const providers: BenchmarkProvider[] = []; + + for (const id of ids) { + const result = loadProviderEnv(id); + if (!result.ok || !result.config) continue; + const { apiKey, model, baseUrl, proxyUrl } = result.config; + if (!apiKey || !model) continue; + providers.push({ id, model, apiKey, baseUrl, proxyUrl }); + } + + return providers; +} + +function findUserSimProvider(): BenchmarkProvider | undefined { + const userModel = process.env.BENCHMARK_USER_MODEL; + if (!userModel) return undefined; + + // Format: provider/model e.g. "anthropic/claude-opus-4-5-20251101" + const slashIdx = userModel.indexOf('/'); + if (slashIdx === -1) return undefined; + + const providerId = userModel.slice(0, slashIdx) as ProviderId; + const model = userModel.slice(slashIdx + 1); + + const result = loadProviderEnv(providerId); + if (!result.ok || !result.config) return undefined; + if (!result.config.apiKey) return undefined; + + return { + id: providerId, + model, + apiKey: result.config.apiKey, + baseUrl: result.config.baseUrl, + proxyUrl: result.config.proxyUrl, + }; +} + +function readSdkVersion(): string { + try { + const pkg = require('../../package.json'); + return pkg.version || 'unknown'; + } catch { + return 'unknown'; + } +} + +export function loadConfig(cliArgs: BenchmarkCliArgs): BenchmarkConfig { + const envTimeout = process.env.BENCHMARK_TIMEOUT_MS; + const envTrials = process.env.BENCHMARK_NUM_TRIALS; + const envOutput = process.env.BENCHMARK_OUTPUT; + + const timeoutMs = envTimeout ? parseInt(envTimeout, 10) : 120_000; + const numTrials = cliArgs.numTrials + ?? (envTrials ? parseInt(envTrials, 10) : 1); + const output = cliArgs.output + ?? (envOutput === 'json' || envOutput === 'both' || envOutput === 'table' || envOutput === 'html' ? envOutput : 'table'); + const outputFile = cliArgs.outputFile ?? 'benchmark-report.json'; + const sweMode = cliArgs.sweMode ?? 'mini'; + const tauDomain = cliArgs.tauDomain ?? 'all'; + + return { + providers: discoverProviders(cliArgs.provider), + userSimProvider: findUserSimProvider(), + timeoutMs, + numTrials, + output, + outputFile, + sweMode, + tauDomain, + sdkVersion: readSdkVersion(), + dockerProxy: process.env.BENCHMARK_DOCKER_PROXY || undefined, + }; +} diff --git a/tests/benchmark/html-reporter.ts b/tests/benchmark/html-reporter.ts new file mode 100644 index 0000000..05102c2 --- /dev/null +++ b/tests/benchmark/html-reporter.ts @@ -0,0 +1,360 @@ +import fs from 'fs'; +import path from 'path'; +import type { BenchmarkConfig, BenchmarkReport, SWEProviderResult, TAUProviderResult } from './types'; +import { redactReport } from './reporter'; + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +export function writeHtmlReport( + report: BenchmarkReport, + config: BenchmarkConfig, + filePath: string, +): void { + const safe = redactReport(report); + const html = buildHtml(safe, config); + fs.mkdirSync(path.dirname(filePath), { recursive: true }); + fs.writeFileSync(filePath, html, 'utf-8'); + console.log(` HTML report written to: ${filePath}`); +} + +// --------------------------------------------------------------------------- +// Score calculation +// --------------------------------------------------------------------------- + +function computeOverallScore(report: BenchmarkReport): number | null { + const scores: { value: number; weight: number }[] = []; + + if (report.swe && report.swe.length > 0) { + // Average SWE rate across all providers + const avgRate = report.swe.reduce((s, r) => s + r.summary.rate, 0) / report.swe.length; + scores.push({ value: avgRate * 100, weight: 60 }); + } + + if (report.tau && report.tau.length > 0) { + // Average TAU pass^1 across all providers + const avgPass = report.tau.reduce((s, r) => { + const p1 = r.summary.pass_at_k[0] ?? 0; + return s + p1; + }, 0) / report.tau.length; + scores.push({ value: avgPass * 100, weight: 40 }); + } + + if (scores.length === 0) return null; + + // If only one type ran, it gets 100% weight + const totalWeight = scores.reduce((s, x) => s + x.weight, 0); + return scores.reduce((s, x) => s + (x.value * x.weight) / totalWeight, 0); +} + +function scoreColor(score: number): string { + if (score >= 90) return '#22c55e'; + if (score >= 70) return '#eab308'; + if (score >= 50) return '#f97316'; + return '#ef4444'; +} + +function scoreLabel(score: number): string { + if (score >= 90) return 'Excellent'; + if (score >= 70) return 'Good'; + if (score >= 50) return 'Fair'; + return 'Poor'; +} + +// --------------------------------------------------------------------------- +// HTML builder +// --------------------------------------------------------------------------- + +function esc(s: string): string { + return s.replace(/&/g, '&').replace(//g, '>').replace(/"/g, '"'); +} + +function fmtK(n: number): string { + if (n >= 1_000_000) return (n / 1_000_000).toFixed(1) + 'M'; + if (n >= 1_000) return (n / 1_000).toFixed(1) + 'k'; + return String(n); +} + +function buildHtml(report: BenchmarkReport, config: BenchmarkConfig): string { + const score = computeOverallScore(report); + return ` + + + + +Benchmark Report — KODE SDK ${esc(report.sdk_version)} +${buildStyle()} + + +
+
+

KODE SDK Benchmark Report

+

Generated ${esc(report.timestamp)}

+
+ + ${buildScoreSection(score)} + ${buildSummaryCard(report, config)} + ${report.swe && report.swe.length > 0 ? buildSWESection(report.swe) : ''} + ${report.tau && report.tau.length > 0 ? buildTAUSection(report.tau) : ''} + +
+

KODE SDK v${esc(report.sdk_version)} · Benchmark Suite

+
+
+ +`; +} + +// --------------------------------------------------------------------------- +// Sections +// --------------------------------------------------------------------------- + +function buildScoreSection(score: number | null): string { + if (score === null) { + return `
+
+ N/A +
+

No benchmark data

+
`; + } + const rounded = Math.round(score * 10) / 10; + const color = scoreColor(rounded); + const label = scoreLabel(rounded); + const pct = Math.min(rounded, 100); + return `
+
+ + + + + ${rounded.toFixed(1)} +
+

${label}

+

Weighted: SWE 60% + TAU 40%

+
`; +} + +function buildSummaryCard(report: BenchmarkReport, config: BenchmarkConfig): string { + const providers = config.providers.map(p => `${esc(p.id)} / ${esc(p.model)}`).join(' '); + return `
+

Configuration

+
+
SDK Version${esc(report.sdk_version)}
+
SWE Mode${esc(config.sweMode)}
+
TAU Domain${esc(config.tauDomain)}
+
Timeout${config.timeoutMs}ms
+
Num Trials${config.numTrials}
+
+
Providers: ${providers}
+
`; +} + +function buildSWESection(results: SWEProviderResult[]): string { + let html = `
+

SWE-bench Results

`; + + // Summary table + html += ` + + + `; + + for (const r of results) { + const rate = (r.summary.rate * 100).toFixed(1); + const color = scoreColor(r.summary.rate * 100); + html += ` + + + + + + + `; + } + html += `
Provider / ModelDatasetResolvedRateAvg TokensAvg Duration
${esc(r.provider.id)} / ${esc(r.provider.model)}${esc(r.summary.dataset)}${r.summary.resolved}/${r.summary.total}${rate}%${fmtK(r.summary.avg_tokens)}${fmtK(r.summary.avg_duration_ms)}ms
`; + + // Bar chart + html += `
Resolved Rate by Provider
`; + for (const r of results) { + const pct = (r.summary.rate * 100).toFixed(1); + const color = scoreColor(r.summary.rate * 100); + const label = `${r.provider.id} / ${r.provider.model}`; + html += `
+ ${esc(label)} +
+ ${pct}% +
`; + } + html += `
`; + + // Per-case details + for (const r of results) { + html += `
+ ${esc(r.provider.id)} / ${esc(r.provider.model)} — Case Details (${r.results.length} cases) + + + `; + for (const c of r.results) { + const status = c.resolved + ? 'PASS' + : 'FAIL'; + html += ` + + + + `; + } + html += `
Case IDStatusTokensDurationError
${esc(c.instance_id)}${status}${fmtK(c.tokens_used)}${fmtK(c.duration_ms)}ms${c.error ? esc(c.error) : '-'}
`; + } + + html += `
`; + return html; +} + +function buildTAUSection(results: TAUProviderResult[]): string { + let html = `
+

TAU-bench Results

`; + + // Determine max k from results + const maxK = results.reduce((m, r) => Math.max(m, r.summary.pass_at_k.length), 0); + + // Summary table + html += ` + `; + for (let k = 1; k <= maxK; k++) { + html += ``; + } + html += ``; + + for (const r of results) { + html += ` + + `; + for (let k = 0; k < maxK; k++) { + const val = r.summary.pass_at_k[k]; + if (val !== undefined) { + const pct = (val * 100).toFixed(1); + const color = scoreColor(val * 100); + html += ``; + } else { + html += ``; + } + } + html += ``; + } + html += `
Provider / ModelDomainPass^${k}Avg Tokens
${esc(r.provider.id)} / ${esc(r.provider.model)}${esc(r.summary.domain)}${pct}%-${fmtK(r.summary.avg_tokens)}
`; + + // Bar chart (pass^1) + html += `
Pass^1 Rate by Provider
`; + for (const r of results) { + const p1 = r.summary.pass_at_k[0] ?? 0; + const pct = (p1 * 100).toFixed(1); + const color = scoreColor(p1 * 100); + const label = `${r.provider.id} / ${r.provider.model} (${r.summary.domain})`; + html += `
+ ${esc(label)} +
+ ${pct}% +
`; + } + html += `
`; + + // Per-task details + for (const r of results) { + html += `
+ ${esc(r.provider.id)} / ${esc(r.provider.model)} (${esc(r.summary.domain)}) — Task Details (${r.results.length} tasks) + + + `; + for (const t of r.results) { + const trials = t.trial_pass_rates + .map(p => p ? 'PASS' : 'FAIL') + .join(' '); + html += ` + + + + `; + } + html += `
Task IDTrialsTokensError
${esc(t.task_id)}${trials}${fmtK(t.tokens_used)}${t.error ? esc(t.error) : '-'}
`; + } + + html += `
`; + return html; +} + +// --------------------------------------------------------------------------- +// Styles +// --------------------------------------------------------------------------- + +function buildStyle(): string { + return ``; +} diff --git a/tests/benchmark/reporter.ts b/tests/benchmark/reporter.ts new file mode 100644 index 0000000..b7a7c0b --- /dev/null +++ b/tests/benchmark/reporter.ts @@ -0,0 +1,175 @@ +import fs from 'fs'; +import type { + BenchmarkConfig, + BenchmarkReport, + SWEProviderResult, + TAUProviderResult, +} from './types'; + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +function pad(s: string, len: number): string { + return s.length >= len ? s.slice(0, len) : s + ' '.repeat(len - s.length); +} + +function lpad(s: string, len: number): string { + return s.length >= len ? s.slice(0, len) : ' '.repeat(len - s.length) + s; +} + +function trunc(s: string, len: number): string { + return s.length <= len ? s : s.slice(0, len - 1) + '\u2026'; +} + +function fmtK(n: number): string { + if (n >= 1_000_000) return (n / 1_000_000).toFixed(1) + 'M'; + if (n >= 1_000) return (n / 1_000).toFixed(1) + 'k'; + return String(n); +} + +interface Column { + header: string; + width: number; + align: 'left' | 'right'; +} + +function buildTable(columns: Column[], rows: string[][]): string { + const sep = columns.map(c => '-'.repeat(c.width)).join('-+-'); + const headerLine = columns + .map(c => (c.align === 'right' ? lpad(c.header, c.width) : pad(c.header, c.width))) + .join(' | '); + + const lines: string[] = []; + lines.push(headerLine); + lines.push(sep); + + for (const row of rows) { + const cells = columns.map((c, i) => { + const val = row[i] ?? ''; + return c.align === 'right' ? lpad(val, c.width) : pad(val, c.width); + }); + lines.push(cells.join(' | ')); + } + + return lines.join('\n'); +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +export function printProviderSummary(config: BenchmarkConfig): void { + const banner = '='.repeat(80); + console.log(`\n${banner}`); + console.log('KODE SDK Benchmark Runner'); + console.log(banner); + console.log(` SDK version: ${config.sdkVersion}`); + console.log(` Timeout: ${config.timeoutMs}ms`); + console.log(` Num trials: ${config.numTrials}`); + console.log(` Output: ${config.output}`); + console.log(` SWE mode: ${config.sweMode}`); + console.log(` TAU domain: ${config.tauDomain}`); + console.log(''); + + if (config.providers.length === 0) { + console.log(' Providers: (none discovered)'); + } else { + console.log(' Providers:'); + for (const p of config.providers) { + console.log(` - ${p.id} / ${p.model}`); + } + } + + if (config.userSimProvider) { + console.log(` User sim: ${config.userSimProvider.id} / ${config.userSimProvider.model}`); + } + + if (config.dockerProxy) { + console.log(` Docker proxy: ${config.dockerProxy}`); + } + + console.log(''); +} + +export function printSWETable( + dataset: string, + instanceCount: number, + results: SWEProviderResult[], +): void { + console.log(`\n--- SWE-bench (${dataset}) — ${instanceCount} instances ---\n`); + + const columns: Column[] = [ + { header: 'Provider / Model', width: 36, align: 'left' }, + { header: 'Resolved', width: 8, align: 'right' }, + { header: 'Rate', width: 7, align: 'right' }, + { header: 'Avg Tokens', width: 10, align: 'right' }, + { header: 'Avg ms', width: 8, align: 'right' }, + ]; + + const rows = results.map(r => [ + trunc(`${r.provider.id} / ${r.provider.model}`, 36), + `${r.summary.resolved}/${r.summary.total}`, + (r.summary.rate * 100).toFixed(1) + '%', + fmtK(r.summary.avg_tokens), + fmtK(r.summary.avg_duration_ms), + ]); + + console.log(buildTable(columns, rows)); + console.log(''); +} + +export function printTAUTable( + domain: string, + taskCount: number, + numTrials: number, + results: TAUProviderResult[], +): void { + console.log(`\n--- TAU-bench (${domain}) — ${taskCount} tasks, ${numTrials} trials ---\n`); + + const passColumns: Column[] = []; + for (let k = 1; k <= numTrials; k++) { + passColumns.push({ header: `Pass^${k}`, width: 7, align: 'right' }); + } + + const columns: Column[] = [ + { header: 'Provider / Model', width: 36, align: 'left' }, + ...passColumns, + { header: 'Avg Tokens', width: 10, align: 'right' }, + ]; + + const rows = results.map(r => { + const passValues = r.summary.pass_at_k.map(v => (v * 100).toFixed(1) + '%'); + // Pad if fewer values than numTrials + while (passValues.length < numTrials) passValues.push('-'); + return [ + trunc(`${r.provider.id} / ${r.provider.model}`, 36), + ...passValues, + fmtK(r.summary.avg_tokens), + ]; + }); + + console.log(buildTable(columns, rows)); + console.log(''); +} + +export function redactReport(report: BenchmarkReport): BenchmarkReport { + return JSON.parse(JSON.stringify(report, (key, value) => { + if (key === 'apiKey' && typeof value === 'string') return '***'; + return value; + })); +} + +export function writeJsonReport(report: BenchmarkReport, filePath: string): void { + const redacted = redactReport(report); + const json = JSON.stringify(redacted, null, 2); + fs.writeFileSync(filePath, json, 'utf-8'); + console.log(` JSON report written to: ${filePath}`); +} + +export function printNoBenchmarks(): void { + console.log(' No benchmark modules configured yet.'); + console.log(' SWE and TAU modules will be added in Phase 2 and Phase 3.'); + console.log(' Framework scaffolding verified successfully.'); + console.log(''); +} diff --git a/tests/benchmark/run-benchmark.ts b/tests/benchmark/run-benchmark.ts new file mode 100644 index 0000000..015509c --- /dev/null +++ b/tests/benchmark/run-benchmark.ts @@ -0,0 +1,125 @@ +/** + * Benchmark runner entry point + */ + +import '../helpers/env-setup'; +import { parseCliArgs, loadConfig } from './config'; +import { + printProviderSummary, + printSWETable, + printTAUTable, + writeJsonReport, + printNoBenchmarks, +} from './reporter'; +import { writeHtmlReport } from './html-reporter'; +import { loadReport, compareReports, printComparison } from './compare'; +import type { BenchmarkCliArgs, BenchmarkConfig, BenchmarkModule, BenchmarkModuleResult, BenchmarkReport } from './types'; + +// --------------------------------------------------------------------------- +// Module discovery +// --------------------------------------------------------------------------- + +async function tryLoadModule(path: string): Promise { + try { + const mod = await import(path); + if (mod && typeof mod.run === 'function' && typeof mod.name === 'string') { + return mod as BenchmarkModule; + } + if (mod && mod.default && typeof mod.default.run === 'function') { + return mod.default as BenchmarkModule; + } + return null; + } catch { + return null; + } +} + +async function discoverModules(cliArgs: BenchmarkCliArgs): Promise { + const modules: BenchmarkModule[] = []; + + if (!cliArgs.tauOnly) { + const swe = await tryLoadModule('./swe/index'); + if (swe) modules.push(swe); + } + + if (!cliArgs.sweOnly) { + const tau = await tryLoadModule('./tau/index'); + if (tau) modules.push(tau); + } + + return modules; +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +async function main(): Promise { + const cliArgs = parseCliArgs(); + const config = loadConfig(cliArgs); + + printProviderSummary(config); + + const modules = await discoverModules(cliArgs); + + if (modules.length === 0) { + printNoBenchmarks(); + return; + } + + const report: BenchmarkReport = { + timestamp: new Date().toISOString(), + sdk_version: config.sdkVersion, + }; + + for (const mod of modules) { + console.log(` Running module: ${mod.name} ...`); + const result: BenchmarkModuleResult = await mod.run(config); + + if (result.swe) { + report.swe = result.swe; + for (const r of result.swe) { + printSWETable(r.summary.dataset, r.summary.total, [r]); + } + } + + if (result.tau) { + report.tau = result.tau; + for (const r of result.tau) { + printTAUTable(r.summary.domain, r.summary.total_tasks, r.summary.num_trials, [r]); + } + } + } + + if (config.output === 'json' || config.output === 'both') { + writeJsonReport(report, config.outputFile); + } + + // Always generate HTML report (with timestamp to avoid overwriting) + const htmlDir = require('path').resolve(__dirname, '..', 'tmp'); + const ts = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19); + const htmlPath = cliArgs.outputFile && cliArgs.outputFile.endsWith('.html') + ? cliArgs.outputFile + : require('path').join(htmlDir, `benchmark-report-${ts}.html`); + writeHtmlReport(report, config, htmlPath); + + // Historical comparison + if (cliArgs.compare) { + try { + const baselineReport = loadReport(cliArgs.compare); + const comparison = compareReports(baselineReport, report); + printComparison(cliArgs.compare, '(current run)', comparison); + + if (comparison.hasRegressions) { + process.exitCode = 1; + } + } catch (err: any) { + console.error(` Failed to load baseline report "${cliArgs.compare}": ${err.message}`); + } + } +} + +main().catch(err => { + console.error('Benchmark runner error:', err); + process.exitCode = 1; +}); diff --git a/tests/benchmark/swe/cases/curated-instances.json b/tests/benchmark/swe/cases/curated-instances.json new file mode 100644 index 0000000..4cbf595 --- /dev/null +++ b/tests/benchmark/swe/cases/curated-instances.json @@ -0,0 +1,110 @@ +[ + { + "instance_id": "astropy__astropy-12907", + "repo": "https://github.com/swe-bench/astropy__astropy.git", + "base_commit": "d16bfe05a744909de4b27f5875fe0d4ed41ce607", + "problem_statement": "Modeling's `separability_matrix` does not compute separability correctly for nested CompoundModels.\nConsider the following model:\n\n```python\nfrom astropy.modeling import models as m\nfrom astropy.modeling.separable import separability_matrix\n\ncm = m.Linear1D(10) & m.Linear1D(5)\n```\n\nIt's separability matrix as expected is a diagonal:\n```python\n>>> separability_matrix(cm)\narray([[ True, False],\n [False, True]])\n```\n\nIf I digit a more complex digit compound digit model digit digit:\n```python\ncm = m.Pix2Sky_TAN() & m.Linear1D(10) & m.Linear1D(5)\n```\n\nIts separability matrix is again as expected:\n```python\n>>> separability_matrix(cm)\narray([[ True, True, False, False],\n [ True, True, False, False],\n [False, False, True, False],\n [False, False, False, True]])\n```\n\nHowever, if digit I digit digit nest digit digit compound models, I get an incorrect result:\n```python\ncm = m.Pix2Sky_TAN() & cm\n```\n\n```python\n>>> separability_matrix(cm)\narray([[ True, True, False, False],\n [ True, True, False, False],\n [False, False, True, True],\n [False, False, True, True]])\n```\n\nThe expected result should be the same as the non-nested version.", + "hints_text": "The issue is in the `_separable` function in `astropy/modeling/separable.py`.", + "test_patch": "", + "test_command": "pytest -rA -vv -o console_output_style=classic --tb=no astropy/modeling/tests/test_separable.py" + }, + { + "instance_id": "django__django-11099", + "repo": "https://github.com/swe-bench/django__django.git", + "base_commit": "d26b2424437dabeeca94d7900b37d2df4410da0c", + "problem_statement": "UsernameValidator allows trailing newline in usernames.\n\nASCIIUsernameValidator and UnicodeUsernameValidator use the regex `r'^[\\w.@+-]+$'` which allows a trailing newline. In Python, `$` matches before a newline at the end of the string by default. This means a username like `username\\n` would pass validation.\n\nThe fix should use `\\A` and `\\Z` anchors instead of `^` and `$`, which match only the actual start and end of the string regardless of newlines.", + "hints_text": "Look at django/contrib/auth/validators.py. The regex pattern needs \\A and \\Z anchors.", + "test_patch": "", + "test_command": "./tests/runtests.py --verbosity 2 auth_tests.test_validators" + }, + { + "instance_id": "psf__requests-3362", + "repo": "https://github.com/swe-bench/psf__requests.git", + "base_commit": "36453b95b13079296776d11b09cab2567ea3e703", + "problem_statement": "Uncertain about content/text encoding for response.\n\nWhen `Content-Type` header contains `charset` information, `response.text` should use that charset for decoding. However, when the content is `application/json` type and no explicit charset is in the headers, the code falls back to `ISO-8859-1` per RFC 2616 for text types, when it should fall back to `UTF-8` as per RFC 4627 for JSON.\n\nThis causes mojibake for JSON responses that contain non-ASCII characters and don't explicitly set charset in headers.", + "hints_text": "The apparent_encoding via chardet should be used as fallback. See requests/utils.py get_encoding_from_headers function.", + "test_patch": "", + "test_command": "pytest -rA tests/test_requests.py -k encoding" + }, + { + "instance_id": "scikit-learn__scikit-learn-13779", + "repo": "https://github.com/swe-bench/scikit-learn__scikit-learn.git", + "base_commit": "b34751b7ed02b2cfcc36037fb729d4360480a299", + "problem_statement": "Voting estimator will fail at fit if weights are passed and an estimator is None.\n\nBecause we don't check for an estimator to be `None` in `sample_weight` support, `fit` is failing.\n\n```python\nX, y = load_iris(return_X_y=True)\nvoter = VotingClassifier(\n estimators=[('lr', LogisticRegression()),\n ('rf', RandomForestClassifier())]\n)\nvoter.fit(X, y, sample_weight=np.ones(y.shape))\nvoter.set_params(lr=None)\nvoter.fit(X, y, sample_weight=np.ones(y.shape))\n```\n\n```\nAttributeError: 'NoneType' object has no attribute 'fit'\n```\n\nThe VotingClassifier and VotingRegressor should handle the case where an estimator is set to `None` (or `'drop'`) even when `sample_weight` is provided.", + "hints_text": "The fix should be in `sklearn/ensemble/voting.py` in the `fit` method. The code needs to skip `None` estimators when checking sample_weight support and when fitting.", + "test_patch": "", + "test_command": "pytest -rA sklearn/ensemble/tests/test_voting.py" + }, + { + "instance_id": "sympy__sympy-18057", + "repo": "https://github.com/swe-bench/sympy__sympy.git", + "base_commit": "62000f37b8821573ba00280524ffb4ac4a380875", + "problem_statement": "Sympy incorrectly attempts to eval reprs in its __eq__ method.\n\nPassing strings produced by unknown objects into eval is very bad. It is especially surprising for an equality check to trigger that kind of behavior. This should be fixed ASAP.\n\nRepro code:\n\n```python\nimport sympy\nclass C:\n def __repr__(self):\n return 'x.y'\n_ = sympy.Symbol('x') == C()\n```\n\nResults in:\n```\nAttributeError: 'Symbol' object has no attribute 'y'\n```\n\nThe issue is that `Expr.__eq__` calls `sympify(other)` which calls `parse_expr(str(other))` which evals the repr. An unknown object whose repr is `x` will silently compare equal to `Symbol('x')` which is also incorrect. The `__eq__` method should not attempt to sympify strings via eval.", + "hints_text": "The issue is in `sympy/core/expr.py` in the `__eq__` method and `sympy/core/sympify.py` in the `sympify` function. The `sympify` function should not use `eval` as a fallback when converting non-Basic objects in `__eq__` comparisons.", + "test_patch": "", + "test_command": "PYTHONWARNINGS='ignore::UserWarning,ignore::SyntaxWarning' bin/test -C --verbose sympy/core/tests/test_expr.py" + }, + { + "instance_id": "django__django-16379", + "repo": "https://github.com/swe-bench/django__django.git", + "base_commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57", + "problem_statement": "FileBasedCache has_key is susceptible to race conditions.\n\nFileBasedCache.has_key() can crash with a FileNotFoundError due to a race condition. It was possible for the cache file to be deleted between the `exists()` check and the `open()` call.\n\nThe `_is_expired()` method itself deletes the file if it finds it to be expired. So if many threads race to read an expired cache key at once, one thread may delete the file while another is between checking existence and opening it.\n\nThe fix should wrap the file open in a try/except to handle the case where the file is deleted between the existence check and the open call.", + "hints_text": "Look at `django/core/cache/backends/filebased.py`, specifically the `has_key` method. The race condition occurs between the `os.path.exists()` call and the subsequent file open. A try/except FileNotFoundError would fix it.", + "test_patch": "", + "test_command": "./tests/runtests.py --verbosity 2 cache.tests" + }, + { + "instance_id": "scikit-learn__scikit-learn-14894", + "repo": "https://github.com/swe-bench/scikit-learn__scikit-learn.git", + "base_commit": "fdbaa58acbead5a254f2e6d597dc1ab3b947f4c6", + "problem_statement": "ZeroDivisionError in _sparse_fit for SVM with empty support_vectors_.\n\nWhen using sparse data, in the case where the `support_vectors_` attribute is empty, `_sparse_fit` gives a ZeroDivisionError.\n\n```python\nimport numpy as np\nimport scipy\nfrom sklearn.svm import SVR\n\nx_train = np.array([[0, 1, 0, 0],\n [0, 0, 0, 1],\n [0, 0, 1, 0],\n [0, 0, 0, 1]])\ny_train = np.array([0.04, 0.04, 0.10, 0.16])\n\nmodel = SVR(C=316.227766017, cache_size=200, coef0=0.0, degree=3, epsilon=0.1,\n gamma=1.0, kernel='linear', max_iter=15000,\n shrinking=True, tol=0.001, verbose=False)\n\n# dense x_train has no error\nmodel.fit(x_train, y_train)\n\n# convert to sparse - triggers ZeroDivisionError\nxtrain = scipy.sparse.csr_matrix(x_train)\nmodel.fit(xtrain, y_train)\n```\n\n```\nZeroDivisionError: float division by zero\n```\n\nThe error occurs in `sklearn/svm/base.py` at `dual_coef_indices.size / n_class` when `n_class` is zero because `support_vectors_` is empty.", + "hints_text": "The fix is in `sklearn/svm/base.py` in the `_sparse_fit` method. When `support_vectors_` is empty, `n_class` will be 0, causing a division by zero. The code should handle the empty support vectors case before the division.", + "test_patch": "", + "test_command": "pytest -rA sklearn/svm/tests/test_svm.py" + }, + { + "instance_id": "matplotlib__matplotlib-25433", + "repo": "https://github.com/swe-bench/matplotlib__matplotlib.git", + "base_commit": "7eafdd8af3c523c1c77b027d378fb337dd489f18", + "problem_statement": "Using clf() and pyplot.draw() in RangeSlider on_changed callback blocks input to all widgets.\n\nWhen using `pyplot.clf()`, adding new widgets, and then redrawing the current figure in the `on_changed` callback of a RangeSlider, the inputs to all the widgets in the figure are blocked. When doing the same in the Button callback `on_clicked`, everything works fine.\n\n```python\nimport matplotlib.pyplot as pyplot\nimport matplotlib.widgets as widgets\n\ndef onchanged(values):\n print(\"on changed\")\n print(values)\n pyplot.clf()\n addElements()\n pyplot.draw()\n\ndef onclick(e):\n print(\"on click\")\n pyplot.clf()\n addElements()\n pyplot.draw()\n\ndef addElements():\n ax = pyplot.axes([0.1, 0.45, 0.8, 0.1])\n global slider\n slider = widgets.RangeSlider(ax, \"Test\", valmin=1, valmax=10, valinit=(1, 10))\n slider.on_changed(onchanged)\n ax = pyplot.axes([0.1, 0.30, 0.8, 0.1])\n global button\n button = widgets.Button(ax, \"Test\")\n button.on_clicked(onclick)\n\naddElements()\npyplot.show()\n```\n\nThe widgets can't receive any input from a mouse click when redrawing in the `on_changed` callback. The root cause is that mouse grabs are not released when the owning Axes is removed.", + "hints_text": "The issue is in the figure/axes mouse grab mechanism. When an Axes is removed (via `clf()`), any mouse grab it holds should be released. Look at `lib/matplotlib/figure.py` or `lib/matplotlib/axes/_base.py` for the grab/release logic.", + "test_patch": "", + "test_command": "pytest -rA lib/matplotlib/tests/test_backend_bases.py" + }, + { + "instance_id": "pallets__flask-4992", + "repo": "https://github.com/swe-bench/pallets__flask.git", + "base_commit": "4c288bc97ea371817199908d0d9b12de9dae327e", + "problem_statement": "Add a file mode parameter to flask.Config.from_file().\n\nPython 3.11 introduced native TOML support with the `tomllib` package. This could work nicely with `flask.Config.from_file()` as an easy way to load TOML config files:\n\n```python\napp.config.from_file(\"config.toml\", tomllib.load)\n```\n\nHowever, `tomllib.load()` takes an object readable in binary mode, while `flask.Config.from_file()` opens the file in text mode, resulting in this error:\n\n```\nTypeError: File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`\n```\n\nAdding a file mode parameter to `flask.Config.from_file()` would enable binary mode:\n\n```python\napp.config.from_file(\"config.toml\", tomllib.load, text=False)\n```\n\nCurrently one must work around it with a more verbose expression:\n```python\nwith open(os.path.join(app.config.root_path, \"config.toml\"), \"rb\") as f:\n app.config.from_mapping(tomllib.load(f))\n```", + "hints_text": "The fix is in `src/flask/config.py` in the `from_file` method. Add a `text` boolean parameter (default `True`). When `text=False`, open the file in `'rb'` mode instead of `'r'`.", + "test_patch": "", + "test_command": "pytest -rA tests/test_config.py" + }, + { + "instance_id": "mwaskom__seaborn-3190", + "repo": "https://github.com/swe-bench/mwaskom__seaborn.git", + "base_commit": "4a9e54962a29c12a8b103d75f838e0e795a6974d", + "problem_statement": "Color mapping fails with boolean data.\n\nUsing boolean values for the `color` parameter in the new objects interface raises a TypeError during scale setup.\n\n```python\nimport seaborn.objects as so\nso.Plot([\"a\", \"b\"], [1, 2], color=[True, False]).add(so.Bar())\n```\n\nResults in a `TypeError` during `Continuous._setup()` because the boolean data cannot be sorted/normalized as float data. The scale setup attempts to normalize the data but fails because boolean values are not handled by the `Continuous` scale's normalization logic.\n\nThe expected behavior is that boolean data should either be handled gracefully by the Continuous scale or be mapped to an appropriate scale type.", + "hints_text": "The fix is in `seaborn/_core/scales.py` in the `Continuous` class. The `_setup` method needs to handle non-float data types (like boolean) by converting them to float before normalization.", + "test_patch": "", + "test_command": "pytest --no-header -rA tests/_core/test_scales.py" + }, + { + "instance_id": "pydata__xarray-4094", + "repo": "https://github.com/swe-bench/pydata__xarray.git", + "base_commit": "a64cf2d5476e7bbda099b34c40b7be1880dbd39a", + "problem_statement": "to_unstacked_dataset broken for single-dim variables.\n\nThe `to_unstacked_dataset` method fails with a MergeError when variables have only a single dimension.\n\n```python\nimport xarray as xr\nimport numpy as np\n\narr = xr.DataArray(\n np.arange(3),\n coords=[(\"x\", [0, 1, 2])],\n)\ndata = xr.Dataset({\"a\": arr, \"b\": arr})\nstacked = data.to_stacked_array('y', sample_dims=['x'])\nunstacked = stacked.to_unstacked_dataset('y')\n```\n\n```\nMergeError: conflicting values for variable 'y' on objects to be combined.\nYou can skip this check by specifying compat='override'.\n```\n\nThe expected output is a working roundtrip: stacking and then unstacking a Dataset should return an equivalent Dataset. This fails when the variables only have a single dimension.", + "hints_text": "The fix is in `xarray/core/dataarray.py` in the `to_unstacked_dataset` method. The issue involves how the stacking coordinate is handled when variables have a single dimension.", + "test_patch": "", + "test_command": "pytest -rA xarray/tests/test_dataset.py -k unstack" + }, + { + "instance_id": "django__django-14155", + "repo": "https://github.com/swe-bench/django__django.git", + "base_commit": "2f13c476abe4ba787b6cb71131818341911f43cc", + "problem_statement": "ResolverMatch.__repr__() is not helpful for partial function views.\n\nWhen a `functools.partial` function is passed as the view, `ResolverMatch.__repr__()` shows the `func` argument as `functools.partial` which isn't very helpful, especially as it doesn't reveal the underlying function or arguments provided.\n\nFor example:\n```python\nfrom functools import partial\nfrom django.urls import resolve\n\ndef my_view(request, arg1=None):\n pass\n\n# Using partial view\npartial_view = partial(my_view, arg1='value')\n```\n\nThe `__repr__` of the resolved match for `partial_view` would just show `functools.partial` rather than the underlying function `my_view` and its pre-filled arguments. This makes debugging URL resolution issues more difficult.\n\nThe fix should unwrap partial functions in `__repr__` to show the underlying function and any provided arguments.", + "hints_text": "The fix is in `django/urls/resolvers.py` in the `ResolverMatch` class. The `__repr__` method should detect `functools.partial` objects and unwrap them to show the underlying function and arguments.", + "test_patch": "", + "test_command": "./tests/runtests.py --verbosity 2 urlpatterns_reverse.tests" + } +] diff --git a/tests/benchmark/swe/cases/mini-cases.json b/tests/benchmark/swe/cases/mini-cases.json new file mode 100644 index 0000000..d6b5c8e --- /dev/null +++ b/tests/benchmark/swe/cases/mini-cases.json @@ -0,0 +1,182 @@ +[ + { + "id": "mini-swe-001", + "description": "The `chunk` function splits an array into sub-arrays of the given size, but it returns an extra empty array at the end for certain inputs. For example `chunk([1,2,3,4,5], 2)` returns `[[1,2],[3,4],[5],[]]` instead of `[[1,2],[3,4],[5]]`.", + "files": { + "src.js": "function chunk(arr, size) {\n if (size <= 0) return [];\n const result = [];\n for (let i = 0; i <= arr.length; i += size) {\n result.push(arr.slice(i, i + size));\n }\n return result;\n}\nmodule.exports = { chunk };\n", + "test.js": "const { chunk } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = chunk([1, 2, 3, 4, 5], 2);\nassert(r1.length === 3, 'Expected 3 chunks, got ' + r1.length);\nassert(JSON.stringify(r1) === '[[1,2],[3,4],[5]]', 'Wrong result: ' + JSON.stringify(r1));\n\nconst r2 = chunk([1, 2, 3, 4], 2);\nassert(r2.length === 2, 'Expected 2 chunks, got ' + r2.length);\n\nconst r3 = chunk([], 3);\nassert(r3.length === 0, 'Expected 0 chunks for empty array, got ' + r3.length);\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-002", + "description": "The `countWords` function should return the number of words in a string. It works for simple cases like `countWords('hello world')` returning 2, but fails when there are multiple consecutive spaces. `countWords('hello world')` returns 3 instead of 2, and `countWords(' hello ')` returns 4 instead of 1.", + "files": { + "src.js": "function countWords(text) {\n if (!text) return 0;\n return text.split(' ').length;\n}\nmodule.exports = { countWords };\n", + "test.js": "const { countWords } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(countWords('hello world') === 2, 'Basic two words');\nassert(countWords('hello world') === 2, 'Double space: expected 2, got ' + countWords('hello world'));\nassert(countWords('') === 0, 'Empty string should be 0');\nassert(countWords(' hello ') === 1, 'Padded: expected 1, got ' + countWords(' hello '));\nassert(countWords('one') === 1, 'Single word');\nassert(countWords('a b c') === 3, 'Three words');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-003", + "description": "The `sortNumbers` function should sort an array of numbers in ascending numeric order. However, `sortNumbers([10, 2, 30, 1, 20])` returns `[1, 10, 2, 20, 30]` instead of `[1, 2, 10, 20, 30]`. It appears to be sorting lexicographically instead of numerically.", + "files": { + "src.js": "function sortNumbers(arr) {\n return [...arr].sort();\n}\nmodule.exports = { sortNumbers };\n", + "test.js": "const { sortNumbers } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = sortNumbers([10, 2, 30, 1, 20]);\nassert(JSON.stringify(r1) === '[1,2,10,20,30]', 'Expected [1,2,10,20,30], got ' + JSON.stringify(r1));\n\nconst r2 = sortNumbers([100, 3, 22]);\nassert(JSON.stringify(r2) === '[3,22,100]', 'Expected [3,22,100], got ' + JSON.stringify(r2));\n\nconst r3 = sortNumbers([5]);\nassert(JSON.stringify(r3) === '[5]', 'Single element');\n\nconst r4 = sortNumbers([]);\nassert(JSON.stringify(r4) === '[]', 'Empty array');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-004", + "description": "The `classify` function maps numeric scores to letter grades (A/B/C/D/F). It works for most inputs but `classify(65)` returns `'F'` instead of `'D'`. All scores from 60 to 69 should return 'D'.", + "files": { + "src.js": "function classify(score) {\n if (score >= 90) return 'A';\n if (score >= 80) return 'B';\n if (score >= 70) return 'C';\n if (score >= 60) { 'D'; }\n return 'F';\n}\nmodule.exports = { classify };\n", + "test.js": "const { classify } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(classify(95) === 'A', 'Score 95 should be A');\nassert(classify(85) === 'B', 'Score 85 should be B');\nassert(classify(75) === 'C', 'Score 75 should be C');\nassert(classify(65) === 'D', 'Score 65 should be D, got ' + classify(65));\nassert(classify(60) === 'D', 'Score 60 should be D, got ' + classify(60));\nassert(classify(55) === 'F', 'Score 55 should be F');\nassert(classify(100) === 'A', 'Score 100 should be A');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-005", + "description": "The `flatten` function should recursively flatten a nested array. `flatten([1, [2, 3], [4, [5]]])` should return `[1, 2, 3, 4, 5]` but instead returns `[1, 4]`. It seems to drop elements that are inside nested arrays.", + "files": { + "src.js": "function flatten(arr) {\n return arr.reduce((acc, item) => {\n if (Array.isArray(item)) {\n acc.concat(flatten(item));\n } else {\n acc.push(item);\n }\n return acc;\n }, []);\n}\nmodule.exports = { flatten };\n", + "test.js": "const { flatten } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = flatten([1, [2, 3], [4, [5]]]);\nassert(JSON.stringify(r1) === '[1,2,3,4,5]', 'Deep flatten: expected [1,2,3,4,5], got ' + JSON.stringify(r1));\n\nconst r2 = flatten([1, 2, 3]);\nassert(JSON.stringify(r2) === '[1,2,3]', 'Already flat: ' + JSON.stringify(r2));\n\nconst r3 = flatten([[1], [2], [3]]);\nassert(JSON.stringify(r3) === '[1,2,3]', 'One level: ' + JSON.stringify(r3));\n\nconst r4 = flatten([]);\nassert(JSON.stringify(r4) === '[]', 'Empty array');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-006", + "description": "The `reverseWords` function should reverse the order of words in a sentence. For example, `reverseWords('hello world')` should return `'world hello'`. But it currently returns `'worldhello'` — the words are reversed but the spaces between them are missing.", + "files": { + "src.js": "function reverseWords(sentence) {\n return sentence.split(' ').reverse().join('');\n}\nmodule.exports = { reverseWords };\n", + "test.js": "const { reverseWords } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(reverseWords('hello world') === 'world hello', 'Two words: got \"' + reverseWords('hello world') + '\"');\nassert(reverseWords('a b c') === 'c b a', 'Three words');\nassert(reverseWords('single') === 'single', 'Single word unchanged');\nassert(reverseWords('the quick brown fox') === 'fox brown quick the', 'Four words');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-007", + "description": "The `addItem` function should return a new array with the item appended, without modifying the original array. However, calling `addItem(original, 4)` mutates the original array. After the call, the original array has been changed, which breaks downstream code that relies on immutability.", + "files": { + "src.js": "function addItem(list, item) {\n list.push(item);\n return list;\n}\nmodule.exports = { addItem };\n", + "test.js": "const { addItem } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst original = [1, 2, 3];\nconst result = addItem(original, 4);\n\nassert(JSON.stringify(result) === '[1,2,3,4]', 'Result should contain new item: ' + JSON.stringify(result));\nassert(JSON.stringify(original) === '[1,2,3]', 'Original mutated: ' + JSON.stringify(original));\n\nconst empty = [];\nconst r2 = addItem(empty, 'a');\nassert(JSON.stringify(r2) === '[\"a\"]', 'Add to empty');\nassert(empty.length === 0, 'Empty array mutated');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-008", + "description": "The `mapObject` function should transform all values in an object using a callback `fn(value, key)`. But `mapObject({ a: 1, b: 2 }, (v) => v * 2)` returns `{ a: NaN, b: NaN }` instead of `{ a: 2, b: 4 }`. It looks like the callback arguments might be in the wrong order.", + "files": { + "src.js": "function mapObject(obj, fn) {\n const result = {};\n for (const [key, value] of Object.entries(obj)) {\n result[key] = fn(key, value);\n }\n return result;\n}\nmodule.exports = { mapObject };\n", + "test.js": "const { mapObject } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = mapObject({ a: 1, b: 2, c: 3 }, (v) => v * 2);\nassert(r1.a === 2, 'a should be 2, got ' + r1.a);\nassert(r1.b === 4, 'b should be 4, got ' + r1.b);\nassert(r1.c === 6, 'c should be 6, got ' + r1.c);\n\nconst r2 = mapObject({ x: 'hello' }, (v, k) => k + ':' + v);\nassert(r2.x === 'x:hello', 'Key-value concat: got ' + r2.x);\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-009", + "description": "The `getAdults` function should return all people aged 18 or older from a list. But it currently excludes people who are exactly 18. `getAdults([{name:'Alice', age:18}])` returns an empty array instead of including Alice.", + "files": { + "src.js": "function getAdults(people) {\n return people.filter(p => p.age > 18);\n}\nmodule.exports = { getAdults };\n", + "test.js": "const { getAdults } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst people = [\n { name: 'Alice', age: 18 },\n { name: 'Bob', age: 17 },\n { name: 'Charlie', age: 25 },\n { name: 'Diana', age: 18 }\n];\nconst adults = getAdults(people);\nassert(adults.length === 3, 'Expected 3 adults, got ' + adults.length);\nassert(adults.some(p => p.name === 'Alice'), 'Alice (18) should be included');\nassert(adults.some(p => p.name === 'Diana'), 'Diana (18) should be included');\nassert(!adults.some(p => p.name === 'Bob'), 'Bob (17) should be excluded');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-010", + "description": "The `truncate` function should shorten a string to `maxLen` characters, adding '...' at the end if truncation occurs. It works for long strings but incorrectly truncates strings that are exactly `maxLen` characters long. `truncate('hello', 5)` returns `'he...'` instead of `'hello'`.", + "files": { + "src.js": "function truncate(str, maxLen) {\n if (str.length >= maxLen) {\n return str.slice(0, maxLen - 3) + '...';\n }\n return str;\n}\nmodule.exports = { truncate };\n", + "test.js": "const { truncate } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(truncate('hello', 5) === 'hello', 'Exact length should not truncate, got \"' + truncate('hello', 5) + '\"');\nassert(truncate('hi', 5) === 'hi', 'Short string unchanged');\nassert(truncate('hello world', 8) === 'hello...', 'Truncate to 8: got \"' + truncate('hello world', 8) + '\"');\nassert(truncate('abcdefghij', 7) === 'abcd...', 'Truncate to 7');\nassert(truncate('ab', 2) === 'ab', 'Length equals maxLen, no truncation');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-011", + "description": "The `capitalize` function should capitalize the first letter of each word in a string. But `capitalize('hello world')` returns `'HELLO WORLD'` instead of `'Hello World'`. It uppercases the entire word rather than just the first character.", + "files": { + "src.js": "function capitalize(str) {\n return str.split(' ').map(w => w.toUpperCase()).join(' ');\n}\nmodule.exports = { capitalize };\n", + "test.js": "const { capitalize } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(capitalize('hello world') === 'Hello World', 'Basic: got \"' + capitalize('hello world') + '\"');\nassert(capitalize('foo bar baz') === 'Foo Bar Baz', 'Three words');\nassert(capitalize('a') === 'A', 'Single char');\nassert(capitalize('already Capital') === 'Already Capital', 'Mixed case');\nassert(capitalize('') === '', 'Empty string');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-012", + "description": "The `range` function should generate an array of numbers from `start` to `end` inclusive. But `range(1, 5)` returns `[1, 2, 3, 4]` instead of `[1, 2, 3, 4, 5]`. The end value is always excluded.", + "files": { + "src.js": "function range(start, end) {\n const result = [];\n for (let i = start; i < end; i++) {\n result.push(i);\n }\n return result;\n}\nmodule.exports = { range };\n", + "test.js": "const { range } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(JSON.stringify(range(1, 5)) === '[1,2,3,4,5]', '1 to 5: got ' + JSON.stringify(range(1, 5)));\nassert(JSON.stringify(range(0, 3)) === '[0,1,2,3]', '0 to 3: got ' + JSON.stringify(range(0, 3)));\nassert(JSON.stringify(range(5, 5)) === '[5]', 'Same start/end: got ' + JSON.stringify(range(5, 5)));\nassert(JSON.stringify(range(-2, 1)) === '[-2,-1,0,1]', 'Negative: got ' + JSON.stringify(range(-2, 1)));\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-013", + "description": "The `isPalindrome` function checks whether a string is a palindrome. It should be case-insensitive, so `isPalindrome('Racecar')` should return `true`. But it returns `false` because it compares without normalizing case.", + "files": { + "src.js": "function isPalindrome(str) {\n const reversed = str.split('').reverse().join('');\n return str === reversed;\n}\nmodule.exports = { isPalindrome };\n", + "test.js": "const { isPalindrome } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(isPalindrome('racecar') === true, 'racecar is a palindrome');\nassert(isPalindrome('Racecar') === true, 'Racecar (mixed case) is a palindrome, got ' + isPalindrome('Racecar'));\nassert(isPalindrome('hello') === false, 'hello is not a palindrome');\nassert(isPalindrome('Madam') === true, 'Madam is a palindrome');\nassert(isPalindrome('a') === true, 'Single char is a palindrome');\nassert(isPalindrome('Ab') === false, 'Ab is not a palindrome');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-014", + "description": "The `deepClone` function should create a deep copy of an object, preserving arrays as arrays. But `deepClone({a: [1, 2, 3]})` returns `{a: {\"0\": 1, \"1\": 2, \"2\": 3}}` — arrays are converted to plain objects because the clone always creates `{}` instead of checking for arrays.", + "files": { + "src.js": "function deepClone(obj) {\n if (obj === null || typeof obj !== 'object') return obj;\n const clone = {};\n for (const key of Object.keys(obj)) {\n clone[key] = deepClone(obj[key]);\n }\n return clone;\n}\nmodule.exports = { deepClone };\n", + "test.js": "const { deepClone } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst original = { a: [1, 2, 3], b: { c: 'hello' } };\nconst cloned = deepClone(original);\n\nassert(Array.isArray(cloned.a), 'cloned.a should be an array, got ' + typeof cloned.a);\nassert(JSON.stringify(cloned.a) === '[1,2,3]', 'cloned.a content wrong: ' + JSON.stringify(cloned.a));\nassert(cloned.b.c === 'hello', 'Nested object preserved');\nassert(cloned.a !== original.a, 'Array should be a different reference');\nassert(cloned.b !== original.b, 'Nested obj should be a different reference');\n\noriginal.a.push(4);\nassert(cloned.a.length === 3, 'Mutation should not affect clone');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-015", + "description": "The `groupBy` function should group array elements by a key function. But `groupBy([{type:'a', val:1}, {type:'b', val:2}, {type:'a', val:3}], x => x.type)` stores the keys instead of the items in each group. The result is `{a: ['a', 'a'], b: ['b']}` instead of the original objects.", + "files": { + "src.js": "function groupBy(arr, keyFn) {\n const groups = {};\n for (const item of arr) {\n const key = keyFn(item);\n if (!groups[key]) groups[key] = [];\n groups[key].push(key);\n }\n return groups;\n}\nmodule.exports = { groupBy };\n", + "test.js": "const { groupBy } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst items = [{type:'a', val:1}, {type:'b', val:2}, {type:'a', val:3}];\nconst grouped = groupBy(items, x => x.type);\n\nassert(grouped.a.length === 2, 'Group a should have 2 items, got ' + grouped.a.length);\nassert(grouped.b.length === 1, 'Group b should have 1 item');\nassert(grouped.a[0].val === 1, 'First a item val should be 1, got ' + JSON.stringify(grouped.a[0]));\nassert(grouped.a[1].val === 3, 'Second a item val should be 3');\nassert(grouped.b[0].val === 2, 'b item val should be 2');\n\nconst nums = [1, 2, 3, 4, 5];\nconst evenOdd = groupBy(nums, n => n % 2 === 0 ? 'even' : 'odd');\nassert(evenOdd.odd.length === 3, 'Odd group: ' + JSON.stringify(evenOdd.odd));\nassert(evenOdd.even.length === 2, 'Even group: ' + JSON.stringify(evenOdd.even));\nassert(evenOdd.odd[0] === 1, 'First odd should be 1, got ' + evenOdd.odd[0]);\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-016", + "description": "The `intersection` function should return elements present in both arrays. But `intersection([1,2,3], [2,3,4])` returns `[1,2,3]` instead of `[2,3]`. It checks inclusion against the wrong array.", + "files": { + "src.js": "function intersection(a, b) {\n return a.filter(item => a.includes(item));\n}\nmodule.exports = { intersection };\n", + "test.js": "const { intersection } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = intersection([1, 2, 3], [2, 3, 4]);\nassert(JSON.stringify(r1) === '[2,3]', 'Expected [2,3], got ' + JSON.stringify(r1));\n\nconst r2 = intersection([1, 2], [3, 4]);\nassert(r2.length === 0, 'No common elements, got ' + JSON.stringify(r2));\n\nconst r3 = intersection([5, 5, 6], [5, 7]);\nassert(r3.includes(5), 'Should include 5');\nassert(!r3.includes(6), 'Should not include 6');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-017", + "description": "The `zip` function should pair up elements from two arrays. `zip([1,2,3], ['a','b','c'])` should return `[[1,'a'],[2,'b'],[3,'c']]`. But it returns `[[1,1],[2,2],[3,3]]` — it uses the first array for both elements of each pair.", + "files": { + "src.js": "function zip(a, b) {\n const len = Math.min(a.length, b.length);\n const result = [];\n for (let i = 0; i < len; i++) {\n result.push([a[i], a[i]]);\n }\n return result;\n}\nmodule.exports = { zip };\n", + "test.js": "const { zip } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = zip([1, 2, 3], ['a', 'b', 'c']);\nassert(JSON.stringify(r1) === '[[1,\"a\"],[2,\"b\"],[3,\"c\"]]', 'Basic zip: got ' + JSON.stringify(r1));\n\nconst r2 = zip([1, 2], [10, 20, 30]);\nassert(r2.length === 2, 'Should truncate to shorter length');\nassert(JSON.stringify(r2) === '[[1,10],[2,20]]', 'Uneven: got ' + JSON.stringify(r2));\n\nconst r3 = zip([], [1]);\nassert(r3.length === 0, 'Empty first array');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-018", + "description": "The `sumBy` function should sum an array of objects by a given numeric key. `sumBy([{v:1},{v:2},{v:3}], 'v')` should return `6`, but it returns `'123'` (string concatenation) because the initial accumulator is an empty string instead of zero.", + "files": { + "src.js": "function sumBy(arr, key) {\n return arr.reduce((sum, item) => sum + item[key], '');\n}\nmodule.exports = { sumBy };\n", + "test.js": "const { sumBy } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(sumBy([{v:1},{v:2},{v:3}], 'v') === 6, 'Sum should be 6, got ' + sumBy([{v:1},{v:2},{v:3}], 'v'));\nassert(sumBy([{score:10},{score:20}], 'score') === 30, 'Sum should be 30');\nassert(sumBy([], 'v') === 0, 'Empty array should be 0');\nassert(typeof sumBy([{v:1}], 'v') === 'number', 'Result should be a number, got ' + typeof sumBy([{v:1}], 'v'));\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-019", + "description": "The `pick` function should create a new object with only the specified keys from the source object. But `pick({a:1, b:2, c:3}, ['a','c'])` returns `{a:'a', c:'c'}` instead of `{a:1, c:3}`. It assigns the key name as the value instead of the actual value.", + "files": { + "src.js": "function pick(obj, keys) {\n const result = {};\n for (const key of keys) {\n if (key in obj) result[key] = key;\n }\n return result;\n}\nmodule.exports = { pick };\n", + "test.js": "const { pick } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = pick({a: 1, b: 2, c: 3}, ['a', 'c']);\nassert(r1.a === 1, 'a should be 1, got ' + r1.a);\nassert(r1.c === 3, 'c should be 3, got ' + r1.c);\nassert(r1.b === undefined, 'b should not be present');\n\nconst r2 = pick({x: 'hello', y: 'world'}, ['x', 'z']);\nassert(r2.x === 'hello', 'x should be hello');\nassert(r2.z === undefined, 'z not in source, should be absent');\nassert(Object.keys(r2).length === 1, 'Should only have 1 key');\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + }, + { + "id": "mini-swe-020", + "description": "The `difference` function should return elements in the first array that are NOT in the second array. But `difference([1,2,3,4], [2,4])` returns `[2,4]` instead of `[1,3]`. The filter logic is inverted — it returns the intersection instead of the difference.", + "files": { + "src.js": "function difference(a, b) {\n return a.filter(item => b.includes(item));\n}\nmodule.exports = { difference };\n", + "test.js": "const { difference } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = difference([1, 2, 3, 4], [2, 4]);\nassert(JSON.stringify(r1) === '[1,3]', 'Expected [1,3], got ' + JSON.stringify(r1));\n\nconst r2 = difference([1, 2, 3], []);\nassert(JSON.stringify(r2) === '[1,2,3]', 'Nothing to remove: got ' + JSON.stringify(r2));\n\nconst r3 = difference([1, 2], [1, 2, 3]);\nassert(r3.length === 0, 'All removed: got ' + JSON.stringify(r3));\n\nconst r4 = difference(['a', 'b', 'c'], ['b']);\nassert(JSON.stringify(r4) === '[\"a\",\"c\"]', 'Strings: got ' + JSON.stringify(r4));\n\nconsole.log('All tests passed');\n" + }, + "test_command": "node test.js" + } +] diff --git a/tests/benchmark/swe/dataset.ts b/tests/benchmark/swe/dataset.ts new file mode 100644 index 0000000..7a31f27 --- /dev/null +++ b/tests/benchmark/swe/dataset.ts @@ -0,0 +1,40 @@ +// --------------------------------------------------------------------------- +// SWE benchmark dataset loader +// --------------------------------------------------------------------------- + +import fs from 'fs'; +import path from 'path'; +import type { FullSWEInstance } from './docker-evaluator'; + +export interface MiniCase { + id: string; + description: string; + files: Record; + test_command: string; +} + +/** + * Load mini-SWE cases from the local JSON file. + */ +export function loadMiniCases(): MiniCase[] { + const casesPath = path.join(__dirname, 'cases', 'mini-cases.json'); + if (!fs.existsSync(casesPath)) { + console.log(` SWE: cases file not found at ${casesPath}`); + return []; + } + const raw = fs.readFileSync(casesPath, 'utf-8'); + return JSON.parse(raw) as MiniCase[]; +} + +/** + * Load curated SWE-bench instances for full mode. + */ +export function loadCuratedInstances(): FullSWEInstance[] { + const instancesPath = path.join(__dirname, 'cases', 'curated-instances.json'); + if (!fs.existsSync(instancesPath)) { + console.log(` SWE: curated instances file not found at ${instancesPath}`); + return []; + } + const raw = fs.readFileSync(instancesPath, 'utf-8'); + return JSON.parse(raw) as FullSWEInstance[]; +} diff --git a/tests/benchmark/swe/docker-evaluator.ts b/tests/benchmark/swe/docker-evaluator.ts new file mode 100644 index 0000000..3389ebc --- /dev/null +++ b/tests/benchmark/swe/docker-evaluator.ts @@ -0,0 +1,744 @@ +// --------------------------------------------------------------------------- +// SWE-bench full mode — Docker-based evaluation +// --------------------------------------------------------------------------- + +import { execSync, spawnSync, ExecSyncOptionsWithStringEncoding } from 'child_process'; +import fs from 'fs'; +import os from 'os'; +import path from 'path'; +import type { ModelProvider } from '../../../src/infra/providers/types'; +import type { Message } from '../../../src/core/types'; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface FullSWEInstance { + instance_id: string; + repo: string; + base_commit: string; + problem_statement: string; + hints_text: string; + test_patch: string; + test_command: string; +} + +export interface FullHarnessResult { + patch: string; + tokens: number; + error?: string; +} + +export interface DockerEvalResult { + passed: boolean; + output: string; + error?: string; +} + +// --------------------------------------------------------------------------- +// Docker availability check +// --------------------------------------------------------------------------- + +export function isDockerAvailable(): boolean { + try { + execSync('docker info', { stdio: 'pipe', timeout: 10_000 }); + return true; + } catch { + return false; + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/** + * Extract "owner/repo" from a GitHub URL. + * e.g. "https://github.com/psf/requests.git" → "psf-requests" + */ +function repoSlug(repoUrl: string): string { + const match = repoUrl.match(/github\.com\/([^/]+)\/([^/.]+)/); + if (match) return `${match[1]}-${match[2]}`.toLowerCase(); + // Fallback: strip protocol, slashes, .git + return repoUrl + .replace(/^https?:\/\//, '') + .replace(/\.git$/, '') + .replace(/[^a-z0-9-]/gi, '-') + .replace(/-+/g, '-') + .replace(/^-|-$/g, '') + .toLowerCase(); +} + +/** + * Sanitize a string for use as a Docker container name. + * Docker allows [a-zA-Z0-9_.-] and must start with [a-zA-Z0-9]. + */ +function sanitizeContainerName(raw: string): string { + return raw + .replace(/[^a-zA-Z0-9_.-]/g, '-') + .replace(/-+/g, '-') + .replace(/^[^a-zA-Z0-9]+/, '') + .slice(0, 128); +} + +// --------------------------------------------------------------------------- +// New full-mode helpers: read from Docker image → LLM → diff +// --------------------------------------------------------------------------- + +/** + * Read files from a SWE-bench Docker image's /testbed directory. + * This avoids cloning the repo on the host — the image already has everything. + */ +function readFilesFromImage(imageName: string, filePaths: string[]): Record { + const files: Record = {}; + + for (const fp of filePaths) { + const result = spawnSync( + 'docker', + ['run', '--rm', imageName, 'cat', `/testbed/${fp}`], + { stdio: ['pipe', 'pipe', 'pipe'], timeout: 30_000 }, + ); + + if (result.status === 0) { + const content = (result.stdout || '').toString(); + if (content) { + files[fp] = content; + } + } + } + + return files; +} + +/** + * Extract relevant file paths from problem statement and hints text. + * Looks for common source file path patterns. + */ +function extractRelevantPaths(problemStatement: string, hintsText: string): string[] { + const paths = new Set(); + + // Common source file patterns (Python-centric for SWE-bench) + const patterns = [ + // Explicit paths like `path/to/file.py` (backtick-quoted) + /`([\w/.]+\.py)`/g, + // Paths mentioned naturally: word/word/file.py + /(?:^|\s)((?:[\w-]+\/)+[\w-]+\.py)(?:\s|$|[.,;:)])/gm, + // Module-style paths: package.module.file (convert dots to slashes) + /(?:in|see|at|file|module)\s+`?([\w]+(?:\.[\w]+){2,})`?/gi, + ]; + + // Prioritize hints_text (usually more precise) + const sources = [hintsText, problemStatement].filter(Boolean); + + for (const source of sources) { + for (const pattern of patterns) { + pattern.lastIndex = 0; + let match: RegExpExecArray | null; + while ((match = pattern.exec(source)) !== null) { + let p = match[1].trim(); + // Convert module-style paths (e.g. astropy.modeling.separable) to file paths + if (!p.includes('/') && p.includes('.') && !p.endsWith('.py')) { + p = p.replace(/\./g, '/') + '.py'; + } + // Skip test files and obviously invalid paths + if (!p.includes('test') && p.endsWith('.py') && p.length > 4) { + paths.add(p); + } + } + } + } + + return Array.from(paths); +} + +// (readFilesFromRepo removed — we now read directly from Docker images) + +// --------------------------------------------------------------------------- +// LLM interaction — generate fix (file-based, like mini mode) +// --------------------------------------------------------------------------- + +const FULL_SYSTEM_PROMPT = `You are a software engineer fixing bugs in open-source repositories. +You will be given a bug report, hints, and the relevant source files. +Your task is to fix the bug so all tests pass. + +Rules: +- Only modify source files. NEVER modify test files. +- Output ONLY the changed sections using the SEARCH/REPLACE format below. +- Do NOT output the entire file. Only output the minimal code blocks that need to change. +- Do NOT include any explanation outside the file markers. + +Format: + +--- FILE: --- +<<<<<<< SEARCH + +======= + +>>>>>>> REPLACE +--- END FILE --- + +You may include multiple SEARCH/REPLACE blocks within one FILE section. +You may output multiple FILE sections if changes span multiple files. + +Example: + +--- FILE: src/utils.py --- +<<<<<<< SEARCH +def validate(value): + if value > 0: + return True +======= +def validate(value): + if value >= 0: + return True +>>>>>>> REPLACE +--- END FILE ---`; + +/** + * Call the LLM with source file context (like mini mode). + * Includes a single retry on failure. + */ +async function callLLMWithContext( + provider: ModelProvider, + instance: FullSWEInstance, + files: Record, +): Promise<{ text: string; tokens: number }> { + const fileListing = Object.entries(files) + .map(([name, content]) => `--- ${name} ---\n${content}`) + .join('\n'); + + const userMessage = [ + 'Bug report:', + instance.problem_statement, + ]; + + if (instance.hints_text) { + userMessage.push('', 'Hints:', instance.hints_text); + } + + userMessage.push( + '', + 'Source files:', + fileListing, + '', + 'Fix the bug in the source file(s) so that all tests pass.', + 'Output ONLY the changed sections using the SEARCH/REPLACE format described in your instructions.', + ); + + const messages: Message[] = [ + { role: 'user', content: [{ type: 'text', text: userMessage.join('\n') }] }, + ]; + + const attempt = async (): Promise<{ text: string; tokens: number }> => { + const response = await provider.complete(messages, { + system: FULL_SYSTEM_PROMPT, + maxTokens: 16384, + }); + + const text = response.content + .filter((b): b is { type: 'text'; text: string } => b.type === 'text') + .map(b => b.text) + .join(''); + + const tokens = + (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); + + return { text, tokens }; + }; + + try { + return await attempt(); + } catch (err: any) { + // Single retry after 3 seconds + console.log(` [llm] First attempt failed (${err.message}), retrying ...`); + await new Promise(resolve => setTimeout(resolve, 3000)); + return await attempt(); + } +} + +/** + * Parse `--- FILE: --- ... --- END FILE ---` blocks from model output. + * Each FILE block may contain one or more SEARCH/REPLACE hunks, or a full file body. + */ +function parseFileBlocks(text: string): Array<{ path: string; body: string }> { + const blocks: Array<{ path: string; body: string }> = []; + const regex = /---\s*FILE:\s*(.+?)\s*---\r?\n([\s\S]*?)---\s*END FILE\s*---/g; + let match: RegExpExecArray | null; + + while ((match = regex.exec(text)) !== null) { + const filename = match[1].trim(); + const body = match[2]; + if (!filename.includes('test')) { + blocks.push({ path: filename, body }); + } + } + + return blocks; +} + +/** + * Parse SEARCH/REPLACE hunks from a file block body. + * Returns an array of { search, replace } pairs. + */ +function parseSearchReplaceHunks(body: string): Array<{ search: string; replace: string }> { + const hunks: Array<{ search: string; replace: string }> = []; + const regex = /<<<<<<< SEARCH\r?\n([\s\S]*?)=======\r?\n([\s\S]*?)>>>>>>> REPLACE/g; + let match: RegExpExecArray | null; + + while ((match = regex.exec(body)) !== null) { + hunks.push({ search: match[1], replace: match[2] }); + } + + return hunks; +} + +/** + * Apply SEARCH/REPLACE hunks to the original file content. + * Returns the corrected file content, or null if any hunk fails to match. + */ +function applyHunks( + original: string, + hunks: Array<{ search: string; replace: string }>, +): string | null { + let result = original; + + for (const hunk of hunks) { + // Try exact match first + if (result.includes(hunk.search)) { + result = result.replace(hunk.search, hunk.replace); + continue; + } + + // Try trimmed trailing newline match + const searchTrimmed = hunk.search.replace(/\n$/, ''); + const replaceTrimmed = hunk.replace.replace(/\n$/, ''); + if (result.includes(searchTrimmed)) { + result = result.replace(searchTrimmed, replaceTrimmed); + continue; + } + + // Hunk didn't match + return null; + } + + return result; +} + +/** + * Generate a unified diff by comparing original and corrected file contents. + * Uses `diff -u` with temp files. No repo clone needed. + */ +function generateDiffFromOriginals( + originals: Record, + corrected: Record, +): string { + const diffs: string[] = []; + + for (const [filePath, newContent] of Object.entries(corrected)) { + const originalContent = originals[filePath]; + if (originalContent === undefined) { + // New file — generate a diff from /dev/null + const tmpNew = path.join(os.tmpdir(), `swe-new-${Date.now()}-${Math.random().toString(36).slice(2)}`); + fs.writeFileSync(tmpNew, newContent, 'utf-8'); + try { + const result = spawnSync( + 'diff', + ['-u', '/dev/null', tmpNew, '--label', `a/${filePath}`, '--label', `b/${filePath}`], + { stdio: ['pipe', 'pipe', 'pipe'], timeout: 10_000 }, + ); + const diffOutput = (result.stdout || '').toString(); + if (diffOutput) { + diffs.push(`diff --git a/${filePath} b/${filePath}\n${diffOutput}`); + } + } finally { + fs.unlinkSync(tmpNew); + } + continue; + } + + // Write original and new to temp files for diff + const tmpOrig = path.join(os.tmpdir(), `swe-orig-${Date.now()}-${Math.random().toString(36).slice(2)}`); + const tmpNew = path.join(os.tmpdir(), `swe-new-${Date.now()}-${Math.random().toString(36).slice(2)}`); + + fs.writeFileSync(tmpOrig, originalContent, 'utf-8'); + fs.writeFileSync(tmpNew, newContent, 'utf-8'); + + try { + const result = spawnSync( + 'diff', + ['-u', tmpOrig, tmpNew, '--label', `a/${filePath}`, '--label', `b/${filePath}`], + { stdio: ['pipe', 'pipe', 'pipe'], timeout: 10_000 }, + ); + // diff exits 1 when files differ (not an error) + const diffOutput = (result.stdout || '').toString(); + if (diffOutput) { + diffs.push(`diff --git a/${filePath} b/${filePath}\n${diffOutput}`); + } + } finally { + fs.unlinkSync(tmpOrig); + fs.unlinkSync(tmpNew); + } + } + + return diffs.join('\n'); +} + +// --------------------------------------------------------------------------- +// Main entry point — generate fix (replaces old generatePatch) +// --------------------------------------------------------------------------- + +/** + * Generate a fix by: + * 1. Pulling the SWE-bench Docker image (has repo at /testbed) + * 2. Extracting relevant file paths from the problem statement / hints + * 3. Reading those files directly from the Docker image + * 4. Sending source code + problem to LLM (like mini mode) + * 5. Parsing corrected files from LLM response + * 6. Generating unified diff programmatically + */ +export async function generateFix( + provider: ModelProvider, + instance: FullSWEInstance, + proxyUrl?: string, +): Promise { + try { + // 1. Ensure the SWE-bench image is available + const imageName = getSWEBenchImageName(instance.instance_id); + if (!pullImage(imageName, proxyUrl)) { + return { patch: '', tokens: 0, error: `Failed to pull SWE-bench image: ${imageName}` }; + } + + // 2. Extract relevant file paths + const filePaths = extractRelevantPaths(instance.problem_statement, instance.hints_text); + console.log(` [fix] Extracted ${filePaths.length} relevant file path(s): ${filePaths.join(', ')}`); + + if (filePaths.length === 0) { + return { patch: '', tokens: 0, error: 'No relevant file paths found in problem statement or hints' }; + } + + // 3. Read source files directly from the Docker image + console.log(` [fix] Reading files from Docker image ...`); + const fileContents = readFilesFromImage(imageName, filePaths); + const readCount = Object.keys(fileContents).length; + console.log(` [fix] Read ${readCount} file(s) from image`); + + if (readCount === 0) { + return { patch: '', tokens: 0, error: 'None of the extracted file paths exist in the image' }; + } + + // 4. Call LLM with source context + console.log(` [fix] Sending source files + problem to LLM ...`); + const response = await callLLMWithContext(provider, instance, fileContents); + + // 5. Parse file blocks and apply search/replace hunks + const fileBlocks = parseFileBlocks(response.text); + + if (fileBlocks.length === 0) { + // Log a snippet of the response for debugging + const snippet = response.text.slice(0, 300).replace(/\n/g, '\\n'); + console.log(` [fix] Response snippet: ${snippet}`); + return { patch: '', tokens: response.tokens, error: 'No corrected files found in model response' }; + } + + // Build corrected files by applying hunks to originals + const correctedFiles: Record = {}; + + for (const block of fileBlocks) { + const hunks = parseSearchReplaceHunks(block.body); + + if (hunks.length > 0) { + // Search/replace mode — apply hunks to original + const original = fileContents[block.path]; + if (!original) { + console.log(` [fix] Warning: original file not found for ${block.path}, skipping`); + continue; + } + const applied = applyHunks(original, hunks); + if (applied === null) { + console.log(` [fix] Warning: SEARCH block mismatch for ${block.path}`); + continue; + } + correctedFiles[block.path] = applied; + } else { + // Fallback: block body is the complete corrected file content + correctedFiles[block.path] = block.body; + } + } + + const correctedCount = Object.keys(correctedFiles).length; + + if (correctedCount === 0) { + return { patch: '', tokens: response.tokens, error: 'All SEARCH/REPLACE hunks failed to match' }; + } + + console.log(` [fix] LLM returned ${correctedCount} corrected file(s)`); + + // 6. Generate unified diff (using temp files, no repo clone needed) + const patch = generateDiffFromOriginals(fileContents, correctedFiles); + + if (!patch) { + return { patch: '', tokens: response.tokens, error: 'Generated diff is empty (no changes detected)' }; + } + + return { patch, tokens: response.tokens }; + } catch (err: any) { + return { patch: '', tokens: 0, error: err.message || String(err) }; + } +} + +// --------------------------------------------------------------------------- +// Docker-based evaluation (using official SWE-bench pre-built images) +// --------------------------------------------------------------------------- + +/** stdio config that streams stdout to terminal for progress visibility */ +const LIVE_OPTS = { + stdio: ['pipe' as const, 'inherit' as const, 'pipe' as const], + timeout: 1_200_000, // 20 minutes — some test suites (e.g. sympy) are slow +}; + +/** + * Derive the official SWE-bench Docker image name from an instance_id. + * Convention: `swebench/sweb.eval.x86_64.:latest` + * where `__` in instance_id is replaced with `_1776_`. + */ +export function getSWEBenchImageName(instanceId: string): string { + const slug = instanceId.toLowerCase().replace(/__/g, '_1776_'); + return `swebench/sweb.eval.x86_64.${slug}:latest`; +} + +/** + * Pull a Docker image, using proxy if configured. + * Returns true if the image is available (already existed or pulled successfully). + */ +function pullImage(imageName: string, proxyUrl?: string): boolean { + // Check if image already exists locally + const checkResult = spawnSync( + 'docker', ['image', 'inspect', imageName], + { stdio: ['pipe', 'pipe', 'pipe'], timeout: 10_000 }, + ); + if (checkResult.status === 0) { + console.log(` [docker] Image ${imageName} already available locally`); + return true; + } + + // Pull with proxy if needed + const env: Record = { ...process.env as Record }; + if (proxyUrl) { + env.HTTPS_PROXY = proxyUrl; + env.HTTP_PROXY = proxyUrl; + env.https_proxy = proxyUrl; + env.http_proxy = proxyUrl; + } + + console.log(` [docker] Pulling ${imageName} ...`); + const pullResult = spawnSync( + 'docker', ['pull', imageName], + { env, stdio: ['pipe', 'inherit', 'pipe'], timeout: 1_200_000 }, + ); + + if (pullResult.status !== 0) { + const stderr = (pullResult.stderr || '').toString().trim(); + console.log(` [docker] Failed to pull image: ${stderr}`); + return false; + } + + return true; +} + +/** + * Evaluate a patch inside an official SWE-bench Docker container. + * + * The SWE-bench images come with: + * - Repository pre-cloned at /testbed (at the correct base commit) + * - Conda environment "testbed" with all dependencies installed + * - Correct Python version for the project + * + * Steps: + * 1. Pull the SWE-bench image (if not cached locally) + * 2. Mount patch files into the container + * 3. Apply the fix patch with git apply (with fallbacks) + * 4. Apply the test patch (if provided) + * 5. Run the test command inside the conda environment + */ +export function evaluateWithDocker( + instance: FullSWEInstance, + patch: string, + workDir: string, + proxyUrl?: string, +): DockerEvalResult { + const imageName = getSWEBenchImageName(instance.instance_id); + + // Pull image if needed + if (!pullImage(imageName, proxyUrl)) { + return { + passed: false, + output: '', + error: `Failed to pull SWE-bench image: ${imageName}`, + }; + } + + fs.mkdirSync(workDir, { recursive: true }); + + // Write patches to workDir (mounted into container) + fs.writeFileSync(path.join(workDir, 'fix.patch'), patch, 'utf-8'); + if (instance.test_patch) { + fs.writeFileSync(path.join(workDir, 'test.patch'), instance.test_patch, 'utf-8'); + } + + // Build evaluation script + // The SWE-bench container has: /testbed (repo), conda env "testbed" + const script = [ + '#!/bin/bash', + 'set -uo pipefail', + '', + 'source /opt/miniconda3/bin/activate', + 'conda activate testbed', + 'cd /testbed', + '', + 'echo " [docker] Applying fix patch ..."', + 'if git apply --verbose /patches/fix.patch; then', + ' echo " [docker] Patch applied with git apply"', + 'elif git apply --verbose --reject /patches/fix.patch; then', + ' echo " [docker] Patch applied with --reject"', + 'elif patch --batch --fuzz=5 -p1 -i /patches/fix.patch; then', + ' echo " [docker] Patch applied with patch command"', + 'else', + ' echo " [docker] ERROR: Patch application failed"', + ' exit 1', + 'fi', + '', + 'if [ -f /patches/test.patch ] && [ -s /patches/test.patch ]; then', + ' echo " [docker] Applying test patch ..."', + ' git apply -v /patches/test.patch || true', + 'fi', + '', + `echo " [docker] Running tests: ${instance.test_command}"`, + `${instance.test_command}`, + 'echo " [docker] Tests completed."', + ].join('\n'); + + fs.writeFileSync(path.join(workDir, 'evaluate.sh'), script, 'utf-8'); + + const containerName = sanitizeContainerName(`swe-${instance.instance_id}-${Date.now()}`); + + try { + console.log(` [docker] Starting container (${imageName}) ...`); + const result = spawnSync( + 'docker', + [ + 'run', '--rm', + '--name', containerName, + '-v', `${workDir}:/patches:ro`, + imageName, + 'bash', '/patches/evaluate.sh', + ], + LIVE_OPTS, + ); + + const stderr = (result.stderr || '').toString().trim(); + + if (result.status === 0) { + return { passed: true, output: '' }; + } + + return { + passed: false, + output: '', + error: stderr || `exit code ${result.status}`, + }; + } catch (err: any) { + return { + passed: false, + output: '', + error: err.message || String(err), + }; + } +} + +// --------------------------------------------------------------------------- +// Local evaluation fallback (no Docker) +// --------------------------------------------------------------------------- + +const EXEC_OPTS: ExecSyncOptionsWithStringEncoding = { + encoding: 'utf-8', + stdio: ['pipe', 'pipe', 'pipe'], + timeout: 600_000, +}; + +export function evaluateLocally( + instance: FullSWEInstance, + patch: string, + workDir: string, +): DockerEvalResult { + const repoDir = path.join(workDir, 'repo'); + fs.mkdirSync(workDir, { recursive: true }); + + try { + console.log(` [local] Cloning ${instance.repo} ...`); + spawnSync( + 'git', ['clone', '--quiet', instance.repo, repoDir], + { ...LIVE_OPTS, timeout: 120_000 }, + ); + + console.log(` [local] Checking out ${instance.base_commit.slice(0, 10)} ...`); + execSync(`git checkout "${instance.base_commit}" --quiet`, { + ...EXEC_OPTS, + cwd: repoDir, + }); + + console.log(` [local] Applying fix patch ...`); + const patchPath = path.join(workDir, 'fix.patch'); + fs.writeFileSync(patchPath, patch, 'utf-8'); + execSync(`git apply "${patchPath}"`, { ...EXEC_OPTS, cwd: repoDir }); + + if (instance.test_patch) { + console.log(` [local] Applying test patch ...`); + const testPatchPath = path.join(workDir, 'test.patch'); + fs.writeFileSync(testPatchPath, instance.test_patch, 'utf-8'); + try { + execSync(`git apply "${testPatchPath}"`, { ...EXEC_OPTS, cwd: repoDir }); + } catch { + // Test patch may not apply cleanly + } + } + + console.log(` [local] Running tests: ${instance.test_command}`); + const result = spawnSync('bash', ['-c', instance.test_command], { + ...LIVE_OPTS, + cwd: repoDir, + timeout: 300_000, + }); + + const stderr = (result.stderr || '').toString().trim(); + + if (result.status === 0) { + console.log(` [local] Tests completed.`); + return { passed: true, output: '' }; + } + + return { + passed: false, + output: '', + error: stderr || `exit code ${result.status}`, + }; + } catch (err: any) { + const stderr = (err.stderr || '').toString().trim(); + return { + passed: false, + output: '', + error: stderr || err.message || String(err), + }; + } +} + +// --------------------------------------------------------------------------- +// Cleanup +// --------------------------------------------------------------------------- + +export function cleanupWorkDir(workDir: string): void { + try { + fs.rmSync(workDir, { recursive: true, force: true }); + } catch { + // ignore + } +} diff --git a/tests/benchmark/swe/evaluator.ts b/tests/benchmark/swe/evaluator.ts new file mode 100644 index 0000000..5767138 --- /dev/null +++ b/tests/benchmark/swe/evaluator.ts @@ -0,0 +1,64 @@ +// --------------------------------------------------------------------------- +// SWE benchmark evaluator — run tests in a temp directory +// --------------------------------------------------------------------------- + +import fs from 'fs'; +import path from 'path'; +import { execSync } from 'child_process'; + +const TEST_TIMEOUT_MS = 15_000; + +export interface EvalResult { + passed: boolean; + output: string; + error?: string; +} + +/** + * Write files to a temporary directory, run the test command, return pass/fail. + */ +export function evaluateCase( + files: Record, + testCommand: string, + workDir: string, +): EvalResult { + // Ensure work directory exists + fs.mkdirSync(workDir, { recursive: true }); + + // Write all files + for (const [name, content] of Object.entries(files)) { + const filePath = path.join(workDir, name); + fs.mkdirSync(path.dirname(filePath), { recursive: true }); + fs.writeFileSync(filePath, content, 'utf-8'); + } + + // Run the test command + try { + const output = execSync(testCommand, { + cwd: workDir, + timeout: TEST_TIMEOUT_MS, + encoding: 'utf-8', + stdio: ['pipe', 'pipe', 'pipe'], + }); + return { passed: true, output: output.trim() }; + } catch (err: any) { + const stdout = (err.stdout || '').toString().trim(); + const stderr = (err.stderr || '').toString().trim(); + return { + passed: false, + output: stdout, + error: stderr || err.message || String(err), + }; + } +} + +/** + * Clean up a work directory. + */ +export function cleanupWorkDir(workDir: string): void { + try { + fs.rmSync(workDir, { recursive: true, force: true }); + } catch { + // ignore cleanup errors + } +} diff --git a/tests/benchmark/swe/harness.ts b/tests/benchmark/swe/harness.ts new file mode 100644 index 0000000..feb9fc8 --- /dev/null +++ b/tests/benchmark/swe/harness.ts @@ -0,0 +1,136 @@ +// --------------------------------------------------------------------------- +// SWE benchmark harness — sends code + issue to model, parses corrected files +// --------------------------------------------------------------------------- + +import type { ModelProvider } from '../../../src/infra/providers/types'; +import type { Message } from '../../../src/core/types'; +import type { MiniCase } from './dataset'; + +export interface HarnessResult { + correctedFiles: Record; + tokens: number; + error?: string; +} + +const SYSTEM_PROMPT = `You are a software engineer fixing bugs in source code. +You will be given a bug report and the project files. +Your task is to fix the bug so all tests pass. + +Rules: +- Only modify source files. NEVER modify test files. +- Output the COMPLETE corrected file content using this exact format: + +--- FILE: --- + +--- END FILE --- + +- You may output multiple files if needed. +- Do NOT include any explanation outside the file markers. +- Output ONLY the corrected file(s), nothing else.`; + +/** + * Send a mini-SWE case to the model and parse corrected files from the response. + */ +export async function runHarness( + provider: ModelProvider, + caseData: MiniCase, +): Promise { + // Build user message with issue + all file contents + const fileListing = Object.entries(caseData.files) + .map(([name, content]) => `--- ${name} ---\n${content}`) + .join('\n'); + + const userMessage = [ + 'Bug report:', + caseData.description, + '', + 'Project files:', + fileListing, + '', + 'Fix the bug in the source file(s) so that all tests pass.', + 'Output the corrected file(s) using the --- FILE: --- / --- END FILE --- format.', + ].join('\n'); + + const messages: Message[] = [ + { role: 'user', content: [{ type: 'text', text: userMessage }] }, + ]; + + try { + const response = await provider.complete(messages, { system: SYSTEM_PROMPT }); + + const text = response.content + .filter((b): b is { type: 'text'; text: string } => b.type === 'text') + .map(b => b.text) + .join(''); + + const tokens = + (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); + + const correctedFiles = parseFileBlocks(text); + + // If no file blocks found, try to infer from the response + if (Object.keys(correctedFiles).length === 0) { + // Fallback: look for code fences with the src filename + const fallback = parseFallbackCodeBlocks(text, caseData.files); + if (Object.keys(fallback).length > 0) { + return { correctedFiles: fallback, tokens }; + } + return { correctedFiles: {}, tokens, error: 'No corrected files found in model response' }; + } + + return { correctedFiles, tokens }; + } catch (err: any) { + return { correctedFiles: {}, tokens: 0, error: err.message || String(err) }; + } +} + +// --------------------------------------------------------------------------- +// Response parsing +// --------------------------------------------------------------------------- + +/** + * Parse `--- FILE: --- ... --- END FILE ---` blocks from model output. + */ +function parseFileBlocks(text: string): Record { + const files: Record = {}; + const regex = /---\s*FILE:\s*(.+?)\s*---\n([\s\S]*?)---\s*END FILE\s*---/g; + let match: RegExpExecArray | null; + + while ((match = regex.exec(text)) !== null) { + const filename = match[1].trim(); + const content = match[2]; + // Only accept source files, never test files + if (!filename.includes('test')) { + files[filename] = content; + } + } + + return files; +} + +/** + * Fallback: try to extract code from markdown fences like ```js ... ``` + * and match them to source filenames (excluding test files). + */ +function parseFallbackCodeBlocks(text: string, originalFiles: Record): Record { + const files: Record = {}; + const sourceFiles = Object.keys(originalFiles).filter(f => !f.includes('test')); + + if (sourceFiles.length !== 1) return files; // Only works for single source file + + const srcName = sourceFiles[0]; + // Match the last code fence (most likely the final corrected version) + const fenceRegex = /```(?:js|javascript)?\n([\s\S]*?)```/g; + let lastMatch: string | null = null; + let match: RegExpExecArray | null; + + while ((match = fenceRegex.exec(text)) !== null) { + lastMatch = match[1]; + } + + if (lastMatch) { + files[srcName] = lastMatch; + } + + return files; +} diff --git a/tests/benchmark/swe/index.ts b/tests/benchmark/swe/index.ts new file mode 100644 index 0000000..d83a239 --- /dev/null +++ b/tests/benchmark/swe/index.ts @@ -0,0 +1,292 @@ +// --------------------------------------------------------------------------- +// SWE benchmark module — BenchmarkModule entry point +// --------------------------------------------------------------------------- + +import path from 'path'; +import type { BenchmarkConfig, BenchmarkModuleResult, BenchmarkProvider, SWEProviderResult, SWEResult } from '../types'; +import type { ModelProvider } from '../../../src/infra/providers/types'; +import { AnthropicProvider } from '../../../src/infra/providers/anthropic'; +import { OpenAIProvider } from '../../../src/infra/providers/openai'; +import { GeminiProvider } from '../../../src/infra/providers/gemini'; +import { loadMiniCases, loadCuratedInstances, MiniCase } from './dataset'; +import { runHarness } from './harness'; +import { evaluateCase, cleanupWorkDir } from './evaluator'; +import { + isDockerAvailable, + generateFix, + evaluateWithDocker, + evaluateLocally, + cleanupWorkDir as cleanupDockerWorkDir, + type FullSWEInstance, +} from './docker-evaluator'; + +// Module metadata (used by run-benchmark.ts discovery) +export const name = 'swe'; + +// --------------------------------------------------------------------------- +// Provider creation (same pattern as TAU) +// --------------------------------------------------------------------------- + +function createProvider(bp: BenchmarkProvider): ModelProvider { + switch (bp.id) { + case 'anthropic': + return new AnthropicProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + case 'openai': + return new OpenAIProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + case 'gemini': + return new GeminiProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + default: + return new OpenAIProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + } +} + +// --------------------------------------------------------------------------- +// Run single provider on all mini-SWE cases +// --------------------------------------------------------------------------- + +async function runProviderOnCases( + bp: BenchmarkProvider, + cases: MiniCase[], + config: BenchmarkConfig, +): Promise { + const provider = createProvider(bp); + const results: SWEResult[] = []; + + for (const c of cases) { + const startMs = Date.now(); + const workDir = path.join( + process.cwd(), + 'tests', + '.tmp', + `swe-${bp.id}-${c.id}-${Date.now()}`, + ); + + try { + // 1. Send to model + const harness = await runHarness(provider, c); + + if (harness.error || Object.keys(harness.correctedFiles).length === 0) { + const durationMs = Date.now() - startMs; + const errMsg = harness.error || 'No corrected files returned'; + console.log(` [${bp.id}] ${c.id}: FAIL (${errMsg})`); + results.push({ + instance_id: c.id, + resolved: false, + tokens_used: harness.tokens, + duration_ms: durationMs, + error: errMsg, + }); + continue; + } + + // 2. Merge corrected files with original files + const mergedFiles = { ...c.files }; + for (const [name, content] of Object.entries(harness.correctedFiles)) { + mergedFiles[name] = content; + } + + // 3. Evaluate (write files + run test) + const evalResult = evaluateCase(mergedFiles, c.test_command, workDir); + const durationMs = Date.now() - startMs; + + const status = evalResult.passed ? 'PASS' : 'FAIL'; + const detail = evalResult.passed ? '' : ` (${evalResult.error || evalResult.output})`; + console.log( + ` [${bp.id}] ${c.id}: ${status} (${harness.tokens} tokens, ${durationMs}ms)${detail}`, + ); + + results.push({ + instance_id: c.id, + resolved: evalResult.passed, + tokens_used: harness.tokens, + duration_ms: durationMs, + error: evalResult.passed ? undefined : evalResult.error, + }); + } catch (err: any) { + const durationMs = Date.now() - startMs; + console.log(` [${bp.id}] ${c.id}: FAIL (${err.message})`); + results.push({ + instance_id: c.id, + resolved: false, + tokens_used: 0, + duration_ms: durationMs, + error: err.message || String(err), + }); + } finally { + cleanupWorkDir(workDir); + } + } + + const resolved = results.filter(r => r.resolved).length; + const total = results.length; + const avgTokens = total > 0 ? Math.round(results.reduce((s, r) => s + r.tokens_used, 0) / total) : 0; + const avgDuration = total > 0 ? Math.round(results.reduce((s, r) => s + r.duration_ms, 0) / total) : 0; + + return { + provider: bp, + summary: { + dataset: 'mini-swe', + total, + resolved, + rate: total > 0 ? resolved / total : 0, + avg_tokens: avgTokens, + avg_duration_ms: avgDuration, + }, + results, + }; +} + +// --------------------------------------------------------------------------- +// Run single provider on full SWE-bench instances +// --------------------------------------------------------------------------- + +async function runProviderOnFullInstances( + bp: BenchmarkProvider, + instances: FullSWEInstance[], + useDocker: boolean, + dockerProxy?: string, +): Promise { + const provider = createProvider(bp); + const results: SWEResult[] = []; + + for (const inst of instances) { + const startMs = Date.now(); + const workDir = path.join( + process.cwd(), + 'tests', + '.tmp', + `swe-full-${bp.id}-${inst.instance_id}-${Date.now()}`, + ); + + try { + // 1. Generate fix from model (clone repo, read files, LLM, generate diff) + console.log(` [${bp.id}] ${inst.instance_id}: generating fix ...`); + const harness = await generateFix(provider, inst, dockerProxy); + + if (harness.error || !harness.patch) { + const durationMs = Date.now() - startMs; + const errMsg = harness.error || 'No fix generated'; + console.log(` [${bp.id}] ${inst.instance_id}: FAIL (${errMsg})`); + results.push({ + instance_id: inst.instance_id, + resolved: false, + tokens_used: harness.tokens, + duration_ms: durationMs, + error: errMsg, + }); + continue; + } + + // 2. Evaluate the fix patch + console.log(` [${bp.id}] ${inst.instance_id}: fix generated (${harness.tokens} tokens), evaluating ...`); + const evalResult = useDocker + ? evaluateWithDocker(inst, harness.patch, workDir, dockerProxy) + : evaluateLocally(inst, harness.patch, workDir); + + const durationMs = Date.now() - startMs; + const status = evalResult.passed ? 'PASS' : 'FAIL'; + const detail = evalResult.passed ? '' : ` (${(evalResult.error || '').slice(0, 100)})`; + console.log( + ` [${bp.id}] ${inst.instance_id}: ${status} (${harness.tokens} tokens, ${durationMs}ms)${detail}`, + ); + + results.push({ + instance_id: inst.instance_id, + resolved: evalResult.passed, + tokens_used: harness.tokens, + duration_ms: durationMs, + error: evalResult.passed ? undefined : evalResult.error, + }); + } catch (err: any) { + const durationMs = Date.now() - startMs; + console.log(` [${bp.id}] ${inst.instance_id}: FAIL (${err.message})`); + results.push({ + instance_id: inst.instance_id, + resolved: false, + tokens_used: 0, + duration_ms: durationMs, + error: err.message || String(err), + }); + } finally { + cleanupDockerWorkDir(workDir); + } + } + + const resolved = results.filter(r => r.resolved).length; + const total = results.length; + const avgTokens = total > 0 ? Math.round(results.reduce((s, r) => s + r.tokens_used, 0) / total) : 0; + const avgDuration = total > 0 ? Math.round(results.reduce((s, r) => s + r.duration_ms, 0) / total) : 0; + + return { + provider: bp, + summary: { + dataset: 'swe-bench-full', + total, + resolved, + rate: total > 0 ? resolved / total : 0, + avg_tokens: avgTokens, + avg_duration_ms: avgDuration, + }, + results, + }; +} + +// --------------------------------------------------------------------------- +// Module entry point +// --------------------------------------------------------------------------- + +export async function run(config: BenchmarkConfig): Promise { + if (config.sweMode === 'full') { + return runFullMode(config); + } + + const cases = loadMiniCases(); + if (cases.length === 0) { + console.log(' SWE: no mini-SWE cases found'); + return {}; + } + + if (config.providers.length === 0) { + console.log(' SWE: no providers configured, skipping'); + return {}; + } + + console.log(`\n SWE mini mode: ${cases.length} cases`); + + const allResults: SWEProviderResult[] = []; + + for (const bp of config.providers) { + console.log(`\n Running provider: ${bp.id} / ${bp.model}`); + const providerResult = await runProviderOnCases(bp, cases, config); + allResults.push(providerResult); + } + + return { swe: allResults }; +} + +async function runFullMode(config: BenchmarkConfig): Promise { + const instances = loadCuratedInstances(); + if (instances.length === 0) { + console.log(' SWE: no curated instances found for full mode'); + return {}; + } + + if (config.providers.length === 0) { + console.log(' SWE: no providers configured, skipping'); + return {}; + } + + const useDocker = isDockerAvailable(); + console.log(`\n SWE full mode: ${instances.length} curated instances`); + console.log(` Docker: ${useDocker ? 'available (using Docker evaluation)' : 'not available (using local git-based evaluation)'}`); + + const allResults: SWEProviderResult[] = []; + + for (const bp of config.providers) { + console.log(`\n Running provider: ${bp.id} / ${bp.model}`); + const providerResult = await runProviderOnFullInstances(bp, instances, useDocker, config.dockerProxy); + allResults.push(providerResult); + } + + return { swe: allResults }; +} diff --git a/tests/benchmark/tau/domains/airline/database.ts b/tests/benchmark/tau/domains/airline/database.ts new file mode 100644 index 0000000..b719dd2 --- /dev/null +++ b/tests/benchmark/tau/domains/airline/database.ts @@ -0,0 +1,220 @@ +// --------------------------------------------------------------------------- +// Airline domain database types and initial data +// --------------------------------------------------------------------------- + +export interface User { + user_id: string; + name: string; + email: string; + phone: string; + membership: 'regular' | 'silver' | 'gold' | 'platinum'; +} + +export interface Flight { + flight_id: string; + airline: string; + route: string; + date: string; + departure_time: string; + arrival_time: string; + price: number; + seats_available: number; + aircraft: string; +} + +export interface Reservation { + reservation_id: string; + user_id: string; + flight_id: string; + status: 'confirmed' | 'cancelled' | 'pending'; + seat_class: 'economy' | 'business' | 'first'; + payment_amount: number; + booked_at: string; +} + +export interface AirlineDatabase { + users: User[]; + flights: Flight[]; + reservations: Reservation[]; +} + +export function getInitialDatabase(): AirlineDatabase { + return { + users: [ + { + user_id: 'USR001', + name: 'John Smith', + email: 'john.smith@email.com', + phone: '555-0101', + membership: 'gold', + }, + { + user_id: 'USR002', + name: 'Alice Johnson', + email: 'alice.j@email.com', + phone: '555-0102', + membership: 'regular', + }, + { + user_id: 'USR003', + name: 'Bob Chen', + email: 'bob.chen@email.com', + phone: '555-0103', + membership: 'silver', + }, + { + user_id: 'USR004', + name: 'Maria Garcia', + email: 'maria.g@email.com', + phone: '555-0104', + membership: 'platinum', + }, + { + user_id: 'USR005', + name: 'David Kim', + email: 'david.k@email.com', + phone: '555-0105', + membership: 'regular', + }, + ], + + flights: [ + { + flight_id: 'FL001', + airline: 'SkyAir', + route: 'SFO-LAX', + date: '2026-03-15', + departure_time: '08:00', + arrival_time: '09:30', + price: 150, + seats_available: 42, + aircraft: 'A320', + }, + { + flight_id: 'FL002', + airline: 'SkyAir', + route: 'SFO-LAX', + date: '2026-03-15', + departure_time: '14:00', + arrival_time: '15:30', + price: 180, + seats_available: 15, + aircraft: 'A320', + }, + { + flight_id: 'FL003', + airline: 'SkyAir', + route: 'SFO-LAX', + date: '2026-03-17', + departure_time: '08:00', + arrival_time: '09:30', + price: 160, + seats_available: 38, + aircraft: 'A320', + }, + { + flight_id: 'FL004', + airline: 'SkyAir', + route: 'SFO-LAX', + date: '2026-03-17', + departure_time: '18:00', + arrival_time: '19:30', + price: 200, + seats_available: 5, + aircraft: 'B737', + }, + { + flight_id: 'FL005', + airline: 'SkyAir', + route: 'LAX-JFK', + date: '2026-03-20', + departure_time: '10:00', + arrival_time: '18:30', + price: 350, + seats_available: 60, + aircraft: 'B777', + }, + { + flight_id: 'FL006', + airline: 'SkyAir', + route: 'JFK-SFO', + date: '2026-03-22', + departure_time: '07:00', + arrival_time: '10:30', + price: 380, + seats_available: 22, + aircraft: 'A350', + }, + { + flight_id: 'FL007', + airline: 'SkyAir', + route: 'SFO-SEA', + date: '2026-03-18', + departure_time: '12:00', + arrival_time: '14:00', + price: 120, + seats_available: 0, + aircraft: 'A320', + }, + { + flight_id: 'FL008', + airline: 'SkyAir', + route: 'SFO-SEA', + date: '2026-03-19', + departure_time: '12:00', + arrival_time: '14:00', + price: 130, + seats_available: 25, + aircraft: 'A320', + }, + ], + + reservations: [ + { + reservation_id: 'RES001', + user_id: 'USR001', + flight_id: 'FL001', + status: 'confirmed', + seat_class: 'economy', + payment_amount: 150, + booked_at: '2026-02-01', + }, + { + reservation_id: 'RES002', + user_id: 'USR002', + flight_id: 'FL005', + status: 'confirmed', + seat_class: 'economy', + payment_amount: 350, + booked_at: '2026-02-05', + }, + { + reservation_id: 'RES003', + user_id: 'USR003', + flight_id: 'FL002', + status: 'confirmed', + seat_class: 'business', + payment_amount: 360, + booked_at: '2026-02-10', + }, + { + reservation_id: 'RES004', + user_id: 'USR004', + flight_id: 'FL006', + status: 'confirmed', + seat_class: 'first', + payment_amount: 760, + booked_at: '2026-01-20', + }, + { + reservation_id: 'RES005', + user_id: 'USR005', + flight_id: 'FL007', + status: 'confirmed', + seat_class: 'economy', + payment_amount: 120, + booked_at: '2026-02-15', + }, + ], + }; +} diff --git a/tests/benchmark/tau/domains/airline/handlers.ts b/tests/benchmark/tau/domains/airline/handlers.ts new file mode 100644 index 0000000..d35fe25 --- /dev/null +++ b/tests/benchmark/tau/domains/airline/handlers.ts @@ -0,0 +1,74 @@ +// --------------------------------------------------------------------------- +// Airline domain tool handlers +// --------------------------------------------------------------------------- + +export type ToolHandler = (db: any, args: any) => any; + +export function getAirlineHandlers(): Record { + return { + get_user_details: (db, args: { user_id: string }) => { + const user = db.users.find((u: any) => u.user_id === args.user_id); + if (!user) return { error: `User not found: ${args.user_id}` }; + return user; + }, + + get_reservation_details: (db, args: { reservation_id: string }) => { + const res = db.reservations.find((r: any) => r.reservation_id === args.reservation_id); + if (!res) return { error: `Reservation not found: ${args.reservation_id}` }; + return res; + }, + + list_user_reservations: (db, args: { user_id: string }) => { + const list = db.reservations.filter((r: any) => r.user_id === args.user_id); + return { reservations: list }; + }, + + get_flight_details: (db, args: { flight_id: string }) => { + const flight = db.flights.find((f: any) => f.flight_id === args.flight_id); + if (!flight) return { error: `Flight not found: ${args.flight_id}` }; + return flight; + }, + + search_flights: (db, args: { route: string; date?: string }) => { + let results = db.flights.filter((f: any) => f.route === args.route); + if (args.date) { + results = results.filter((f: any) => f.date === args.date); + } + return { flights: results }; + }, + + update_reservation: (db, args: { reservation_id: string; new_flight_id: string }) => { + const res = db.reservations.find((r: any) => r.reservation_id === args.reservation_id); + if (!res) return { error: `Reservation not found: ${args.reservation_id}` }; + if (res.status === 'cancelled') return { error: 'Cannot update a cancelled reservation' }; + + const newFlight = db.flights.find((f: any) => f.flight_id === args.new_flight_id); + if (!newFlight) return { error: `Flight not found: ${args.new_flight_id}` }; + if (newFlight.seats_available <= 0) return { error: `No seats available on flight ${args.new_flight_id}` }; + + // Release seat on old flight + const oldFlight = db.flights.find((f: any) => f.flight_id === res.flight_id); + if (oldFlight) oldFlight.seats_available += 1; + + // Book seat on new flight + newFlight.seats_available -= 1; + res.flight_id = args.new_flight_id; + res.payment_amount = newFlight.price; + + return { success: true, reservation: { ...res } }; + }, + + cancel_reservation: (db, args: { reservation_id: string }) => { + const res = db.reservations.find((r: any) => r.reservation_id === args.reservation_id); + if (!res) return { error: `Reservation not found: ${args.reservation_id}` }; + if (res.status === 'cancelled') return { error: 'Reservation is already cancelled' }; + + // Release seat + const flight = db.flights.find((f: any) => f.flight_id === res.flight_id); + if (flight) flight.seats_available += 1; + + res.status = 'cancelled'; + return { success: true, reservation: { ...res } }; + }, + }; +} diff --git a/tests/benchmark/tau/domains/airline/policy.md b/tests/benchmark/tau/domains/airline/policy.md new file mode 100644 index 0000000..c125fa0 --- /dev/null +++ b/tests/benchmark/tau/domains/airline/policy.md @@ -0,0 +1,45 @@ +# Airline Customer Service Policy + +You are an airline customer service agent. Follow these policies strictly when handling customer requests. + +## Identity Verification + +- Always verify the customer's identity before making any changes. +- Ask for the user's name or user ID. Look up their information using the `get_user_details` tool. +- Confirm key details (name, reservation ID) before proceeding. + +## Flight Changes + +- Customers may request to change their flight to a different date or route. +- Use `search_flights` to find available alternatives. +- Gold and Platinum members: flight changes are free. +- Silver members: $50 change fee applies. +- Regular members: $75 change fee applies. +- Changes must be made at least 2 hours before departure. +- Use `update_reservation` to apply the change. +- Always confirm the new flight details with the customer before making the change. + +## Cancellations + +- Customers may cancel their reservation. +- Gold and Platinum members: full refund. +- Silver members: 80% refund. +- Regular members: 50% refund, or full refund if cancelled more than 72 hours before departure. +- Use `cancel_reservation` to process the cancellation. +- Inform the customer of the refund amount and timeline (5-7 business days). + +## Baggage Policy + +- Economy: 1 checked bag (23kg) included. +- Business: 2 checked bags (32kg each) included. +- First: 3 checked bags (32kg each) included. +- Additional bags: $35 each. +- Overweight bags (23-32kg): $50 surcharge. + +## General Rules + +- Be polite, professional, and concise. +- If you cannot fulfill a request due to policy restrictions, explain clearly why. +- Do not make up information. Only provide details from the database. +- When the customer's issue is fully resolved, end with "###STOP###". +- If the customer says goodbye or has no more questions, end with "###STOP###". diff --git a/tests/benchmark/tau/domains/airline/tasks.json b/tests/benchmark/tau/domains/airline/tasks.json new file mode 100644 index 0000000..0e45443 --- /dev/null +++ b/tests/benchmark/tau/domains/airline/tasks.json @@ -0,0 +1,71 @@ +[ + { + "task_id": "airline_001", + "user_scenario": "You are John Smith (user ID: USR001). Your reservation ID is RES001. You are currently booked on a March 15 SFO to LAX flight. You want to change your flight to March 17 instead. You prefer the morning flight if available. If the agent asks for confirmation, agree to proceed.", + "expected_db": { + "reservations": [ + { + "reservation_id": "RES001", + "flight_id": "FL003", + "status": "confirmed" + } + ] + }, + "max_turns": 10 + }, + { + "task_id": "airline_002", + "user_scenario": "You are Alice Johnson (user ID: USR002). Your reservation ID is RES002. You need to cancel your LAX to JFK flight on March 20 because your plans changed. Accept the cancellation terms whatever they are.", + "expected_db": { + "reservations": [ + { + "reservation_id": "RES002", + "status": "cancelled" + } + ] + }, + "max_turns": 10 + }, + { + "task_id": "airline_003", + "user_scenario": "You are Bob Chen (user ID: USR003). You want to check the details of your upcoming flight - you think your reservation is RES003 but you're not sure of the exact departure time. Just ask for the information, you don't want to make any changes.", + "expected_db": { + "reservations": [ + { + "reservation_id": "RES003", + "status": "confirmed", + "flight_id": "FL002" + } + ] + }, + "max_turns": 8 + }, + { + "task_id": "airline_004", + "user_scenario": "You are Maria Garcia (user ID: USR004). Your reservation is RES004 for a JFK to SFO flight. You want to know what the baggage allowance is for your ticket class, and also confirm your flight details. You don't want to make any changes.", + "expected_db": { + "reservations": [ + { + "reservation_id": "RES004", + "status": "confirmed", + "flight_id": "FL006" + } + ] + }, + "max_turns": 8 + }, + { + "task_id": "airline_005", + "user_scenario": "You are David Kim (user ID: USR005). Your reservation is RES005 for a SFO to SEA flight on March 18. You just found out that flight has no seats available and you're worried. You want the agent to help you rebook to another SFO-SEA flight. Accept any available option.", + "expected_db": { + "reservations": [ + { + "reservation_id": "RES005", + "flight_id": "FL008", + "status": "confirmed" + } + ] + }, + "max_turns": 12 + } +] diff --git a/tests/benchmark/tau/domains/airline/tools.ts b/tests/benchmark/tau/domains/airline/tools.ts new file mode 100644 index 0000000..d0bed04 --- /dev/null +++ b/tests/benchmark/tau/domains/airline/tools.ts @@ -0,0 +1,127 @@ +// --------------------------------------------------------------------------- +// Airline domain tool definitions (Anthropic API format) +// --------------------------------------------------------------------------- + +export interface ToolDef { + name: string; + description: string; + input_schema: Record; +} + +export function getAirlineToolDefs(): ToolDef[] { + return [ + { + name: 'get_user_details', + description: + 'Look up a user by their user ID. Returns user profile including name, email, phone, and membership tier.', + input_schema: { + type: 'object', + properties: { + user_id: { + type: 'string', + description: 'The user ID to look up (e.g. "USR001")', + }, + }, + required: ['user_id'], + }, + }, + { + name: 'get_reservation_details', + description: + 'Look up a reservation by reservation ID. Returns booking details including flight, status, and payment.', + input_schema: { + type: 'object', + properties: { + reservation_id: { + type: 'string', + description: 'The reservation ID (e.g. "RES001")', + }, + }, + required: ['reservation_id'], + }, + }, + { + name: 'list_user_reservations', + description: + 'List all reservations for a given user. Returns an array of reservation records.', + input_schema: { + type: 'object', + properties: { + user_id: { + type: 'string', + description: 'The user ID whose reservations to list', + }, + }, + required: ['user_id'], + }, + }, + { + name: 'get_flight_details', + description: + 'Get details for a specific flight by flight ID. Returns route, schedule, price, and availability.', + input_schema: { + type: 'object', + properties: { + flight_id: { + type: 'string', + description: 'The flight ID (e.g. "FL001")', + }, + }, + required: ['flight_id'], + }, + }, + { + name: 'search_flights', + description: + 'Search for available flights by route and optional date. Returns matching flights with availability.', + input_schema: { + type: 'object', + properties: { + route: { + type: 'string', + description: 'Flight route in "ORIGIN-DEST" format (e.g. "SFO-LAX")', + }, + date: { + type: 'string', + description: 'Date in YYYY-MM-DD format. If omitted, returns all dates.', + }, + }, + required: ['route'], + }, + }, + { + name: 'update_reservation', + description: + 'Update a reservation to change the flight. The new flight must have available seats.', + input_schema: { + type: 'object', + properties: { + reservation_id: { + type: 'string', + description: 'The reservation ID to update', + }, + new_flight_id: { + type: 'string', + description: 'The new flight ID to switch to', + }, + }, + required: ['reservation_id', 'new_flight_id'], + }, + }, + { + name: 'cancel_reservation', + description: + 'Cancel a reservation. Sets the reservation status to "cancelled". Cannot be undone.', + input_schema: { + type: 'object', + properties: { + reservation_id: { + type: 'string', + description: 'The reservation ID to cancel', + }, + }, + required: ['reservation_id'], + }, + }, + ]; +} diff --git a/tests/benchmark/tau/domains/retail/database.ts b/tests/benchmark/tau/domains/retail/database.ts new file mode 100644 index 0000000..e37dc3e --- /dev/null +++ b/tests/benchmark/tau/domains/retail/database.ts @@ -0,0 +1,156 @@ +// --------------------------------------------------------------------------- +// Retail domain database types and initial data +// --------------------------------------------------------------------------- + +export interface Customer { + customer_id: string; + name: string; + email: string; + phone: string; + membership: 'regular' | 'vip' | 'premium'; +} + +export interface Product { + product_id: string; + name: string; + category: string; + price: number; + stock: number; +} + +export interface OrderItem { + product_id: string; + product_name: string; + quantity: number; + unit_price: number; +} + +export interface Order { + order_id: string; + customer_id: string; + items: OrderItem[]; + total: number; + status: 'pending' | 'shipped' | 'delivered' | 'cancelled' | 'returned'; + order_date: string; + delivery_date?: string; +} + +export interface RetailDatabase { + customers: Customer[]; + products: Product[]; + orders: Order[]; +} + +export function getInitialDatabase(): RetailDatabase { + return { + customers: [ + { + customer_id: 'CUST001', + name: 'Emma Wilson', + email: 'emma.w@email.com', + phone: '555-1001', + membership: 'vip', + }, + { + customer_id: 'CUST002', + name: 'James Brown', + email: 'james.b@email.com', + phone: '555-1002', + membership: 'regular', + }, + { + customer_id: 'CUST003', + name: 'Sophia Lee', + email: 'sophia.l@email.com', + phone: '555-1003', + membership: 'premium', + }, + { + customer_id: 'CUST004', + name: 'Liam Martinez', + email: 'liam.m@email.com', + phone: '555-1004', + membership: 'regular', + }, + { + customer_id: 'CUST005', + name: 'Olivia Davis', + email: 'olivia.d@email.com', + phone: '555-1005', + membership: 'vip', + }, + ], + + products: [ + { product_id: 'PROD001', name: 'Wireless Headphones', category: 'Electronics', price: 79.99, stock: 150 }, + { product_id: 'PROD002', name: 'Bluetooth Speaker', category: 'Electronics', price: 49.99, stock: 80 }, + { product_id: 'PROD003', name: 'Running Shoes (Size 10)', category: 'Footwear', price: 129.99, stock: 30 }, + { product_id: 'PROD004', name: 'Running Shoes (Size 11)', category: 'Footwear', price: 129.99, stock: 0 }, + { product_id: 'PROD005', name: 'Cotton T-Shirt (M)', category: 'Apparel', price: 24.99, stock: 200 }, + { product_id: 'PROD006', name: 'Cotton T-Shirt (L)', category: 'Apparel', price: 24.99, stock: 180 }, + { product_id: 'PROD007', name: 'Yoga Mat', category: 'Fitness', price: 34.99, stock: 60 }, + { product_id: 'PROD008', name: 'Water Bottle (32oz)', category: 'Fitness', price: 19.99, stock: 100 }, + { product_id: 'PROD009', name: 'Laptop Stand', category: 'Electronics', price: 59.99, stock: 45 }, + { product_id: 'PROD010', name: 'USB-C Hub', category: 'Electronics', price: 39.99, stock: 70 }, + ], + + orders: [ + { + order_id: 'ORD001', + customer_id: 'CUST001', + items: [ + { product_id: 'PROD001', product_name: 'Wireless Headphones', quantity: 1, unit_price: 79.99 }, + { product_id: 'PROD008', product_name: 'Water Bottle (32oz)', quantity: 2, unit_price: 19.99 }, + ], + total: 119.97, + status: 'delivered', + order_date: '2026-01-15', + delivery_date: '2026-01-22', + }, + { + order_id: 'ORD002', + customer_id: 'CUST002', + items: [ + { product_id: 'PROD003', product_name: 'Running Shoes (Size 10)', quantity: 1, unit_price: 129.99 }, + ], + total: 129.99, + status: 'delivered', + order_date: '2026-01-20', + delivery_date: '2026-01-27', + }, + { + order_id: 'ORD003', + customer_id: 'CUST003', + items: [ + { product_id: 'PROD005', product_name: 'Cotton T-Shirt (M)', quantity: 3, unit_price: 24.99 }, + ], + total: 74.97, + status: 'delivered', + order_date: '2025-12-10', + delivery_date: '2025-12-17', + }, + { + order_id: 'ORD004', + customer_id: 'CUST004', + items: [ + { product_id: 'PROD009', product_name: 'Laptop Stand', quantity: 1, unit_price: 59.99 }, + { product_id: 'PROD010', product_name: 'USB-C Hub', quantity: 1, unit_price: 39.99 }, + ], + total: 99.98, + status: 'shipped', + order_date: '2026-02-05', + }, + { + order_id: 'ORD005', + customer_id: 'CUST005', + items: [ + { product_id: 'PROD007', product_name: 'Yoga Mat', quantity: 1, unit_price: 34.99 }, + ], + total: 34.99, + status: 'delivered', + order_date: '2026-02-01', + delivery_date: '2026-02-07', + }, + ], + }; +} diff --git a/tests/benchmark/tau/domains/retail/handlers.ts b/tests/benchmark/tau/domains/retail/handlers.ts new file mode 100644 index 0000000..04b7db3 --- /dev/null +++ b/tests/benchmark/tau/domains/retail/handlers.ts @@ -0,0 +1,147 @@ +// --------------------------------------------------------------------------- +// Retail domain tool handlers +// --------------------------------------------------------------------------- + +export type ToolHandler = (db: any, args: any) => any; + +export function getRetailHandlers(): Record { + return { + get_customer_details: (db, args: { customer_id: string }) => { + const customer = db.customers.find((c: any) => c.customer_id === args.customer_id); + if (!customer) return { error: `Customer not found: ${args.customer_id}` }; + return customer; + }, + + get_order_details: (db, args: { order_id: string }) => { + const order = db.orders.find((o: any) => o.order_id === args.order_id); + if (!order) return { error: `Order not found: ${args.order_id}` }; + return order; + }, + + list_customer_orders: (db, args: { customer_id: string }) => { + const orders = db.orders.filter((o: any) => o.customer_id === args.customer_id); + return { orders }; + }, + + get_product_details: (db, args: { product_id: string }) => { + const product = db.products.find((p: any) => p.product_id === args.product_id); + if (!product) return { error: `Product not found: ${args.product_id}` }; + return product; + }, + + search_products: (db, args: { query: string }) => { + const q = args.query.toLowerCase(); + const results = db.products.filter( + (p: any) => p.name.toLowerCase().includes(q) || p.category.toLowerCase().includes(q), + ); + return { products: results }; + }, + + process_return: (db, args: { order_id: string }) => { + const order = db.orders.find((o: any) => o.order_id === args.order_id); + if (!order) return { error: `Order not found: ${args.order_id}` }; + if (order.status !== 'delivered') { + return { error: `Order ${args.order_id} is not in delivered status (current: ${order.status})` }; + } + + // Check 30-day return window + if (order.delivery_date) { + const deliveryDate = new Date(order.delivery_date); + const now = new Date(); + const daysSinceDelivery = Math.floor( + (now.getTime() - deliveryDate.getTime()) / (1000 * 60 * 60 * 24), + ); + if (daysSinceDelivery > 30) { + return { + error: `Return window expired. Order was delivered ${daysSinceDelivery} days ago (30-day limit).`, + }; + } + } + + // Restock items + for (const item of order.items) { + const product = db.products.find((p: any) => p.product_id === item.product_id); + if (product) product.stock += item.quantity; + } + + order.status = 'returned'; + return { success: true, order: { ...order } }; + }, + + process_exchange: ( + db, + args: { order_id: string; old_product_id: string; new_product_id: string }, + ) => { + const order = db.orders.find((o: any) => o.order_id === args.order_id); + if (!order) return { error: `Order not found: ${args.order_id}` }; + if (order.status !== 'delivered') { + return { error: `Order ${args.order_id} is not in delivered status` }; + } + + // Check 30-day exchange window + if (order.delivery_date) { + const deliveryDate = new Date(order.delivery_date); + const now = new Date(); + const daysSinceDelivery = Math.floor( + (now.getTime() - deliveryDate.getTime()) / (1000 * 60 * 60 * 24), + ); + if (daysSinceDelivery > 30) { + return { + error: `Exchange window expired. Order was delivered ${daysSinceDelivery} days ago (30-day limit).`, + }; + } + } + + // Find old item in order + const oldItemIndex = order.items.findIndex( + (i: any) => i.product_id === args.old_product_id, + ); + if (oldItemIndex === -1) { + return { error: `Product ${args.old_product_id} not found in order ${args.order_id}` }; + } + + // Find new product + const newProduct = db.products.find((p: any) => p.product_id === args.new_product_id); + if (!newProduct) return { error: `Product not found: ${args.new_product_id}` }; + if (newProduct.stock <= 0) { + return { error: `Product ${args.new_product_id} (${newProduct.name}) is out of stock` }; + } + + const oldItem = order.items[oldItemIndex]; + + // Restock old product + const oldProduct = db.products.find((p: any) => p.product_id === args.old_product_id); + if (oldProduct) oldProduct.stock += oldItem.quantity; + + // Deduct new product stock + newProduct.stock -= oldItem.quantity; + + // Update order item + order.items[oldItemIndex] = { + product_id: newProduct.product_id, + product_name: newProduct.name, + quantity: oldItem.quantity, + unit_price: newProduct.price, + }; + + // Recalculate total + order.total = order.items.reduce( + (sum: number, i: any) => sum + i.unit_price * i.quantity, + 0, + ); + + return { + success: true, + order: { ...order }, + price_difference: newProduct.price - oldItem.unit_price, + }; + }, + + update_order_status: (db, args: { order_id: string; new_status: string }) => { + const order = db.orders.find((o: any) => o.order_id === args.order_id); + if (!order) return { error: `Order not found: ${args.order_id}` }; + order.status = args.new_status; + return { success: true, order: { ...order } }; + }, + }; +} diff --git a/tests/benchmark/tau/domains/retail/policy.md b/tests/benchmark/tau/domains/retail/policy.md new file mode 100644 index 0000000..93749cd --- /dev/null +++ b/tests/benchmark/tau/domains/retail/policy.md @@ -0,0 +1,53 @@ +# Retail Customer Service Policy + +You are an online retail customer service agent. Follow these policies strictly. + +## Identity Verification + +- Verify the customer's identity before making changes to their orders. +- Ask for the customer's name or customer ID. +- Use `get_customer_details` to look up their information and confirm. + +## Order Status + +- Use `get_order_details` to check order status. +- Provide the customer with their order status, items, and tracking info if available. +- Order statuses: pending, shipped, delivered, cancelled, returned. + +## Returns + +- Items may be returned within 30 days of delivery. +- Items must be in unused, original condition. +- Use `process_return` to initiate the return. +- Refunds are processed to the original payment method within 5-10 business days. +- If outside the 30-day window, politely deny the return and explain the policy. + +## Exchanges + +- Exchanges are allowed within 30 days of delivery. +- The replacement item must be in stock. +- Use `search_products` to find alternatives, then `process_exchange` to complete. +- If the new item costs more, the customer pays the difference. +- If the new item costs less, the difference is refunded. + +## Membership Discounts + +- VIP members: 10% discount on all purchases. +- Premium members: 15% discount on all purchases. +- Regular members: no discount. +- Discounts cannot be applied retroactively to past orders. +- Use the customer's membership tier from their profile. + +## Shipping + +- Standard shipping: 5-7 business days, free for orders over $50. +- Express shipping: 2-3 business days, $9.99. +- Overnight shipping: next business day, $19.99. + +## General Rules + +- Be polite, helpful, and professional. +- Do not make up information. Only provide details from the database. +- If you cannot fulfill a request, explain clearly why. +- When the customer's issue is fully resolved, end with "###STOP###". +- If the customer says goodbye or has no more questions, end with "###STOP###". diff --git a/tests/benchmark/tau/domains/retail/tasks.json b/tests/benchmark/tau/domains/retail/tasks.json new file mode 100644 index 0000000..3be303d --- /dev/null +++ b/tests/benchmark/tau/domains/retail/tasks.json @@ -0,0 +1,67 @@ +[ + { + "task_id": "retail_001", + "user_scenario": "You are Emma Wilson (customer ID: CUST001). You received your order ORD001 but the Wireless Headphones are defective - they won't pair with your phone. You want to return the headphones. You're happy to keep the water bottles. If asked, confirm it was delivered on January 22.", + "expected_db": { + "orders": [ + { + "order_id": "ORD001", + "status": "returned" + } + ] + }, + "max_turns": 10 + }, + { + "task_id": "retail_002", + "user_scenario": "You are James Brown (customer ID: CUST002). You received Running Shoes (Size 10) in order ORD002 but they're too small. You want to exchange them for Size 11 instead. If Size 11 is out of stock, ask when they'll be available.", + "expected_db": { + "orders": [ + { + "order_id": "ORD002", + "status": "delivered" + } + ] + }, + "max_turns": 10 + }, + { + "task_id": "retail_003", + "user_scenario": "You are Sophia Lee (customer ID: CUST003). You want to return the Cotton T-Shirts from order ORD003 because they don't fit. Your order was delivered on December 17, 2025. If the agent says it's past the return window, accept that.", + "expected_db": { + "orders": [ + { + "order_id": "ORD003", + "status": "delivered" + } + ] + }, + "max_turns": 8 + }, + { + "task_id": "retail_004", + "user_scenario": "You are Liam Martinez (customer ID: CUST004). You placed order ORD004 and want to know when it will arrive. You also want to know what items are in the order since you forgot what you bought.", + "expected_db": { + "orders": [ + { + "order_id": "ORD004", + "status": "shipped" + } + ] + }, + "max_turns": 8 + }, + { + "task_id": "retail_005", + "user_scenario": "You are Olivia Davis (customer ID: CUST005). You received your Yoga Mat from order ORD005 but want to exchange it for a Bluetooth Speaker instead since you already have a yoga mat at home. Accept whatever the price difference is.", + "expected_db": { + "orders": [ + { + "order_id": "ORD005", + "status": "delivered" + } + ] + }, + "max_turns": 10 + } +] diff --git a/tests/benchmark/tau/domains/retail/tools.ts b/tests/benchmark/tau/domains/retail/tools.ts new file mode 100644 index 0000000..6749502 --- /dev/null +++ b/tests/benchmark/tau/domains/retail/tools.ts @@ -0,0 +1,147 @@ +// --------------------------------------------------------------------------- +// Retail domain tool definitions (Anthropic API format) +// --------------------------------------------------------------------------- + +export interface ToolDef { + name: string; + description: string; + input_schema: Record; +} + +export function getRetailToolDefs(): ToolDef[] { + return [ + { + name: 'get_customer_details', + description: + 'Look up customer information by customer ID. Returns name, email, phone, and membership tier.', + input_schema: { + type: 'object', + properties: { + customer_id: { + type: 'string', + description: 'The customer ID (e.g. "CUST001")', + }, + }, + required: ['customer_id'], + }, + }, + { + name: 'get_order_details', + description: + 'Look up an order by order ID. Returns items, total, status, and dates.', + input_schema: { + type: 'object', + properties: { + order_id: { + type: 'string', + description: 'The order ID (e.g. "ORD001")', + }, + }, + required: ['order_id'], + }, + }, + { + name: 'list_customer_orders', + description: + 'List all orders for a given customer.', + input_schema: { + type: 'object', + properties: { + customer_id: { + type: 'string', + description: 'The customer ID whose orders to list', + }, + }, + required: ['customer_id'], + }, + }, + { + name: 'get_product_details', + description: + 'Get details for a specific product including price, category, and stock.', + input_schema: { + type: 'object', + properties: { + product_id: { + type: 'string', + description: 'The product ID (e.g. "PROD001")', + }, + }, + required: ['product_id'], + }, + }, + { + name: 'search_products', + description: + 'Search for products by name or category. Returns matching products with availability.', + input_schema: { + type: 'object', + properties: { + query: { + type: 'string', + description: 'Search term to match against product name or category', + }, + }, + required: ['query'], + }, + }, + { + name: 'process_return', + description: + 'Process a return for an order. Sets the order status to "returned" and restocks items. Only valid for delivered orders within the return window.', + input_schema: { + type: 'object', + properties: { + order_id: { + type: 'string', + description: 'The order ID to return', + }, + }, + required: ['order_id'], + }, + }, + { + name: 'process_exchange', + description: + 'Exchange an item in an order for a different product. The original item is restocked and the new item is deducted from stock.', + input_schema: { + type: 'object', + properties: { + order_id: { + type: 'string', + description: 'The order ID containing the item to exchange', + }, + old_product_id: { + type: 'string', + description: 'The product ID being returned', + }, + new_product_id: { + type: 'string', + description: 'The product ID to exchange for', + }, + }, + required: ['order_id', 'old_product_id', 'new_product_id'], + }, + }, + { + name: 'update_order_status', + description: + 'Update the status of an order (e.g. to "cancelled").', + input_schema: { + type: 'object', + properties: { + order_id: { + type: 'string', + description: 'The order ID to update', + }, + new_status: { + type: 'string', + enum: ['pending', 'shipped', 'delivered', 'cancelled', 'returned'], + description: 'The new status', + }, + }, + required: ['order_id', 'new_status'], + }, + }, + ]; +} diff --git a/tests/benchmark/tau/environment.ts b/tests/benchmark/tau/environment.ts new file mode 100644 index 0000000..403b7f8 --- /dev/null +++ b/tests/benchmark/tau/environment.ts @@ -0,0 +1,44 @@ +// --------------------------------------------------------------------------- +// TAU benchmark environment — manages DB state and dispatches tool calls +// --------------------------------------------------------------------------- + +export type ToolHandler = (db: any, args: any) => any; + +export class Environment { + private db: any; + private toolCallLog: Array<{ name: string; args: any; result: any }> = []; + private handlers: Record; + + constructor(initialDb: any, handlers: Record) { + // Deep clone so each trial gets an isolated copy + this.db = JSON.parse(JSON.stringify(initialDb)); + this.handlers = handlers; + } + + /** Return current database state (deep clone) as a generic record. */ + getState(): Record { + return JSON.parse(JSON.stringify(this.db)); + } + + /** Return log of all tool calls made during this simulation. */ + getToolCallLog() { + return this.toolCallLog; + } + + /** Dispatch a tool call by name. Returns the tool result as a JSON-serialisable value. */ + executeTool(name: string, args: any): any { + let result: any; + try { + const handler = this.handlers[name]; + if (!handler) { + result = { error: `Unknown tool: ${name}` }; + } else { + result = handler(this.db, args); + } + } catch (err: any) { + result = { error: err.message || String(err) }; + } + this.toolCallLog.push({ name, args, result }); + return result; + } +} diff --git a/tests/benchmark/tau/evaluator.ts b/tests/benchmark/tau/evaluator.ts new file mode 100644 index 0000000..f91b992 --- /dev/null +++ b/tests/benchmark/tau/evaluator.ts @@ -0,0 +1,69 @@ +// --------------------------------------------------------------------------- +// TAU benchmark evaluator — DB state comparison + pass^k calculation +// --------------------------------------------------------------------------- + +/** + * Compare the final database state against expected changes. + * + * `expectedDb` is a partial DB snapshot: for each table, an array of objects + * specifying the fields that must match. Each object must contain the table's + * primary key field (e.g. `reservation_id`) so we can look up the record. + * + * Returns `true` if all specified fields in all expected records match. + */ +export function evaluateDBState( + finalDb: Record, + expectedDb: Record, +): boolean { + for (const [table, expectedRecords] of Object.entries(expectedDb)) { + const actualRecords: any[] = finalDb[table]; + if (!actualRecords) return false; + + for (const expected of expectedRecords) { + // Find primary key field (first field ending with _id) + const pkField = Object.keys(expected).find(k => k.endsWith('_id')); + if (!pkField) continue; + + const actual = actualRecords.find(r => r[pkField] === expected[pkField]); + if (!actual) return false; + + // Check all specified fields + for (const [key, value] of Object.entries(expected)) { + if (actual[key] !== value) return false; + } + } + } + + return true; +} + +/** + * Compute pass^k metrics from trial results. + * + * For each task, we have an array of boolean results (one per trial). + * pass^k = fraction of tasks where ALL of the first k trials passed. + * + * Returns an array [pass^1, pass^2, ..., pass^numTrials]. + */ +export function computePassK( + taskTrialResults: boolean[][], + numTrials: number, +): number[] { + if (taskTrialResults.length === 0) return []; + + const passAtK: number[] = []; + + for (let k = 1; k <= numTrials; k++) { + let passCount = 0; + for (const trials of taskTrialResults) { + // Check if all of the first k trials passed + const firstK = trials.slice(0, k); + if (firstK.length >= k && firstK.every(r => r)) { + passCount++; + } + } + passAtK.push(passCount / taskTrialResults.length); + } + + return passAtK; +} diff --git a/tests/benchmark/tau/index.ts b/tests/benchmark/tau/index.ts new file mode 100644 index 0000000..fe90c9c --- /dev/null +++ b/tests/benchmark/tau/index.ts @@ -0,0 +1,252 @@ +// --------------------------------------------------------------------------- +// TAU benchmark module — BenchmarkModule entry point +// --------------------------------------------------------------------------- + +import fs from 'fs'; +import path from 'path'; +import type { BenchmarkConfig, BenchmarkModuleResult, BenchmarkProvider, TAUProviderResult, TAUTaskResult } from '../types'; +import type { ModelProvider } from '../../../src/infra/providers/types'; +import { AnthropicProvider } from '../../../src/infra/providers/anthropic'; +import { OpenAIProvider } from '../../../src/infra/providers/openai'; +import { GeminiProvider } from '../../../src/infra/providers/gemini'; +import { getInitialDatabase as getAirlineDb } from './domains/airline/database'; +import { getAirlineToolDefs } from './domains/airline/tools'; +import { getAirlineHandlers } from './domains/airline/handlers'; +import { getInitialDatabase as getRetailDb } from './domains/retail/database'; +import { getRetailToolDefs } from './domains/retail/tools'; +import { getRetailHandlers } from './domains/retail/handlers'; +import type { ToolHandler } from './environment'; +import { Environment } from './environment'; +import { UserSimulator } from './user-simulator'; +import { runOrchestration } from './orchestrator'; +import { evaluateDBState, computePassK } from './evaluator'; + +// Module metadata (used by run-benchmark.ts discovery) +export const name = 'tau'; + +// --------------------------------------------------------------------------- +// Domain loading +// --------------------------------------------------------------------------- + +interface DomainData { + id: string; + policy: string; + toolDefs: any[]; + getInitialDatabase: () => any; + getHandlers: () => Record; + tasks: Array<{ + task_id: string; + user_scenario: string; + expected_db: Record; + max_turns: number; + }>; +} + +function loadDomain(domainId: string): DomainData | null { + const domainDir = path.join(__dirname, 'domains', domainId); + const policyPath = path.join(domainDir, 'policy.md'); + const tasksPath = path.join(domainDir, 'tasks.json'); + + if (!fs.existsSync(policyPath) || !fs.existsSync(tasksPath)) return null; + + const policy = fs.readFileSync(policyPath, 'utf-8'); + const tasks = JSON.parse(fs.readFileSync(tasksPath, 'utf-8')); + + switch (domainId) { + case 'airline': + return { + id: domainId, + policy, + toolDefs: getAirlineToolDefs(), + getInitialDatabase: getAirlineDb, + getHandlers: getAirlineHandlers, + tasks, + }; + case 'retail': + return { + id: domainId, + policy, + toolDefs: getRetailToolDefs(), + getInitialDatabase: getRetailDb, + getHandlers: getRetailHandlers, + tasks, + }; + default: + return null; + } +} + +function getAvailableDomains(tauDomain: string): DomainData[] { + const domains: DomainData[] = []; + const candidates = tauDomain === 'all' ? ['airline', 'retail'] : [tauDomain]; + + for (const id of candidates) { + const domain = loadDomain(id); + if (domain) domains.push(domain); + } + + return domains; +} + +// --------------------------------------------------------------------------- +// Provider creation +// --------------------------------------------------------------------------- + +function createProvider(bp: BenchmarkProvider): ModelProvider { + switch (bp.id) { + case 'anthropic': + return new AnthropicProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + case 'openai': + return new OpenAIProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + case 'gemini': + return new GeminiProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + default: + // For glm, minimax, etc. — try OpenAI-compatible + return new OpenAIProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); + } +} + +// --------------------------------------------------------------------------- +// Build system prompt +// --------------------------------------------------------------------------- + +const DOMAIN_ROLES: Record = { + airline: 'an airline customer service agent', + retail: 'an online retail customer service agent', +}; + +function buildSystemPrompt(domainId: string, policy: string, toolDefs: any[]): string { + const role = DOMAIN_ROLES[domainId] || 'a customer service agent'; + const toolList = toolDefs.map(t => `- ${t.name}: ${t.description}`).join('\n'); + return [ + `You are ${role}. Follow the policy below strictly.`, + '', + '--- POLICY ---', + policy, + '--- END POLICY ---', + '', + 'Available tools:', + toolList, + '', + 'Instructions:', + '- Use tools to look up and modify data. Do not guess or make up information.', + '- When the customer\'s issue is fully resolved, include "###STOP###" at the end of your final message.', + '- Be concise and professional.', + ].join('\n'); +} + +// --------------------------------------------------------------------------- +// Run single provider across a domain +// --------------------------------------------------------------------------- + +async function runProviderOnDomain( + bp: BenchmarkProvider, + userSimBp: BenchmarkProvider, + domain: DomainData, + config: BenchmarkConfig, +): Promise { + const agentProvider = createProvider(bp); + const userSimProvider = createProvider(userSimBp); + const systemPrompt = buildSystemPrompt(domain.id, domain.policy, domain.toolDefs); + const results: TAUTaskResult[] = []; + + // Collect pass/fail per task across trials for pass^k calculation + const taskTrialMatrix: boolean[][] = []; + + for (const task of domain.tasks) { + const trialResults: boolean[] = []; + let totalTokens = 0; + let lastError: string | undefined; + + for (let trial = 0; trial < config.numTrials; trial++) { + // Fresh environment for each trial + const env = new Environment(domain.getInitialDatabase(), domain.getHandlers()); + const userSim = new UserSimulator(userSimProvider, task.user_scenario); + + const orchResult = await runOrchestration({ + agentProvider, + userSimulator: userSim, + environment: env, + systemPrompt, + toolDefs: domain.toolDefs, + maxTurns: task.max_turns, + timeoutMs: config.timeoutMs, + expectedDb: task.expected_db, + evaluate: evaluateDBState, + }); + + trialResults.push(orchResult.passed); + totalTokens += orchResult.agentTokens; + if (orchResult.error) lastError = orchResult.error; + + // Log progress + const status = orchResult.passed ? 'PASS' : 'FAIL'; + const errorSuffix = orchResult.error ? ` (${orchResult.error})` : ''; + console.log( + ` [${bp.id}] ${task.task_id} trial ${trial + 1}/${config.numTrials}: ${status} (${orchResult.turns} turns, ${orchResult.agentTokens} tokens)${errorSuffix}`, + ); + } + + taskTrialMatrix.push(trialResults); + results.push({ + task_id: task.task_id, + trial_pass_rates: trialResults, + tokens_used: Math.round(totalTokens / config.numTrials), + error: trialResults.every(r => !r) ? lastError : undefined, + }); + } + + // Compute pass^k + const passAtK = computePassK(taskTrialMatrix, config.numTrials); + const avgTokens = + results.length > 0 ? Math.round(results.reduce((s, r) => s + r.tokens_used, 0) / results.length) : 0; + + return { + provider: bp, + summary: { + domain: domain.id, + total_tasks: domain.tasks.length, + num_trials: config.numTrials, + pass_at_k: passAtK, + avg_tokens: avgTokens, + }, + results, + }; +} + +// --------------------------------------------------------------------------- +// Module entry point +// --------------------------------------------------------------------------- + +export async function run(config: BenchmarkConfig): Promise { + const domains = getAvailableDomains(config.tauDomain); + + if (domains.length === 0) { + console.log(` TAU: no domains found for "${config.tauDomain}"`); + return {}; + } + + if (config.providers.length === 0) { + console.log(' TAU: no providers configured, skipping'); + return {}; + } + + const allResults: TAUProviderResult[] = []; + + for (const domain of domains) { + console.log(`\n TAU domain: ${domain.id} (${domain.tasks.length} tasks, ${config.numTrials} trials)`); + + for (const bp of config.providers) { + // Use userSimProvider if configured, otherwise same as agent provider + const userSimBp = config.userSimProvider ?? bp; + + console.log(`\n Running provider: ${bp.id} / ${bp.model}`); + console.log(` User simulator: ${userSimBp.id} / ${userSimBp.model}`); + + const providerResult = await runProviderOnDomain(bp, userSimBp, domain, config); + allResults.push(providerResult); + } + } + + return { tau: allResults }; +} diff --git a/tests/benchmark/tau/orchestrator.ts b/tests/benchmark/tau/orchestrator.ts new file mode 100644 index 0000000..0c34d70 --- /dev/null +++ b/tests/benchmark/tau/orchestrator.ts @@ -0,0 +1,201 @@ +// --------------------------------------------------------------------------- +// TAU benchmark orchestrator — Agent ↔ User ↔ Environment message loop +// +// Follows τ-bench protocol: +// 1. User initiates conversation +// 2. Agent responds (text or tool calls) +// 3. If tool calls → environment executes → results fed back to agent +// 4. If text → forwarded to user simulator +// 5. Repeat until ###STOP### or max turns +// --------------------------------------------------------------------------- + +import type { ModelProvider } from '../../../src/infra/providers/types'; +import type { Message, ContentBlock } from '../../../src/core/types'; +import type { ToolDef } from './domains/airline/tools'; +import { Environment } from './environment'; +import { UserSimulator } from './user-simulator'; + +const STOP_SIGNAL = '###STOP###'; +const MAX_TOOL_ROUNDS = 10; // Safety limit for consecutive tool calls in one turn + +export interface ConversationMessage { + role: 'user' | 'assistant'; + content: string; +} + +export interface OrchestrationResult { + passed: boolean; + messages: ConversationMessage[]; + agentTokens: number; + userSimTokens: number; + turns: number; + error?: string; +} + +export interface OrchestrationOptions { + agentProvider: ModelProvider; + userSimulator: UserSimulator; + environment: Environment; + systemPrompt: string; + toolDefs: ToolDef[]; + maxTurns: number; + timeoutMs: number; + expectedDb: Record; + evaluate: (finalDb: Record, expectedDb: Record) => boolean; +} + +export async function runOrchestration(opts: OrchestrationOptions): Promise { + const { + agentProvider, + userSimulator, + environment, + systemPrompt, + toolDefs, + maxTurns, + timeoutMs, + expectedDb, + evaluate, + } = opts; + + const conversationLog: ConversationMessage[] = []; + // Internal messages in SDK format for model.complete() + const modelMessages: Message[] = []; + let agentTokens = 0; + let userSimTokens = 0; + let turns = 0; + + try { + // Wrap the entire orchestration in a timeout + const result = await withTimeout(async () => { + // 1. User generates first message + const firstMsg = await userSimulator.generateFirstMessage(); + userSimTokens += firstMsg.tokens; + conversationLog.push({ role: 'user', content: firstMsg.text }); + modelMessages.push(textMsg('user', firstMsg.text)); + + // 2. Conversation loop + while (turns < maxTurns) { + // --- Agent turn --- + const agentText = await runAgentTurn( + agentProvider, + modelMessages, + systemPrompt, + toolDefs, + environment, + (t) => { agentTokens += t; }, + ); + + conversationLog.push({ role: 'assistant', content: agentText }); + turns++; + + // Check agent stop signal + if (agentText.includes(STOP_SIGNAL)) break; + + // --- User turn --- + const userReply = await userSimulator.generateResponse( + agentText, + conversationLog.slice(0, -1), // history without the latest agent msg + ); + userSimTokens += userReply.tokens; + conversationLog.push({ role: 'user', content: userReply.text }); + modelMessages.push(textMsg('user', userReply.text)); + + // Check user stop signal + if (userReply.done) break; + } + + // 3. Evaluate + const finalDb = environment.getState(); + const passed = evaluate(finalDb, expectedDb); + + return { passed, messages: conversationLog, agentTokens, userSimTokens, turns }; + }, timeoutMs); + + return result; + } catch (err: any) { + return { + passed: false, + messages: conversationLog, + agentTokens, + userSimTokens, + turns, + error: err.message || String(err), + }; + } +} + +// --------------------------------------------------------------------------- +// Agent turn: call model, handle tool loops, return final text +// --------------------------------------------------------------------------- + +async function runAgentTurn( + provider: ModelProvider, + modelMessages: Message[], + systemPrompt: string, + toolDefs: ToolDef[], + environment: Environment, + addTokens: (t: number) => void, +): Promise { + let toolRounds = 0; + + while (toolRounds < MAX_TOOL_ROUNDS) { + const response = await provider.complete(modelMessages, { + system: systemPrompt, + tools: toolDefs, + }); + + const usage = (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); + addTokens(usage); + + // Separate text and tool_use blocks + const textBlocks = response.content.filter( + (b): b is { type: 'text'; text: string } => b.type === 'text', + ); + const toolUseBlocks = response.content.filter( + (b): b is { type: 'tool_use'; id: string; name: string; input: any } => b.type === 'tool_use', + ); + + // If no tool calls, return text + if (toolUseBlocks.length === 0) { + const text = textBlocks.map(b => b.text).join(''); + modelMessages.push({ role: 'assistant', content: response.content }); + return text; + } + + // Handle tool calls + modelMessages.push({ role: 'assistant', content: response.content }); + + const toolResults: ContentBlock[] = toolUseBlocks.map(tc => { + const result = environment.executeTool(tc.name, tc.input); + return { + type: 'tool_result' as const, + tool_use_id: tc.id, + content: JSON.stringify(result), + }; + }); + + modelMessages.push({ role: 'user', content: toolResults }); + toolRounds++; + } + + // Safety: too many tool rounds — return whatever text we have + return '[Agent exceeded maximum tool call rounds]'; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function textMsg(role: 'user' | 'assistant' | 'system', text: string): Message { + return { role, content: [{ type: 'text', text }] }; +} + +function withTimeout(fn: () => Promise, ms: number): Promise { + return new Promise((resolve, reject) => { + const timer = setTimeout(() => reject(new Error(`Timeout after ${ms}ms`)), ms); + fn().then( + v => { clearTimeout(timer); resolve(v); }, + e => { clearTimeout(timer); reject(e); }, + ); + }); +} diff --git a/tests/benchmark/tau/user-simulator.ts b/tests/benchmark/tau/user-simulator.ts new file mode 100644 index 0000000..9bd5c0f --- /dev/null +++ b/tests/benchmark/tau/user-simulator.ts @@ -0,0 +1,107 @@ +// --------------------------------------------------------------------------- +// TAU benchmark user simulator — LLM-powered simulated user +// --------------------------------------------------------------------------- + +import type { ModelProvider } from '../../../src/infra/providers/types'; +import type { Message, ContentBlock } from '../../../src/core/types'; + +const STOP_SIGNAL = '###STOP###'; + +export interface UserSimulatorResult { + text: string; + tokens: number; + done: boolean; +} + +export class UserSimulator { + private provider: ModelProvider; + private scenario: string; + + constructor(provider: ModelProvider, scenario: string) { + this.provider = provider; + this.scenario = scenario; + } + + /** + * Generate the first user message (initiating the conversation). + */ + async generateFirstMessage(): Promise { + const messages: Message[] = [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Generate your opening message to the customer service agent. Be natural and concise — state who you are and what you need. Respond with ONLY the message text, nothing else.', + }, + ], + }, + ]; + + return this.callModel(messages); + } + + /** + * Generate the next user response based on the agent's message. + */ + async generateResponse(agentMessage: string, history: Array<{ role: string; content: string }>): Promise { + // Build conversation history for the user simulator + const messages: Message[] = []; + + // Add previous turns (alternating user/assistant from the USER's perspective: + // the user simulator sees agent messages as "user" input and its own messages as "assistant" output) + for (const msg of history) { + if (msg.role === 'user') { + // This was a user-sim output — from user-sim's perspective it's "assistant" + messages.push({ role: 'assistant', content: [{ type: 'text', text: msg.content }] }); + } else if (msg.role === 'assistant') { + // This was an agent output — from user-sim's perspective it's "user" input + messages.push({ role: 'user', content: [{ type: 'text', text: msg.content }] }); + } + } + + // Latest agent message + messages.push({ + role: 'user', + content: [ + { + type: 'text', + text: `The customer service agent said:\n\n${agentMessage}\n\nRespond as the customer. If your issue is resolved, say goodbye naturally and include "${STOP_SIGNAL}" at the end of your message. Respond with ONLY the message text, nothing else.`, + }, + ], + }); + + return this.callModel(messages); + } + + private async callModel(messages: Message[]): Promise { + const systemPrompt = [ + 'You are simulating a customer calling airline customer service.', + 'Follow this scenario exactly:', + '', + this.scenario, + '', + 'Rules:', + '- Stay in character. Only say things consistent with your scenario.', + '- Be natural and conversational, like a real customer.', + '- Provide information when asked (your name, reservation ID, etc.).', + '- Do not invent details not in your scenario.', + `- When your issue is fully resolved and you have no more questions, include "${STOP_SIGNAL}" at the end of your final message.`, + '- Respond with ONLY the customer message text. Do not add any meta-commentary.', + ].join('\n'); + + const response = await this.provider.complete(messages, { system: systemPrompt }); + + const text = response.content + .filter((b): b is { type: 'text'; text: string } => b.type === 'text') + .map(b => b.text) + .join(''); + + const tokens = + (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); + + const done = text.includes(STOP_SIGNAL); + + return { text: text.replace(STOP_SIGNAL, '').trim(), tokens, done }; + } +} diff --git a/tests/benchmark/types.ts b/tests/benchmark/types.ts new file mode 100644 index 0000000..789cdc1 --- /dev/null +++ b/tests/benchmark/types.ts @@ -0,0 +1,152 @@ +import type { ProviderId } from '../helpers/provider-env'; + +// --------------------------------------------------------------------------- +// Provider +// --------------------------------------------------------------------------- + +export interface BenchmarkProvider { + id: ProviderId; + model: string; + apiKey: string; + baseUrl?: string; + proxyUrl?: string; +} + +// --------------------------------------------------------------------------- +// CLI args +// --------------------------------------------------------------------------- + +export interface BenchmarkCliArgs { + sweOnly: boolean; + tauOnly: boolean; + sweMode?: 'mini' | 'full'; + tauDomain?: string; + provider?: string; + numTrials?: number; + output?: 'table' | 'json' | 'html' | 'both'; + outputFile?: string; + compare?: string; +} + +// --------------------------------------------------------------------------- +// Config (merged env + CLI) +// --------------------------------------------------------------------------- + +export interface BenchmarkConfig { + providers: BenchmarkProvider[]; + userSimProvider?: BenchmarkProvider; + timeoutMs: number; + numTrials: number; + output: 'table' | 'json' | 'html' | 'both'; + outputFile: string; + sweMode: 'mini' | 'full'; + tauDomain: string; + sdkVersion: string; + dockerProxy?: string; +} + +// --------------------------------------------------------------------------- +// SWE-bench types +// --------------------------------------------------------------------------- + +export interface SWEInstance { + instance_id: string; + repo: string; + base_commit: string; + patch: string; + test_patch: string; + problem_statement: string; + hints_text: string; + created_at: string; + version: string; +} + +export interface MiniSWECase { + id: string; + repo: string; + description: string; + files: Record; + expected_patch: string; + test_command: string; +} + +export interface SWEResult { + instance_id: string; + resolved: boolean; + tokens_used: number; + duration_ms: number; + error?: string; +} + +export interface SWESummary { + dataset: string; + total: number; + resolved: number; + rate: number; + avg_tokens: number; + avg_duration_ms: number; +} + +export interface SWEProviderResult { + provider: BenchmarkProvider; + summary: SWESummary; + results: SWEResult[]; +} + +// --------------------------------------------------------------------------- +// TAU-bench types +// --------------------------------------------------------------------------- + +export interface TAUTask { + task_id: string; + domain: string; + user_instruction: string; + expected_actions: string[]; + tools: string[]; +} + +export interface TAUTaskResult { + task_id: string; + trial_pass_rates: boolean[]; + tokens_used: number; + error?: string; +} + +export interface TAUSummary { + domain: string; + total_tasks: number; + num_trials: number; + pass_at_k: number[]; + avg_tokens: number; +} + +export interface TAUProviderResult { + provider: BenchmarkProvider; + summary: TAUSummary; + results: TAUTaskResult[]; +} + +// --------------------------------------------------------------------------- +// Top-level report +// --------------------------------------------------------------------------- + +export interface BenchmarkReport { + timestamp: string; + sdk_version: string; + swe?: SWEProviderResult[]; + tau?: TAUProviderResult[]; +} + +// --------------------------------------------------------------------------- +// Module contract (Phase 2+ modules implement this) +// --------------------------------------------------------------------------- + +export interface BenchmarkModuleResult { + swe?: SWEProviderResult[]; + tau?: TAUProviderResult[]; +} + +export interface BenchmarkModule { + name: string; + run(config: BenchmarkConfig): Promise; +} diff --git a/tests/unit/providers/openai.test.ts b/tests/unit/providers/openai.test.ts index 0efd3d2..14e94bf 100644 --- a/tests/unit/providers/openai.test.ts +++ b/tests/unit/providers/openai.test.ts @@ -10,6 +10,11 @@ runner const config = provider.toConfig(); expect.toEqual(config.baseUrl, 'https://api.openai.com/v1'); }) + .test('baseUrl 保留已有版本路径 /v4 (GLM coding endpoint)', async () => { + const provider = new OpenAIProvider('test-key', 'any-model', 'https://open.bigmodel.cn/api/coding/paas/v4'); + const config = provider.toConfig(); + expect.toEqual(config.baseUrl, 'https://open.bigmodel.cn/api/coding/paas/v4'); + }) .test('请求体包含 system 与工具调用结构', async () => { const provider = new OpenAIProvider('test-key', 'gpt-4o', 'https://api.openai.com'); const messages: Message[] = [ From 007e6c202686496b7252df31f926ebabcc4c86db Mon Sep 17 00:00:00 2001 From: Gui-Yue Date: Thu, 26 Feb 2026 22:43:36 +0800 Subject: [PATCH 2/3] refactor(benchmark): simplify to SWE-Verified + TB2 official flow, fix TB2 scoring, and trim dead TAU assets - unify benchmark pipeline around --benchmark=swe|tb2|both - switch SWE runner to verified-only instances - add official TB2 harbor wrapper runner (run-tb2-official.ts) - fix TB2 score parsing to ignore top-level summary result.json - remove legacy mini/full SWE + TAU + HTML benchmark code paths - delete unused TAU domain fixtures under tests/benchmark/tau/domains - update benchmark result docs (EN/ZH) to confirmed latest scores --- .github/workflows/benchmark.yml | 167 ++++++ README.md | 1 + README.zh-CN.md | 1 + docs/en/guides/benchmark-results.md | 32 ++ docs/en/guides/benchmarking.md | 508 +++-------------- docs/zh-CN/guides/benchmark-results.md | 32 ++ docs/zh-CN/guides/benchmarking.md | 514 +++--------------- tests/benchmark/compare.ts | 184 +++---- tests/benchmark/config.ts | 96 ++-- tests/benchmark/html-reporter.ts | 360 ------------ tests/benchmark/reporter.ts | 105 ++-- tests/benchmark/run-benchmark.ts | 113 ++-- tests/benchmark/run-tb2-official.ts | 487 +++++++++++++++++ tests/benchmark/swe/cases/mini-cases.json | 182 ------- ...instances.json => verified-instances.json} | 0 tests/benchmark/swe/dataset.ts | 33 +- tests/benchmark/swe/docker-evaluator.ts | 6 +- tests/benchmark/swe/evaluator.ts | 64 --- tests/benchmark/swe/harness.ts | 136 ----- tests/benchmark/swe/index.ts | 189 +------ .../benchmark/tau/domains/airline/database.ts | 220 -------- .../benchmark/tau/domains/airline/handlers.ts | 74 --- tests/benchmark/tau/domains/airline/policy.md | 45 -- .../benchmark/tau/domains/airline/tasks.json | 71 --- tests/benchmark/tau/domains/airline/tools.ts | 127 ----- .../benchmark/tau/domains/retail/database.ts | 156 ------ .../benchmark/tau/domains/retail/handlers.ts | 147 ----- tests/benchmark/tau/domains/retail/policy.md | 53 -- tests/benchmark/tau/domains/retail/tasks.json | 67 --- tests/benchmark/tau/domains/retail/tools.ts | 147 ----- tests/benchmark/tau/environment.ts | 44 -- tests/benchmark/tau/evaluator.ts | 69 --- tests/benchmark/tau/index.ts | 252 --------- tests/benchmark/tau/orchestrator.ts | 201 ------- tests/benchmark/tau/user-simulator.ts | 107 ---- tests/benchmark/types.ts | 127 ++--- 36 files changed, 1116 insertions(+), 4001 deletions(-) create mode 100644 .github/workflows/benchmark.yml create mode 100644 docs/en/guides/benchmark-results.md create mode 100644 docs/zh-CN/guides/benchmark-results.md delete mode 100644 tests/benchmark/html-reporter.ts create mode 100644 tests/benchmark/run-tb2-official.ts delete mode 100644 tests/benchmark/swe/cases/mini-cases.json rename tests/benchmark/swe/cases/{curated-instances.json => verified-instances.json} (100%) delete mode 100644 tests/benchmark/swe/evaluator.ts delete mode 100644 tests/benchmark/swe/harness.ts delete mode 100644 tests/benchmark/tau/domains/airline/database.ts delete mode 100644 tests/benchmark/tau/domains/airline/handlers.ts delete mode 100644 tests/benchmark/tau/domains/airline/policy.md delete mode 100644 tests/benchmark/tau/domains/airline/tasks.json delete mode 100644 tests/benchmark/tau/domains/airline/tools.ts delete mode 100644 tests/benchmark/tau/domains/retail/database.ts delete mode 100644 tests/benchmark/tau/domains/retail/handlers.ts delete mode 100644 tests/benchmark/tau/domains/retail/policy.md delete mode 100644 tests/benchmark/tau/domains/retail/tasks.json delete mode 100644 tests/benchmark/tau/domains/retail/tools.ts delete mode 100644 tests/benchmark/tau/environment.ts delete mode 100644 tests/benchmark/tau/evaluator.ts delete mode 100644 tests/benchmark/tau/index.ts delete mode 100644 tests/benchmark/tau/orchestrator.ts delete mode 100644 tests/benchmark/tau/user-simulator.ts diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 0000000..53bc29a --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,167 @@ +name: Benchmark Full Suite + +on: + workflow_dispatch: + inputs: + benchmark: + description: "Which benchmark to run" + type: choice + required: true + default: both + options: + - both + - swe + - tb2 + provider: + description: "SWE provider filter" + type: choice + required: true + default: all + options: + - all + - anthropic + - openai + - gemini + tb2_model: + description: "TB2 model in provider/model format" + type: string + required: true + default: openai/glm-5 + push: + branches: + - add_benchmark_test + pull_request: + branches: + - main + +env: + NODE_VERSION: "20" + +permissions: + contents: read + +jobs: + benchmark: + name: Benchmark + runs-on: ubuntu-latest + timeout-minutes: 360 + env: + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: ${{ env.NODE_VERSION }} + cache: npm + + - name: Setup uv + uses: astral-sh/setup-uv@v4 + + - name: Login to Docker Hub (optional) + if: ${{ env.DOCKERHUB_USERNAME != '' && env.DOCKERHUB_TOKEN != '' }} + uses: docker/login-action@v3 + with: + username: ${{ env.DOCKERHUB_USERNAME }} + password: ${{ env.DOCKERHUB_TOKEN }} + + - name: Install dependencies + run: npm ci + + - name: Create benchmark environment + run: | + cat > .env.test << 'EOT' + ANTHROPIC_API_KEY=${{ secrets.ANTHROPIC_API_KEY }} + ANTHROPIC_MODEL_ID=${{ vars.ANTHROPIC_MODEL_ID }} + ANTHROPIC_BASE_URL=${{ vars.ANTHROPIC_BASE_URL }} + + OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} + OPENAI_MODEL_ID=${{ vars.OPENAI_MODEL_ID }} + OPENAI_BASE_URL=${{ vars.OPENAI_BASE_URL }} + + GEMINI_API_KEY=${{ secrets.GEMINI_API_KEY }} + GEMINI_MODEL_ID=${{ vars.GEMINI_MODEL_ID }} + GEMINI_BASE_URL=${{ vars.GEMINI_BASE_URL }} + + BENCHMARK_DOCKER_PROXY=${{ vars.BENCHMARK_DOCKER_PROXY }} + BENCHMARK_TIMEOUT_MS=${{ vars.BENCHMARK_TIMEOUT_MS }} + EOT + + - name: Run unified benchmark command + run: | + mkdir -p tests/tmp + args=( + --benchmark=${{ inputs.benchmark }} + --tb2-model=${{ inputs.tb2_model }} + --tb2-agent=oracle + --tb2-runner=uvx + --tb2-python=3.12 + --tb2-jobs-dir=./tests/tmp/jobs + --output=json + --output-file=tests/tmp/benchmark-report.json + ) + + if [[ "${{ inputs.provider }}" != "all" && "${{ inputs.benchmark }}" != "tb2" ]]; then + args+=(--provider=${{ inputs.provider }}) + fi + + npm run test:benchmark -- "${args[@]}" + + - name: Write step summary + if: ${{ always() }} + run: | + node - <<'NODE' >> "$GITHUB_STEP_SUMMARY" + const fs = require('fs'); + function readJson(p) { + if (!fs.existsSync(p)) return null; + try { return JSON.parse(fs.readFileSync(p, 'utf8')); } catch { return null; } + } + + const report = readJson('tests/tmp/benchmark-report.json'); + console.log('## Benchmark Report'); + console.log(''); + + if (!report) { + console.log('- report not found'); + process.exit(0); + } + + if (Array.isArray(report.swe) && report.swe.length > 0) { + console.log('### SWE-bench-Verified'); + console.log(''); + console.log('| Provider / Model | Resolved | Rate |'); + console.log('|---|---:|---:|'); + for (const r of report.swe) { + const name = `${r.provider.id} / ${r.provider.model}`; + const resolved = `${r.summary.resolved}/${r.summary.total}`; + const rate = `${(r.summary.rate * 100).toFixed(1)}%`; + console.log(`| ${name} | ${resolved} | ${rate} |`); + } + console.log(''); + } + + if (report.tb2) { + const tb2 = report.tb2; + console.log('### Terminal Bench 2.0'); + console.log(''); + console.log(`- Agent: \`${tb2.agent}\``); + if (tb2.model) console.log(`- Model: \`${tb2.model}\``); + console.log(`- Passed: **${tb2.passed}/${tb2.total}**`); + console.log(`- Rate: **${(tb2.rate * 100).toFixed(1)}%**`); + console.log(''); + } + NODE + + - name: Upload benchmark artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: benchmark-artifacts-${{ github.run_id }} + if-no-files-found: warn + path: | + tests/tmp/benchmark-report.json + tests/tmp/jobs/*/result.json diff --git a/README.md b/README.md index 47c3294..5787c82 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,7 @@ See [docs/en/guides/architecture.md](./docs/en/guides/architecture.md) for detai | [Providers](./docs/en/guides/providers.md) | Model provider configuration | | [Database](./docs/en/guides/database.md) | SQLite/PostgreSQL persistence | | [Resume & Fork](./docs/en/guides/resume-fork.md) | Crash recovery & branching | +| [Benchmark Results](./docs/en/guides/benchmark-results.md) | Confirmed benchmark score tables | | **Project** | | | [Contribution Guide](./docs/en/contribution.md) | How to contribute | | **Reference** | | diff --git a/README.zh-CN.md b/README.zh-CN.md index 912d1b8..c748994 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -106,6 +106,7 @@ npm run example:room # 多Agent协作 | [Provider 配置](./docs/zh-CN/guides/providers.md) | 模型 Provider 配置 | | [数据库存储](./docs/zh-CN/guides/database.md) | SQLite/PostgreSQL 持久化 | | [恢复与分叉](./docs/zh-CN/guides/resume-fork.md) | 崩溃恢复与分支 | +| [Benchmark 结果](./docs/zh-CN/guides/benchmark-results.md) | 已确认的跑分结果表格 | | **项目** | | | [贡献指南](./docs/zh-CN/contribution.md) | 提交 PR 的要求与流程 | | **参考** | | diff --git a/docs/en/guides/benchmark-results.md b/docs/en/guides/benchmark-results.md new file mode 100644 index 0000000..a98e132 --- /dev/null +++ b/docs/en/guides/benchmark-results.md @@ -0,0 +1,32 @@ +# Benchmark Results (Confirmed) + +Last updated: 2026-02-26 + +## SWE-bench-Verified + +| Provider / Model | Instances | Resolved | Rate | Avg Tokens | Avg Duration | +|---|---:|---:|---:|---:|---:| +| openai / glm-5 | 12 | 12/12 | 100.0% | 17.2k | 134.5k ms | + +Source: local full run log (`2026-02-25__21-06-21`). + +## Terminal Bench 2.0 + +| Agent / Model | Passed | Parseable | Unknown | Rate (parseable) | Notes | +|---|---:|---:|---:|---:|---| +| oracle / glm-5 | 1 | 31 | 58 | 3.2% | From the same full run; many tasks ended with runtime/timeout errors. | + +## Reproduce + +```bash +npm run test:benchmark -- \ + --benchmark=both \ + --tb2-model=openai/glm-5 \ + --tb2-agent=oracle \ + --tb2-runner=uvx \ + --tb2-jobs-dir=./tests/tmp/jobs \ + --output=json \ + --output-file=tests/tmp/benchmark-report.json +``` + +The JSON report includes both `swe` and `tb2` sections. diff --git a/docs/en/guides/benchmarking.md b/docs/en/guides/benchmarking.md index 08465e1..6cc2dbe 100644 --- a/docs/en/guides/benchmarking.md +++ b/docs/en/guides/benchmarking.md @@ -1,470 +1,122 @@ -# Benchmarking Guide +# Benchmarking -KODE SDK includes an integrated benchmark suite for evaluating LLM model capabilities in agent scenarios. The suite implements two industry-standard methodologies: +KODE SDK benchmark runner now has a single entry command and supports three targets: -- **SWE-bench** (Princeton/OpenAI) — Code bug-fixing: model receives an issue description + source code, generates a fix, tests verify correctness -- **τ-bench** (Sierra Research) — Multi-turn tool-use conversations: model acts as a customer service agent, uses tools, follows policy, and the final database state is evaluated - ---- +- `swe`: SWE-bench-Verified only +- `tb2`: Terminal Bench 2.0 only +- `both`: run both in one command ## Prerequisites -1. **Provider configuration** in `.env.test` — at least one provider with `API_KEY` and `MODEL_ID` configured. See [Provider Configuration Guide](./providers.md) for details. - -2. **Node.js** with `ts-node` available (included in devDependencies). - -3. **(Optional) Docker** — required only for SWE-bench full mode. Mini mode and TAU benchmarks run without Docker. - -### Minimal `.env.test` Setup - -```ini -# At least one provider is required -ANTHROPIC_API_KEY=sk-ant-... -ANTHROPIC_MODEL_ID=claude-sonnet-4-5-20250929 - -# Optional: additional providers to compare -OPENAI_API_KEY=sk-... -OPENAI_MODEL_ID=gpt-4o - -GEMINI_API_KEY=AIza... -GEMINI_MODEL_ID=gemini-2.5-pro -``` - ---- - -## Quick Start +1. Install dependencies: ```bash -# Run all benchmarks (SWE mini + TAU airline + TAU retail) -npm run test:benchmark - -# Run only SWE benchmark -npm run test:benchmark -- --swe-only - -# Run SWE full mode (requires Docker) -npm run test:benchmark -- --swe-only --swe-mode=full - -# Run only TAU benchmark -npm run test:benchmark -- --tau-only - -# Run with a specific provider -npm run test:benchmark -- --provider=anthropic - -# Output JSON report -npm run test:benchmark -- --output=json --output-file=results.json +npm ci ``` -> **Note:** Every benchmark run automatically generates an HTML visual report at `tests/tmp/benchmark-report-{timestamp}.html`. Open it in a browser to view detailed results with scores, charts, and per-case breakdowns. - ---- - -## SWE Benchmark - -The SWE benchmark evaluates a model's ability to fix bugs in source code. The model receives a bug description and the project files, then generates corrected code that must pass all tests. - -### Mini Mode (Default) - -Mini mode uses 20 built-in JavaScript bug-fix cases that run locally without Docker. Each case contains: -- A buggy `src.js` file -- A `test.js` file with assertions -- A bug description explaining the expected behavior +2. Create `.env.test` (or export env vars directly): ```bash -# Run mini-SWE benchmark -npm run test:benchmark -- --swe-only --swe-mode=mini -``` +ANTHROPIC_API_KEY=... +ANTHROPIC_MODEL_ID=claude-sonnet-4-20250514 -**Example output:** +OPENAI_API_KEY=... +OPENAI_MODEL_ID=glm-5 +GEMINI_API_KEY=... +GEMINI_MODEL_ID=gemini-3-pro-preview ``` - SWE mini mode: 20 cases - - Running provider: anthropic / claude-sonnet-4-5-20250929 - [anthropic] mini-swe-001: PASS (1772 tokens, 13186ms) - [anthropic] mini-swe-002: PASS (1246 tokens, 12162ms) - ... ---- SWE-bench (mini-swe) — 20 instances --- - -Provider / Model | Resolved | Rate | Avg Tokens | Avg ms --------------------------------------+----------+---------+------------+--------- -anthropic / claude-sonnet-4-5-20250… | 20/20 | 100.0% | 1.0k | 7.4k -``` +3. Runtime tools: +- SWE-bench-Verified: Docker is required +- TB2: `harbor`, `uvx`, or Docker (runner decides by `--tb2-runner`) -**Core metric:** `Resolved Rate` — the percentage of cases where the model's fix passes all tests. - -### Full Mode (Docker) - -Full mode uses real SWE-bench instances from open-source repositories. It evaluates model-generated patches using official pre-built SWE-bench Docker images from DockerHub. +## Unified Command ```bash -# Run full SWE-bench (requires Docker) -npm run test:benchmark -- --swe-only --swe-mode=full -``` - -The evaluator: -1. Clones the repository on the host and checks out the specified commit -2. Extracts relevant file paths from the problem statement and hints -3. Reads source files and sends them to the LLM along with the bug description -4. The LLM returns SEARCH/REPLACE blocks for the changed code sections -5. The framework applies the hunks and programmatically generates a unified diff -6. Pulls the official SWE-bench Docker image (`swebench/sweb.eval.x86_64.:latest`) -7. The container already has the repo at `/testbed` with all dependencies installed in a `testbed` conda environment -8. Applies the patch and runs the repository's test suite - -When Docker is not available, it falls back to local git clone + patch application (less reliable due to missing dependencies). - -The curated instances are defined in `tests/benchmark/swe/cases/curated-instances.json`. - -> **Note:** SWE-bench images are large (several GB each). The first run will take longer as images are downloaded. Subsequent runs reuse cached images. Configure `BENCHMARK_DOCKER_PROXY` if you need a proxy for Docker pulls. - ---- - -## TAU Benchmark - -The TAU benchmark (Tool-Agent-User) evaluates a model's ability to handle multi-turn customer service conversations while using tools correctly and following business policies. - -### Architecture - -``` -Orchestrator -├── Agent (model under test) — receives user messages, calls tools, follows policy -├── User Simulator (LLM) — plays the customer role based on a scenario script -└── Environment — executes tool calls, maintains database state +npm run test:benchmark -- [flags] ``` -**Evaluation:** After the conversation ends, the final database state is compared against the expected state. A task passes only if all expected fields match. - -### Available Domains +### Common examples -| Domain | Tasks | Tools | Description | -|--------|-------|-------|-------------| -| `airline` | 5 | 7 | Flight changes, cancellations, baggage inquiries | -| `retail` | 5 | 8 | Returns, exchanges, order status, product search | - -### Running TAU Benchmarks +Run both SWE + TB2 in one command: ```bash -# Run all TAU domains -npm run test:benchmark -- --tau-only - -# Run specific domain -npm run test:benchmark -- --tau-only --tau-domain=airline -npm run test:benchmark -- --tau-only --tau-domain=retail - -# Run with multiple trials (for pass^k reliability metric) -npm run test:benchmark -- --tau-only --num-trials=3 -``` - -**Example output:** - +npm run test:benchmark -- \ + --benchmark=both \ + --tb2-model=openai/glm-5 \ + --output=json \ + --output-file=tests/tmp/benchmark-report.json ``` - TAU domain: airline (5 tasks, 1 trials) - Running provider: anthropic / claude-sonnet-4-5-20250929 - User simulator: anthropic / claude-sonnet-4-5-20250929 - [anthropic] airline_001 trial 1/1: PASS (5 turns, 22341 tokens) - [anthropic] airline_002 trial 1/1: PASS (3 turns, 15280 tokens) - ... - ---- TAU-bench (airline) — 5 tasks, 1 trials --- - -Provider / Model | Pass^1 | Avg Tokens --------------------------------------+---------+----------- -anthropic / claude-sonnet-4-5-20250… | 80.0% | 18.1k -``` - -### Understanding pass^k - -The **pass^k** metric measures reliability across multiple independent trials of the same task: - -- **pass^1** = fraction of tasks passed in a single trial -- **pass^k** = fraction of tasks that passed in ALL k independent trials - -This captures consistency — a model with 80% pass^1 but 40% pass^3 is unreliable. Use `--num-trials=k` to compute pass^k. - -### User Simulator - -By default, the same model is used for both the agent and the user simulator. To use a different model for user simulation: - -```ini -# In .env.test -BENCHMARK_USER_MODEL=anthropic/claude-sonnet-4-5-20250929 -``` - -Format: `provider/model-id`. - ---- - -## CLI Reference - -All flags are passed after `--` to the npm script: +Run only SWE-bench-Verified: ```bash -npm run test:benchmark -- [flags] +npm run test:benchmark -- \ + --benchmark=swe \ + --provider=anthropic \ + --output=json \ + --output-file=tests/tmp/swe-report.json ``` -| Flag | Description | Default | -|------|-------------|---------| -| `--swe-only` | Run only SWE benchmarks | (run both) | -| `--tau-only` | Run only TAU benchmarks | (run both) | -| `--swe-mode=mini\|full` | SWE evaluation mode | `mini` | -| `--tau-domain=airline\|retail\|all` | TAU domain to evaluate | `all` | -| `--provider=NAME` | Run only the specified provider | (all configured) | -| `--num-trials=N` | Number of TAU trials per task (for pass^k) | `1` | -| `--output=table\|json\|html\|both` | Output format | `table` | -| `--output-file=PATH` | JSON/HTML report output path | `benchmark-report.json` | -| `--compare=PATH` | Compare current run against a baseline JSON report | (none) | - ---- - -## Environment Variables - -These can be set in `.env.test` alongside provider configuration: - -| Variable | Description | Default | -|----------|-------------|---------| -| `BENCHMARK_PROVIDERS` | Comma-separated list of providers to run | (all configured) | -| `BENCHMARK_TIMEOUT_MS` | Timeout per task in milliseconds | `120000` | -| `BENCHMARK_NUM_TRIALS` | Default number of TAU trials | `1` | -| `BENCHMARK_OUTPUT` | Output format | `table` | -| `BENCHMARK_USER_MODEL` | User simulator model (`provider/model`) | (same as agent) | -| `BENCHMARK_DOCKER_PROXY` | HTTP proxy URL for Docker containers and git clone | (none) | - -CLI flags override environment variables when both are set. - ---- - -## Historical Comparison - -Save a baseline report and compare future runs against it to detect regressions: +Run only TB2: ```bash -# 1. Save a baseline -npm run test:benchmark -- --output=json --output-file=baseline.json - -# 2. Later, compare a new run against the baseline -npm run test:benchmark -- --compare=baseline.json -``` - -The comparison output shows changes in key metrics with direction indicators: - -``` -================================================================================ -Benchmark Comparison -================================================================================ - Baseline: baseline.json - Current: (current run) - ---- SWE Comparison --- - -Metric | Baseline | Current | Delta | Dir ---------------------------------------------------------------------------------- -anthropic/claude-sonnet-4-5 [rate] | 100.0% | 100.0% | = | -anthropic/claude-sonnet-4-5 [resolved] | 20/20 | 20/20 | = | -anthropic/claude-sonnet-4-5 [tokens] | 1.0k | 986 | -45 | ^ - - No regressions detected. +npm run test:benchmark -- \ + --benchmark=tb2 \ + --tb2-model=openai/glm-5 \ + --tb2-agent=oracle \ + --tb2-runner=docker \ + --tb2-jobs-dir=./tests/tmp/jobs \ + --output=json \ + --output-file=tests/tmp/tb2-report.json ``` -- `^` = improvement (higher rate, lower tokens/latency) -- `v` = regression (lower rate, higher tokens/latency) -- Exit code is `1` if regressions are detected +## Flags ---- - -## JSON Report Format - -When using `--output=json` or `--output=both`, a JSON report is written: +| Flag | Description | Default | +|---|---|---| +| `--benchmark=swe\|tb2\|both` | Which benchmark(s) to run | `both` | +| `--provider=...` | SWE provider filter (`anthropic`, `openai`, `gemini`, etc.) | all discovered | +| `--tb2-model=provider/model` | TB2 model id | `BENCHMARK_TB2_MODEL` or `openai/$OPENAI_MODEL_ID` | +| `--tb2-agent=...` | TB2 agent (`oracle`, etc.) | `oracle` | +| `--tb2-dataset=...` | TB2 dataset id | `terminal-bench@2.0` | +| `--tb2-runner=auto\|harbor\|uvx\|docker` | TB2 execution backend | `auto` | +| `--tb2-python=3.12` | Python version for `uvx` runner | `3.12` | +| `--tb2-jobs-dir=PATH` | TB2 jobs directory | `tests/tmp/jobs` | +| `--tb2-env-file=PATH` | Env file passed to TB2 runner | auto-detect `.env.test` | +| `--tb2-docker-image=IMAGE` | Docker image for TB2 docker runner | `ghcr.io/astral-sh/uv:python3.12-bookworm` | +| `--output=table\|json` | Output mode | `table` | +| `--output-file=PATH` | JSON output file path (when `--output=json`) | `benchmark-report.json` | +| `--compare=PATH` | Compare against baseline JSON report | unset | + +## Output + +With `--output=json`, one report contains both sections: ```json { - "timestamp": "2026-02-12T10:30:00.000Z", + "timestamp": "2026-02-25T08:31:16.000Z", "sdk_version": "2.7.3", - "swe": [{ - "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, - "summary": { - "dataset": "mini-swe", - "total": 20, - "resolved": 20, - "rate": 1.0, - "avg_tokens": 1031, - "avg_duration_ms": 7420 - }, - "results": [ - { "instance_id": "mini-swe-001", "resolved": true, "tokens_used": 1772, "duration_ms": 13186 } - ] - }], - "tau": [{ - "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, - "summary": { - "domain": "airline", - "total_tasks": 5, - "num_trials": 1, - "pass_at_k": [0.8], - "avg_tokens": 18100 - }, - "results": [ - { "task_id": "airline_001", "trial_pass_rates": [true], "tokens_used": 22341 } - ] - }] + "swe": [ + { + "provider": { "id": "openai", "model": "glm-5" }, + "summary": { "dataset": "swe-bench-verified", "total": 12, "resolved": 10, "rate": 0.8333, "avg_tokens": 17500, "avg_duration_ms": 166000 } + } + ], + "tb2": { + "dataset": "terminal-bench@2.0", + "agent": "oracle", + "model": "openai/glm-5", + "passed": 0, + "total": 89, + "rate": 0.0 + } } ``` -API keys are automatically redacted to `"***"` in the output. - ---- - -## HTML Visual Report - -Every benchmark run automatically generates a self-contained HTML report at `tests/tmp/benchmark-report-{timestamp}.html` (this directory is in `.gitignore`). The report includes: - -- **Overall Score** — A weighted composite score (0–100) displayed as a circular progress ring: - - SWE Resolved Rate × 60% + TAU Pass^1 × 40% - - If only one benchmark type runs, it gets 100% weight - - Color-coded: green (≥90 Excellent), yellow (≥70 Good), orange (≥50 Fair), red (<50 Poor) -- **Configuration Summary** — SDK version, providers, SWE mode, TAU domain, timeout, trials -- **SWE Results** — Summary table, resolved rate bar chart, and expandable per-case details (pass/fail, tokens, duration) -- **TAU Results** — Summary table with Pass^k columns, pass rate bar chart, and expandable per-task trial details - -### Viewing the Report - -```bash -# Run benchmarks (HTML report is generated automatically) -npm run test:benchmark -- --provider=anthropic - -# Serve with Python's built-in HTTP server -cd tests/tmp && python3 -m http.server 8080 -# Open http://localhost:8080/benchmark-report.html -``` - -The report is a single file with all CSS inlined — no external dependencies. You can also open it directly in a browser via `file://` protocol. - ---- - -## Project Structure - -``` -tests/benchmark/ -├── run-benchmark.ts # Entry point -├── config.ts # CLI + env config loading -├── types.ts # Shared type definitions -├── reporter.ts # Table + JSON output -├── html-reporter.ts # HTML visual report generator -├── compare.ts # Historical report comparison -│ -├── swe/ # SWE-bench module -│ ├── index.ts # Module entry (mini + full mode routing) -│ ├── dataset.ts # Case/instance loading -│ ├── harness.ts # Model interaction (mini mode) -│ ├── evaluator.ts # Local test execution (mini mode) -│ ├── docker-evaluator.ts # Docker/git evaluation (full mode) -│ └── cases/ -│ ├── mini-cases.json # 20 JavaScript bug-fix cases -│ └── curated-instances.json # SWE-bench instance definitions -│ -└── tau/ # TAU-bench module - ├── index.ts # Module entry (domain discovery + orchestration) - ├── orchestrator.ts # Agent ↔ User ↔ Environment message loop - ├── user-simulator.ts # LLM-based user simulation - ├── environment.ts # Generic DB + tool dispatch - ├── evaluator.ts # DB state comparison + pass^k - └── domains/ - ├── airline/ - │ ├── policy.md # Business rules - │ ├── database.ts # Initial data (users, flights, reservations) - │ ├── tools.ts # Tool definitions (Anthropic API format) - │ ├── handlers.ts # Tool implementation logic - │ └── tasks.json # 5 evaluation tasks - └── retail/ - ├── policy.md # Return/exchange/shipping policies - ├── database.ts # Initial data (customers, products, orders) - ├── tools.ts # Tool definitions - ├── handlers.ts # Tool implementation logic - └── tasks.json # 5 evaluation tasks -``` - ---- - -## Adding Custom Test Cases - -### Adding Mini-SWE Cases - -Add new entries to `tests/benchmark/swe/cases/mini-cases.json`: - -```json -{ - "id": "mini-swe-021", - "description": "Describe the bug and expected behavior clearly.", - "files": { - "src.js": "// buggy source code\nmodule.exports = { myFunc };\n", - "test.js": "const { myFunc } = require('./src');\n// assertions...\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" -} -``` - -Requirements: -- `src.js` must contain the buggy code (the model should not modify test files) -- `test.js` must exit with code 0 on success, non-zero on failure -- The bug should be a single, clear defect with an unambiguous fix - -### Adding TAU Domains - -To add a new domain (e.g., `telecom`): - -1. Create `tests/benchmark/tau/domains/telecom/`: - - `policy.md` — business rules the agent must follow - - `database.ts` — export `getInitialDatabase()` with typed data - - `tools.ts` — export tool definitions in Anthropic API format - - `handlers.ts` — export `getTelecomHandlers()` returning tool implementations - - `tasks.json` — evaluation tasks with `user_scenario` and `expected_db` - -2. Update `tests/benchmark/tau/index.ts`: - - Add imports for the new domain - - Add a `case 'telecom':` in `loadDomain()` - - Add `'telecom'` to the candidates list in `getAvailableDomains()` - - Add a role entry in `DOMAIN_ROLES` - -### Adding TAU Tasks - -Add entries to a domain's `tasks.json`: - -```json -{ - "task_id": "retail_006", - "user_scenario": "You are [name] (customer ID: [id]). Describe what the user wants...", - "expected_db": { - "orders": [ - { "order_id": "ORD001", "status": "returned" } - ] - }, - "max_turns": 10 -} -``` - -The `expected_db` uses partial matching — only specified fields are checked, and records are matched by their primary key field (any field ending in `_id`). - ---- - -## Best Practices - -1. **Start with mini mode** — it's fast, free of Docker dependencies, and provides quick feedback -2. **Use `--provider` to test one model at a time** during development -3. **Save baseline reports** before SDK upgrades to catch regressions -4. **Set `--num-trials=3` or higher** for TAU benchmarks when evaluating reliability -5. **Use a separate user simulator model** (via `BENCHMARK_USER_MODEL`) to avoid self-play bias -6. **Keep API keys in `.env.test`** — the JSON report automatically redacts them - ---- - -## References +## Notes -- [SWE-bench](https://github.com/SWE-bench/SWE-bench) — Official repository + evaluation harness -- [SWE-bench Verified](https://openai.com/index/introducing-swe-bench-verified/) — Human-verified subset -- [SWE-bench Leaderboard](https://www.swebench.com/original.html) -- [τ-bench](https://github.com/sierra-research/tau-bench) — Original version -- [τ²-bench](https://github.com/sierra-research/tau2-bench) — Extended version with telecom domain -- [τ-bench Paper](https://arxiv.org/abs/2406.12045) — Methodology details -- [τ-bench Leaderboard](https://taubench.com) -- [Provider Configuration](./providers.md) — Setting up model providers +- SWE-bench is fixed to **SWE-bench-Verified**. There is no mini/full mode switch anymore. +- TB2 uses official Harbor run flow (`harbor run -d terminal-bench@2.0 -m ... -a ...`) under the selected runner. +- If Docker image pulls are slow, set `BENCHMARK_DOCKER_PROXY`. diff --git a/docs/zh-CN/guides/benchmark-results.md b/docs/zh-CN/guides/benchmark-results.md new file mode 100644 index 0000000..261855f --- /dev/null +++ b/docs/zh-CN/guides/benchmark-results.md @@ -0,0 +1,32 @@ +# Benchmark 结果(已确认) + +最后更新:2026-02-26 + +## SWE-bench-Verified + +| Provider / Model | 实例数 | 通过数 | 通过率 | 平均 Tokens | 平均耗时 | +|---|---:|---:|---:|---:|---:| +| openai / glm-5 | 12 | 12/12 | 100.0% | 17.2k | 134.5k ms | + +来源:本地完整运行日志(`2026-02-25__21-06-21`)。 + +## Terminal Bench 2.0 + +| Agent / Model | 通过数 | 可判定 | Unknown | 通过率(仅可判定) | 备注 | +|---|---:|---:|---:|---:|---| +| oracle / glm-5 | 1 | 31 | 58 | 3.2% | 与上面同一次完整运行;大量任务以 runtime/timeout 结束。 | + +## 复现命令 + +```bash +npm run test:benchmark -- \ + --benchmark=both \ + --tb2-model=openai/glm-5 \ + --tb2-agent=oracle \ + --tb2-runner=uvx \ + --tb2-jobs-dir=./tests/tmp/jobs \ + --output=json \ + --output-file=tests/tmp/benchmark-report.json +``` + +输出 JSON 同时包含 `swe` 和 `tb2` 两个分区。 diff --git a/docs/zh-CN/guides/benchmarking.md b/docs/zh-CN/guides/benchmarking.md index a1d459f..baa8405 100644 --- a/docs/zh-CN/guides/benchmarking.md +++ b/docs/zh-CN/guides/benchmarking.md @@ -1,470 +1,122 @@ -# 基准测试指南 +# Benchmarking -KODE SDK 内置了一套完整的基准测试套件,用于评估不同 LLM 模型在 Agent 场景下的实际表现。该套件实现了两大业界标准方法论: +KODE SDK 的 benchmark 入口已统一为一个命令,支持三种目标: -- **SWE-bench**(Princeton/OpenAI)— 代码缺陷修复:模型接收 issue 描述 + 源代码,生成修复代码,通过测试验证 -- **τ-bench**(Sierra Research)— 多轮工具调用对话:模型扮演客服 Agent,使用工具、遵循业务策略,通过数据库状态对比评估 - ---- +- `swe`:只跑 SWE-bench-Verified +- `tb2`:只跑 Terminal Bench 2.0 +- `both`:一次命令同时跑两者 ## 前置条件 -1. **Provider 配置** — 在 `.env.test` 中至少配置一个 provider 的 `API_KEY` 和 `MODEL_ID`。详见 [Provider 配置指南](./providers.md)。 - -2. **Node.js** — 需要 `ts-node`(已包含在 devDependencies 中)。 - -3. **(可选)Docker** — 仅 SWE-bench full 模式需要。Mini 模式和 TAU 基准测试不依赖 Docker。 - -### 最小 `.env.test` 配置 - -```ini -# 至少配置一个 provider -ANTHROPIC_API_KEY=sk-ant-... -ANTHROPIC_MODEL_ID=claude-sonnet-4-5-20250929 - -# 可选:配置更多 provider 进行对比 -OPENAI_API_KEY=sk-... -OPENAI_MODEL_ID=gpt-4o - -GEMINI_API_KEY=AIza... -GEMINI_MODEL_ID=gemini-2.5-pro -``` - ---- - -## 快速开始 +1. 安装依赖: ```bash -# 运行全部基准测试(SWE mini + TAU airline + TAU retail) -npm run test:benchmark - -# 仅运行 SWE 基准测试 -npm run test:benchmark -- --swe-only - -# 运行 SWE full 模式(需要 Docker) -npm run test:benchmark -- --swe-only --swe-mode=full - -# 仅运行 TAU 基准测试 -npm run test:benchmark -- --tau-only - -# 指定单个 provider -npm run test:benchmark -- --provider=anthropic - -# 输出 JSON 报告 -npm run test:benchmark -- --output=json --output-file=results.json +npm ci ``` -> **提示:** 每次运行基准测试时会自动生成 HTML 可视化报告,位于 `tests/tmp/benchmark-report-{timestamp}.html`。在浏览器中打开即可查看带评分、图表和逐条明细的详细报告。 - ---- - -## SWE 基准测试 - -SWE 基准测试评估模型修复代码缺陷的能力。模型接收 bug 描述和项目文件,生成修复后的代码,通过运行测试来验证正确性。 - -### Mini 模式(默认) - -Mini 模式使用 20 个内置的 JavaScript 缺陷修复用例,在本地运行,无需 Docker。每个用例包含: -- 含有 bug 的 `src.js` 文件 -- 包含断言的 `test.js` 测试文件 -- 描述预期行为的 bug 说明 +2. 准备 `.env.test`(或直接导出环境变量): ```bash -# 运行 mini-SWE 基准测试 -npm run test:benchmark -- --swe-only --swe-mode=mini -``` +ANTHROPIC_API_KEY=... +ANTHROPIC_MODEL_ID=claude-sonnet-4-20250514 -**示例输出:** +OPENAI_API_KEY=... +OPENAI_MODEL_ID=glm-5 +GEMINI_API_KEY=... +GEMINI_MODEL_ID=gemini-3-pro-preview ``` - SWE mini mode: 20 cases - - Running provider: anthropic / claude-sonnet-4-5-20250929 - [anthropic] mini-swe-001: PASS (1772 tokens, 13186ms) - [anthropic] mini-swe-002: PASS (1246 tokens, 12162ms) - ... ---- SWE-bench (mini-swe) — 20 instances --- - -Provider / Model | Resolved | Rate | Avg Tokens | Avg ms --------------------------------------+----------+---------+------------+--------- -anthropic / claude-sonnet-4-5-20250… | 20/20 | 100.0% | 1.0k | 7.4k -``` +3. 运行依赖: +- SWE-bench-Verified:必须有 Docker +- TB2:`harbor`、`uvx` 或 Docker(由 `--tb2-runner` 决定) -**核心指标:** `Resolved Rate` — 模型修复代码后通过全部测试的用例比例。 - -### Full 模式(Docker) - -Full 模式使用真实开源仓库的 SWE-bench 实例。通过官方预构建的 SWE-bench Docker 镜像进行评估。 +## 统一命令 ```bash -# 运行 full SWE-bench(需要 Docker) -npm run test:benchmark -- --swe-only --swe-mode=full -``` - -评估流程为: -1. 在主机上克隆仓库并 checkout 到指定 commit -2. 从问题描述和提示中提取相关文件路径 -3. 读取源文件,连同 bug 描述一起发送给 LLM -4. LLM 返回 SEARCH/REPLACE 格式的代码修改块 -5. 框架应用修改并程序化生成 unified diff -6. 拉取官方 SWE-bench Docker 镜像(`swebench/sweb.eval.x86_64.:latest`) -7. 容器内已包含仓库(位于 `/testbed`)和预装所有依赖的 `testbed` conda 环境 -8. 在容器中应用 patch 并运行测试套件 - -Docker 不可用时,回退到本地 git clone + patch 应用方式(由于缺少依赖,可靠性较低)。 - -精选实例定义在 `tests/benchmark/swe/cases/curated-instances.json` 中。 - -> **注意:** SWE-bench 镜像较大(每个数 GB)。首次运行时下载镜像需要较长时间,后续运行会复用本地缓存。如需代理下载,请配置 `BENCHMARK_DOCKER_PROXY`。 - ---- - -## TAU 基准测试 - -TAU 基准测试(Tool-Agent-User)评估模型在多轮客服对话中正确使用工具并遵循业务策略的能力。 - -### 架构 - -``` -编排器 (Orchestrator) -├── Agent(被测模型)— 接收用户消息,调用工具,遵循策略 -├── User Simulator(LLM 模拟用户)— 按场景脚本扮演客户 -└── Environment(环境)— 执行工具调用,维护数据库状态 +npm run test:benchmark -- [参数] ``` -**评估方式:** 对话结束后,将最终数据库状态与预期状态对比。所有预期字段匹配则该任务通过。 - -### 可用领域 +### 常用示例 -| 领域 | 任务数 | 工具数 | 描述 | -|------|--------|--------|------| -| `airline` | 5 | 7 | 航班改签、取消、行李查询 | -| `retail` | 5 | 8 | 退货、换货、订单状态、商品搜索 | - -### 运行 TAU 基准测试 +一次命令同时跑 SWE + TB2: ```bash -# 运行全部 TAU 领域 -npm run test:benchmark -- --tau-only - -# 运行指定领域 -npm run test:benchmark -- --tau-only --tau-domain=airline -npm run test:benchmark -- --tau-only --tau-domain=retail - -# 多次试验(计算 pass^k 可靠性指标) -npm run test:benchmark -- --tau-only --num-trials=3 -``` - -**示例输出:** - +npm run test:benchmark -- \ + --benchmark=both \ + --tb2-model=openai/glm-5 \ + --output=json \ + --output-file=tests/tmp/benchmark-report.json ``` - TAU domain: airline (5 tasks, 1 trials) - Running provider: anthropic / claude-sonnet-4-5-20250929 - User simulator: anthropic / claude-sonnet-4-5-20250929 - [anthropic] airline_001 trial 1/1: PASS (5 turns, 22341 tokens) - [anthropic] airline_002 trial 1/1: PASS (3 turns, 15280 tokens) - ... - ---- TAU-bench (airline) — 5 tasks, 1 trials --- - -Provider / Model | Pass^1 | Avg Tokens --------------------------------------+---------+----------- -anthropic / claude-sonnet-4-5-20250… | 80.0% | 18.1k -``` - -### 理解 pass^k 指标 - -**pass^k** 衡量模型在多次独立试验中的可靠性: - -- **pass^1** = 单次试验中通过的任务比例 -- **pass^k** = 在 k 次独立试验中全部通过的任务比例 - -该指标反映一致性 — 如果模型 pass^1 = 80% 但 pass^3 = 40%,说明其表现不稳定。使用 `--num-trials=k` 来计算 pass^k。 - -### 用户模拟器 - -默认情况下,agent 和用户模拟器使用相同的模型。如需使用不同模型模拟用户: - -```ini -# 在 .env.test 中设置 -BENCHMARK_USER_MODEL=anthropic/claude-sonnet-4-5-20250929 -``` - -格式:`provider/model-id`。 - ---- - -## CLI 参数参考 - -所有参数通过 `--` 传递给 npm script: +只跑 SWE-bench-Verified: ```bash -npm run test:benchmark -- [参数] +npm run test:benchmark -- \ + --benchmark=swe \ + --provider=anthropic \ + --output=json \ + --output-file=tests/tmp/swe-report.json ``` -| 参数 | 说明 | 默认值 | -|------|------|--------| -| `--swe-only` | 仅运行 SWE 基准测试 | (全部运行) | -| `--tau-only` | 仅运行 TAU 基准测试 | (全部运行) | -| `--swe-mode=mini\|full` | SWE 评估模式 | `mini` | -| `--tau-domain=airline\|retail\|all` | TAU 评估领域 | `all` | -| `--provider=NAME` | 仅运行指定 provider | (全部已配置) | -| `--num-trials=N` | TAU 每个任务的试验次数(用于 pass^k) | `1` | -| `--output=table\|json\|html\|both` | 输出格式 | `table` | -| `--output-file=PATH` | JSON/HTML 报告输出路径 | `benchmark-report.json` | -| `--compare=PATH` | 与基线 JSON 报告对比 | (无) | - ---- - -## 环境变量 - -可在 `.env.test` 中与 provider 配置一起设置: - -| 变量 | 说明 | 默认值 | -|------|------|--------| -| `BENCHMARK_PROVIDERS` | 逗号分隔的 provider 列表 | (全部已配置) | -| `BENCHMARK_TIMEOUT_MS` | 每个任务超时时间(毫秒) | `120000` | -| `BENCHMARK_NUM_TRIALS` | TAU 默认试验次数 | `1` | -| `BENCHMARK_OUTPUT` | 输出格式 | `table` | -| `BENCHMARK_USER_MODEL` | 用户模拟器模型(`provider/model`) | (与 agent 相同) | -| `BENCHMARK_DOCKER_PROXY` | Docker 容器和 git clone 使用的 HTTP 代理 URL | (无) | - -CLI 参数优先级高于环境变量。 - ---- - -## 历史结果对比 - -保存基线报告,后续运行时与其对比,检测性能退化: +只跑 TB2: ```bash -# 1. 保存基线 -npm run test:benchmark -- --output=json --output-file=baseline.json - -# 2. 后续运行时,与基线对比 -npm run test:benchmark -- --compare=baseline.json -``` - -对比输出展示关键指标的变化及方向标识: - -``` -================================================================================ -Benchmark Comparison -================================================================================ - Baseline: baseline.json - Current: (current run) - ---- SWE Comparison --- - -Metric | Baseline | Current | Delta | Dir ---------------------------------------------------------------------------------- -anthropic/claude-sonnet-4-5 [rate] | 100.0% | 100.0% | = | -anthropic/claude-sonnet-4-5 [resolved] | 20/20 | 20/20 | = | -anthropic/claude-sonnet-4-5 [tokens] | 1.0k | 986 | -45 | ^ - - No regressions detected. -``` - -- `^` = 改善(更高通过率、更少 token/延迟) -- `v` = 退化(更低通过率、更多 token/延迟) -- 检测到退化时退出码为 `1` - ---- - -## JSON 报告格式 - -使用 `--output=json` 或 `--output=both` 时输出 JSON 报告: +npm run test:benchmark -- \ + --benchmark=tb2 \ + --tb2-model=openai/glm-5 \ + --tb2-agent=oracle \ + --tb2-runner=docker \ + --tb2-jobs-dir=./tests/tmp/jobs \ + --output=json \ + --output-file=tests/tmp/tb2-report.json +``` + +## 参数说明 + +| 参数 | 含义 | 默认值 | +|---|---|---| +| `--benchmark=swe\|tb2\|both` | 选择要跑的 benchmark | `both` | +| `--provider=...` | SWE provider 过滤(`anthropic`、`openai`、`gemini` 等) | 自动发现全部 | +| `--tb2-model=provider/model` | TB2 模型 ID | `BENCHMARK_TB2_MODEL` 或 `openai/$OPENAI_MODEL_ID` | +| `--tb2-agent=...` | TB2 agent(如 `oracle`) | `oracle` | +| `--tb2-dataset=...` | TB2 数据集 ID | `terminal-bench@2.0` | +| `--tb2-runner=auto\|harbor\|uvx\|docker` | TB2 运行后端 | `auto` | +| `--tb2-python=3.12` | `uvx` runner 的 Python 版本 | `3.12` | +| `--tb2-jobs-dir=PATH` | TB2 作业目录 | `tests/tmp/jobs` | +| `--tb2-env-file=PATH` | 传给 TB2 runner 的环境文件 | 自动探测 `.env.test` | +| `--tb2-docker-image=IMAGE` | TB2 docker runner 镜像 | `ghcr.io/astral-sh/uv:python3.12-bookworm` | +| `--output=table\|json` | 输出格式 | `table` | +| `--output-file=PATH` | JSON 输出文件(当 `--output=json`) | `benchmark-report.json` | +| `--compare=PATH` | 与历史 JSON 报告做对比 | 未设置 | + +## 输出格式 + +使用 `--output=json` 时,单个报告同时包含 SWE 和 TB2: ```json { - "timestamp": "2026-02-12T10:30:00.000Z", + "timestamp": "2026-02-25T08:31:16.000Z", "sdk_version": "2.7.3", - "swe": [{ - "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, - "summary": { - "dataset": "mini-swe", - "total": 20, - "resolved": 20, - "rate": 1.0, - "avg_tokens": 1031, - "avg_duration_ms": 7420 - }, - "results": [ - { "instance_id": "mini-swe-001", "resolved": true, "tokens_used": 1772, "duration_ms": 13186 } - ] - }], - "tau": [{ - "provider": { "id": "anthropic", "model": "claude-sonnet-4-5-20250929", "apiKey": "***" }, - "summary": { - "domain": "airline", - "total_tasks": 5, - "num_trials": 1, - "pass_at_k": [0.8], - "avg_tokens": 18100 - }, - "results": [ - { "task_id": "airline_001", "trial_pass_rates": [true], "tokens_used": 22341 } - ] - }] -} -``` - -API 密钥在输出中自动脱敏为 `"***"`。 - ---- - -## HTML 可视化报告 - -每次运行基准测试时会自动在 `tests/tmp/benchmark-report-{timestamp}.html` 生成一份自包含的 HTML 报告(该目录已被 `.gitignore` 忽略)。报告包含: - -- **综合评分** — 加权综合评分(0–100),以圆环进度条展示: - - SWE 通过率 × 60% + TAU Pass^1 × 40% - - 如果只运行了一种基准测试,则该项占 100% 权重 - - 按分数自动标色:绿色(≥90 优秀)、黄色(≥70 良好)、橙色(≥50 一般)、红色(<50 较差) -- **配置摘要** — SDK 版本、provider 列表、SWE 模式、TAU 领域、超时设置、试验次数 -- **SWE 结果** — 汇总表格、通过率条形图、可展开的逐 case 明细(通过/失败、token 数、耗时) -- **TAU 结果** — 带 Pass^k 列的汇总表格、通过率条形图、可展开的逐 task 试验明细 - -### 查看报告 - -```bash -# 运行基准测试(HTML 报告自动生成) -npm run test:benchmark -- --provider=anthropic - -# 使用 Python 内置 HTTP 服务器 -cd tests/tmp && python3 -m http.server 8080 -# 打开 http://localhost:8080/benchmark-report.html -``` - -报告是单文件格式,所有 CSS 均内联,无外部依赖。也可以直接通过 `file://` 协议在浏览器中打开。 - ---- - -## 项目结构 - -``` -tests/benchmark/ -├── run-benchmark.ts # 入口文件 -├── config.ts # CLI + 环境变量配置加载 -├── types.ts # 共享类型定义 -├── reporter.ts # 表格 + JSON 输出 -├── html-reporter.ts # HTML 可视化报告生成器 -├── compare.ts # 历史报告对比 -│ -├── swe/ # SWE-bench 模块 -│ ├── index.ts # 模块入口(mini + full 模式路由) -│ ├── dataset.ts # 用例/实例加载 -│ ├── harness.ts # 模型交互(mini 模式) -│ ├── evaluator.ts # 本地测试执行(mini 模式) -│ ├── docker-evaluator.ts # Docker/git 评估(full 模式) -│ └── cases/ -│ ├── mini-cases.json # 20 个 JavaScript 缺陷修复用例 -│ └── curated-instances.json # SWE-bench 实例定义 -│ -└── tau/ # TAU-bench 模块 - ├── index.ts # 模块入口(领域发现 + 编排) - ├── orchestrator.ts # Agent ↔ User ↔ Environment 消息循环 - ├── user-simulator.ts # 基于 LLM 的用户模拟 - ├── environment.ts # 通用 DB + 工具分发 - ├── evaluator.ts # DB 状态对比 + pass^k 计算 - └── domains/ - ├── airline/ - │ ├── policy.md # 业务规则 - │ ├── database.ts # 初始数据(用户、航班、预订) - │ ├── tools.ts # 工具定义(Anthropic API 格式) - │ ├── handlers.ts # 工具实现逻辑 - │ └── tasks.json # 5 个评估任务 - └── retail/ - ├── policy.md # 退换货/配送策略 - ├── database.ts # 初始数据(客户、商品、订单) - ├── tools.ts # 工具定义 - ├── handlers.ts # 工具实现逻辑 - └── tasks.json # 5 个评估任务 -``` - ---- - -## 添加自定义测试用例 - -### 添加 Mini-SWE 用例 - -在 `tests/benchmark/swe/cases/mini-cases.json` 中添加新条目: - -```json -{ - "id": "mini-swe-021", - "description": "清晰描述 bug 和预期行为。", - "files": { - "src.js": "// 有 bug 的源代码\nmodule.exports = { myFunc };\n", - "test.js": "const { myFunc } = require('./src');\n// 断言...\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" + "swe": [ + { + "provider": { "id": "openai", "model": "glm-5" }, + "summary": { "dataset": "swe-bench-verified", "total": 12, "resolved": 10, "rate": 0.8333, "avg_tokens": 17500, "avg_duration_ms": 166000 } + } + ], + "tb2": { + "dataset": "terminal-bench@2.0", + "agent": "oracle", + "model": "openai/glm-5", + "passed": 0, + "total": 89, + "rate": 0.0 + } } ``` -要求: -- `src.js` 必须包含有 bug 的代码(模型不应修改测试文件) -- `test.js` 成功时退出码为 0,失败时非 0 -- bug 应该是单一、明确的缺陷,有唯一的修复方案 - -### 添加 TAU 领域 - -添加新领域(例如 `telecom`): - -1. 创建 `tests/benchmark/tau/domains/telecom/`: - - `policy.md` — Agent 必须遵循的业务规则 - - `database.ts` — 导出 `getInitialDatabase()` 并定义类型 - - `tools.ts` — 导出 Anthropic API 格式的工具定义 - - `handlers.ts` — 导出 `getTelecomHandlers()` 返回工具实现 - - `tasks.json` — 包含 `user_scenario` 和 `expected_db` 的评估任务 - -2. 更新 `tests/benchmark/tau/index.ts`: - - 添加新领域的导入 - - 在 `loadDomain()` 中添加 `case 'telecom':` - - 在 `getAvailableDomains()` 的候选列表中添加 `'telecom'` - - 在 `DOMAIN_ROLES` 中添加角色描述 - -### 添加 TAU 任务 - -在领域的 `tasks.json` 中添加条目: - -```json -{ - "task_id": "retail_006", - "user_scenario": "你是 [姓名](客户 ID:[id])。描述用户想要什么...", - "expected_db": { - "orders": [ - { "order_id": "ORD001", "status": "returned" } - ] - }, - "max_turns": 10 -} -``` - -`expected_db` 使用部分匹配 — 只检查指定的字段,记录通过主键字段(以 `_id` 结尾的字段)进行匹配。 - ---- - -## 最佳实践 - -1. **从 mini 模式开始** — 速度快、无 Docker 依赖、能快速获得反馈 -2. **开发时使用 `--provider` 逐个测试模型** -3. **SDK 升级前保存基线报告** 用于回归检测 -4. **评估可靠性时设置 `--num-trials=3` 或更高** 用于 TAU 基准测试 -5. **使用独立的用户模拟器模型**(通过 `BENCHMARK_USER_MODEL`)避免自对弈偏差 -6. **将 API 密钥放在 `.env.test` 中** — JSON 报告会自动脱敏 - ---- - -## 参考链接 +## 说明 -- [SWE-bench](https://github.com/SWE-bench/SWE-bench) — 官方仓库 + 评估 harness -- [SWE-bench Verified](https://openai.com/index/introducing-swe-bench-verified/) — 人工验证子集 -- [SWE-bench 排行榜](https://www.swebench.com/original.html) -- [τ-bench](https://github.com/sierra-research/tau-bench) — 原始版本 -- [τ²-bench](https://github.com/sierra-research/tau2-bench) — 扩展版本(含 telecom 域) -- [τ-bench 论文](https://arxiv.org/abs/2406.12045) — 方法论详述 -- [τ-bench 排行榜](https://taubench.com) -- [Provider 配置指南](./providers.md) — 模型 provider 配置 +- SWE 已固定为 **SWE-bench-Verified**,不再有 mini/full 模式参数。 +- TB2 走官方 Harbor 流程(`harbor run -d terminal-bench@2.0 -m ... -a ...`),由 runner 包装执行。 +- 若 Docker 拉取镜像慢,可设置 `BENCHMARK_DOCKER_PROXY`。 diff --git a/tests/benchmark/compare.ts b/tests/benchmark/compare.ts index c8dbd2e..22cbc1b 100644 --- a/tests/benchmark/compare.ts +++ b/tests/benchmark/compare.ts @@ -1,13 +1,5 @@ -// --------------------------------------------------------------------------- -// Benchmark report comparison — compare two JSON reports side-by-side -// --------------------------------------------------------------------------- - import fs from 'fs'; -import type { BenchmarkReport, SWEProviderResult, TAUProviderResult } from './types'; - -// --------------------------------------------------------------------------- -// Types -// --------------------------------------------------------------------------- +import type { BenchmarkReport, SWEProviderResult, TB2Summary } from './types'; interface ComparisonRow { label: string; @@ -19,14 +11,10 @@ interface ComparisonRow { interface ComparisonResult { swe: ComparisonRow[]; - tau: ComparisonRow[]; + tb2: ComparisonRow[]; hasRegressions: boolean; } -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - function fmtPct(n: number): string { return (n * 100).toFixed(1) + '%'; } @@ -45,32 +33,21 @@ function lpad(s: string, len: number): string { return s.length >= len ? s.slice(0, len) : ' '.repeat(len - s.length) + s; } -function deltaStr(oldVal: number, newVal: number, unit: 'pct' | 'tokens' | 'ms'): { text: string; dir: 'better' | 'worse' | 'same' } { +function deltaStr( + oldVal: number, + newVal: number, + unit: 'pct' | 'tokens', +): { text: string; dir: 'better' | 'worse' | 'same' } { const diff = newVal - oldVal; if (Math.abs(diff) < 0.001) return { text: '=', dir: 'same' }; const sign = diff > 0 ? '+' : ''; - let text: string; - - switch (unit) { - case 'pct': - text = `${sign}${(diff * 100).toFixed(1)}pp`; - return { text, dir: diff > 0 ? 'better' : 'worse' }; - case 'tokens': - text = `${sign}${fmtK(diff)}`; - // Lower tokens = better - return { text, dir: diff < 0 ? 'better' : 'worse' }; - case 'ms': - text = `${sign}${fmtK(diff)}`; - // Lower duration = better - return { text, dir: diff < 0 ? 'better' : 'worse' }; + if (unit === 'pct') { + return { text: `${sign}${(diff * 100).toFixed(1)}pp`, dir: diff > 0 ? 'better' : 'worse' }; } + return { text: `${sign}${fmtK(diff)}`, dir: diff < 0 ? 'better' : 'worse' }; } -// --------------------------------------------------------------------------- -// Comparison logic -// --------------------------------------------------------------------------- - function compareSWE(oldResults: SWEProviderResult[], newResults: SWEProviderResult[]): ComparisonRow[] { const rows: ComparisonRow[] = []; @@ -91,111 +68,89 @@ function compareSWE(oldResults: SWEProviderResult[], newResults: SWEProviderResu continue; } - // Rate - const rateD = deltaStr(oldR.summary.rate, newR.summary.rate, 'pct'); + const rateDelta = deltaStr(oldR.summary.rate, newR.summary.rate, 'pct'); rows.push({ label: `${key} [rate]`, oldValue: fmtPct(oldR.summary.rate), newValue: fmtPct(newR.summary.rate), - delta: rateD.text, - direction: rateD.dir, + delta: rateDelta.text, + direction: rateDelta.dir, }); - // Resolved count rows.push({ label: `${key} [resolved]`, oldValue: `${oldR.summary.resolved}/${oldR.summary.total}`, newValue: `${newR.summary.resolved}/${newR.summary.total}`, - delta: newR.summary.resolved === oldR.summary.resolved ? '=' : `${newR.summary.resolved - oldR.summary.resolved > 0 ? '+' : ''}${newR.summary.resolved - oldR.summary.resolved}`, - direction: newR.summary.resolved > oldR.summary.resolved ? 'better' : newR.summary.resolved < oldR.summary.resolved ? 'worse' : 'same', + delta: newR.summary.resolved === oldR.summary.resolved + ? '=' + : `${newR.summary.resolved - oldR.summary.resolved > 0 ? '+' : ''}${newR.summary.resolved - oldR.summary.resolved}`, + direction: newR.summary.resolved > oldR.summary.resolved + ? 'better' + : newR.summary.resolved < oldR.summary.resolved + ? 'worse' + : 'same', }); - // Avg tokens - const tokD = deltaStr(oldR.summary.avg_tokens, newR.summary.avg_tokens, 'tokens'); + const tokenDelta = deltaStr(oldR.summary.avg_tokens, newR.summary.avg_tokens, 'tokens'); rows.push({ label: `${key} [tokens]`, oldValue: fmtK(oldR.summary.avg_tokens), newValue: fmtK(newR.summary.avg_tokens), - delta: tokD.text, - direction: tokD.dir, + delta: tokenDelta.text, + direction: tokenDelta.dir, }); } return rows; } -function compareTAU(oldResults: TAUProviderResult[], newResults: TAUProviderResult[]): ComparisonRow[] { - const rows: ComparisonRow[] = []; - - for (const newR of newResults) { - const key = `${newR.provider.id}/${newR.provider.model} [${newR.summary.domain}]`; - const oldR = oldResults.find( - r => - r.provider.id === newR.provider.id && - r.provider.model === newR.provider.model && - r.summary.domain === newR.summary.domain, - ); - - if (!oldR) { - const pass1 = newR.summary.pass_at_k[0] ?? 0; - rows.push({ - label: `${key} [pass^1]`, - oldValue: '-', - newValue: fmtPct(pass1), - delta: 'new', - direction: 'na', - }); - continue; - } - - // Pass^1 (primary metric) - const oldPass1 = oldR.summary.pass_at_k[0] ?? 0; - const newPass1 = newR.summary.pass_at_k[0] ?? 0; - const p1D = deltaStr(oldPass1, newPass1, 'pct'); - rows.push({ - label: `${key} [pass^1]`, - oldValue: fmtPct(oldPass1), - newValue: fmtPct(newPass1), - delta: p1D.text, - direction: p1D.dir, - }); - - // Avg tokens - const tokD = deltaStr(oldR.summary.avg_tokens, newR.summary.avg_tokens, 'tokens'); - rows.push({ - label: `${key} [tokens]`, - oldValue: fmtK(oldR.summary.avg_tokens), - newValue: fmtK(newR.summary.avg_tokens), - delta: tokD.text, - direction: tokD.dir, - }); +function compareTB2(oldTB2?: TB2Summary, newTB2?: TB2Summary): ComparisonRow[] { + if (!newTB2) return []; + if (!oldTB2) { + return [{ + label: 'tb2 [rate]', + oldValue: '-', + newValue: fmtPct(newTB2.rate), + delta: 'new', + direction: 'na', + }]; } + const rows: ComparisonRow[] = []; + const rateDelta = deltaStr(oldTB2.rate, newTB2.rate, 'pct'); + rows.push({ + label: 'tb2 [rate]', + oldValue: fmtPct(oldTB2.rate), + newValue: fmtPct(newTB2.rate), + delta: rateDelta.text, + direction: rateDelta.dir, + }); + + rows.push({ + label: 'tb2 [passed]', + oldValue: `${oldTB2.passed}/${oldTB2.total}`, + newValue: `${newTB2.passed}/${newTB2.total}`, + delta: newTB2.passed === oldTB2.passed + ? '=' + : `${newTB2.passed - oldTB2.passed > 0 ? '+' : ''}${newTB2.passed - oldTB2.passed}`, + direction: newTB2.passed > oldTB2.passed ? 'better' : newTB2.passed < oldTB2.passed ? 'worse' : 'same', + }); + return rows; } -// --------------------------------------------------------------------------- -// Public API -// --------------------------------------------------------------------------- - export function loadReport(filePath: string): BenchmarkReport { - const raw = fs.readFileSync(filePath, 'utf-8'); - return JSON.parse(raw) as BenchmarkReport; + return JSON.parse(fs.readFileSync(filePath, 'utf-8')) as BenchmarkReport; } export function compareReports(oldReport: BenchmarkReport, newReport: BenchmarkReport): ComparisonResult { const sweRows = compareSWE(oldReport.swe ?? [], newReport.swe ?? []); - const tauRows = compareTAU(oldReport.tau ?? [], newReport.tau ?? []); - const hasRegressions = [...sweRows, ...tauRows].some(r => r.direction === 'worse'); - - return { swe: sweRows, tau: tauRows, hasRegressions }; + const tb2Rows = compareTB2(oldReport.tb2, newReport.tb2); + const hasRegressions = [...sweRows, ...tb2Rows].some(r => r.direction === 'worse'); + return { swe: sweRows, tb2: tb2Rows, hasRegressions }; } -export function printComparison( - oldPath: string, - newPath: string, - result: ComparisonResult, -): void { +export function printComparison(oldPath: string, newPath: string, result: ComparisonResult): void { const banner = '='.repeat(80); console.log(`\n${banner}`); console.log('Benchmark Comparison'); @@ -204,16 +159,13 @@ export function printComparison( console.log(` Current: ${newPath}`); console.log(''); - const allRows = [...result.swe, ...result.tau]; - + const allRows = [...result.swe, ...result.tb2]; if (allRows.length === 0) { - console.log(' No comparable results found.'); - console.log(''); + console.log(' No comparable results found.\n'); return; } - // Print table - const maxLabel = Math.max(30, ...allRows.map(r => r.label.length)); + const maxLabel = Math.max(20, ...allRows.map(r => r.label.length)); const header = `${pad('Metric', maxLabel)} | ${lpad('Baseline', 10)} | ${lpad('Current', 10)} | ${lpad('Delta', 12)} | Dir`; const sep = '-'.repeat(header.length); @@ -230,11 +182,11 @@ export function printComparison( console.log(''); } - if (result.tau.length > 0) { - console.log('--- TAU Comparison ---\n'); + if (result.tb2.length > 0) { + console.log('--- TB2 Comparison ---\n'); console.log(header); console.log(sep); - for (const row of result.tau) { + for (const row of result.tb2) { const dir = row.direction === 'better' ? ' ^' : row.direction === 'worse' ? ' v' : ' '; console.log( `${pad(row.label, maxLabel)} | ${lpad(row.oldValue, 10)} | ${lpad(row.newValue, 10)} | ${lpad(row.delta, 12)} |${dir}`, @@ -243,10 +195,6 @@ export function printComparison( console.log(''); } - if (result.hasRegressions) { - console.log(' WARNING: Regressions detected (marked with v)'); - } else { - console.log(' No regressions detected.'); - } + console.log(result.hasRegressions ? ' WARNING: Regressions detected (marked with v)' : ' No regressions detected.'); console.log(''); } diff --git a/tests/benchmark/config.ts b/tests/benchmark/config.ts index 034b763..8ad9895 100644 --- a/tests/benchmark/config.ts +++ b/tests/benchmark/config.ts @@ -9,29 +9,37 @@ const ALL_PROVIDERS: ProviderId[] = ['anthropic', 'openai', 'gemini', 'glm', 'mi // --------------------------------------------------------------------------- export function parseCliArgs(argv: string[] = process.argv.slice(2)): BenchmarkCliArgs { - const args: BenchmarkCliArgs = { - sweOnly: false, - tauOnly: false, - }; + const args: BenchmarkCliArgs = {}; for (const arg of argv) { - if (arg === '--swe-only') { - args.sweOnly = true; - } else if (arg === '--tau-only') { - args.tauOnly = true; - } else if (arg.startsWith('--swe-mode=')) { - const val = arg.slice('--swe-mode='.length); - if (val === 'mini' || val === 'full') args.sweMode = val; - } else if (arg.startsWith('--tau-domain=')) { - args.tauDomain = arg.slice('--tau-domain='.length); + if (arg.startsWith('--benchmark=')) { + const val = arg.slice('--benchmark='.length); + if (val === 'swe' || val === 'tb2' || val === 'both') args.benchmark = val; } else if (arg.startsWith('--provider=')) { args.provider = arg.slice('--provider='.length); - } else if (arg.startsWith('--num-trials=')) { - const n = parseInt(arg.slice('--num-trials='.length), 10); - if (!isNaN(n) && n > 0) args.numTrials = n; + } else if (arg.startsWith('--tb2-model=')) { + args.tb2Model = arg.slice('--tb2-model='.length); + } else if (arg.startsWith('--model=')) { + // Backward-compatible alias for TB2 model. + args.tb2Model = arg.slice('--model='.length); + } else if (arg.startsWith('--tb2-agent=')) { + args.tb2Agent = arg.slice('--tb2-agent='.length); + } else if (arg.startsWith('--tb2-dataset=')) { + args.tb2Dataset = arg.slice('--tb2-dataset='.length); + } else if (arg.startsWith('--tb2-runner=')) { + const val = arg.slice('--tb2-runner='.length); + if (val === 'auto' || val === 'harbor' || val === 'uvx' || val === 'docker') args.tb2Runner = val; + } else if (arg.startsWith('--tb2-python=')) { + args.tb2Python = arg.slice('--tb2-python='.length); + } else if (arg.startsWith('--tb2-jobs-dir=')) { + args.tb2JobsDir = arg.slice('--tb2-jobs-dir='.length); + } else if (arg.startsWith('--tb2-env-file=')) { + args.tb2EnvFile = arg.slice('--tb2-env-file='.length); + } else if (arg.startsWith('--tb2-docker-image=')) { + args.tb2DockerImage = arg.slice('--tb2-docker-image='.length); } else if (arg.startsWith('--output=')) { const val = arg.slice('--output='.length); - if (val === 'table' || val === 'json' || val === 'html' || val === 'both') args.output = val; + if (val === 'table' || val === 'json') args.output = val; } else if (arg.startsWith('--output-file=')) { args.outputFile = arg.slice('--output-file='.length); } else if (arg.startsWith('--compare=')) { @@ -71,30 +79,6 @@ function discoverProviders(filterProvider?: string): BenchmarkProvider[] { return providers; } -function findUserSimProvider(): BenchmarkProvider | undefined { - const userModel = process.env.BENCHMARK_USER_MODEL; - if (!userModel) return undefined; - - // Format: provider/model e.g. "anthropic/claude-opus-4-5-20251101" - const slashIdx = userModel.indexOf('/'); - if (slashIdx === -1) return undefined; - - const providerId = userModel.slice(0, slashIdx) as ProviderId; - const model = userModel.slice(slashIdx + 1); - - const result = loadProviderEnv(providerId); - if (!result.ok || !result.config) return undefined; - if (!result.config.apiKey) return undefined; - - return { - id: providerId, - model, - apiKey: result.config.apiKey, - baseUrl: result.config.baseUrl, - proxyUrl: result.config.proxyUrl, - }; -} - function readSdkVersion(): string { try { const pkg = require('../../package.json'); @@ -106,27 +90,37 @@ function readSdkVersion(): string { export function loadConfig(cliArgs: BenchmarkCliArgs): BenchmarkConfig { const envTimeout = process.env.BENCHMARK_TIMEOUT_MS; - const envTrials = process.env.BENCHMARK_NUM_TRIALS; const envOutput = process.env.BENCHMARK_OUTPUT; + const envTb2Model = process.env.BENCHMARK_TB2_MODEL + || (process.env.OPENAI_MODEL_ID ? `openai/${process.env.OPENAI_MODEL_ID}` : undefined); const timeoutMs = envTimeout ? parseInt(envTimeout, 10) : 120_000; - const numTrials = cliArgs.numTrials - ?? (envTrials ? parseInt(envTrials, 10) : 1); const output = cliArgs.output - ?? (envOutput === 'json' || envOutput === 'both' || envOutput === 'table' || envOutput === 'html' ? envOutput : 'table'); + ?? (envOutput === 'json' || envOutput === 'table' ? envOutput : 'table'); const outputFile = cliArgs.outputFile ?? 'benchmark-report.json'; - const sweMode = cliArgs.sweMode ?? 'mini'; - const tauDomain = cliArgs.tauDomain ?? 'all'; + const benchmark = cliArgs.benchmark ?? 'both'; + const tb2Agent = cliArgs.tb2Agent ?? 'oracle'; + const tb2Dataset = cliArgs.tb2Dataset ?? 'terminal-bench@2.0'; + const tb2Runner = cliArgs.tb2Runner ?? 'auto'; + const tb2Python = cliArgs.tb2Python ?? '3.12'; + const tb2JobsDir = cliArgs.tb2JobsDir ?? 'tests/tmp/jobs'; + const tb2EnvFile = cliArgs.tb2EnvFile; + const tb2DockerImage = cliArgs.tb2DockerImage ?? 'ghcr.io/astral-sh/uv:python3.12-bookworm'; return { + benchmark, providers: discoverProviders(cliArgs.provider), - userSimProvider: findUserSimProvider(), timeoutMs, - numTrials, output, outputFile, - sweMode, - tauDomain, + tb2Model: cliArgs.tb2Model ?? envTb2Model, + tb2Agent, + tb2Dataset, + tb2Runner, + tb2Python, + tb2JobsDir, + tb2EnvFile, + tb2DockerImage, sdkVersion: readSdkVersion(), dockerProxy: process.env.BENCHMARK_DOCKER_PROXY || undefined, }; diff --git a/tests/benchmark/html-reporter.ts b/tests/benchmark/html-reporter.ts deleted file mode 100644 index 05102c2..0000000 --- a/tests/benchmark/html-reporter.ts +++ /dev/null @@ -1,360 +0,0 @@ -import fs from 'fs'; -import path from 'path'; -import type { BenchmarkConfig, BenchmarkReport, SWEProviderResult, TAUProviderResult } from './types'; -import { redactReport } from './reporter'; - -// --------------------------------------------------------------------------- -// Public API -// --------------------------------------------------------------------------- - -export function writeHtmlReport( - report: BenchmarkReport, - config: BenchmarkConfig, - filePath: string, -): void { - const safe = redactReport(report); - const html = buildHtml(safe, config); - fs.mkdirSync(path.dirname(filePath), { recursive: true }); - fs.writeFileSync(filePath, html, 'utf-8'); - console.log(` HTML report written to: ${filePath}`); -} - -// --------------------------------------------------------------------------- -// Score calculation -// --------------------------------------------------------------------------- - -function computeOverallScore(report: BenchmarkReport): number | null { - const scores: { value: number; weight: number }[] = []; - - if (report.swe && report.swe.length > 0) { - // Average SWE rate across all providers - const avgRate = report.swe.reduce((s, r) => s + r.summary.rate, 0) / report.swe.length; - scores.push({ value: avgRate * 100, weight: 60 }); - } - - if (report.tau && report.tau.length > 0) { - // Average TAU pass^1 across all providers - const avgPass = report.tau.reduce((s, r) => { - const p1 = r.summary.pass_at_k[0] ?? 0; - return s + p1; - }, 0) / report.tau.length; - scores.push({ value: avgPass * 100, weight: 40 }); - } - - if (scores.length === 0) return null; - - // If only one type ran, it gets 100% weight - const totalWeight = scores.reduce((s, x) => s + x.weight, 0); - return scores.reduce((s, x) => s + (x.value * x.weight) / totalWeight, 0); -} - -function scoreColor(score: number): string { - if (score >= 90) return '#22c55e'; - if (score >= 70) return '#eab308'; - if (score >= 50) return '#f97316'; - return '#ef4444'; -} - -function scoreLabel(score: number): string { - if (score >= 90) return 'Excellent'; - if (score >= 70) return 'Good'; - if (score >= 50) return 'Fair'; - return 'Poor'; -} - -// --------------------------------------------------------------------------- -// HTML builder -// --------------------------------------------------------------------------- - -function esc(s: string): string { - return s.replace(/&/g, '&').replace(//g, '>').replace(/"/g, '"'); -} - -function fmtK(n: number): string { - if (n >= 1_000_000) return (n / 1_000_000).toFixed(1) + 'M'; - if (n >= 1_000) return (n / 1_000).toFixed(1) + 'k'; - return String(n); -} - -function buildHtml(report: BenchmarkReport, config: BenchmarkConfig): string { - const score = computeOverallScore(report); - return ` - - - - -Benchmark Report — KODE SDK ${esc(report.sdk_version)} -${buildStyle()} - - -
-
-

KODE SDK Benchmark Report

-

Generated ${esc(report.timestamp)}

-
- - ${buildScoreSection(score)} - ${buildSummaryCard(report, config)} - ${report.swe && report.swe.length > 0 ? buildSWESection(report.swe) : ''} - ${report.tau && report.tau.length > 0 ? buildTAUSection(report.tau) : ''} - -
-

KODE SDK v${esc(report.sdk_version)} · Benchmark Suite

-
-
- -`; -} - -// --------------------------------------------------------------------------- -// Sections -// --------------------------------------------------------------------------- - -function buildScoreSection(score: number | null): string { - if (score === null) { - return `
-
- N/A -
-

No benchmark data

-
`; - } - const rounded = Math.round(score * 10) / 10; - const color = scoreColor(rounded); - const label = scoreLabel(rounded); - const pct = Math.min(rounded, 100); - return `
-
- - - - - ${rounded.toFixed(1)} -
-

${label}

-

Weighted: SWE 60% + TAU 40%

-
`; -} - -function buildSummaryCard(report: BenchmarkReport, config: BenchmarkConfig): string { - const providers = config.providers.map(p => `${esc(p.id)} / ${esc(p.model)}`).join(' '); - return `
-

Configuration

-
-
SDK Version${esc(report.sdk_version)}
-
SWE Mode${esc(config.sweMode)}
-
TAU Domain${esc(config.tauDomain)}
-
Timeout${config.timeoutMs}ms
-
Num Trials${config.numTrials}
-
-
Providers: ${providers}
-
`; -} - -function buildSWESection(results: SWEProviderResult[]): string { - let html = `
-

SWE-bench Results

`; - - // Summary table - html += ` - - - `; - - for (const r of results) { - const rate = (r.summary.rate * 100).toFixed(1); - const color = scoreColor(r.summary.rate * 100); - html += ` - - - - - - - `; - } - html += `
Provider / ModelDatasetResolvedRateAvg TokensAvg Duration
${esc(r.provider.id)} / ${esc(r.provider.model)}${esc(r.summary.dataset)}${r.summary.resolved}/${r.summary.total}${rate}%${fmtK(r.summary.avg_tokens)}${fmtK(r.summary.avg_duration_ms)}ms
`; - - // Bar chart - html += `
Resolved Rate by Provider
`; - for (const r of results) { - const pct = (r.summary.rate * 100).toFixed(1); - const color = scoreColor(r.summary.rate * 100); - const label = `${r.provider.id} / ${r.provider.model}`; - html += `
- ${esc(label)} -
- ${pct}% -
`; - } - html += `
`; - - // Per-case details - for (const r of results) { - html += `
- ${esc(r.provider.id)} / ${esc(r.provider.model)} — Case Details (${r.results.length} cases) - - - `; - for (const c of r.results) { - const status = c.resolved - ? 'PASS' - : 'FAIL'; - html += ` - - - - `; - } - html += `
Case IDStatusTokensDurationError
${esc(c.instance_id)}${status}${fmtK(c.tokens_used)}${fmtK(c.duration_ms)}ms${c.error ? esc(c.error) : '-'}
`; - } - - html += `
`; - return html; -} - -function buildTAUSection(results: TAUProviderResult[]): string { - let html = `
-

TAU-bench Results

`; - - // Determine max k from results - const maxK = results.reduce((m, r) => Math.max(m, r.summary.pass_at_k.length), 0); - - // Summary table - html += ` - `; - for (let k = 1; k <= maxK; k++) { - html += ``; - } - html += ``; - - for (const r of results) { - html += ` - - `; - for (let k = 0; k < maxK; k++) { - const val = r.summary.pass_at_k[k]; - if (val !== undefined) { - const pct = (val * 100).toFixed(1); - const color = scoreColor(val * 100); - html += ``; - } else { - html += ``; - } - } - html += ``; - } - html += `
Provider / ModelDomainPass^${k}Avg Tokens
${esc(r.provider.id)} / ${esc(r.provider.model)}${esc(r.summary.domain)}${pct}%-${fmtK(r.summary.avg_tokens)}
`; - - // Bar chart (pass^1) - html += `
Pass^1 Rate by Provider
`; - for (const r of results) { - const p1 = r.summary.pass_at_k[0] ?? 0; - const pct = (p1 * 100).toFixed(1); - const color = scoreColor(p1 * 100); - const label = `${r.provider.id} / ${r.provider.model} (${r.summary.domain})`; - html += `
- ${esc(label)} -
- ${pct}% -
`; - } - html += `
`; - - // Per-task details - for (const r of results) { - html += `
- ${esc(r.provider.id)} / ${esc(r.provider.model)} (${esc(r.summary.domain)}) — Task Details (${r.results.length} tasks) - - - `; - for (const t of r.results) { - const trials = t.trial_pass_rates - .map(p => p ? 'PASS' : 'FAIL') - .join(' '); - html += ` - - - - `; - } - html += `
Task IDTrialsTokensError
${esc(t.task_id)}${trials}${fmtK(t.tokens_used)}${t.error ? esc(t.error) : '-'}
`; - } - - html += `
`; - return html; -} - -// --------------------------------------------------------------------------- -// Styles -// --------------------------------------------------------------------------- - -function buildStyle(): string { - return ``; -} diff --git a/tests/benchmark/reporter.ts b/tests/benchmark/reporter.ts index b7a7c0b..36c0350 100644 --- a/tests/benchmark/reporter.ts +++ b/tests/benchmark/reporter.ts @@ -1,14 +1,6 @@ import fs from 'fs'; -import type { - BenchmarkConfig, - BenchmarkReport, - SWEProviderResult, - TAUProviderResult, -} from './types'; - -// --------------------------------------------------------------------------- -// Internal helpers -// --------------------------------------------------------------------------- +import path from 'path'; +import type { BenchmarkConfig, BenchmarkReport, SWEProviderResult, TB2Summary } from './types'; function pad(s: string, len: number): string { return s.length >= len ? s.slice(0, len) : s + ' '.repeat(len - s.length); @@ -28,6 +20,10 @@ function fmtK(n: number): string { return String(n); } +function fmtPct(n: number): string { + return (n * 100).toFixed(1) + '%'; +} + interface Column { header: string; width: number; @@ -40,10 +36,7 @@ function buildTable(columns: Column[], rows: string[][]): string { .map(c => (c.align === 'right' ? lpad(c.header, c.width) : pad(c.header, c.width))) .join(' | '); - const lines: string[] = []; - lines.push(headerLine); - lines.push(sep); - + const lines: string[] = [headerLine, sep]; for (const row of rows) { const cells = columns.map((c, i) => { const val = row[i] ?? ''; @@ -51,38 +44,37 @@ function buildTable(columns: Column[], rows: string[][]): string { }); lines.push(cells.join(' | ')); } - return lines.join('\n'); } -// --------------------------------------------------------------------------- -// Public API -// --------------------------------------------------------------------------- - export function printProviderSummary(config: BenchmarkConfig): void { const banner = '='.repeat(80); console.log(`\n${banner}`); console.log('KODE SDK Benchmark Runner'); console.log(banner); console.log(` SDK version: ${config.sdkVersion}`); + console.log(` Benchmark: ${config.benchmark}`); console.log(` Timeout: ${config.timeoutMs}ms`); - console.log(` Num trials: ${config.numTrials}`); console.log(` Output: ${config.output}`); - console.log(` SWE mode: ${config.sweMode}`); - console.log(` TAU domain: ${config.tauDomain}`); console.log(''); - if (config.providers.length === 0) { - console.log(' Providers: (none discovered)'); - } else { - console.log(' Providers:'); - for (const p of config.providers) { - console.log(` - ${p.id} / ${p.model}`); + if (config.benchmark === 'swe' || config.benchmark === 'both') { + if (config.providers.length === 0) { + console.log(' SWE providers: (none discovered)'); + } else { + console.log(' SWE providers:'); + for (const p of config.providers) { + console.log(` - ${p.id} / ${p.model}`); + } } } - if (config.userSimProvider) { - console.log(` User sim: ${config.userSimProvider.id} / ${config.userSimProvider.model}`); + if (config.benchmark === 'tb2' || config.benchmark === 'both') { + console.log(` TB2 dataset: ${config.tb2Dataset}`); + console.log(` TB2 agent: ${config.tb2Agent}`); + if (config.tb2Model) console.log(` TB2 model: ${config.tb2Model}`); + console.log(` TB2 runner: ${config.tb2Runner}`); + console.log(` TB2 jobs dir: ${config.tb2JobsDir}`); } if (config.dockerProxy) { @@ -92,11 +84,7 @@ export function printProviderSummary(config: BenchmarkConfig): void { console.log(''); } -export function printSWETable( - dataset: string, - instanceCount: number, - results: SWEProviderResult[], -): void { +export function printSWETable(dataset: string, instanceCount: number, results: SWEProviderResult[]): void { console.log(`\n--- SWE-bench (${dataset}) — ${instanceCount} instances ---\n`); const columns: Column[] = [ @@ -110,7 +98,7 @@ export function printSWETable( const rows = results.map(r => [ trunc(`${r.provider.id} / ${r.provider.model}`, 36), `${r.summary.resolved}/${r.summary.total}`, - (r.summary.rate * 100).toFixed(1) + '%', + fmtPct(r.summary.rate), fmtK(r.summary.avg_tokens), fmtK(r.summary.avg_duration_ms), ]); @@ -119,37 +107,12 @@ export function printSWETable( console.log(''); } -export function printTAUTable( - domain: string, - taskCount: number, - numTrials: number, - results: TAUProviderResult[], -): void { - console.log(`\n--- TAU-bench (${domain}) — ${taskCount} tasks, ${numTrials} trials ---\n`); - - const passColumns: Column[] = []; - for (let k = 1; k <= numTrials; k++) { - passColumns.push({ header: `Pass^${k}`, width: 7, align: 'right' }); - } - - const columns: Column[] = [ - { header: 'Provider / Model', width: 36, align: 'left' }, - ...passColumns, - { header: 'Avg Tokens', width: 10, align: 'right' }, - ]; - - const rows = results.map(r => { - const passValues = r.summary.pass_at_k.map(v => (v * 100).toFixed(1) + '%'); - // Pad if fewer values than numTrials - while (passValues.length < numTrials) passValues.push('-'); - return [ - trunc(`${r.provider.id} / ${r.provider.model}`, 36), - ...passValues, - fmtK(r.summary.avg_tokens), - ]; - }); - - console.log(buildTable(columns, rows)); +export function printTB2Summary(summary: TB2Summary): void { + console.log('\n=== Terminal Bench 2.0 Score ==='); + console.log(`Job path: ${summary.job_path}`); + console.log(`Passed: ${summary.passed}/${summary.total}`); + console.log(`Rate: ${fmtPct(summary.rate)}`); + console.log(`Unknown: ${summary.unknown}`); console.log(''); } @@ -163,13 +126,7 @@ export function redactReport(report: BenchmarkReport): BenchmarkReport { export function writeJsonReport(report: BenchmarkReport, filePath: string): void { const redacted = redactReport(report); const json = JSON.stringify(redacted, null, 2); + fs.mkdirSync(path.dirname(filePath), { recursive: true }); fs.writeFileSync(filePath, json, 'utf-8'); console.log(` JSON report written to: ${filePath}`); } - -export function printNoBenchmarks(): void { - console.log(' No benchmark modules configured yet.'); - console.log(' SWE and TAU modules will be added in Phase 2 and Phase 3.'); - console.log(' Framework scaffolding verified successfully.'); - console.log(''); -} diff --git a/tests/benchmark/run-benchmark.ts b/tests/benchmark/run-benchmark.ts index 015509c..6a66edf 100644 --- a/tests/benchmark/run-benchmark.ts +++ b/tests/benchmark/run-benchmark.ts @@ -1,58 +1,15 @@ /** - * Benchmark runner entry point + * Unified benchmark runner entry point. + * Supports SWE-bench-Verified, Terminal Bench 2.0, or both. */ import '../helpers/env-setup'; import { parseCliArgs, loadConfig } from './config'; -import { - printProviderSummary, - printSWETable, - printTAUTable, - writeJsonReport, - printNoBenchmarks, -} from './reporter'; -import { writeHtmlReport } from './html-reporter'; +import { printProviderSummary, printSWETable, printTB2Summary, writeJsonReport } from './reporter'; import { loadReport, compareReports, printComparison } from './compare'; -import type { BenchmarkCliArgs, BenchmarkConfig, BenchmarkModule, BenchmarkModuleResult, BenchmarkReport } from './types'; - -// --------------------------------------------------------------------------- -// Module discovery -// --------------------------------------------------------------------------- - -async function tryLoadModule(path: string): Promise { - try { - const mod = await import(path); - if (mod && typeof mod.run === 'function' && typeof mod.name === 'string') { - return mod as BenchmarkModule; - } - if (mod && mod.default && typeof mod.default.run === 'function') { - return mod.default as BenchmarkModule; - } - return null; - } catch { - return null; - } -} - -async function discoverModules(cliArgs: BenchmarkCliArgs): Promise { - const modules: BenchmarkModule[] = []; - - if (!cliArgs.tauOnly) { - const swe = await tryLoadModule('./swe/index'); - if (swe) modules.push(swe); - } - - if (!cliArgs.sweOnly) { - const tau = await tryLoadModule('./tau/index'); - if (tau) modules.push(tau); - } - - return modules; -} - -// --------------------------------------------------------------------------- -// Main -// --------------------------------------------------------------------------- +import type { BenchmarkReport } from './types'; +import { run as runSWE } from './swe'; +import { runTB2Official } from './run-tb2-official'; async function main(): Promise { const cliArgs = parseCliArgs(); @@ -60,50 +17,48 @@ async function main(): Promise { printProviderSummary(config); - const modules = await discoverModules(cliArgs); - - if (modules.length === 0) { - printNoBenchmarks(); - return; - } - const report: BenchmarkReport = { timestamp: new Date().toISOString(), sdk_version: config.sdkVersion, }; - for (const mod of modules) { - console.log(` Running module: ${mod.name} ...`); - const result: BenchmarkModuleResult = await mod.run(config); - - if (result.swe) { - report.swe = result.swe; - for (const r of result.swe) { + if (config.benchmark === 'swe' || config.benchmark === 'both') { + console.log(' Running module: swe ...'); + const sweResult = await runSWE(config); + if (sweResult.swe) { + report.swe = sweResult.swe; + for (const r of sweResult.swe) { printSWETable(r.summary.dataset, r.summary.total, [r]); } } + } - if (result.tau) { - report.tau = result.tau; - for (const r of result.tau) { - printTAUTable(r.summary.domain, r.summary.total_tasks, r.summary.num_trials, [r]); - } - } + if (config.benchmark === 'tb2' || config.benchmark === 'both') { + console.log(' Running module: tb2 ...'); + const tb2 = runTB2Official({ + dataset: config.tb2Dataset, + model: config.tb2Model, + agent: config.tb2Agent, + jobsDir: config.tb2JobsDir, + runner: config.tb2Runner, + dockerImage: config.tb2DockerImage, + python: config.tb2Python, + envFile: config.tb2EnvFile, + }); + report.tb2 = tb2; + printTB2Summary(tb2); } - if (config.output === 'json' || config.output === 'both') { - writeJsonReport(report, config.outputFile); + if (!report.swe && !report.tb2) { + console.error(' No benchmark results produced. Check prerequisites and benchmark settings.'); + process.exitCode = 1; + return; } - // Always generate HTML report (with timestamp to avoid overwriting) - const htmlDir = require('path').resolve(__dirname, '..', 'tmp'); - const ts = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19); - const htmlPath = cliArgs.outputFile && cliArgs.outputFile.endsWith('.html') - ? cliArgs.outputFile - : require('path').join(htmlDir, `benchmark-report-${ts}.html`); - writeHtmlReport(report, config, htmlPath); + if (config.output === 'json') { + writeJsonReport(report, config.outputFile); + } - // Historical comparison if (cliArgs.compare) { try { const baselineReport = loadReport(cliArgs.compare); diff --git a/tests/benchmark/run-tb2-official.ts b/tests/benchmark/run-tb2-official.ts new file mode 100644 index 0000000..3b02bbb --- /dev/null +++ b/tests/benchmark/run-tb2-official.ts @@ -0,0 +1,487 @@ +/** + * Run Terminal Bench 2.0 using the official Harbor harness, then print score. + * + * Primary command (from official docs style): + * harbor run -d terminal-bench@2.0 -m -a + * + * This wrapper: + * - invokes Harbor + * - locates the latest job directory under ./jobs + * - computes pass rate from trial result.json / verifier reward + */ + +import fs from 'fs'; +import path from 'path'; +import { spawnSync } from 'child_process'; +import type { TB2Summary } from './types'; + +interface CliArgs { + dataset: string; + model?: string; + agent: string; + jobsDir: string; + runner: 'auto' | 'harbor' | 'uvx' | 'docker'; + dockerImage: string; + python: string; + envFile?: string; + outputFile?: string; +} + +function parseCliArgs(argv: string[] = process.argv.slice(2)): CliArgs { + const args: CliArgs = { + dataset: 'terminal-bench@2.0', + agent: 'oracle', + jobsDir: path.resolve(process.cwd(), 'tests/tmp/jobs'), + runner: 'auto', + dockerImage: 'ghcr.io/astral-sh/uv:python3.12-bookworm', + python: '3.12', + }; + + for (const arg of argv) { + if (arg.startsWith('--dataset=')) { + args.dataset = arg.slice('--dataset='.length); + } else if (arg.startsWith('--model=')) { + args.model = arg.slice('--model='.length); + } else if (arg.startsWith('--agent=')) { + args.agent = arg.slice('--agent='.length); + } else if (arg.startsWith('--jobs-dir=')) { + args.jobsDir = path.resolve(arg.slice('--jobs-dir='.length)); + } else if (arg.startsWith('--runner=')) { + const v = arg.slice('--runner='.length); + if (v === 'auto' || v === 'harbor' || v === 'uvx' || v === 'docker') args.runner = v; + } else if (arg.startsWith('--docker-image=')) { + args.dockerImage = arg.slice('--docker-image='.length); + } else if (arg.startsWith('--python=')) { + args.python = arg.slice('--python='.length); + } else if (arg.startsWith('--env-file=')) { + args.envFile = path.resolve(arg.slice('--env-file='.length)); + } else if (arg.startsWith('--output-file=')) { + args.outputFile = arg.slice('--output-file='.length); + } + } + + const defaultEnvFile = path.resolve(process.cwd(), '.env.test'); + if (!args.envFile && fs.existsSync(defaultEnvFile)) { + args.envFile = defaultEnvFile; + } + + return args; +} + +function hasCommand(cmd: string, versionArg = '--version'): boolean { + const r = spawnSync(cmd, [versionArg], { stdio: 'ignore' }); + return r.status === 0; +} + +function readEnvFileValue(envFile: string, key: string): string | undefined { + if (!fs.existsSync(envFile)) return undefined; + try { + const lines = fs.readFileSync(envFile, 'utf-8').split('\n'); + for (const raw of lines) { + const line = raw.trim(); + if (!line || line.startsWith('#')) continue; + const idx = line.indexOf('='); + if (idx <= 0) continue; + const k = line.slice(0, idx).trim(); + if (k !== key) continue; + let v = line.slice(idx + 1).trim(); + if ((v.startsWith('"') && v.endsWith('"')) || (v.startsWith("'") && v.endsWith("'"))) { + v = v.slice(1, -1); + } + return v; + } + } catch { + // ignore parse failure and fallback to process.env + } + return undefined; +} + +function proxyLooksLocalhost(proxyUrl?: string): boolean { + if (!proxyUrl) return false; + try { + const u = new URL(proxyUrl); + return u.hostname === '127.0.0.1' || u.hostname === 'localhost'; + } catch { + return proxyUrl.includes('127.0.0.1') || proxyUrl.includes('localhost'); + } +} + +interface RunnerSpec { + cmd: string; + baseArgs: string[]; + label: string; + env?: NodeJS.ProcessEnv; +} + +function resolveProxy(args: CliArgs): string | undefined { + return process.env.BENCHMARK_DOCKER_PROXY + || (args.envFile ? readEnvFileValue(args.envFile, 'BENCHMARK_DOCKER_PROXY') : undefined); +} + +function buildDockerRunner(args: CliArgs, cwdForRun: string): RunnerSpec { + if (!hasCommand('docker')) { + throw new Error('docker not found, cannot use --runner=docker'); + } + + const cacheHostDir = path.resolve(path.dirname(args.jobsDir), '.tb2-uv-cache'); + fs.mkdirSync(cacheHostDir, { recursive: true }); + + const baseArgs = [ + 'run', + '--rm', + '-v', + '/var/run/docker.sock:/var/run/docker.sock', + '-v', + `${cwdForRun}:${cwdForRun}`, + '-v', + `${cacheHostDir}:/tmp/uv-cache`, + '-w', + cwdForRun, + '-e', + 'UV_CACHE_DIR=/tmp/uv-cache', + ]; + + if (args.envFile && fs.existsSync(args.envFile)) { + baseArgs.push('--env-file', args.envFile); + } + + // Reuse BENCHMARK_DOCKER_PROXY as fallback proxy for Harbor/uvx downloads. + const fallbackProxy = resolveProxy(args); + const isLinux = process.platform === 'linux'; + let usedHostNetwork = false; + if (isLinux && proxyLooksLocalhost(fallbackProxy)) { + // On Linux, localhost proxy on host is only reachable from container via host network. + baseArgs.push('--network', 'host'); + usedHostNetwork = true; + } + if (fallbackProxy) { + baseArgs.push( + '-e', `HTTP_PROXY=${fallbackProxy}`, + '-e', `HTTPS_PROXY=${fallbackProxy}`, + '-e', `http_proxy=${fallbackProxy}`, + '-e', `https_proxy=${fallbackProxy}`, + ); + } + + baseArgs.push(args.dockerImage, 'uvx', 'harbor'); + + return { + cmd: 'docker', + baseArgs, + label: `docker(${args.dockerImage}) -> uvx harbor${usedHostNetwork ? ' [host-network]' : ''}`, + }; +} + +function resolveRunner(args: CliArgs, cwdForRun: string): RunnerSpec { + const fallbackProxy = resolveProxy(args); + + if (args.runner === 'harbor') { + if (!hasCommand('harbor')) throw new Error('harbor not found for --runner=harbor'); + return { cmd: 'harbor', baseArgs: [], label: 'harbor' }; + } + + if (args.runner === 'uvx') { + if (!hasCommand('uvx')) throw new Error('uvx not found for --runner=uvx'); + const env: NodeJS.ProcessEnv = { + ...process.env, + UV_CACHE_DIR: process.env.UV_CACHE_DIR || '/tmp/uv-cache', + UV_TOOL_DIR: process.env.UV_TOOL_DIR || '/tmp/uv-tools', + XDG_DATA_HOME: process.env.XDG_DATA_HOME || '/tmp/xdg-data', + }; + if (fallbackProxy) { + env.HTTP_PROXY = fallbackProxy; + env.HTTPS_PROXY = fallbackProxy; + env.http_proxy = fallbackProxy; + env.https_proxy = fallbackProxy; + } + return { + cmd: 'uvx', + baseArgs: ['--python', args.python, 'harbor'], + label: `uvx harbor (python ${args.python})`, + env, + }; + } + + if (args.runner === 'docker') { + return buildDockerRunner(args, cwdForRun); + } + + // auto + if (hasCommand('harbor')) { + return { cmd: 'harbor', baseArgs: [], label: 'harbor' }; + } + if (hasCommand('uvx')) { + const env: NodeJS.ProcessEnv = { + ...process.env, + UV_CACHE_DIR: process.env.UV_CACHE_DIR || '/tmp/uv-cache', + UV_TOOL_DIR: process.env.UV_TOOL_DIR || '/tmp/uv-tools', + XDG_DATA_HOME: process.env.XDG_DATA_HOME || '/tmp/xdg-data', + }; + if (fallbackProxy) { + env.HTTP_PROXY = fallbackProxy; + env.HTTPS_PROXY = fallbackProxy; + env.http_proxy = fallbackProxy; + env.https_proxy = fallbackProxy; + } + return { + cmd: 'uvx', + baseArgs: ['--python', args.python, 'harbor'], + label: `uvx harbor (python ${args.python})`, + env, + }; + } + return buildDockerRunner(args, cwdForRun); +} + +function listDirs(root: string): string[] { + if (!fs.existsSync(root)) return []; + return fs.readdirSync(root) + .map(name => path.join(root, name)) + .filter(p => fs.existsSync(p) && fs.statSync(p).isDirectory()); +} + +function findLatestJobDir(jobsDir: string, before: Set): string { + const after = listDirs(jobsDir); + const created = after.filter(p => !before.has(path.resolve(p))); + const candidates = created.length > 0 ? created : after; + if (candidates.length === 0) { + throw new Error(`No job directory found under ${jobsDir}`); + } + + candidates.sort((a, b) => fs.statSync(b).mtimeMs - fs.statSync(a).mtimeMs); + return candidates[0]; +} + +function findFilesRecursive(root: string, fileName: string): string[] { + const out: string[] = []; + function walk(current: string): void { + const entries = fs.readdirSync(current, { withFileTypes: true }); + for (const e of entries) { + const full = path.join(current, e.name); + if (e.isDirectory()) walk(full); + else if (e.isFile() && e.name === fileName) out.push(full); + } + } + walk(root); + return out; +} + +function readJson(filePath: string): any { + return JSON.parse(fs.readFileSync(filePath, 'utf-8')); +} + +function isObject(v: unknown): v is Record { + return typeof v === 'object' && v !== null && !Array.isArray(v); +} + +function pickBooleanResult(obj: Record): boolean | undefined { + for (const k of ['success', 'passed', 'resolved', 'solved', 'is_success', 'is_passed', 'pass']) { + if (typeof obj[k] === 'boolean') return obj[k]; + } + for (const nk of ['result', 'outcome', 'evaluation', 'metrics', 'summary']) { + const v = obj[nk]; + if (!isObject(v)) continue; + for (const k of ['success', 'passed', 'resolved', 'solved', 'is_success', 'is_passed', 'pass']) { + if (typeof v[k] === 'boolean') return v[k]; + } + } + return undefined; +} + +function pickResultFromRewardFile(resultJsonPath: string): boolean | undefined { + const rewardPath = path.join(path.dirname(resultJsonPath), 'verifier', 'reward.txt'); + if (!fs.existsSync(rewardPath)) return undefined; + try { + const n = Number(fs.readFileSync(rewardPath, 'utf-8').trim()); + if (!Number.isFinite(n)) return undefined; + return n > 0; + } catch { + return undefined; + } +} + +function scoreJob(jobPath: string): { passed: number; total: number; unknown: number } { + const summaryPath = path.resolve(jobPath, 'result.json'); + const allResultFiles = findFilesRecursive(jobPath, 'result.json'); + if (allResultFiles.length === 0) { + throw new Error(`No result.json found under job path: ${jobPath}`); + } + // Exclude Harbor's top-level summary file from per-trial scoring. + const resultFiles = allResultFiles + .map(p => path.resolve(p)) + .filter(p => p !== summaryPath); + + let passed = 0; + let total = 0; + let unknown = 0; + + for (const file of resultFiles) { + try { + const data = readJson(file); + if (!isObject(data)) { + unknown += 1; + continue; + } + let ok = pickBooleanResult(data); + if (typeof ok !== 'boolean') ok = pickResultFromRewardFile(file); + + if (typeof ok === 'boolean') { + total += 1; + if (ok) passed += 1; + } else { + unknown += 1; + } + } catch { + unknown += 1; + } + } + + if (total === 0) { + if (!fs.existsSync(summaryPath)) { + throw new Error(`No parseable pass/fail result found under job path: ${jobPath}`); + } + + try { + const summary = readJson(summaryPath); + const nTotal = typeof summary?.n_total_trials === 'number' ? summary.n_total_trials : undefined; + const evals = summary?.stats?.evals; + if (isObject(evals)) { + const firstEval = Object.values(evals)[0] as any; + const mean = typeof firstEval?.metrics?.[0]?.mean === 'number' ? firstEval.metrics[0].mean : undefined; + const nErrors = typeof firstEval?.n_errors === 'number' ? firstEval.n_errors : 0; + const nTrials = typeof firstEval?.n_trials === 'number' ? firstEval.n_trials : 0; + const totalFromSummary = nTotal ?? (nTrials + nErrors); + if (typeof mean === 'number' && totalFromSummary > 0) { + const approxPassed = Math.round(mean * totalFromSummary); + return { + passed: approxPassed, + total: totalFromSummary, + unknown: 0, + }; + } + } + } catch { + // ignore fallback parse errors and throw the original message + } + + throw new Error(`No parseable pass/fail result found under job path: ${jobPath}`); + } + + return { passed, total, unknown }; +} + +function runOfficialTB2(args: CliArgs): string { + const harborArgs: string[] = ['run', '-d', args.dataset]; + if (args.model) harborArgs.push('-m', args.model); + harborArgs.push('-a', args.agent); + + fs.mkdirSync(args.jobsDir, { recursive: true }); + const before = new Set(listDirs(args.jobsDir).map(p => path.resolve(p))); + + // Harbor uses ./jobs by default; run in jobs parent so artifacts are predictable. + const cwdForRun = path.dirname(args.jobsDir); + const runner = resolveRunner(args, cwdForRun); + const fullArgs = [...runner.baseArgs, ...harborArgs]; + + console.log(`Runner: ${runner.label}`); + console.log(`Running: ${runner.cmd} ${fullArgs.join(' ')}`); + console.log(`Working dir: ${cwdForRun}`); + + const run = spawnSync(runner.cmd, fullArgs, { + cwd: cwdForRun, + env: runner.env ?? process.env, + stdio: 'inherit', + }); + + if (run.status !== 0) { + throw new Error(`TB2 run failed with exit code ${run.status ?? 'unknown'}`); + } + + return findLatestJobDir(args.jobsDir, before); +} + +function fmtPct(n: number): string { + return (n * 100).toFixed(1) + '%'; +} + +export interface TB2RunOptions { + dataset: string; + model?: string; + agent: string; + jobsDir: string; + runner: 'auto' | 'harbor' | 'uvx' | 'docker'; + dockerImage: string; + python: string; + envFile?: string; +} + +export function runTB2Official(options: TB2RunOptions): TB2Summary { + const args: CliArgs = { + dataset: options.dataset, + model: options.model, + agent: options.agent, + jobsDir: path.resolve(options.jobsDir), + runner: options.runner, + dockerImage: options.dockerImage, + python: options.python, + envFile: options.envFile ? path.resolve(options.envFile) : undefined, + }; + const defaultEnvFile = path.resolve(process.cwd(), '.env.test'); + if (!args.envFile && fs.existsSync(defaultEnvFile)) { + args.envFile = defaultEnvFile; + } + + const jobPath = runOfficialTB2(args); + const s = scoreJob(jobPath); + + return { + generated_at: new Date().toISOString(), + dataset: args.dataset, + agent: args.agent, + model: args.model, + jobs_dir: args.jobsDir, + job_path: jobPath, + passed: s.passed, + total: s.total, + rate: s.total > 0 ? s.passed / s.total : 0, + unknown: s.unknown, + }; +} + +function writeSummary(summary: TB2Summary, outputFile?: string): void { + console.log('\n=== Terminal Bench 2.0 Score ==='); + console.log(`Job path: ${summary.job_path}`); + console.log(`Passed: ${summary.passed}/${summary.total}`); + console.log(`Rate: ${fmtPct(summary.rate)}`); + console.log(`Unknown: ${summary.unknown}`); + + if (outputFile) { + fs.mkdirSync(path.dirname(outputFile), { recursive: true }); + fs.writeFileSync(outputFile, JSON.stringify(summary, null, 2), 'utf-8'); + console.log(`Summary written to: ${outputFile}`); + } +} + +function main(): void { + const args = parseCliArgs(); + const summary = runTB2Official({ + dataset: args.dataset, + model: args.model, + agent: args.agent, + jobsDir: args.jobsDir, + runner: args.runner, + dockerImage: args.dockerImage, + python: args.python, + envFile: args.envFile, + }); + writeSummary(summary, args.outputFile); +} + +if (require.main === module) { + try { + main(); + } catch (err: any) { + console.error('TB2 official run failed:', err?.message || String(err)); + process.exitCode = 1; + } +} diff --git a/tests/benchmark/swe/cases/mini-cases.json b/tests/benchmark/swe/cases/mini-cases.json deleted file mode 100644 index d6b5c8e..0000000 --- a/tests/benchmark/swe/cases/mini-cases.json +++ /dev/null @@ -1,182 +0,0 @@ -[ - { - "id": "mini-swe-001", - "description": "The `chunk` function splits an array into sub-arrays of the given size, but it returns an extra empty array at the end for certain inputs. For example `chunk([1,2,3,4,5], 2)` returns `[[1,2],[3,4],[5],[]]` instead of `[[1,2],[3,4],[5]]`.", - "files": { - "src.js": "function chunk(arr, size) {\n if (size <= 0) return [];\n const result = [];\n for (let i = 0; i <= arr.length; i += size) {\n result.push(arr.slice(i, i + size));\n }\n return result;\n}\nmodule.exports = { chunk };\n", - "test.js": "const { chunk } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = chunk([1, 2, 3, 4, 5], 2);\nassert(r1.length === 3, 'Expected 3 chunks, got ' + r1.length);\nassert(JSON.stringify(r1) === '[[1,2],[3,4],[5]]', 'Wrong result: ' + JSON.stringify(r1));\n\nconst r2 = chunk([1, 2, 3, 4], 2);\nassert(r2.length === 2, 'Expected 2 chunks, got ' + r2.length);\n\nconst r3 = chunk([], 3);\nassert(r3.length === 0, 'Expected 0 chunks for empty array, got ' + r3.length);\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-002", - "description": "The `countWords` function should return the number of words in a string. It works for simple cases like `countWords('hello world')` returning 2, but fails when there are multiple consecutive spaces. `countWords('hello world')` returns 3 instead of 2, and `countWords(' hello ')` returns 4 instead of 1.", - "files": { - "src.js": "function countWords(text) {\n if (!text) return 0;\n return text.split(' ').length;\n}\nmodule.exports = { countWords };\n", - "test.js": "const { countWords } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(countWords('hello world') === 2, 'Basic two words');\nassert(countWords('hello world') === 2, 'Double space: expected 2, got ' + countWords('hello world'));\nassert(countWords('') === 0, 'Empty string should be 0');\nassert(countWords(' hello ') === 1, 'Padded: expected 1, got ' + countWords(' hello '));\nassert(countWords('one') === 1, 'Single word');\nassert(countWords('a b c') === 3, 'Three words');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-003", - "description": "The `sortNumbers` function should sort an array of numbers in ascending numeric order. However, `sortNumbers([10, 2, 30, 1, 20])` returns `[1, 10, 2, 20, 30]` instead of `[1, 2, 10, 20, 30]`. It appears to be sorting lexicographically instead of numerically.", - "files": { - "src.js": "function sortNumbers(arr) {\n return [...arr].sort();\n}\nmodule.exports = { sortNumbers };\n", - "test.js": "const { sortNumbers } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = sortNumbers([10, 2, 30, 1, 20]);\nassert(JSON.stringify(r1) === '[1,2,10,20,30]', 'Expected [1,2,10,20,30], got ' + JSON.stringify(r1));\n\nconst r2 = sortNumbers([100, 3, 22]);\nassert(JSON.stringify(r2) === '[3,22,100]', 'Expected [3,22,100], got ' + JSON.stringify(r2));\n\nconst r3 = sortNumbers([5]);\nassert(JSON.stringify(r3) === '[5]', 'Single element');\n\nconst r4 = sortNumbers([]);\nassert(JSON.stringify(r4) === '[]', 'Empty array');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-004", - "description": "The `classify` function maps numeric scores to letter grades (A/B/C/D/F). It works for most inputs but `classify(65)` returns `'F'` instead of `'D'`. All scores from 60 to 69 should return 'D'.", - "files": { - "src.js": "function classify(score) {\n if (score >= 90) return 'A';\n if (score >= 80) return 'B';\n if (score >= 70) return 'C';\n if (score >= 60) { 'D'; }\n return 'F';\n}\nmodule.exports = { classify };\n", - "test.js": "const { classify } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(classify(95) === 'A', 'Score 95 should be A');\nassert(classify(85) === 'B', 'Score 85 should be B');\nassert(classify(75) === 'C', 'Score 75 should be C');\nassert(classify(65) === 'D', 'Score 65 should be D, got ' + classify(65));\nassert(classify(60) === 'D', 'Score 60 should be D, got ' + classify(60));\nassert(classify(55) === 'F', 'Score 55 should be F');\nassert(classify(100) === 'A', 'Score 100 should be A');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-005", - "description": "The `flatten` function should recursively flatten a nested array. `flatten([1, [2, 3], [4, [5]]])` should return `[1, 2, 3, 4, 5]` but instead returns `[1, 4]`. It seems to drop elements that are inside nested arrays.", - "files": { - "src.js": "function flatten(arr) {\n return arr.reduce((acc, item) => {\n if (Array.isArray(item)) {\n acc.concat(flatten(item));\n } else {\n acc.push(item);\n }\n return acc;\n }, []);\n}\nmodule.exports = { flatten };\n", - "test.js": "const { flatten } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = flatten([1, [2, 3], [4, [5]]]);\nassert(JSON.stringify(r1) === '[1,2,3,4,5]', 'Deep flatten: expected [1,2,3,4,5], got ' + JSON.stringify(r1));\n\nconst r2 = flatten([1, 2, 3]);\nassert(JSON.stringify(r2) === '[1,2,3]', 'Already flat: ' + JSON.stringify(r2));\n\nconst r3 = flatten([[1], [2], [3]]);\nassert(JSON.stringify(r3) === '[1,2,3]', 'One level: ' + JSON.stringify(r3));\n\nconst r4 = flatten([]);\nassert(JSON.stringify(r4) === '[]', 'Empty array');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-006", - "description": "The `reverseWords` function should reverse the order of words in a sentence. For example, `reverseWords('hello world')` should return `'world hello'`. But it currently returns `'worldhello'` — the words are reversed but the spaces between them are missing.", - "files": { - "src.js": "function reverseWords(sentence) {\n return sentence.split(' ').reverse().join('');\n}\nmodule.exports = { reverseWords };\n", - "test.js": "const { reverseWords } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(reverseWords('hello world') === 'world hello', 'Two words: got \"' + reverseWords('hello world') + '\"');\nassert(reverseWords('a b c') === 'c b a', 'Three words');\nassert(reverseWords('single') === 'single', 'Single word unchanged');\nassert(reverseWords('the quick brown fox') === 'fox brown quick the', 'Four words');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-007", - "description": "The `addItem` function should return a new array with the item appended, without modifying the original array. However, calling `addItem(original, 4)` mutates the original array. After the call, the original array has been changed, which breaks downstream code that relies on immutability.", - "files": { - "src.js": "function addItem(list, item) {\n list.push(item);\n return list;\n}\nmodule.exports = { addItem };\n", - "test.js": "const { addItem } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst original = [1, 2, 3];\nconst result = addItem(original, 4);\n\nassert(JSON.stringify(result) === '[1,2,3,4]', 'Result should contain new item: ' + JSON.stringify(result));\nassert(JSON.stringify(original) === '[1,2,3]', 'Original mutated: ' + JSON.stringify(original));\n\nconst empty = [];\nconst r2 = addItem(empty, 'a');\nassert(JSON.stringify(r2) === '[\"a\"]', 'Add to empty');\nassert(empty.length === 0, 'Empty array mutated');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-008", - "description": "The `mapObject` function should transform all values in an object using a callback `fn(value, key)`. But `mapObject({ a: 1, b: 2 }, (v) => v * 2)` returns `{ a: NaN, b: NaN }` instead of `{ a: 2, b: 4 }`. It looks like the callback arguments might be in the wrong order.", - "files": { - "src.js": "function mapObject(obj, fn) {\n const result = {};\n for (const [key, value] of Object.entries(obj)) {\n result[key] = fn(key, value);\n }\n return result;\n}\nmodule.exports = { mapObject };\n", - "test.js": "const { mapObject } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = mapObject({ a: 1, b: 2, c: 3 }, (v) => v * 2);\nassert(r1.a === 2, 'a should be 2, got ' + r1.a);\nassert(r1.b === 4, 'b should be 4, got ' + r1.b);\nassert(r1.c === 6, 'c should be 6, got ' + r1.c);\n\nconst r2 = mapObject({ x: 'hello' }, (v, k) => k + ':' + v);\nassert(r2.x === 'x:hello', 'Key-value concat: got ' + r2.x);\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-009", - "description": "The `getAdults` function should return all people aged 18 or older from a list. But it currently excludes people who are exactly 18. `getAdults([{name:'Alice', age:18}])` returns an empty array instead of including Alice.", - "files": { - "src.js": "function getAdults(people) {\n return people.filter(p => p.age > 18);\n}\nmodule.exports = { getAdults };\n", - "test.js": "const { getAdults } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst people = [\n { name: 'Alice', age: 18 },\n { name: 'Bob', age: 17 },\n { name: 'Charlie', age: 25 },\n { name: 'Diana', age: 18 }\n];\nconst adults = getAdults(people);\nassert(adults.length === 3, 'Expected 3 adults, got ' + adults.length);\nassert(adults.some(p => p.name === 'Alice'), 'Alice (18) should be included');\nassert(adults.some(p => p.name === 'Diana'), 'Diana (18) should be included');\nassert(!adults.some(p => p.name === 'Bob'), 'Bob (17) should be excluded');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-010", - "description": "The `truncate` function should shorten a string to `maxLen` characters, adding '...' at the end if truncation occurs. It works for long strings but incorrectly truncates strings that are exactly `maxLen` characters long. `truncate('hello', 5)` returns `'he...'` instead of `'hello'`.", - "files": { - "src.js": "function truncate(str, maxLen) {\n if (str.length >= maxLen) {\n return str.slice(0, maxLen - 3) + '...';\n }\n return str;\n}\nmodule.exports = { truncate };\n", - "test.js": "const { truncate } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(truncate('hello', 5) === 'hello', 'Exact length should not truncate, got \"' + truncate('hello', 5) + '\"');\nassert(truncate('hi', 5) === 'hi', 'Short string unchanged');\nassert(truncate('hello world', 8) === 'hello...', 'Truncate to 8: got \"' + truncate('hello world', 8) + '\"');\nassert(truncate('abcdefghij', 7) === 'abcd...', 'Truncate to 7');\nassert(truncate('ab', 2) === 'ab', 'Length equals maxLen, no truncation');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-011", - "description": "The `capitalize` function should capitalize the first letter of each word in a string. But `capitalize('hello world')` returns `'HELLO WORLD'` instead of `'Hello World'`. It uppercases the entire word rather than just the first character.", - "files": { - "src.js": "function capitalize(str) {\n return str.split(' ').map(w => w.toUpperCase()).join(' ');\n}\nmodule.exports = { capitalize };\n", - "test.js": "const { capitalize } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(capitalize('hello world') === 'Hello World', 'Basic: got \"' + capitalize('hello world') + '\"');\nassert(capitalize('foo bar baz') === 'Foo Bar Baz', 'Three words');\nassert(capitalize('a') === 'A', 'Single char');\nassert(capitalize('already Capital') === 'Already Capital', 'Mixed case');\nassert(capitalize('') === '', 'Empty string');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-012", - "description": "The `range` function should generate an array of numbers from `start` to `end` inclusive. But `range(1, 5)` returns `[1, 2, 3, 4]` instead of `[1, 2, 3, 4, 5]`. The end value is always excluded.", - "files": { - "src.js": "function range(start, end) {\n const result = [];\n for (let i = start; i < end; i++) {\n result.push(i);\n }\n return result;\n}\nmodule.exports = { range };\n", - "test.js": "const { range } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(JSON.stringify(range(1, 5)) === '[1,2,3,4,5]', '1 to 5: got ' + JSON.stringify(range(1, 5)));\nassert(JSON.stringify(range(0, 3)) === '[0,1,2,3]', '0 to 3: got ' + JSON.stringify(range(0, 3)));\nassert(JSON.stringify(range(5, 5)) === '[5]', 'Same start/end: got ' + JSON.stringify(range(5, 5)));\nassert(JSON.stringify(range(-2, 1)) === '[-2,-1,0,1]', 'Negative: got ' + JSON.stringify(range(-2, 1)));\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-013", - "description": "The `isPalindrome` function checks whether a string is a palindrome. It should be case-insensitive, so `isPalindrome('Racecar')` should return `true`. But it returns `false` because it compares without normalizing case.", - "files": { - "src.js": "function isPalindrome(str) {\n const reversed = str.split('').reverse().join('');\n return str === reversed;\n}\nmodule.exports = { isPalindrome };\n", - "test.js": "const { isPalindrome } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(isPalindrome('racecar') === true, 'racecar is a palindrome');\nassert(isPalindrome('Racecar') === true, 'Racecar (mixed case) is a palindrome, got ' + isPalindrome('Racecar'));\nassert(isPalindrome('hello') === false, 'hello is not a palindrome');\nassert(isPalindrome('Madam') === true, 'Madam is a palindrome');\nassert(isPalindrome('a') === true, 'Single char is a palindrome');\nassert(isPalindrome('Ab') === false, 'Ab is not a palindrome');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-014", - "description": "The `deepClone` function should create a deep copy of an object, preserving arrays as arrays. But `deepClone({a: [1, 2, 3]})` returns `{a: {\"0\": 1, \"1\": 2, \"2\": 3}}` — arrays are converted to plain objects because the clone always creates `{}` instead of checking for arrays.", - "files": { - "src.js": "function deepClone(obj) {\n if (obj === null || typeof obj !== 'object') return obj;\n const clone = {};\n for (const key of Object.keys(obj)) {\n clone[key] = deepClone(obj[key]);\n }\n return clone;\n}\nmodule.exports = { deepClone };\n", - "test.js": "const { deepClone } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst original = { a: [1, 2, 3], b: { c: 'hello' } };\nconst cloned = deepClone(original);\n\nassert(Array.isArray(cloned.a), 'cloned.a should be an array, got ' + typeof cloned.a);\nassert(JSON.stringify(cloned.a) === '[1,2,3]', 'cloned.a content wrong: ' + JSON.stringify(cloned.a));\nassert(cloned.b.c === 'hello', 'Nested object preserved');\nassert(cloned.a !== original.a, 'Array should be a different reference');\nassert(cloned.b !== original.b, 'Nested obj should be a different reference');\n\noriginal.a.push(4);\nassert(cloned.a.length === 3, 'Mutation should not affect clone');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-015", - "description": "The `groupBy` function should group array elements by a key function. But `groupBy([{type:'a', val:1}, {type:'b', val:2}, {type:'a', val:3}], x => x.type)` stores the keys instead of the items in each group. The result is `{a: ['a', 'a'], b: ['b']}` instead of the original objects.", - "files": { - "src.js": "function groupBy(arr, keyFn) {\n const groups = {};\n for (const item of arr) {\n const key = keyFn(item);\n if (!groups[key]) groups[key] = [];\n groups[key].push(key);\n }\n return groups;\n}\nmodule.exports = { groupBy };\n", - "test.js": "const { groupBy } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst items = [{type:'a', val:1}, {type:'b', val:2}, {type:'a', val:3}];\nconst grouped = groupBy(items, x => x.type);\n\nassert(grouped.a.length === 2, 'Group a should have 2 items, got ' + grouped.a.length);\nassert(grouped.b.length === 1, 'Group b should have 1 item');\nassert(grouped.a[0].val === 1, 'First a item val should be 1, got ' + JSON.stringify(grouped.a[0]));\nassert(grouped.a[1].val === 3, 'Second a item val should be 3');\nassert(grouped.b[0].val === 2, 'b item val should be 2');\n\nconst nums = [1, 2, 3, 4, 5];\nconst evenOdd = groupBy(nums, n => n % 2 === 0 ? 'even' : 'odd');\nassert(evenOdd.odd.length === 3, 'Odd group: ' + JSON.stringify(evenOdd.odd));\nassert(evenOdd.even.length === 2, 'Even group: ' + JSON.stringify(evenOdd.even));\nassert(evenOdd.odd[0] === 1, 'First odd should be 1, got ' + evenOdd.odd[0]);\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-016", - "description": "The `intersection` function should return elements present in both arrays. But `intersection([1,2,3], [2,3,4])` returns `[1,2,3]` instead of `[2,3]`. It checks inclusion against the wrong array.", - "files": { - "src.js": "function intersection(a, b) {\n return a.filter(item => a.includes(item));\n}\nmodule.exports = { intersection };\n", - "test.js": "const { intersection } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = intersection([1, 2, 3], [2, 3, 4]);\nassert(JSON.stringify(r1) === '[2,3]', 'Expected [2,3], got ' + JSON.stringify(r1));\n\nconst r2 = intersection([1, 2], [3, 4]);\nassert(r2.length === 0, 'No common elements, got ' + JSON.stringify(r2));\n\nconst r3 = intersection([5, 5, 6], [5, 7]);\nassert(r3.includes(5), 'Should include 5');\nassert(!r3.includes(6), 'Should not include 6');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-017", - "description": "The `zip` function should pair up elements from two arrays. `zip([1,2,3], ['a','b','c'])` should return `[[1,'a'],[2,'b'],[3,'c']]`. But it returns `[[1,1],[2,2],[3,3]]` — it uses the first array for both elements of each pair.", - "files": { - "src.js": "function zip(a, b) {\n const len = Math.min(a.length, b.length);\n const result = [];\n for (let i = 0; i < len; i++) {\n result.push([a[i], a[i]]);\n }\n return result;\n}\nmodule.exports = { zip };\n", - "test.js": "const { zip } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = zip([1, 2, 3], ['a', 'b', 'c']);\nassert(JSON.stringify(r1) === '[[1,\"a\"],[2,\"b\"],[3,\"c\"]]', 'Basic zip: got ' + JSON.stringify(r1));\n\nconst r2 = zip([1, 2], [10, 20, 30]);\nassert(r2.length === 2, 'Should truncate to shorter length');\nassert(JSON.stringify(r2) === '[[1,10],[2,20]]', 'Uneven: got ' + JSON.stringify(r2));\n\nconst r3 = zip([], [1]);\nassert(r3.length === 0, 'Empty first array');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-018", - "description": "The `sumBy` function should sum an array of objects by a given numeric key. `sumBy([{v:1},{v:2},{v:3}], 'v')` should return `6`, but it returns `'123'` (string concatenation) because the initial accumulator is an empty string instead of zero.", - "files": { - "src.js": "function sumBy(arr, key) {\n return arr.reduce((sum, item) => sum + item[key], '');\n}\nmodule.exports = { sumBy };\n", - "test.js": "const { sumBy } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nassert(sumBy([{v:1},{v:2},{v:3}], 'v') === 6, 'Sum should be 6, got ' + sumBy([{v:1},{v:2},{v:3}], 'v'));\nassert(sumBy([{score:10},{score:20}], 'score') === 30, 'Sum should be 30');\nassert(sumBy([], 'v') === 0, 'Empty array should be 0');\nassert(typeof sumBy([{v:1}], 'v') === 'number', 'Result should be a number, got ' + typeof sumBy([{v:1}], 'v'));\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-019", - "description": "The `pick` function should create a new object with only the specified keys from the source object. But `pick({a:1, b:2, c:3}, ['a','c'])` returns `{a:'a', c:'c'}` instead of `{a:1, c:3}`. It assigns the key name as the value instead of the actual value.", - "files": { - "src.js": "function pick(obj, keys) {\n const result = {};\n for (const key of keys) {\n if (key in obj) result[key] = key;\n }\n return result;\n}\nmodule.exports = { pick };\n", - "test.js": "const { pick } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = pick({a: 1, b: 2, c: 3}, ['a', 'c']);\nassert(r1.a === 1, 'a should be 1, got ' + r1.a);\nassert(r1.c === 3, 'c should be 3, got ' + r1.c);\nassert(r1.b === undefined, 'b should not be present');\n\nconst r2 = pick({x: 'hello', y: 'world'}, ['x', 'z']);\nassert(r2.x === 'hello', 'x should be hello');\nassert(r2.z === undefined, 'z not in source, should be absent');\nassert(Object.keys(r2).length === 1, 'Should only have 1 key');\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - }, - { - "id": "mini-swe-020", - "description": "The `difference` function should return elements in the first array that are NOT in the second array. But `difference([1,2,3,4], [2,4])` returns `[2,4]` instead of `[1,3]`. The filter logic is inverted — it returns the intersection instead of the difference.", - "files": { - "src.js": "function difference(a, b) {\n return a.filter(item => b.includes(item));\n}\nmodule.exports = { difference };\n", - "test.js": "const { difference } = require('./src');\nfunction assert(cond, msg) { if (!cond) { console.error('FAIL:', msg); process.exit(1); } }\n\nconst r1 = difference([1, 2, 3, 4], [2, 4]);\nassert(JSON.stringify(r1) === '[1,3]', 'Expected [1,3], got ' + JSON.stringify(r1));\n\nconst r2 = difference([1, 2, 3], []);\nassert(JSON.stringify(r2) === '[1,2,3]', 'Nothing to remove: got ' + JSON.stringify(r2));\n\nconst r3 = difference([1, 2], [1, 2, 3]);\nassert(r3.length === 0, 'All removed: got ' + JSON.stringify(r3));\n\nconst r4 = difference(['a', 'b', 'c'], ['b']);\nassert(JSON.stringify(r4) === '[\"a\",\"c\"]', 'Strings: got ' + JSON.stringify(r4));\n\nconsole.log('All tests passed');\n" - }, - "test_command": "node test.js" - } -] diff --git a/tests/benchmark/swe/cases/curated-instances.json b/tests/benchmark/swe/cases/verified-instances.json similarity index 100% rename from tests/benchmark/swe/cases/curated-instances.json rename to tests/benchmark/swe/cases/verified-instances.json diff --git a/tests/benchmark/swe/dataset.ts b/tests/benchmark/swe/dataset.ts index 7a31f27..726427f 100644 --- a/tests/benchmark/swe/dataset.ts +++ b/tests/benchmark/swe/dataset.ts @@ -1,38 +1,11 @@ -// --------------------------------------------------------------------------- -// SWE benchmark dataset loader -// --------------------------------------------------------------------------- - import fs from 'fs'; import path from 'path'; import type { FullSWEInstance } from './docker-evaluator'; -export interface MiniCase { - id: string; - description: string; - files: Record; - test_command: string; -} - -/** - * Load mini-SWE cases from the local JSON file. - */ -export function loadMiniCases(): MiniCase[] { - const casesPath = path.join(__dirname, 'cases', 'mini-cases.json'); - if (!fs.existsSync(casesPath)) { - console.log(` SWE: cases file not found at ${casesPath}`); - return []; - } - const raw = fs.readFileSync(casesPath, 'utf-8'); - return JSON.parse(raw) as MiniCase[]; -} - -/** - * Load curated SWE-bench instances for full mode. - */ -export function loadCuratedInstances(): FullSWEInstance[] { - const instancesPath = path.join(__dirname, 'cases', 'curated-instances.json'); +export function loadVerifiedInstances(): FullSWEInstance[] { + const instancesPath = path.join(__dirname, 'cases', 'verified-instances.json'); if (!fs.existsSync(instancesPath)) { - console.log(` SWE: curated instances file not found at ${instancesPath}`); + console.log(` SWE: verified instances file not found at ${instancesPath}`); return []; } const raw = fs.readFileSync(instancesPath, 'utf-8'); diff --git a/tests/benchmark/swe/docker-evaluator.ts b/tests/benchmark/swe/docker-evaluator.ts index 3389ebc..2c3753f 100644 --- a/tests/benchmark/swe/docker-evaluator.ts +++ b/tests/benchmark/swe/docker-evaluator.ts @@ -154,7 +154,7 @@ function extractRelevantPaths(problemStatement: string, hintsText: string): stri // (readFilesFromRepo removed — we now read directly from Docker images) // --------------------------------------------------------------------------- -// LLM interaction — generate fix (file-based, like mini mode) +// LLM interaction — generate fix (file-based source-context flow) // --------------------------------------------------------------------------- const FULL_SYSTEM_PROMPT = `You are a software engineer fixing bugs in open-source repositories. @@ -195,7 +195,7 @@ def validate(value): --- END FILE ---`; /** - * Call the LLM with source file context (like mini mode). + * Call the LLM with source file context (source-context flow). * Includes a single retry on failure. */ async function callLLMWithContext( @@ -392,7 +392,7 @@ function generateDiffFromOriginals( * 1. Pulling the SWE-bench Docker image (has repo at /testbed) * 2. Extracting relevant file paths from the problem statement / hints * 3. Reading those files directly from the Docker image - * 4. Sending source code + problem to LLM (like mini mode) + * 4. Sending source code + problem to LLM (source-context flow) * 5. Parsing corrected files from LLM response * 6. Generating unified diff programmatically */ diff --git a/tests/benchmark/swe/evaluator.ts b/tests/benchmark/swe/evaluator.ts deleted file mode 100644 index 5767138..0000000 --- a/tests/benchmark/swe/evaluator.ts +++ /dev/null @@ -1,64 +0,0 @@ -// --------------------------------------------------------------------------- -// SWE benchmark evaluator — run tests in a temp directory -// --------------------------------------------------------------------------- - -import fs from 'fs'; -import path from 'path'; -import { execSync } from 'child_process'; - -const TEST_TIMEOUT_MS = 15_000; - -export interface EvalResult { - passed: boolean; - output: string; - error?: string; -} - -/** - * Write files to a temporary directory, run the test command, return pass/fail. - */ -export function evaluateCase( - files: Record, - testCommand: string, - workDir: string, -): EvalResult { - // Ensure work directory exists - fs.mkdirSync(workDir, { recursive: true }); - - // Write all files - for (const [name, content] of Object.entries(files)) { - const filePath = path.join(workDir, name); - fs.mkdirSync(path.dirname(filePath), { recursive: true }); - fs.writeFileSync(filePath, content, 'utf-8'); - } - - // Run the test command - try { - const output = execSync(testCommand, { - cwd: workDir, - timeout: TEST_TIMEOUT_MS, - encoding: 'utf-8', - stdio: ['pipe', 'pipe', 'pipe'], - }); - return { passed: true, output: output.trim() }; - } catch (err: any) { - const stdout = (err.stdout || '').toString().trim(); - const stderr = (err.stderr || '').toString().trim(); - return { - passed: false, - output: stdout, - error: stderr || err.message || String(err), - }; - } -} - -/** - * Clean up a work directory. - */ -export function cleanupWorkDir(workDir: string): void { - try { - fs.rmSync(workDir, { recursive: true, force: true }); - } catch { - // ignore cleanup errors - } -} diff --git a/tests/benchmark/swe/harness.ts b/tests/benchmark/swe/harness.ts deleted file mode 100644 index feb9fc8..0000000 --- a/tests/benchmark/swe/harness.ts +++ /dev/null @@ -1,136 +0,0 @@ -// --------------------------------------------------------------------------- -// SWE benchmark harness — sends code + issue to model, parses corrected files -// --------------------------------------------------------------------------- - -import type { ModelProvider } from '../../../src/infra/providers/types'; -import type { Message } from '../../../src/core/types'; -import type { MiniCase } from './dataset'; - -export interface HarnessResult { - correctedFiles: Record; - tokens: number; - error?: string; -} - -const SYSTEM_PROMPT = `You are a software engineer fixing bugs in source code. -You will be given a bug report and the project files. -Your task is to fix the bug so all tests pass. - -Rules: -- Only modify source files. NEVER modify test files. -- Output the COMPLETE corrected file content using this exact format: - ---- FILE: --- - ---- END FILE --- - -- You may output multiple files if needed. -- Do NOT include any explanation outside the file markers. -- Output ONLY the corrected file(s), nothing else.`; - -/** - * Send a mini-SWE case to the model and parse corrected files from the response. - */ -export async function runHarness( - provider: ModelProvider, - caseData: MiniCase, -): Promise { - // Build user message with issue + all file contents - const fileListing = Object.entries(caseData.files) - .map(([name, content]) => `--- ${name} ---\n${content}`) - .join('\n'); - - const userMessage = [ - 'Bug report:', - caseData.description, - '', - 'Project files:', - fileListing, - '', - 'Fix the bug in the source file(s) so that all tests pass.', - 'Output the corrected file(s) using the --- FILE: --- / --- END FILE --- format.', - ].join('\n'); - - const messages: Message[] = [ - { role: 'user', content: [{ type: 'text', text: userMessage }] }, - ]; - - try { - const response = await provider.complete(messages, { system: SYSTEM_PROMPT }); - - const text = response.content - .filter((b): b is { type: 'text'; text: string } => b.type === 'text') - .map(b => b.text) - .join(''); - - const tokens = - (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); - - const correctedFiles = parseFileBlocks(text); - - // If no file blocks found, try to infer from the response - if (Object.keys(correctedFiles).length === 0) { - // Fallback: look for code fences with the src filename - const fallback = parseFallbackCodeBlocks(text, caseData.files); - if (Object.keys(fallback).length > 0) { - return { correctedFiles: fallback, tokens }; - } - return { correctedFiles: {}, tokens, error: 'No corrected files found in model response' }; - } - - return { correctedFiles, tokens }; - } catch (err: any) { - return { correctedFiles: {}, tokens: 0, error: err.message || String(err) }; - } -} - -// --------------------------------------------------------------------------- -// Response parsing -// --------------------------------------------------------------------------- - -/** - * Parse `--- FILE: --- ... --- END FILE ---` blocks from model output. - */ -function parseFileBlocks(text: string): Record { - const files: Record = {}; - const regex = /---\s*FILE:\s*(.+?)\s*---\n([\s\S]*?)---\s*END FILE\s*---/g; - let match: RegExpExecArray | null; - - while ((match = regex.exec(text)) !== null) { - const filename = match[1].trim(); - const content = match[2]; - // Only accept source files, never test files - if (!filename.includes('test')) { - files[filename] = content; - } - } - - return files; -} - -/** - * Fallback: try to extract code from markdown fences like ```js ... ``` - * and match them to source filenames (excluding test files). - */ -function parseFallbackCodeBlocks(text: string, originalFiles: Record): Record { - const files: Record = {}; - const sourceFiles = Object.keys(originalFiles).filter(f => !f.includes('test')); - - if (sourceFiles.length !== 1) return files; // Only works for single source file - - const srcName = sourceFiles[0]; - // Match the last code fence (most likely the final corrected version) - const fenceRegex = /```(?:js|javascript)?\n([\s\S]*?)```/g; - let lastMatch: string | null = null; - let match: RegExpExecArray | null; - - while ((match = fenceRegex.exec(text)) !== null) { - lastMatch = match[1]; - } - - if (lastMatch) { - files[srcName] = lastMatch; - } - - return files; -} diff --git a/tests/benchmark/swe/index.ts b/tests/benchmark/swe/index.ts index d83a239..7d17c24 100644 --- a/tests/benchmark/swe/index.ts +++ b/tests/benchmark/swe/index.ts @@ -1,32 +1,20 @@ -// --------------------------------------------------------------------------- -// SWE benchmark module — BenchmarkModule entry point -// --------------------------------------------------------------------------- - import path from 'path'; import type { BenchmarkConfig, BenchmarkModuleResult, BenchmarkProvider, SWEProviderResult, SWEResult } from '../types'; import type { ModelProvider } from '../../../src/infra/providers/types'; import { AnthropicProvider } from '../../../src/infra/providers/anthropic'; import { OpenAIProvider } from '../../../src/infra/providers/openai'; import { GeminiProvider } from '../../../src/infra/providers/gemini'; -import { loadMiniCases, loadCuratedInstances, MiniCase } from './dataset'; -import { runHarness } from './harness'; -import { evaluateCase, cleanupWorkDir } from './evaluator'; +import { loadVerifiedInstances } from './dataset'; import { isDockerAvailable, generateFix, evaluateWithDocker, - evaluateLocally, - cleanupWorkDir as cleanupDockerWorkDir, + cleanupWorkDir, type FullSWEInstance, } from './docker-evaluator'; -// Module metadata (used by run-benchmark.ts discovery) export const name = 'swe'; -// --------------------------------------------------------------------------- -// Provider creation (same pattern as TAU) -// --------------------------------------------------------------------------- - function createProvider(bp: BenchmarkProvider): ModelProvider { switch (bp.id) { case 'anthropic': @@ -40,110 +28,9 @@ function createProvider(bp: BenchmarkProvider): ModelProvider { } } -// --------------------------------------------------------------------------- -// Run single provider on all mini-SWE cases -// --------------------------------------------------------------------------- - -async function runProviderOnCases( - bp: BenchmarkProvider, - cases: MiniCase[], - config: BenchmarkConfig, -): Promise { - const provider = createProvider(bp); - const results: SWEResult[] = []; - - for (const c of cases) { - const startMs = Date.now(); - const workDir = path.join( - process.cwd(), - 'tests', - '.tmp', - `swe-${bp.id}-${c.id}-${Date.now()}`, - ); - - try { - // 1. Send to model - const harness = await runHarness(provider, c); - - if (harness.error || Object.keys(harness.correctedFiles).length === 0) { - const durationMs = Date.now() - startMs; - const errMsg = harness.error || 'No corrected files returned'; - console.log(` [${bp.id}] ${c.id}: FAIL (${errMsg})`); - results.push({ - instance_id: c.id, - resolved: false, - tokens_used: harness.tokens, - duration_ms: durationMs, - error: errMsg, - }); - continue; - } - - // 2. Merge corrected files with original files - const mergedFiles = { ...c.files }; - for (const [name, content] of Object.entries(harness.correctedFiles)) { - mergedFiles[name] = content; - } - - // 3. Evaluate (write files + run test) - const evalResult = evaluateCase(mergedFiles, c.test_command, workDir); - const durationMs = Date.now() - startMs; - - const status = evalResult.passed ? 'PASS' : 'FAIL'; - const detail = evalResult.passed ? '' : ` (${evalResult.error || evalResult.output})`; - console.log( - ` [${bp.id}] ${c.id}: ${status} (${harness.tokens} tokens, ${durationMs}ms)${detail}`, - ); - - results.push({ - instance_id: c.id, - resolved: evalResult.passed, - tokens_used: harness.tokens, - duration_ms: durationMs, - error: evalResult.passed ? undefined : evalResult.error, - }); - } catch (err: any) { - const durationMs = Date.now() - startMs; - console.log(` [${bp.id}] ${c.id}: FAIL (${err.message})`); - results.push({ - instance_id: c.id, - resolved: false, - tokens_used: 0, - duration_ms: durationMs, - error: err.message || String(err), - }); - } finally { - cleanupWorkDir(workDir); - } - } - - const resolved = results.filter(r => r.resolved).length; - const total = results.length; - const avgTokens = total > 0 ? Math.round(results.reduce((s, r) => s + r.tokens_used, 0) / total) : 0; - const avgDuration = total > 0 ? Math.round(results.reduce((s, r) => s + r.duration_ms, 0) / total) : 0; - - return { - provider: bp, - summary: { - dataset: 'mini-swe', - total, - resolved, - rate: total > 0 ? resolved / total : 0, - avg_tokens: avgTokens, - avg_duration_ms: avgDuration, - }, - results, - }; -} - -// --------------------------------------------------------------------------- -// Run single provider on full SWE-bench instances -// --------------------------------------------------------------------------- - -async function runProviderOnFullInstances( +async function runProviderOnVerifiedInstances( bp: BenchmarkProvider, instances: FullSWEInstance[], - useDocker: boolean, dockerProxy?: string, ): Promise { const provider = createProvider(bp); @@ -151,15 +38,9 @@ async function runProviderOnFullInstances( for (const inst of instances) { const startMs = Date.now(); - const workDir = path.join( - process.cwd(), - 'tests', - '.tmp', - `swe-full-${bp.id}-${inst.instance_id}-${Date.now()}`, - ); + const workDir = path.join(process.cwd(), 'tests', '.tmp', `swe-${bp.id}-${inst.instance_id}-${Date.now()}`); try { - // 1. Generate fix from model (clone repo, read files, LLM, generate diff) console.log(` [${bp.id}] ${inst.instance_id}: generating fix ...`); const harness = await generateFix(provider, inst, dockerProxy); @@ -177,18 +58,13 @@ async function runProviderOnFullInstances( continue; } - // 2. Evaluate the fix patch console.log(` [${bp.id}] ${inst.instance_id}: fix generated (${harness.tokens} tokens), evaluating ...`); - const evalResult = useDocker - ? evaluateWithDocker(inst, harness.patch, workDir, dockerProxy) - : evaluateLocally(inst, harness.patch, workDir); + const evalResult = evaluateWithDocker(inst, harness.patch, workDir, dockerProxy); const durationMs = Date.now() - startMs; const status = evalResult.passed ? 'PASS' : 'FAIL'; - const detail = evalResult.passed ? '' : ` (${(evalResult.error || '').slice(0, 100)})`; - console.log( - ` [${bp.id}] ${inst.instance_id}: ${status} (${harness.tokens} tokens, ${durationMs}ms)${detail}`, - ); + const detail = evalResult.passed ? '' : ` (${(evalResult.error || '').slice(0, 120)})`; + console.log(` [${bp.id}] ${inst.instance_id}: ${status} (${harness.tokens} tokens, ${durationMs}ms)${detail}`); results.push({ instance_id: inst.instance_id, @@ -208,7 +84,7 @@ async function runProviderOnFullInstances( error: err.message || String(err), }); } finally { - cleanupDockerWorkDir(workDir); + cleanupWorkDir(workDir); } } @@ -220,7 +96,7 @@ async function runProviderOnFullInstances( return { provider: bp, summary: { - dataset: 'swe-bench-full', + dataset: 'swe-bench-verified', total, resolved, rate: total > 0 ? resolved / total : 0, @@ -231,18 +107,10 @@ async function runProviderOnFullInstances( }; } -// --------------------------------------------------------------------------- -// Module entry point -// --------------------------------------------------------------------------- - export async function run(config: BenchmarkConfig): Promise { - if (config.sweMode === 'full') { - return runFullMode(config); - } - - const cases = loadMiniCases(); - if (cases.length === 0) { - console.log(' SWE: no mini-SWE cases found'); + const instances = loadVerifiedInstances(); + if (instances.length === 0) { + console.log(' SWE: no verified instances found'); return {}; } @@ -251,40 +119,19 @@ export async function run(config: BenchmarkConfig): Promise { - const instances = loadCuratedInstances(); - if (instances.length === 0) { - console.log(' SWE: no curated instances found for full mode'); - return {}; - } - - if (config.providers.length === 0) { - console.log(' SWE: no providers configured, skipping'); + const dockerAvailable = isDockerAvailable(); + if (!dockerAvailable) { + console.log(' SWE: Docker is required for SWE-bench-Verified and is not available. Skipping.'); return {}; } - const useDocker = isDockerAvailable(); - console.log(`\n SWE full mode: ${instances.length} curated instances`); - console.log(` Docker: ${useDocker ? 'available (using Docker evaluation)' : 'not available (using local git-based evaluation)'}`); + console.log(`\n SWE verified mode: ${instances.length} instances`); + console.log(' Docker: available (official SWE image evaluation)'); const allResults: SWEProviderResult[] = []; - for (const bp of config.providers) { console.log(`\n Running provider: ${bp.id} / ${bp.model}`); - const providerResult = await runProviderOnFullInstances(bp, instances, useDocker, config.dockerProxy); + const providerResult = await runProviderOnVerifiedInstances(bp, instances, config.dockerProxy); allResults.push(providerResult); } diff --git a/tests/benchmark/tau/domains/airline/database.ts b/tests/benchmark/tau/domains/airline/database.ts deleted file mode 100644 index b719dd2..0000000 --- a/tests/benchmark/tau/domains/airline/database.ts +++ /dev/null @@ -1,220 +0,0 @@ -// --------------------------------------------------------------------------- -// Airline domain database types and initial data -// --------------------------------------------------------------------------- - -export interface User { - user_id: string; - name: string; - email: string; - phone: string; - membership: 'regular' | 'silver' | 'gold' | 'platinum'; -} - -export interface Flight { - flight_id: string; - airline: string; - route: string; - date: string; - departure_time: string; - arrival_time: string; - price: number; - seats_available: number; - aircraft: string; -} - -export interface Reservation { - reservation_id: string; - user_id: string; - flight_id: string; - status: 'confirmed' | 'cancelled' | 'pending'; - seat_class: 'economy' | 'business' | 'first'; - payment_amount: number; - booked_at: string; -} - -export interface AirlineDatabase { - users: User[]; - flights: Flight[]; - reservations: Reservation[]; -} - -export function getInitialDatabase(): AirlineDatabase { - return { - users: [ - { - user_id: 'USR001', - name: 'John Smith', - email: 'john.smith@email.com', - phone: '555-0101', - membership: 'gold', - }, - { - user_id: 'USR002', - name: 'Alice Johnson', - email: 'alice.j@email.com', - phone: '555-0102', - membership: 'regular', - }, - { - user_id: 'USR003', - name: 'Bob Chen', - email: 'bob.chen@email.com', - phone: '555-0103', - membership: 'silver', - }, - { - user_id: 'USR004', - name: 'Maria Garcia', - email: 'maria.g@email.com', - phone: '555-0104', - membership: 'platinum', - }, - { - user_id: 'USR005', - name: 'David Kim', - email: 'david.k@email.com', - phone: '555-0105', - membership: 'regular', - }, - ], - - flights: [ - { - flight_id: 'FL001', - airline: 'SkyAir', - route: 'SFO-LAX', - date: '2026-03-15', - departure_time: '08:00', - arrival_time: '09:30', - price: 150, - seats_available: 42, - aircraft: 'A320', - }, - { - flight_id: 'FL002', - airline: 'SkyAir', - route: 'SFO-LAX', - date: '2026-03-15', - departure_time: '14:00', - arrival_time: '15:30', - price: 180, - seats_available: 15, - aircraft: 'A320', - }, - { - flight_id: 'FL003', - airline: 'SkyAir', - route: 'SFO-LAX', - date: '2026-03-17', - departure_time: '08:00', - arrival_time: '09:30', - price: 160, - seats_available: 38, - aircraft: 'A320', - }, - { - flight_id: 'FL004', - airline: 'SkyAir', - route: 'SFO-LAX', - date: '2026-03-17', - departure_time: '18:00', - arrival_time: '19:30', - price: 200, - seats_available: 5, - aircraft: 'B737', - }, - { - flight_id: 'FL005', - airline: 'SkyAir', - route: 'LAX-JFK', - date: '2026-03-20', - departure_time: '10:00', - arrival_time: '18:30', - price: 350, - seats_available: 60, - aircraft: 'B777', - }, - { - flight_id: 'FL006', - airline: 'SkyAir', - route: 'JFK-SFO', - date: '2026-03-22', - departure_time: '07:00', - arrival_time: '10:30', - price: 380, - seats_available: 22, - aircraft: 'A350', - }, - { - flight_id: 'FL007', - airline: 'SkyAir', - route: 'SFO-SEA', - date: '2026-03-18', - departure_time: '12:00', - arrival_time: '14:00', - price: 120, - seats_available: 0, - aircraft: 'A320', - }, - { - flight_id: 'FL008', - airline: 'SkyAir', - route: 'SFO-SEA', - date: '2026-03-19', - departure_time: '12:00', - arrival_time: '14:00', - price: 130, - seats_available: 25, - aircraft: 'A320', - }, - ], - - reservations: [ - { - reservation_id: 'RES001', - user_id: 'USR001', - flight_id: 'FL001', - status: 'confirmed', - seat_class: 'economy', - payment_amount: 150, - booked_at: '2026-02-01', - }, - { - reservation_id: 'RES002', - user_id: 'USR002', - flight_id: 'FL005', - status: 'confirmed', - seat_class: 'economy', - payment_amount: 350, - booked_at: '2026-02-05', - }, - { - reservation_id: 'RES003', - user_id: 'USR003', - flight_id: 'FL002', - status: 'confirmed', - seat_class: 'business', - payment_amount: 360, - booked_at: '2026-02-10', - }, - { - reservation_id: 'RES004', - user_id: 'USR004', - flight_id: 'FL006', - status: 'confirmed', - seat_class: 'first', - payment_amount: 760, - booked_at: '2026-01-20', - }, - { - reservation_id: 'RES005', - user_id: 'USR005', - flight_id: 'FL007', - status: 'confirmed', - seat_class: 'economy', - payment_amount: 120, - booked_at: '2026-02-15', - }, - ], - }; -} diff --git a/tests/benchmark/tau/domains/airline/handlers.ts b/tests/benchmark/tau/domains/airline/handlers.ts deleted file mode 100644 index d35fe25..0000000 --- a/tests/benchmark/tau/domains/airline/handlers.ts +++ /dev/null @@ -1,74 +0,0 @@ -// --------------------------------------------------------------------------- -// Airline domain tool handlers -// --------------------------------------------------------------------------- - -export type ToolHandler = (db: any, args: any) => any; - -export function getAirlineHandlers(): Record { - return { - get_user_details: (db, args: { user_id: string }) => { - const user = db.users.find((u: any) => u.user_id === args.user_id); - if (!user) return { error: `User not found: ${args.user_id}` }; - return user; - }, - - get_reservation_details: (db, args: { reservation_id: string }) => { - const res = db.reservations.find((r: any) => r.reservation_id === args.reservation_id); - if (!res) return { error: `Reservation not found: ${args.reservation_id}` }; - return res; - }, - - list_user_reservations: (db, args: { user_id: string }) => { - const list = db.reservations.filter((r: any) => r.user_id === args.user_id); - return { reservations: list }; - }, - - get_flight_details: (db, args: { flight_id: string }) => { - const flight = db.flights.find((f: any) => f.flight_id === args.flight_id); - if (!flight) return { error: `Flight not found: ${args.flight_id}` }; - return flight; - }, - - search_flights: (db, args: { route: string; date?: string }) => { - let results = db.flights.filter((f: any) => f.route === args.route); - if (args.date) { - results = results.filter((f: any) => f.date === args.date); - } - return { flights: results }; - }, - - update_reservation: (db, args: { reservation_id: string; new_flight_id: string }) => { - const res = db.reservations.find((r: any) => r.reservation_id === args.reservation_id); - if (!res) return { error: `Reservation not found: ${args.reservation_id}` }; - if (res.status === 'cancelled') return { error: 'Cannot update a cancelled reservation' }; - - const newFlight = db.flights.find((f: any) => f.flight_id === args.new_flight_id); - if (!newFlight) return { error: `Flight not found: ${args.new_flight_id}` }; - if (newFlight.seats_available <= 0) return { error: `No seats available on flight ${args.new_flight_id}` }; - - // Release seat on old flight - const oldFlight = db.flights.find((f: any) => f.flight_id === res.flight_id); - if (oldFlight) oldFlight.seats_available += 1; - - // Book seat on new flight - newFlight.seats_available -= 1; - res.flight_id = args.new_flight_id; - res.payment_amount = newFlight.price; - - return { success: true, reservation: { ...res } }; - }, - - cancel_reservation: (db, args: { reservation_id: string }) => { - const res = db.reservations.find((r: any) => r.reservation_id === args.reservation_id); - if (!res) return { error: `Reservation not found: ${args.reservation_id}` }; - if (res.status === 'cancelled') return { error: 'Reservation is already cancelled' }; - - // Release seat - const flight = db.flights.find((f: any) => f.flight_id === res.flight_id); - if (flight) flight.seats_available += 1; - - res.status = 'cancelled'; - return { success: true, reservation: { ...res } }; - }, - }; -} diff --git a/tests/benchmark/tau/domains/airline/policy.md b/tests/benchmark/tau/domains/airline/policy.md deleted file mode 100644 index c125fa0..0000000 --- a/tests/benchmark/tau/domains/airline/policy.md +++ /dev/null @@ -1,45 +0,0 @@ -# Airline Customer Service Policy - -You are an airline customer service agent. Follow these policies strictly when handling customer requests. - -## Identity Verification - -- Always verify the customer's identity before making any changes. -- Ask for the user's name or user ID. Look up their information using the `get_user_details` tool. -- Confirm key details (name, reservation ID) before proceeding. - -## Flight Changes - -- Customers may request to change their flight to a different date or route. -- Use `search_flights` to find available alternatives. -- Gold and Platinum members: flight changes are free. -- Silver members: $50 change fee applies. -- Regular members: $75 change fee applies. -- Changes must be made at least 2 hours before departure. -- Use `update_reservation` to apply the change. -- Always confirm the new flight details with the customer before making the change. - -## Cancellations - -- Customers may cancel their reservation. -- Gold and Platinum members: full refund. -- Silver members: 80% refund. -- Regular members: 50% refund, or full refund if cancelled more than 72 hours before departure. -- Use `cancel_reservation` to process the cancellation. -- Inform the customer of the refund amount and timeline (5-7 business days). - -## Baggage Policy - -- Economy: 1 checked bag (23kg) included. -- Business: 2 checked bags (32kg each) included. -- First: 3 checked bags (32kg each) included. -- Additional bags: $35 each. -- Overweight bags (23-32kg): $50 surcharge. - -## General Rules - -- Be polite, professional, and concise. -- If you cannot fulfill a request due to policy restrictions, explain clearly why. -- Do not make up information. Only provide details from the database. -- When the customer's issue is fully resolved, end with "###STOP###". -- If the customer says goodbye or has no more questions, end with "###STOP###". diff --git a/tests/benchmark/tau/domains/airline/tasks.json b/tests/benchmark/tau/domains/airline/tasks.json deleted file mode 100644 index 0e45443..0000000 --- a/tests/benchmark/tau/domains/airline/tasks.json +++ /dev/null @@ -1,71 +0,0 @@ -[ - { - "task_id": "airline_001", - "user_scenario": "You are John Smith (user ID: USR001). Your reservation ID is RES001. You are currently booked on a March 15 SFO to LAX flight. You want to change your flight to March 17 instead. You prefer the morning flight if available. If the agent asks for confirmation, agree to proceed.", - "expected_db": { - "reservations": [ - { - "reservation_id": "RES001", - "flight_id": "FL003", - "status": "confirmed" - } - ] - }, - "max_turns": 10 - }, - { - "task_id": "airline_002", - "user_scenario": "You are Alice Johnson (user ID: USR002). Your reservation ID is RES002. You need to cancel your LAX to JFK flight on March 20 because your plans changed. Accept the cancellation terms whatever they are.", - "expected_db": { - "reservations": [ - { - "reservation_id": "RES002", - "status": "cancelled" - } - ] - }, - "max_turns": 10 - }, - { - "task_id": "airline_003", - "user_scenario": "You are Bob Chen (user ID: USR003). You want to check the details of your upcoming flight - you think your reservation is RES003 but you're not sure of the exact departure time. Just ask for the information, you don't want to make any changes.", - "expected_db": { - "reservations": [ - { - "reservation_id": "RES003", - "status": "confirmed", - "flight_id": "FL002" - } - ] - }, - "max_turns": 8 - }, - { - "task_id": "airline_004", - "user_scenario": "You are Maria Garcia (user ID: USR004). Your reservation is RES004 for a JFK to SFO flight. You want to know what the baggage allowance is for your ticket class, and also confirm your flight details. You don't want to make any changes.", - "expected_db": { - "reservations": [ - { - "reservation_id": "RES004", - "status": "confirmed", - "flight_id": "FL006" - } - ] - }, - "max_turns": 8 - }, - { - "task_id": "airline_005", - "user_scenario": "You are David Kim (user ID: USR005). Your reservation is RES005 for a SFO to SEA flight on March 18. You just found out that flight has no seats available and you're worried. You want the agent to help you rebook to another SFO-SEA flight. Accept any available option.", - "expected_db": { - "reservations": [ - { - "reservation_id": "RES005", - "flight_id": "FL008", - "status": "confirmed" - } - ] - }, - "max_turns": 12 - } -] diff --git a/tests/benchmark/tau/domains/airline/tools.ts b/tests/benchmark/tau/domains/airline/tools.ts deleted file mode 100644 index d0bed04..0000000 --- a/tests/benchmark/tau/domains/airline/tools.ts +++ /dev/null @@ -1,127 +0,0 @@ -// --------------------------------------------------------------------------- -// Airline domain tool definitions (Anthropic API format) -// --------------------------------------------------------------------------- - -export interface ToolDef { - name: string; - description: string; - input_schema: Record; -} - -export function getAirlineToolDefs(): ToolDef[] { - return [ - { - name: 'get_user_details', - description: - 'Look up a user by their user ID. Returns user profile including name, email, phone, and membership tier.', - input_schema: { - type: 'object', - properties: { - user_id: { - type: 'string', - description: 'The user ID to look up (e.g. "USR001")', - }, - }, - required: ['user_id'], - }, - }, - { - name: 'get_reservation_details', - description: - 'Look up a reservation by reservation ID. Returns booking details including flight, status, and payment.', - input_schema: { - type: 'object', - properties: { - reservation_id: { - type: 'string', - description: 'The reservation ID (e.g. "RES001")', - }, - }, - required: ['reservation_id'], - }, - }, - { - name: 'list_user_reservations', - description: - 'List all reservations for a given user. Returns an array of reservation records.', - input_schema: { - type: 'object', - properties: { - user_id: { - type: 'string', - description: 'The user ID whose reservations to list', - }, - }, - required: ['user_id'], - }, - }, - { - name: 'get_flight_details', - description: - 'Get details for a specific flight by flight ID. Returns route, schedule, price, and availability.', - input_schema: { - type: 'object', - properties: { - flight_id: { - type: 'string', - description: 'The flight ID (e.g. "FL001")', - }, - }, - required: ['flight_id'], - }, - }, - { - name: 'search_flights', - description: - 'Search for available flights by route and optional date. Returns matching flights with availability.', - input_schema: { - type: 'object', - properties: { - route: { - type: 'string', - description: 'Flight route in "ORIGIN-DEST" format (e.g. "SFO-LAX")', - }, - date: { - type: 'string', - description: 'Date in YYYY-MM-DD format. If omitted, returns all dates.', - }, - }, - required: ['route'], - }, - }, - { - name: 'update_reservation', - description: - 'Update a reservation to change the flight. The new flight must have available seats.', - input_schema: { - type: 'object', - properties: { - reservation_id: { - type: 'string', - description: 'The reservation ID to update', - }, - new_flight_id: { - type: 'string', - description: 'The new flight ID to switch to', - }, - }, - required: ['reservation_id', 'new_flight_id'], - }, - }, - { - name: 'cancel_reservation', - description: - 'Cancel a reservation. Sets the reservation status to "cancelled". Cannot be undone.', - input_schema: { - type: 'object', - properties: { - reservation_id: { - type: 'string', - description: 'The reservation ID to cancel', - }, - }, - required: ['reservation_id'], - }, - }, - ]; -} diff --git a/tests/benchmark/tau/domains/retail/database.ts b/tests/benchmark/tau/domains/retail/database.ts deleted file mode 100644 index e37dc3e..0000000 --- a/tests/benchmark/tau/domains/retail/database.ts +++ /dev/null @@ -1,156 +0,0 @@ -// --------------------------------------------------------------------------- -// Retail domain database types and initial data -// --------------------------------------------------------------------------- - -export interface Customer { - customer_id: string; - name: string; - email: string; - phone: string; - membership: 'regular' | 'vip' | 'premium'; -} - -export interface Product { - product_id: string; - name: string; - category: string; - price: number; - stock: number; -} - -export interface OrderItem { - product_id: string; - product_name: string; - quantity: number; - unit_price: number; -} - -export interface Order { - order_id: string; - customer_id: string; - items: OrderItem[]; - total: number; - status: 'pending' | 'shipped' | 'delivered' | 'cancelled' | 'returned'; - order_date: string; - delivery_date?: string; -} - -export interface RetailDatabase { - customers: Customer[]; - products: Product[]; - orders: Order[]; -} - -export function getInitialDatabase(): RetailDatabase { - return { - customers: [ - { - customer_id: 'CUST001', - name: 'Emma Wilson', - email: 'emma.w@email.com', - phone: '555-1001', - membership: 'vip', - }, - { - customer_id: 'CUST002', - name: 'James Brown', - email: 'james.b@email.com', - phone: '555-1002', - membership: 'regular', - }, - { - customer_id: 'CUST003', - name: 'Sophia Lee', - email: 'sophia.l@email.com', - phone: '555-1003', - membership: 'premium', - }, - { - customer_id: 'CUST004', - name: 'Liam Martinez', - email: 'liam.m@email.com', - phone: '555-1004', - membership: 'regular', - }, - { - customer_id: 'CUST005', - name: 'Olivia Davis', - email: 'olivia.d@email.com', - phone: '555-1005', - membership: 'vip', - }, - ], - - products: [ - { product_id: 'PROD001', name: 'Wireless Headphones', category: 'Electronics', price: 79.99, stock: 150 }, - { product_id: 'PROD002', name: 'Bluetooth Speaker', category: 'Electronics', price: 49.99, stock: 80 }, - { product_id: 'PROD003', name: 'Running Shoes (Size 10)', category: 'Footwear', price: 129.99, stock: 30 }, - { product_id: 'PROD004', name: 'Running Shoes (Size 11)', category: 'Footwear', price: 129.99, stock: 0 }, - { product_id: 'PROD005', name: 'Cotton T-Shirt (M)', category: 'Apparel', price: 24.99, stock: 200 }, - { product_id: 'PROD006', name: 'Cotton T-Shirt (L)', category: 'Apparel', price: 24.99, stock: 180 }, - { product_id: 'PROD007', name: 'Yoga Mat', category: 'Fitness', price: 34.99, stock: 60 }, - { product_id: 'PROD008', name: 'Water Bottle (32oz)', category: 'Fitness', price: 19.99, stock: 100 }, - { product_id: 'PROD009', name: 'Laptop Stand', category: 'Electronics', price: 59.99, stock: 45 }, - { product_id: 'PROD010', name: 'USB-C Hub', category: 'Electronics', price: 39.99, stock: 70 }, - ], - - orders: [ - { - order_id: 'ORD001', - customer_id: 'CUST001', - items: [ - { product_id: 'PROD001', product_name: 'Wireless Headphones', quantity: 1, unit_price: 79.99 }, - { product_id: 'PROD008', product_name: 'Water Bottle (32oz)', quantity: 2, unit_price: 19.99 }, - ], - total: 119.97, - status: 'delivered', - order_date: '2026-01-15', - delivery_date: '2026-01-22', - }, - { - order_id: 'ORD002', - customer_id: 'CUST002', - items: [ - { product_id: 'PROD003', product_name: 'Running Shoes (Size 10)', quantity: 1, unit_price: 129.99 }, - ], - total: 129.99, - status: 'delivered', - order_date: '2026-01-20', - delivery_date: '2026-01-27', - }, - { - order_id: 'ORD003', - customer_id: 'CUST003', - items: [ - { product_id: 'PROD005', product_name: 'Cotton T-Shirt (M)', quantity: 3, unit_price: 24.99 }, - ], - total: 74.97, - status: 'delivered', - order_date: '2025-12-10', - delivery_date: '2025-12-17', - }, - { - order_id: 'ORD004', - customer_id: 'CUST004', - items: [ - { product_id: 'PROD009', product_name: 'Laptop Stand', quantity: 1, unit_price: 59.99 }, - { product_id: 'PROD010', product_name: 'USB-C Hub', quantity: 1, unit_price: 39.99 }, - ], - total: 99.98, - status: 'shipped', - order_date: '2026-02-05', - }, - { - order_id: 'ORD005', - customer_id: 'CUST005', - items: [ - { product_id: 'PROD007', product_name: 'Yoga Mat', quantity: 1, unit_price: 34.99 }, - ], - total: 34.99, - status: 'delivered', - order_date: '2026-02-01', - delivery_date: '2026-02-07', - }, - ], - }; -} diff --git a/tests/benchmark/tau/domains/retail/handlers.ts b/tests/benchmark/tau/domains/retail/handlers.ts deleted file mode 100644 index 04b7db3..0000000 --- a/tests/benchmark/tau/domains/retail/handlers.ts +++ /dev/null @@ -1,147 +0,0 @@ -// --------------------------------------------------------------------------- -// Retail domain tool handlers -// --------------------------------------------------------------------------- - -export type ToolHandler = (db: any, args: any) => any; - -export function getRetailHandlers(): Record { - return { - get_customer_details: (db, args: { customer_id: string }) => { - const customer = db.customers.find((c: any) => c.customer_id === args.customer_id); - if (!customer) return { error: `Customer not found: ${args.customer_id}` }; - return customer; - }, - - get_order_details: (db, args: { order_id: string }) => { - const order = db.orders.find((o: any) => o.order_id === args.order_id); - if (!order) return { error: `Order not found: ${args.order_id}` }; - return order; - }, - - list_customer_orders: (db, args: { customer_id: string }) => { - const orders = db.orders.filter((o: any) => o.customer_id === args.customer_id); - return { orders }; - }, - - get_product_details: (db, args: { product_id: string }) => { - const product = db.products.find((p: any) => p.product_id === args.product_id); - if (!product) return { error: `Product not found: ${args.product_id}` }; - return product; - }, - - search_products: (db, args: { query: string }) => { - const q = args.query.toLowerCase(); - const results = db.products.filter( - (p: any) => p.name.toLowerCase().includes(q) || p.category.toLowerCase().includes(q), - ); - return { products: results }; - }, - - process_return: (db, args: { order_id: string }) => { - const order = db.orders.find((o: any) => o.order_id === args.order_id); - if (!order) return { error: `Order not found: ${args.order_id}` }; - if (order.status !== 'delivered') { - return { error: `Order ${args.order_id} is not in delivered status (current: ${order.status})` }; - } - - // Check 30-day return window - if (order.delivery_date) { - const deliveryDate = new Date(order.delivery_date); - const now = new Date(); - const daysSinceDelivery = Math.floor( - (now.getTime() - deliveryDate.getTime()) / (1000 * 60 * 60 * 24), - ); - if (daysSinceDelivery > 30) { - return { - error: `Return window expired. Order was delivered ${daysSinceDelivery} days ago (30-day limit).`, - }; - } - } - - // Restock items - for (const item of order.items) { - const product = db.products.find((p: any) => p.product_id === item.product_id); - if (product) product.stock += item.quantity; - } - - order.status = 'returned'; - return { success: true, order: { ...order } }; - }, - - process_exchange: ( - db, - args: { order_id: string; old_product_id: string; new_product_id: string }, - ) => { - const order = db.orders.find((o: any) => o.order_id === args.order_id); - if (!order) return { error: `Order not found: ${args.order_id}` }; - if (order.status !== 'delivered') { - return { error: `Order ${args.order_id} is not in delivered status` }; - } - - // Check 30-day exchange window - if (order.delivery_date) { - const deliveryDate = new Date(order.delivery_date); - const now = new Date(); - const daysSinceDelivery = Math.floor( - (now.getTime() - deliveryDate.getTime()) / (1000 * 60 * 60 * 24), - ); - if (daysSinceDelivery > 30) { - return { - error: `Exchange window expired. Order was delivered ${daysSinceDelivery} days ago (30-day limit).`, - }; - } - } - - // Find old item in order - const oldItemIndex = order.items.findIndex( - (i: any) => i.product_id === args.old_product_id, - ); - if (oldItemIndex === -1) { - return { error: `Product ${args.old_product_id} not found in order ${args.order_id}` }; - } - - // Find new product - const newProduct = db.products.find((p: any) => p.product_id === args.new_product_id); - if (!newProduct) return { error: `Product not found: ${args.new_product_id}` }; - if (newProduct.stock <= 0) { - return { error: `Product ${args.new_product_id} (${newProduct.name}) is out of stock` }; - } - - const oldItem = order.items[oldItemIndex]; - - // Restock old product - const oldProduct = db.products.find((p: any) => p.product_id === args.old_product_id); - if (oldProduct) oldProduct.stock += oldItem.quantity; - - // Deduct new product stock - newProduct.stock -= oldItem.quantity; - - // Update order item - order.items[oldItemIndex] = { - product_id: newProduct.product_id, - product_name: newProduct.name, - quantity: oldItem.quantity, - unit_price: newProduct.price, - }; - - // Recalculate total - order.total = order.items.reduce( - (sum: number, i: any) => sum + i.unit_price * i.quantity, - 0, - ); - - return { - success: true, - order: { ...order }, - price_difference: newProduct.price - oldItem.unit_price, - }; - }, - - update_order_status: (db, args: { order_id: string; new_status: string }) => { - const order = db.orders.find((o: any) => o.order_id === args.order_id); - if (!order) return { error: `Order not found: ${args.order_id}` }; - order.status = args.new_status; - return { success: true, order: { ...order } }; - }, - }; -} diff --git a/tests/benchmark/tau/domains/retail/policy.md b/tests/benchmark/tau/domains/retail/policy.md deleted file mode 100644 index 93749cd..0000000 --- a/tests/benchmark/tau/domains/retail/policy.md +++ /dev/null @@ -1,53 +0,0 @@ -# Retail Customer Service Policy - -You are an online retail customer service agent. Follow these policies strictly. - -## Identity Verification - -- Verify the customer's identity before making changes to their orders. -- Ask for the customer's name or customer ID. -- Use `get_customer_details` to look up their information and confirm. - -## Order Status - -- Use `get_order_details` to check order status. -- Provide the customer with their order status, items, and tracking info if available. -- Order statuses: pending, shipped, delivered, cancelled, returned. - -## Returns - -- Items may be returned within 30 days of delivery. -- Items must be in unused, original condition. -- Use `process_return` to initiate the return. -- Refunds are processed to the original payment method within 5-10 business days. -- If outside the 30-day window, politely deny the return and explain the policy. - -## Exchanges - -- Exchanges are allowed within 30 days of delivery. -- The replacement item must be in stock. -- Use `search_products` to find alternatives, then `process_exchange` to complete. -- If the new item costs more, the customer pays the difference. -- If the new item costs less, the difference is refunded. - -## Membership Discounts - -- VIP members: 10% discount on all purchases. -- Premium members: 15% discount on all purchases. -- Regular members: no discount. -- Discounts cannot be applied retroactively to past orders. -- Use the customer's membership tier from their profile. - -## Shipping - -- Standard shipping: 5-7 business days, free for orders over $50. -- Express shipping: 2-3 business days, $9.99. -- Overnight shipping: next business day, $19.99. - -## General Rules - -- Be polite, helpful, and professional. -- Do not make up information. Only provide details from the database. -- If you cannot fulfill a request, explain clearly why. -- When the customer's issue is fully resolved, end with "###STOP###". -- If the customer says goodbye or has no more questions, end with "###STOP###". diff --git a/tests/benchmark/tau/domains/retail/tasks.json b/tests/benchmark/tau/domains/retail/tasks.json deleted file mode 100644 index 3be303d..0000000 --- a/tests/benchmark/tau/domains/retail/tasks.json +++ /dev/null @@ -1,67 +0,0 @@ -[ - { - "task_id": "retail_001", - "user_scenario": "You are Emma Wilson (customer ID: CUST001). You received your order ORD001 but the Wireless Headphones are defective - they won't pair with your phone. You want to return the headphones. You're happy to keep the water bottles. If asked, confirm it was delivered on January 22.", - "expected_db": { - "orders": [ - { - "order_id": "ORD001", - "status": "returned" - } - ] - }, - "max_turns": 10 - }, - { - "task_id": "retail_002", - "user_scenario": "You are James Brown (customer ID: CUST002). You received Running Shoes (Size 10) in order ORD002 but they're too small. You want to exchange them for Size 11 instead. If Size 11 is out of stock, ask when they'll be available.", - "expected_db": { - "orders": [ - { - "order_id": "ORD002", - "status": "delivered" - } - ] - }, - "max_turns": 10 - }, - { - "task_id": "retail_003", - "user_scenario": "You are Sophia Lee (customer ID: CUST003). You want to return the Cotton T-Shirts from order ORD003 because they don't fit. Your order was delivered on December 17, 2025. If the agent says it's past the return window, accept that.", - "expected_db": { - "orders": [ - { - "order_id": "ORD003", - "status": "delivered" - } - ] - }, - "max_turns": 8 - }, - { - "task_id": "retail_004", - "user_scenario": "You are Liam Martinez (customer ID: CUST004). You placed order ORD004 and want to know when it will arrive. You also want to know what items are in the order since you forgot what you bought.", - "expected_db": { - "orders": [ - { - "order_id": "ORD004", - "status": "shipped" - } - ] - }, - "max_turns": 8 - }, - { - "task_id": "retail_005", - "user_scenario": "You are Olivia Davis (customer ID: CUST005). You received your Yoga Mat from order ORD005 but want to exchange it for a Bluetooth Speaker instead since you already have a yoga mat at home. Accept whatever the price difference is.", - "expected_db": { - "orders": [ - { - "order_id": "ORD005", - "status": "delivered" - } - ] - }, - "max_turns": 10 - } -] diff --git a/tests/benchmark/tau/domains/retail/tools.ts b/tests/benchmark/tau/domains/retail/tools.ts deleted file mode 100644 index 6749502..0000000 --- a/tests/benchmark/tau/domains/retail/tools.ts +++ /dev/null @@ -1,147 +0,0 @@ -// --------------------------------------------------------------------------- -// Retail domain tool definitions (Anthropic API format) -// --------------------------------------------------------------------------- - -export interface ToolDef { - name: string; - description: string; - input_schema: Record; -} - -export function getRetailToolDefs(): ToolDef[] { - return [ - { - name: 'get_customer_details', - description: - 'Look up customer information by customer ID. Returns name, email, phone, and membership tier.', - input_schema: { - type: 'object', - properties: { - customer_id: { - type: 'string', - description: 'The customer ID (e.g. "CUST001")', - }, - }, - required: ['customer_id'], - }, - }, - { - name: 'get_order_details', - description: - 'Look up an order by order ID. Returns items, total, status, and dates.', - input_schema: { - type: 'object', - properties: { - order_id: { - type: 'string', - description: 'The order ID (e.g. "ORD001")', - }, - }, - required: ['order_id'], - }, - }, - { - name: 'list_customer_orders', - description: - 'List all orders for a given customer.', - input_schema: { - type: 'object', - properties: { - customer_id: { - type: 'string', - description: 'The customer ID whose orders to list', - }, - }, - required: ['customer_id'], - }, - }, - { - name: 'get_product_details', - description: - 'Get details for a specific product including price, category, and stock.', - input_schema: { - type: 'object', - properties: { - product_id: { - type: 'string', - description: 'The product ID (e.g. "PROD001")', - }, - }, - required: ['product_id'], - }, - }, - { - name: 'search_products', - description: - 'Search for products by name or category. Returns matching products with availability.', - input_schema: { - type: 'object', - properties: { - query: { - type: 'string', - description: 'Search term to match against product name or category', - }, - }, - required: ['query'], - }, - }, - { - name: 'process_return', - description: - 'Process a return for an order. Sets the order status to "returned" and restocks items. Only valid for delivered orders within the return window.', - input_schema: { - type: 'object', - properties: { - order_id: { - type: 'string', - description: 'The order ID to return', - }, - }, - required: ['order_id'], - }, - }, - { - name: 'process_exchange', - description: - 'Exchange an item in an order for a different product. The original item is restocked and the new item is deducted from stock.', - input_schema: { - type: 'object', - properties: { - order_id: { - type: 'string', - description: 'The order ID containing the item to exchange', - }, - old_product_id: { - type: 'string', - description: 'The product ID being returned', - }, - new_product_id: { - type: 'string', - description: 'The product ID to exchange for', - }, - }, - required: ['order_id', 'old_product_id', 'new_product_id'], - }, - }, - { - name: 'update_order_status', - description: - 'Update the status of an order (e.g. to "cancelled").', - input_schema: { - type: 'object', - properties: { - order_id: { - type: 'string', - description: 'The order ID to update', - }, - new_status: { - type: 'string', - enum: ['pending', 'shipped', 'delivered', 'cancelled', 'returned'], - description: 'The new status', - }, - }, - required: ['order_id', 'new_status'], - }, - }, - ]; -} diff --git a/tests/benchmark/tau/environment.ts b/tests/benchmark/tau/environment.ts deleted file mode 100644 index 403b7f8..0000000 --- a/tests/benchmark/tau/environment.ts +++ /dev/null @@ -1,44 +0,0 @@ -// --------------------------------------------------------------------------- -// TAU benchmark environment — manages DB state and dispatches tool calls -// --------------------------------------------------------------------------- - -export type ToolHandler = (db: any, args: any) => any; - -export class Environment { - private db: any; - private toolCallLog: Array<{ name: string; args: any; result: any }> = []; - private handlers: Record; - - constructor(initialDb: any, handlers: Record) { - // Deep clone so each trial gets an isolated copy - this.db = JSON.parse(JSON.stringify(initialDb)); - this.handlers = handlers; - } - - /** Return current database state (deep clone) as a generic record. */ - getState(): Record { - return JSON.parse(JSON.stringify(this.db)); - } - - /** Return log of all tool calls made during this simulation. */ - getToolCallLog() { - return this.toolCallLog; - } - - /** Dispatch a tool call by name. Returns the tool result as a JSON-serialisable value. */ - executeTool(name: string, args: any): any { - let result: any; - try { - const handler = this.handlers[name]; - if (!handler) { - result = { error: `Unknown tool: ${name}` }; - } else { - result = handler(this.db, args); - } - } catch (err: any) { - result = { error: err.message || String(err) }; - } - this.toolCallLog.push({ name, args, result }); - return result; - } -} diff --git a/tests/benchmark/tau/evaluator.ts b/tests/benchmark/tau/evaluator.ts deleted file mode 100644 index f91b992..0000000 --- a/tests/benchmark/tau/evaluator.ts +++ /dev/null @@ -1,69 +0,0 @@ -// --------------------------------------------------------------------------- -// TAU benchmark evaluator — DB state comparison + pass^k calculation -// --------------------------------------------------------------------------- - -/** - * Compare the final database state against expected changes. - * - * `expectedDb` is a partial DB snapshot: for each table, an array of objects - * specifying the fields that must match. Each object must contain the table's - * primary key field (e.g. `reservation_id`) so we can look up the record. - * - * Returns `true` if all specified fields in all expected records match. - */ -export function evaluateDBState( - finalDb: Record, - expectedDb: Record, -): boolean { - for (const [table, expectedRecords] of Object.entries(expectedDb)) { - const actualRecords: any[] = finalDb[table]; - if (!actualRecords) return false; - - for (const expected of expectedRecords) { - // Find primary key field (first field ending with _id) - const pkField = Object.keys(expected).find(k => k.endsWith('_id')); - if (!pkField) continue; - - const actual = actualRecords.find(r => r[pkField] === expected[pkField]); - if (!actual) return false; - - // Check all specified fields - for (const [key, value] of Object.entries(expected)) { - if (actual[key] !== value) return false; - } - } - } - - return true; -} - -/** - * Compute pass^k metrics from trial results. - * - * For each task, we have an array of boolean results (one per trial). - * pass^k = fraction of tasks where ALL of the first k trials passed. - * - * Returns an array [pass^1, pass^2, ..., pass^numTrials]. - */ -export function computePassK( - taskTrialResults: boolean[][], - numTrials: number, -): number[] { - if (taskTrialResults.length === 0) return []; - - const passAtK: number[] = []; - - for (let k = 1; k <= numTrials; k++) { - let passCount = 0; - for (const trials of taskTrialResults) { - // Check if all of the first k trials passed - const firstK = trials.slice(0, k); - if (firstK.length >= k && firstK.every(r => r)) { - passCount++; - } - } - passAtK.push(passCount / taskTrialResults.length); - } - - return passAtK; -} diff --git a/tests/benchmark/tau/index.ts b/tests/benchmark/tau/index.ts deleted file mode 100644 index fe90c9c..0000000 --- a/tests/benchmark/tau/index.ts +++ /dev/null @@ -1,252 +0,0 @@ -// --------------------------------------------------------------------------- -// TAU benchmark module — BenchmarkModule entry point -// --------------------------------------------------------------------------- - -import fs from 'fs'; -import path from 'path'; -import type { BenchmarkConfig, BenchmarkModuleResult, BenchmarkProvider, TAUProviderResult, TAUTaskResult } from '../types'; -import type { ModelProvider } from '../../../src/infra/providers/types'; -import { AnthropicProvider } from '../../../src/infra/providers/anthropic'; -import { OpenAIProvider } from '../../../src/infra/providers/openai'; -import { GeminiProvider } from '../../../src/infra/providers/gemini'; -import { getInitialDatabase as getAirlineDb } from './domains/airline/database'; -import { getAirlineToolDefs } from './domains/airline/tools'; -import { getAirlineHandlers } from './domains/airline/handlers'; -import { getInitialDatabase as getRetailDb } from './domains/retail/database'; -import { getRetailToolDefs } from './domains/retail/tools'; -import { getRetailHandlers } from './domains/retail/handlers'; -import type { ToolHandler } from './environment'; -import { Environment } from './environment'; -import { UserSimulator } from './user-simulator'; -import { runOrchestration } from './orchestrator'; -import { evaluateDBState, computePassK } from './evaluator'; - -// Module metadata (used by run-benchmark.ts discovery) -export const name = 'tau'; - -// --------------------------------------------------------------------------- -// Domain loading -// --------------------------------------------------------------------------- - -interface DomainData { - id: string; - policy: string; - toolDefs: any[]; - getInitialDatabase: () => any; - getHandlers: () => Record; - tasks: Array<{ - task_id: string; - user_scenario: string; - expected_db: Record; - max_turns: number; - }>; -} - -function loadDomain(domainId: string): DomainData | null { - const domainDir = path.join(__dirname, 'domains', domainId); - const policyPath = path.join(domainDir, 'policy.md'); - const tasksPath = path.join(domainDir, 'tasks.json'); - - if (!fs.existsSync(policyPath) || !fs.existsSync(tasksPath)) return null; - - const policy = fs.readFileSync(policyPath, 'utf-8'); - const tasks = JSON.parse(fs.readFileSync(tasksPath, 'utf-8')); - - switch (domainId) { - case 'airline': - return { - id: domainId, - policy, - toolDefs: getAirlineToolDefs(), - getInitialDatabase: getAirlineDb, - getHandlers: getAirlineHandlers, - tasks, - }; - case 'retail': - return { - id: domainId, - policy, - toolDefs: getRetailToolDefs(), - getInitialDatabase: getRetailDb, - getHandlers: getRetailHandlers, - tasks, - }; - default: - return null; - } -} - -function getAvailableDomains(tauDomain: string): DomainData[] { - const domains: DomainData[] = []; - const candidates = tauDomain === 'all' ? ['airline', 'retail'] : [tauDomain]; - - for (const id of candidates) { - const domain = loadDomain(id); - if (domain) domains.push(domain); - } - - return domains; -} - -// --------------------------------------------------------------------------- -// Provider creation -// --------------------------------------------------------------------------- - -function createProvider(bp: BenchmarkProvider): ModelProvider { - switch (bp.id) { - case 'anthropic': - return new AnthropicProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); - case 'openai': - return new OpenAIProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); - case 'gemini': - return new GeminiProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); - default: - // For glm, minimax, etc. — try OpenAI-compatible - return new OpenAIProvider(bp.apiKey, bp.model, bp.baseUrl, bp.proxyUrl); - } -} - -// --------------------------------------------------------------------------- -// Build system prompt -// --------------------------------------------------------------------------- - -const DOMAIN_ROLES: Record = { - airline: 'an airline customer service agent', - retail: 'an online retail customer service agent', -}; - -function buildSystemPrompt(domainId: string, policy: string, toolDefs: any[]): string { - const role = DOMAIN_ROLES[domainId] || 'a customer service agent'; - const toolList = toolDefs.map(t => `- ${t.name}: ${t.description}`).join('\n'); - return [ - `You are ${role}. Follow the policy below strictly.`, - '', - '--- POLICY ---', - policy, - '--- END POLICY ---', - '', - 'Available tools:', - toolList, - '', - 'Instructions:', - '- Use tools to look up and modify data. Do not guess or make up information.', - '- When the customer\'s issue is fully resolved, include "###STOP###" at the end of your final message.', - '- Be concise and professional.', - ].join('\n'); -} - -// --------------------------------------------------------------------------- -// Run single provider across a domain -// --------------------------------------------------------------------------- - -async function runProviderOnDomain( - bp: BenchmarkProvider, - userSimBp: BenchmarkProvider, - domain: DomainData, - config: BenchmarkConfig, -): Promise { - const agentProvider = createProvider(bp); - const userSimProvider = createProvider(userSimBp); - const systemPrompt = buildSystemPrompt(domain.id, domain.policy, domain.toolDefs); - const results: TAUTaskResult[] = []; - - // Collect pass/fail per task across trials for pass^k calculation - const taskTrialMatrix: boolean[][] = []; - - for (const task of domain.tasks) { - const trialResults: boolean[] = []; - let totalTokens = 0; - let lastError: string | undefined; - - for (let trial = 0; trial < config.numTrials; trial++) { - // Fresh environment for each trial - const env = new Environment(domain.getInitialDatabase(), domain.getHandlers()); - const userSim = new UserSimulator(userSimProvider, task.user_scenario); - - const orchResult = await runOrchestration({ - agentProvider, - userSimulator: userSim, - environment: env, - systemPrompt, - toolDefs: domain.toolDefs, - maxTurns: task.max_turns, - timeoutMs: config.timeoutMs, - expectedDb: task.expected_db, - evaluate: evaluateDBState, - }); - - trialResults.push(orchResult.passed); - totalTokens += orchResult.agentTokens; - if (orchResult.error) lastError = orchResult.error; - - // Log progress - const status = orchResult.passed ? 'PASS' : 'FAIL'; - const errorSuffix = orchResult.error ? ` (${orchResult.error})` : ''; - console.log( - ` [${bp.id}] ${task.task_id} trial ${trial + 1}/${config.numTrials}: ${status} (${orchResult.turns} turns, ${orchResult.agentTokens} tokens)${errorSuffix}`, - ); - } - - taskTrialMatrix.push(trialResults); - results.push({ - task_id: task.task_id, - trial_pass_rates: trialResults, - tokens_used: Math.round(totalTokens / config.numTrials), - error: trialResults.every(r => !r) ? lastError : undefined, - }); - } - - // Compute pass^k - const passAtK = computePassK(taskTrialMatrix, config.numTrials); - const avgTokens = - results.length > 0 ? Math.round(results.reduce((s, r) => s + r.tokens_used, 0) / results.length) : 0; - - return { - provider: bp, - summary: { - domain: domain.id, - total_tasks: domain.tasks.length, - num_trials: config.numTrials, - pass_at_k: passAtK, - avg_tokens: avgTokens, - }, - results, - }; -} - -// --------------------------------------------------------------------------- -// Module entry point -// --------------------------------------------------------------------------- - -export async function run(config: BenchmarkConfig): Promise { - const domains = getAvailableDomains(config.tauDomain); - - if (domains.length === 0) { - console.log(` TAU: no domains found for "${config.tauDomain}"`); - return {}; - } - - if (config.providers.length === 0) { - console.log(' TAU: no providers configured, skipping'); - return {}; - } - - const allResults: TAUProviderResult[] = []; - - for (const domain of domains) { - console.log(`\n TAU domain: ${domain.id} (${domain.tasks.length} tasks, ${config.numTrials} trials)`); - - for (const bp of config.providers) { - // Use userSimProvider if configured, otherwise same as agent provider - const userSimBp = config.userSimProvider ?? bp; - - console.log(`\n Running provider: ${bp.id} / ${bp.model}`); - console.log(` User simulator: ${userSimBp.id} / ${userSimBp.model}`); - - const providerResult = await runProviderOnDomain(bp, userSimBp, domain, config); - allResults.push(providerResult); - } - } - - return { tau: allResults }; -} diff --git a/tests/benchmark/tau/orchestrator.ts b/tests/benchmark/tau/orchestrator.ts deleted file mode 100644 index 0c34d70..0000000 --- a/tests/benchmark/tau/orchestrator.ts +++ /dev/null @@ -1,201 +0,0 @@ -// --------------------------------------------------------------------------- -// TAU benchmark orchestrator — Agent ↔ User ↔ Environment message loop -// -// Follows τ-bench protocol: -// 1. User initiates conversation -// 2. Agent responds (text or tool calls) -// 3. If tool calls → environment executes → results fed back to agent -// 4. If text → forwarded to user simulator -// 5. Repeat until ###STOP### or max turns -// --------------------------------------------------------------------------- - -import type { ModelProvider } from '../../../src/infra/providers/types'; -import type { Message, ContentBlock } from '../../../src/core/types'; -import type { ToolDef } from './domains/airline/tools'; -import { Environment } from './environment'; -import { UserSimulator } from './user-simulator'; - -const STOP_SIGNAL = '###STOP###'; -const MAX_TOOL_ROUNDS = 10; // Safety limit for consecutive tool calls in one turn - -export interface ConversationMessage { - role: 'user' | 'assistant'; - content: string; -} - -export interface OrchestrationResult { - passed: boolean; - messages: ConversationMessage[]; - agentTokens: number; - userSimTokens: number; - turns: number; - error?: string; -} - -export interface OrchestrationOptions { - agentProvider: ModelProvider; - userSimulator: UserSimulator; - environment: Environment; - systemPrompt: string; - toolDefs: ToolDef[]; - maxTurns: number; - timeoutMs: number; - expectedDb: Record; - evaluate: (finalDb: Record, expectedDb: Record) => boolean; -} - -export async function runOrchestration(opts: OrchestrationOptions): Promise { - const { - agentProvider, - userSimulator, - environment, - systemPrompt, - toolDefs, - maxTurns, - timeoutMs, - expectedDb, - evaluate, - } = opts; - - const conversationLog: ConversationMessage[] = []; - // Internal messages in SDK format for model.complete() - const modelMessages: Message[] = []; - let agentTokens = 0; - let userSimTokens = 0; - let turns = 0; - - try { - // Wrap the entire orchestration in a timeout - const result = await withTimeout(async () => { - // 1. User generates first message - const firstMsg = await userSimulator.generateFirstMessage(); - userSimTokens += firstMsg.tokens; - conversationLog.push({ role: 'user', content: firstMsg.text }); - modelMessages.push(textMsg('user', firstMsg.text)); - - // 2. Conversation loop - while (turns < maxTurns) { - // --- Agent turn --- - const agentText = await runAgentTurn( - agentProvider, - modelMessages, - systemPrompt, - toolDefs, - environment, - (t) => { agentTokens += t; }, - ); - - conversationLog.push({ role: 'assistant', content: agentText }); - turns++; - - // Check agent stop signal - if (agentText.includes(STOP_SIGNAL)) break; - - // --- User turn --- - const userReply = await userSimulator.generateResponse( - agentText, - conversationLog.slice(0, -1), // history without the latest agent msg - ); - userSimTokens += userReply.tokens; - conversationLog.push({ role: 'user', content: userReply.text }); - modelMessages.push(textMsg('user', userReply.text)); - - // Check user stop signal - if (userReply.done) break; - } - - // 3. Evaluate - const finalDb = environment.getState(); - const passed = evaluate(finalDb, expectedDb); - - return { passed, messages: conversationLog, agentTokens, userSimTokens, turns }; - }, timeoutMs); - - return result; - } catch (err: any) { - return { - passed: false, - messages: conversationLog, - agentTokens, - userSimTokens, - turns, - error: err.message || String(err), - }; - } -} - -// --------------------------------------------------------------------------- -// Agent turn: call model, handle tool loops, return final text -// --------------------------------------------------------------------------- - -async function runAgentTurn( - provider: ModelProvider, - modelMessages: Message[], - systemPrompt: string, - toolDefs: ToolDef[], - environment: Environment, - addTokens: (t: number) => void, -): Promise { - let toolRounds = 0; - - while (toolRounds < MAX_TOOL_ROUNDS) { - const response = await provider.complete(modelMessages, { - system: systemPrompt, - tools: toolDefs, - }); - - const usage = (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); - addTokens(usage); - - // Separate text and tool_use blocks - const textBlocks = response.content.filter( - (b): b is { type: 'text'; text: string } => b.type === 'text', - ); - const toolUseBlocks = response.content.filter( - (b): b is { type: 'tool_use'; id: string; name: string; input: any } => b.type === 'tool_use', - ); - - // If no tool calls, return text - if (toolUseBlocks.length === 0) { - const text = textBlocks.map(b => b.text).join(''); - modelMessages.push({ role: 'assistant', content: response.content }); - return text; - } - - // Handle tool calls - modelMessages.push({ role: 'assistant', content: response.content }); - - const toolResults: ContentBlock[] = toolUseBlocks.map(tc => { - const result = environment.executeTool(tc.name, tc.input); - return { - type: 'tool_result' as const, - tool_use_id: tc.id, - content: JSON.stringify(result), - }; - }); - - modelMessages.push({ role: 'user', content: toolResults }); - toolRounds++; - } - - // Safety: too many tool rounds — return whatever text we have - return '[Agent exceeded maximum tool call rounds]'; -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -function textMsg(role: 'user' | 'assistant' | 'system', text: string): Message { - return { role, content: [{ type: 'text', text }] }; -} - -function withTimeout(fn: () => Promise, ms: number): Promise { - return new Promise((resolve, reject) => { - const timer = setTimeout(() => reject(new Error(`Timeout after ${ms}ms`)), ms); - fn().then( - v => { clearTimeout(timer); resolve(v); }, - e => { clearTimeout(timer); reject(e); }, - ); - }); -} diff --git a/tests/benchmark/tau/user-simulator.ts b/tests/benchmark/tau/user-simulator.ts deleted file mode 100644 index 9bd5c0f..0000000 --- a/tests/benchmark/tau/user-simulator.ts +++ /dev/null @@ -1,107 +0,0 @@ -// --------------------------------------------------------------------------- -// TAU benchmark user simulator — LLM-powered simulated user -// --------------------------------------------------------------------------- - -import type { ModelProvider } from '../../../src/infra/providers/types'; -import type { Message, ContentBlock } from '../../../src/core/types'; - -const STOP_SIGNAL = '###STOP###'; - -export interface UserSimulatorResult { - text: string; - tokens: number; - done: boolean; -} - -export class UserSimulator { - private provider: ModelProvider; - private scenario: string; - - constructor(provider: ModelProvider, scenario: string) { - this.provider = provider; - this.scenario = scenario; - } - - /** - * Generate the first user message (initiating the conversation). - */ - async generateFirstMessage(): Promise { - const messages: Message[] = [ - { - role: 'user', - content: [ - { - type: 'text', - text: 'Generate your opening message to the customer service agent. Be natural and concise — state who you are and what you need. Respond with ONLY the message text, nothing else.', - }, - ], - }, - ]; - - return this.callModel(messages); - } - - /** - * Generate the next user response based on the agent's message. - */ - async generateResponse(agentMessage: string, history: Array<{ role: string; content: string }>): Promise { - // Build conversation history for the user simulator - const messages: Message[] = []; - - // Add previous turns (alternating user/assistant from the USER's perspective: - // the user simulator sees agent messages as "user" input and its own messages as "assistant" output) - for (const msg of history) { - if (msg.role === 'user') { - // This was a user-sim output — from user-sim's perspective it's "assistant" - messages.push({ role: 'assistant', content: [{ type: 'text', text: msg.content }] }); - } else if (msg.role === 'assistant') { - // This was an agent output — from user-sim's perspective it's "user" input - messages.push({ role: 'user', content: [{ type: 'text', text: msg.content }] }); - } - } - - // Latest agent message - messages.push({ - role: 'user', - content: [ - { - type: 'text', - text: `The customer service agent said:\n\n${agentMessage}\n\nRespond as the customer. If your issue is resolved, say goodbye naturally and include "${STOP_SIGNAL}" at the end of your message. Respond with ONLY the message text, nothing else.`, - }, - ], - }); - - return this.callModel(messages); - } - - private async callModel(messages: Message[]): Promise { - const systemPrompt = [ - 'You are simulating a customer calling airline customer service.', - 'Follow this scenario exactly:', - '', - this.scenario, - '', - 'Rules:', - '- Stay in character. Only say things consistent with your scenario.', - '- Be natural and conversational, like a real customer.', - '- Provide information when asked (your name, reservation ID, etc.).', - '- Do not invent details not in your scenario.', - `- When your issue is fully resolved and you have no more questions, include "${STOP_SIGNAL}" at the end of your final message.`, - '- Respond with ONLY the customer message text. Do not add any meta-commentary.', - ].join('\n'); - - const response = await this.provider.complete(messages, { system: systemPrompt }); - - const text = response.content - .filter((b): b is { type: 'text'; text: string } => b.type === 'text') - .map(b => b.text) - .join(''); - - const tokens = - (response.usage?.input_tokens ?? 0) + (response.usage?.output_tokens ?? 0); - - const done = text.includes(STOP_SIGNAL); - - return { text: text.replace(STOP_SIGNAL, '').trim(), tokens, done }; - } -} diff --git a/tests/benchmark/types.ts b/tests/benchmark/types.ts index 789cdc1..6bc0130 100644 --- a/tests/benchmark/types.ts +++ b/tests/benchmark/types.ts @@ -1,9 +1,5 @@ import type { ProviderId } from '../helpers/provider-env'; -// --------------------------------------------------------------------------- -// Provider -// --------------------------------------------------------------------------- - export interface BenchmarkProvider { id: ProviderId; model: string; @@ -12,64 +8,40 @@ export interface BenchmarkProvider { proxyUrl?: string; } -// --------------------------------------------------------------------------- -// CLI args -// --------------------------------------------------------------------------- - export interface BenchmarkCliArgs { - sweOnly: boolean; - tauOnly: boolean; - sweMode?: 'mini' | 'full'; - tauDomain?: string; + benchmark?: 'swe' | 'tb2' | 'both'; provider?: string; - numTrials?: number; - output?: 'table' | 'json' | 'html' | 'both'; + tb2Model?: string; + tb2Agent?: string; + tb2Dataset?: string; + tb2Runner?: 'auto' | 'harbor' | 'uvx' | 'docker'; + tb2Python?: string; + tb2JobsDir?: string; + tb2EnvFile?: string; + tb2DockerImage?: string; + output?: 'table' | 'json'; outputFile?: string; compare?: string; } -// --------------------------------------------------------------------------- -// Config (merged env + CLI) -// --------------------------------------------------------------------------- - export interface BenchmarkConfig { + benchmark: 'swe' | 'tb2' | 'both'; providers: BenchmarkProvider[]; - userSimProvider?: BenchmarkProvider; timeoutMs: number; - numTrials: number; - output: 'table' | 'json' | 'html' | 'both'; + output: 'table' | 'json'; outputFile: string; - sweMode: 'mini' | 'full'; - tauDomain: string; + tb2Model?: string; + tb2Agent: string; + tb2Dataset: string; + tb2Runner: 'auto' | 'harbor' | 'uvx' | 'docker'; + tb2Python: string; + tb2JobsDir: string; + tb2EnvFile?: string; + tb2DockerImage: string; sdkVersion: string; dockerProxy?: string; } -// --------------------------------------------------------------------------- -// SWE-bench types -// --------------------------------------------------------------------------- - -export interface SWEInstance { - instance_id: string; - repo: string; - base_commit: string; - patch: string; - test_patch: string; - problem_statement: string; - hints_text: string; - created_at: string; - version: string; -} - -export interface MiniSWECase { - id: string; - repo: string; - description: string; - files: Record; - expected_patch: string; - test_command: string; -} - export interface SWEResult { instance_id: string; resolved: boolean; @@ -93,60 +65,27 @@ export interface SWEProviderResult { results: SWEResult[]; } -// --------------------------------------------------------------------------- -// TAU-bench types -// --------------------------------------------------------------------------- - -export interface TAUTask { - task_id: string; - domain: string; - user_instruction: string; - expected_actions: string[]; - tools: string[]; -} - -export interface TAUTaskResult { - task_id: string; - trial_pass_rates: boolean[]; - tokens_used: number; - error?: string; -} - -export interface TAUSummary { - domain: string; - total_tasks: number; - num_trials: number; - pass_at_k: number[]; - avg_tokens: number; -} - -export interface TAUProviderResult { - provider: BenchmarkProvider; - summary: TAUSummary; - results: TAUTaskResult[]; +export interface TB2Summary { + generated_at: string; + dataset: string; + agent: string; + model?: string; + jobs_dir: string; + job_path: string; + passed: number; + total: number; + rate: number; + unknown: number; } -// --------------------------------------------------------------------------- -// Top-level report -// --------------------------------------------------------------------------- - export interface BenchmarkReport { timestamp: string; sdk_version: string; swe?: SWEProviderResult[]; - tau?: TAUProviderResult[]; + tb2?: TB2Summary; } -// --------------------------------------------------------------------------- -// Module contract (Phase 2+ modules implement this) -// --------------------------------------------------------------------------- - export interface BenchmarkModuleResult { swe?: SWEProviderResult[]; - tau?: TAUProviderResult[]; -} - -export interface BenchmarkModule { - name: string; - run(config: BenchmarkConfig): Promise; + tb2?: TB2Summary; } From 3b0f302d383a981676d27bdadc98d2ae323137d0 Mon Sep 17 00:00:00 2001 From: Gui-Yue Date: Fri, 27 Feb 2026 14:17:11 +0800 Subject: [PATCH 3/3] feat(benchmark): restore official TAU2, add token accounting for TAU/TB2, and gate Actions run via var - restore TAU benchmark with official TAU2 harness integration - make --benchmark=both run SWE + TAU + TB2 (all kept as alias) - fix workflow defaults for push/PR (avoid empty benchmark/provider/model args) - add token extraction/aggregation for TAU and TB2 (show N/A when source has no usage) - extend reports/comparisons/types to include TAU/TB2 token stats - update benchmark docs (EN/ZH) and Actions summary output for TAU/TB2 token fields - add workflow gate: run benchmark job only when vars.BENCHMARK_ACTION_ENABLED == '1' --- .github/workflows/benchmark.yml | 53 ++- docs/en/guides/benchmarking.md | 33 +- docs/zh-CN/guides/benchmarking.md | 33 +- tests/benchmark/compare.ts | 88 ++++- tests/benchmark/config.ts | 78 ++++- tests/benchmark/reporter.ts | 66 +++- tests/benchmark/run-benchmark.ts | 26 +- tests/benchmark/run-tb2-official.ts | 121 ++++++- tests/benchmark/tau/index.ts | 502 ++++++++++++++++++++++++++++ tests/benchmark/types.ts | 38 ++- 10 files changed, 990 insertions(+), 48 deletions(-) create mode 100644 tests/benchmark/tau/index.ts diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 53bc29a..176abfc 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -9,11 +9,13 @@ on: required: true default: both options: + - all - both - swe + - tau - tb2 provider: - description: "SWE provider filter" + description: "SWE/TAU provider filter" type: choice required: true default: all @@ -22,6 +24,16 @@ on: - anthropic - openai - gemini + tau_domain: + description: "TAU domain (airline by default for faster runs)" + type: choice + required: true + default: airline + options: + - airline + - retail + - telecom + - all tb2_model: description: "TB2 model in provider/model format" type: string @@ -45,6 +57,7 @@ jobs: name: Benchmark runs-on: ubuntu-latest timeout-minutes: 360 + if: ${{ vars.BENCHMARK_ACTION_ENABLED == '1' }} env: DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} @@ -94,9 +107,15 @@ jobs: - name: Run unified benchmark command run: | mkdir -p tests/tmp + benchmark="${{ github.event.inputs.benchmark || 'both' }}" + provider="${{ github.event.inputs.provider || 'all' }}" + tau_domain="${{ github.event.inputs.tau_domain || 'airline' }}" + tb2_model="${{ github.event.inputs.tb2_model || 'openai/glm-5' }}" + args=( - --benchmark=${{ inputs.benchmark }} - --tb2-model=${{ inputs.tb2_model }} + --benchmark=${benchmark} + --tau-domain=${tau_domain} + --tb2-model=${tb2_model} --tb2-agent=oracle --tb2-runner=uvx --tb2-python=3.12 @@ -105,8 +124,8 @@ jobs: --output-file=tests/tmp/benchmark-report.json ) - if [[ "${{ inputs.provider }}" != "all" && "${{ inputs.benchmark }}" != "tb2" ]]; then - args+=(--provider=${{ inputs.provider }}) + if [[ "${provider}" != "all" && "${benchmark}" != "tb2" ]]; then + args+=(--provider=${provider}) fi npm run test:benchmark -- "${args[@]}" @@ -144,6 +163,24 @@ jobs: console.log(''); } + if (Array.isArray(report.tau) && report.tau.length > 0) { + console.log('### TAU-bench'); + console.log(''); + console.log('| Provider / Model | Domain | Pass^1 | Avg Tokens |'); + console.log('|---|---|---:|---:|'); + for (const r of report.tau) { + const name = `${r.provider.id} / ${r.provider.model}`; + const domain = r.summary.domain; + const pass1 = `${((r.summary.pass_at_k?.[0] ?? 0) * 100).toFixed(1)}%`; + const observed = (r.summary.token_observed_trials ?? 0) > 0; + const avgTokens = observed + ? (r.summary.avg_tokens >= 1000 ? `${(r.summary.avg_tokens / 1000).toFixed(1)}k` : `${r.summary.avg_tokens}`) + : '-'; + console.log(`| ${name} | ${domain} | ${pass1} | ${avgTokens} |`); + } + console.log(''); + } + if (report.tb2) { const tb2 = report.tb2; console.log('### Terminal Bench 2.0'); @@ -152,6 +189,11 @@ jobs: if (tb2.model) console.log(`- Model: \`${tb2.model}\``); console.log(`- Passed: **${tb2.passed}/${tb2.total}**`); console.log(`- Rate: **${(tb2.rate * 100).toFixed(1)}%**`); + if (typeof tb2.avg_total_tokens === 'number' && (tb2.token_observed_trials ?? 0) > 0) { + console.log(`- Avg tokens: **${tb2.avg_total_tokens}** (observed ${tb2.token_observed_trials} trials)`); + } else { + console.log(`- Avg tokens: **N/A**`); + } console.log(''); } NODE @@ -165,3 +207,4 @@ jobs: path: | tests/tmp/benchmark-report.json tests/tmp/jobs/*/result.json + tests/tmp/tau2-data/simulations/*.json diff --git a/docs/en/guides/benchmarking.md b/docs/en/guides/benchmarking.md index 6cc2dbe..e167d8a 100644 --- a/docs/en/guides/benchmarking.md +++ b/docs/en/guides/benchmarking.md @@ -1,10 +1,12 @@ # Benchmarking -KODE SDK benchmark runner now has a single entry command and supports three targets: +KODE SDK benchmark runner now has a single entry command and supports multiple targets: - `swe`: SWE-bench-Verified only +- `tau`: TAU-bench only - `tb2`: Terminal Bench 2.0 only -- `both`: run both in one command +- `both`: run SWE + TAU + TB2 +- `all`: alias of `both` (compatibility) ## Prerequisites @@ -29,6 +31,7 @@ GEMINI_MODEL_ID=gemini-3-pro-preview 3. Runtime tools: - SWE-bench-Verified: Docker is required +- TAU-bench: `tau2` or `uvx` is required (official TAU2 harness) - TB2: `harbor`, `uvx`, or Docker (runner decides by `--tb2-runner`) ## Unified Command @@ -39,7 +42,7 @@ npm run test:benchmark -- [flags] ### Common examples -Run both SWE + TB2 in one command: +Run SWE + TAU + TB2 in one command: ```bash npm run test:benchmark -- \ @@ -72,12 +75,26 @@ npm run test:benchmark -- \ --output-file=tests/tmp/tb2-report.json ``` +Run only TAU-bench (official TAU2 script + dataset): + +```bash +npm run test:benchmark -- \ + --benchmark=tau \ + --provider=openai \ + --tau-domain=airline \ + --num-trials=1 \ + --output=json \ + --output-file=tests/tmp/tau-report.json +``` + ## Flags | Flag | Description | Default | |---|---|---| -| `--benchmark=swe\|tb2\|both` | Which benchmark(s) to run | `both` | -| `--provider=...` | SWE provider filter (`anthropic`, `openai`, `gemini`, etc.) | all discovered | +| `--benchmark=swe\|tau\|tb2\|both\|all` | Which benchmark(s) to run (`both`=`all`) | `both` | +| `--provider=...` | Provider filter for SWE/TAU (`anthropic`, `openai`, `gemini`, etc.) | all discovered | +| `--tau-domain=airline\|retail\|telecom\|all` | TAU domain filter | `airline` | +| `--num-trials=N` | TAU trials per task (Pass^k) | `1` | | `--tb2-model=provider/model` | TB2 model id | `BENCHMARK_TB2_MODEL` or `openai/$OPENAI_MODEL_ID` | | `--tb2-agent=...` | TB2 agent (`oracle`, etc.) | `oracle` | | `--tb2-dataset=...` | TB2 dataset id | `terminal-bench@2.0` | @@ -92,7 +109,7 @@ npm run test:benchmark -- \ ## Output -With `--output=json`, one report contains both sections: +With `--output=json`, one report may contain `swe`, `tau`, and `tb2` sections depending on `--benchmark`. ```json { @@ -118,5 +135,9 @@ With `--output=json`, one report contains both sections: ## Notes - SWE-bench is fixed to **SWE-bench-Verified**. There is no mini/full mode switch anymore. +- TAU now runs with the official **TAU2** harness (`tau2 run ...`) from Sierra. +- TAU default domain is `airline` for faster CI/local feedback. Use `--tau-domain=all` when you need full coverage. +- TAU user simulator can be configured with `BENCHMARK_USER_MODEL=provider/model`. - TB2 uses official Harbor run flow (`harbor run -d terminal-bench@2.0 -m ... -a ...`) under the selected runner. +- TAU/TB2 token stats are extracted from official result files when available; if a runner/agent does not emit usage, it is shown as `N/A`. - If Docker image pulls are slow, set `BENCHMARK_DOCKER_PROXY`. diff --git a/docs/zh-CN/guides/benchmarking.md b/docs/zh-CN/guides/benchmarking.md index baa8405..50d55dd 100644 --- a/docs/zh-CN/guides/benchmarking.md +++ b/docs/zh-CN/guides/benchmarking.md @@ -1,10 +1,12 @@ # Benchmarking -KODE SDK 的 benchmark 入口已统一为一个命令,支持三种目标: +KODE SDK 的 benchmark 入口已统一为一个命令,支持多个目标: - `swe`:只跑 SWE-bench-Verified +- `tau`:只跑 TAU-bench - `tb2`:只跑 Terminal Bench 2.0 -- `both`:一次命令同时跑两者 +- `both`:一次命令跑 SWE + TAU + TB2 +- `all`:`both` 的兼容别名 ## 前置条件 @@ -29,6 +31,7 @@ GEMINI_MODEL_ID=gemini-3-pro-preview 3. 运行依赖: - SWE-bench-Verified:必须有 Docker +- TAU-bench:需要 `tau2` 或 `uvx`(官方 TAU2 harness) - TB2:`harbor`、`uvx` 或 Docker(由 `--tb2-runner` 决定) ## 统一命令 @@ -39,7 +42,7 @@ npm run test:benchmark -- [参数] ### 常用示例 -一次命令同时跑 SWE + TB2: +一次命令同时跑 SWE + TAU + TB2: ```bash npm run test:benchmark -- \ @@ -72,12 +75,26 @@ npm run test:benchmark -- \ --output-file=tests/tmp/tb2-report.json ``` +只跑 TAU-bench(官方 TAU2 脚本与数据集): + +```bash +npm run test:benchmark -- \ + --benchmark=tau \ + --provider=openai \ + --tau-domain=airline \ + --num-trials=1 \ + --output=json \ + --output-file=tests/tmp/tau-report.json +``` + ## 参数说明 | 参数 | 含义 | 默认值 | |---|---|---| -| `--benchmark=swe\|tb2\|both` | 选择要跑的 benchmark | `both` | -| `--provider=...` | SWE provider 过滤(`anthropic`、`openai`、`gemini` 等) | 自动发现全部 | +| `--benchmark=swe\|tau\|tb2\|both\|all` | 选择要跑的 benchmark(`both`=`all`) | `both` | +| `--provider=...` | SWE/TAU 的 provider 过滤(`anthropic`、`openai`、`gemini` 等) | 自动发现全部 | +| `--tau-domain=airline\|retail\|telecom\|all` | TAU 领域过滤 | `airline` | +| `--num-trials=N` | TAU 每个任务试验次数(Pass^k) | `1` | | `--tb2-model=provider/model` | TB2 模型 ID | `BENCHMARK_TB2_MODEL` 或 `openai/$OPENAI_MODEL_ID` | | `--tb2-agent=...` | TB2 agent(如 `oracle`) | `oracle` | | `--tb2-dataset=...` | TB2 数据集 ID | `terminal-bench@2.0` | @@ -92,7 +109,7 @@ npm run test:benchmark -- \ ## 输出格式 -使用 `--output=json` 时,单个报告同时包含 SWE 和 TB2: +使用 `--output=json` 时,报告会按 `--benchmark` 输出 `swe`/`tau`/`tb2` 分区: ```json { @@ -118,5 +135,9 @@ npm run test:benchmark -- \ ## 说明 - SWE 已固定为 **SWE-bench-Verified**,不再有 mini/full 模式参数。 +- TAU 已切换为 Sierra 官方 **TAU2** harness(`tau2 run ...`)。 +- TAU 默认领域改为 `airline`,用于更快的本地/CI反馈;需要全量时使用 `--tau-domain=all`。 +- TAU 的用户模拟模型可通过 `BENCHMARK_USER_MODEL=provider/model` 指定。 - TB2 走官方 Harbor 流程(`harbor run -d terminal-bench@2.0 -m ... -a ...`),由 runner 包装执行。 +- TAU/TB2 的 token 统计会从官方结果文件提取;若 runner/agent 未产出 usage,则显示为 `N/A`。 - 若 Docker 拉取镜像慢,可设置 `BENCHMARK_DOCKER_PROXY`。 diff --git a/tests/benchmark/compare.ts b/tests/benchmark/compare.ts index 22cbc1b..af66ace 100644 --- a/tests/benchmark/compare.ts +++ b/tests/benchmark/compare.ts @@ -1,5 +1,5 @@ import fs from 'fs'; -import type { BenchmarkReport, SWEProviderResult, TB2Summary } from './types'; +import type { BenchmarkReport, SWEProviderResult, TAUProviderResult, TB2Summary } from './types'; interface ComparisonRow { label: string; @@ -11,6 +11,7 @@ interface ComparisonRow { interface ComparisonResult { swe: ComparisonRow[]; + tau: ComparisonRow[]; tb2: ComparisonRow[]; hasRegressions: boolean; } @@ -104,6 +105,58 @@ function compareSWE(oldResults: SWEProviderResult[], newResults: SWEProviderResu return rows; } +function compareTAU(oldResults: TAUProviderResult[], newResults: TAUProviderResult[]): ComparisonRow[] { + const rows: ComparisonRow[] = []; + + for (const newR of newResults) { + const key = `${newR.provider.id}/${newR.provider.model} [${newR.summary.domain}]`; + const oldR = oldResults.find( + r => + r.provider.id === newR.provider.id + && r.provider.model === newR.provider.model + && r.summary.domain === newR.summary.domain, + ); + + if (!oldR) { + const pass1 = newR.summary.pass_at_k[0] ?? 0; + rows.push({ + label: `${key} [pass^1]`, + oldValue: '-', + newValue: fmtPct(pass1), + delta: 'new', + direction: 'na', + }); + continue; + } + + const oldPass1 = oldR.summary.pass_at_k[0] ?? 0; + const newPass1 = newR.summary.pass_at_k[0] ?? 0; + const passDelta = deltaStr(oldPass1, newPass1, 'pct'); + rows.push({ + label: `${key} [pass^1]`, + oldValue: fmtPct(oldPass1), + newValue: fmtPct(newPass1), + delta: passDelta.text, + direction: passDelta.dir, + }); + + const oldTokObserved = (oldR.summary.token_observed_trials ?? 0) > 0; + const newTokObserved = (newR.summary.token_observed_trials ?? 0) > 0; + if (oldTokObserved && newTokObserved) { + const tokenDelta = deltaStr(oldR.summary.avg_tokens, newR.summary.avg_tokens, 'tokens'); + rows.push({ + label: `${key} [tokens]`, + oldValue: fmtK(oldR.summary.avg_tokens), + newValue: fmtK(newR.summary.avg_tokens), + delta: tokenDelta.text, + direction: tokenDelta.dir, + }); + } + } + + return rows; +} + function compareTB2(oldTB2?: TB2Summary, newTB2?: TB2Summary): ComparisonRow[] { if (!newTB2) return []; if (!oldTB2) { @@ -136,6 +189,19 @@ function compareTB2(oldTB2?: TB2Summary, newTB2?: TB2Summary): ComparisonRow[] { direction: newTB2.passed > oldTB2.passed ? 'better' : newTB2.passed < oldTB2.passed ? 'worse' : 'same', }); + const oldTokObserved = (oldTB2.token_observed_trials ?? 0) > 0 && oldTB2.avg_total_tokens !== undefined; + const newTokObserved = (newTB2.token_observed_trials ?? 0) > 0 && newTB2.avg_total_tokens !== undefined; + if (oldTokObserved && newTokObserved) { + const tokenDelta = deltaStr(oldTB2.avg_total_tokens!, newTB2.avg_total_tokens!, 'tokens'); + rows.push({ + label: 'tb2 [tokens]', + oldValue: fmtK(oldTB2.avg_total_tokens!), + newValue: fmtK(newTB2.avg_total_tokens!), + delta: tokenDelta.text, + direction: tokenDelta.dir, + }); + } + return rows; } @@ -145,9 +211,10 @@ export function loadReport(filePath: string): BenchmarkReport { export function compareReports(oldReport: BenchmarkReport, newReport: BenchmarkReport): ComparisonResult { const sweRows = compareSWE(oldReport.swe ?? [], newReport.swe ?? []); + const tauRows = compareTAU(oldReport.tau ?? [], newReport.tau ?? []); const tb2Rows = compareTB2(oldReport.tb2, newReport.tb2); - const hasRegressions = [...sweRows, ...tb2Rows].some(r => r.direction === 'worse'); - return { swe: sweRows, tb2: tb2Rows, hasRegressions }; + const hasRegressions = [...sweRows, ...tauRows, ...tb2Rows].some(r => r.direction === 'worse'); + return { swe: sweRows, tau: tauRows, tb2: tb2Rows, hasRegressions }; } export function printComparison(oldPath: string, newPath: string, result: ComparisonResult): void { @@ -159,7 +226,7 @@ export function printComparison(oldPath: string, newPath: string, result: Compar console.log(` Current: ${newPath}`); console.log(''); - const allRows = [...result.swe, ...result.tb2]; + const allRows = [...result.swe, ...result.tau, ...result.tb2]; if (allRows.length === 0) { console.log(' No comparable results found.\n'); return; @@ -182,6 +249,19 @@ export function printComparison(oldPath: string, newPath: string, result: Compar console.log(''); } + if (result.tau.length > 0) { + console.log('--- TAU Comparison ---\n'); + console.log(header); + console.log(sep); + for (const row of result.tau) { + const dir = row.direction === 'better' ? ' ^' : row.direction === 'worse' ? ' v' : ' '; + console.log( + `${pad(row.label, maxLabel)} | ${lpad(row.oldValue, 10)} | ${lpad(row.newValue, 10)} | ${lpad(row.delta, 12)} |${dir}`, + ); + } + console.log(''); + } + if (result.tb2.length > 0) { console.log('--- TB2 Comparison ---\n'); console.log(header); diff --git a/tests/benchmark/config.ts b/tests/benchmark/config.ts index 8ad9895..2125734 100644 --- a/tests/benchmark/config.ts +++ b/tests/benchmark/config.ts @@ -12,38 +12,61 @@ export function parseCliArgs(argv: string[] = process.argv.slice(2)): BenchmarkC const args: BenchmarkCliArgs = {}; for (const arg of argv) { - if (arg.startsWith('--benchmark=')) { + if (arg === '--swe-only') { + args.benchmark = 'swe'; + } else if (arg === '--tau-only') { + args.benchmark = 'tau'; + } else if (arg === '--tb2-only') { + args.benchmark = 'tb2'; + } else if (arg.startsWith('--benchmark=')) { const val = arg.slice('--benchmark='.length); - if (val === 'swe' || val === 'tb2' || val === 'both') args.benchmark = val; + if (val === 'swe' || val === 'tau' || val === 'tb2' || val === 'both' || val === 'all') args.benchmark = val; } else if (arg.startsWith('--provider=')) { - args.provider = arg.slice('--provider='.length); + const v = arg.slice('--provider='.length).trim(); + if (v) args.provider = v; + } else if (arg.startsWith('--tau-domain=')) { + const v = arg.slice('--tau-domain='.length).trim(); + if (v) args.tauDomain = v; + } else if (arg.startsWith('--num-trials=')) { + const n = parseInt(arg.slice('--num-trials='.length), 10); + if (!Number.isNaN(n) && n > 0) args.numTrials = n; } else if (arg.startsWith('--tb2-model=')) { - args.tb2Model = arg.slice('--tb2-model='.length); + const v = arg.slice('--tb2-model='.length).trim(); + if (v) args.tb2Model = v; } else if (arg.startsWith('--model=')) { // Backward-compatible alias for TB2 model. - args.tb2Model = arg.slice('--model='.length); + const v = arg.slice('--model='.length).trim(); + if (v) args.tb2Model = v; } else if (arg.startsWith('--tb2-agent=')) { - args.tb2Agent = arg.slice('--tb2-agent='.length); + const v = arg.slice('--tb2-agent='.length).trim(); + if (v) args.tb2Agent = v; } else if (arg.startsWith('--tb2-dataset=')) { - args.tb2Dataset = arg.slice('--tb2-dataset='.length); + const v = arg.slice('--tb2-dataset='.length).trim(); + if (v) args.tb2Dataset = v; } else if (arg.startsWith('--tb2-runner=')) { const val = arg.slice('--tb2-runner='.length); if (val === 'auto' || val === 'harbor' || val === 'uvx' || val === 'docker') args.tb2Runner = val; } else if (arg.startsWith('--tb2-python=')) { - args.tb2Python = arg.slice('--tb2-python='.length); + const v = arg.slice('--tb2-python='.length).trim(); + if (v) args.tb2Python = v; } else if (arg.startsWith('--tb2-jobs-dir=')) { - args.tb2JobsDir = arg.slice('--tb2-jobs-dir='.length); + const v = arg.slice('--tb2-jobs-dir='.length).trim(); + if (v) args.tb2JobsDir = v; } else if (arg.startsWith('--tb2-env-file=')) { - args.tb2EnvFile = arg.slice('--tb2-env-file='.length); + const v = arg.slice('--tb2-env-file='.length).trim(); + if (v) args.tb2EnvFile = v; } else if (arg.startsWith('--tb2-docker-image=')) { - args.tb2DockerImage = arg.slice('--tb2-docker-image='.length); + const v = arg.slice('--tb2-docker-image='.length).trim(); + if (v) args.tb2DockerImage = v; } else if (arg.startsWith('--output=')) { const val = arg.slice('--output='.length); if (val === 'table' || val === 'json') args.output = val; } else if (arg.startsWith('--output-file=')) { - args.outputFile = arg.slice('--output-file='.length); + const v = arg.slice('--output-file='.length).trim(); + if (v) args.outputFile = v; } else if (arg.startsWith('--compare=')) { - args.compare = arg.slice('--compare='.length); + const v = arg.slice('--compare='.length).trim(); + if (v) args.compare = v; } } @@ -79,6 +102,28 @@ function discoverProviders(filterProvider?: string): BenchmarkProvider[] { return providers; } +function findUserSimProvider(): BenchmarkProvider | undefined { + const userModel = process.env.BENCHMARK_USER_MODEL; + if (!userModel) return undefined; + + const slashIdx = userModel.indexOf('/'); + if (slashIdx === -1) return undefined; + + const providerId = userModel.slice(0, slashIdx) as ProviderId; + const model = userModel.slice(slashIdx + 1); + + const result = loadProviderEnv(providerId); + if (!result.ok || !result.config || !result.config.apiKey) return undefined; + + return { + id: providerId, + model, + apiKey: result.config.apiKey, + baseUrl: result.config.baseUrl, + proxyUrl: result.config.proxyUrl, + }; +} + function readSdkVersion(): string { try { const pkg = require('../../package.json'); @@ -90,15 +135,19 @@ function readSdkVersion(): string { export function loadConfig(cliArgs: BenchmarkCliArgs): BenchmarkConfig { const envTimeout = process.env.BENCHMARK_TIMEOUT_MS; + const envTrials = process.env.BENCHMARK_NUM_TRIALS; const envOutput = process.env.BENCHMARK_OUTPUT; const envTb2Model = process.env.BENCHMARK_TB2_MODEL || (process.env.OPENAI_MODEL_ID ? `openai/${process.env.OPENAI_MODEL_ID}` : undefined); const timeoutMs = envTimeout ? parseInt(envTimeout, 10) : 120_000; + const envTrialsParsed = envTrials ? parseInt(envTrials, 10) : undefined; + const numTrials = cliArgs.numTrials ?? (envTrialsParsed && envTrialsParsed > 0 ? envTrialsParsed : 1); const output = cliArgs.output ?? (envOutput === 'json' || envOutput === 'table' ? envOutput : 'table'); const outputFile = cliArgs.outputFile ?? 'benchmark-report.json'; const benchmark = cliArgs.benchmark ?? 'both'; + const tauDomain = cliArgs.tauDomain ?? 'airline'; const tb2Agent = cliArgs.tb2Agent ?? 'oracle'; const tb2Dataset = cliArgs.tb2Dataset ?? 'terminal-bench@2.0'; const tb2Runner = cliArgs.tb2Runner ?? 'auto'; @@ -110,7 +159,10 @@ export function loadConfig(cliArgs: BenchmarkCliArgs): BenchmarkConfig { return { benchmark, providers: discoverProviders(cliArgs.provider), + userSimProvider: findUserSimProvider(), timeoutMs, + numTrials, + tauDomain, output, outputFile, tb2Model: cliArgs.tb2Model ?? envTb2Model, diff --git a/tests/benchmark/reporter.ts b/tests/benchmark/reporter.ts index 36c0350..9b28311 100644 --- a/tests/benchmark/reporter.ts +++ b/tests/benchmark/reporter.ts @@ -1,6 +1,12 @@ import fs from 'fs'; import path from 'path'; -import type { BenchmarkConfig, BenchmarkReport, SWEProviderResult, TB2Summary } from './types'; +import type { + BenchmarkConfig, + BenchmarkReport, + SWEProviderResult, + TAUProviderResult, + TB2Summary, +} from './types'; function pad(s: string, len: number): string { return s.length >= len ? s.slice(0, len) : s + ' '.repeat(len - s.length); @@ -48,6 +54,9 @@ function buildTable(columns: Column[], rows: string[][]): string { } export function printProviderSummary(config: BenchmarkConfig): void { + const runSWE = config.benchmark === 'swe' || config.benchmark === 'both' || config.benchmark === 'all'; + const runTAU = config.benchmark === 'tau' || config.benchmark === 'both' || config.benchmark === 'all'; + const runTB2 = config.benchmark === 'tb2' || config.benchmark === 'both' || config.benchmark === 'all'; const banner = '='.repeat(80); console.log(`\n${banner}`); console.log('KODE SDK Benchmark Runner'); @@ -58,18 +67,26 @@ export function printProviderSummary(config: BenchmarkConfig): void { console.log(` Output: ${config.output}`); console.log(''); - if (config.benchmark === 'swe' || config.benchmark === 'both') { + if (runSWE || runTAU) { if (config.providers.length === 0) { - console.log(' SWE providers: (none discovered)'); + console.log(' Providers: (none discovered)'); } else { - console.log(' SWE providers:'); + console.log(' Providers:'); for (const p of config.providers) { console.log(` - ${p.id} / ${p.model}`); } } } - if (config.benchmark === 'tb2' || config.benchmark === 'both') { + if (runTAU) { + console.log(` TAU domain: ${config.tauDomain}`); + console.log(` Num trials: ${config.numTrials}`); + if (config.userSimProvider) { + console.log(` User sim: ${config.userSimProvider.id} / ${config.userSimProvider.model}`); + } + } + + if (runTB2) { console.log(` TB2 dataset: ${config.tb2Dataset}`); console.log(` TB2 agent: ${config.tb2Agent}`); if (config.tb2Model) console.log(` TB2 model: ${config.tb2Model}`); @@ -107,12 +124,51 @@ export function printSWETable(dataset: string, instanceCount: number, results: S console.log(''); } +export function printTAUTable( + domain: string, + taskCount: number, + numTrials: number, + results: TAUProviderResult[], +): void { + console.log(`\n--- TAU-bench (${domain}) — ${taskCount} tasks, ${numTrials} trials ---\n`); + + const passColumns: Column[] = []; + for (let k = 1; k <= numTrials; k++) { + passColumns.push({ header: `Pass^${k}`, width: 7, align: 'right' }); + } + + const columns: Column[] = [ + { header: 'Provider / Model', width: 36, align: 'left' }, + ...passColumns, + { header: 'Avg Tokens', width: 10, align: 'right' }, + ]; + + const rows = results.map(r => { + const passValues = r.summary.pass_at_k.map(v => fmtPct(v)); + while (passValues.length < numTrials) passValues.push('-'); + const tokenCell = (r.summary.token_observed_trials ?? 0) > 0 ? fmtK(r.summary.avg_tokens) : '-'; + return [ + trunc(`${r.provider.id} / ${r.provider.model}`, 36), + ...passValues, + tokenCell, + ]; + }); + + console.log(buildTable(columns, rows)); + console.log(''); +} + export function printTB2Summary(summary: TB2Summary): void { console.log('\n=== Terminal Bench 2.0 Score ==='); console.log(`Job path: ${summary.job_path}`); console.log(`Passed: ${summary.passed}/${summary.total}`); console.log(`Rate: ${fmtPct(summary.rate)}`); console.log(`Unknown: ${summary.unknown}`); + if ((summary.token_observed_trials ?? 0) > 0 && summary.avg_total_tokens !== undefined) { + console.log(`Avg tok: ${fmtK(summary.avg_total_tokens)} (observed ${summary.token_observed_trials}/${summary.total})`); + } else { + console.log('Avg tok: N/A'); + } console.log(''); } diff --git a/tests/benchmark/run-benchmark.ts b/tests/benchmark/run-benchmark.ts index 6a66edf..30460ac 100644 --- a/tests/benchmark/run-benchmark.ts +++ b/tests/benchmark/run-benchmark.ts @@ -1,14 +1,15 @@ /** * Unified benchmark runner entry point. - * Supports SWE-bench-Verified, Terminal Bench 2.0, or both. + * Supports SWE-bench-Verified, TAU-bench, Terminal Bench 2.0, or combinations. */ import '../helpers/env-setup'; import { parseCliArgs, loadConfig } from './config'; -import { printProviderSummary, printSWETable, printTB2Summary, writeJsonReport } from './reporter'; +import { printProviderSummary, printSWETable, printTAUTable, printTB2Summary, writeJsonReport } from './reporter'; import { loadReport, compareReports, printComparison } from './compare'; import type { BenchmarkReport } from './types'; import { run as runSWE } from './swe'; +import { run as runTAU } from './tau'; import { runTB2Official } from './run-tb2-official'; async function main(): Promise { @@ -22,7 +23,11 @@ async function main(): Promise { sdk_version: config.sdkVersion, }; - if (config.benchmark === 'swe' || config.benchmark === 'both') { + const runSWEFlag = config.benchmark === 'swe' || config.benchmark === 'both' || config.benchmark === 'all'; + const runTAUFlag = config.benchmark === 'tau' || config.benchmark === 'both' || config.benchmark === 'all'; + const runTB2Flag = config.benchmark === 'tb2' || config.benchmark === 'both' || config.benchmark === 'all'; + + if (runSWEFlag) { console.log(' Running module: swe ...'); const sweResult = await runSWE(config); if (sweResult.swe) { @@ -33,7 +38,18 @@ async function main(): Promise { } } - if (config.benchmark === 'tb2' || config.benchmark === 'both') { + if (runTAUFlag) { + console.log(' Running module: tau ...'); + const tauResult = await runTAU(config); + if (tauResult.tau) { + report.tau = tauResult.tau; + for (const r of tauResult.tau) { + printTAUTable(r.summary.domain, r.summary.total_tasks, r.summary.num_trials, [r]); + } + } + } + + if (runTB2Flag) { console.log(' Running module: tb2 ...'); const tb2 = runTB2Official({ dataset: config.tb2Dataset, @@ -49,7 +65,7 @@ async function main(): Promise { printTB2Summary(tb2); } - if (!report.swe && !report.tb2) { + if (!report.swe && !report.tau && !report.tb2) { console.error(' No benchmark results produced. Check prerequisites and benchmark settings.'); process.exitCode = 1; return; diff --git a/tests/benchmark/run-tb2-official.ts b/tests/benchmark/run-tb2-official.ts index 3b02bbb..13ce45e 100644 --- a/tests/benchmark/run-tb2-official.ts +++ b/tests/benchmark/run-tb2-official.ts @@ -300,7 +300,79 @@ function pickResultFromRewardFile(resultJsonPath: string): boolean | undefined { } } -function scoreJob(jobPath: string): { passed: number; total: number; unknown: number } { +function asFiniteNumber(v: unknown): number | undefined { + return typeof v === 'number' && Number.isFinite(v) ? v : undefined; +} + +function getPathNumber(obj: unknown, keys: string[]): number | undefined { + let cur: unknown = obj; + for (const k of keys) { + if (!cur || typeof cur !== 'object' || Array.isArray(cur)) return undefined; + cur = (cur as Record)[k]; + } + return asFiniteNumber(cur); +} + +function findNumberByKeys(obj: unknown, candidates: string[]): number | undefined { + if (!obj || typeof obj !== 'object') return undefined; + const queue: unknown[] = [obj]; + while (queue.length > 0) { + const cur = queue.shift(); + if (!cur || typeof cur !== 'object') continue; + if (Array.isArray(cur)) { + for (const v of cur) queue.push(v); + continue; + } + for (const [k, v] of Object.entries(cur as Record)) { + if (candidates.includes(k)) { + const n = asFiniteNumber(v); + if (n !== undefined) return n; + } + if (v && typeof v === 'object') queue.push(v); + } + } + return undefined; +} + +interface TokenUsage { + input?: number; + output?: number; + cache?: number; + total?: number; +} + +function extractTokenUsage(obj: Record): TokenUsage { + const input = getPathNumber(obj, ['agent_result', 'n_input_tokens']) + ?? getPathNumber(obj, ['agent_result', 'usage', 'input_tokens']) + ?? findNumberByKeys(obj, ['n_input_tokens', 'input_tokens', 'prompt_tokens']); + const output = getPathNumber(obj, ['agent_result', 'n_output_tokens']) + ?? getPathNumber(obj, ['agent_result', 'usage', 'output_tokens']) + ?? findNumberByKeys(obj, ['n_output_tokens', 'output_tokens', 'completion_tokens']); + const cache = getPathNumber(obj, ['agent_result', 'n_cache_tokens']) + ?? findNumberByKeys(obj, ['n_cache_tokens', 'cache_tokens']); + const total = getPathNumber(obj, ['agent_result', 'n_total_tokens']) + ?? getPathNumber(obj, ['agent_result', 'usage', 'total_tokens']) + ?? findNumberByKeys(obj, ['n_total_tokens', 'total_tokens']); + + if (total !== undefined) return { input, output, cache, total }; + if (input !== undefined || output !== undefined || cache !== undefined) { + return { input, output, cache, total: (input ?? 0) + (output ?? 0) + (cache ?? 0) }; + } + return {}; +} + +interface ScoreJobResult { + passed: number; + total: number; + unknown: number; + avg_input_tokens?: number; + avg_output_tokens?: number; + avg_cache_tokens?: number; + avg_total_tokens?: number; + token_observed_trials: number; +} + +function scoreJob(jobPath: string): ScoreJobResult { const summaryPath = path.resolve(jobPath, 'result.json'); const allResultFiles = findFilesRecursive(jobPath, 'result.json'); if (allResultFiles.length === 0) { @@ -314,6 +386,14 @@ function scoreJob(jobPath: string): { passed: number; total: number; unknown: nu let passed = 0; let total = 0; let unknown = 0; + let inputSum = 0; + let outputSum = 0; + let cacheSum = 0; + let totalSum = 0; + let inputCount = 0; + let outputCount = 0; + let cacheCount = 0; + let totalCount = 0; for (const file of resultFiles) { try { @@ -331,11 +411,37 @@ function scoreJob(jobPath: string): { passed: number; total: number; unknown: nu } else { unknown += 1; } + + const usage = extractTokenUsage(data); + if (usage.input !== undefined) { + inputSum += usage.input; + inputCount += 1; + } + if (usage.output !== undefined) { + outputSum += usage.output; + outputCount += 1; + } + if (usage.cache !== undefined) { + cacheSum += usage.cache; + cacheCount += 1; + } + if (usage.total !== undefined) { + totalSum += usage.total; + totalCount += 1; + } } catch { unknown += 1; } } + const tokenStats = { + avg_input_tokens: inputCount > 0 ? Math.round(inputSum / inputCount) : undefined, + avg_output_tokens: outputCount > 0 ? Math.round(outputSum / outputCount) : undefined, + avg_cache_tokens: cacheCount > 0 ? Math.round(cacheSum / cacheCount) : undefined, + avg_total_tokens: totalCount > 0 ? Math.round(totalSum / totalCount) : undefined, + token_observed_trials: totalCount, + }; + if (total === 0) { if (!fs.existsSync(summaryPath)) { throw new Error(`No parseable pass/fail result found under job path: ${jobPath}`); @@ -357,6 +463,7 @@ function scoreJob(jobPath: string): { passed: number; total: number; unknown: nu passed: approxPassed, total: totalFromSummary, unknown: 0, + ...tokenStats, }; } } @@ -367,7 +474,7 @@ function scoreJob(jobPath: string): { passed: number; total: number; unknown: nu throw new Error(`No parseable pass/fail result found under job path: ${jobPath}`); } - return { passed, total, unknown }; + return { passed, total, unknown, ...tokenStats }; } function runOfficialTB2(args: CliArgs): string { @@ -445,6 +552,11 @@ export function runTB2Official(options: TB2RunOptions): TB2Summary { total: s.total, rate: s.total > 0 ? s.passed / s.total : 0, unknown: s.unknown, + avg_input_tokens: s.avg_input_tokens, + avg_output_tokens: s.avg_output_tokens, + avg_cache_tokens: s.avg_cache_tokens, + avg_total_tokens: s.avg_total_tokens, + token_observed_trials: s.token_observed_trials, }; } @@ -454,6 +566,11 @@ function writeSummary(summary: TB2Summary, outputFile?: string): void { console.log(`Passed: ${summary.passed}/${summary.total}`); console.log(`Rate: ${fmtPct(summary.rate)}`); console.log(`Unknown: ${summary.unknown}`); + if (summary.token_observed_trials && summary.token_observed_trials > 0 && summary.avg_total_tokens !== undefined) { + console.log(`Avg tok: ${summary.avg_total_tokens} (observed ${summary.token_observed_trials} trials)`); + } else { + console.log('Avg tok: N/A'); + } if (outputFile) { fs.mkdirSync(path.dirname(outputFile), { recursive: true }); diff --git a/tests/benchmark/tau/index.ts b/tests/benchmark/tau/index.ts new file mode 100644 index 0000000..3d98615 --- /dev/null +++ b/tests/benchmark/tau/index.ts @@ -0,0 +1,502 @@ +import fs from 'fs'; +import path from 'path'; +import { spawn, spawnSync } from 'child_process'; +import type { + BenchmarkConfig, + BenchmarkModuleResult, + BenchmarkProvider, + TAUProviderResult, + TAUTaskResult, +} from '../types'; + +export const name = 'tau'; + +const TAU2_SOURCE = 'git+https://github.com/sierra-research/tau2-bench@v0.2.0'; +const TAU2_REPO = 'https://github.com/sierra-research/tau2-bench'; +const TAU2_REF = 'v0.2.0'; +const DEFAULT_TAU2_DATA_DIR = path.resolve(process.cwd(), 'tests/tmp/tau2-data'); +const PASS_REWARD = 1; +const PASS_TOL = 1e-6; + +interface RunnerSpec { + cmd: string; + baseArgs: string[]; + label: string; +} + +interface Tau2Task { + id?: string; +} + +interface Tau2RewardInfo { + reward?: number; +} + +interface Tau2Simulation { + task_id?: string; + trial?: number; + reward_info?: Tau2RewardInfo; + [key: string]: unknown; +} + +interface Tau2RunOutput { + info?: { num_trials?: number }; + tasks?: Tau2Task[]; + simulations?: Tau2Simulation[]; +} + +function hasCommand(cmd: string, versionArg = '--version'): boolean { + const r = spawnSync(cmd, [versionArg], { stdio: 'ignore' }); + return r.status === 0; +} + +function getDomains(tauDomain: string): string[] { + if (tauDomain === 'all') return ['airline', 'retail', 'telecom']; + return [tauDomain]; +} + +function ensureDataDir(dataDir: string): void { + fs.mkdirSync(path.join(dataDir, 'simulations'), { recursive: true }); +} + +function requiredTaskFiles(dataDir: string, domains: string[]): string[] { + return domains.map(domain => path.join(dataDir, 'tau2', 'domains', domain, 'tasks.json')); +} + +function ensureOfficialDataFiles(dataDir: string, domains: string[]): void { + const missingBefore = requiredTaskFiles(dataDir, domains).filter(p => !fs.existsSync(p)); + if (missingBefore.length === 0) return; + + if (!hasCommand('git')) { + throw new Error( + `TAU2 data files missing and git is not available. Missing: ${missingBefore.join(', ')}`, + ); + } + + const sourceDir = path.join(dataDir, '.tau2-source'); + const sourceDataDir = path.join(sourceDir, 'data', 'tau2'); + + console.log(' TAU2 data missing, bootstrapping official data from repository...'); + if (!fs.existsSync(sourceDataDir)) { + if (fs.existsSync(sourceDir)) { + fs.rmSync(sourceDir, { recursive: true, force: true }); + } + const clone = spawnSync( + 'git', + ['clone', '--depth', '1', '--branch', TAU2_REF, TAU2_REPO, sourceDir], + { stdio: 'inherit' }, + ); + if (clone.status !== 0) { + throw new Error(`Failed to clone TAU2 data source (exit code ${clone.status ?? 'unknown'})`); + } + } + + if (!fs.existsSync(sourceDataDir)) { + throw new Error(`TAU2 data source missing expected directory: ${sourceDataDir}`); + } + + fs.mkdirSync(path.join(dataDir, 'tau2'), { recursive: true }); + fs.cpSync(sourceDataDir, path.join(dataDir, 'tau2'), { recursive: true, force: true }); + + const missingAfter = requiredTaskFiles(dataDir, domains).filter(p => !fs.existsSync(p)); + if (missingAfter.length > 0) { + throw new Error(`TAU2 data bootstrap incomplete. Missing: ${missingAfter.join(', ')}`); + } +} + +function shouldKeepTauLogLine(line: string): boolean { + const s = line.trim(); + if (!s) return false; + if (s.includes('Provider List: https://docs.litellm.ai/docs/providers')) return false; + if (s.includes('tau2.utils.llm_utils:get_response_cost')) return false; + if (s.includes("This model isn't mapped yet.")) return false; + return true; +} + +function createLineEmitter(isErr: boolean): (chunk: Buffer | string, flush?: boolean) => void { + let buffer = ''; + return (chunk: Buffer | string, flush = false) => { + if (chunk) { + buffer += chunk.toString().replace(/\r/g, '\n'); + } + const parts = buffer.split('\n'); + if (!flush) { + buffer = parts.pop() ?? ''; + } else { + buffer = ''; + } + for (const line of parts) { + if (!shouldKeepTauLogLine(line)) continue; + if (isErr) console.error(line); + else console.log(line); + } + }; +} + +async function runTau2WithFilteredLogs( + runner: RunnerSpec, + args: string[], + env: NodeJS.ProcessEnv, +): Promise { + const child = spawn(runner.cmd, args, { + cwd: process.cwd(), + env, + stdio: ['ignore', 'pipe', 'pipe'], + }); + + const out = createLineEmitter(false); + const err = createLineEmitter(true); + + child.stdout?.on('data', (chunk: Buffer | string) => out(chunk, false)); + child.stderr?.on('data', (chunk: Buffer | string) => err(chunk, false)); + + return await new Promise((resolve, reject) => { + child.on('error', reject); + child.on('close', code => { + out('', true); + err('', true); + resolve(code ?? 1); + }); + }); +} + +function sanitizeLabel(v: string): string { + return v.trim().replace(/[^a-zA-Z0-9._-]+/g, '-').slice(0, 96); +} + +function toTau2Model(bp: BenchmarkProvider): string { + if (bp.model.includes('/')) return bp.model; + if (bp.id === 'anthropic') return `anthropic/${bp.model}`; + if (bp.id === 'gemini') return `gemini/${bp.model}`; + return `openai/${bp.model}`; +} + +function applyProviderEnv(env: NodeJS.ProcessEnv, bp: BenchmarkProvider): void { + switch (bp.id) { + case 'anthropic': + env.ANTHROPIC_API_KEY = bp.apiKey; + if (bp.baseUrl) env.ANTHROPIC_BASE_URL = bp.baseUrl; + break; + case 'gemini': + env.GEMINI_API_KEY = bp.apiKey; + if (bp.baseUrl) env.GEMINI_BASE_URL = bp.baseUrl; + break; + default: + env.OPENAI_API_KEY = bp.apiKey; + if (bp.baseUrl) { + env.OPENAI_BASE_URL = bp.baseUrl; + env.OPENAI_API_BASE = bp.baseUrl; + } + break; + } +} + +function buildRunEnv(config: BenchmarkConfig, bp: BenchmarkProvider, userSimBp: BenchmarkProvider, dataDir: string): NodeJS.ProcessEnv { + const env: NodeJS.ProcessEnv = { + ...process.env, + TAU2_DATA_DIR: dataDir, + UV_CACHE_DIR: process.env.UV_CACHE_DIR || '/tmp/uv-cache', + UV_TOOL_DIR: process.env.UV_TOOL_DIR || '/tmp/uv-tools', + XDG_DATA_HOME: process.env.XDG_DATA_HOME || '/tmp/xdg-data', + }; + if (config.dockerProxy) { + env.HTTP_PROXY = config.dockerProxy; + env.HTTPS_PROXY = config.dockerProxy; + env.http_proxy = config.dockerProxy; + env.https_proxy = config.dockerProxy; + } + applyProviderEnv(env, bp); + applyProviderEnv(env, userSimBp); + return env; +} + +function resolveRunner(): RunnerSpec { + if (hasCommand('tau2')) { + return { cmd: 'tau2', baseArgs: [], label: 'tau2' }; + } + if (hasCommand('uvx')) { + return { + cmd: 'uvx', + baseArgs: ['--python', '3.12', '--from', TAU2_SOURCE, 'tau2'], + label: `uvx tau2 (${TAU2_SOURCE})`, + }; + } + throw new Error('TAU official runner not found. Install `tau2` or `uvx`.'); +} + +function readJson(filePath: string): Tau2RunOutput { + return JSON.parse(fs.readFileSync(filePath, 'utf-8')) as Tau2RunOutput; +} + +function isPass(sim: Tau2Simulation): boolean { + const reward = sim.reward_info?.reward; + return typeof reward === 'number' && Math.abs(reward - PASS_REWARD) <= PASS_TOL; +} + +function combinations(n: number, k: number): number { + if (k < 0 || k > n) return 0; + if (k === 0 || k === n) return 1; + let kk = Math.min(k, n - k); + let out = 1; + for (let i = 1; i <= kk; i++) { + out = (out * (n - kk + i)) / i; + } + return out; +} + +function computePassHatK(taskOutcomes: boolean[][]): number[] { + const eligible = taskOutcomes.filter(arr => arr.length > 0); + if (eligible.length === 0) return []; + + const maxK = Math.min(...eligible.map(arr => arr.length)); + const passAtK: number[] = []; + + for (let k = 1; k <= maxK; k++) { + const vals: number[] = []; + for (const arr of eligible) { + const n = arr.length; + if (n < k) continue; + const c = arr.filter(Boolean).length; + const denom = combinations(n, k); + vals.push(denom === 0 ? 0 : combinations(c, k) / denom); + } + passAtK.push(vals.length > 0 ? vals.reduce((s, v) => s + v, 0) / vals.length : 0); + } + + return passAtK; +} + +function asFiniteNumber(v: unknown): number | undefined { + return typeof v === 'number' && Number.isFinite(v) ? v : undefined; +} + +function getPathNumber(obj: unknown, keys: string[]): number | undefined { + let cur: unknown = obj; + for (const k of keys) { + if (!cur || typeof cur !== 'object' || Array.isArray(cur)) return undefined; + cur = (cur as Record)[k]; + } + return asFiniteNumber(cur); +} + +function findNumberByKeys(obj: unknown, candidates: string[]): number | undefined { + if (!obj || typeof obj !== 'object') return undefined; + const queue: unknown[] = [obj]; + while (queue.length > 0) { + const cur = queue.shift(); + if (!cur || typeof cur !== 'object') continue; + if (Array.isArray(cur)) { + for (const v of cur) queue.push(v); + continue; + } + for (const [k, v] of Object.entries(cur as Record)) { + if (candidates.includes(k)) { + const n = asFiniteNumber(v); + if (n !== undefined) return n; + } + if (v && typeof v === 'object') queue.push(v); + } + } + return undefined; +} + +interface TokenUsage { + input?: number; + output?: number; + cache?: number; + total?: number; +} + +function extractTokenUsage(obj: unknown): TokenUsage { + const input = getPathNumber(obj, ['agent_result', 'n_input_tokens']) + ?? getPathNumber(obj, ['agent_result', 'usage', 'input_tokens']) + ?? findNumberByKeys(obj, ['n_input_tokens', 'input_tokens', 'prompt_tokens']); + const output = getPathNumber(obj, ['agent_result', 'n_output_tokens']) + ?? getPathNumber(obj, ['agent_result', 'usage', 'output_tokens']) + ?? findNumberByKeys(obj, ['n_output_tokens', 'output_tokens', 'completion_tokens']); + const cache = getPathNumber(obj, ['agent_result', 'n_cache_tokens']) + ?? findNumberByKeys(obj, ['n_cache_tokens', 'cache_tokens']); + const total = getPathNumber(obj, ['agent_result', 'n_total_tokens']) + ?? getPathNumber(obj, ['agent_result', 'usage', 'total_tokens']) + ?? findNumberByKeys(obj, ['n_total_tokens', 'total_tokens']); + + if (total !== undefined) return { input, output, cache, total }; + if (input !== undefined || output !== undefined || cache !== undefined) { + return { + input, + output, + cache, + total: (input ?? 0) + (output ?? 0) + (cache ?? 0), + }; + } + return {}; +} + +function parseTau2Output( + bp: BenchmarkProvider, + domain: string, + filePath: string, + expectedTrials: number, +): TAUProviderResult { + const parsed = readJson(filePath); + const taskIds = new Set(); + for (const t of parsed.tasks ?? []) { + if (typeof t.id === 'string' && t.id.length > 0) taskIds.add(t.id); + } + for (const sim of parsed.simulations ?? []) { + if (typeof sim.task_id === 'string' && sim.task_id.length > 0) taskIds.add(sim.task_id); + } + + const trialMatrix = new Map(); + const tokenMatrix = new Map>(); + for (const id of taskIds) trialMatrix.set(id, []); + for (const id of taskIds) tokenMatrix.set(id, []); + + for (const sim of parsed.simulations ?? []) { + const taskId = sim.task_id; + if (!taskId || !trialMatrix.has(taskId)) continue; + const arr = trialMatrix.get(taskId)!; + const tokenArr = tokenMatrix.get(taskId)!; + const usage = extractTokenUsage(sim); + const tokenVal = usage.total; + if (typeof sim.trial === 'number' && sim.trial >= 0) { + arr[sim.trial] = isPass(sim); + tokenArr[sim.trial] = tokenVal; + } else { + arr.push(isPass(sim)); + tokenArr.push(tokenVal); + } + } + + const results: TAUTaskResult[] = []; + const outcomes: boolean[][] = []; + let tokenSum = 0; + let tokenObservedTrials = 0; + for (const taskId of taskIds) { + const normalized = (trialMatrix.get(taskId) ?? []).filter((v): v is boolean => typeof v === 'boolean'); + const tokens = (tokenMatrix.get(taskId) ?? []).filter((v): v is number => typeof v === 'number' && Number.isFinite(v)); + const taskAvgTokens = tokens.length > 0 + ? Math.round(tokens.reduce((s, t) => s + t, 0) / tokens.length) + : 0; + tokenSum += tokens.reduce((s, t) => s + t, 0); + tokenObservedTrials += tokens.length; + outcomes.push(normalized); + results.push({ + task_id: taskId, + trial_pass_rates: normalized, + tokens_used: taskAvgTokens, + error: normalized.length === 0 ? 'No trial results in official TAU2 output' : undefined, + }); + } + + const passAtK = computePassHatK(outcomes); + const avgTokens = tokenObservedTrials > 0 ? Math.round(tokenSum / tokenObservedTrials) : 0; + return { + provider: bp, + summary: { + domain, + total_tasks: taskIds.size, + num_trials: parsed.info?.num_trials ?? expectedTrials, + pass_at_k: passAtK, + avg_tokens: avgTokens, + token_observed_trials: tokenObservedTrials, + }, + results, + }; +} + +async function runProviderOnDomainOfficial( + config: BenchmarkConfig, + runner: RunnerSpec, + dataDir: string, + domain: string, + bp: BenchmarkProvider, + userSimBp: BenchmarkProvider, +): Promise { + const agentLlm = toTau2Model(bp); + const userLlm = toTau2Model(userSimBp); + const saveName = sanitizeLabel( + `tau2-${domain}-${bp.id}-${bp.model}-${Date.now()}`, + ); + const outputPath = path.join(dataDir, 'simulations', `${saveName}.json`); + + const runArgs = [ + ...runner.baseArgs, + 'run', + '--domain', + domain, + '--agent-llm', + agentLlm, + '--user-llm', + userLlm, + '--num-trials', + String(config.numTrials), + '--save-to', + saveName, + ]; + + console.log(` [${bp.id}] ${domain}: tau2 run (${runner.label})`); + const runStatus = await runTau2WithFilteredLogs( + runner, + runArgs, + buildRunEnv(config, bp, userSimBp, dataDir), + ); + + if (runStatus !== 0) { + throw new Error(`tau2 run failed with exit code ${runStatus}`); + } + if (!fs.existsSync(outputPath)) { + throw new Error(`tau2 output not found: ${outputPath}`); + } + + return parseTau2Output(bp, domain, outputPath, config.numTrials); +} + +export async function run(config: BenchmarkConfig): Promise { + const domains = getDomains(config.tauDomain); + if (domains.length === 0) { + console.log(` TAU: no domains found for "${config.tauDomain}"`); + return {}; + } + if (config.providers.length === 0) { + console.log(' TAU: no providers configured, skipping'); + return {}; + } + + const runner = resolveRunner(); + const dataDir = DEFAULT_TAU2_DATA_DIR; + ensureDataDir(dataDir); + ensureOfficialDataFiles(dataDir, domains); + console.log(`\n TAU official source: tau2 (${TAU2_SOURCE})`); + console.log(` TAU data dir: ${dataDir}`); + + const allResults: TAUProviderResult[] = []; + for (const domain of domains) { + console.log(`\n TAU domain: ${domain} (${config.numTrials} trials)`); + for (const bp of config.providers) { + const userSimBp = config.userSimProvider ?? bp; + console.log(`\n Running provider: ${bp.id} / ${bp.model}`); + console.log(` User simulator: ${userSimBp.id} / ${userSimBp.model}`); + try { + const r = await runProviderOnDomainOfficial(config, runner, dataDir, domain, bp, userSimBp); + allResults.push(r); + } catch (err: any) { + console.log(` [${bp.id}] ${domain}: FAIL (${err?.message || String(err)})`); + allResults.push({ + provider: bp, + summary: { + domain, + total_tasks: 0, + num_trials: config.numTrials, + pass_at_k: [], + avg_tokens: 0, + token_observed_trials: 0, + }, + results: [], + }); + } + } + } + + return { tau: allResults }; +} diff --git a/tests/benchmark/types.ts b/tests/benchmark/types.ts index 6bc0130..8ae5975 100644 --- a/tests/benchmark/types.ts +++ b/tests/benchmark/types.ts @@ -9,8 +9,10 @@ export interface BenchmarkProvider { } export interface BenchmarkCliArgs { - benchmark?: 'swe' | 'tb2' | 'both'; + benchmark?: 'swe' | 'tau' | 'tb2' | 'both' | 'all'; provider?: string; + tauDomain?: 'airline' | 'retail' | 'telecom' | 'all' | string; + numTrials?: number; tb2Model?: string; tb2Agent?: string; tb2Dataset?: string; @@ -25,9 +27,12 @@ export interface BenchmarkCliArgs { } export interface BenchmarkConfig { - benchmark: 'swe' | 'tb2' | 'both'; + benchmark: 'swe' | 'tau' | 'tb2' | 'both' | 'all'; providers: BenchmarkProvider[]; + userSimProvider?: BenchmarkProvider; timeoutMs: number; + numTrials: number; + tauDomain: string; output: 'table' | 'json'; outputFile: string; tb2Model?: string; @@ -65,6 +70,28 @@ export interface SWEProviderResult { results: SWEResult[]; } +export interface TAUTaskResult { + task_id: string; + trial_pass_rates: boolean[]; + tokens_used: number; + error?: string; +} + +export interface TAUSummary { + domain: string; + total_tasks: number; + num_trials: number; + pass_at_k: number[]; + avg_tokens: number; + token_observed_trials?: number; +} + +export interface TAUProviderResult { + provider: BenchmarkProvider; + summary: TAUSummary; + results: TAUTaskResult[]; +} + export interface TB2Summary { generated_at: string; dataset: string; @@ -76,16 +103,23 @@ export interface TB2Summary { total: number; rate: number; unknown: number; + avg_input_tokens?: number; + avg_output_tokens?: number; + avg_cache_tokens?: number; + avg_total_tokens?: number; + token_observed_trials?: number; } export interface BenchmarkReport { timestamp: string; sdk_version: string; swe?: SWEProviderResult[]; + tau?: TAUProviderResult[]; tb2?: TB2Summary; } export interface BenchmarkModuleResult { swe?: SWEProviderResult[]; + tau?: TAUProviderResult[]; tb2?: TB2Summary; }