Skip to content

Commit fd1cf5e

Browse files
Tapan Chughclaude
andcommitted
Add app_oracle API prediction mode for AppWorld benchmarks
Implements a new intermediate API prediction mode that uses oracle data to identify required services, then exposes all APIs from those services. Changes: - Add app_oracle mode: Uses ground truth to identify apps (e.g., spotify, venmo), then loads all APIs from those apps. System apps (supervisor) only include ground truth APIs. - Refactor: Split appworld_helpers.py into api_predictor.py (API prediction) and prompts.py (prompt management) for better separation of concerns - Fix: Remove 20-API limit for "all" mode (now returns all 473 APIs) - Fix: Eliminate duplicate Task loading in predict_apis() API count comparison for typical task: - ground_truth: 6 APIs (exact oracle) - app_oracle: 95 APIs (3 supervisor + 92 spotify) - all: 473 APIs (no limit) Usage: pytest tests/benchmarks/appworld/test_appworld.py --api-mode app_oracle \ --dataset train --limit 5 --model gpt-4o 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent cb1fbf8 commit fd1cf5e

File tree

4 files changed

+99
-53
lines changed

4 files changed

+99
-53
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""API prediction for AppWorld tasks."""
2+
3+
from pathlib import Path
4+
5+
import appworld_experiments
6+
from appworld.task import Task
7+
from appworld_experiments.code.common.api_predictor import APIPredictor
8+
9+
EXPERIMENTS_PATH = Path(appworld_experiments.__file__).parent
10+
SYSTEM_APP_NAME = "supervisor"
11+
12+
13+
def _get_ground_truth_apis(task: Task) -> list[str]:
14+
"""Get exact API list from ground truth using APIPredictor."""
15+
prompt_path = EXPERIMENTS_PATH / "prompts" / "api_predictor.txt"
16+
predictor = APIPredictor(
17+
prompt_file_path=str(prompt_path),
18+
demo_task_ids=[],
19+
app_api_separator="__",
20+
mode="ground_truth",
21+
)
22+
return predictor.non_predicted_apis(task)
23+
24+
25+
def _predict_apis_using_model(task: Task, model_name: str) -> list[str]:
26+
raise NotImplementedError(
27+
"Predicted mode requires language model configuration. "
28+
"Use mode='ground_truth' (train/dev only) or mode='all' instead."
29+
)
30+
31+
32+
def predict_apis(
33+
task_id: str,
34+
mode: str = "predicted",
35+
model_name: str = "gpt-4o-mini",
36+
) -> list[str]:
37+
"""
38+
Predict which APIs are needed for a task.
39+
40+
Args:
41+
task_id: AppWorld task ID
42+
mode: predicted/ground_truth/app_oracle/all
43+
model_name: Model for prediction (only used if mode="predicted")
44+
45+
Returns:
46+
List of API names in format "app__method"
47+
- ground_truth: ~6-10 specific APIs from oracle
48+
- app_oracle: ~50-100 APIs from oracle-identified apps
49+
- all: All available APIs (no limit)
50+
"""
51+
needs_ground_truth = mode in ("ground_truth", "app_oracle")
52+
task = Task.load(
53+
task_id=task_id,
54+
storage_type="memory",
55+
load_ground_truth=needs_ground_truth,
56+
ground_truth_mode="full" if needs_ground_truth else "minimal",
57+
)
58+
59+
if mode == "ground_truth":
60+
return _get_ground_truth_apis(task)
61+
62+
elif mode == "predicted":
63+
return _predict_apis_using_model(task, model_name)
64+
65+
elif mode == "app_oracle":
66+
ground_truth_apis_list = _get_ground_truth_apis(task)
67+
required_apps = {api.split("__", 1)[0] for api in ground_truth_apis_list}
68+
69+
result_apis: list[str] = []
70+
for app_name, api_docs in task.api_docs.items():
71+
if app_name in required_apps:
72+
if app_name == SYSTEM_APP_NAME:
73+
system_apis = [api for api in ground_truth_apis_list if api.startswith(f"{app_name}__")]
74+
result_apis.extend(system_apis)
75+
else:
76+
result_apis.extend(f"{app_name}__{api_name}" for api_name in api_docs.keys())
77+
78+
return result_apis
79+
80+
elif mode == "all":
81+
return [
82+
f"{app_name}__{api_name}" for app_name, api_docs in task.api_docs.items() for api_name in api_docs.keys()
83+
]
84+
85+
raise ValueError(f"Invalid mode: {mode}")

tests/benchmarks/appworld/conftest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ def pytest_addoption(parser: pytest.Parser) -> None:
2020
parser.addoption(
2121
"--api-mode",
2222
default="ground_truth",
23-
choices=["predicted", "ground_truth", "all"],
24-
help="API prediction mode: predicted (LLM), ground_truth (oracle), all (default: ground_truth)",
23+
choices=["predicted", "ground_truth", "app_oracle", "all"],
24+
help=(
25+
"API prediction mode: predicted (LLM), ground_truth (API-level oracle), "
26+
"app_oracle (app-level oracle), all (default: ground_truth)"
27+
),
2528
)
2629

2730

@@ -45,7 +48,8 @@ def api_mode(request: pytest.FixtureRequest) -> str:
4548
4649
Returns:
4750
"predicted": Use LLM to predict APIs (costs 1 extra call per task)
48-
"ground_truth": Use oracle APIs from task data (train/dev only)
49-
"all": Use all available APIs (limited to 20)
51+
"ground_truth": Use oracle APIs from task data (API-level oracle, train/dev only)
52+
"app_oracle": Use oracle to identify apps, load all APIs from those apps (app-level oracle)
53+
"all": Use all available APIs (no limit)
5054
"""
5155
return str(request.config.getoption("--api-mode"))

