Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/psycopack/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ def drop_sequence_if_exists(self, *, seq: str) -> None:
.as_string(self.conn)
)

def create_sequence(self, *, seq: str, bigint: bool) -> None:
def create_sequence(self, *, seq: str, bigint: bool, minvalue: int) -> None:
if bigint:
sql = "CREATE SEQUENCE {schema}.{seq} AS BIGINT;"
sql = "CREATE SEQUENCE {schema}.{seq} AS BIGINT MINVALUE {minvalue};"
else:
sql = "CREATE SEQUENCE {schema}.{seq};"
sql = "CREATE SEQUENCE {schema}.{seq} MINVALUE {minvalue};"

self.cur.execute(
psycopg.sql.SQL(sql)
.format(
seq=psycopg.sql.Identifier(seq),
schema=psycopg.sql.Identifier(self.schema),
minvalue=psycopg.sql.Literal(minvalue),
)
.as_string(self.conn)
)
Expand Down Expand Up @@ -475,20 +476,11 @@ def swap_pk_sequence_name(self, *, first_table: str, second_table: str) -> None:
self.rename_sequence(seq_from=second_seq, seq_to=first_seq)
self.rename_sequence(seq_from=temp_seq, seq_to=second_seq)

def transfer_pk_sequence_value(
self, *, source_table: str, dest_table: str, convert_pk_to_bigint: bool
) -> None:
def transfer_pk_sequence_value(self, *, source_table: str, dest_table: str) -> None:
source_seq = self.introspector.get_pk_sequence_name(table=source_table)
dest_seq = self.introspector.get_pk_sequence_name(table=dest_table)
value = self.introspector.get_pk_sequence_value(seq=source_seq)

if convert_pk_to_bigint and value < 0:
# special case handling where negative PK values were used before bigint conversion
value = 2**31 # reset to positive, specifically the first bigint value

# TODO: try to correctly restore a negative PK sequence value if we revert swap
# while doing a bigint conversion

