Skip to content

Commit

Permalink
refactor: Use transactional DB tests
Browse files Browse the repository at this point in the history
Make use of DB transactions to rollback tests back to the original state
instead of recreating DB on each test.
  • Loading branch information
gbdlin committed Jan 13, 2024
1 parent adfb798 commit 3393fd5
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 62 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ deps/upgrade:
poetry lock

test:
pytest ./tests
SQLALCHEMY_WARN_20=1 pytest ./tests
4 changes: 0 additions & 4 deletions docker-compose.local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ services:
redis_db:
image: redis:latest
restart: unless-stopped
ports:
- "6379:6379"
environment:
- REDIS_PORT=6379

Expand All @@ -32,7 +30,5 @@ services:
- POSTGRES_DB=decky
- POSTGRES_USER=decky
- POSTGRES_PASSWORD=decky
ports:
- '5432:5432'
volumes:
- ../store-postgres:/var/lib/postgresql/data
71 changes: 26 additions & 45 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime, UTC
from os import getenv
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -7,14 +6,18 @@
import pytest_asyncio
from httpx import AsyncClient
from pytest_mock import MockFixture
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import sessionmaker

import main
from api import database as db_dependency
from database.database import Database
from database.models import Base
from db_helpers import FakePluginGenerator
from db_helpers import (
create_test_db_engine,
create_test_db_sessionmaker,
prepare_test_db,
prepare_transactioned_db_session,
)

if TYPE_CHECKING:
from typing import AsyncIterator
Expand All @@ -33,7 +36,8 @@ def mock_external_services(session_mocker: "MockFixture"):
"cdn.fetch_image",
return_value=((DUMMY_DATA_PATH / "plugin-image.png").read_bytes(), "image/png"),
)
session_mocker.patch("discord.AsyncDiscordWebhook", new=session_mocker.AsyncMock)
discord_mock = session_mocker.patch("discord.AsyncDiscordWebhook", new=session_mocker.AsyncMock)
discord_mock.add_embed = session_mocker.Mock()


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -45,8 +49,7 @@ def mock_constants(session_mocker: "MockFixture"):


@pytest.fixture()
def plugin_store(db: "Database") -> "FastAPI":
main.app.dependency_overrides[db_dependency] = lambda: db
def plugin_store() -> "FastAPI":
return main.app


Expand All @@ -65,52 +68,30 @@ async def client_auth(client_unauth: "AsyncClient") -> "AsyncClient":
return client_unauth


@pytest.fixture()
def db_engine():
return create_async_engine(
getenv("DB_URL"),
pool_pre_ping=True,
# echo=settings.ECHO_SQL,
)


@pytest.fixture()
def db_sessionmaker(db_engine):
return sessionmaker(bind=db_engine, autoflush=False, future=True, expire_on_commit=False, class_=AsyncSession)
@pytest_asyncio.fixture(scope="session")
async def seed_db_engine() -> tuple["AsyncEngine", "sessionmaker"]:
engine = create_test_db_engine()
db_sessionmaker = create_test_db_sessionmaker(engine)
await prepare_test_db(engine, db_sessionmaker, True)
return engine, db_sessionmaker


@pytest_asyncio.fixture()
async def _migrate_db(db_engine):
async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@pytest.fixture(scope="session")
def seed_db_sessionmaker(seed_db_engine: tuple["AsyncEngine", "sessionmaker"]) -> "sessionmaker":
return seed_db_engine[1]


@pytest_asyncio.fixture()
async def db(_migrate_db: None, db_sessionmaker: sessionmaker, mocker: "MockFixture") -> "Database":
return Database(db_sessionmaker(), lock=mocker.MagicMock())
async def seed_db_session(seed_db_engine: tuple["AsyncEngine", "sessionmaker"]) -> "AsyncSession":
async with prepare_transactioned_db_session(*seed_db_engine) as session:
yield session


