Skip to content

Commit

Permalink
Moved agent factories to Settings
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 14, 2024
1 parent b1745b0 commit 1cc953a
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 106 deletions.
111 changes: 6 additions & 105 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import logging
import os
from collections.abc import Awaitable, Callable
from pydoc import locate
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from aviary.message import MalformedMessageError, Message
from aviary.tools import (
Expand All @@ -13,36 +12,16 @@
ToolSelector,
ToolSelectorLedger,
)
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel
from rich.console import Console
from tenacity import (
Retrying,
before_sleep_log,
retry_if_exception_type,
stop_after_attempt,
)

try:
from ldp.agent import (
Agent,
HTTPAgentClient,
MemoryAgent,
ReActAgent,
SimpleAgent,
SimpleAgentState,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.llms import EmbeddingModel

_Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated]

HAS_LDP_INSTALLED = True
except ImportError:
HAS_LDP_INSTALLED = False
from rich.console import Console

from paperqa.docs import Docs
from paperqa.settings import Settings
from paperqa.types import Answer
from paperqa.utils import pqa_directory

Expand All @@ -53,6 +32,7 @@
from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch

if TYPE_CHECKING:
from ldp.agent import Agent, SimpleAgentState
from ldp.graph.ops import OpResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,85 +76,6 @@ async def agent_query(
return response


def to_aviary_tool_selector(
agent_type: str | type, settings: Settings
) -> ToolSelector | None:
"""Attempt to convert the agent type to an aviary ToolSelector."""
if agent_type is ToolSelector or (
isinstance(agent_type, str)
and (
agent_type == ToolSelector.__name__
or (
agent_type.startswith(ToolSelector.__module__.split(".", maxsplit=1)[0])
and locate(agent_type) is ToolSelector
)
)
):
return ToolSelector(
model_name=settings.agent.agent_llm,
acompletion=settings.get_agent_llm().router.acompletion,
**(settings.agent.agent_config or {}),
)
return None


async def to_ldp_agent(
agent_type: str | type, settings: Settings
) -> "Agent[SimpleAgentState] | None":
"""Attempt to convert the agent type to an ldp Agent."""
if not isinstance(agent_type, str): # Convert to fully qualified name
agent_type = f"{agent_type.__module__}.{agent_type.__name__}"
if not agent_type.startswith("ldp"):
return None
if not HAS_LDP_INSTALLED:
raise ImportError(
"ldp agents requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
)

# TODO: support general agents
agent_cls = cast(type[Agent], locate(agent_type))
agent_settings = settings.agent
agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {}
if issubclass(agent_cls, ReActAgent | MemoryAgent):
if (
issubclass(agent_cls, MemoryAgent)
and "memory_model" in config
and "memories" in config
):
if "embedding_model" in config["memory_model"]:
# Work around EmbeddingModel not yet supporting deserialization
config["memory_model"]["embedding_model"] = EmbeddingModel.from_name(
embedding=config["memory_model"].pop("embedding_model")["name"]
)
config["memory_model"] = UIndexMemoryModel(**config["memory_model"])
memories = _Memories.validate_python(config.pop("memories"))
await asyncio.gather(
*(
config["memory_model"].add_memory(memory)
for memory in (
memories.values() if isinstance(memories, dict) else memories
)
)
)
return agent_cls(
llm_model={"model": agent_llm, "temperature": settings.temperature},
**config,
)
if issubclass(agent_cls, SimpleAgent):
return agent_cls(
llm_model={"model": agent_llm, "temperature": settings.temperature},
sys_prompt=agent_settings.agent_system_prompt,
**config,
)
if issubclass(agent_cls, HTTPAgentClient):
set_training_mode(False)
return HTTPAgentClient[SimpleAgentState](
agent_state_type=SimpleAgentState, **config
)
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")


async def run_agent(
docs: Docs,
query: QueryRequest,
Expand Down Expand Up @@ -205,11 +106,11 @@ async def run_agent(

if agent_type == "fake":
answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
elif tool_selector_or_none := to_aviary_tool_selector(agent_type, query.settings):
elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type):
answer, agent_status = await run_aviary_agent(
query, docs, tool_selector_or_none, **runner_kwargs
)
elif ldp_agent_or_none := await to_ldp_agent(agent_type, query.settings):
elif ldp_agent_or_none := await query.settings.make_ldp_agent(agent_type):
answer, agent_status = await run_ldp_agent(
query, docs, ldp_agent_or_none, **runner_kwargs
)
Expand Down
114 changes: 113 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
import asyncio
import importlib.resources
import os
from enum import StrEnum
from pathlib import Path
from pydoc import locate
from typing import Any, ClassVar, assert_never, cast

from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator
from aviary.tools import ToolSelector
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
computed_field,
field_validator,
)
from pydantic_settings import BaseSettings, CliSettingsSource, SettingsConfigDict

try:
from ldp.agent import (
Agent,
HTTPAgentClient,
MemoryAgent,
ReActAgent,
SimpleAgent,
SimpleAgentState,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.llms import EmbeddingModel as LDPEmbeddingModel

_Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated]

HAS_LDP_INSTALLED = True
except ImportError:
HAS_LDP_INSTALLED = False

from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory
from paperqa.prompts import (
citation_prompt,
Expand Down Expand Up @@ -520,6 +549,89 @@ def get_agent_llm(self) -> LiteLLMModel:
def get_embedding_model(self) -> EmbeddingModel:
return embedding_model_factory(self.embedding, **(self.embedding_config or {}))

def make_aviary_tool_selector(self, agent_type: str | type) -> ToolSelector | None:
"""Attempt to convert the input agent type to an aviary ToolSelector."""
if agent_type is ToolSelector or (
isinstance(agent_type, str)
and (
agent_type == ToolSelector.__name__
or (
agent_type.startswith(
ToolSelector.__module__.split(".", maxsplit=1)[0]
)
and locate(agent_type) is ToolSelector
)
)
):
return ToolSelector(
model_name=self.agent.agent_llm,
acompletion=self.get_agent_llm().router.acompletion,
**(self.agent.agent_config or {}),
)
return None

async def make_ldp_agent(
self, agent_type: str | type
) -> "Agent[SimpleAgentState] | None":
"""Attempt to convert the input agent type to an ldp Agent."""
if not isinstance(agent_type, str): # Convert to fully qualified name
agent_type = f"{agent_type.__module__}.{agent_type.__name__}"
if not agent_type.startswith("ldp"):
return None
if not HAS_LDP_INSTALLED:
raise ImportError(
"ldp agents requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
)

# TODO: support general agents
agent_cls = cast(type[Agent], locate(agent_type))
agent_settings = self.agent
agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {}
if issubclass(agent_cls, ReActAgent | MemoryAgent):
if (
issubclass(agent_cls, MemoryAgent)
and "memory_model" in config
and "memories" in config
):
if "embedding_model" in config["memory_model"]:
# Work around LDPEmbeddingModel not yet supporting deserialization
config["memory_model"]["embedding_model"] = (
LDPEmbeddingModel.from_name(
embedding=config["memory_model"].pop("embedding_model")[
"name"
]
)
)
config["memory_model"] = UIndexMemoryModel(**config["memory_model"])
memories = _Memories.validate_python(config.pop("memories"))
await asyncio.gather(
*(
config["memory_model"].add_memory(memory)
for memory in (
memories.values()
if isinstance(memories, dict)
else memories
)
)
)
return agent_cls(
llm_model={"model": agent_llm, "temperature": self.temperature},
**config,
)
if issubclass(agent_cls, SimpleAgent):
return agent_cls(
llm_model={"model": agent_llm, "temperature": self.temperature},
sys_prompt=agent_settings.agent_system_prompt,
**config,
)
if issubclass(agent_cls, HTTPAgentClient):
set_training_mode(False)
return HTTPAgentClient[SimpleAgentState](
agent_state_type=SimpleAgentState, **config
)
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")


MaybeSettings = Settings | str | None

Expand Down

0 comments on commit 1cc953a

Please sign in to comment.