diff --git a/article/migrations/0014_articleformat_status_alter_articleformat_article_and_more.py b/article/migrations/0014_articleformat_status_alter_articleformat_article_and_more.py new file mode 100644 index 00000000..ad3705fd --- /dev/null +++ b/article/migrations/0014_articleformat_status_alter_articleformat_article_and_more.py @@ -0,0 +1,56 @@ +# Generated by Django 5.0.3 on 2024-08-23 18:12 + +import django.db.models.deletion +import modelcluster.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("article", "0013_article_article_license"), + ] + + operations = [ + migrations.AddField( + model_name="articleformat", + name="status", + field=models.CharField( + blank=True, + choices=[ + ("E", "Error occurred during export format creation"), + ("S", "Export format created successfully"), + ("A", "Export format available on external site"), + ], + max_length=1, + null=True, + ), + ), + migrations.AlterField( + model_name="articleformat", + name="article", + field=modelcluster.fields.ParentalKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="article_format", + to="article.article", + ), + ), + migrations.AlterField( + model_name="articleformat", + name="format_name", + field=models.CharField( + blank=True, + choices=[ + ("crossref", "Crossref"), + ("pubmed", "PubMed"), + ("pmc", "PubMed PMC"), + ("doaj", "DOAJ"), + ], + max_length=20, + null=True, + verbose_name="Article Format", + ), + ), + ] diff --git a/article/models.py b/article/models.py index 0798d5f5..4ad38138 100755 --- a/article/models.py +++ b/article/models.py @@ -1,7 +1,7 @@ import os import sys from datetime import datetime - +from lxml import etree from django.core.files.base import ContentFile from django.db import IntegrityError, models from django.db.utils import DataError @@ -635,18 +635,29 @@ def article_directory_path(instance, filename): except AttributeError: return os.path.join(instance.article.pid_v3, instance.format_name, filename) +STATUS_EXPORT_FILE = [ + ("E", "Error occurred during export format creation"), + ("S", "Export format created successfully"), + ("A", "Export format available on external site"), +] -class ArticleFormat(CommonControlField): +TYPE_OF_FORMAT = [ + ("crossref", "Crossref"), + ("pubmed", "PubMed"), + ("pmc", "PubMed PMC"), + ("doaj", "DOAJ"), +] +class ArticleFormat(CommonControlField): article = ParentalKey( Article, null=True, blank=True, on_delete=models.SET_NULL, - related_name="format", + related_name="article_format", ) format_name = models.CharField( - _("Article Format"), max_length=20, null=True, blank=True + _("Article Format"), max_length=20, null=True, blank=True, choices=TYPE_OF_FORMAT ) version = models.PositiveIntegerField(null=True, blank=True) file = models.FileField( @@ -657,6 +668,12 @@ class ArticleFormat(CommonControlField): ) report = models.JSONField(null=True, blank=True) valid = models.BooleanField(default=None, null=True, blank=True) + status = models.CharField( + blank=True, + null=True, + max_length=1, + choices=STATUS_EXPORT_FILE + ) finger_print = models.CharField(max_length=64, null=True, blank=True) base_form_class = CoreAdminModelForm @@ -664,6 +681,7 @@ class ArticleFormat(CommonControlField): FieldPanel("file"), FieldPanel("format_name"), FieldPanel("version"), + FieldPanel("status"), FieldPanel("report"), ] @@ -712,7 +730,7 @@ def get(cls, article, format_name=None, version=None): @classmethod def create(cls, user, article, format_name=None, version=None): - if article or format_name or version: + if article and format_name or version: try: obj = cls() obj.article = article @@ -731,15 +749,12 @@ def create(cls, user, article, format_name=None, version=None): def create_or_update(cls, user, article, format_name=None, version=None): try: obj = cls.get(article, format_name=format_name, version=version) - obj.updated_by = user - obj.format_name = format_name or obj.format_name - obj.version = version or obj.version - obj.save() except cls.DoesNotExist: obj = cls.create(user, article, format_name, version) return obj def save_file(self, filename, content): + content = etree.tostring(content) finger_print = generate_finger_print(content) if finger_print != self.finger_print: try: @@ -750,33 +765,24 @@ def save_file(self, filename, content): self.finger_print = finger_print self.save() - @classmethod - def generate( - cls, - user, - article, - format_name, + def save_format_xml( + self, + format_xml, filename, - function_generate_format, + status, + report=None, indexed_check=False, - data=None, version=None, ): - if indexed_check and not article.is_indexed_at(format_name): - return + if indexed_check and not self.article.is_indexed_at(self.format_name): + return try: - version = version or 1 - obj = None - obj = cls.create_or_update(user, article, format_name, version) - xmltree = article.xmltree - if data is not None: - content = function_generate_format(xmltree, data=data) - else: - content = function_generate_format(xmltree) - obj.save_file(filename, content) - obj.report = None - obj.save() - return obj + if filename and len(format_xml): + self.save_file(filename=filename, content=format_xml) + self.version = version or 1 + self.report = report + self.status = status + self.save() except Exception as e: exc_type, exc_value, exc_traceback = sys.exc_info() unexpected_event = UnexpectedEvent.create( @@ -784,47 +790,9 @@ def generate( exc_traceback=exc_traceback, detail=dict( function="article.models.ArticleFormat.generate", - format_name=format_name, - article_pid_v3=article.pid_v3, - sps_pkg_name=article.sps_pkg_name, + format_name=self.format_name, + article_pid_v3=self.article.pid_v3, + sps_pkg_name=self.article.sps_pkg_name, ), ) - if obj: - obj.report = unexpected_event.data - obj.valid = False - obj.save() - @classmethod - def generate_formats(cls, user, article): - for doi in article.doi.all(): - if not doi.value: - break - try: - prefix = doi.value.split("/")[0] - crossref_data = CrossRefConfiguration.get_data(prefix) - cls.generate( - user, - article, - "crossref", - article.sps_pkg_name + ".xml", - crossref.pipeline_crossref, - data=crossref_data, - ) - except CrossRefConfiguration.DoesNotExist: - break - cls.generate( - user, - article, - "pubmed", - article.sps_pkg_name + ".xml", - pubmed.pipeline_pubmed, - indexed_check=False, - ) - cls.generate( - user, - article, - "pmc", - article.sps_pkg_name + ".xml", - pmc.pipeline_pmc, - indexed_check=False, - ) diff --git a/article/tasks.py b/article/tasks.py index 1fc6eaa0..cce70837 100644 --- a/article/tasks.py +++ b/article/tasks.py @@ -5,11 +5,14 @@ from django.db.models import Q, Count from django.contrib.auth import get_user_model from django.utils.translation import gettext as _ +from packtools.sps.formats import pubmed, pmc, crossref from article.models import Article, ArticleFormat from article.sources import xmlsps from article.sources.preprint import harvest_preprints from config import celery_app +from doi_manager.models import CrossRefConfiguration +from journal.models import Journal from pid_provider.models import PidProviderXML from pid_provider.provider import PidProvider from tracker.models import UnexpectedEvent @@ -51,11 +54,13 @@ def _items_to_load_article(from_date, force_update): if not from_date: # obtém a última atualização de Article try: - article = Article.objects.filter( - ~Q(valid=True) - ).order_by("-updated").first() + article = ( + Article.objects.filter(~Q(valid=True)).order_by("-updated").first() + ) if not article: - article = Article.objects.filter(valid=True).order_by("-updated").first() + article = ( + Article.objects.filter(valid=True).order_by("-updated").first() + ) if article: from_date = article.updated except Article.DoesNotExist: @@ -127,34 +132,54 @@ def load_preprint(self, user_id, oai_pmh_preprint_uri): harvest_preprints(oai_pmh_preprint_uri, user) +def get_function_format_xml(format_name): + dict_functions_formats = { + "pmc": pmc.pipeline_pmc, + "pubmed": pubmed.pipeline_pubmed, + "crossref": crossref.pipeline_crossref, + } + return dict_functions_formats.get(format_name) + + +def handler_formatting_error(article_format, message): + article_format.save_format_xml( + filename=None, format_xml=None, status="E", report={"exception_msg": message} + ) + + +def get_article_format(user, pid_v3, format_name): + try: + article = Article.objects.get(pid_v3=pid_v3) + except Article.DoesNotExist: + logging.info(f"Unable to convert article {pid_v3} to the specified format") + return + + try: + article_format = ArticleFormat.objects.get(article=article, format_name="pmc") + except ArticleFormat.DoesNotExist: + article_format = ArticleFormat.create_or_update( + user=user, article=article, format_name=format_name, version=1 + ) + return article_format + + @celery_app.task(bind=True) def task_convert_xml_to_other_formats_for_articles( - self, user_id=None, username=None, from_date=None, force_update=False + self, format_name, user_id=None, username=None, force_update=False ): - try: - user = _get_user(self.request, username, user_id) + journals = Journal.objects.filter(indexed_at__acronym=format_name) + articles = Article.objects.filter(journal__in=journals) - for item in Article.objects.filter(sps_pkg_name__isnull=False).iterator(): - logging.info(item.pid_v3) - try: - convert_xml_to_other_formats.apply_async( - kwargs={ - "user_id": user.id, - "username": user.username, - "item_id": item.id, - "force_update": force_update, - } - ) - except Exception as exception: - exc_type, exc_value, exc_traceback = sys.exc_info() - UnexpectedEvent.create( - exception=exception, - exc_traceback=exc_traceback, - detail={ - "task": "article.tasks.task_convert_xml_to_other_formats_for_articles", - "item": str(item), - }, - ) + if not force_update: + articles = articles.filter(article_format__isnull=True) + + try: + task_function_dict = { + "pubmed": convert_xml_to_pubmed_or_pmc_formats, + "pmc": convert_xml_to_pubmed_or_pmc_formats, + "crossref": convert_xml_to_crossref_format, + } + task_function = task_function_dict[format_name] except Exception as exception: exc_type, exc_value, exc_traceback = sys.exc_info() UnexpectedEvent.create( @@ -162,34 +187,85 @@ def task_convert_xml_to_other_formats_for_articles( exc_traceback=exc_traceback, detail={ "task": "article.tasks.task_convert_xml_to_other_formats_for_articles", + "item": str(article), }, ) + return + + for article in articles: + try: + task_function.apply_async( + user_id=user_id, + username=username, + format_name=format_name, + ) + except Exception as exception: + exc_type, exc_value, exc_traceback = sys.exc_info() + UnexpectedEvent.create( + exception=exception, + exc_traceback=exc_traceback, + detail={ + "task": "article.tasks.task_convert_xml_to_other_formats_for_articles", + "item": str(article), + }, + ) @celery_app.task(bind=True) -def convert_xml_to_other_formats( - self, user_id=None, username=None, item_id=None, force_update=None +def convert_xml_to_pubmed_or_pmc_formats( + self, pid_v3, format_name, user_id=None, username=None ): - user = _get_user(self.request, username, user_id) + user = _get_user(request=self.request, username=username, user_id=user_id) - try: - article = Article.objects.get(pk=item_id) - except Article.DoesNotExist: - logging.info(f"Not found {item_id}") + article_format = get_article_format( + pid_v3=pid_v3, format_name=format_name, user=user + ) + + function_format = get_function_format_xml(format_name=format_name) + + content = function_format(article_format.article.xmltree) + article_format.save_format_xml( + format_xml=content, + filename=article_format.article.sps_pkg_name + ".xml", + status="S", + ) + + +@celery_app.task(bind=True) +def convert_xml_to_crossref_format( + self, pid_v3, format_name, user_id=None, username=None +): + user = _get_user(request=self.request, username=username, user_id=user_id) + + article_format = get_article_format( + pid_v3=pid_v3, format_name=format_name, user=user + ) + + doi = article_format.article.doi.first() + if not doi: + handler_formatting_error( + article_format=article_format, + message=f"Unable to format because the article {pid_v3} has no DOI associated with it", + ) return - done = False + prefix = doi.value.split("/")[0] try: - article_format = ArticleFormat.objects.get(article=article) - done = True - except ArticleFormat.MultipleObjectsReturned: - done = True - except ArticleFormat.DoesNotExist: - done = False - logging.info(f"Done {done}") + data = CrossRefConfiguration.get_data(prefix) + except CrossRefConfiguration.DoesNotExist: + handler_formatting_error( + article_format=article_format, + message=f"Unable to convert article {pid_v3} to crossref format. CrossrefConfiguration missing", + ) + return - if not done or force_update: - ArticleFormat.generate_formats(user, article=article) + function_format = get_function_format_xml(format_name=format_name) + content = function_format(article_format.article.xmltree, data) + article_format.save_format_xml( + format_xml=content, + filename=article_format.article.sps_pkg_name + ".xml", + status="S", + ) @celery_app.task(bind=True) @@ -243,26 +319,32 @@ def article_complete_data( pass - @celery_app.task(bind=True) -def transfer_license_statements_fk_to_article_license(self, user_id=None, username=None): +def transfer_license_statements_fk_to_article_license( + self, user_id=None, username=None +): user = _get_user(self.request, username, user_id) articles_to_update = [] for instance in Article.objects.filter(article_license__isnull=True): new_license = None - if instance.license_statements.exists() and instance.license_statements.first().url: + if ( + instance.license_statements.exists() + and instance.license_statements.first().url + ): new_license = instance.license_statements.first().url elif instance.license and instance.license.license_type: new_license = instance.license.license_type - + if new_license: instance.article_license = new_license instance.updated_by = user articles_to_update.append(instance) if articles_to_update: - Article.objects.bulk_update(articles_to_update, ['article_license', 'updated_by']) + Article.objects.bulk_update( + articles_to_update, ["article_license", "updated_by"] + ) logging.info("The article_license of model Articles have been updated") @@ -270,15 +352,26 @@ def remove_duplicate_articles(pid_v3=None): ids_to_exclude = [] try: if pid_v3: - duplicates = Article.objects.filter(pid_v3=pid_v3).values("pid_v3").annotate(pid_v3_count=Count("pid_v3")).filter(pid_v3_count__gt=1) + duplicates = ( + Article.objects.filter(pid_v3=pid_v3) + .values("pid_v3") + .annotate(pid_v3_count=Count("pid_v3")) + .filter(pid_v3_count__gt=1) + ) else: - duplicates = Article.objects.values("pid_v3").annotate(pid_v3_count=Count("pid_v3")).filter(pid_v3_count__gt=1) + duplicates = ( + Article.objects.values("pid_v3") + .annotate(pid_v3_count=Count("pid_v3")) + .filter(pid_v3_count__gt=1) + ) for duplicate in duplicates: - article_ids = Article.objects.filter( - pid_v3=duplicate["pid_v3"] - ).order_by("created")[1:].values_list("id", flat=True) + article_ids = ( + Article.objects.filter(pid_v3=duplicate["pid_v3"]) + .order_by("created")[1:] + .values_list("id", flat=True) + ) ids_to_exclude.extend(article_ids) - + if ids_to_exclude: Article.objects.filter(id__in=ids_to_exclude).delete() except Exception as exception: @@ -291,7 +384,7 @@ def remove_duplicate_articles(pid_v3=None): }, ) + @celery_app.task(bind=True) def remove_duplicate_articles_task(self, user_id=None, username=None, pid_v3=None): remove_duplicate_articles(pid_v3) - diff --git a/article/tests.py b/article/tests.py index f97d8943..a18a83eb 100755 --- a/article/tests.py +++ b/article/tests.py @@ -1,37 +1,30 @@ +from lxml import etree from freezegun import freeze_time from django.test import TestCase -from django_test_migrations.migrator import Migrator +from django.core.files.uploadedfile import SimpleUploadedFile from datetime import datetime from django.utils.timezone import make_aware +from unittest.mock import patch, PropertyMock -from article.models import Article -from article.tasks import remove_duplicate_articles - - -class TestArticleMigration(TestCase): - def test_migration_0013_article_article_license(self): - migrator = Migrator(database='default') - old_state = migrator.apply_initial_migration(('article', '0012_alter_article_publisher')) - Article = old_state.apps.get_model('article', 'Article') - LicenseStatement = old_state.apps.get_model('core', 'LicenseStatement') - article = Article.objects.create() - license_statement = LicenseStatement.objects.create(url="https://www.teste.com.br") - article.license_statements.add(license_statement) - - new_state = migrator.apply_tested_migration(('article', '0013_article_article_license')) - - Article = new_state.apps.get_model('article', 'Article') - - article = Article.objects.first() - self.assertEqual(article.article_license, 'https://www.teste.com.br') - migrator.reset() +from article.models import Article, ArticleFormat +from article.tasks import ( + remove_duplicate_articles, + convert_xml_to_pubmed_or_pmc_formats, + convert_xml_to_crossref_format, +) +from core.users.models import User +from doi.models import DOI +from doi_manager.models import CrossRefConfiguration class RemoveDuplicateArticlesTest(TestCase): def create_article_at_time(self, dt, v3): @freeze_time(dt) def create_article(): - Article.objects.create(pid_v3=v3, created=make_aware(datetime.strptime(dt, "%Y-%m-%d"))) + Article.objects.create( + pid_v3=v3, created=make_aware(datetime.strptime(dt, "%Y-%m-%d")) + ) + create_article() def test_remove_duplicates_keeps_earliest_article(self): @@ -40,13 +33,17 @@ def test_remove_duplicates_keeps_earliest_article(self): self.create_article_at_time("2023-01-03", "pid1") remove_duplicate_articles() self.assertEqual(Article.objects.all().count(), 1) - self.assertEqual(Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1))) + self.assertEqual( + Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1)) + ) def test_no_removal_if_only_one_article(self): self.create_article_at_time("2023-01-01", "pid1") remove_duplicate_articles() self.assertEqual(Article.objects.all().count(), 1) - self.assertEqual(Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1))) + self.assertEqual( + Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1)) + ) def test_remove_duplicates_for_multiple_pids(self): self.create_article_at_time("2022-06-03", "pid2") @@ -56,5 +53,235 @@ def test_remove_duplicates_for_multiple_pids(self): remove_duplicate_articles() self.assertEqual(Article.objects.filter(pid_v3="pid2").count(), 1) self.assertEqual(Article.objects.filter(pid_v3="pid3").count(), 1) - self.assertEqual(Article.objects.get(pid_v3="pid2").created, make_aware(datetime(2022, 6, 3))) - self.assertEqual(Article.objects.get(pid_v3="pid3").created, make_aware(datetime(2022, 6, 14))) + self.assertEqual( + Article.objects.get(pid_v3="pid2").created, make_aware(datetime(2022, 6, 3)) + ) + self.assertEqual( + Article.objects.get(pid_v3="pid3").created, + make_aware(datetime(2022, 6, 14)), + ) + + +class ArticleFormatModelTest(TestCase): + def setUp(self): + self.user = User.objects.create( + name="admin", + ) + self.article = Article.objects.create( + pid_v3="P3swRmPHQfy37r9xRbLCw8G", + sps_pkg_name="0001-3714-rm-30-04-299", + ) + + self.test_file = SimpleUploadedFile( + "test_file.xml", + b"
Test
", + content_type="application/xml", + ) + self.test_file2 = SimpleUploadedFile( + "test_file2.xml", + b"
Test2
", + content_type="application/xml", + ) + self.article_format = ArticleFormat.objects.create( + article=self.article, + format_name="pmc", + version=1, + file=self.test_file, + valid=True, + status="S", + ) + + def verify_fields_model_article_format( + self, article_format, version, format_name=None, status=None, file=None + ): + self.assertEqual(article_format.article, self.article) + if format_name: + self.assertEqual(article_format.format_name, format_name) + self.assertEqual(article_format.version, version) + self.assertEqual(article_format.report, None) + if status: + self.assertEqual(article_format.status, status) + + def test_get_method(self): + article_format = ArticleFormat.get(self.article, format_name="pmc", version=1) + self.verify_fields_model_article_format( + article_format=article_format, format_name="pmc", version=1 + ) + + def test_create_classmethod(self): + article_format = ArticleFormat.create( + user=self.user, article=self.article, format_name="pubmed", version=1 + ) + self.verify_fields_model_article_format( + article_format=article_format, format_name="pubmed", version=1 + ) + + def test_get_method_raises_value_error(self): + with self.assertRaises(ValueError) as context: + ArticleFormat.get(self.article, format_name="pubmed") + + self.assertEqual( + str(context.exception), + "ArticleFormat.get requires article and format_name and version", + ) + + def test_create_or_update_classmethod(self): + article_format = ArticleFormat.create_or_update( + user=self.user, + article=self.article, + format_name="pmc", + version=1, + ) + self.verify_fields_model_article_format( + article_format=article_format, format_name="pmc", version=1 + ) + + def convert_and_compare_xml(self, article_format, filename, xml_content): + article_format.save_file(filename=filename, content=xml_content) + with article_format.file.open("rb") as f: + saved_content = f.read() + self.assertEqual( + etree.tostring(etree.fromstring(saved_content), encoding="utf-8"), + etree.tostring(xml_content, encoding="utf-8"), + ) + + def test_save_file_method(self): + filename = "0034-7094-rba-69-03-0227.xml" + article_format = ArticleFormat.get(self.article, format_name="pmc", version=1) + self.test_file.seek(0) + input_xml = etree.fromstring(self.test_file.read()) + self.convert_and_compare_xml( + article_format=article_format, filename=filename, xml_content=input_xml + ) + + def test_update_xml_in_save_file_method(self): + filename = "0034-7094-rba-69-03-0227.xml" + self.test_file.seek(0) + input_xml = etree.fromstring(self.test_file.read()) + + article_format = ArticleFormat.get(self.article, format_name="pmc", version=1) + self.convert_and_compare_xml( + article_format=article_format, filename=filename, xml_content=input_xml + ) + + self.test_file2.seek(0) + update_xml = etree.fromstring(self.test_file2.read()) + self.convert_and_compare_xml( + article_format=article_format, filename=filename, xml_content=update_xml + ) + + def test_save_format_xml_method(self): + article_format = ArticleFormat.get(self.article, format_name="pmc", version=1) + input_xml = etree.fromstring("
Test
") + filename = article_format.article.sps_pkg_name + ".xml" + article_format.save_format_xml( + format_xml=input_xml, filename=filename, status="S" + ) + with article_format.file.open("rb") as f: + saved_content = f.read() + self.assertEqual(saved_content, etree.tostring(input_xml, encoding="utf-8")) + self.verify_fields_model_article_format( + article_format=article_format, status="S", version=1 + ) + + +class TasksConvertXmlFormatsTest(TestCase): + def setUp(self): + self.doi = DOI.objects.create(value="10.1000.10/123456") + self.article = Article.objects.create( + pid_v3="P3swRmPHQfy37r9xRbLCw8G", + sps_pkg_name="0001-3714-rm-30-04-299", + ) + self.user = User.objects.create( + username="admin", + ) + + self.input_xml = etree.fromstring( + "
Original PMC
" + ) + self.modified_xml = etree.fromstring( + "
Modified PMC
" + ) + + def verify_article_format( + self, status, version, pid_v3=None, report=None, file_exists=True + ): + self.assertEqual(ArticleFormat.objects.count(), 1) + article_format = ArticleFormat.objects.first() + self.assertEqual(article_format.article.pid_v3, pid_v3 or self.article.pid_v3) + self.assertEqual(article_format.status, status) + self.assertEqual(article_format.version, version) + if report: + self.assertEqual(article_format.report, report) + if file_exists: + with article_format.file.open("rb") as f: + content = f.read() + self.assertEqual( + content, etree.tostring(self.modified_xml, encoding="utf-8") + ) + else: + self.assertFalse(article_format.file) + + @patch("article.models.Article.xmltree", new_callable=PropertyMock) + @patch("article.tasks.pmc.pipeline_pmc") + def test_convert_xml_to_pmc_formats(self, mock_pipeline_pmc, mock_property_xmltree): + mock_property_xmltree.return_value = self.input_xml + mock_pipeline_pmc.return_value = self.modified_xml + + convert_xml_to_pubmed_or_pmc_formats( + pid_v3=self.article.pid_v3, format_name="pmc", username="admin" + ) + mock_pipeline_pmc.assert_called_once_with(mock_property_xmltree.return_value) + self.verify_article_format(status="S", version=1) + + @patch("article.models.Article.xmltree", new_callable=PropertyMock) + @patch("article.tasks.pubmed.pipeline_pubmed") + def test_convert_xml_to_pubmed_formats( + self, mock_pipeline_pubmed, mock_property_xmltree + ): + mock_property_xmltree.return_value = self.input_xml + mock_pipeline_pubmed.return_value = self.modified_xml + + convert_xml_to_pubmed_or_pmc_formats( + pid_v3=self.article.pid_v3, format_name="pubmed", username="admin" + ) + mock_pipeline_pubmed.assert_called_once_with(mock_property_xmltree.return_value) + self.verify_article_format(status="S", version=1) + + @patch("doi_manager.models.CrossRefConfiguration.get_data", return_value=dict()) + @patch("article.models.Article.xmltree", new_callable=PropertyMock) + @patch("article.tasks.crossref.pipeline_crossref") + def test_convert_xml_to_crossref_formats( + self, mock_pipeline_crossref, mock_property_xmltree, mock_get_data + ): + self.article.doi.add(self.doi) + mock_property_xmltree.return_value = self.input_xml + mock_pipeline_crossref.return_value = self.modified_xml + + convert_xml_to_crossref_format( + pid_v3=self.article.pid_v3, format_name="crossref", username="admin" + ) + self.verify_article_format(status="S", version=1) + + def test_convert_xml_to_crossref_formats_without_doi(self): + convert_xml_to_crossref_format( + pid_v3=self.article.pid_v3, format_name="crossref", username="admin" + ) + expected_msg = { + "exception_msg": f"Unable to format because the article {self.article.pid_v3} has no DOI associated with it" + } + self.verify_article_format( + status="E", version=1, file_exists=False, report=expected_msg + ) + + @patch("doi_manager.models.CrossRefConfiguration.get_data") + def test_convert_xml_to_crossref_formats_missing_crossref_configuration( + self, mock_get_data + ): + self.article.doi.add(self.doi) + mock_get_data.side_effect = CrossRefConfiguration.DoesNotExist + convert_xml_to_crossref_format( + pid_v3=self.article.pid_v3, format_name="crossref", username="admin" + ) + expected_prefix = "10.1000.10" + mock_get_data.assert_called_once_with(expected_prefix) diff --git a/doi_manager/models.py b/doi_manager/models.py index dd5c88a4..1f1d982e 100644 --- a/doi_manager/models.py +++ b/doi_manager/models.py @@ -30,7 +30,5 @@ def data(self): @classmethod def get_data(cls, prefix): - try: - return cls.objects.get(prefix=prefix).data - except cls.DoesNotExist: - return cls().data + return cls.objects.get(prefix=prefix).data +