diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 31fdcb44..fa1d4fd3 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -10,10 +10,10 @@ Tool, ToolCall, ToolRequestMessage, - ToolResponseMessage, ToolSelector, + ToolSelectorLedger, ) -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, TypeAdapter from tenacity import ( Retrying, before_sleep_log, @@ -268,19 +268,6 @@ async def step(tool: Tool, **call_kwargs) -> None: return env.state.answer, AgentStatus.SUCCESS -class ToolSelectorLedger(BaseModel): - """ - Simple ledger to record tools and messages. - - TODO: remove this after it's upstreamed into aviary. - """ - - tools: list[Tool] = Field(default_factory=list) - messages: list[ToolRequestMessage | ToolResponseMessage | Message] = Field( - default_factory=list - ) - - async def run_aviary_agent( query: QueryRequest, docs: Docs, diff --git a/paperqa/agents/models.py b/paperqa/agents/models.py index 12567393..e35c92fa 100644 --- a/paperqa/agents/models.py +++ b/paperqa/agents/models.py @@ -130,8 +130,8 @@ class SimpleProfiler(BaseModel): # [Profiling] {**name** of timer} | {**elapsed** time of function} | {**__version__** of PaperQA} """ - timers: dict[str, list[float]] = {} - running_timers: dict[str, TimerData] = {} + timers: dict[str, list[float]] = Field(default_factory=dict) + running_timers: dict[str, TimerData] = Field(default_factory=dict) uid: UUID = Field(default_factory=uuid4) @asynccontextmanager diff --git a/paperqa/docs.py b/paperqa/docs.py index ad2fc595..2fb5b466 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -68,15 +68,15 @@ class Docs(BaseModel): model_config = ConfigDict(extra="forbid") id: UUID = Field(default_factory=uuid4) - docs: dict[DocKey, Doc | DocDetails] = {} - texts: list[Text] = [] - docnames: set[str] = set() + docs: dict[DocKey, Doc | DocDetails] = Field(default_factory=dict) + texts: list[Text] = Field(default_factory=list) + docnames: set[str] = Field(default_factory=set) texts_index: VectorStore = Field(default_factory=NumpyVectorStore) name: str = Field(default="default", description="Name of this docs collection") index_path: Path | None = Field( default=PAPERQA_DIR, description="Path to save index", validate_default=True ) - deleted_dockeys: set[DocKey] = set() + deleted_dockeys: set[DocKey] = Field(default_factory=set) @field_validator("index_path") @classmethod diff --git a/paperqa/types.py b/paperqa/types.py index 2aee291e..036a8116 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -155,7 +155,7 @@ class Answer(BaseModel): question: str answer: str = "" context: str = "" - contexts: list[Context] = [] + contexts: list[Context] = Field(default_factory=list) references: str = "" formatted_answer: str = "" cost: float = 0.0