Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-42411: Update 8c57494cabcc.py to support offline mode. #33

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
Create Date: 2023-12-08 11:25:33.984476

"""
import contextlib
import logging
import pprint
from collections.abc import Iterator
from typing import NamedTuple

import sqlalchemy
Expand Down Expand Up @@ -48,6 +50,28 @@
}


@contextlib.contextmanager
def _reflection_bind() -> Iterator[sqlalchemy.engine.Connection]:
"""Return database connection to be used for reflection. In online mode
this returns connection instantiated by Alembic, in offline mode it creates
new angine using configured URL.

Yields
------
connection : `sqlalchemy.engine.Connection`
Actual connection to database to use for reflection.
"""
if context.is_offline_mode():
url = context.config.get_main_option("sqlalchemy.url")
if url is None:
raise ValueError("sqlalchemy.url is missing from config")
engine = sqlalchemy.create_engine(url)
with engine.connect() as connection:
yield connection
else:
yield op.get_bind()


class TableInfo(NamedTuple):
"""Info about table reflected from database before migration."""

Expand Down Expand Up @@ -83,7 +107,8 @@ def upgrade() -> None:

# reflect schema from database
metadata = sqlalchemy.schema.MetaData(schema=schema)
metadata.reflect(bind)
with _reflection_bind() as conn:
metadata.reflect(conn)

all_tables = _all_tables(schema)
dynamic_tables = [table for table in all_tables if table not in STATIC_TABLES]
Expand Down Expand Up @@ -147,8 +172,9 @@ def upgrade() -> None:
op.drop_column("run", "name", schema=schema)

# Update metadata to see new columns.
metadata = sqlalchemy.schema.MetaData(schema=schema)
metadata.reflect(bind)
if not context.is_offline_mode():
metadata = sqlalchemy.schema.MetaData(schema=schema)
metadata.reflect(bind)

# Change PKs, need to drop old PK from collection table first.
_make_pk("collection", table_infos["collection"], schema=schema, drop_existing=True)
Expand All @@ -172,12 +198,13 @@ def _all_tables(schema: str) -> list[str]:

Returned tables are ordered based on their FK (in CREATE order).
"""
inspector = sqlalchemy.inspect(op.get_bind())
tables = [
table
for table, _ in inspector.get_sorted_table_and_fkc_names(schema)
if table and (table in STATIC_TABLES or table.startswith(DYNAMIC_TABLES_PREFIX))
]
with _reflection_bind() as conn:
inspector = sqlalchemy.inspect(conn)
tables = [
table
for table, _ in inspector.get_sorted_table_and_fkc_names(schema)
if table and (table in STATIC_TABLES or table.startswith(DYNAMIC_TABLES_PREFIX))
]
_LOG.debug("_all_tables: %s", tables)
return tables

Expand Down Expand Up @@ -211,19 +238,20 @@ def _reflect_tables(schema: str, table_names: list[str]) -> dict[str, TableInfo]
infos : `dict` [`str`, `TableInfo`]
Reflected information for the tables.
"""
inspector = sqlalchemy.inspect(op.get_bind())

table_infos: dict[str, TableInfo] = {}
for table in table_names:
pk = inspector.get_pk_constraint(table, schema)
fks = inspector.get_foreign_keys(table, schema)
uniques = inspector.get_unique_constraints(table, schema)
indices = inspector.get_indexes(table, schema)

table_infos[table] = TableInfo(
primary_key=pk, foreign_keys=fks, unique_constraints=uniques, indices=indices
)
_LOG.debug("TableInfo for %r: %s", table, table_infos[table])
with _reflection_bind() as conn:
inspector = sqlalchemy.inspect(conn)

for table in table_names:
pk = inspector.get_pk_constraint(table, schema)
fks = inspector.get_foreign_keys(table, schema)
uniques = inspector.get_unique_constraints(table, schema)
indices = inspector.get_indexes(table, schema)

table_infos[table] = TableInfo(
primary_key=pk, foreign_keys=fks, unique_constraints=uniques, indices=indices
)
_LOG.debug("TableInfo for %r: %s", table, table_infos[table])

return table_infos

Expand Down Expand Up @@ -426,12 +454,10 @@ def _extend_collection_table(metadata: sqlalchemy.schema.MetaData) -> None:
op.add_column(table_name, new_column, schema=schema)
_report_table_size("after adding collection_id", table_name, schema)

# Update metadata.
metadata.reflect(bind, extend_existing=True)

# Fill collection_id.
_LOG.info("Filling column %s.%s with IDs", table_name, column_name)
table = _get_table(metadata, table_name)
table.append_column(sqlalchemy.schema.Column(column_name, sqlalchemy.BigInteger, nullable=True))
update = table.update().values(collection_id=sequence.next_value())
op.execute(update)
_report_table_size("after filling collection_id", table_name, schema)
Expand Down Expand Up @@ -472,22 +498,19 @@ def _add_id_column(
parent_id_column : `str`
Name of the column in parent table holding collection IDs.
"""
bind = op.get_bind()