@pytest_asyncio.fixture()
async def seed_db(db: "Database", db_sessionmaker: "sessionmaker") -> "Database":
session = db_sessionmaker()
generator = FakePluginGenerator(session, datetime(2022, 2, 25, 0, 0, 0, tzinfo=UTC))
await generator.create(tags=["tag-1", "tag-2"], versions=["0.1.0", "0.2.0", "1.0.0"])
generator.date = datetime(2022, 2, 25, 0, 1, 0, 0, tzinfo=UTC)
await generator.create(image="2.png", tags=["tag-2"], versions=["1.1.0", "2.0.0"])
generator.date = datetime(2022, 2, 25, 0, 2, 0, 0, tzinfo=UTC)
await generator.create("third", tags=["tag-2", "tag-3"], versions=["3.0.0", "3.1.0", "3.2.0"])
generator.date = datetime(2022, 2, 25, 0, 3, 0, 0, tzinfo=UTC)
await generator.create(tags=["tag-1", "tag-3"], versions=["1.0.0", "2.0.0", "3.0.0", "4.0.0"])
generator.date = datetime(2022, 2, 25, 0, 4, 0, 0, tzinfo=UTC)
await generator.create(tags=["tag-1", "tag-2"], versions=["0.1.0", "0.2.0", "1.0.0"], visible=False)
generator.date = datetime(2022, 2, 25, 0, 5, 0, 0, tzinfo=UTC)
await generator.create(image="6.png", tags=["tag-2"], versions=["1.1.0", "2.0.0"], visible=False)
generator.date = datetime(2022, 2, 25, 0, 6, 0, 0, tzinfo=UTC)
await generator.create("seventh", tags=["tag-2", "tag-3"], versions=["3.0.0", "3.1.0", "3.2.0"], visible=False)
generator.date = datetime(2022, 2, 25, 0, 7, 0, 0, tzinfo=UTC)
await generator.create(tags=["tag-1", "tag-3"], versions=["1.0.0", "2.0.0", "3.0.0", "4.0.0"], visible=False)

return db
async def seed_db(plugin_store: "FastAPI", seed_db_session: "AsyncSession", mocker: "MockFixture") -> "Database":
database = Database(seed_db_session, lock=mocker.MagicMock())
main.app.dependency_overrides[db_dependency] = lambda: database
return database


@pytest.fixture()
Expand Down
94 changes: 88 additions & 6 deletions tests/db_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import random

from contextlib import asynccontextmanager
from datetime import datetime, timedelta, UTC
from hashlib import sha256
from os import getenv

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import select

from database.models import Artifact, Tag, Version
from database.models import Artifact, Base, Tag, Version


