From 6256550945d502bbc13115e2218ffd181dcab8bb Mon Sep 17 00:00:00 2001 From: Geemi Wellawatte <49410838+geemi725@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:08:40 -0700 Subject: [PATCH] Updates to retraction status checker (#370) Co-authored-by: James Braza Co-authored-by: Andrew White --- .gitignore | 5 +++ paperqa/clients/crossref.py | 62 ++++++++++++++++++++++++++++++---- paperqa/clients/retractions.py | 40 ++-------------------- paperqa/types.py | 11 +++++- tests/.gitignore | 3 -- 5 files changed, 73 insertions(+), 48 deletions(-) delete mode 100644 tests/.gitignore diff --git a/.gitignore b/.gitignore index 304d979c..29de9389 100644 --- a/.gitignore +++ b/.gitignore @@ -303,3 +303,8 @@ cython_debug/ tests/*txt tests/*html tests/test_index/* +tests/example.* +tests/example2.* + +# Client data +paperqa/clients/client_data/retractions.csv diff --git a/paperqa/clients/crossref.py b/paperqa/clients/crossref.py index f7882cd2..77a62c46 100644 --- a/paperqa/clients/crossref.py +++ b/paperqa/clients/crossref.py @@ -11,6 +11,8 @@ from urllib.parse import quote import aiohttp +from anyio import open_file +from tenacity import retry, stop_after_attempt, wait_exponential from paperqa.types import CITATION_FALLBACK_DATA, DocDetails from paperqa.utils import ( @@ -104,6 +106,20 @@ def crossref_headers() -> dict[str, str]: return {} +def get_crossref_mailto() -> str: + """Crossref mailto if available, otherwise a default.""" + MAILTO = os.getenv("CROSSREF_MAILTO") + + if not MAILTO: + logger.warning( + "CROSSREF_MAILTO environment variable not set. Crossref API rate limits may" + " apply." + ) + return "test@example.com" + + return MAILTO + + async def doi_to_bibtex( doi: str, session: aiohttp.ClientSession, @@ -251,12 +267,7 @@ async def get_doc_details_from_crossref( # noqa: PLR0912 inputs_msg = f"DOI {doi}" if doi is not None else f"title {title}" - if not (CROSSREF_MAILTO := os.getenv("CROSSREF_MAILTO")): - logger.warning( - "CROSSREF_MAILTO environment variable not set. Crossref API rate limits may" - " apply." - ) - CROSSREF_MAILTO = "test@example.com" + CROSSREF_MAILTO = get_crossref_mailto() quoted_doi = f"/{quote(doi, safe='')}" if doi else "" url = f"{CROSSREF_BASE_URL}/works{quoted_doi}" params = {"mailto": CROSSREF_MAILTO} @@ -335,6 +346,45 @@ async def get_doc_details_from_crossref( # noqa: PLR0912 return await parse_crossref_to_doc_details(message, session, query_bibtex) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=5, min=5), + reraise=True, +) +async def download_retracted_dataset( + retraction_data_path: os.PathLike | str, +) -> None: + """ + Download the retraction dataset from Crossref. + + Saves the retraction dataset to `retraction_data_path`. + """ + url = f"https://api.labs.crossref.org/data/retractionwatch?{get_crossref_mailto()}" + + async with ( + aiohttp.ClientSession() as session, + session.get( + url, + timeout=aiohttp.ClientTimeout(total=300), + ) as response, + ): + response.raise_for_status() + + logger.info( + f"Retraction data was not cashed. Downloading retraction data from {url}..." + ) + + async with await open_file(str(retraction_data_path), "wb") as f: + while True: + chunk = await response.content.read(1024) + if not chunk: + break + await f.write(chunk) + + if os.path.getsize(str(retraction_data_path)) == 0: + raise RuntimeError("Retraction data is empty") + + class CrossrefProvider(DOIOrTitleBasedProvider): async def _query(self, query: TitleAuthorQuery | DOIQuery) -> DocDetails | None: if isinstance(query, DOIQuery): diff --git a/paperqa/clients/retractions.py b/paperqa/clients/retractions.py index ffceb4dd..40163581 100644 --- a/paperqa/clients/retractions.py +++ b/paperqa/clients/retractions.py @@ -5,14 +5,12 @@ import logging import os -import aiohttp -from anyio import open_file from pydantic import ValidationError -from tenacity import retry, stop_after_attempt, wait_exponential from paperqa.types import DocDetails from .client_models import DOIQuery, MetadataPostProcessor +from .crossref import download_retracted_dataset logger = logging.getLogger(__name__) @@ -52,40 +50,6 @@ def _has_cache_expired(self) -> bool: def _is_csv_cached(self) -> bool: return os.path.exists(self.retraction_data_path) - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=5, min=5), - reraise=True, - ) - async def _download_retracted_dataset(self) -> None: - - if not (CROSSREF_MAILTO := os.getenv("CROSSREF_MAILTO")): - CROSSREF_MAILTO = "test@example.com" - url = f"https://api.labs.crossref.org/data/retractionwatch?{CROSSREF_MAILTO}" - - async with ( - aiohttp.ClientSession() as session, - session.get( - url, - timeout=aiohttp.ClientTimeout(total=300), - ) as response, - ): - response.raise_for_status() - - logger.info( - f"Retraction data was not cashed. Downloading retraction data from {url}..." - ) - - async with await open_file(self.retraction_data_path, "wb") as f: - while True: - chunk = await response.content.read(1024) - if not chunk: - break - await f.write(chunk) - - if os.path.getsize(self.retraction_data_path) == 0: - raise RuntimeError("Retraction data is empty") - def _filter_dois(self) -> None: with open(self.retraction_data_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) @@ -96,7 +60,7 @@ def _filter_dois(self) -> None: async def load_data(self) -> None: if not self._is_csv_cached() or self._has_cache_expired(): - await self._download_retracted_dataset() + await download_retracted_dataset(self.retraction_data_path) self._filter_dois() diff --git a/paperqa/types.py b/paperqa/types.py index 9db707a4..2aee291e 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -326,6 +326,7 @@ class DocDetails(Doc): " quality and None means it needs to be hydrated." ), ) + is_retracted: bool | None = Field( default=None, description="Flag for whether the paper is retracted." ) @@ -550,7 +551,14 @@ def __getitem__(self, item: str): def formatted_citation(self) -> str: if self.is_retracted: - return f"**RETRACTED ARTICLE** Citation: {self.citation} Retrieved from http://retractiondatabase.org/." + base_message = "**RETRACTED ARTICLE**" + retract_info = "Retrieved from http://retractiondatabase.org/." + citation_message = ( + f"Citation: {self.citation}" + if self.citation + else f"Original DOI: {self.doi}" + ) + return f"{base_message} {citation_message} {retract_info}" if ( self.citation is None # type: ignore[redundant-expr] @@ -561,6 +569,7 @@ def formatted_citation(self) -> str: "Citation, citationCount, and sourceQuality are not set -- do you need" " to call `hydrate`?" ) + quality = ( SOURCE_QUALITY_MESSAGES[self.source_quality] if self.source_quality >= 0 diff --git a/tests/.gitignore b/tests/.gitignore deleted file mode 100644 index a8d8e0c7..00000000 --- a/tests/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# ignore test-generated files -example.* -example2.*