Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/benchmarks/appworld/api_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""API prediction for AppWorld tasks."""

from pathlib import Path

import appworld_experiments
from appworld.task import Task
from appworld_experiments.code.common.api_predictor import APIPredictor

EXPERIMENTS_PATH = Path(appworld_experiments.__file__).parent
SYSTEM_APP_NAME = "supervisor"


def _get_ground_truth_apis(task: Task) -> list[str]:
"""Get exact API list from ground truth using APIPredictor."""
prompt_path = EXPERIMENTS_PATH / "prompts" / "api_predictor.txt"
predictor = APIPredictor(
prompt_file_path=str(prompt_path),
demo_task_ids=[],
app_api_separator="__",
mode="ground_truth",
)
return predictor.non_predicted_apis(task)


def _predict_apis_using_model(task: Task, model_name: str) -> list[str]:
raise NotImplementedError(
"Predicted mode requires language model configuration. "
"Use mode='ground_truth' (train/dev only) or mode='all' instead."
)


def predict_apis(
task_id: str,
mode: str = "predicted",
model_name: str = "gpt-4o-mini",
) -> list[str]:
"""
Predict which APIs are needed for a task.

Args:
task_id: AppWorld task ID
mode: predicted/ground_truth/app_oracle/all
model_name: Model for prediction (only used if mode="predicted")

Returns:
List of API names in format "app__method"
- ground_truth: ~6-10 specific APIs from oracle
- app_oracle: ~50-100 APIs from oracle-identified apps
- all: All available APIs (no limit)
"""
needs_ground_truth = mode in ("ground_truth", "app_oracle")
task = Task.load(
task_id=task_id,
storage_type="memory",
load_ground_truth=needs_ground_truth,
ground_truth_mode="full" if needs_ground_truth else "minimal",
)

if mode == "ground_truth":
return _get_ground_truth_apis(task)

elif mode == "predicted":
return _predict_apis_using_model(task, model_name)

elif mode == "app_oracle":
ground_truth_apis_list = _get_ground_truth_apis(task)
required_apps = {api.split("__", 1)[0] for api in ground_truth_apis_list}

system_apis = [api for api in ground_truth_apis_list if api.startswith(f"{SYSTEM_APP_NAME}__")]
domain_apis = [
f"{app_name}__{api_name}"
for app_name, api_docs in task.api_docs.items()
if app_name in required_apps and app_name != SYSTEM_APP_NAME
for api_name in api_docs.keys()
]

return system_apis + domain_apis

elif mode == "all":
return [
f"{app_name}__{api_name}" for app_name, api_docs in task.api_docs.items() for api_name in api_docs.keys()
]

raise ValueError(f"Invalid mode: {mode}")
12 changes: 8 additions & 4 deletions tests/benchmarks/appworld/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--api-mode",
default="ground_truth",
choices=["predicted", "ground_truth", "all"],
help="API prediction mode: predicted (LLM), ground_truth (oracle), all (default: ground_truth)",
choices=["predicted", "ground_truth", "app_oracle", "all"],
help=(
"API prediction mode: predicted (LLM), ground_truth (API-level oracle), "
"app_oracle (app-level oracle), all (default: ground_truth)"
),
)


Expand All @@ -45,7 +48,8 @@ def api_mode(request: pytest.FixtureRequest) -> str:

Returns:
"predicted": Use LLM to predict APIs (costs 1 extra call per task)
"ground_truth": Use oracle APIs from task data (train/dev only)
"all": Use all available APIs (limited to 20)
"ground_truth": Use oracle APIs from task data (API-level oracle, train/dev only)
"app_oracle": Use oracle to identify apps, load all APIs from those apps (app-level oracle)
"all": Use all available APIs (no limit)
"""
return str(request.config.getoption("--api-mode"))
Original file line number Diff line number Diff line change
@@ -1,61 +1,18 @@
"""AppWorld integration helpers - instructions and API prediction."""
"""AppWorld prompt and instruction management."""

import json
from pathlib import Path
from typing import Any, Literal, cast
from typing import Any

import appworld_experiments
from appworld.common.io import dump_yaml, read_file, read_json
from appworld.common.text import render_template
from appworld.task import Task
from appworld_experiments.code.common.api_predictor import APIPredictor

# Path to installed appworld_experiments package
EXPERIMENTS_PATH = Path(appworld_experiments.__file__).parent


def predict_apis(
task_id: str,
mode: str = "predicted",
model_name: str = "gpt-4o-mini",
) -> list[str]:
"""
Predict which APIs are needed for a task using AppWorld's APIPredictor.

Args:
task_id: AppWorld task ID
mode: predicted/ground_truth/all
model_name: Model for prediction (only used if mode="predicted")

Returns:
List of API names (typically 6-20 APIs instead of 400+)
"""
task = Task.load(
task_id=task_id,
storage_type="memory",
load_ground_truth=(mode == "ground_truth"),
ground_truth_mode="full" if mode == "ground_truth" else "minimal",
)

prompt_path = EXPERIMENTS_PATH / "prompts/api_predictor.txt"

predictor = APIPredictor(
prompt_file_path=str(prompt_path),
demo_task_ids=[],
max_predicted_apis=20,
app_api_separator="__",
mode=cast(Literal["ground_truth", "predicted", "all"], mode),
)

if mode == "predicted":
raise NotImplementedError(
"Predicted mode requires language model configuration. "
"Use mode='ground_truth' (train/dev only) or mode='all' instead."
)

return predictor.non_predicted_apis(task)


def load_system_instruction(task: Task, max_steps: int = 40) -> str:
"""
Load and render system instruction from AppWorld's template with demo examples.
Expand Down
8 changes: 4 additions & 4 deletions tests/benchmarks/appworld/test_appworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from appworld.task import Task
from fast_agent import FastAgent

from tests.benchmarks.appworld import appworld_helpers
from tests.benchmarks.appworld import api_predictor, prompts
from tests.utils.fastagent_helpers import MessageSerializer
from tests.utils.logger import StructuredEventLogger

Expand Down Expand Up @@ -126,7 +126,7 @@ async def _run_appworld_test(
# Create and run FastAgent
config_path = Path(__file__).parent / "fastagent.config.yaml"
agent = FastAgent("AppWorld Test", config_path=str(config_path), ignore_unknown_args=True)
system_instruction = appworld_helpers.load_system_instruction(task)
system_instruction = prompts.load_system_instruction(task)

@agent.agent(
name="test_agent",
Expand Down Expand Up @@ -162,11 +162,11 @@ def _setup_mcp_environment(
"""Configure environment variables for MCP server."""
# Predict which APIs are needed
try:
predicted_apis = appworld_helpers.predict_apis(task_id, mode=api_mode, model_name=model)
predicted_apis = api_predictor.predict_apis(task_id, mode=api_mode, model_name=model)
print(f"API mode: {api_mode}, predicted {len(predicted_apis)} APIs")
except NotImplementedError:
print(f"Warning: {api_mode} mode not supported, falling back to ground_truth")
predicted_apis = appworld_helpers.predict_apis(task_id, mode="ground_truth", model_name=model)
predicted_apis = api_predictor.predict_apis(task_id, mode="ground_truth", model_name=model)

# Set environment variables
os.environ.update(
Expand Down