diff --git a/common/forms.py b/common/forms.py index 4189ffe0a..1c13c6977 100644 --- a/common/forms.py +++ b/common/forms.py @@ -892,3 +892,72 @@ def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: representation. """ return {} + + def get_data_form_list(self) -> dict: + """ + Returns a form list based on form_list, conditionally including only + those items as per condition_list and also appearing in data_form_list. + + The list is generated dynamically because conditions in condition_list + may be dynamic. + + Essentially, version of `WizardView.get_form_list()` filtering in only + those list items appearing in `data_form_list`. + """ + data_form_keys = [key for key, form in self.data_form_list] + return { + form_key: form_class + for form_key, form_class in self.get_form_list().items() + if form_key in data_form_keys + } + + def all_serializable_form_data(self) -> Dict: + """ + Returns serializable data for all wizard steps. + + This is a re-implementation of + MeasureCreateWizard.get_all_cleaned_data(), but using self.data after + is_valid() has been successfully run. + """ + + all_data = {} + + for form_key in self.get_data_form_list().keys(): + all_data[form_key] = self.serializable_form_data_for_step(form_key) + + return all_data + + def serializable_form_data_for_step(self, step) -> Dict: + """ + Returns serializable data for a wizard step. + + This is a re-implementation of WizardView.get_cleaned_data_for_step(), + returning the serializable version of data in place of the form's + regular cleaned_data. + """ + + form_obj = self.get_form( + step=step, + data=self.storage.get_step_data(step), + files=self.storage.get_step_files(step), + ) + + return form_obj.serializable_data(remove_key_prefix=step) + + def all_serializable_form_kwargs(self) -> Dict: + """Returns serializable kwargs for all wizard steps.""" + + all_kwargs = {} + + for form_key in self.get_data_form_list().keys(): + all_kwargs[form_key] = self.serializable_form_kwargs_for_step(form_key) + + return all_kwargs + + def serializable_form_kwargs_for_step(self, step) -> Dict: + """Returns serializable kwargs for a wizard step.""" + + form_kwargs = self.get_form_kwargs(step) + form_class = self.form_list[step] + + return form_class.serializable_init_kwargs(form_kwargs) \ No newline at end of file diff --git a/measures/forms/wizard.py b/measures/forms/wizard.py index 2e68915ec..9c00a3858 100644 --- a/measures/forms/wizard.py +++ b/measures/forms/wizard.py @@ -52,6 +52,9 @@ from . import MeasureFootnotesForm from . import MeasureGeoAreaInitialDataMixin +import logging +logger = logging.getLogger(__name__) + class MeasureConditionsWizardStepForm(MeasureConditionsFormMixin): # override methods that use form kwargs @@ -781,7 +784,7 @@ def __init__(self, *args, **kwargs): ) -class MeasureStartDateForm(forms.Form): +class MeasureStartDateForm(forms.Form, SerializableFormMixin): start_date = DateInputFieldFixed( label="Start date", help_text="For example, 27 3 2008", @@ -806,7 +809,6 @@ def __init__(self, *args, **kwargs): def clean(self): cleaned_data = super().clean() - if "start_date" in cleaned_data: start_date = cleaned_data["start_date"] for measure in self.selected_measures: @@ -818,9 +820,33 @@ def clean(self): ) return cleaned_data + + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter(pk__in=serialized_selected_measures_pks) + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs -class MeasureEndDateForm(forms.Form): + +class MeasureEndDateForm(forms.Form, SerializableFormMixin): end_date = DateInputFieldFixed( label="End date", help_text="For example, 27 3 2008", @@ -860,9 +886,33 @@ def clean(self): cleaned_data["end_date"] = None return cleaned_data + + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter(pk__in=serialized_selected_measures_pks) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs -class MeasureRegulationForm(forms.Form): +class MeasureRegulationForm(forms.Form, SerializableFormMixin): generating_regulation = AutoCompleteField( label="Regulation ID", help_text="Select the regulation which provides the legal basis for the measures.", @@ -887,9 +937,32 @@ def __init__(self, *args, **kwargs): data_prevent_double_click="true", ), ) + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter(pk__in=serialized_selected_measures_pks) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs -class MeasureDutiesForm(forms.Form): +class MeasureDutiesForm(forms.Form, SerializableFormMixin): duties = forms.CharField( label="Duties", help_text="Enter the duty that applies to the measures.", @@ -920,6 +993,30 @@ def __init__(self, *args, **kwargs): data_prevent_double_click="true", ), ) + + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter(pk__in=serialized_selected_measures_pks) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs def clean(self): cleaned_data = super().clean() @@ -965,7 +1062,7 @@ def __init__(self, *args, **kwargs): ) -class MeasureGeographicalAreaExclusionsFormSet(FormSet): +class MeasureGeographicalAreaExclusionsFormSet(FormSet, SerializableFormMixin): """Allows editing the geographical area exclusions of multiple measures in `MeasureEditWizard`.""" diff --git a/measures/migrations/0017_measuresbulkeditor.py b/measures/migrations/0017_measuresbulkeditor.py new file mode 100644 index 000000000..e3adb412a --- /dev/null +++ b/measures/migrations/0017_measuresbulkeditor.py @@ -0,0 +1,84 @@ +# Generated by Django 4.2.14 on 2024-07-29 14:39 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django_fsm +import measures.models.bulk_processing + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("workbaskets", "0008_datarow_dataupload"), + ("measures", "0016_measuresbulkcreator"), + ] + + operations = [ + migrations.CreateModel( + name="MeasuresBulkEditor", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "task_id", + models.CharField(blank=True, max_length=50, null=True, unique=True), + ), + ( + "processing_state", + django_fsm.FSMField( + choices=[ + ("AWAITING_PROCESSING", "Awaiting processing"), + ("CURRENTLY_PROCESSING", "Currently processing"), + ("SUCCESSFULLY_PROCESSED", "Successfully processed"), + ("FAILED_PROCESSING", "Failed processing"), + ("CANCELLED", "Cancelled"), + ], + db_index=True, + default="AWAITING_PROCESSING", + editable=False, + max_length=50, + protected=True, + ), + ), + ( + "successfully_processed_count", + models.PositiveIntegerField(default=0), + ), + ("form_data", models.JSONField()), + ("form_kwargs", models.JSONField()), + ("selected_measures", models.JSONField()), + ( + "user", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "workbasket", + models.ForeignKey( + editable=False, + null=True, + on_delete=measures.models.bulk_processing.REVOKE_TASKS_AND_SET_NULL, + to="workbaskets.workbasket", + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/measures/models/__init__.py b/measures/models/__init__.py index a4c641802..f49a0d12e 100644 --- a/measures/models/__init__.py +++ b/measures/models/__init__.py @@ -1,5 +1,6 @@ from measures.models.bulk_processing import BulkProcessor from measures.models.bulk_processing import MeasuresBulkCreator +from measures.models.bulk_processing import MeasuresBulkEditor from measures.models.bulk_processing import ProcessingState from measures.models.tracked_models import AdditionalCodeTypeMeasureType from measures.models.tracked_models import DutyExpression @@ -23,6 +24,7 @@ # - Classes exported from bulk_processing.py. "BulkProcessor", "MeasuresBulkCreator", + "MeasuresBulkEditor", "ProcessingState", # - Classes exported from tracked_model.py. "AdditionalCodeTypeMeasureType", diff --git a/measures/models/bulk_processing.py b/measures/models/bulk_processing.py index aef1d8763..bef4ef427 100644 --- a/measures/models/bulk_processing.py +++ b/measures/models/bulk_processing.py @@ -17,7 +17,13 @@ from common.celery import app from common.models.mixins import TimestampedMixin from common.models.utils import override_current_transaction +from common.util import TaricDateRange +from common.validators import UpdateType from measures.models.tracked_models import Measure +from measures.utils.edit import update_measure_components +from measures.utils.edit import update_measure_condition_components +from measures.utils.edit import update_measure_excluded_geographical_areas +from measures.utils.edit import update_measure_footnote_associations logger = logging.getLogger(__name__) @@ -414,3 +420,231 @@ def _log_form_errors(self, form_class, form_or_formset) -> None: for form_errors in errors: for error_key, error_values in form_errors.items(): logger.error(f"{error_key}: {error_values}") + + + +class MeasuresBulkEditorManager(models.Manager): + """Model Manager for MeasuresBulkEditor models.""" + + def create( + self, + form_data: Dict, + form_kwargs: Dict, + workbasket, + user, + selected_measures, + **kwargs, + ) -> "MeasuresBulkCreator": + """Create and save an instance of MeasuresBulkEditor.""" + + return super().create( + form_data=form_data, + form_kwargs=form_kwargs, + workbasket=workbasket, + user=user, + selected_measures=selected_measures, + **kwargs, + ) + + +class MeasuresBulkEditor(BulkProcessor): + """ + Model class used to bulk edit Measures instances from serialized form + data. + The stored form data is serialized and deserialized by Forms that subclass + SerializableFormMixin. + """ + + objects = MeasuresBulkEditorManager() + + form_data = models.JSONField() + """Dictionary of all Form.data, used to reconstruct bound Form instances as + if the form data had been sumbitted by the user within the measure wizard + process.""" + + form_kwargs = models.JSONField() + """Dictionary of all form init data, excluding a form's `data` param (which + is preserved via this class's `form_data` attribute).""" + + selected_measures = models.JSONField() + """List of all measures that have been selected for bulk editing.""" + + workbasket = models.ForeignKey( + "workbaskets.WorkBasket", + on_delete=REVOKE_TASKS_AND_SET_NULL, + null=True, + editable=False, + ) + """The workbasket with which created measures are associated.""" + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=SET_NULL, + null=True, + editable=False, + ) + """The user who submitted the task to create measures.""" + + def schedule_task(self) -> AsyncResult: + """Implementation of base class method.""" + + from measures.tasks import bulk_edit_measures + + async_result = bulk_edit_measures.apply_async( + kwargs={ + "measures_bulk_editor_pk": self.pk, + }, + countdown=1, + ) + self.task_id = async_result.id + self.save() + + logger.info( + f"Measure bulk edit scheduled on task with ID {async_result.id}" + f"using MeasuresBulkEditor.pk={self.pk}.", + ) + + return async_result + + @atomic + def edit_measures(self) -> Iterable[Measure]: + logger.info("INSIDE EDIT MEASURES - BULK PROCESSING") + + with override_current_transaction( + transaction=self.workbasket.current_transaction, + ): + cleaned_data = self.get_forms_cleaned_data() + deserialized_selected_measures = Measure.objects.filter(pk__in=self.selected_measures) + + new_start_date = cleaned_data.get("start_date", None) + new_end_date = cleaned_data.get("end_date", False) + new_quota_order_number = cleaned_data.get("order_number", None) + new_generating_regulation = cleaned_data.get("generating_regulation", None) + new_duties = cleaned_data.get("duties", None) + new_exclusions = [ + e["excluded_area"] + for e in cleaned_data.get("formset-geographical_area_exclusions", []) + ] + + if deserialized_selected_measures: + edited_measures = [] + + for measure in deserialized_selected_measures: + new_measure = measure.new_version( + workbasket=self.workbasket, + update_type=UpdateType.UPDATE, + valid_between=TaricDateRange( + lower=( + new_start_date + if new_start_date + else measure.valid_between.lower + ), + upper=( + new_end_date + if new_end_date + else measure.valid_between.upper + ), + ), + order_number=( + new_quota_order_number + if new_quota_order_number + else measure.order_number + ), + generating_regulation=( + new_generating_regulation + if new_generating_regulation + else measure.generating_regulation + ), + ) + logger.info("UPDATE FUNCTIONS STARTING") + logger.info(f"BE - NEW MEASURE: {new_measure.__dict__}") + logger.info(f"BE - NEW DUTIES: {new_duties}") + update_measure_components( + measure=new_measure, + duties=new_duties, + workbasket=self.workbasket, + ) + update_measure_condition_components( + measure=new_measure, + workbasket=self.workbasket, + ) + update_measure_excluded_geographical_areas( + edited="geographical_area_exclusions" + in cleaned_data.get("fields_to_edit", []), + measure=new_measure, + exclusions=new_exclusions, + workbasket=self.workbasket, + ) + update_measure_footnote_associations( + measure=new_measure, + workbasket=self.workbasket, + ) + logger.info(f"NEW MEASURE WITH FUNCTIONS RUN: {new_measure}") + + edited_measures.append(new_measure.id) + logger.info(f"EDITED MEASURES ARRAY ON CLOSE: {edited_measures}") + return edited_measures + + def get_forms_cleaned_data(self) -> Dict: + """ + Returns a merged dictionary of all Form cleaned_data. + + If a Form's data contains a `FormSet`, the key will be prefixed with + "formset-" and contain a list of the formset cleaned_data dictionaries. + + If form validation errors are encountered when constructing cleaned + data, then this function raises Django's `ValidationError` exception. + """ + all_cleaned_data = {} + + from measures.views import MeasureEditWizard + for form_key, form_class in MeasureEditWizard.data_form_list: + + if form_key not in self.form_data: + # Forms are conditionally included during step processing - see + # `MeasureEditWizard.show_step()` for details. + continue + + data = self.form_data[form_key] + kwargs = form_class.deserialize_init_kwargs(self.form_kwargs[form_key]) + + form = form_class(data=data, **kwargs) + + if not form.is_valid(): + self._log_form_errors(form_class=form_class, form_or_formset=form) + raise ValidationError( + f"{form_class.__name__} has {len(form.errors)} errors.", + ) + + if isinstance(form.cleaned_data, (tuple, list)): + all_cleaned_data[f"formset-{form_key}"] = form.cleaned_data + else: + all_cleaned_data.update(form.cleaned_data) + + logger.info(f"RESULT OF ALL CLEANED DATA: {all_cleaned_data}") + return all_cleaned_data + + def _log_form_errors(self, form_class, form_or_formset) -> None: + """Output errors associated with a Form or Formset instance, handling + output for each instance type in a uniform manner.""" + + logger.error( + f"MeasuresBulkEditor.edit_measures() - " + f"{form_class.__name__} has {len(form_or_formset.errors)} errors.", + ) + + # Form.errors is a dictionary of errors, but FormSet.errors is a + # list of dictionaries of Form.errors. Access their errors in + # a uniform manner. + errors = [] + + if isinstance(form_or_formset, BaseFormSet): + errors = [ + {"formset_errors": form_or_formset.non_form_errors()}, + ] + form_or_formset.errors + else: + errors = [form_or_formset.errors] + + for form_errors in errors: + for error_key, error_values in form_errors.items(): + logger.error(f"{error_key}: {error_values}") \ No newline at end of file diff --git a/measures/parsers.py b/measures/parsers.py index e13a1bf11..bd112e85c 100644 --- a/measures/parsers.py +++ b/measures/parsers.py @@ -41,6 +41,11 @@ from measures.models import MonetaryUnit from measures.util import convert_eur_to_gbp + +import logging +logger = logging.getLogger(__name__) + + # Used to represent percentage or currency values. Amount = Decimal @@ -135,7 +140,7 @@ def __init__( duty_expressions: Iterable[DutyExpression], monetary_units: Iterable[MonetaryUnit], permitted_measurements: Iterable[Measurement], - component_output: Type[TrackedModel] = MeasureComponent, + component_output_type: Type[TrackedModel] = MeasureComponent, ): # Decimal numbers are a sequence of digits (without a left-trailing zero) # followed optionally by a decimal point and a number of digits (we have seen @@ -173,6 +178,8 @@ def component(duty_exp: DutyExpression) -> Parser: """Matches a string prefix and returns the associated type id, along with any parsed amounts and units according to their applicability, as a 4-tuple of (id, amount, monetary unit, measurement).""" + + logger.info("INSIDE COMPONENT") prefix = duty_exp.prefix has_amount = duty_exp.duty_amount_applicability_code has_measurement = duty_exp.measurement_unit_applicability_code @@ -211,7 +218,7 @@ def component(duty_exp: DutyExpression) -> Parser: and has_measurement != ApplicabilityCode.NOT_PERMITTED else component ).parsecmap( - lambda exp: component_output( + lambda exp: component_output_type( duty_expression=exp[0], duty_amount=exp[1], monetary_unit=exp[2], diff --git a/measures/tasks.py b/measures/tasks.py index ed61c3f6d..3ebe5e6d4 100644 --- a/measures/tasks.py +++ b/measures/tasks.py @@ -2,6 +2,7 @@ from common.celery import app from measures.models import MeasuresBulkCreator +from measures.models import MeasuresBulkEditor logger = logging.getLogger(__name__) @@ -43,3 +44,37 @@ def bulk_create_measures(measures_bulk_creator_pk: int) -> None: f"succeeded but created no measures in " f"WorkBasket({measures_bulk_creator.workbasket.pk}).", ) + +@app.task +def bulk_edit_measures(measures_bulk_editor_pk: int) -> None: + """Bulk edit measures from serialized measures form data saved within an + instance of MeasuresBulkEditor.""" + + measures_bulk_editor = MeasuresBulkEditor.objects.get(pk=measures_bulk_editor_pk) + measures_bulk_editor.begin_processing() + measures_bulk_editor.save() + + try: + measures = measures_bulk_editor.edit_measures() + except Exception as e: + measures_bulk_editor.processing_failed() + measures_bulk_editor.save() + logger.error( + f"MeasuresBulkCreator({measures_bulk_editor.pk}) task failed " + f"attempting to edit measures in " + f"WorkBasket({measures_bulk_editor.workbasket.pk}).", + ) + raise e + + logger.info(f"MEASURES: {measures}") + + measures_bulk_editor.processing_succeeded() + measures_bulk_editor.successfully_processed_count = len(measures) + measures_bulk_editor.save() + + if measures: + logger.info( + f"MeasuresBulkEditoror({measures_bulk_editor.pk}) task " + f"succeeded in editing {len(measures)} Measures in " + f"WorkBasket({measures_bulk_editor.workbasket.pk}).", + ) \ No newline at end of file diff --git a/measures/util.py b/measures/util.py index bf00efd6a..da706d532 100644 --- a/measures/util.py +++ b/measures/util.py @@ -3,12 +3,16 @@ from math import floor from typing import Type + from common.models import TrackedModel from common.models.transactions import Transaction from common.validators import UpdateType -from measures.models import MeasureComponent -from workbaskets.models import WorkBasket +# from measures import models as measure_models +# from workbaskets import models as workbasket_models + +import logging +logger = logging.getLogger(__name__) def convert_eur_to_gbp(amount: str, eur_gbp_conversion_rate: float) -> str: """Convert EUR amount to GBP and round down to nearest pence.""" @@ -29,9 +33,9 @@ def diff_components( instance, duty_sentence: str, start_date: date, - workbasket: WorkBasket, - transaction: Type[Transaction], - component_output: Type[TrackedModel] = MeasureComponent, + transaction: Transaction, + workbasket: "workbasket_models.Workbasket", + component_output_type: Type = None, reverse_attribute: str = "component_measure", ): """ @@ -47,28 +51,48 @@ def diff_components( of transactions and avoid business rule violations (e.g. ActionRequiresDuty). """ + logger.info("DIFF COMPONENTS CALLED") from measures.parsers import DutySentenceParser + from measures.models import MeasureComponent + # from measures.duty_sentence_parser import DutySentenceParser as LarkDutySentenceParser + # Setting as a default parameter causes a circular import. To work round it, we set the default to none, + # Then reassign once we call the function + component_output_type = MeasureComponent if not component_output_type else component_output_type parser = DutySentenceParser.create( start_date, - component_output=component_output, + component_output=component_output_type, ) - + logger.info(f"DC - DUTY SENTENCE: {duty_sentence}") + logger.info(f"DC - DUTY SENTENCE TYPE: {type(duty_sentence)}") new_components = parser.parse(duty_sentence) old_components = instance.components.approved_up_to_transaction( workbasket.current_transaction, ) + logger.info(f"DC - NEW COMPONENTS: {new_components}") + logger.info(f"DC - OLD COMPONENTS: {old_components}") + new_by_id = {c.duty_expression.id: c for c in new_components} old_by_id = {c.duty_expression.id: c for c in old_components} + + logger.info(f"DC - NEW BY ID: {new_by_id}") + logger.info(f"DC - OLD BY ID: {old_by_id}") + all_ids = set(new_by_id.keys()) | set(old_by_id.keys()) + + logger.info(f"DC - ALL ID: {all_ids}") + update_transaction = transaction if transaction else None for id in all_ids: new = new_by_id.get(id) old = old_by_id.get(id) if new and old: # Component is having amount/unit changed – UPDATE it + logger.info(f"DC IF - NEW: {new}") + logger.info(f"DC IF - OLD: {old}") new.update_type = UpdateType.UPDATE new.version_group = old.version_group + setattr(new, reverse_attribute, instance) if not update_transaction: update_transaction = workbasket.new_transaction() @@ -90,4 +114,4 @@ def diff_components( workbasket, update_type=UpdateType.DELETE, transaction=workbasket.new_transaction(), - ) + ) \ No newline at end of file diff --git a/measures/utils/edit.py b/measures/utils/edit.py new file mode 100644 index 000000000..5b13f67eb --- /dev/null +++ b/measures/utils/edit.py @@ -0,0 +1,119 @@ +from common.models import TrackedModel +from common.validators import UpdateType +from geo_areas.models import GeographicalArea +from geo_areas.utils import get_all_members_of_geo_groups +from measures import models as measure_models +from typing import List +from typing import Type +from measures.util import diff_components +from workbaskets import models as workbasket_models + +import logging +logger = logging.getLogger(__name__) + +def update_measure_components( + duties: str, + measure: Type[TrackedModel] = "measure_models.Measure", + workbasket: Type[TrackedModel] = "workbasket_models.WorkBasket", + ): + """Updates the measure components associated to the measure.""" + + logger.info("UPDATE MEASURE COMPONENT CALLED") + logger.info(f"UMC - MEASURE DUTY SENTENCE: {measure.duty_sentence}") + logger.info(f"UMC - DUTIES: {duties}") + diff_components( + instance=measure, + duty_sentence=duties if duties else measure.duty_sentence, + start_date=measure.valid_between.lower, + workbasket=workbasket, + transaction=workbasket.current_transaction, + ) + + +def update_measure_condition_components( + workbasket: Type[TrackedModel] = "workbasket_models.WorkBasket", + measure: Type[TrackedModel] = "measure_models.Measure", +): + """Updates the measure condition components associated to the + measure.""" + logger.info("UPDATE MEASURE CONDITION CALLED") + conditions = measure.conditions.current() + for condition in conditions: + condition.new_version( + dependent_measure=measure, + workbasket=workbasket, + ) + + +def update_measure_excluded_geographical_areas( + edited: bool, + exclusions: List[GeographicalArea], + workbasket: Type[TrackedModel] = "workbasket_models.WorkBasket", + measure: Type[TrackedModel] = "measure_models.Measure", +): + """Updates the excluded geographical areas associated to the measure.""" + + logger.info("UPDATE MEASURE EXCLUDED GEO AREAS CALLED") + existing_exclusions = measure.exclusions.current() + + # Update any exclusions to new measure version + if not edited: + for exclusion in existing_exclusions: + exclusion.new_version( + modified_measure=measure, + workbasket=workbasket, + ) + return + + new_excluded_areas = get_all_members_of_geo_groups( + validity=measure.valid_between, + geo_areas=exclusions, + ) + + for geo_area in new_excluded_areas: + existing_exclusion = existing_exclusions.filter( + excluded_geographical_area=geo_area, + ).first() + if existing_exclusion: + existing_exclusion.new_version( + modified_measure=measure, + workbasket=workbasket, + ) + else: + measure_models.MeasureExcludedGeographicalArea.objects.create( + modified_measure=measure, + excluded_geographical_area=geo_area, + update_type=UpdateType.CREATE, + transaction=workbasket.new_transaction(), + ) + + removed_excluded_areas = { + e.excluded_geographical_area for e in existing_exclusions + }.difference(set(exclusions)) + + exclusions_to_remove = [ + existing_exclusions.get(excluded_geographical_area__id=geo_area.id) + for geo_area in removed_excluded_areas + ] + + for exclusion in exclusions_to_remove: + exclusion.new_version( + update_type=UpdateType.DELETE, + modified_measure=measure, + workbasket=workbasket, + ) + + +def update_measure_footnote_associations(measure, workbasket): + """Updates the footnotes associated to the measure.""" + logger.info("UPDATE MEASURE FOOTNOTE ASSOSH CALLED") + footnote_associations = ( + measure_models.FootnoteAssociationMeasure.objects.current().filter( + footnoted_measure__sid=measure.sid, + ) + ) + for fa in footnote_associations: + fa.new_version( + footnoted_measure=measure, + workbasket=workbasket, + ) diff --git a/measures/views/wizard.py b/measures/views/wizard.py index dfa156a2b..0bb65eee4 100644 --- a/measures/views/wizard.py +++ b/measures/views/wizard.py @@ -1,7 +1,7 @@ import logging from typing import Dict -from typing import List +from common.forms import SerializableFormMixin from crispy_forms_gds.helper import FormHelper from django.conf import settings from django.contrib.auth.mixins import PermissionRequiredMixin @@ -18,7 +18,6 @@ from geo_areas import constants from geo_areas.models import GeographicalArea from geo_areas.models import GeographicalMembership -from geo_areas.utils import get_all_members_of_geo_groups from geo_areas.validators import AreaCode from measures import forms from measures import models @@ -27,7 +26,10 @@ from measures.constants import START from measures.constants import MeasureEditSteps from measures.creators import MeasuresCreator -from measures.util import diff_components +from measures.utils.edit import update_measure_components +from measures.utils.edit import update_measure_condition_components +from measures.utils.edit import update_measure_excluded_geographical_areas +from measures.utils.edit import update_measure_footnote_associations from workbaskets.models import WorkBasket from workbaskets.views.decorators import require_current_workbasket @@ -41,6 +43,7 @@ class MeasureEditWizard( PermissionRequiredMixin, MeasureSelectionQuerysetMixin, NamedUrlSessionWizardView, + SerializableFormMixin, ): """ Multipart form wizard for editing multiple measures. @@ -51,8 +54,7 @@ class MeasureEditWizard( storage_name = "measures.wizard.MeasureEditSessionStorage" permission_required = ["common.change_trackedmodel"] - form_list = [ - (START, forms.MeasuresEditFieldsForm), + data_form_list = [ (MeasureEditSteps.START_DATE, forms.MeasureStartDateForm), (MeasureEditSteps.END_DATE, forms.MeasureEndDateForm), (MeasureEditSteps.QUOTA_ORDER_NUMBER, forms.MeasureQuotaOrderNumberForm), @@ -63,6 +65,14 @@ class MeasureEditWizard( forms.MeasureGeographicalAreaExclusionsFormSet, ), ] + """Forms in this wizard's steps that collect user data.""" + + form_list = [ + (START, forms.MeasuresEditFieldsForm), + *data_form_list, + ] + """All Forms in this wizard's steps, including both those that collect user + data and those that don't.""" templates = { START: "measures/edit-multiple-start.jinja", @@ -94,6 +104,10 @@ class MeasureEditWizard( }, } + @property + def workbasket(self) -> WorkBasket: + return WorkBasket.current(self.request) + def get_template_names(self): return self.templates.get( self.steps.current, @@ -131,106 +145,39 @@ def get_form_kwargs(self, step): return kwargs - def update_measure_components( - self, - measure: models.Measure, - duties: str, - workbasket: WorkBasket, - ): - """Updates the measure components associated to the measure.""" - diff_components( - instance=measure, - duty_sentence=duties if duties else measure.duty_sentence, - start_date=measure.valid_between.lower, - workbasket=workbasket, - transaction=workbasket.current_transaction, - ) + def done(self, form_list, **kwargs): + if settings.MEASURES_ASYNC_EDIT: + return self.async_done(form_list, **kwargs) + else: + return self.sync_done(form_list, **kwargs) - def update_measure_condition_components( - self, - measure: models.Measure, - workbasket: WorkBasket, - ): - """Updates the measure condition components associated to the - measure.""" - conditions = measure.conditions.current() - for condition in conditions: - condition.new_version( - dependent_measure=measure, - workbasket=workbasket, - ) + def async_done(self, form_list, **kwargs): + logger.info("Editing measures asynchronously.") + serializable_data = self.all_serializable_form_data() + serializable_form_kwargs = self.all_serializable_form_kwargs() - def update_measure_excluded_geographical_areas( - self, - edited: bool, - measure: models.Measure, - exclusions: List[GeographicalArea], - workbasket: WorkBasket, - ): - """Updates the excluded geographical areas associated to the measure.""" - existing_exclusions = measure.exclusions.current() - - # Update any exclusions to new measure version - if not edited: - for exclusion in existing_exclusions: - exclusion.new_version( - modified_measure=measure, - workbasket=workbasket, - ) - return + db_selected_measures = [] + for measure in self.get_queryset(): + db_selected_measures.append(measure.id) - new_excluded_areas = get_all_members_of_geo_groups( - validity=measure.valid_between, - geo_areas=exclusions, + measures_bulk_editor = models.MeasuresBulkEditor.objects.create( + form_data=serializable_data, + form_kwargs=serializable_form_kwargs, + workbasket=self.workbasket, + user=self.request.user, + selected_measures=db_selected_measures, ) + self.session_store.clear() + measures_bulk_editor.schedule_task() - for geo_area in new_excluded_areas: - existing_exclusion = existing_exclusions.filter( - excluded_geographical_area=geo_area, - ).first() - if existing_exclusion: - existing_exclusion.new_version( - modified_measure=measure, - workbasket=workbasket, - ) - else: - models.MeasureExcludedGeographicalArea.objects.create( - modified_measure=measure, - excluded_geographical_area=geo_area, - update_type=UpdateType.CREATE, - transaction=workbasket.new_transaction(), - ) - - removed_excluded_areas = { - e.excluded_geographical_area for e in existing_exclusions - }.difference(set(exclusions)) - - exclusions_to_remove = [ - existing_exclusions.get(excluded_geographical_area__id=geo_area.id) - for geo_area in removed_excluded_areas - ] - - for exclusion in exclusions_to_remove: - exclusion.new_version( - update_type=UpdateType.DELETE, - modified_measure=measure, - workbasket=workbasket, - ) - - def update_measure_footnote_associations(self, measure, workbasket): - """Updates the footnotes associated to the measure.""" - footnote_associations = ( - models.FootnoteAssociationMeasure.objects.current().filter( - footnoted_measure__sid=measure.sid, - ) + return redirect( + reverse( + "workbaskets:workbasket-ui-review-measures", + kwargs={"pk": self.workbasket.pk}, + ), ) - for fa in footnote_associations: - fa.new_version( - footnoted_measure=measure, - workbasket=workbasket, - ) - def done(self, form_list, **kwargs): + def sync_done(self, form_list, **kwargs): cleaned_data = self.get_all_cleaned_data() selected_measures = self.get_queryset() workbasket = WorkBasket.current(self.request) @@ -270,23 +217,23 @@ def done(self, form_list, **kwargs): else measure.generating_regulation ), ) - self.update_measure_components( + update_measure_components( measure=new_measure, duties=new_duties, workbasket=workbasket, ) - self.update_measure_condition_components( + update_measure_condition_components( measure=new_measure, workbasket=workbasket, ) - self.update_measure_excluded_geographical_areas( + update_measure_excluded_geographical_areas( edited="geographical_area_exclusions" in cleaned_data.get("fields_to_edit", []), measure=new_measure, exclusions=new_exclusions, workbasket=workbasket, ) - self.update_measure_footnote_associations( + update_measure_footnote_associations( measure=new_measure, workbasket=workbasket, ) diff --git a/settings/common.py b/settings/common.py index b3728ef59..f398ca328 100644 --- a/settings/common.py +++ b/settings/common.py @@ -663,6 +663,9 @@ "measures.tasks.bulk_create_measures": { "queue": "bulk-create", }, + "measures.tasks.bulk_edit_measures": { + "queue": "bulk-create", + }, } SQLITE_EXCLUDED_APPS = [ @@ -917,3 +920,4 @@ # Asynchronous / background (bulk) object creation and editing config. MEASURES_ASYNC_CREATION = is_truthy(os.environ.get("MEASURES_ASYNC_CREATION", "true")) +MEASURES_ASYNC_EDIT = is_truthy(os.environ.get("MEASURES_ASYNC_EDIT", "true")) \ No newline at end of file