Skip to content

Commit

Permalink
Add ability to configure alembic_version table in DialectImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
maver1ck committed Oct 31, 2024
1 parent bd50ba3 commit 0dc0bda
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 18 deletions.
29 changes: 29 additions & 0 deletions alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from typing import Union

from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import text

from . import _autogen
Expand Down Expand Up @@ -136,6 +141,30 @@ def static_output(self, text: str) -> None:
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()

def version_table_impl(self, version_table, version_table_schema, version_table_pk):
"""create the Table object for the version_table.
Provided as part of impl so that third party dialects can override
this.
.. versionadded:: 1.13.4
"""
vt = Table(
version_table,
MetaData(),
Column("version_num", String(32), nullable=False),
schema=version_table_schema,
)
if version_table_pk:
vt.append_constraint(
PrimaryKeyConstraint(
"version_num", name=f"{version_table}_pkc"
)
)

return vt

def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
Expand Down
25 changes: 8 additions & 17 deletions alembic/runtime/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@

from sqlalchemy import Column
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.engine import url as sqla_url
from sqlalchemy.engine.strategies import MockEngineStrategy
Expand Down Expand Up @@ -190,18 +187,6 @@ def __init__(
self.version_table_schema = version_table_schema = opts.get(
"version_table_schema", None
)
self._version = Table(
version_table,
MetaData(),
Column("version_num", String(32), nullable=False),
schema=version_table_schema,
)
if opts.get("version_table_pk", True):
self._version.append_constraint(
PrimaryKeyConstraint(
"version_num", name="%s_pkc" % version_table
)
)

self._start_from_rev: Optional[str] = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
Expand All @@ -212,6 +197,12 @@ def __init__(
self.output_buffer,
opts,
)

self._version = self.impl.version_table_impl(
version_table, version_table_schema,
opts.get("version_table_pk", True)
)

log.info("Context impl %s.", self.impl.__class__.__name__)
if self.as_sql:
log.info("Generating static SQL")
Expand Down Expand Up @@ -540,7 +531,7 @@ def get_current_heads(self) -> Tuple[str, ...]:
return ()
assert self.connection is not None
return tuple(
row[0] for row in self.connection.execute(self._version.select())
row[0] for row in self.connection.execute(select(self._version.c.version_num))
)

def _ensure_version_table(self, purge: bool = False) -> None:
Expand Down
46 changes: 45 additions & 1 deletion tests/test_version_table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from sqlalchemy import Column
from sqlalchemy import Column, PrimaryKeyConstraint
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.dialects import registry

from alembic import migration
from alembic.ddl import impl
from alembic.testing import assert_raises
from alembic.testing import assert_raises_message
from alembic.testing import config
Expand All @@ -20,6 +23,8 @@
)




def _up(from_, to_, branch_presence_changed=False):
return migration.StampStep(from_, to_, True, branch_presence_changed)

Expand All @@ -31,6 +36,7 @@ def _down(from_, to_, branch_presence_changed=False):
class TestMigrationContext(TestBase):
@classmethod
def setup_class(cls):
registry.register("mydialect", "sqlite", "SQLiteDialect")
cls.bind = config.db

def setUp(self):
Expand Down Expand Up @@ -373,3 +379,41 @@ def test_delete_multi_match_no_sane_rowcount(self):
self.connection.dialect, "supports_sane_rowcount", False
):
self.updater.update_to_step(_down("a", None, True))


class CustomVersionTableTest(TestMigrationContext):

class MyDialectImpl(impl.DefaultImpl):

def version_table_impl(self, version_table, version_table_schema, version_table_pk):
vt = Table(
version_table,
MetaData(),
Column("id", Integer, autoincrement=True),
Column("version_num", String(32), nullable=False),
schema=version_table_schema,
)
if version_table_pk:
vt.append_constraint(
PrimaryKeyConstraint(
"id", name=f"{version_table}_pkc"
)
)
return vt


def setUp(self):
# nasty hack to get the sqlite dialect to use our custom dialect implementation
impl._impls["sqlite_bak"] = impl._impls["sqlite"]
impl._impls["sqlite"] = self.MyDialectImpl
super().setUp()

def tearDown(self):
super().tearDown()
impl._impls["sqlite"] = impl._impls["sqlite_bak"]

def test_custom_version_table(self):
context = migration.MigrationContext.configure(
dialect_name="sqlite",
)
eq_(len(context._version.columns), 2)

0 comments on commit 0dc0bda

Please sign in to comment.