Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added account for cost info #360

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ repos:
args: [--pretty, --ignore-missing-imports]
additional_dependencies:
- aiohttp
- fhaviary[llm]>=0.5 # Match pyproject.toml
- fhaviary[llm]>=0.6 # Match pyproject.toml
- ldp>=0.4 # Match pyproject.toml
- html2text
- httpx
Expand Down
13 changes: 12 additions & 1 deletion paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import Answer
from paperqa.types import Answer, LLMResult
from paperqa.utils import get_year

from .models import QueryRequest
Expand Down Expand Up @@ -151,6 +151,17 @@ def export_frame(self) -> Frame:
async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:

# add usage for action if it has usage
info = action.info
if info and "usage" in info and "model" in info:
r = LLMResult(
model=info["model"],
prompt_count=info["usage"][0],
completion_count=info["usage"][1],
)
self.state.answer.add_tokens(r)

# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
list[Message],
Expand Down
5 changes: 3 additions & 2 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def to_aviary_tool_selector(
)
):
return ToolSelector(
model=query.settings.agent.agent_llm,
model_name=query.settings.agent.agent_llm,
acompletion=query.settings.get_agent_llm().router.acompletion,
**(query.settings.agent.agent_config or {}),
)
return None
Expand Down Expand Up @@ -220,7 +221,7 @@ async def run_agent(
f"Finished agent {agent_type!r} run with question {query.query!r} and status"
f" {agent_status}."
)
return AnswerResponse(answer=answer, usage=answer.token_counts, status=agent_status)
return AnswerResponse(answer=answer, status=agent_status)


async def run_fake_agent(
Expand Down
1 change: 0 additions & 1 deletion paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def set_docs_name(self, docs_name: str) -> None:

class AnswerResponse(BaseModel):
answer: Answer
usage: dict[str, list[int]]
bibtex: dict[str, str] | None = None
status: AgentStatus
timing_info: dict[str, dict[str, float]] | None = None
Expand Down
15 changes: 14 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,19 @@ class AgentSettings(BaseModel):
default="gpt-4o-2024-08-06",
description="Model to use for agent",
)

agent_llm_config: dict | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few lines below this there is an agent_config that is basically kwargs for the Agent/ToolSelector

Do you mind adjusting their names and/or descriptions so it's clear what is the difference between agent_llm_config and agent_config? They're similar enough in naming right now I think we should fix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Tried to make the descriptions more clear in distinction

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your efforts! Being honest, it's still not quite intuitive enough for my tastes, but for now it's good.

I think we can make this clearer by moving to_aviary_tool_selector and to_ldp_agent to be methods of AgentSettings, since all their information is derived from there.

default=None,
description="Optional kwargs for LLM constructor",
)

agent_type: str = Field(
default="fake",
description="Type of agent to use",
)
agent_config: dict[str, Any] | None = Field(
default=None,
description="Optional keyword argument configuration for the agent.",
description="Optional kwarg for AGENT constructor",
)

agent_system_prompt: str | None = Field(
Expand Down Expand Up @@ -500,6 +506,13 @@ def get_summary_llm(self) -> LiteLLMModel:
or self._default_litellm_router_settings(self.summary_llm),
)

def get_agent_llm(self) -> LiteLLMModel:
return LiteLLMModel(
name=self.agent.agent_llm,
config=self.agent.agent_llm_config
or self._default_litellm_router_settings(self.agent.agent_llm),
)

def get_embedding_model(self) -> EmbeddingModel:
return embedding_model_factory(self.embedding, **(self.embedding_config or {}))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"PyCryptodome",
"aiohttp", # TODO: remove in favor of httpx
"anyio",
"fhaviary[llm]>=0.5", # For ToolSelector
"fhaviary[llm]>=0.6", # For info on Message
"html2text", # TODO: evaluate moving to an opt-in dependency
"httpx",
"litellm",
Expand Down
24 changes: 20 additions & 4 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,32 @@ async def test_agent_types(
) -> None:
question = "How can you use XAI for chemical property prediction?"

# make sure agent_llm is different from default, so we can correctly track tokens
# for agent
agent_test_settings.agent.agent_llm = "gpt-4o-2024-08-06"
agent_test_settings.llm = "gpt-4o-mini"
agent_test_settings.summary_llm = "gpt-4o-mini"
agent_test_settings.agent.agent_prompt += (
"\n\n Call each tool once in appropriate order and "
" accept the answer for now, as we're in debug mode."
)
request = QueryRequest(query=question, settings=agent_test_settings)
response = await agent_query(request, agent_type=agent_type)
assert response.answer.answer, "Answer not generated"
assert response.answer.answer != "I cannot answer", "Answer not generated"
assert response.answer.context, "No contexts were found"
assert response.answer.question == question
agent_llm = request.settings.agent.agent_llm
assert response.usage[agent_llm][0] > 5000, "Expected many prompt tokens"
assert response.usage[agent_llm][1] > 250, "Expected many completion tokens"
assert response.answer.cost > 0, "Expected nonzero cost"
# TODO: once LDP can track tokens, we can remove this check
if agent_type not in {"fake", SimpleAgent}:
print(response.answer.token_counts)
assert (
response.answer.token_counts[agent_llm][0] > 1000
), "Expected many prompt tokens"
assert (
response.answer.token_counts[agent_llm][1] > 50
), "Expected many completion tokens"
assert response.answer.cost > 0, "Expected nonzero cost"


@pytest.mark.asyncio
Expand Down Expand Up @@ -356,7 +372,7 @@ def test_answers_are_striped() -> None:
)
],
)
response = AnswerResponse(answer=answer, usage={}, bibtex={}, status="success")
response = AnswerResponse(answer=answer, bibtex={}, status="success")

assert response.answer.contexts[0].text.embedding is None
assert response.answer.contexts[0].text.text == "" # type: ignore[unreachable,unused-ignore]
Expand Down
Loading
Loading