1
- import os
2
- import uuid
3
- from typing import Dict , Generator
1
+ import subprocess
2
+ import tempfile
3
+ from typing import Dict
4
4
5
5
import boto3
6
6
import pytest
9
9
from fastapi .testclient import TestClient
10
10
from moto import mock_s3
11
11
from sqlalchemy import create_engine
12
- from sqlalchemy .engine import Connection
13
12
from sqlalchemy .orm import sessionmaker
14
13
from sqlalchemy_utils import create_database , database_exists , drop_database
15
14
40
39
UNFCCC_ORG_ID = 2
41
40
SUPER_ORG_ID = 50
42
41
42
+ migration_file = None
43
43
44
- def get_test_db_url () -> str :
45
- return SQLALCHEMY_DATABASE_URI + f"_test_{ uuid .uuid4 ()} "
46
44
45
+ def _create_engine_run_migrations (test_db_url : str ):
46
+ test_engine = create_engine (test_db_url )
47
+ run_migrations (test_engine )
48
+ return test_engine
47
49
48
- @pytest .fixture (scope = "function" )
49
- def slow_db (monkeypatch ):
50
- """Create a fresh test database for each test."""
51
50
52
- test_db_url = get_test_db_url ()
51
+ def do_cached_migrations (test_db_url : str ):
52
+
53
+ global migration_file # Note this is scoped to the module, so it will not get recreated.
53
54
54
55
# Create the test database
55
56
if database_exists (test_db_url ):
56
57
drop_database (test_db_url )
57
58
create_database (test_db_url )
58
59
60
+ test_engine = None
61
+
62
+ if not migration_file :
63
+ test_engine = _create_engine_run_migrations (test_db_url )
64
+ migration_file = tempfile .NamedTemporaryFile ().name
65
+ result = subprocess .run (["pg_dump" , "-f" , migration_file ])
66
+ assert result .returncode == 0
67
+ else :
68
+ result = subprocess .run (["psql" , "-f" , migration_file ])
69
+ assert result .returncode == 0
70
+ test_engine = create_engine (test_db_url )
71
+
72
+ return test_engine
73
+
74
+
75
+ @pytest .fixture (scope = "function" )
76
+ def data_db (monkeypatch ):
77
+ """Create a fresh test database for each test."""
78
+
79
+ test_db_url = SQLALCHEMY_DATABASE_URI # Use the same db - cannot parrallelize tests
80
+
81
+ test_engine = do_cached_migrations (test_db_url )
59
82
test_session = None
60
83
connection = None
61
84
try :
62
- test_engine = create_engine (test_db_url )
63
- connection = test_engine .connect ()
64
85
65
- run_migrations ( test_engine )
86
+ connection = test_engine . connect ( )
66
87
test_session_maker = sessionmaker (
67
88
autocommit = False ,
68
89
autoflush = False ,
@@ -86,58 +107,6 @@ def get_test_db():
86
107
drop_database (test_db_url )
87
108
88
109
89
- @pytest .fixture (scope = "session" )
90
- def data_db_connection () -> Generator [Connection , None , None ]:
91
- test_db_url = get_test_db_url ()
92
-
93
- if database_exists (test_db_url ):
94
- drop_database (test_db_url )
95
- create_database (test_db_url )
96
-
97
- saved_db_url = os .environ ["DATABASE_URL" ]
98
- os .environ ["DATABASE_URL" ] = test_db_url
99
-
100
- test_engine = create_engine (test_db_url )
101
-
102
- run_migrations (test_engine )
103
- connection = test_engine .connect ()
104
-
105
- yield connection
106
- connection .close ()
107
-
108
- os .environ ["DATABASE_URL" ] = saved_db_url
109
- drop_database (test_db_url )
110
-
111
-
112
- @pytest .fixture (scope = "function" )
113
- def data_db (slow_db ):
114
- yield slow_db
115
-
116
-
117
- # @pytest.fixture(scope="function")
118
- # def data_db(data_db_connection, monkeypatch):
119
-
120
- # outer = data_db_connection.begin_nested()
121
- # SessionLocal = sessionmaker(
122
- # autocommit=False, autoflush=False, bind=data_db_connection
123
- # )
124
- # session = SessionLocal()
125
-
126
- # def get_test_db():
127
- # return session
128
-
129
- # monkeypatch.setattr(db_session, "get_db", get_test_db)
130
- # yield session
131
- # if not outer.is_active:
132
- # print("Outer transaction already completed.")
133
- # #raise RuntimeError("Outer transaction already completed.")
134
- # else:
135
- # outer.rollback()
136
- # n_cols = data_db_connection.execute("select count(*) from collection")
137
- # if n_cols.scalar() != 0:
138
- # raise RuntimeError("Database not cleaned up properly")
139
-
140
-
141
110
@pytest .fixture
142
111
def client ():
143
112
"""Get a TestClient instance that reads/write to the test database."""
0 commit comments