From 846f136088b00cfee276f4d439a9fbda14c6cbc0 Mon Sep 17 00:00:00 2001 From: James Braza Date: Sun, 15 Sep 2024 21:37:26 -0700 Subject: [PATCH] Promoting agent factories to `Settings` (#407) --- paperqa/agents/main.py | 111 +++------------------------------------ paperqa/settings.py | 114 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 119 insertions(+), 106 deletions(-) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 7443b15c..d1bd0483 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -2,8 +2,7 @@ import logging import os from collections.abc import Awaitable, Callable -from pydoc import locate -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from aviary.message import MalformedMessageError, Message from aviary.tools import ( @@ -13,7 +12,8 @@ ToolSelector, ToolSelectorLedger, ) -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel +from rich.console import Console from tenacity import ( Retrying, before_sleep_log, @@ -21,28 +21,7 @@ stop_after_attempt, ) -try: - from ldp.agent import ( - Agent, - HTTPAgentClient, - MemoryAgent, - ReActAgent, - SimpleAgent, - SimpleAgentState, - ) - from ldp.graph.memory import Memory, UIndexMemoryModel - from ldp.graph.op_utils import set_training_mode - from ldp.llms import EmbeddingModel - - _Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated] - - HAS_LDP_INSTALLED = True -except ImportError: - HAS_LDP_INSTALLED = False -from rich.console import Console - from paperqa.docs import Docs -from paperqa.settings import Settings from paperqa.types import Answer from paperqa.utils import pqa_directory @@ -53,6 +32,7 @@ from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch if TYPE_CHECKING: + from ldp.agent import Agent, SimpleAgentState from ldp.graph.ops import OpResult logger = logging.getLogger(__name__) @@ -96,85 +76,6 @@ async def agent_query( return response -def to_aviary_tool_selector( - agent_type: str | type, settings: Settings -) -> ToolSelector | None: - """Attempt to convert the agent type to an aviary ToolSelector.""" - if agent_type is ToolSelector or ( - isinstance(agent_type, str) - and ( - agent_type == ToolSelector.__name__ - or ( - agent_type.startswith(ToolSelector.__module__.split(".", maxsplit=1)[0]) - and locate(agent_type) is ToolSelector - ) - ) - ): - return ToolSelector( - model_name=settings.agent.agent_llm, - acompletion=settings.get_agent_llm().router.acompletion, - **(settings.agent.agent_config or {}), - ) - return None - - -async def to_ldp_agent( - agent_type: str | type, settings: Settings -) -> "Agent[SimpleAgentState] | None": - """Attempt to convert the agent type to an ldp Agent.""" - if not isinstance(agent_type, str): # Convert to fully qualified name - agent_type = f"{agent_type.__module__}.{agent_type.__name__}" - if not agent_type.startswith("ldp"): - return None - if not HAS_LDP_INSTALLED: - raise ImportError( - "ldp agents requires the 'ldp' extra for 'ldp'. Please:" - " `pip install paper-qa[ldp]`." - ) - - # TODO: support general agents - agent_cls = cast(type[Agent], locate(agent_type)) - agent_settings = settings.agent - agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {} - if issubclass(agent_cls, ReActAgent | MemoryAgent): - if ( - issubclass(agent_cls, MemoryAgent) - and "memory_model" in config - and "memories" in config - ): - if "embedding_model" in config["memory_model"]: - # Work around EmbeddingModel not yet supporting deserialization - config["memory_model"]["embedding_model"] = EmbeddingModel.from_name( - embedding=config["memory_model"].pop("embedding_model")["name"] - ) - config["memory_model"] = UIndexMemoryModel(**config["memory_model"]) - memories = _Memories.validate_python(config.pop("memories")) - await asyncio.gather( - *( - config["memory_model"].add_memory(memory) - for memory in ( - memories.values() if isinstance(memories, dict) else memories - ) - ) - ) - return agent_cls( - llm_model={"model": agent_llm, "temperature": settings.temperature}, - **config, - ) - if issubclass(agent_cls, SimpleAgent): - return agent_cls( - llm_model={"model": agent_llm, "temperature": settings.temperature}, - sys_prompt=agent_settings.agent_system_prompt, - **config, - ) - if issubclass(agent_cls, HTTPAgentClient): - set_training_mode(False) - return HTTPAgentClient[SimpleAgentState]( - agent_state_type=SimpleAgentState, **config - ) - raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.") - - async def run_agent( docs: Docs, query: QueryRequest, @@ -205,11 +106,11 @@ async def run_agent( if agent_type == "fake": answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs) - elif tool_selector_or_none := to_aviary_tool_selector(agent_type, query.settings): + elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type): answer, agent_status = await run_aviary_agent( query, docs, tool_selector_or_none, **runner_kwargs ) - elif ldp_agent_or_none := await to_ldp_agent(agent_type, query.settings): + elif ldp_agent_or_none := await query.settings.make_ldp_agent(agent_type): answer, agent_status = await run_ldp_agent( query, docs, ldp_agent_or_none, **runner_kwargs ) diff --git a/paperqa/settings.py b/paperqa/settings.py index c058c7ec..0720abcc 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -1,12 +1,41 @@ +import asyncio import importlib.resources import os from enum import StrEnum from pathlib import Path +from pydoc import locate from typing import Any, ClassVar, assert_never, cast -from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator +from aviary.tools import ToolSelector +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + computed_field, + field_validator, +) from pydantic_settings import BaseSettings, CliSettingsSource, SettingsConfigDict +try: + from ldp.agent import ( + Agent, + HTTPAgentClient, + MemoryAgent, + ReActAgent, + SimpleAgent, + SimpleAgentState, + ) + from ldp.graph.memory import Memory, UIndexMemoryModel + from ldp.graph.op_utils import set_training_mode + from ldp.llms import EmbeddingModel as LDPEmbeddingModel + + _Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated] + + HAS_LDP_INSTALLED = True +except ImportError: + HAS_LDP_INSTALLED = False + from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory from paperqa.prompts import ( citation_prompt, @@ -520,6 +549,89 @@ def get_agent_llm(self) -> LiteLLMModel: def get_embedding_model(self) -> EmbeddingModel: return embedding_model_factory(self.embedding, **(self.embedding_config or {})) + def make_aviary_tool_selector(self, agent_type: str | type) -> ToolSelector | None: + """Attempt to convert the input agent type to an aviary ToolSelector.""" + if agent_type is ToolSelector or ( + isinstance(agent_type, str) + and ( + agent_type == ToolSelector.__name__ + or ( + agent_type.startswith( + ToolSelector.__module__.split(".", maxsplit=1)[0] + ) + and locate(agent_type) is ToolSelector + ) + ) + ): + return ToolSelector( + model_name=self.agent.agent_llm, + acompletion=self.get_agent_llm().router.acompletion, + **(self.agent.agent_config or {}), + ) + return None + + async def make_ldp_agent( + self, agent_type: str | type + ) -> "Agent[SimpleAgentState] | None": + """Attempt to convert the input agent type to an ldp Agent.""" + if not isinstance(agent_type, str): # Convert to fully qualified name + agent_type = f"{agent_type.__module__}.{agent_type.__name__}" + if not agent_type.startswith("ldp"): + return None + if not HAS_LDP_INSTALLED: + raise ImportError( + "ldp agents requires the 'ldp' extra for 'ldp'. Please:" + " `pip install paper-qa[ldp]`." + ) + + # TODO: support general agents + agent_cls = cast(type[Agent], locate(agent_type)) + agent_settings = self.agent + agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {} + if issubclass(agent_cls, ReActAgent | MemoryAgent): + if ( + issubclass(agent_cls, MemoryAgent) + and "memory_model" in config + and "memories" in config + ): + if "embedding_model" in config["memory_model"]: + # Work around LDPEmbeddingModel not yet supporting deserialization + config["memory_model"]["embedding_model"] = ( + LDPEmbeddingModel.from_name( + embedding=config["memory_model"].pop("embedding_model")[ + "name" + ] + ) + ) + config["memory_model"] = UIndexMemoryModel(**config["memory_model"]) + memories = _Memories.validate_python(config.pop("memories")) + await asyncio.gather( + *( + config["memory_model"].add_memory(memory) + for memory in ( + memories.values() + if isinstance(memories, dict) + else memories + ) + ) + ) + return agent_cls( + llm_model={"model": agent_llm, "temperature": self.temperature}, + **config, + ) + if issubclass(agent_cls, SimpleAgent): + return agent_cls( + llm_model={"model": agent_llm, "temperature": self.temperature}, + sys_prompt=agent_settings.agent_system_prompt, + **config, + ) + if issubclass(agent_cls, HTTPAgentClient): + set_training_mode(False) + return HTTPAgentClient[SimpleAgentState]( + agent_state_type=SimpleAgentState, **config + ) + raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.") + MaybeSettings = Settings | str | None