diff --git a/src/psycopack/_commands.py b/src/psycopack/_commands.py index 87f84b9..eec0e25 100644 --- a/src/psycopack/_commands.py +++ b/src/psycopack/_commands.py @@ -74,17 +74,18 @@ def drop_sequence_if_exists(self, *, seq: str) -> None: .as_string(self.conn) ) - def create_sequence(self, *, seq: str, bigint: bool) -> None: + def create_sequence(self, *, seq: str, bigint: bool, minvalue: int) -> None: if bigint: - sql = "CREATE SEQUENCE {schema}.{seq} AS BIGINT;" + sql = "CREATE SEQUENCE {schema}.{seq} AS BIGINT MINVALUE {minvalue};" else: - sql = "CREATE SEQUENCE {schema}.{seq};" + sql = "CREATE SEQUENCE {schema}.{seq} MINVALUE {minvalue};" self.cur.execute( psycopg.sql.SQL(sql) .format( seq=psycopg.sql.Identifier(seq), schema=psycopg.sql.Identifier(self.schema), + minvalue=psycopg.sql.Literal(minvalue), ) .as_string(self.conn) ) @@ -475,20 +476,11 @@ def swap_pk_sequence_name(self, *, first_table: str, second_table: str) -> None: self.rename_sequence(seq_from=second_seq, seq_to=first_seq) self.rename_sequence(seq_from=temp_seq, seq_to=second_seq) - def transfer_pk_sequence_value( - self, *, source_table: str, dest_table: str, convert_pk_to_bigint: bool - ) -> None: + def transfer_pk_sequence_value(self, *, source_table: str, dest_table: str) -> None: source_seq = self.introspector.get_pk_sequence_name(table=source_table) dest_seq = self.introspector.get_pk_sequence_name(table=dest_table) value = self.introspector.get_pk_sequence_value(seq=source_seq) - if convert_pk_to_bigint and value < 0: - # special case handling where negative PK values were used before bigint conversion - value = 2**31 # reset to positive, specifically the first bigint value - - # TODO: try to correctly restore a negative PK sequence value if we revert swap - # while doing a bigint conversion - self.cur.execute( psycopg.sql.SQL("SELECT setval('{schema}.{sequence}', {value});") .format( @@ -499,6 +491,27 @@ def transfer_pk_sequence_value( .as_string(self.conn) ) + def update_pk_sequence_value(self, *, table: str) -> None: + """ + Update the sequence value if it was negative (for use in bigint conversions). + """ + seq = self.introspector.get_pk_sequence_name(table=table) + value = self.introspector.get_pk_sequence_value(seq=seq) + + if value < 0: + # special case handling where negative PK values were used before bigint conversion + value = 2**31 # reset to positive, specifically the first bigint value + + self.cur.execute( + psycopg.sql.SQL("SELECT setval('{schema}.{sequence}', {value});") + .format( + schema=psycopg.sql.Identifier(self.schema), + sequence=psycopg.sql.Identifier(seq), + value=psycopg.sql.SQL(str(value)), + ) + .as_string(self.conn) + ) + def acquire_access_exclusive_lock(self, *, table: str) -> None: self.cur.execute( psycopg.sql.SQL("LOCK TABLE {schema}.{table} IN ACCESS EXCLUSIVE MODE;") diff --git a/src/psycopack/_introspect.py b/src/psycopack/_introspect.py index da3de1c..109acf1 100644 --- a/src/psycopack/_introspect.py +++ b/src/psycopack/_introspect.py @@ -628,6 +628,26 @@ def get_pk_sequence_value(self, *, seq: str) -> int: assert isinstance(value, int) return value + def get_pk_sequence_min_value(self, *, seq: str) -> int: + self.cur.execute( + psycopg.sql.SQL( + dedent(""" + SELECT min_value FROM pg_sequences + WHERE schemaname = {schema} AND sequencename = {sequence}; + """) + ) + .format( + schema=psycopg.sql.Literal(self.schema), + sequence=psycopg.sql.Literal(seq), + ) + .as_string(self.conn) + ) + result = self.cur.fetchone() + assert result is not None + value = result[0] + assert isinstance(value, int) + return value + def get_backfill_batch(self, *, table: str) -> BackfillBatch | None: self.cur.execute( psycopg.sql.SQL( diff --git a/src/psycopack/_repack.py b/src/psycopack/_repack.py index 712be06..c512b78 100644 --- a/src/psycopack/_repack.py +++ b/src/psycopack/_repack.py @@ -403,7 +403,6 @@ def swap(self) -> None: self.command.transfer_pk_sequence_value( source_table=self.table, dest_table=self.copy_table, - convert_pk_to_bigint=self.convert_pk_to_bigint, ) self.command.rename_table( table_from=self.table, table_to=self.repacked_name @@ -461,7 +460,6 @@ def revert_swap(self) -> None: self.command.transfer_pk_sequence_value( source_table=self.table, dest_table=self.repacked_name, - convert_pk_to_bigint=self.convert_pk_to_bigint, ) self.command.rename_table(table_from=self.table, table_to=self.copy_table) @@ -507,6 +505,11 @@ def clean_up(self) -> None: ) self.command.drop_function_if_exists(function=self.repacked_function) + if self.convert_pk_to_bigint and self.introspector.get_pk_sequence_name( + table=self.table + ): + self.command.update_pk_sequence_value(table=self.table) + for idx_sql in indexes: for index_data in indexes[idx_sql]: self.command.rename_index( @@ -599,7 +602,7 @@ def _create_copy_table(self) -> None: always=(pk_info.identity_type == "a"), pk_column=self.pk_column, ) - elif self.introspector.get_pk_sequence_name(table=self.table): + elif seq := self.introspector.get_pk_sequence_name(table=self.table): # Create a new sequence for the copied table's id column so that it # does not depend on the original's one. Otherwise, we wouldn't be # able to delete the original table after the repack process is @@ -608,6 +611,7 @@ def _create_copy_table(self) -> None: self.command.create_sequence( seq=self.id_seq, bigint=("big" in pk_info.data_types[0].lower()), + minvalue=self.introspector.get_pk_sequence_min_value(seq=seq), ) self.command.set_table_id_seq( table=self.copy_table, diff --git a/tests/factories.py b/tests/factories.py index e167045..b871ec8 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -70,6 +70,12 @@ def create_table_for_repacking( ); """) ) + if "serial" in pk_type.lower(): + seq = f"{table_name}_{pk_name}_seq" + cur.execute( + f"ALTER SEQUENCE {schema}.{seq} MINVALUE {pk_start} RESTART WITH {pk_start};" + ) + cur.execute(f"CREATE INDEX btree_idx ON {schema}.{table_name} (var_with_btree);") cur.execute( f"CREATE INDEX pattern_ops_idx ON {schema}.{table_name} (var_with_pattern_ops varchar_pattern_ops);" diff --git a/tests/test_repack.py b/tests/test_repack.py index b780926..2ebbe7a 100644 --- a/tests/test_repack.py +++ b/tests/test_repack.py @@ -1,5 +1,6 @@ import dataclasses from textwrap import dedent +from typing import Tuple, Union from unittest import mock import pytest @@ -137,7 +138,7 @@ def _assert_repack( if table_before.pk_seq_val is None or table_before.pk_seq_val > 0: assert table_before.pk_seq_val == table_after.pk_seq_val else: - assert table_after.pk_seq_val == 2**31 + assert table_after.pk_seq_val is None or table_after.pk_seq_val >= 2**31 # All functions and triggers are removed. trigger_info = _get_trigger_info(repack, cur) @@ -166,6 +167,78 @@ def _assert_reset(repack: Psycopack, cur: _cur.Cursor) -> None: assert repack.introspector.get_table_oid(table=repack.tracker.tracker_table) is None +def _do_writes( + table: str, + cur: _cur.Cursor, + schema: str = "public", + check_table: str | None = None, +) -> None: + """ + Do some writes (insert, update, delete) to check that the copy function works. + """ + cur.execute( + dedent(f""" + INSERT INTO {schema}.{table} ( + var_with_btree, + var_with_pattern_ops, + int_with_check, + int_with_not_valid_check, + int_with_long_index_name, + var_with_unique_idx, + var_with_unique_const, + valid_fk, + not_valid_fk, + {table}, + var_maybe_with_exclusion, + var_with_multiple_idx + ) + VALUES ( + substring(md5(random()::text), 1, 10), + substring(md5(random()::text), 1, 10), + (floor(random() * 10) + 1)::int, + (floor(random() * 10) + 1)::int, + (floor(random() * 10) + 1)::int, + substring(md5(random()::text), 1, 10), + substring(md5(random()::text), 1, 10), + (floor(random() * 10) + 1)::int, + (floor(random() * 10) + 1)::int, + (floor(random() * 10) + 1)::int, + substring(md5(random()::text), 1, 10), + substring(md5(random()::text), 1, 10) + ) + RETURNING id; + """) + ) + result = cur.fetchone() + assert result is not None + id_ = result[0] + if check_table is not None: + assert _query_row(table=table, id_=id_, cur=cur, schema=schema) == _query_row( + table=check_table, id_=id_, cur=cur, schema=schema + ) + + cur.execute(f"UPDATE {schema}.{table} SET var_with_btree = 'foo' WHERE id = {id_};") + if check_table is not None: + assert _query_row(table=table, id_=id_, cur=cur, schema=schema) == _query_row( + table=check_table, id_=id_, cur=cur, schema=schema + ) + + cur.execute(f"DELETE FROM {schema}.{table} WHERE id = {id_};") + assert _query_row(table=table, id_=id_, cur=cur, schema=schema) is None + if check_table is not None: + assert _query_row(table=check_table, id_=id_, cur=cur, schema=schema) is None + + +def _query_row( + table: str, + id_: int, + cur: _cur.Cursor, + schema: str = "public", +) -> Tuple[Union[int, str], ...] | None: + cur.execute(f"SELECT * FROM {schema}.{table} WHERE id = {id_};") + return cur.fetchone() + + @pytest.mark.parametrize( "pk_type", ("bigint", "bigserial", "integer", "serial", "smallint", "smallserial"), @@ -1324,6 +1397,56 @@ def test_when_table_has_negative_pk_values( ) +@pytest.mark.parametrize( + "initial_pk_type", + ( + "integer", + "serial", + "smallint", + "smallserial", + ), +) +def test_with_writes_when_table_has_negative_pk_values( + connection: _psycopg.Connection, initial_pk_type: str +) -> None: + with _cur.get_cursor(connection, logged=True) as cur: + factories.create_table_for_repacking( + connection=connection, + cur=cur, + table_name="to_repack", + rows=100, + pk_type=initial_pk_type, + pk_start=-200, + ) + table_before = _collect_table_info(table="to_repack", connection=connection) + + repack = Psycopack( + table="to_repack", + batch_size=1, + conn=connection, + cur=cur, + convert_pk_to_bigint=True, + ) + repack.pre_validate() + repack.setup_repacking() + repack.backfill() + _do_writes(table="to_repack", cur=cur, check_table=repack.copy_table) + repack.sync_schemas() + _do_writes(table="to_repack", cur=cur, check_table=repack.copy_table) + repack.swap() + _do_writes(table="to_repack", cur=cur, check_table=repack.repacked_name) + repack.clean_up() + _do_writes(table="to_repack", cur=cur) + + table_after = _collect_table_info(table="to_repack", connection=connection) + _assert_repack( + table_before=table_before, + table_after=table_after, + repack=repack, + cur=cur, + ) + + def test_when_table_has_large_value_being_inserted( connection: _psycopg.Connection, ) -> None: