Skip to content

Commit

Permalink
Add Client (external API) Module For Enhanced Metadata (#306)
Browse files Browse the repository at this point in the history
* first pass at adding clients module

* remove settings.py

* add client module and holder docdetails type

* remove some TODOs and add weird request test

* stop pulling bibtex if not requested

* better comment on bibtex usage

* fix eof error

* revert to default email for test cassettes

* s2 and crossref module-level api headers

* move doi_url into its own field, refactor all model validators into one method, regenerate all cassettes with new fields

* add robustness for timeout errors

* move exception handling into parent method

* add explicit "prefer other" to the __add__ method and use crossref in live tests for stability

* move clients/utils into utils

* rename text in prompt to citation

* use loop in tests

* adjust citation prompt, add docstring for populate_bibtex_key_citation, replace bibtex extract w pattern, lower timeout threshold in test

* add topological run-order via nested sequence

---------

Co-authored-by: Michael Skarlinski <[email protected]>
  • Loading branch information
mskarlin and Michael Skarlinski committed Aug 14, 2024
1 parent 9fc10a7 commit eced4f3
Show file tree
Hide file tree
Showing 36 changed files with 47,507 additions and 6 deletions.
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ repos:
rev: v4.6.0
hooks:
- id: check-added-large-files
exclude: |
(?x)^(
paperqa/clients/client_data.*
)$
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
Expand Down Expand Up @@ -50,6 +54,11 @@ repos:
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
exclude: |
(?x)^(
tests/cassettes.*|
paperqa/clients/client_data.*
)$
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.18
hooks:
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pre-commit
pytest
pytest-asyncio
pytest-sugar
pytest-vcr
pytest-timer
types-requests
types-setuptools
2 changes: 2 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
llm_model_factory,
vector_store_factory,
)
from .types import DocDetails
from .version import __version__

__all__ = [
"Answer",
"AnthropicLLMModel",
"Context",
"Doc",
"DocDetails",
"Docs",
"EmbeddingModel",
"HybridEmbeddingModel",
Expand Down
177 changes: 177 additions & 0 deletions paperqa/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import copy
import logging
from typing import Any, Collection, Coroutine, Sequence

import aiohttp
from pydantic import BaseModel, ConfigDict

from ..types import Doc, DocDetails
from ..utils import gather_with_concurrency
from .client_models import MetadataPostProcessor, MetadataProvider
from .crossref import CrossrefProvider
from .journal_quality import JournalQualityPostProcessor
from .semantic_scholar import SemanticScholarProvider

logger = logging.getLogger(__name__)

ALL_CLIENTS: (
Collection[type[MetadataPostProcessor | MetadataProvider]]
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]]
) = {
CrossrefProvider,
SemanticScholarProvider,
JournalQualityPostProcessor,
}


class DocMetadataTask(BaseModel):
"""Holder for provider and processor tasks."""

providers: Collection[MetadataProvider]
processors: Collection[MetadataPostProcessor]

model_config = ConfigDict(arbitrary_types_allowed=True)

def provider_queries(
self, query: dict
) -> list[Coroutine[Any, Any, DocDetails | None]]:
return [p.query(query) for p in self.providers]

def processor_queries(
self, doc_details: DocDetails, session: aiohttp.ClientSession
) -> list[Coroutine[Any, Any, DocDetails]]:
return [
p.process(copy.copy(doc_details), session=session) for p in self.processors
]

def __repr__(self) -> str:
return (
f"DocMetadataTask(providers={self.providers}, processors={self.processors})"
)


class DocMetadataClient:
def __init__(
self,
session: aiohttp.ClientSession | None = None,
clients: (
Collection[type[MetadataPostProcessor | MetadataProvider]]
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]]
) = ALL_CLIENTS,
) -> None:
"""Metadata client for querying multiple metadata providers and processors.
Args:
session: outer scope aiohttp session to allow for connection pooling
clients: list of MetadataProvider and MetadataPostProcessor classes to query;
if nested, will query in order looking for termination criteria after each.
Will terminate early if either DocDetails.is_hydration_needed is False OR if
all requested fields are present in the DocDetails object.
"""
self._session = session
self.tasks: list[DocMetadataTask] = []

# first see if we are nested; i.e. we want order
if isinstance(clients, Sequence) and all(
isinstance(sub_clients, Collection) for sub_clients in clients
):
for sub_clients in clients:
self.tasks.append(
DocMetadataTask(
providers=[
c() for c in sub_clients if issubclass(c, MetadataProvider)
],
processors=[
c()
for c in sub_clients
if issubclass(c, MetadataPostProcessor)
],
)
)
# otherwise, we are a flat collection
if not self.tasks and all(not isinstance(c, Collection) for c in clients):
self.tasks.append(
DocMetadataTask(
providers=[c() for c in clients if issubclass(c, MetadataProvider)], # type: ignore[operator, arg-type]
processors=[
c() for c in clients if issubclass(c, MetadataPostProcessor) # type: ignore[operator, arg-type]
],
)
)

if not self.tasks or (self.tasks and not self.tasks[0].providers):
raise ValueError("At least one MetadataProvider must be provided.")

async def query(self, **kwargs) -> DocDetails | None:

session = aiohttp.ClientSession() if self._session is None else self._session

query_args = kwargs if "session" in kwargs else kwargs | {"session": session}

doc_details: DocDetails | None = None

for ti, task in enumerate(self.tasks):
logger.debug(
f"Attempting to populate metadata query: {query_args} via {task}"
)

# first query all client_models and aggregate the results
doc_details = (
sum(
p
for p in (
await gather_with_concurrency(
len(task.providers), task.provider_queries(query_args)
)
)
if p
)
or None
)

# then process and re-aggregate the results
if doc_details and task.processors:
doc_details = sum(
await gather_with_concurrency(
len(task.processors),
task.processor_queries(doc_details, session),
)
)

if doc_details and not doc_details.is_hydration_needed(
inclusion=kwargs.get("fields", [])
):
logger.debug(
"All requested fields are present in the DocDetails "
f"object{', stopping early.' if ti != len(self.tasks) - 1 else '.'}"
)
break

return doc_details

async def bulk_query(
self, queries: Collection[dict[str, Any]], concurrency: int = 10
) -> list[DocDetails]:
return await gather_with_concurrency(
concurrency, [self.query(**kwargs) for kwargs in queries]
)

async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails:
if doc_details := await self.query(**kwargs):
if doc.overwrite_fields_from_metadata:
return doc_details
# hard overwrite the details from the prior object
doc_details.dockey = doc.dockey
doc_details.doc_id = doc.dockey
doc_details.docname = doc.docname
doc_details.key = doc.docname
doc_details.citation = doc.citation
return doc_details

# if we can't get metadata, just return the doc, but don't overwrite any fields
prior_doc = doc.model_dump()
prior_doc["overwrite_fields_from_metadata"] = False
return DocDetails(**prior_doc)
Loading

0 comments on commit eced4f3

Please sign in to comment.