diff --git a/src/apps/plots/migrations/0002_datarun_expiration_date.py b/src/apps/plots/migrations/0002_datarun_expiration_date.py index 8f1470a..125f947 100644 --- a/src/apps/plots/migrations/0002_datarun_expiration_date.py +++ b/src/apps/plots/migrations/0002_datarun_expiration_date.py @@ -2,7 +2,21 @@ import datetime +from django.conf import settings from django.db import migrations, models +from django.utils import timezone + +from config.instruments import Instruments + + +def set_expiration_date(apps, _): + DataRun = apps.get_model("plots", "DataRun") + for run in DataRun.objects.all(): + if Instruments.has_value(run.instrument.name): + run.expiration_date = run.created_on + datetime.timedelta(days=settings.LIVE_PLOT_EXPIRATION_TIME) + else: + run.expiration_date = timezone.datetime(2100, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + run.save() class Migration(migrations.Migration): @@ -14,9 +28,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="datarun", name="expiration_date", - field=models.DateTimeField( - default=datetime.datetime(2027, 8, 8, 18, 55, 41, 999298, tzinfo=datetime.timezone.utc), - verbose_name="Expires", - ), + field=models.DateTimeField(default=None, blank=True, null=True, verbose_name="Expires"), ), + migrations.RunPython(set_expiration_date), ] diff --git a/src/apps/plots/models.py b/src/apps/plots/models.py index 0f16ac4..37c34b1 100644 --- a/src/apps/plots/models.py +++ b/src/apps/plots/models.py @@ -10,6 +10,8 @@ from django.db import models from django.utils import timezone +from config.instruments import Instruments + DATA_TYPES = {"json": 0, "html": 1, "div": 1} DATA_TYPE_INFO = {0: {"name": "json"}, 1: {"name": "html"}} @@ -41,9 +43,14 @@ class DataRun(models.Model): run_id = models.TextField() instrument = models.ForeignKey(Instrument, on_delete=models.deletion.CASCADE) created_on = models.DateTimeField("Timestamp", auto_now_add=True) - expiration_date = models.DateTimeField( - "Expires", default=timezone.now() + timedelta(days=(settings.LIVE_PLOT_EXPIRATION_TIME)) - ) + expiration_date = models.DateTimeField("Expires", default=None, null=True, blank=True) + + def clean(self): + if self.expiration_date is None: + if Instruments.has_value(self.instrument.name): + self.expiration_date = self.created_on + timedelta(days=settings.LIVE_PLOT_EXPIRATION_TIME) + else: + self.expiration_date = timezone.datetime(2100, 1, 1, 0, 0, 0, tzinfo=timezone.utc) def __str__(self): return f"{self.instrument}_{self.run_number}_{self.run_id}" diff --git a/src/apps/plots/view_util.py b/src/apps/plots/view_util.py index 583bbeb..e9a86af 100644 --- a/src/apps/plots/view_util.py +++ b/src/apps/plots/view_util.py @@ -144,10 +144,10 @@ def store_user_data(user, data_id, data, data_type, expiration_date: Optional[da run_obj.run_number = 0 run_obj.run_id = data_id run_obj.expiration_date = expiration_date + # Save run object to generate id (primary key) run_obj.save() - # Since user data have no run number, force the run number to be the PK, - # which is unique and will allow user to retrieve the data like normal - # instrument data. + # User data has no run number, use the unique id as the run number + # so that the user can retrieve the data like normal instrument data run_obj.run_number = run_obj.id run_obj.save() diff --git a/src/apps/plots/views.py b/src/apps/plots/views.py index c7be7db..965dbdb 100644 --- a/src/apps/plots/views.py +++ b/src/apps/plots/views.py @@ -99,7 +99,10 @@ def _store(request, instrument, run_id=None, as_user=False): data_type_default = PlotData.get_data_type_from_data(raw_data) data_type = request.POST.get("data_type", default=data_type_default) expiration_date = request.POST.get( - "expiration_date", default=timezone.now() + timedelta(days=settings.LIVE_PLOT_EXPIRATION_TIME) + "expiration_date", + default=timezone.datetime(2100, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + if as_user + else timezone.now() + timedelta(days=settings.LIVE_PLOT_EXPIRATION_TIME), ) if as_user: diff --git a/src/config/instruments.py b/src/config/instruments.py new file mode 100644 index 0000000..65f8d3e --- /dev/null +++ b/src/config/instruments.py @@ -0,0 +1,43 @@ +from enum import StrEnum + + +class Instruments(StrEnum): + ARCS = "arcs" + CG2 = "cg2" + CNCs = "cncs" + CORELLI = "corelli" + EQSANS = "eqsans" + HB2A = "hb2a" + HB2B = "hb2b" + HB2C = "hb2c" + HB3A = "hb3a" + HYS = "hys" + MANDI = "mandi" + NOM = "nom" + PG3 = "pg3" + REF_L = "ref_l" + REF_M = "ref_m" + SEQ = "seq" + SNAP = "snap" + TOPAZ = "topaz" + USANS = "usans" + VULCAN = "vulcan" + # instruments that haven't published to livedata yet + BL0 = "bl0" + BSS = "bss" + CG1D = "cg1d" + CG3 = "cg3" + FNPB = "fnpb" + HB3 = "hb3" + NOWB = "nowb" + NOWD = "nowd" + NOWG = "nowg" + NOWV = "nowv" + NOWX = "nowx" + NSE = "nse" + VENUS = "venus" + VIS = "vis" + + @classmethod + def has_value(cls, value): + return any(value == item.value for item in cls) diff --git a/tests/test_enum.py b/tests/test_enum.py new file mode 100644 index 0000000..6846919 --- /dev/null +++ b/tests/test_enum.py @@ -0,0 +1,12 @@ +import sys +from pathlib import Path + +import pytest + +sys.path.append(str(Path(__file__).parents[1])) +from src.config.instruments import Instruments + + +@pytest.mark.parametrize("instrument", ["fake_instrument", "ref_m"]) +def test_enum(instrument): + assert Instruments.has_value(instrument) == (instrument == "ref_m")