Skip to content

Commit

Permalink
remove regex solution, replaced with logic that handles undefined val…
Browse files Browse the repository at this point in the history
…ues. Pull main.
  • Loading branch information
nadolskit committed Sep 16, 2024
2 parents 533fe9a + 846f136 commit 145af72
Show file tree
Hide file tree
Showing 11 changed files with 687 additions and 308 deletions.
111 changes: 6 additions & 105 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import logging
import os
from collections.abc import Awaitable, Callable
from pydoc import locate
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from aviary.message import MalformedMessageError, Message
from aviary.tools import (
Expand All @@ -13,36 +12,16 @@
ToolSelector,
ToolSelectorLedger,
)
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel
from rich.console import Console
from tenacity import (
Retrying,
before_sleep_log,
retry_if_exception_type,
stop_after_attempt,
)

try:
from ldp.agent import (
Agent,
HTTPAgentClient,
MemoryAgent,
ReActAgent,
SimpleAgent,
SimpleAgentState,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.llms import EmbeddingModel

_Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated]

HAS_LDP_INSTALLED = True
except ImportError:
HAS_LDP_INSTALLED = False
from rich.console import Console

from paperqa.docs import Docs
from paperqa.settings import Settings
from paperqa.types import Answer
from paperqa.utils import pqa_directory

Expand All @@ -53,6 +32,7 @@
from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch

