diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 68a93dd0..eae01acb 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -4,7 +4,7 @@ from typing import Callable from typing import Iterator from typing import List -from typing import Optional +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import Union @@ -23,6 +23,10 @@ from ..runtime.environment import _GetRevArg from ..runtime.migration import MigrationContext + ProcessRevisionDirectiveFn = Callable[ + [MigrationContext, _GetRevArg, List["MigrationScript"]], None + ] + class Rewriter: """A helper object that allows easy 'rewriting' of ops streams. @@ -52,15 +56,21 @@ class Rewriter: _traverse = util.Dispatcher() - _chained: Optional[Rewriter] = None + _chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = () def __init__(self) -> None: self.dispatch = util.Dispatcher() - def chain(self, other: Rewriter) -> Rewriter: + def chain( + self, + other: Union[ + ProcessRevisionDirectiveFn, + Rewriter, + ], + ) -> Rewriter: """Produce a "chain" of this :class:`.Rewriter` to another. - This allows two rewriters to operate serially on a stream, + This allows two or more rewriters to operate serially on a stream, e.g.:: writer1 = autogenerate.Rewriter() @@ -89,7 +99,7 @@ def add_column_idx(context, revision, op): """ wr = self.__class__.__new__(self.__class__) wr.__dict__.update(self.__dict__) - wr._chained = other + wr._chained += (other,) return wr def rewrites( @@ -146,8 +156,8 @@ def __call__( directives: List[MigrationScript], ) -> None: self.process_revision_directives(context, revision, directives) - if self._chained: - self._chained(context, revision, directives) + for process_revision_directives in self._chained: + process_revision_directives(context, revision, directives) @_traverse.dispatch_for(ops.MigrationScript) def _traverse_script( diff --git a/docs/build/unreleased/1337.rst b/docs/build/unreleased/1337.rst new file mode 100644 index 00000000..2660e831 --- /dev/null +++ b/docs/build/unreleased/1337.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, autogenerate + :tickets: 1337 + + Fixes `autogenerate.Rewriter` so that more than two instances could be + chained together correctly, and `process_revision_directives` callable + could also be chained. diff --git a/tests/test_script_production.py b/tests/test_script_production.py index 3b5a6f60..7b7db814 100644 --- a/tests/test_script_production.py +++ b/tests/test_script_production.py @@ -933,6 +933,11 @@ def add_column_idx(context, revision, op): idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name]) return [op, idx_op] + def process_revision_directives(context, revision, generate_revisions): + generate_revisions[0].downgrade_ops = ops.DowngradeOps( + ops=[ops.DropColumnOp("t1", "x")] + ) + directives = [ ops.MigrationScript( util.rev_id(), @@ -956,7 +961,8 @@ def add_column_idx(context, revision, op): ] ctx, rev = mock.Mock(), mock.Mock() - writer1.chain(writer2)(ctx, rev, directives) + writer = writer1.chain(process_revision_directives).chain(writer2) + writer(ctx, rev, directives) eq_( autogenerate.render_python_code(directives[0].upgrade_ops), @@ -970,6 +976,13 @@ def add_column_idx(context, revision, op): " # ### end Alembic commands ###", ) + eq_( + autogenerate.render_python_code(directives[0].downgrade_ops), + "# ### commands auto generated by Alembic - please adjust! ###\n" + " op.drop_column('t1', 'x')\n" + " # ### end Alembic commands ###", + ) + def test_no_needless_pass(self): writer1 = autogenerate.Rewriter()