Skip to content

Commit

Permalink
PDCT 1108/fix bug when not marking docs as deleted (#137)
Browse files Browse the repository at this point in the history
* Ensure code up to date when tests run

* Add doc string to with_transaction

* refactor function to aid readability

* Fix issue with query

* Add a comment to the file to explain how and what it tests

* move the patch of get_db() into the data_db fixture

* Add failing test

* Bump version

* Add fix

* Bump version

* Change where it is fixed so its testable

* update tests
  • Loading branch information
diversemix authored May 22, 2024
1 parent 9b4b13e commit 7103c9d
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 76 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ build:
build_dev:
docker compose build

unit_test:
unit_test: build
docker compose run --rm navigator-admin-backend pytest -vvv tests/unit_tests

integration_test:
integration_test: build_dev
docker compose run --rm navigator-admin-backend pytest -vvv tests/integration_tests

test:
test: build_dev
docker compose run --rm navigator-admin-backend pytest -vvv tests

run:
Expand Down
9 changes: 9 additions & 0 deletions app/clients/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def get_db() -> Session:


def with_transaction(module_name, context=session_context):
"""Wraps a function with this standard transaction handler.
Note: You still need to call commit() in the `func` if you require
any changes to persist.
:param _type_ module_name: The name of the module, used for logging context.
:param _type_ context: any context object to propagate to `func`, defaults to session_context
"""

def inner(func):
def wrapper(*args, **kwargs):
context.error = None
Expand Down
100 changes: 57 additions & 43 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,38 @@ def create(db: Session, family: FamilyCreateDTO, geo_id: int, org_id: int) -> st
return cast(str, new_family.import_id)


def hard_delete(db: Session, import_id: str):
"""Forces a hard delete of the family.
:param db Session: the database connection
:param str import_id: The family import id to delete.
:return bool: True if deleted False if not.
"""
commands = [
db_delete(CollectionFamily).where(
CollectionFamily.family_import_id == import_id
),
db_delete(FamilyEvent).where(FamilyEvent.family_import_id == import_id),
db_delete(FamilyCorpus).where(FamilyCorpus.family_import_id == import_id),
db_delete(Slug).where(Slug.family_import_id == import_id),
db_delete(FamilyMetadata).where(FamilyMetadata.family_import_id == import_id),
db_delete(Family).where(Family.import_id == import_id),
]

for c in commands:
result = db.execute(c)
# Keep this for debug.
_LOGGER.debug("%s, %s", str(c), result.rowcount) # type: ignore
db.commit()

fam_deleted = db.query(Family).filter(Family.import_id == import_id).one_or_none()
if fam_deleted is not None:
msg = f"Could not hard delete family: {import_id}"
_LOGGER.error(msg)

return bool(fam_deleted is None)


def delete(db: Session, import_id: str) -> bool:
"""
Deletes a single family by the import id.
Expand All @@ -408,53 +440,35 @@ def delete(db: Session, import_id: str) -> bool:
return False

# Only perform if we have docs associated with this family
family_doc_count = (
family_docs = (
db.query(FamilyDocument)
.filter(FamilyDocument.family_import_id == import_id)
.count()
.all()
)

if family_doc_count > 0:
# Soft delete all documents associated with the family.
result = db.execute(
db_update(FamilyDocument)
.filter(FamilyDocument.family_import_id == import_id)
.values(document_status=DocumentStatus.DELETED)
)

if result.rowcount == 0: # type: ignore
msg = f"Could not soft delete documents in family : {import_id}"
_LOGGER.error(msg)
raise RepositoryError(msg)

elif family_doc_count == 0 and found.family_status == FamilyStatus.CREATED:
commands = [
db_delete(CollectionFamily).where(
CollectionFamily.family_import_id == import_id
),
db_delete(FamilyEvent).where(FamilyEvent.family_import_id == import_id),
db_delete(FamilyCorpus).where(FamilyCorpus.family_import_id == import_id),
db_delete(Slug).where(Slug.family_import_id == import_id),
db_delete(FamilyMetadata).where(
FamilyMetadata.family_import_id == import_id
),
db_delete(Family).where(Family.import_id == import_id),
]

for c in commands:
result = db.execute(c)
# Keep this for debug.
_LOGGER.debug("%s, %s", str(c), result.rowcount) # type: ignore
db.commit()

fam_deleted = (
db.query(Family).filter(Family.import_id == import_id).one_or_none()
)
if fam_deleted is not None:
msg = f"Could not hard delete family: {import_id}"
_LOGGER.error(msg)

return bool(fam_deleted is None)
if len(family_docs) == 0 and found.family_status == FamilyStatus.CREATED:
return hard_delete(db, import_id)

# Soft delete all documents associated with the family.
for doc in family_docs:
doc.document_status = DocumentStatus.DELETED
db.add(doc)

db.commit() # TODO: Fix PDCT-1115

# The below code is preserved in this comment while we decide
# what is wrong.
# TODO: remove
# result = db.execute(
# db_update(FamilyDocument)
# .filter(FamilyDocument.family_import_id == import_id)
# .values(document_status=DocumentStatus.DELETED)
# )

# if result.rowcount == 0: # type: ignore
# msg = f"Could not soft delete documents in family : {import_id}"
# _LOGGER.error(msg)
# raise RepositoryError(msg)

# Check family has been soft deleted if all documents have also been soft deleted.
fam_deleted = db.query(Family).filter(Family.import_id == import_id).one()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "admin_backend"
version = "2.6.2"
version = "2.6.3"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
29 changes: 18 additions & 11 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import pytest
from db_client import run_migrations
from db_client.models.base import Base
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.engine import Connection
Expand Down Expand Up @@ -39,8 +38,8 @@ def get_test_db_url() -> str:
return SQLALCHEMY_DATABASE_URI + f"_test_{uuid.uuid4()}"


@pytest.fixture
def test_db(scope="function"):
@pytest.fixture(scope="function")
def slow_db(monkeypatch):
"""Create a fresh test database for each test."""

test_db_url = get_test_db_url()
Expand All @@ -55,14 +54,19 @@ def test_db(scope="function"):
try:
test_engine = create_engine(test_db_url)
connection = test_engine.connect()
Base.metadata.create_all(test_engine) # type: ignore

run_migrations(test_engine)
test_session_maker = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
)
test_session = test_session_maker()

def get_test_db():
return test_session

monkeypatch.setattr(db_session, "get_db", get_test_db)
# Run the tests
yield test_session
finally:
Expand Down Expand Up @@ -99,29 +103,32 @@ def data_db_connection() -> Generator[Connection, None, None]:


@pytest.fixture(scope="function")
def data_db(data_db_connection):
def data_db(data_db_connection, monkeypatch):

transaction = data_db_connection.begin()
print(f"This test is being performed with transaction {transaction}")

SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=data_db_connection
)
session = SessionLocal()

def get_test_db():
return session

monkeypatch.setattr(db_session, "get_db", get_test_db)

yield session

session.close()
print(f"This test is finished and being rolledback with transaction {transaction}")
transaction.rollback()


@pytest.fixture
def client(data_db, monkeypatch):
def client():
"""Get a TestClient instance that reads/write to the test database."""

def get_test_db():
return data_db

monkeypatch.setattr(db_session, "get_db", get_test_db)

yield TestClient(app)


Expand Down
37 changes: 20 additions & 17 deletions tests/integration_tests/family/test_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_delete_family_without_docs(
assert data_db.query(Slug).filter(Slug.family_import_id == "A.0.0.2").count() == 0
assert (
data_db.query(FamilyMetadata)
.filter(FamilyDocument.family_import_id == "A.0.0.2")
.filter(FamilyMetadata.family_import_id == "A.0.0.2")
.count()
== 0
)
Expand All @@ -91,22 +91,25 @@ def test_delete_family_when_not_authenticated(client: TestClient, data_db: Sessi
assert response.status_code == status.HTTP_401_UNAUTHORIZED


def test_delete_family_rollback(
client: TestClient, data_db: Session, rollback_family_repo, admin_user_header_token
):
setup_db(data_db)
response = client.delete(
"/api/v1/families/A.0.0.3", headers=admin_user_header_token
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
assert (
data_db.query(FamilyDocument)
.filter(FamilyDocument.document_status == DocumentStatus.DELETED)
.count()
== 0
)
test_family = data_db.query(Family).filter(Family.import_id == "A.0.0.3").one()
assert test_family.family_status != FamilyStatus.DELETED
# FIX: PDCT-1115 - This test no longer works in the test environment, the rollback call
# returns the db to an empty state.

# def test_delete_family_rollback(
# client: TestClient, data_db: Session, rollback_family_repo, admin_user_header_token
# ):
# setup_db(data_db)
# response = client.delete(
# "/api/v1/families/A.0.0.3", headers=admin_user_header_token
# )
# assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
# assert (
# data_db.query(FamilyDocument)
# .filter(FamilyDocument.document_status == DocumentStatus.DELETED)
# .count()
# == 0
# )
# test_family = data_db.query(Family).filter(Family.import_id == "A.0.0.3").one()
# assert test_family.family_status != FamilyStatus.DELETED


def test_delete_family_when_not_found(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/family/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def test_update_family_idempotent_when_ok(
assert db_family.family_category == EXPECTED_FAMILIES[1]["category"]


# TODO
# TODO: Fix with PDCT-1115
# def test_update_family_rollback(
# client: TestClient, test_db: Session, rollback_family_repo, user_header_token
# ):
Expand Down
67 changes: 67 additions & 0 deletions tests/integration_tests/test_with_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
These tests are designed to test a real world `with_transaction()` call.
This is the decorator that is used on service calls to ensure that any
db operations are properly rolled back. However, as our tests are wrapped
in an outer transaction for speed there is some inter-play here that also
needs to be tested to ensure the testing environment is operating correctly.
"""

import pytest
from db_client.models.dfce import Family, FamilyStatus
from sqlalchemy import exc
from sqlalchemy.orm import Session, sessionmaker

from app.clients.db.session import with_transaction
from app.errors import RepositoryError
from app.repository.family import delete
from tests.integration_tests.setup_db import setup_db


def reconnect(bind) -> Session:
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=bind)
return SessionLocal()


