Skip to content

Commit

Permalink
Sync with remote
Browse files Browse the repository at this point in the history
  • Loading branch information
taabishm2 committed Sep 14, 2024
2 parents 5814c04 + be173ca commit bb47be5
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 22 deletions.
17 changes: 2 additions & 15 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
Tool,
ToolCall,
ToolRequestMessage,
ToolResponseMessage,
ToolSelector,
ToolSelectorLedger,
)
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, TypeAdapter
from tenacity import (
Retrying,
before_sleep_log,
Expand Down Expand Up @@ -268,19 +268,6 @@ async def step(tool: Tool, **call_kwargs) -> None:
return env.state.answer, AgentStatus.SUCCESS


class ToolSelectorLedger(BaseModel):
"""
Simple ledger to record tools and messages.
TODO: remove this after it's upstreamed into aviary.
"""

tools: list[Tool] = Field(default_factory=list)
messages: list[ToolRequestMessage | ToolResponseMessage | Message] = Field(
default_factory=list
)


async def run_aviary_agent(
query: QueryRequest,
docs: Docs,
Expand Down
4 changes: 2 additions & 2 deletions paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class SimpleProfiler(BaseModel):
# [Profiling] {**name** of timer} | {**elapsed** time of function} | {**__version__** of PaperQA}
"""

timers: dict[str, list[float]] = {}
running_timers: dict[str, TimerData] = {}
timers: dict[str, list[float]] = Field(default_factory=dict)
running_timers: dict[str, TimerData] = Field(default_factory=dict)
uid: UUID = Field(default_factory=uuid4)

@asynccontextmanager
Expand Down
8 changes: 4 additions & 4 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class Docs(BaseModel):
model_config = ConfigDict(extra="forbid")

id: UUID = Field(default_factory=uuid4)
docs: dict[DocKey, Doc | DocDetails] = {}
texts: list[Text] = []
docnames: set[str] = set()
docs: dict[DocKey, Doc | DocDetails] = Field(default_factory=dict)
texts: list[Text] = Field(default_factory=list)
docnames: set[str] = Field(default_factory=set)
texts_index: VectorStore = Field(default_factory=NumpyVectorStore)
name: str = Field(default="default", description="Name of this docs collection")
index_path: Path | None = Field(
default=PAPERQA_DIR, description="Path to save index", validate_default=True
)
deleted_dockeys: set[DocKey] = set()
deleted_dockeys: set[DocKey] = Field(default_factory=set)

@field_validator("index_path")
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class Answer(BaseModel):
question: str
answer: str = ""
context: str = ""
contexts: list[Context] = []
contexts: list[Context] = Field(default_factory=list)
references: str = ""
formatted_answer: str = ""
cost: float = 0.0
Expand Down

0 comments on commit bb47be5

Please sign in to comment.