if TYPE_CHECKING:
from ldp.agent import Agent, SimpleAgentState
from ldp.graph.ops import OpResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,85 +76,6 @@ async def agent_query(
return response


def to_aviary_tool_selector(
agent_type: str | type, settings: Settings
) -> ToolSelector | None:
"""Attempt to convert the agent type to an aviary ToolSelector."""
if agent_type is ToolSelector or (
isinstance(agent_type, str)
and (
agent_type == ToolSelector.__name__
or (
agent_type.startswith(ToolSelector.__module__.split(".", maxsplit=1)[0])
and locate(agent_type) is ToolSelector
)
)
):
return ToolSelector(
model_name=settings.agent.agent_llm,
acompletion=settings.get_agent_llm().router.acompletion,
**(settings.agent.agent_config or {}),
)
return None


async def to_ldp_agent(
agent_type: str | type, settings: Settings
) -> "Agent[SimpleAgentState] | None":
"""Attempt to convert the agent type to an ldp Agent."""
if not isinstance(agent_type, str): # Convert to fully qualified name
agent_type = f"{agent_type.__module__}.{agent_type.__name__}"
if not agent_type.startswith("ldp"):
return None
if not HAS_LDP_INSTALLED:
raise ImportError(
"ldp agents requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
)

# TODO: support general agents
agent_cls = cast(type[Agent], locate(agent_type))
agent_settings = settings.agent
agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {}
if issubclass(agent_cls, ReActAgent | MemoryAgent):
if (
issubclass(agent_cls, MemoryAgent)
and "memory_model" in config
and "memories" in config
):
if "embedding_model" in config["memory_model"]:
# Work around EmbeddingModel not yet supporting deserialization
config["memory_model"]["embedding_model"] = EmbeddingModel.from_name(
embedding=config["memory_model"].pop("embedding_model")["name"]
)
config["memory_model"] = UIndexMemoryModel(**config["memory_model"])
memories = _Memories.validate_python(config.pop("memories"))
await asyncio.gather(
*(
config["memory_model"].add_memory(memory)
for memory in (
memories.values() if isinstance(memories, dict) else memories
)
)
)
return agent_cls(
llm_model={"model": agent_llm, "temperature": settings.temperature},
**config,
)
if issubclass(agent_cls, SimpleAgent):
return agent_cls(
llm_model={"model": agent_llm, "temperature": settings.temperature},
sys_prompt=agent_settings.agent_system_prompt,
**config,
)
if issubclass(agent_cls, HTTPAgentClient):
set_training_mode(False)
return HTTPAgentClient[SimpleAgentState](
agent_state_type=SimpleAgentState, **config
)
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")


async def run_agent(
docs: Docs,
query: QueryRequest,
Expand Down Expand Up @@ -205,11 +106,11 @@ async def run_agent(

if agent_type == "fake":
answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
elif tool_selector_or_none := to_aviary_tool_selector(agent_type, query.settings):
elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type):
answer, agent_status = await run_aviary_agent(
query, docs, tool_selector_or_none, **runner_kwargs
)
elif ldp_agent_or_none := await to_ldp_agent(agent_type, query.settings):
elif ldp_agent_or_none := await query.settings.make_ldp_agent(agent_type):
answer, agent_status = await run_ldp_agent(
query, docs, ldp_agent_or_none, **runner_kwargs
)
Expand Down
31 changes: 24 additions & 7 deletions paperqa/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .client_models import MetadataPostProcessor, MetadataProvider
from .crossref import CrossrefProvider
from .journal_quality import JournalQualityPostProcessor
from .retractions import RetrationDataPostProcessor
from .retractions import RetractionDataPostProcessor
from .semantic_scholar import SemanticScholarProvider
from .unpaywall import UnpaywallProvider

Expand All @@ -29,7 +29,7 @@
ALL_CLIENTS: Collection[type[MetadataPostProcessor | MetadataProvider]] = {
*DEFAULT_CLIENTS,
UnpaywallProvider,
RetrationDataPostProcessor,
RetractionDataPostProcessor,
}


Expand Down Expand Up @@ -89,22 +89,39 @@ def __init__( # pylint: disable=dangerous-default-value
self.tasks.append(
DocMetadataTask(
providers=[
c() for c in sub_clients if issubclass(c, MetadataProvider)
c if isinstance(c, MetadataProvider) else c()
for c in sub_clients
if (isinstance(c, type) and issubclass(c, MetadataProvider))
or isinstance(c, MetadataProvider)
],
processors=[
c()
c if isinstance(c, MetadataPostProcessor) else c()
for c in sub_clients
if issubclass(c, MetadataPostProcessor)
if (
isinstance(c, type)
and issubclass(c, MetadataPostProcessor)
)
or isinstance(c, MetadataPostProcessor)
],
)
)
# otherwise, we are a flat collection
if not self.tasks and all(not isinstance(c, Collection) for c in clients):
self.tasks.append(
DocMetadataTask(
providers=[c() for c in clients if issubclass(c, MetadataProvider)], # type: ignore[operator, arg-type]
providers=[
c if isinstance(c, MetadataProvider) else c() # type: ignore[redundant-expr]
for c in clients
if (isinstance(c, type) and issubclass(c, MetadataProvider))
or isinstance(c, MetadataProvider)
],
processors=[
c() for c in clients if issubclass(c, MetadataPostProcessor) # type: ignore[operator, arg-type]
c if isinstance(c, MetadataPostProcessor) else c() # type: ignore[redundant-expr]
for c in clients
if (
isinstance(c, type) and issubclass(c, MetadataPostProcessor)
)
or isinstance(c, MetadataPostProcessor)
],
)
)
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/retractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = logging.getLogger(__name__)


class RetrationDataPostProcessor(MetadataPostProcessor[DOIQuery]):
class RetractionDataPostProcessor(MetadataPostProcessor[DOIQuery]):
def __init__(self, retraction_data_path: os.PathLike | str | None = None) -> None:

if retraction_data_path is None:
Expand Down
114 changes: 113 additions & 1 deletion paperqa/settings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
import asyncio
import importlib.resources
import os
from enum import StrEnum
from pathlib import Path
from pydoc import locate
from typing import Any, ClassVar, assert_never, cast

from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator
from aviary.tools import ToolSelector
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
computed_field,
field_validator,
)
from pydantic_settings import BaseSettings, CliSettingsSource, SettingsConfigDict

try:
from ldp.agent import (
Agent,
HTTPAgentClient,
MemoryAgent,
ReActAgent,
SimpleAgent,
SimpleAgentState,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.llms import EmbeddingModel as LDPEmbeddingModel

_Memories = TypeAdapter(dict[int, Memory] | list[Memory]) # type: ignore[var-annotated]

HAS_LDP_INSTALLED = True
except ImportError:
HAS_LDP_INSTALLED = False

from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory
from paperqa.prompts import (
citation_prompt,
Expand Down Expand Up @@ -520,6 +549,89 @@ def get_agent_llm(self) -> LiteLLMModel:
def get_embedding_model(self) -> EmbeddingModel:
return embedding_model_factory(self.embedding, **(self.embedding_config or {}))

def make_aviary_tool_selector(self, agent_type: str | type) -> ToolSelector | None:
"""Attempt to convert the input agent type to an aviary ToolSelector."""
if agent_type is ToolSelector or (
isinstance(agent_type, str)
and (
agent_type == ToolSelector.__name__
or (
agent_type.startswith(
ToolSelector.__module__.split(".", maxsplit=1)[0]
)
and locate(agent_type) is ToolSelector
)
)
):
return ToolSelector(
model_name=self.agent.agent_llm,
acompletion=self.get_agent_llm().router.acompletion,
**(self.agent.agent_config or {}),
)
return None

async def make_ldp_agent(
self, agent_type: str | type
) -> "Agent[SimpleAgentState] | None":
"""Attempt to convert the input agent type to an ldp Agent."""
if not isinstance(agent_type, str): # Convert to fully qualified name
agent_type = f"{agent_type.__module__}.{agent_type.__name__}"
if not agent_type.startswith("ldp"):
return None
if not HAS_LDP_INSTALLED:
raise ImportError(
"ldp agents requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
)

# TODO: support general agents
agent_cls = cast(type[Agent], locate(agent_type))
agent_settings = self.agent
agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {}
if issubclass(agent_cls, ReActAgent | MemoryAgent):
if (
issubclass(agent_cls, MemoryAgent)
and "memory_model" in config
and "memories" in config
):
if "embedding_model" in config["memory_model"]:
# Work around LDPEmbeddingModel not yet supporting deserialization
config["memory_model"]["embedding_model"] = (
LDPEmbeddingModel.from_name(
embedding=config["memory_model"].pop("embedding_model")[
"name"
]
)
)
config["memory_model"] = UIndexMemoryModel(**config["memory_model"])
memories = _Memories.validate_python(config.pop("memories"))
await asyncio.gather(
*(
config["memory_model"].add_memory(memory)
for memory in (
memories.values()
if isinstance(memories, dict)
else memories
)
)
)
return agent_cls(
llm_model={"model": agent_llm, "temperature": self.temperature},
**config,
)
if issubclass(agent_cls, SimpleAgent):
return agent_cls(
llm_model={"model": agent_llm, "temperature": self.temperature},
sys_prompt=agent_settings.agent_system_prompt,
**config,
)
if issubclass(agent_cls, HTTPAgentClient):
set_training_mode(False)
return HTTPAgentClient[SimpleAgentState](
agent_state_type=SimpleAgentState, **config
)
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")


MaybeSettings = Settings | str | None

Expand Down
Loading

0 comments on commit 145af72

Please sign in to comment.