diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 25746889..2609a62d 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -21,7 +21,12 @@ from typing import Union from sqlalchemy import cast +from sqlalchemy import Column +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema +from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import text from . import _autogen @@ -43,11 +48,9 @@ from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name - from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index - from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import UniqueConstraint from sqlalchemy.sql.selectable import TableClause from sqlalchemy.sql.type_api import TypeEngine @@ -136,6 +139,40 @@ def static_output(self, text: str) -> None: self.output_buffer.write(text + "\n\n") self.output_buffer.flush() + def version_table_impl( + self, + *, + version_table: str, + version_table_schema: Optional[str], + version_table_pk: bool, + **kw: Any, + ) -> Table: + """Generate a :class:`.Table` object which will be used as the + structure for the Alembic version table. + + Third party dialects may override this hook to provide an alternate + structure for this :class:`.Table`; requirements are only that it + be named based on the ``version_table`` parameter and contains + at least a single string-holding column named ``version_num``. + + .. versionadded:: 1.14 + + """ + vt = Table( + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint( + "version_num", name=f"{version_table}_pkc" + ) + ) + + return vt + def requires_recreate_in_batch( self, batch_op: BatchOperationsImpl ) -> bool: diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 6cfe5e23..28f01c3b 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -24,10 +24,6 @@ from sqlalchemy import Column from sqlalchemy import literal_column -from sqlalchemy import MetaData -from sqlalchemy import PrimaryKeyConstraint -from sqlalchemy import String -from sqlalchemy import Table from sqlalchemy.engine import Engine from sqlalchemy.engine import url as sqla_url from sqlalchemy.engine.strategies import MockEngineStrategy @@ -36,6 +32,7 @@ from .. import util from ..util import sqla_compat from ..util.compat import EncodedIO +from ..util.sqla_compat import _select if TYPE_CHECKING: from sqlalchemy.engine import Dialect @@ -190,18 +187,6 @@ def __init__( self.version_table_schema = version_table_schema = opts.get( "version_table_schema", None ) - self._version = Table( - version_table, - MetaData(), - Column("version_num", String(32), nullable=False), - schema=version_table_schema, - ) - if opts.get("version_table_pk", True): - self._version.append_constraint( - PrimaryKeyConstraint( - "version_num", name="%s_pkc" % version_table - ) - ) self._start_from_rev: Optional[str] = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( @@ -212,6 +197,13 @@ def __init__( self.output_buffer, opts, ) + + self._version = self.impl.version_table_impl( + version_table=version_table, + version_table_schema=version_table_schema, + version_table_pk=opts.get("version_table_pk", True), + ) + log.info("Context impl %s.", self.impl.__class__.__name__) if self.as_sql: log.info("Generating static SQL") @@ -540,7 +532,10 @@ def get_current_heads(self) -> Tuple[str, ...]: return () assert self.connection is not None return tuple( - row[0] for row in self.connection.execute(self._version.select()) + row[0] + for row in self.connection.execute( + _select(self._version.c.version_num) + ) ) def _ensure_version_table(self, purge: bool = False) -> None: diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst index 51a8a5e2..2d33a186 100644 --- a/docs/build/changelog.rst +++ b/docs/build/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== .. changelog:: - :version: 1.13.4 + :version: 1.14.0 :include_notes_from: unreleased .. changelog:: diff --git a/docs/build/unreleased/1560.rst b/docs/build/unreleased/1560.rst new file mode 100644 index 00000000..e808b307 --- /dev/null +++ b/docs/build/unreleased/1560.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, runtime + :tickets: 1560 + + Added a new hook to the :class:`.DefaultImpl` + :meth:`.DefaultImpl.version_table_impl`. This allows third party dialects + to define the exact structure of the alembic_version table, to include use + cases where the table requires special directives and/or additional columns + so that it may function correctly on a particular backend. This is not + intended as a user-expansion hook, only a dialect implementation hook to + produce a working alembic_version table. Pull request courtesy Maciek + BryƄski. diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 5ad3c21d..ca569366 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -1,10 +1,15 @@ from sqlalchemy import Column from sqlalchemy import inspect +from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy.dialects import registry +from sqlalchemy.engine import default from alembic import migration +from alembic.ddl import impl from alembic.testing import assert_raises from alembic.testing import assert_raises_message from alembic.testing import config @@ -373,3 +378,44 @@ def test_delete_multi_match_no_sane_rowcount(self): self.connection.dialect, "supports_sane_rowcount", False ): self.updater.update_to_step(_down("a", None, True)) + + +registry.register("custom_version", __name__, "CustomVersionDialect") + + +class CustomVersionDialect(default.DefaultDialect): + name = "custom_version" + + +class CustomVersionTableImpl(impl.DefaultImpl): + __dialect__ = "custom_version" + + def version_table_impl( + self, + *, + version_table, + version_table_schema, + version_table_pk, + **kw, + ): + vt = Table( + version_table, + MetaData(), + Column("id", Integer, autoincrement=True), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint("id", name=f"{version_table}_pkc") + ) + return vt + + +class CustomVersionTableTest(TestMigrationContext): + + def test_custom_version_table(self): + context = migration.MigrationContext.configure( + dialect_name="custom_version", + ) + eq_(len(context._version.columns), 2)