diff --git a/article/migrations/0014_articleformat_status_alter_articleformat_article_and_more.py b/article/migrations/0014_articleformat_status_alter_articleformat_article_and_more.py
new file mode 100644
index 00000000..ad3705fd
--- /dev/null
+++ b/article/migrations/0014_articleformat_status_alter_articleformat_article_and_more.py
@@ -0,0 +1,56 @@
+# Generated by Django 5.0.3 on 2024-08-23 18:12
+
+import django.db.models.deletion
+import modelcluster.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("article", "0013_article_article_license"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="articleformat",
+ name="status",
+ field=models.CharField(
+ blank=True,
+ choices=[
+ ("E", "Error occurred during export format creation"),
+ ("S", "Export format created successfully"),
+ ("A", "Export format available on external site"),
+ ],
+ max_length=1,
+ null=True,
+ ),
+ ),
+ migrations.AlterField(
+ model_name="articleformat",
+ name="article",
+ field=modelcluster.fields.ParentalKey(
+ blank=True,
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ related_name="article_format",
+ to="article.article",
+ ),
+ ),
+ migrations.AlterField(
+ model_name="articleformat",
+ name="format_name",
+ field=models.CharField(
+ blank=True,
+ choices=[
+ ("crossref", "Crossref"),
+ ("pubmed", "PubMed"),
+ ("pmc", "PubMed PMC"),
+ ("doaj", "DOAJ"),
+ ],
+ max_length=20,
+ null=True,
+ verbose_name="Article Format",
+ ),
+ ),
+ ]
diff --git a/article/models.py b/article/models.py
index 0798d5f5..4ad38138 100755
--- a/article/models.py
+++ b/article/models.py
@@ -1,7 +1,7 @@
import os
import sys
from datetime import datetime
-
+from lxml import etree
from django.core.files.base import ContentFile
from django.db import IntegrityError, models
from django.db.utils import DataError
@@ -635,18 +635,29 @@ def article_directory_path(instance, filename):
except AttributeError:
return os.path.join(instance.article.pid_v3, instance.format_name, filename)
+STATUS_EXPORT_FILE = [
+ ("E", "Error occurred during export format creation"),
+ ("S", "Export format created successfully"),
+ ("A", "Export format available on external site"),
+]
-class ArticleFormat(CommonControlField):
+TYPE_OF_FORMAT = [
+ ("crossref", "Crossref"),
+ ("pubmed", "PubMed"),
+ ("pmc", "PubMed PMC"),
+ ("doaj", "DOAJ"),
+]
+class ArticleFormat(CommonControlField):
article = ParentalKey(
Article,
null=True,
blank=True,
on_delete=models.SET_NULL,
- related_name="format",
+ related_name="article_format",
)
format_name = models.CharField(
- _("Article Format"), max_length=20, null=True, blank=True
+ _("Article Format"), max_length=20, null=True, blank=True, choices=TYPE_OF_FORMAT
)
version = models.PositiveIntegerField(null=True, blank=True)
file = models.FileField(
@@ -657,6 +668,12 @@ class ArticleFormat(CommonControlField):
)
report = models.JSONField(null=True, blank=True)
valid = models.BooleanField(default=None, null=True, blank=True)
+ status = models.CharField(
+ blank=True,
+ null=True,
+ max_length=1,
+ choices=STATUS_EXPORT_FILE
+ )
finger_print = models.CharField(max_length=64, null=True, blank=True)
base_form_class = CoreAdminModelForm
@@ -664,6 +681,7 @@ class ArticleFormat(CommonControlField):
FieldPanel("file"),
FieldPanel("format_name"),
FieldPanel("version"),
+ FieldPanel("status"),
FieldPanel("report"),
]
@@ -712,7 +730,7 @@ def get(cls, article, format_name=None, version=None):
@classmethod
def create(cls, user, article, format_name=None, version=None):
- if article or format_name or version:
+ if article and format_name or version:
try:
obj = cls()
obj.article = article
@@ -731,15 +749,12 @@ def create(cls, user, article, format_name=None, version=None):
def create_or_update(cls, user, article, format_name=None, version=None):
try:
obj = cls.get(article, format_name=format_name, version=version)
- obj.updated_by = user
- obj.format_name = format_name or obj.format_name
- obj.version = version or obj.version
- obj.save()
except cls.DoesNotExist:
obj = cls.create(user, article, format_name, version)
return obj
def save_file(self, filename, content):
+ content = etree.tostring(content)
finger_print = generate_finger_print(content)
if finger_print != self.finger_print:
try:
@@ -750,33 +765,24 @@ def save_file(self, filename, content):
self.finger_print = finger_print
self.save()
- @classmethod
- def generate(
- cls,
- user,
- article,
- format_name,
+ def save_format_xml(
+ self,
+ format_xml,
filename,
- function_generate_format,
+ status,
+ report=None,
indexed_check=False,
- data=None,
version=None,
):
- if indexed_check and not article.is_indexed_at(format_name):
- return
+ if indexed_check and not self.article.is_indexed_at(self.format_name):
+ return
try:
- version = version or 1
- obj = None
- obj = cls.create_or_update(user, article, format_name, version)
- xmltree = article.xmltree
- if data is not None:
- content = function_generate_format(xmltree, data=data)
- else:
- content = function_generate_format(xmltree)
- obj.save_file(filename, content)
- obj.report = None
- obj.save()
- return obj
+ if filename and len(format_xml):
+ self.save_file(filename=filename, content=format_xml)
+ self.version = version or 1
+ self.report = report
+ self.status = status
+ self.save()
except Exception as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
unexpected_event = UnexpectedEvent.create(
@@ -784,47 +790,9 @@ def generate(
exc_traceback=exc_traceback,
detail=dict(
function="article.models.ArticleFormat.generate",
- format_name=format_name,
- article_pid_v3=article.pid_v3,
- sps_pkg_name=article.sps_pkg_name,
+ format_name=self.format_name,
+ article_pid_v3=self.article.pid_v3,
+ sps_pkg_name=self.article.sps_pkg_name,
),
)
- if obj:
- obj.report = unexpected_event.data
- obj.valid = False
- obj.save()
- @classmethod
- def generate_formats(cls, user, article):
- for doi in article.doi.all():
- if not doi.value:
- break
- try:
- prefix = doi.value.split("/")[0]
- crossref_data = CrossRefConfiguration.get_data(prefix)
- cls.generate(
- user,
- article,
- "crossref",
- article.sps_pkg_name + ".xml",
- crossref.pipeline_crossref,
- data=crossref_data,
- )
- except CrossRefConfiguration.DoesNotExist:
- break
- cls.generate(
- user,
- article,
- "pubmed",
- article.sps_pkg_name + ".xml",
- pubmed.pipeline_pubmed,
- indexed_check=False,
- )
- cls.generate(
- user,
- article,
- "pmc",
- article.sps_pkg_name + ".xml",
- pmc.pipeline_pmc,
- indexed_check=False,
- )
diff --git a/article/tasks.py b/article/tasks.py
index 1fc6eaa0..cce70837 100644
--- a/article/tasks.py
+++ b/article/tasks.py
@@ -5,11 +5,14 @@
from django.db.models import Q, Count
from django.contrib.auth import get_user_model
from django.utils.translation import gettext as _
+from packtools.sps.formats import pubmed, pmc, crossref
from article.models import Article, ArticleFormat
from article.sources import xmlsps
from article.sources.preprint import harvest_preprints
from config import celery_app
+from doi_manager.models import CrossRefConfiguration
+from journal.models import Journal
from pid_provider.models import PidProviderXML
from pid_provider.provider import PidProvider
from tracker.models import UnexpectedEvent
@@ -51,11 +54,13 @@ def _items_to_load_article(from_date, force_update):
if not from_date:
# obtém a última atualização de Article
try:
- article = Article.objects.filter(
- ~Q(valid=True)
- ).order_by("-updated").first()
+ article = (
+ Article.objects.filter(~Q(valid=True)).order_by("-updated").first()
+ )
if not article:
- article = Article.objects.filter(valid=True).order_by("-updated").first()
+ article = (
+ Article.objects.filter(valid=True).order_by("-updated").first()
+ )
if article:
from_date = article.updated
except Article.DoesNotExist:
@@ -127,34 +132,54 @@ def load_preprint(self, user_id, oai_pmh_preprint_uri):
harvest_preprints(oai_pmh_preprint_uri, user)
+def get_function_format_xml(format_name):
+ dict_functions_formats = {
+ "pmc": pmc.pipeline_pmc,
+ "pubmed": pubmed.pipeline_pubmed,
+ "crossref": crossref.pipeline_crossref,
+ }
+ return dict_functions_formats.get(format_name)
+
+
+def handler_formatting_error(article_format, message):
+ article_format.save_format_xml(
+ filename=None, format_xml=None, status="E", report={"exception_msg": message}
+ )
+
+
+def get_article_format(user, pid_v3, format_name):
+ try:
+ article = Article.objects.get(pid_v3=pid_v3)
+ except Article.DoesNotExist:
+ logging.info(f"Unable to convert article {pid_v3} to the specified format")
+ return
+
+ try:
+ article_format = ArticleFormat.objects.get(article=article, format_name="pmc")
+ except ArticleFormat.DoesNotExist:
+ article_format = ArticleFormat.create_or_update(
+ user=user, article=article, format_name=format_name, version=1
+ )
+ return article_format
+
+
@celery_app.task(bind=True)
def task_convert_xml_to_other_formats_for_articles(
- self, user_id=None, username=None, from_date=None, force_update=False
+ self, format_name, user_id=None, username=None, force_update=False
):
- try:
- user = _get_user(self.request, username, user_id)
+ journals = Journal.objects.filter(indexed_at__acronym=format_name)
+ articles = Article.objects.filter(journal__in=journals)
- for item in Article.objects.filter(sps_pkg_name__isnull=False).iterator():
- logging.info(item.pid_v3)
- try:
- convert_xml_to_other_formats.apply_async(
- kwargs={
- "user_id": user.id,
- "username": user.username,
- "item_id": item.id,
- "force_update": force_update,
- }
- )
- except Exception as exception:
- exc_type, exc_value, exc_traceback = sys.exc_info()
- UnexpectedEvent.create(
- exception=exception,
- exc_traceback=exc_traceback,
- detail={
- "task": "article.tasks.task_convert_xml_to_other_formats_for_articles",
- "item": str(item),
- },
- )
+ if not force_update:
+ articles = articles.filter(article_format__isnull=True)
+
+ try:
+ task_function_dict = {
+ "pubmed": convert_xml_to_pubmed_or_pmc_formats,
+ "pmc": convert_xml_to_pubmed_or_pmc_formats,
+ "crossref": convert_xml_to_crossref_format,
+ }
+ task_function = task_function_dict[format_name]
except Exception as exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
UnexpectedEvent.create(
@@ -162,34 +187,85 @@ def task_convert_xml_to_other_formats_for_articles(
exc_traceback=exc_traceback,
detail={
"task": "article.tasks.task_convert_xml_to_other_formats_for_articles",
+ "item": str(article),
},
)
+ return
+
+ for article in articles:
+ try:
+ task_function.apply_async(
+ user_id=user_id,
+ username=username,
+ format_name=format_name,
+ )
+ except Exception as exception:
+ exc_type, exc_value, exc_traceback = sys.exc_info()
+ UnexpectedEvent.create(
+ exception=exception,
+ exc_traceback=exc_traceback,
+ detail={
+ "task": "article.tasks.task_convert_xml_to_other_formats_for_articles",
+ "item": str(article),
+ },
+ )
@celery_app.task(bind=True)
-def convert_xml_to_other_formats(
- self, user_id=None, username=None, item_id=None, force_update=None
+def convert_xml_to_pubmed_or_pmc_formats(
+ self, pid_v3, format_name, user_id=None, username=None
):
- user = _get_user(self.request, username, user_id)
+ user = _get_user(request=self.request, username=username, user_id=user_id)
- try:
- article = Article.objects.get(pk=item_id)
- except Article.DoesNotExist:
- logging.info(f"Not found {item_id}")
+ article_format = get_article_format(
+ pid_v3=pid_v3, format_name=format_name, user=user
+ )
+
+ function_format = get_function_format_xml(format_name=format_name)
+
+ content = function_format(article_format.article.xmltree)
+ article_format.save_format_xml(
+ format_xml=content,
+ filename=article_format.article.sps_pkg_name + ".xml",
+ status="S",
+ )
+
+
+@celery_app.task(bind=True)
+def convert_xml_to_crossref_format(
+ self, pid_v3, format_name, user_id=None, username=None
+):
+ user = _get_user(request=self.request, username=username, user_id=user_id)
+
+ article_format = get_article_format(
+ pid_v3=pid_v3, format_name=format_name, user=user
+ )
+
+ doi = article_format.article.doi.first()
+ if not doi:
+ handler_formatting_error(
+ article_format=article_format,
+ message=f"Unable to format because the article {pid_v3} has no DOI associated with it",
+ )
return
- done = False
+ prefix = doi.value.split("/")[0]
try:
- article_format = ArticleFormat.objects.get(article=article)
- done = True
- except ArticleFormat.MultipleObjectsReturned:
- done = True
- except ArticleFormat.DoesNotExist:
- done = False
- logging.info(f"Done {done}")
+ data = CrossRefConfiguration.get_data(prefix)
+ except CrossRefConfiguration.DoesNotExist:
+ handler_formatting_error(
+ article_format=article_format,
+ message=f"Unable to convert article {pid_v3} to crossref format. CrossrefConfiguration missing",
+ )
+ return
- if not done or force_update:
- ArticleFormat.generate_formats(user, article=article)
+ function_format = get_function_format_xml(format_name=format_name)
+ content = function_format(article_format.article.xmltree, data)
+ article_format.save_format_xml(
+ format_xml=content,
+ filename=article_format.article.sps_pkg_name + ".xml",
+ status="S",
+ )
@celery_app.task(bind=True)
@@ -243,26 +319,32 @@ def article_complete_data(
pass
-
@celery_app.task(bind=True)
-def transfer_license_statements_fk_to_article_license(self, user_id=None, username=None):
+def transfer_license_statements_fk_to_article_license(
+ self, user_id=None, username=None
+):
user = _get_user(self.request, username, user_id)
articles_to_update = []
for instance in Article.objects.filter(article_license__isnull=True):
new_license = None
- if instance.license_statements.exists() and instance.license_statements.first().url:
+ if (
+ instance.license_statements.exists()
+ and instance.license_statements.first().url
+ ):
new_license = instance.license_statements.first().url
elif instance.license and instance.license.license_type:
new_license = instance.license.license_type
-
+
if new_license:
instance.article_license = new_license
instance.updated_by = user
articles_to_update.append(instance)
if articles_to_update:
- Article.objects.bulk_update(articles_to_update, ['article_license', 'updated_by'])
+ Article.objects.bulk_update(
+ articles_to_update, ["article_license", "updated_by"]
+ )
logging.info("The article_license of model Articles have been updated")
@@ -270,15 +352,26 @@ def remove_duplicate_articles(pid_v3=None):
ids_to_exclude = []
try:
if pid_v3:
- duplicates = Article.objects.filter(pid_v3=pid_v3).values("pid_v3").annotate(pid_v3_count=Count("pid_v3")).filter(pid_v3_count__gt=1)
+ duplicates = (
+ Article.objects.filter(pid_v3=pid_v3)
+ .values("pid_v3")
+ .annotate(pid_v3_count=Count("pid_v3"))
+ .filter(pid_v3_count__gt=1)
+ )
else:
- duplicates = Article.objects.values("pid_v3").annotate(pid_v3_count=Count("pid_v3")).filter(pid_v3_count__gt=1)
+ duplicates = (
+ Article.objects.values("pid_v3")
+ .annotate(pid_v3_count=Count("pid_v3"))
+ .filter(pid_v3_count__gt=1)
+ )
for duplicate in duplicates:
- article_ids = Article.objects.filter(
- pid_v3=duplicate["pid_v3"]
- ).order_by("created")[1:].values_list("id", flat=True)
+ article_ids = (
+ Article.objects.filter(pid_v3=duplicate["pid_v3"])
+ .order_by("created")[1:]
+ .values_list("id", flat=True)
+ )
ids_to_exclude.extend(article_ids)
-
+
if ids_to_exclude:
Article.objects.filter(id__in=ids_to_exclude).delete()
except Exception as exception:
@@ -291,7 +384,7 @@ def remove_duplicate_articles(pid_v3=None):
},
)
+
@celery_app.task(bind=True)
def remove_duplicate_articles_task(self, user_id=None, username=None, pid_v3=None):
remove_duplicate_articles(pid_v3)
-
diff --git a/article/tests.py b/article/tests.py
index f97d8943..a18a83eb 100755
--- a/article/tests.py
+++ b/article/tests.py
@@ -1,37 +1,30 @@
+from lxml import etree
from freezegun import freeze_time
from django.test import TestCase
-from django_test_migrations.migrator import Migrator
+from django.core.files.uploadedfile import SimpleUploadedFile
from datetime import datetime
from django.utils.timezone import make_aware
+from unittest.mock import patch, PropertyMock
-from article.models import Article
-from article.tasks import remove_duplicate_articles
-
-
-class TestArticleMigration(TestCase):
- def test_migration_0013_article_article_license(self):
- migrator = Migrator(database='default')
- old_state = migrator.apply_initial_migration(('article', '0012_alter_article_publisher'))
- Article = old_state.apps.get_model('article', 'Article')
- LicenseStatement = old_state.apps.get_model('core', 'LicenseStatement')
- article = Article.objects.create()
- license_statement = LicenseStatement.objects.create(url="https://www.teste.com.br")
- article.license_statements.add(license_statement)
-
- new_state = migrator.apply_tested_migration(('article', '0013_article_article_license'))
-
- Article = new_state.apps.get_model('article', 'Article')
-
- article = Article.objects.first()
- self.assertEqual(article.article_license, 'https://www.teste.com.br')
- migrator.reset()
+from article.models import Article, ArticleFormat
+from article.tasks import (
+ remove_duplicate_articles,
+ convert_xml_to_pubmed_or_pmc_formats,
+ convert_xml_to_crossref_format,
+)
+from core.users.models import User
+from doi.models import DOI
+from doi_manager.models import CrossRefConfiguration
class RemoveDuplicateArticlesTest(TestCase):
def create_article_at_time(self, dt, v3):
@freeze_time(dt)
def create_article():
- Article.objects.create(pid_v3=v3, created=make_aware(datetime.strptime(dt, "%Y-%m-%d")))
+ Article.objects.create(
+ pid_v3=v3, created=make_aware(datetime.strptime(dt, "%Y-%m-%d"))
+ )
+
create_article()
def test_remove_duplicates_keeps_earliest_article(self):
@@ -40,13 +33,17 @@ def test_remove_duplicates_keeps_earliest_article(self):
self.create_article_at_time("2023-01-03", "pid1")
remove_duplicate_articles()
self.assertEqual(Article.objects.all().count(), 1)
- self.assertEqual(Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1)))
+ self.assertEqual(
+ Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1))
+ )
def test_no_removal_if_only_one_article(self):
self.create_article_at_time("2023-01-01", "pid1")
remove_duplicate_articles()
self.assertEqual(Article.objects.all().count(), 1)
- self.assertEqual(Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1)))
+ self.assertEqual(
+ Article.objects.all()[0].created, make_aware(datetime(2023, 1, 1))
+ )
def test_remove_duplicates_for_multiple_pids(self):
self.create_article_at_time("2022-06-03", "pid2")
@@ -56,5 +53,235 @@ def test_remove_duplicates_for_multiple_pids(self):
remove_duplicate_articles()
self.assertEqual(Article.objects.filter(pid_v3="pid2").count(), 1)
self.assertEqual(Article.objects.filter(pid_v3="pid3").count(), 1)
- self.assertEqual(Article.objects.get(pid_v3="pid2").created, make_aware(datetime(2022, 6, 3)))
- self.assertEqual(Article.objects.get(pid_v3="pid3").created, make_aware(datetime(2022, 6, 14)))
+ self.assertEqual(
+ Article.objects.get(pid_v3="pid2").created, make_aware(datetime(2022, 6, 3))
+ )
+ self.assertEqual(
+ Article.objects.get(pid_v3="pid3").created,
+ make_aware(datetime(2022, 6, 14)),
+ )
+
+
+class ArticleFormatModelTest(TestCase):
+ def setUp(self):
+ self.user = User.objects.create(
+ name="admin",
+ )
+ self.article = Article.objects.create(
+ pid_v3="P3swRmPHQfy37r9xRbLCw8G",
+ sps_pkg_name="0001-3714-rm-30-04-299",
+ )
+
+ self.test_file = SimpleUploadedFile(
+ "test_file.xml",
+ b"Test",
+ content_type="application/xml",
+ )
+ self.test_file2 = SimpleUploadedFile(
+ "test_file2.xml",
+ b"Test2",
+ content_type="application/xml",
+ )
+ self.article_format = ArticleFormat.objects.create(
+ article=self.article,
+ format_name="pmc",
+ version=1,
+ file=self.test_file,
+ valid=True,
+ status="S",
+ )
+
+ def verify_fields_model_article_format(
+ self, article_format, version, format_name=None, status=None, file=None
+ ):
+ self.assertEqual(article_format.article, self.article)
+ if format_name:
+ self.assertEqual(article_format.format_name, format_name)
+ self.assertEqual(article_format.version, version)
+ self.assertEqual(article_format.report, None)
+ if status:
+ self.assertEqual(article_format.status, status)
+
+ def test_get_method(self):
+ article_format = ArticleFormat.get(self.article, format_name="pmc", version=1)
+ self.verify_fields_model_article_format(
+ article_format=article_format, format_name="pmc", version=1
+ )
+
+ def test_create_classmethod(self):
+ article_format = ArticleFormat.create(
+ user=self.user, article=self.article, format_name="pubmed", version=1
+ )
+ self.verify_fields_model_article_format(
+ article_format=article_format, format_name="pubmed", version=1
+ )
+
+ def test_get_method_raises_value_error(self):
+ with self.assertRaises(ValueError) as context:
+ ArticleFormat.get(self.article, format_name="pubmed")
+
+ self.assertEqual(
+ str(context.exception),
+ "ArticleFormat.get requires article and format_name and version",
+ )
+
+ def test_create_or_update_classmethod(self):
+ article_format = ArticleFormat.create_or_update(
+ user=self.user,
+ article=self.article,
+ format_name="pmc",
+ version=1,
+ )
+ self.verify_fields_model_article_format(
+ article_format=article_format, format_name="pmc", version=1
+ )
+
+ def convert_and_compare_xml(self, article_format, filename, xml_content):
+ article_format.save_file(filename=filename, content=xml_content)
+ with article_format.file.open("rb") as f:
+ saved_content = f.read()
+ self.assertEqual(
+ etree.tostring(etree.fromstring(saved_content), encoding="utf-8"),
+ etree.tostring(xml_content, encoding="utf-8"),
+ )
+
+ def test_save_file_method(self):
+ filename = "0034-7094-rba-69-03-0227.xml"
+ article_format = ArticleFormat.get(self.article, format_name="pmc", version=1)
+ self.test_file.seek(0)
+ input_xml = etree.fromstring(self.test_file.read())
+ self.convert_and_compare_xml(
+ article_format=article_format, filename=filename, xml_content=input_xml
+ )
+
+ def test_update_xml_in_save_file_method(self):
+ filename = "0034-7094-rba-69-03-0227.xml"
+ self.test_file.seek(0)
+ input_xml = etree.fromstring(self.test_file.read())
+
+ article_format = ArticleFormat.get(self.article, format_name="pmc", version=1)
+ self.convert_and_compare_xml(
+ article_format=article_format, filename=filename, xml_content=input_xml
+ )
+
+ self.test_file2.seek(0)
+ update_xml = etree.fromstring(self.test_file2.read())
+ self.convert_and_compare_xml(
+ article_format=article_format, filename=filename, xml_content=update_xml
+ )
+
+ def test_save_format_xml_method(self):
+ article_format = ArticleFormat.get(self.article, format_name="pmc", version=1)
+ input_xml = etree.fromstring("Test")
+ filename = article_format.article.sps_pkg_name + ".xml"
+ article_format.save_format_xml(
+ format_xml=input_xml, filename=filename, status="S"
+ )
+ with article_format.file.open("rb") as f:
+ saved_content = f.read()
+ self.assertEqual(saved_content, etree.tostring(input_xml, encoding="utf-8"))
+ self.verify_fields_model_article_format(
+ article_format=article_format, status="S", version=1
+ )
+
+
+class TasksConvertXmlFormatsTest(TestCase):
+ def setUp(self):
+ self.doi = DOI.objects.create(value="10.1000.10/123456")
+ self.article = Article.objects.create(
+ pid_v3="P3swRmPHQfy37r9xRbLCw8G",
+ sps_pkg_name="0001-3714-rm-30-04-299",
+ )
+ self.user = User.objects.create(
+ username="admin",
+ )
+
+ self.input_xml = etree.fromstring(
+ "Original PMC"
+ )
+ self.modified_xml = etree.fromstring(
+ "Modified PMC"
+ )
+
+ def verify_article_format(
+ self, status, version, pid_v3=None, report=None, file_exists=True
+ ):
+ self.assertEqual(ArticleFormat.objects.count(), 1)
+ article_format = ArticleFormat.objects.first()
+ self.assertEqual(article_format.article.pid_v3, pid_v3 or self.article.pid_v3)
+ self.assertEqual(article_format.status, status)
+ self.assertEqual(article_format.version, version)
+ if report:
+ self.assertEqual(article_format.report, report)
+ if file_exists:
+ with article_format.file.open("rb") as f:
+ content = f.read()
+ self.assertEqual(
+ content, etree.tostring(self.modified_xml, encoding="utf-8")
+ )
+ else:
+ self.assertFalse(article_format.file)
+
+ @patch("article.models.Article.xmltree", new_callable=PropertyMock)
+ @patch("article.tasks.pmc.pipeline_pmc")
+ def test_convert_xml_to_pmc_formats(self, mock_pipeline_pmc, mock_property_xmltree):
+ mock_property_xmltree.return_value = self.input_xml
+ mock_pipeline_pmc.return_value = self.modified_xml
+
+ convert_xml_to_pubmed_or_pmc_formats(
+ pid_v3=self.article.pid_v3, format_name="pmc", username="admin"
+ )
+ mock_pipeline_pmc.assert_called_once_with(mock_property_xmltree.return_value)
+ self.verify_article_format(status="S", version=1)
+
+ @patch("article.models.Article.xmltree", new_callable=PropertyMock)
+ @patch("article.tasks.pubmed.pipeline_pubmed")
+ def test_convert_xml_to_pubmed_formats(
+ self, mock_pipeline_pubmed, mock_property_xmltree
+ ):
+ mock_property_xmltree.return_value = self.input_xml
+ mock_pipeline_pubmed.return_value = self.modified_xml
+
+ convert_xml_to_pubmed_or_pmc_formats(
+ pid_v3=self.article.pid_v3, format_name="pubmed", username="admin"
+ )
+ mock_pipeline_pubmed.assert_called_once_with(mock_property_xmltree.return_value)
+ self.verify_article_format(status="S", version=1)
+
+ @patch("doi_manager.models.CrossRefConfiguration.get_data", return_value=dict())
+ @patch("article.models.Article.xmltree", new_callable=PropertyMock)
+ @patch("article.tasks.crossref.pipeline_crossref")
+ def test_convert_xml_to_crossref_formats(
+ self, mock_pipeline_crossref, mock_property_xmltree, mock_get_data
+ ):
+ self.article.doi.add(self.doi)
+ mock_property_xmltree.return_value = self.input_xml
+ mock_pipeline_crossref.return_value = self.modified_xml
+
+ convert_xml_to_crossref_format(
+ pid_v3=self.article.pid_v3, format_name="crossref", username="admin"
+ )
+ self.verify_article_format(status="S", version=1)
+
+ def test_convert_xml_to_crossref_formats_without_doi(self):
+ convert_xml_to_crossref_format(
+ pid_v3=self.article.pid_v3, format_name="crossref", username="admin"
+ )
+ expected_msg = {
+ "exception_msg": f"Unable to format because the article {self.article.pid_v3} has no DOI associated with it"
+ }
+ self.verify_article_format(
+ status="E", version=1, file_exists=False, report=expected_msg
+ )
+
+ @patch("doi_manager.models.CrossRefConfiguration.get_data")
+ def test_convert_xml_to_crossref_formats_missing_crossref_configuration(
+ self, mock_get_data
+ ):
+ self.article.doi.add(self.doi)
+ mock_get_data.side_effect = CrossRefConfiguration.DoesNotExist
+ convert_xml_to_crossref_format(
+ pid_v3=self.article.pid_v3, format_name="crossref", username="admin"
+ )
+ expected_prefix = "10.1000.10"
+ mock_get_data.assert_called_once_with(expected_prefix)
diff --git a/doi_manager/models.py b/doi_manager/models.py
index dd5c88a4..1f1d982e 100644
--- a/doi_manager/models.py
+++ b/doi_manager/models.py
@@ -30,7 +30,5 @@ def data(self):
@classmethod
def get_data(cls, prefix):
- try:
- return cls.objects.get(prefix=prefix).data
- except cls.DoesNotExist:
- return cls().data
+ return cls.objects.get(prefix=prefix).data
+