Skip to content

Commit 1b5f1c8

Browse files
committed
Add raw SQL tracking for native versioning
1 parent a6ee0ce commit 1b5f1c8

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

sqlalchemy_continuum/dialects/postgresql.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
UPDATE {version_table_name}
1616
SET {update_values}
1717
WHERE
18-
{transaction_column} = (
19-
SELECT MAX(id) FROM {transaction_table_name}
20-
WHERE native_tx_id = txid_current()
21-
)
18+
{transaction_column} = transaction_id_value
2219
AND
2320
{primary_key_criteria}
2421
RETURNING *
@@ -37,7 +34,17 @@
3734

3835
procedure_sql = """
3936
CREATE OR REPLACE FUNCTION {procedure_name}() RETURNS TRIGGER AS $$
37+
DECLARE transaction_id_value INT;
4038
BEGIN
39+
transaction_id_value = (
40+
SELECT MAX(id) FROM {transaction_table_name}
41+
WHERE native_tx_id = txid_current()
42+
);
43+
IF (transaction_id_value IS NULL) THEN
44+
INSERT INTO transaction (native_tx_id)
45+
VALUES (txid_current()) RETURNING id INTO transaction_id_value;
46+
END IF;
47+
4148
IF (TG_OP = 'INSERT') THEN
4249
{after_insert}
4350
{upsert_insert}
@@ -61,10 +68,7 @@
6168

6269
validity_sql = """
6370
UPDATE {version_table_name}
64-
SET {end_transaction_column} = (
65-
SELECT MAX(id) FROM {transaction_table_name}
66-
WHERE native_tx_id = txid_current()
67-
)
71+
SET {end_transaction_column} = transaction_id_value
6872
WHERE
6973
{transaction_column} = (
7074
SELECT MIN({transaction_column}) FROM {version_table_name}
@@ -351,6 +355,7 @@ def __str__(self):
351355
excluded_columns=', '.join(
352356
"'%s'" % c for c in self.excluded_columns
353357
),
358+
transaction_table_name=self.transaction_table_name,
354359
after_insert=after_insert,
355360
after_update=after_update,
356361
after_delete=after_delete,

tests/test_raw_sql.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
from tests import TestCase, uses_native_versioning
3+
4+
5+
@pytest.mark.skipif('not uses_native_versioning()')
6+
class TestRawSQL(TestCase):
7+
def test_single_statement(self):
8+
self.session.execute(
9+
"INSERT INTO article (name) VALUES ('some article')"
10+
)
11+
assert self.session.execute(
12+
"SELECT COUNT(1) FROM transaction"
13+
).scalar() == 1
14+
15+
def test_multiple_statements(self):
16+
self.session.execute(
17+
"INSERT INTO article (name) VALUES ('some article')"
18+
)
19+
self.session.execute(
20+
"INSERT INTO article (name) VALUES ('some article')"
21+
)
22+
assert self.session.execute(
23+
"SELECT COUNT(1) FROM transaction"
24+
).scalar() == 1

0 commit comments

Comments
 (0)