_LOG.info("Add column %s to table %s", id_column, table_name)
_report_table_size("before adding column", table_name, schema)

new_column = sqlalchemy.schema.Column(id_column, sqlalchemy.BigInteger, nullable=True)
op.add_column(table_name, new_column, schema=schema)
_report_table_size("after adding column", table_name, schema)

# Update metadata.
metadata.reflect(bind, extend_existing=True)

# Fill collection_id.
table = _get_table(metadata, table_name)
parent = _get_table(metadata, parent_table)

table.append_column(sqlalchemy.schema.Column(id_column, sqlalchemy.BigInteger, nullable=True))

# Correlated subquery to select collection id from parent table.
subq = sqlalchemy.sql.select(parent.columns[parent_id_column]).where(
parent.columns[parent_name_column] == table.columns[name_column]
Expand Down Expand Up @@ -516,6 +539,9 @@ def _get_table(metadata: sqlalchemy.schema.MetaData, name: str) -> sqlalchemy.sc

def _report_table_size(message: str, table: str, schema: str | None = None) -> None:
"""Print information about table sizes."""
if context.is_offline_mode():
# It's not worth the trouble to do it in offline mode.
return
if schema:
query = (
"select pg_table_size(quote_ident(:schema) || '.' || quote_ident(:table)), "
Expand Down
14 changes: 11 additions & 3 deletions python/lsst/daf/butler_migrate/butler_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
self._connection = connection
metadata = sqlalchemy.schema.MetaData(schema=schema)
self._table = sqlalchemy.schema.Table(
"butler_attributes", metadata, autoload_with=connection, schema=schema
"butler_attributes",
metadata,
sqlalchemy.schema.Column("name", sqlalchemy.Text, primary_key=True),
sqlalchemy.schema.Column("value", sqlalchemy.Text, nullable=False),
schema=schema,
)

def get(self, name: str) -> str | None:
Expand Down Expand Up @@ -108,7 +112,9 @@
"""
# update version
sql = self._table.update().where(self._table.columns.name == name).values(value=value)
return self._connection.execute(sql).rowcount
result = self._connection.execute(sql)
# result may be None in offline mode, assume that we updated something
return 1 if result is None else result.rowcount

def update_manager_version(self, manager: str, version: str) -> None:
"""Update version for the specified manager.
Expand Down Expand Up @@ -145,7 +151,9 @@
otherwise.
"""
sql = self._table.delete().where(self._table.columns.name == name)
return self._connection.execute(sql).rowcount
result = self._connection.execute(sql)

Check warning on line 154 in python/lsst/daf/butler_migrate/butler_attributes.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler_migrate/butler_attributes.py#L154

Added line #L154 was not covered by tests
# result may be None in offline mode, assume that we deleted something
return 1 if result is None else result.rowcount

Check warning on line 156 in python/lsst/daf/butler_migrate/butler_attributes.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler_migrate/butler_attributes.py#L156

Added line #L156 was not covered by tests

def get_dimensions_json(self) -> dict[str, Any]:
"""Return dimensions configuration from dimensions.json.
Expand Down
Loading