def family_delete_ok(context=None, db=None):
assert db is not None
assert context is not None
return delete(db, "A.0.0.3")


def family_delete_fails(context=None, db=None):
assert db is not None
assert context is not None
delete(db, "A.0.0.3")
context.error = "During Testing"
raise exc.SQLAlchemyError("testing")


def test_with_transaction_when_family_delete_ok(slow_db: Session):
saved_bind = slow_db.bind
setup_db(slow_db)
inner = with_transaction("test")
wrapper = inner(family_delete_ok)
result = wrapper()
assert result is True
slow_db.close()

db = reconnect(saved_bind)
family = db.query(Family).filter(Family.import_id == "A.0.0.3").one()
assert family.family_status == FamilyStatus.DELETED


def test_with_transaction_when_family_delete_fails(slow_db: Session):
saved_bind = slow_db.bind
setup_db(slow_db)
inner = with_transaction("test")
wrapper = inner(family_delete_fails)
with pytest.raises(RepositoryError) as e:
wrapper()

assert e.value.message == "During Testing"
slow_db.close()

db = reconnect(saved_bind)
family = db.query(Family).filter(Family.import_id == "A.0.0.3").one()
assert family.family_status == FamilyStatus.CREATED
9 changes: 9 additions & 0 deletions tests/mocks/repos/rollback_family_repo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
For the family repo - do that actual db operation then raise to force a rollback.
Note this relies on the actual operation not doing the commit - rather the wrapper
at the service layer (with_transaction). However, we have a bug that means this
still needs fixing. See PDCT-1115
"""

from typing import Optional

from pytest import MonkeyPatch
Expand Down

0 comments on commit 7103c9d

Please sign in to comment.