tests/benchmarks/appworld/appworld_helpers.py renamed to tests/benchmarks/appworld/prompts.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,18 @@
1-
"""AppWorld integration helpers - instructions and API prediction."""
1+
"""AppWorld prompt and instruction management."""
22

33
import json
44
from pathlib import Path
5-
from typing import Any, Literal, cast
5+
from typing import Any
66

77
import appworld_experiments
88
from appworld.common.io import dump_yaml, read_file, read_json
99
from appworld.common.text import render_template
1010
from appworld.task import Task
11-
from appworld_experiments.code.common.api_predictor import APIPredictor
1211

1312
# Path to installed appworld_experiments package
1413
EXPERIMENTS_PATH = Path(appworld_experiments.__file__).parent
1514

1615

17-
def predict_apis(
18-
task_id: str,
19-
mode: str = "predicted",
20-
model_name: str = "gpt-4o-mini",
21-
) -> list[str]:
22-
"""
23-
Predict which APIs are needed for a task using AppWorld's APIPredictor.
24-
25-
Args:
26-
task_id: AppWorld task ID
27-
mode: predicted/ground_truth/all
28-
model_name: Model for prediction (only used if mode="predicted")
29-
30-
Returns:
31-
List of API names (typically 6-20 APIs instead of 400+)
32-
"""
33-
task = Task.load(
34-
task_id=task_id,
35-
storage_type="memory",
36-
load_ground_truth=(mode == "ground_truth"),
37-
ground_truth_mode="full" if mode == "ground_truth" else "minimal",
38-
)
39-
40-
prompt_path = EXPERIMENTS_PATH / "prompts/api_predictor.txt"
41-
42-
predictor = APIPredictor(
43-
prompt_file_path=str(prompt_path),
44-
demo_task_ids=[],
45-
max_predicted_apis=20,
46-
app_api_separator="__",
47-
mode=cast(Literal["ground_truth", "predicted", "all"], mode),
48-
)
49-
50-
if mode == "predicted":
51-
raise NotImplementedError(
52-
"Predicted mode requires language model configuration. "
53-
"Use mode='ground_truth' (train/dev only) or mode='all' instead."
54-
)
55-
56-
return predictor.non_predicted_apis(task)
57-
58-
5916
def load_system_instruction(task: Task, max_steps: int = 40) -> str:
6017
"""
6118
Load and render system instruction from AppWorld's template with demo examples.

tests/benchmarks/appworld/test_appworld.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from appworld.task import Task
1313
from fast_agent import FastAgent
1414

15-
from tests.benchmarks.appworld import appworld_helpers
15+
from tests.benchmarks.appworld import api_predictor, prompts
1616
from tests.utils.fastagent_helpers import MessageSerializer
1717
from tests.utils.logger import StructuredEventLogger
1818

@@ -126,7 +126,7 @@ async def _run_appworld_test(
126126
# Create and run FastAgent
127127
config_path = Path(__file__).parent / "fastagent.config.yaml"
128128
agent = FastAgent("AppWorld Test", config_path=str(config_path), ignore_unknown_args=True)
129-
system_instruction = appworld_helpers.load_system_instruction(task)
129+
system_instruction = prompts.load_system_instruction(task)
130130

131131
@agent.agent(
132132
name="test_agent",
@@ -162,11 +162,11 @@ def _setup_mcp_environment(
162162
"""Configure environment variables for MCP server."""
163163
# Predict which APIs are needed
164164
try:
165-
predicted_apis = appworld_helpers.predict_apis(task_id, mode=api_mode, model_name=model)
165+
predicted_apis = api_predictor.predict_apis(task_id, mode=api_mode, model_name=model)
166166
print(f"API mode: {api_mode}, predicted {len(predicted_apis)} APIs")
167167
except NotImplementedError:
168168
print(f"Warning: {api_mode} mode not supported, falling back to ground_truth")
169-
predicted_apis = appworld_helpers.predict_apis(task_id, mode="ground_truth", model_name=model)
169+
predicted_apis = api_predictor.predict_apis(task_id, mode="ground_truth", model_name=model)
170170

171171
# Set environment variables
172172
os.environ.update(

0 commit comments

Comments
 (0)