class FakePluginGenerator:
Expand Down Expand Up @@ -41,7 +46,7 @@ async def _create_plugin(
image_path,
tags,
visible,
id=None,
id_=None,
):
plugin = Artifact(
name=name,
Expand All @@ -51,8 +56,8 @@ async def _create_plugin(
tags=tags,
visible=visible,
)
if id is not None:
plugin.id = id
if id_ is not None:
plugin.id = id_
self.session.add(plugin)
await self.session.commit()
return plugin
Expand Down Expand Up @@ -107,3 +112,80 @@ async def create(
await self._create_versions(plugin, versions)

self.created_plugins_count += 1


def create_test_db_engine() -> "AsyncEngine":
return create_async_engine(
getenv("DB_URL"),
pool_pre_ping=True,
# echo=True,
)


def create_test_db_sessionmaker(engine: "AsyncEngine") -> "sessionmaker":
return sessionmaker(
bind=engine,
autoflush=False,
future=True,
expire_on_commit=False,
autocommit=False,
class_=AsyncSession,
)


async def migrate_test_db(engine: "AsyncEngine") -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)


async def seed_test_db(db_sessionmaker: "sessionmaker") -> None:
session = db_sessionmaker()
generator = FakePluginGenerator(session, datetime(2022, 2, 25, 0, 0, 0, tzinfo=UTC))
await generator.create(tags=["tag-1", "tag-2"], versions=["0.1.0", "0.2.0", "1.0.0"])
generator.date = datetime(2022, 2, 25, 0, 1, 0, 0, tzinfo=UTC)
await generator.create(image="2.png", tags=["tag-2"], versions=["1.1.0", "2.0.0"])
generator.date = datetime(2022, 2, 25, 0, 2, 0, 0, tzinfo=UTC)
await generator.create("third", tags=["tag-2", "tag-3"], versions=["3.0.0", "3.1.0", "3.2.0"])
generator.date = datetime(2022, 2, 25, 0, 3, 0, 0, tzinfo=UTC)
await generator.create(tags=["tag-1", "tag-3"], versions=["1.0.0", "2.0.0", "3.0.0", "4.0.0"])
generator.date = datetime(2022, 2, 25, 0, 4, 0, 0, tzinfo=UTC)
await generator.create(tags=["tag-1", "tag-2"], versions=["0.1.0", "0.2.0", "1.0.0"], visible=False)
generator.date = datetime(2022, 2, 25, 0, 5, 0, 0, tzinfo=UTC)
await generator.create(image="6.png", tags=["tag-2"], versions=["1.1.0", "2.0.0"], visible=False)
generator.date = datetime(2022, 2, 25, 0, 6, 0, 0, tzinfo=UTC)
await generator.create("seventh", tags=["tag-2", "tag-3"], versions=["3.0.0", "3.1.0", "3.2.0"], visible=False)
generator.date = datetime(2022, 2, 25, 0, 7, 0, 0, tzinfo=UTC)
await generator.create(tags=["tag-1", "tag-3"], versions=["1.0.0", "2.0.0", "3.0.0", "4.0.0"], visible=False)


async def prepare_test_db(
engine: "AsyncEngine",
db_sessionmaker: "sessionmaker",
seed: bool = False,
) -> None:
await migrate_test_db(engine)
if seed:
await seed_test_db(db_sessionmaker)


@asynccontextmanager
async def prepare_transactioned_db_session(engine: "AsyncEngine", db_sessionmaker: "sessionmaker") -> "AsyncSession":
connection = await engine.connect()
outer_transaction = await connection.begin()
async_session = db_sessionmaker(bind=connection)
# seems like for sqlite releasing last savepoint commits the whole transaction. This should fix that.
await connection.begin_nested()
nested = await connection.begin_nested()

@event.listens_for(async_session.sync_session, "after_transaction_end")
def end_savepoint(session, transaction):
nonlocal nested

if not nested.is_active:
nested = connection.sync_connection.begin_nested()

yield async_session

await outer_transaction.rollback()
await async_session.close()
await connection.close()
12 changes: 6 additions & 6 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,9 @@ async def test_submit_endpoint_requires_auth(client_unauth: "AsyncClient"):
),
[
(
lazy_fixture("db"),
lazy_fixture("seed_db"),
"new-plugin",
1,
9,
"new-plugin",
status.HTTP_201_CREATED,
[
Expand Down Expand Up @@ -664,9 +664,9 @@ async def test_submit_endpoint(
assert actual.downloads == 0
assert actual.updates == 0

statement = select(Tag).where(Tag.tag == "tag-1").with_only_columns([func.count()]).order_by(None)
statement = select(Tag).where(Tag.tag == "tag-1").with_only_columns(func.count()).order_by(None)
assert (await db_fixture.session.execute(statement)).scalar() == 1
statement = select(Tag).where(Tag.tag == "new-tag-2").with_only_columns([func.count()]).order_by(None)
statement = select(Tag).where(Tag.tag == "new-tag-2").with_only_columns(func.count()).order_by(None)
assert (await db_fixture.session.execute(statement)).scalar() == 1

list_response = await client_auth.get("/plugins")
Expand Down Expand Up @@ -813,9 +813,9 @@ async def test_update_endpoint(
assert actual.hash == expected["hash"]
assert actual.created.isoformat().replace("+00:00", "Z") == expected["created"] # type:ignore[union-attr]

statement = select(Tag).where(Tag.tag == "new-tag-1").with_only_columns([func.count()]).order_by(None)
statement = select(Tag).where(Tag.tag == "new-tag-1").with_only_columns(func.count()).order_by(None)
assert (await seed_db.session.execute(statement)).scalar() == 1
statement = select(Tag).where(Tag.tag == "tag-2").with_only_columns([func.count()]).order_by(None)
statement = select(Tag).where(Tag.tag == "tag-2").with_only_columns(func.count()).order_by(None)
assert (await seed_db.session.execute(statement)).scalar() == 1


Expand Down

0 comments on commit 3393fd5

Please sign in to comment.