-
Notifications
You must be signed in to change notification settings - Fork 553
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Client (external API) Module For Enhanced Metadata (#306)
* first pass at adding clients module * remove settings.py * add client module and holder docdetails type * remove some TODOs and add weird request test * stop pulling bibtex if not requested * better comment on bibtex usage * fix eof error * revert to default email for test cassettes * s2 and crossref module-level api headers * move doi_url into its own field, refactor all model validators into one method, regenerate all cassettes with new fields * add robustness for timeout errors * move exception handling into parent method * add explicit "prefer other" to the __add__ method and use crossref in live tests for stability * move clients/utils into utils * rename text in prompt to citation * use loop in tests * adjust citation prompt, add docstring for populate_bibtex_key_citation, replace bibtex extract w pattern, lower timeout threshold in test * add topological run-order via nested sequence --------- Co-authored-by: Michael Skarlinski <[email protected]>
- Loading branch information
Showing
36 changed files
with
47,507 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ pre-commit | |
pytest | ||
pytest-asyncio | ||
pytest-sugar | ||
pytest-vcr | ||
pytest-timer | ||
types-requests | ||
types-setuptools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
from __future__ import annotations | ||
|
||
import copy | ||
import logging | ||
from typing import Any, Collection, Coroutine, Sequence | ||
|
||
import aiohttp | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
from ..types import Doc, DocDetails | ||
from ..utils import gather_with_concurrency | ||
from .client_models import MetadataPostProcessor, MetadataProvider | ||
from .crossref import CrossrefProvider | ||
from .journal_quality import JournalQualityPostProcessor | ||
from .semantic_scholar import SemanticScholarProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
ALL_CLIENTS: ( | ||
Collection[type[MetadataPostProcessor | MetadataProvider]] | ||
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]] | ||
) = { | ||
CrossrefProvider, | ||
SemanticScholarProvider, | ||
JournalQualityPostProcessor, | ||
} | ||
|
||
|
||
class DocMetadataTask(BaseModel): | ||
"""Holder for provider and processor tasks.""" | ||
|
||
providers: Collection[MetadataProvider] | ||
processors: Collection[MetadataPostProcessor] | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
def provider_queries( | ||
self, query: dict | ||
) -> list[Coroutine[Any, Any, DocDetails | None]]: | ||
return [p.query(query) for p in self.providers] | ||
|
||
def processor_queries( | ||
self, doc_details: DocDetails, session: aiohttp.ClientSession | ||
) -> list[Coroutine[Any, Any, DocDetails]]: | ||
return [ | ||
p.process(copy.copy(doc_details), session=session) for p in self.processors | ||
] | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"DocMetadataTask(providers={self.providers}, processors={self.processors})" | ||
) | ||
|
||
|
||
class DocMetadataClient: | ||
def __init__( | ||
self, | ||
session: aiohttp.ClientSession | None = None, | ||
clients: ( | ||
Collection[type[MetadataPostProcessor | MetadataProvider]] | ||
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]] | ||
) = ALL_CLIENTS, | ||
) -> None: | ||
"""Metadata client for querying multiple metadata providers and processors. | ||
Args: | ||
session: outer scope aiohttp session to allow for connection pooling | ||
clients: list of MetadataProvider and MetadataPostProcessor classes to query; | ||
if nested, will query in order looking for termination criteria after each. | ||
Will terminate early if either DocDetails.is_hydration_needed is False OR if | ||
all requested fields are present in the DocDetails object. | ||
""" | ||
self._session = session | ||
self.tasks: list[DocMetadataTask] = [] | ||
|
||
# first see if we are nested; i.e. we want order | ||
if isinstance(clients, Sequence) and all( | ||
isinstance(sub_clients, Collection) for sub_clients in clients | ||
): | ||
for sub_clients in clients: | ||
self.tasks.append( | ||
DocMetadataTask( | ||
providers=[ | ||
c() for c in sub_clients if issubclass(c, MetadataProvider) | ||
], | ||
processors=[ | ||
c() | ||
for c in sub_clients | ||
if issubclass(c, MetadataPostProcessor) | ||
], | ||
) | ||
) | ||
# otherwise, we are a flat collection | ||
if not self.tasks and all(not isinstance(c, Collection) for c in clients): | ||
self.tasks.append( | ||
DocMetadataTask( | ||
providers=[c() for c in clients if issubclass(c, MetadataProvider)], # type: ignore[operator, arg-type] | ||
processors=[ | ||
c() for c in clients if issubclass(c, MetadataPostProcessor) # type: ignore[operator, arg-type] | ||
], | ||
) | ||
) | ||
|
||
if not self.tasks or (self.tasks and not self.tasks[0].providers): | ||
raise ValueError("At least one MetadataProvider must be provided.") | ||
|
||
async def query(self, **kwargs) -> DocDetails | None: | ||
|
||
session = aiohttp.ClientSession() if self._session is None else self._session | ||
|
||
query_args = kwargs if "session" in kwargs else kwargs | {"session": session} | ||
|
||
doc_details: DocDetails | None = None | ||
|
||
for ti, task in enumerate(self.tasks): | ||
logger.debug( | ||
f"Attempting to populate metadata query: {query_args} via {task}" | ||
) | ||
|
||
# first query all client_models and aggregate the results | ||
doc_details = ( | ||
sum( | ||
p | ||
for p in ( | ||
await gather_with_concurrency( | ||
len(task.providers), task.provider_queries(query_args) | ||
) | ||
) | ||
if p | ||
) | ||
or None | ||
) | ||
|
||
# then process and re-aggregate the results | ||
if doc_details and task.processors: | ||
doc_details = sum( | ||
await gather_with_concurrency( | ||
len(task.processors), | ||
task.processor_queries(doc_details, session), | ||
) | ||
) | ||
|
||
if doc_details and not doc_details.is_hydration_needed( | ||
inclusion=kwargs.get("fields", []) | ||
): | ||
logger.debug( | ||
"All requested fields are present in the DocDetails " | ||
f"object{', stopping early.' if ti != len(self.tasks) - 1 else '.'}" | ||
) | ||
break | ||
|
||
return doc_details | ||
|
||
async def bulk_query( | ||
self, queries: Collection[dict[str, Any]], concurrency: int = 10 | ||
) -> list[DocDetails]: | ||
return await gather_with_concurrency( | ||
concurrency, [self.query(**kwargs) for kwargs in queries] | ||
) | ||
|
||
async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails: | ||
if doc_details := await self.query(**kwargs): | ||
if doc.overwrite_fields_from_metadata: | ||
return doc_details | ||
# hard overwrite the details from the prior object | ||
doc_details.dockey = doc.dockey | ||
doc_details.doc_id = doc.dockey | ||
doc_details.docname = doc.docname | ||
doc_details.key = doc.docname | ||
doc_details.citation = doc.citation | ||
return doc_details | ||
|
||
# if we can't get metadata, just return the doc, but don't overwrite any fields | ||
prior_doc = doc.model_dump() | ||
prior_doc["overwrite_fields_from_metadata"] = False | ||
return DocDetails(**prior_doc) |
Oops, something went wrong.