From 5502c51e59d9d768f40e5a87d19def48397765d7 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 13 Sep 2024 16:16:26 -0700 Subject: [PATCH 1/2] Fixing mutable `BaseModel` defaults and removing extra `BaseModel` (#400) --- paperqa/agents/main.py | 17 ++--------------- paperqa/agents/models.py | 4 ++-- paperqa/docs.py | 8 ++++---- paperqa/types.py | 2 +- 4 files changed, 9 insertions(+), 22 deletions(-) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 31fdcb44..fa1d4fd3 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -10,10 +10,10 @@ Tool, ToolCall, ToolRequestMessage, - ToolResponseMessage, ToolSelector, + ToolSelectorLedger, ) -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, TypeAdapter from tenacity import ( Retrying, before_sleep_log, @@ -268,19 +268,6 @@ async def step(tool: Tool, **call_kwargs) -> None: return env.state.answer, AgentStatus.SUCCESS -class ToolSelectorLedger(BaseModel): - """ - Simple ledger to record tools and messages. - - TODO: remove this after it's upstreamed into aviary. - """ - - tools: list[Tool] = Field(default_factory=list) - messages: list[ToolRequestMessage | ToolResponseMessage | Message] = Field( - default_factory=list - ) - - async def run_aviary_agent( query: QueryRequest, docs: Docs, diff --git a/paperqa/agents/models.py b/paperqa/agents/models.py index 12567393..e35c92fa 100644 --- a/paperqa/agents/models.py +++ b/paperqa/agents/models.py @@ -130,8 +130,8 @@ class SimpleProfiler(BaseModel): # [Profiling] {**name** of timer} | {**elapsed** time of function} | {**__version__** of PaperQA} """ - timers: dict[str, list[float]] = {} - running_timers: dict[str, TimerData] = {} + timers: dict[str, list[float]] = Field(default_factory=dict) + running_timers: dict[str, TimerData] = Field(default_factory=dict) uid: UUID = Field(default_factory=uuid4) @asynccontextmanager diff --git a/paperqa/docs.py b/paperqa/docs.py index ad2fc595..2fb5b466 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -68,15 +68,15 @@ class Docs(BaseModel): model_config = ConfigDict(extra="forbid") id: UUID = Field(default_factory=uuid4) - docs: dict[DocKey, Doc | DocDetails] = {} - texts: list[Text] = [] - docnames: set[str] = set() + docs: dict[DocKey, Doc | DocDetails] = Field(default_factory=dict) + texts: list[Text] = Field(default_factory=list) + docnames: set[str] = Field(default_factory=set) texts_index: VectorStore = Field(default_factory=NumpyVectorStore) name: str = Field(default="default", description="Name of this docs collection") index_path: Path | None = Field( default=PAPERQA_DIR, description="Path to save index", validate_default=True ) - deleted_dockeys: set[DocKey] = set() + deleted_dockeys: set[DocKey] = Field(default_factory=set) @field_validator("index_path") @classmethod diff --git a/paperqa/types.py b/paperqa/types.py index 2aee291e..036a8116 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -155,7 +155,7 @@ class Answer(BaseModel): question: str answer: str = "" context: str = "" - contexts: list[Context] = [] + contexts: list[Context] = Field(default_factory=list) references: str = "" formatted_answer: str = "" cost: float = 0.0 From be173caf3feafdef9bb928f3f81e311c30d1d21a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 23:45:59 +0000 Subject: [PATCH 2/2] [pre-commit.ci lite] apply automatic fixes --- paperqa/readers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paperqa/readers.py b/paperqa/readers.py index faca366f..10498777 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -34,6 +34,7 @@ def parse_pdf_to_pages(path: Path) -> ParsedText: ) return ParsedText(content=pages, metadata=metadata) + def parse_docx_to_text(path: Path) -> ParsedText: doc = docx.Document(path) text = "\n".join([para.text for para in doc.paragraphs])