diff --git a/common/inspect_tap_tasks.py b/common/inspect_tap_tasks.py index 55ea9a2ff..47a874468 100644 --- a/common/inspect_tap_tasks.py +++ b/common/inspect_tap_tasks.py @@ -59,9 +59,10 @@ def clean_tasks(self, tasks, task_status="", task_name="") -> List[Dict]: return tasks_cleaned - def current_rule_checks(self, task_name="") -> List[CeleryTask]: + def current_tasks(self, task_name="") -> List[CeleryTask]: """Return the list of tasks queued or started, ready to display in the view.""" + inspect = app.control.inspect() if not inspect: return [] diff --git a/common/util.py b/common/util.py index 39d1bf19f..391391a57 100644 --- a/common/util.py +++ b/common/util.py @@ -2,7 +2,9 @@ from __future__ import annotations +import os import re +import typing from datetime import date from datetime import datetime from datetime import timedelta @@ -43,6 +45,7 @@ from django.db.models.functions.text import Upper from django.db.transaction import atomic from django.template import loader +from django.utils import timezone from lxml import etree from psycopg.types.range import DateRange from psycopg.types.range import TimestampRange @@ -599,3 +602,46 @@ def format_date_string(date_string: str, short_format=False) -> str: return date_parser.parse(date_string).strftime(settings.DATE_FORMAT) except: return "" + + +def log_timing(logger_function: typing.Callable): + """ + Decorator function to log start and end times of a decorated function. + + When decorating a function, `logger_function` must be passed in to the + decorator to ensure the correct logger instance and function are applied. + `logger_function` may be any one of the logging output functions, but is + likely to be either `debug` or `info`. + Example: + ``` + import logging + + logger = logging.getLogger(__name__) + + @log_timing(logger_function=logger.info) + def my_function(): + ... + ``` + """ + + @wrapt.decorator + def wrapper(wrapped, instance, args, kwargs): + start_time = timezone.localtime() + logger_function( + f"Entering the function {wrapped.__name__}() on process " + f"pid={os.getpid()} at {start_time.isoformat()}", + ) + + result = wrapped(*args, **kwargs) + + end_time = timezone.localtime() + elapsed_time = end_time - start_time + logger_function( + f"Exited the function {wrapped.__name__}() on " + f"process pid={os.getpid()} at {end_time.isoformat()} after " + f"an elapsed time of {elapsed_time}.", + ) + + return result + + return wrapper diff --git a/common/views.py b/common/views.py index d0c977f98..1ba38b3b8 100644 --- a/common/views.py +++ b/common/views.py @@ -47,7 +47,7 @@ from commodities.models import GoodsNomenclature from common.business_rules import BusinessRule from common.business_rules import BusinessRuleViolation -from common.celery import app +from common.celery import app as celery_app from common.forms import HomeSearchForm from common.models import TrackedModel from common.models import Transaction @@ -65,8 +65,6 @@ from workbaskets.models import WorkflowStatus from workbaskets.views.mixins import WithCurrentWorkBasket -from .celery import app as celery_app - class HomeView(LoginRequiredMixin, FormView): template_name = "common/homepage.jinja" @@ -350,7 +348,7 @@ class AppInfoView( DATETIME_FORMAT = "%d %b %Y, %H:%M" def active_tasks(self) -> Dict: - inspect = app.control.inspect() + inspect = celery_app.control.inspect() if not inspect: return {} diff --git a/conftest.py b/conftest.py index 54590a937..4c7883647 100644 --- a/conftest.py +++ b/conftest.py @@ -1119,13 +1119,13 @@ def hmrc_storage(s3): @pytest.fixture def sqlite_storage(s3, s3_bucket_names): - """Patch SQLiteStorage with moto so that nothing is really uploaded to + """Patch SQLiteS3VFSStorage with moto so that nothing is really uploaded to s3.""" - from exporter.storages import SQLiteStorage + from exporter.storages import SQLiteS3VFSStorage storage = make_storage_mock( s3, - SQLiteStorage, + SQLiteS3VFSStorage, bucket_name=settings.SQLITE_STORAGE_BUCKET_NAME, ) assert storage.endpoint_url is settings.SQLITE_S3_ENDPOINT_URL diff --git a/exporter/management/commands/dump_sqlite.py b/exporter/management/commands/dump_sqlite.py index 8a3def7ae..c672ca521 100644 --- a/exporter/management/commands/dump_sqlite.py +++ b/exporter/management/commands/dump_sqlite.py @@ -11,20 +11,39 @@ class Command(BaseCommand): + help = ( + "Create a snapshot of the application database to a file in SQLite " + "format. Snapshot file names take the form .db, " + "where is the value of the last published " + "transaction's order attribute. Care should be taken to ensure that " + "there is sufficient local file system storage to accomodate the " + "SQLite file - if you choose to target remote S3 storage, then a " + "temporary local copy of the file will be created and cleaned up." + ) + def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( - "--immediately", + "--asynchronous", action="store_const", - help="Run the task in this process now rather than queueing it up", + help="Queue the snapshot task to run in an asynchronous process.", const=True, default=False, ) + parser.add_argument( + "--save-local", + help=( + "Save the SQLite snapshot to the local file system under the " + "(existing) directory given by DIRECTORY_PATH." + ), + dest="DIRECTORY_PATH", + ) return super().add_arguments(parser) def handle(self, *args: Any, **options: Any) -> Optional[str]: logger.info(f"Triggering tariff database export to SQLite") - if options["immediately"]: - export_and_upload_sqlite() + local_path = options["DIRECTORY_PATH"] + if options["asynchronous"]: + export_and_upload_sqlite.delay(local_path) else: - export_and_upload_sqlite.delay() + export_and_upload_sqlite(local_path) diff --git a/exporter/sqlite/__init__.py b/exporter/sqlite/__init__.py index 998b941c1..e85ac64c1 100644 --- a/exporter/sqlite/__init__.py +++ b/exporter/sqlite/__init__.py @@ -41,32 +41,41 @@ } -def make_export_plan(sqlite: runner.Runner) -> plan.Plan: - names = ( +def make_export_plan(sqlite_runner: runner.Runner) -> plan.Plan: + app_names = ( name.split(".")[0] for name in settings.DOMAIN_APPS if name not in settings.SQLITE_EXCLUDED_APPS ) - all_models = chain(*[apps.get_app_config(name).get_models() for name in names]) + all_models = chain(*[apps.get_app_config(name).get_models() for name in app_names]) models_by_table = {model._meta.db_table: model for model in all_models} import_script = plan.Plan() - for table, sql in sqlite.tables: + for table, create_table_statement in sqlite_runner.tables: model = models_by_table.get(table) if model is None or model.__name__ in SKIPPED_MODELS: continue - columns = list(sqlite.read_column_order(model._meta.db_table)) - import_script.add_schema(sql) + columns = list(sqlite_runner.read_column_order(model._meta.db_table)) + import_script.add_schema(create_table_statement) import_script.add_data(model, columns) return import_script def make_export(connection: apsw.Connection): - with NamedTemporaryFile() as db_name: - sqlite = runner.Runner.make_tamato_database(Path(db_name.name)) - plan = make_export_plan(sqlite) - - export = runner.Runner(connection) - export.run_operations(plan.operations) + with NamedTemporaryFile() as temp_sqlite_db: + # Create Runner instance with its SQLite file name pointing at a path on + # the local file system. This is only required temporarily in order to + # create an in-memory plan that can be run against a target database + # object. + plan_runner = runner.Runner.make_tamato_database( + Path(temp_sqlite_db.name), + ) + plan = make_export_plan(plan_runner) + # make_tamato_database() creates a Connection instance that needs + # closing once an in-memory plan has been created from it. + plan_runner.database.close() + + export_runner = runner.Runner(connection) + export_runner.run_operations(plan.operations) diff --git a/exporter/sqlite/plan.py b/exporter/sqlite/plan.py index 70e0ffa48..d3eee670f 100644 --- a/exporter/sqlite/plan.py +++ b/exporter/sqlite/plan.py @@ -100,9 +100,11 @@ def operations(self) -> Iterable[Operation]: ] def add_schema(self, sql: str): + """Add sql schema (table) creation statements to this Plan instance.""" self._operations.append((sql, [[]])) def add_data(self, model: Type[Model], columns: Iterable[str]): + """Add data insert statements to this Plan instance.""" queryset = model.objects output_columns = [] for column in columns: diff --git a/exporter/sqlite/runner.py b/exporter/sqlite/runner.py index 31612acd2..070a87cf2 100644 --- a/exporter/sqlite/runner.py +++ b/exporter/sqlite/runner.py @@ -40,7 +40,7 @@ def normalise_loglevel(cls, loglevel): return loglevel @classmethod - def manage(cls, db: Path, *args: str): + def manage(cls, sqlite_file: Path, *args: str): """ Runs a Django management command on the SQLite database. @@ -56,7 +56,7 @@ def manage(cls, db: Path, *args: str): sqlite_env["CELERY_LOG_LEVEL"], ) - sqlite_env["DATABASE_URL"] = f"sqlite:///{str(db)}" + sqlite_env["DATABASE_URL"] = f"sqlite:///{str(sqlite_file)}" # Required to make sure the postgres default isn't set as the DB_URL if sqlite_env.get("VCAP_SERVICES"): vcap_env = json.loads(sqlite_env["VCAP_SERVICES"]) @@ -71,21 +71,20 @@ def manage(cls, db: Path, *args: str): ) @classmethod - def make_tamato_database(cls, db: Path) -> "Runner": - """ - Generate a new and empty SQLite database with the TaMaTo schema. - - Because SQLite uses different fields to PostgreSQL, first missing - migrations are generated to bring in the different style of validity - fields. However, these should not generally stick around and be applied - to Postgres so they are removed after being applied. - """ + def make_tamato_database(cls, sqlite_file: Path) -> "Runner": + """Generate a new and empty SQLite database with the TaMaTo schema + derived from Tamato's models - by performing 'makemigrations' followed + by 'migrate' on the Sqlite file located at `sqlite_file`.""" try: - cls.manage(db, "makemigrations", "--name", "sqlite_export") - cls.manage(db, "migrate") - assert db.exists() - return cls(apsw.Connection(str(db))) - + # Because SQLite uses different fields to PostgreSQL, missing + # migrations are first generated to bring in the different style of + # validity fields. However, these should not be applied to Postgres + # and so should be removed (in the `finally` block) after they have + # been applied (when running `migrate`). + cls.manage(sqlite_file, "makemigrations", "--name", "sqlite_export") + cls.manage(sqlite_file, "migrate") + assert sqlite_file.exists() + return cls(apsw.Connection(str(sqlite_file))) finally: for file in Path(settings.BASE_DIR).rglob( "**/migrations/*sqlite_export.py", @@ -93,6 +92,15 @@ def make_tamato_database(cls, db: Path) -> "Runner": file.unlink() def read_schema(self, type: str) -> Iterator[Tuple[str, str]]: + """ + Generator yielding a tuple of 'name' and 'sql' column values from + Sqlite's "schema table", 'sqlite_schema'. + + The `type` param filters rows that have a matching 'type' column value, + which may be any one of: 'table', 'index', 'view', or 'trigger'. + + See https://www.sqlite.org/schematab.html for further details. + """ cursor = self.database.cursor() cursor.execute( f""" @@ -110,16 +118,21 @@ def read_schema(self, type: str) -> Iterator[Tuple[str, str]]: @property def tables(self) -> Iterator[Tuple[str, str]]: + """Generator yielding a tuple of each Sqlite table object's 'name' and + the SQL `CREATE_TABLE` statement that can be used to create the + table.""" yield from self.read_schema("table") @property def indexes(self) -> Iterator[Tuple[str, str]]: + """Generator yielding a tuple of each SQLite table index object name and + the SQL `CREATE_INDEX` statement that can be used to create it.""" yield from self.read_schema("index") def read_column_order(self, table: str) -> Iterator[str]: """ - Returns the name of the columns in the order they are defined in an - SQLite database. + Returns the name of `table`'s columns in the order they are defined in + an SQLite database. This is necessary because the Django migrations do not generate the columns in the order they are defined on the model, and there's no other @@ -131,8 +144,8 @@ def read_column_order(self, table: str) -> Iterator[str]: yield column[1] def run_operations(self, operations: Iterable[Operation]): - """Runs the supplied sequence of operations against the SQLite - database.""" + """Runs each operation in `operations` against `database` member + attribute (a connection object to an SQLite database file).""" cursor = self.database.cursor() for operation in operations: logger.debug("%s: %s", self.database, operation[0]) diff --git a/exporter/sqlite/tasks.py b/exporter/sqlite/tasks.py index 81c72516f..b803d662a 100644 --- a/exporter/sqlite/tasks.py +++ b/exporter/sqlite/tasks.py @@ -1,10 +1,10 @@ import logging +import os from common.celery import app from common.models.transactions import Transaction from common.models.transactions import TransactionPartition -from exporter import sqlite -from exporter.storages import SQLiteStorage +from exporter import storages logger = logging.getLogger(__name__) @@ -30,33 +30,40 @@ def get_output_filename(): @app.task -def export_and_upload_sqlite() -> bool: +def export_and_upload_sqlite(local_path: str = None) -> bool: """ - Generates an export of the currently attached database to a portable SQLite - file and uploads it to the configured S3 bucket. + Generates an export of latest published data from the primary database to a + portable SQLite database file. The most recently published Transaction's + `order` value is used to define latest published data, and its value is used + to name the generated SQLite database file. - If an SQLite export of the current state of the database (as given by the - most recently approved transaction ID) already exists, no action is taken. - Returns a boolean that is ``True`` if a file was uploaded and ``False`` if - not. + If `local_path` is provided, then the SQLite database file will be saved in + that directory location (note that in this case `local_path` must be an + existing directory path on the local file system). + + If `local_path` is not provided, then the SQLite database file will be saved + to the configured S3 bucket. """ - storage = SQLiteStorage() db_name = get_output_filename() + if local_path: + logger.info("SQLite export process targetting local file system.") + storage = storages.SQLiteLocalStorage(location=local_path) + else: + logger.info("SQLite export process targetting S3 file system.") + storage = storages.SQLiteS3Storage() + export_filename = storage.generate_filename(db_name) - logger.debug("Checking for need to upload tariff database %s", export_filename) + logger.info(f"Checking for existing database {export_filename}.") if storage.exists(export_filename): - logger.debug("Database %s already present", export_filename) + logger.info( + f"Database {export_filename} already exists. Exiting process, " + f"pid={os.getpid()}.", + ) return False - logger.info("Generating database %s", export_filename) - sqlite.make_export(storage.get_connection(export_filename)) - logger.info("Generation complete") - - logger.info("Serializing %s", export_filename) - storage.serialize(export_filename) - logger.info("Serializing complete") - - logger.info("Upload complete") + logger.info(f"Generating SQLite database export {export_filename}.") + storage.export_database(export_filename) + logger.info(f"SQLite database export {export_filename} complete.") return True diff --git a/exporter/storages.py b/exporter/storages.py index df9e2d07d..a6a86ac7d 100644 --- a/exporter/storages.py +++ b/exporter/storages.py @@ -1,10 +1,19 @@ +import logging from functools import cached_property from os import path +from pathlib import Path +from tempfile import NamedTemporaryFile import apsw +from django.core.files.storage import Storage from sqlite_s3vfs import S3VFS from storages.backends.s3boto3 import S3Boto3Storage +from common.util import log_timing +from exporter import sqlite + +logger = logging.getLogger(__name__) + class HMRCStorage(S3Boto3Storage): def get_default_settings(self): @@ -24,7 +33,20 @@ def get_object_parameters(self, name): return super().get_object_parameters(name) -class SQLiteStorage(S3Boto3Storage): +class SQLiteExportMixin: + """Mixin class used to define a common export API among SQLite Storage + subclasses.""" + + def export_database(self, filename: str): + """Export Tamato's primary database to an SQLite file format, saving to + Storage's backing store (S3, local file system, etc).""" + raise NotImplementedError + + +class SQLiteS3StorageBase(S3Boto3Storage): + """Storage base class used for remotely storing SQLite database files to an + AWS S3-like backing store (AWS S3, Minio, etc).""" + def get_default_settings(self): from django.conf import settings @@ -46,17 +68,73 @@ def generate_filename(self, filename: str) -> str: ) return super().generate_filename(filename) + +class SQLiteS3VFSStorage(SQLiteExportMixin, SQLiteS3StorageBase): + """ + Storage class used for remotely storing SQLite database files to an AWS + S3-like backing store. + + This class uses the s3sqlite package ( + https://pypi.org/project/s3sqlite/) + to apply an S3 virtual file system strategy when saving the SQLite file to + S3. + """ + def exists(self, filename: str) -> bool: return any(self.listdir(filename)) - def serialize(self, filename): - vfs_fileobj = self.vfs.serialize_fileobj(key_prefix=filename) - self.bucket.Object(filename).upload_fileobj(vfs_fileobj) - @cached_property def vfs(self) -> apsw.VFS: return S3VFS(bucket=self.bucket, block_size=65536) - def get_connection(self, filename: str) -> apsw.Connection: - """Creates a new empty SQLite database.""" - return apsw.Connection(filename, vfs=self.vfs.name) + @log_timing(logger_function=logger.info) + def export_database(self, filename: str): + connection = apsw.Connection(filename, vfs=self.vfs.name) + sqlite.make_export(connection) + connection.close() + logger.info(f"Serializing {filename} to S3 storage.") + vfs_fileobj = self.vfs.serialize_fileobj(key_prefix=filename) + self.bucket.Object(filename).upload_fileobj(vfs_fileobj) + + +class SQLiteS3Storage(SQLiteExportMixin, SQLiteS3StorageBase): + """ + Storage class used for remotely storing SQLite database files to an AWS + S3-like backing store. + + This class applies a strategy that first creates a temporary instance of the + SQLite file on the local file system before transfering its contents to S3. + """ + + @log_timing(logger_function=logger.info) + def export_database(self, filename: str): + with NamedTemporaryFile() as temp_sqlite_db: + connection = apsw.Connection(temp_sqlite_db.name) + sqlite.make_export(connection) + connection.close() + logger.info(f"Saving {filename} to S3 storage.") + self.save(filename, temp_sqlite_db.file) + + +class SQLiteLocalStorage(SQLiteExportMixin, Storage): + """Storage class used for storing SQLite database files to the local file + system.""" + + def __init__(self, location) -> None: + self._location = Path(location).expanduser().resolve() + logger.info(f"Normalised path `{location}` to `{self._location}`.") + if not self._location.is_dir(): + raise Exception(f"Directory does not exist: {location}.") + + def path(self, name: str) -> str: + return str(self._location.joinpath(name)) + + def exists(self, name: str) -> bool: + return Path(self.path(name)).exists() + + @log_timing(logger_function=logger.info) + def export_database(self, filename: str): + connection = apsw.Connection(self.path(filename)) + logger.info(f"Saving {filename} to local file system storage.") + sqlite.make_export(connection) + connection.close() diff --git a/exporter/tests/test_exporter_commands.py b/exporter/tests/test_exporter_commands.py index 9a0497d7b..762819c88 100644 --- a/exporter/tests/test_exporter_commands.py +++ b/exporter/tests/test_exporter_commands.py @@ -1,5 +1,6 @@ from io import StringIO from unittest import mock +from unittest.mock import MagicMock import pytest from django.core.management import call_command @@ -14,6 +15,40 @@ pytestmark = pytest.mark.django_db +@pytest.mark.parametrize( + ("asynchronous_flag", "save_local_flag_value"), + ( + (None, None), + ("--asynchronous", None), + ("--asynchronous", "/tmp"), + (None, "/tmp"), + ), +) +def test_dump_sqlite_command(asynchronous_flag, save_local_flag_value): + flags = [] + if asynchronous_flag: + flags.append(asynchronous_flag) + if save_local_flag_value: + flags.extend(("--save-local", save_local_flag_value)) + + with mock.patch( + "exporter.management.commands.dump_sqlite.export_and_upload_sqlite", + return_value=MagicMock(), + ) as mock_export_and_upload_sqlite: + call_command("dump_sqlite", *flags) + + if asynchronous_flag: + mock_export_and_upload_sqlite.assert_not_called() + mock_export_and_upload_sqlite.delay.assert_called_once_with( + save_local_flag_value, + ) + else: + mock_export_and_upload_sqlite.assert_called_once_with( + save_local_flag_value, + ) + mock_export_and_upload_sqlite.delay.assert_not_called() + + @pytest.mark.skip() def test_upload_command_uploads_queued_workbasket_to_s3( approved_transaction, diff --git a/exporter/tests/test_sqlite.py b/exporter/tests/test_sqlite.py index a7bc036ae..c75d8f839 100644 --- a/exporter/tests/test_sqlite.py +++ b/exporter/tests/test_sqlite.py @@ -2,6 +2,7 @@ from io import BytesIO from os import path from pathlib import Path +from typing import Iterator from unittest import mock import apsw @@ -18,7 +19,7 @@ @pytest.fixture(scope="module") -def sqlite_template() -> Runner: +def sqlite_template() -> Iterator[Runner]: """ Provides a template SQLite file with the correct TaMaTo schema but without any data. @@ -33,7 +34,7 @@ def sqlite_template() -> Runner: @pytest.fixture(scope="function") -def sqlite_database(sqlite_template: Runner) -> Runner: +def sqlite_database(sqlite_template: Runner) -> Iterator[Runner]: """Copies the template file to a new location that will be cleaned up at the end of one test.""" in_memory_database = apsw.Connection(":memory:") @@ -130,10 +131,10 @@ def test_valid_between_export( assert validity_end is None -def test_export_task_does_not_reupload(sqlite_storage, s3_object_names, settings): +def test_s3_export_task_does_not_reupload(sqlite_storage, s3_object_names, settings): """ - If a file has already been generated for this database state, we don't need - to upload it again. + If a file has already been generated and uploaded to S3 for this database + state, we don't need to upload it again. This idempotency allows us to regularly run an export check without constantly uploading files and wasting bandwidth/money. @@ -149,7 +150,10 @@ def test_export_task_does_not_reupload(sqlite_storage, s3_object_names, settings sqlite_storage.save(expected_key, BytesIO(b"")) names_before = s3_object_names(sqlite_storage.bucket_name) - with mock.patch("exporter.sqlite.tasks.SQLiteStorage", new=lambda: sqlite_storage): + with mock.patch( + "exporter.sqlite.tasks.storages.SQLiteS3Storage", + new=lambda: sqlite_storage, + ): returned = tasks.export_and_upload_sqlite() assert returned is False @@ -157,7 +161,22 @@ def test_export_task_does_not_reupload(sqlite_storage, s3_object_names, settings assert names_before == names_after -def test_export_task_uploads(sqlite_storage, s3_object_names, settings): +def test_local_export_task_does_not_replace(tmp_path): + """Test that if an SQLite file has already been generated on the local file + system at a specific directory location for this database state, then no + attempt is made to create it again.""" + factories.SeedFileTransactionFactory.create(order="999") + transaction = factories.PublishedTransactionFactory.create() + + sqlite_file_path = tmp_path / f"{tasks.normalised_order(transaction.order)}.db" + sqlite_file_path.write_bytes(b"") + files_before = set(tmp_path.iterdir()) + + assert not tasks.export_and_upload_sqlite(tmp_path) + assert files_before == set(tmp_path.iterdir()) + + +def test_s3_export_task_uploads(sqlite_storage, s3_object_names, settings): """The export system should actually upload a file to S3.""" factories.SeedFileTransactionFactory.create(order="999") transaction = factories.PublishedTransactionFactory.create() @@ -167,7 +186,10 @@ def test_export_task_uploads(sqlite_storage, s3_object_names, settings): f"{tasks.normalised_order(transaction.order)}.db", ) - with mock.patch("exporter.sqlite.tasks.SQLiteStorage", new=lambda: sqlite_storage): + with mock.patch( + "exporter.sqlite.tasks.storages.SQLiteS3Storage", + new=lambda: sqlite_storage, + ): returned = tasks.export_and_upload_sqlite() assert returned is True @@ -176,14 +198,26 @@ def test_export_task_uploads(sqlite_storage, s3_object_names, settings): ) -def test_export_task_ignores_unpublished_and_unapproved_transactions( +def test_local_export_task_saves(tmp_path): + """Test that export correctly saves a file to the local file system.""" + factories.SeedFileTransactionFactory.create(order="999") + transaction = factories.PublishedTransactionFactory.create() + + sqlite_file_path = tmp_path / f"{tasks.normalised_order(transaction.order)}.db" + files_before = set(tmp_path.iterdir()) + + assert tasks.export_and_upload_sqlite(tmp_path) + assert files_before | {sqlite_file_path} == set(tmp_path.iterdir()) + + +def test_s3_export_task_ignores_unpublished_and_unapproved_transactions( sqlite_storage, s3_object_names, settings, ): - """Only transactions that have been approved should be included in the - upload as draft data may be sensitive and unpublished, and shouldn't be - included.""" + """Only transactions that have been published should be included in the + upload as draft and queued data may be sensitive and unpublished, and should + therefore not be included.""" factories.SeedFileTransactionFactory.create(order="999") transaction = factories.PublishedTransactionFactory.create(order="123") factories.ApprovedTransactionFactory.create(order="124") @@ -198,7 +232,7 @@ def test_export_task_ignores_unpublished_and_unapproved_transactions( names_before = s3_object_names(sqlite_storage.bucket_name) with mock.patch( - "exporter.sqlite.tasks.SQLiteStorage", + "exporter.sqlite.tasks.storages.SQLiteS3Storage", new=lambda: sqlite_storage, ): returned = tasks.export_and_upload_sqlite() @@ -206,3 +240,19 @@ def test_export_task_ignores_unpublished_and_unapproved_transactions( names_after = s3_object_names(sqlite_storage.bucket_name) assert names_before == names_after + + +def test_local_export_task_ignores_unpublished_and_unapproved_transactions(tmp_path): + """Only transactions that have been published should be included in the + upload as draft and queued data may be sensitive and unpublished, and should + therefore not be included.""" + factories.SeedFileTransactionFactory.create(order="999") + transaction = factories.PublishedTransactionFactory.create(order="123") + factories.ApprovedTransactionFactory.create(order="124") + factories.UnapprovedTransactionFactory.create(order="125") + + sqlite_file_path = tmp_path / f"{tasks.normalised_order(transaction.order)}.db" + files_before = set(tmp_path.iterdir()) + + assert tasks.export_and_upload_sqlite(tmp_path) + assert files_before | {sqlite_file_path} == set(tmp_path.iterdir()) diff --git a/settings/common.py b/settings/common.py index 2d0036fca..b3728ef59 100644 --- a/settings/common.py +++ b/settings/common.py @@ -606,27 +606,32 @@ # Lock expires in 10 minutes CROWN_DEPENDENCIES_API_TASK_LOCK = 60 * 10 -CROWN_DEPENDENCIES_API_CRON = ( - crontab(os.environ.get("CROWN_DEPENDENCIES_API_CRON")) - if os.environ.get("CROWN_DEPENDENCIES_API_CRON") - else crontab(minute="0", hour="8-18/2", day_of_week="mon-fri") -) +CELERY_BEAT_SCHEDULE = {} -# `SQLITE_EXPORT_CRONTAB` sets the time, in crontab format, that an Sqlite -# snapshot task is scheduled by Celery Beat for execution by a Celery task. -# (See https://en.wikipedia.org/wiki/Cron for format description.) -SQLITE_EXPORT_CRONTAB = os.environ.get("SQLITE_EXPORT_CRONTAB", "05 19 * * *") -CELERY_BEAT_SCHEDULE = { - "sqlite_export": { +ENABLE_SQLITE_EXPORT_SCHEDULE = is_truthy( + os.environ.get("ENABLE_SQLITE_EXPORT_SCHEDULE", "True"), +) +if ENABLE_SQLITE_EXPORT_SCHEDULE: + # `SQLITE_EXPORT_CRONTAB` sets the time, in crontab format, that an Sqlite + # snapshot task is scheduled by Celery Beat for execution by a Celery task. + # (See https://en.wikipedia.org/wiki/Cron for format description.) + SQLITE_EXPORT_CRONTAB = os.environ.get( + "SQLITE_EXPORT_CRONTAB", + "05 19 * * *", + ) + CELERY_BEAT_SCHEDULE["sqlite_export"] = { "task": "exporter.sqlite.tasks.export_and_upload_sqlite", "schedule": crontab(*SQLITE_EXPORT_CRONTAB.split()), - }, -} + } if ENABLE_CROWN_DEPENDENCIES_PUBLISHING: + CROWN_DEPENDENCIES_API_CRON = ( + crontab(os.environ.get("CROWN_DEPENDENCIES_API_CRON")) + if os.environ.get("CROWN_DEPENDENCIES_API_CRON") + else crontab(minute="0", hour="8-18/2", day_of_week="mon-fri") + ) CELERY_BEAT_SCHEDULE["crown_dependencies_api_publish"] = { "task": "publishing.tasks.publish_to_api", - # every 2 hours between 8am and 6pm on weekdays "schedule": CROWN_DEPENDENCIES_API_CRON, } diff --git a/workbaskets/tests/test_views.py b/workbaskets/tests/test_views.py index d2c66ced6..7ef5d7dfc 100644 --- a/workbaskets/tests/test_views.py +++ b/workbaskets/tests/test_views.py @@ -2562,10 +2562,10 @@ def test_clean_tasks(): ) -def test_current_rule_checks_is_called(valid_user_client): - """Test that current_rule_checks function gets called when a user goes to - the rule check page and the page correctly displays the returned list of - rule check tasks.""" +def test_current_tasks_is_called(valid_user_client): + """Test that current_tasks function gets called when a user goes to the rule + check page and the page correctly displays the returned list of rule check + tasks.""" return_value = [ CeleryTask( @@ -2581,13 +2581,13 @@ def test_current_rule_checks_is_called(valid_user_client): with patch.object( TAPTasks, - "current_rule_checks", + "current_tasks", return_value=return_value, - ) as mock_current_rule_checks: + ) as mock_current_tasks: response = valid_user_client.get(reverse("workbaskets:rule-check-queue")) assert response.status_code == 200 - # Assert current_rule_checks gets called - mock_current_rule_checks.assert_called_once() + # Assert current_tasks gets called + mock_current_tasks.assert_called_once() # Assert the mocked response is formatted correctly on the page soup = BeautifulSoup(str(response.content), "html.parser") table_rows = [element for element in soup.select(".govuk-table__row")] diff --git a/workbaskets/views/ui.py b/workbaskets/views/ui.py index 82fe033d7..01b77b4c9 100644 --- a/workbaskets/views/ui.py +++ b/workbaskets/views/ui.py @@ -1743,10 +1743,9 @@ def get_context_data(self, **kwargs): tap_tasks = TAPTasks() try: context["celery_healthy"] = True - current_rule_checks = tap_tasks.current_rule_checks( + context["current_rule_checks"] = tap_tasks.current_tasks( "workbaskets.tasks.call_check_workbasket_sync", ) - context["current_rule_checks"] = current_rule_checks context["status_tag_generator"] = self.status_tag_generator except kombu.exceptions.OperationalError as oe: context["celery_healthy"] = False