-
Notifications
You must be signed in to change notification settings - Fork 553
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updates to retraction status checker (#370)
Co-authored-by: James Braza <[email protected]> Co-authored-by: Andrew White <[email protected]>
- Loading branch information
1 parent
dcf1331
commit 6256550
Showing
5 changed files
with
73 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ | |
from urllib.parse import quote | ||
|
||
import aiohttp | ||
from anyio import open_file | ||
from tenacity import retry, stop_after_attempt, wait_exponential | ||
|
||
from paperqa.types import CITATION_FALLBACK_DATA, DocDetails | ||
from paperqa.utils import ( | ||
|
@@ -104,6 +106,20 @@ def crossref_headers() -> dict[str, str]: | |
return {} | ||
|
||
|
||
def get_crossref_mailto() -> str: | ||
"""Crossref mailto if available, otherwise a default.""" | ||
MAILTO = os.getenv("CROSSREF_MAILTO") | ||
|
||
if not MAILTO: | ||
logger.warning( | ||
"CROSSREF_MAILTO environment variable not set. Crossref API rate limits may" | ||
" apply." | ||
) | ||
return "[email protected]" | ||
|
||
return MAILTO | ||
|
||
|
||
async def doi_to_bibtex( | ||
doi: str, | ||
session: aiohttp.ClientSession, | ||
|
@@ -251,12 +267,7 @@ async def get_doc_details_from_crossref( # noqa: PLR0912 | |
|
||
inputs_msg = f"DOI {doi}" if doi is not None else f"title {title}" | ||
|
||
if not (CROSSREF_MAILTO := os.getenv("CROSSREF_MAILTO")): | ||
logger.warning( | ||
"CROSSREF_MAILTO environment variable not set. Crossref API rate limits may" | ||
" apply." | ||
) | ||
CROSSREF_MAILTO = "[email protected]" | ||
CROSSREF_MAILTO = get_crossref_mailto() | ||
quoted_doi = f"/{quote(doi, safe='')}" if doi else "" | ||
url = f"{CROSSREF_BASE_URL}/works{quoted_doi}" | ||
params = {"mailto": CROSSREF_MAILTO} | ||
|
@@ -335,6 +346,45 @@ async def get_doc_details_from_crossref( # noqa: PLR0912 | |
return await parse_crossref_to_doc_details(message, session, query_bibtex) | ||
|
||
|
||
@retry( | ||
stop=stop_after_attempt(3), | ||
wait=wait_exponential(multiplier=5, min=5), | ||
reraise=True, | ||
) | ||
async def download_retracted_dataset( | ||
retraction_data_path: os.PathLike | str, | ||
) -> None: | ||
""" | ||
Download the retraction dataset from Crossref. | ||
Saves the retraction dataset to `retraction_data_path`. | ||
""" | ||
url = f"https://api.labs.crossref.org/data/retractionwatch?{get_crossref_mailto()}" | ||
|
||
async with ( | ||
aiohttp.ClientSession() as session, | ||
session.get( | ||
url, | ||
timeout=aiohttp.ClientTimeout(total=300), | ||
) as response, | ||
): | ||
response.raise_for_status() | ||
|
||
logger.info( | ||
f"Retraction data was not cashed. Downloading retraction data from {url}..." | ||
) | ||
|
||
async with await open_file(str(retraction_data_path), "wb") as f: | ||
while True: | ||
chunk = await response.content.read(1024) | ||
if not chunk: | ||
break | ||
await f.write(chunk) | ||
|
||
if os.path.getsize(str(retraction_data_path)) == 0: | ||
raise RuntimeError("Retraction data is empty") | ||
|
||
|
||
class CrossrefProvider(DOIOrTitleBasedProvider): | ||
async def _query(self, query: TitleAuthorQuery | DOIQuery) -> DocDetails | None: | ||
if isinstance(query, DOIQuery): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,14 +5,12 @@ | |
import logging | ||
import os | ||
|
||
import aiohttp | ||
from anyio import open_file | ||
from pydantic import ValidationError | ||
from tenacity import retry, stop_after_attempt, wait_exponential | ||
|
||
from paperqa.types import DocDetails | ||
|
||
from .client_models import DOIQuery, MetadataPostProcessor | ||
from .crossref import download_retracted_dataset | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -52,40 +50,6 @@ def _has_cache_expired(self) -> bool: | |
def _is_csv_cached(self) -> bool: | ||
return os.path.exists(self.retraction_data_path) | ||
|
||
@retry( | ||
stop=stop_after_attempt(3), | ||
wait=wait_exponential(multiplier=5, min=5), | ||
reraise=True, | ||
) | ||
async def _download_retracted_dataset(self) -> None: | ||
|
||
if not (CROSSREF_MAILTO := os.getenv("CROSSREF_MAILTO")): | ||
CROSSREF_MAILTO = "[email protected]" | ||
url = f"https://api.labs.crossref.org/data/retractionwatch?{CROSSREF_MAILTO}" | ||
|
||
async with ( | ||
aiohttp.ClientSession() as session, | ||
session.get( | ||
url, | ||
timeout=aiohttp.ClientTimeout(total=300), | ||
) as response, | ||
): | ||
response.raise_for_status() | ||
|
||
logger.info( | ||
f"Retraction data was not cashed. Downloading retraction data from {url}..." | ||
) | ||
|
||
async with await open_file(self.retraction_data_path, "wb") as f: | ||
while True: | ||
chunk = await response.content.read(1024) | ||
if not chunk: | ||
break | ||
await f.write(chunk) | ||
|
||
if os.path.getsize(self.retraction_data_path) == 0: | ||
raise RuntimeError("Retraction data is empty") | ||
|
||
def _filter_dois(self) -> None: | ||
with open(self.retraction_data_path, newline="", encoding="utf-8") as csvfile: | ||
reader = csv.DictReader(csvfile) | ||
|
@@ -96,7 +60,7 @@ def _filter_dois(self) -> None: | |
|
||
async def load_data(self) -> None: | ||
if not self._is_csv_cached() or self._has_cache_expired(): | ||
await self._download_retracted_dataset() | ||
await download_retracted_dataset(self.retraction_data_path) | ||
|
||
self._filter_dois() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.