self.cur.execute(
psycopg.sql.SQL("SELECT setval('{schema}.{sequence}', {value});")
.format(
Expand All @@ -499,6 +491,27 @@ def transfer_pk_sequence_value(
.as_string(self.conn)
)

def update_pk_sequence_value(self, *, table: str) -> None:
"""
Update the sequence value if it was negative (for use in bigint conversions).
"""
seq = self.introspector.get_pk_sequence_name(table=table)
value = self.introspector.get_pk_sequence_value(seq=seq)

if value < 0:
# special case handling where negative PK values were used before bigint conversion
value = 2**31 # reset to positive, specifically the first bigint value

self.cur.execute(
psycopg.sql.SQL("SELECT setval('{schema}.{sequence}', {value});")
.format(
schema=psycopg.sql.Identifier(self.schema),
sequence=psycopg.sql.Identifier(seq),
value=psycopg.sql.SQL(str(value)),
)
.as_string(self.conn)
)

def acquire_access_exclusive_lock(self, *, table: str) -> None:
self.cur.execute(
psycopg.sql.SQL("LOCK TABLE {schema}.{table} IN ACCESS EXCLUSIVE MODE;")
Expand Down
20 changes: 20 additions & 0 deletions src/psycopack/_introspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,26 @@ def get_pk_sequence_value(self, *, seq: str) -> int:
assert isinstance(value, int)
return value

def get_pk_sequence_min_value(self, *, seq: str) -> int:
self.cur.execute(
psycopg.sql.SQL(
dedent("""
SELECT min_value FROM pg_sequences
WHERE schemaname = {schema} AND sequencename = {sequence};
""")
)
.format(
schema=psycopg.sql.Literal(self.schema),
sequence=psycopg.sql.Literal(seq),
)
.as_string(self.conn)
)
result = self.cur.fetchone()
assert result is not None
value = result[0]
assert isinstance(value, int)
return value

def get_backfill_batch(self, *, table: str) -> BackfillBatch | None:
self.cur.execute(
psycopg.sql.SQL(
Expand Down
10 changes: 7 additions & 3 deletions src/psycopack/_repack.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def swap(self) -> None:
self.command.transfer_pk_sequence_value(
source_table=self.table,
dest_table=self.copy_table,
convert_pk_to_bigint=self.convert_pk_to_bigint,
)
self.command.rename_table(
table_from=self.table, table_to=self.repacked_name
Expand Down Expand Up @@ -461,7 +460,6 @@ def revert_swap(self) -> None:
self.command.transfer_pk_sequence_value(
source_table=self.table,
dest_table=self.repacked_name,
convert_pk_to_bigint=self.convert_pk_to_bigint,
)

self.command.rename_table(table_from=self.table, table_to=self.copy_table)
Expand Down Expand Up @@ -507,6 +505,11 @@ def clean_up(self) -> None:
)
self.command.drop_function_if_exists(function=self.repacked_function)

if self.convert_pk_to_bigint and self.introspector.get_pk_sequence_name(
table=self.table
):
self.command.update_pk_sequence_value(table=self.table)

for idx_sql in indexes:
for index_data in indexes[idx_sql]:
self.command.rename_index(
Expand Down Expand Up @@ -599,7 +602,7 @@ def _create_copy_table(self) -> None:
always=(pk_info.identity_type == "a"),
pk_column=self.pk_column,
)
elif self.introspector.get_pk_sequence_name(table=self.table):
elif seq := self.introspector.get_pk_sequence_name(table=self.table):
# Create a new sequence for the copied table's id column so that it
# does not depend on the original's one. Otherwise, we wouldn't be
# able to delete the original table after the repack process is
Expand All @@ -608,6 +611,7 @@ def _create_copy_table(self) -> None:
self.command.create_sequence(
seq=self.id_seq,
bigint=("big" in pk_info.data_types[0].lower()),
minvalue=self.introspector.get_pk_sequence_min_value(seq=seq),
)
self.command.set_table_id_seq(
table=self.copy_table,
Expand Down
6 changes: 6 additions & 0 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def create_table_for_repacking(
);
""")
)
if "serial" in pk_type.lower():
seq = f"{table_name}_{pk_name}_seq"
cur.execute(
f"ALTER SEQUENCE {schema}.{seq} MINVALUE {pk_start} RESTART WITH {pk_start};"
)

cur.execute(f"CREATE INDEX btree_idx ON {schema}.{table_name} (var_with_btree);")
cur.execute(
f"CREATE INDEX pattern_ops_idx ON {schema}.{table_name} (var_with_pattern_ops varchar_pattern_ops);"
Expand Down
125 changes: 124 additions & 1 deletion tests/test_repack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from textwrap import dedent
from typing import Tuple, Union
from unittest import mock

import pytest
Expand Down Expand Up @@ -137,7 +138,7 @@ def _assert_repack(
if table_before.pk_seq_val is None or table_before.pk_seq_val > 0:
assert table_before.pk_seq_val == table_after.pk_seq_val
else:
assert table_after.pk_seq_val == 2**31
assert table_after.pk_seq_val is None or table_after.pk_seq_val >= 2**31

# All functions and triggers are removed.
trigger_info = _get_trigger_info(repack, cur)
Expand Down Expand Up @@ -166,6 +167,78 @@ def _assert_reset(repack: Psycopack, cur: _cur.Cursor) -> None:
assert repack.introspector.get_table_oid(table=repack.tracker.tracker_table) is None


def _do_writes(
table: str,
cur: _cur.Cursor,
schema: str = "public",
check_table: str | None = None,
) -> None:
"""
Do some writes (insert, update, delete) to check that the copy function works.
"""
cur.execute(
dedent(f"""
INSERT INTO {schema}.{table} (
var_with_btree,
var_with_pattern_ops,
int_with_check,
int_with_not_valid_check,
int_with_long_index_name,
var_with_unique_idx,
var_with_unique_const,
valid_fk,
not_valid_fk,
{table},
var_maybe_with_exclusion,
var_with_multiple_idx
)
VALUES (
substring(md5(random()::text), 1, 10),
substring(md5(random()::text), 1, 10),
(floor(random() * 10) + 1)::int,
(floor(random() * 10) + 1)::int,
(floor(random() * 10) + 1)::int,
substring(md5(random()::text), 1, 10),
substring(md5(random()::text), 1, 10),
(floor(random() * 10) + 1)::int,
(floor(random() * 10) + 1)::int,
(floor(random() * 10) + 1)::int,
substring(md5(random()::text), 1, 10),
substring(md5(random()::text), 1, 10)
)
RETURNING id;
""")
)
result = cur.fetchone()
assert result is not None
id_ = result[0]
if check_table is not None:
assert _query_row(table=table, id_=id_, cur=cur, schema=schema) == _query_row(
table=check_table, id_=id_, cur=cur, schema=schema
)

cur.execute(f"UPDATE {schema}.{table} SET var_with_btree = 'foo' WHERE id = {id_};")
if check_table is not None:
assert _query_row(table=table, id_=id_, cur=cur, schema=schema) == _query_row(
table=check_table, id_=id_, cur=cur, schema=schema
)

cur.execute(f"DELETE FROM {schema}.{table} WHERE id = {id_};")
assert _query_row(table=table, id_=id_, cur=cur, schema=schema) is None
if check_table is not None:
assert _query_row(table=check_table, id_=id_, cur=cur, schema=schema) is None


def _query_row(
table: str,
id_: int,
cur: _cur.Cursor,
schema: str = "public",
) -> Tuple[Union[int, str], ...] | None:
cur.execute(f"SELECT * FROM {schema}.{table} WHERE id = {id_};")
return cur.fetchone()


@pytest.mark.parametrize(
"pk_type",
("bigint", "bigserial", "integer", "serial", "smallint", "smallserial"),
Expand Down Expand Up @@ -1324,6 +1397,56 @@ def test_when_table_has_negative_pk_values(
)


@pytest.mark.parametrize(
"initial_pk_type",
(
"integer",
"serial",
"smallint",
"smallserial",
),
)
def test_with_writes_when_table_has_negative_pk_values(
connection: _psycopg.Connection, initial_pk_type: str
) -> None:
with _cur.get_cursor(connection, logged=True) as cur:
factories.create_table_for_repacking(
connection=connection,
cur=cur,
table_name="to_repack",
rows=100,
pk_type=initial_pk_type,
pk_start=-200,
)
table_before = _collect_table_info(table="to_repack", connection=connection)

repack = Psycopack(
table="to_repack",
batch_size=1,
conn=connection,
cur=cur,
convert_pk_to_bigint=True,
)
repack.pre_validate()
repack.setup_repacking()
repack.backfill()
_do_writes(table="to_repack", cur=cur, check_table=repack.copy_table)
repack.sync_schemas()
_do_writes(table="to_repack", cur=cur, check_table=repack.copy_table)
repack.swap()
_do_writes(table="to_repack", cur=cur, check_table=repack.repacked_name)
repack.clean_up()
_do_writes(table="to_repack", cur=cur)

table_after = _collect_table_info(table="to_repack", connection=connection)
_assert_repack(
table_before=table_before,
table_after=table_after,
repack=repack,
cur=cur,
)


def test_when_table_has_large_value_being_inserted(
connection: _psycopg.Connection,
) -> None:
Expand Down
Loading