diff --git a/airflow/models/base.py b/airflow/models/base.py index 575eb6e0b1015..9965de7ec7071 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -20,7 +20,7 @@ from typing import Any from sqlalchemy import MetaData, String -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import registry from airflow.configuration import conf @@ -45,8 +45,9 @@ def _get_schema(): metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention) +mapper_registry = registry(metadata=metadata) -Base: Any = declarative_base(metadata=metadata) +Base: Any = mapper_registry.generate_base() ID_LEN = 250 diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 015f535af0fd1..0ffc304d49187 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -49,7 +49,6 @@ from airflow.models.base import Base - log = logging.getLogger(__name__) REVISION_HEADS_MAP = { @@ -686,21 +685,28 @@ def create_default_connections(session: Session = NEW_SESSION): ) -def _create_db_from_orm(session): - from alembic import command +def _get_flask_db(sql_database_uri): from flask import Flask from flask_sqlalchemy import SQLAlchemy + from airflow.www.session import AirflowDatabaseSessionInterface + + flask_app = Flask(__name__) + flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri + flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + db = SQLAlchemy(flask_app) + AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") + return db + + +def _create_db_from_orm(session): + from alembic import command + from airflow.models.base import Base from airflow.www.fab_security.sqla.models import Model - from airflow.www.session import AirflowDatabaseSessionInterface def _create_flask_session_tbl(sql_database_uri): - flask_app = Flask(__name__) - flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri - flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - db = SQLAlchemy(flask_app) - AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") + db = _get_flask_db(sql_database_uri) db.create_all() with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): @@ -1004,15 +1010,16 @@ def reflect_tables(tables: list[Base | str] | None, session): """ import sqlalchemy.schema - metadata = sqlalchemy.schema.MetaData(session.bind) + bind = session.bind + metadata = sqlalchemy.schema.MetaData() if tables is None: - metadata.reflect(resolve_fks=False) + metadata.reflect(bind=bind, resolve_fks=False) else: for tbl in tables: try: table_name = tbl if isinstance(tbl, str) else tbl.__tablename__ - metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False) + metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False) except exc.InvalidRequestError: continue return metadata @@ -1633,8 +1640,9 @@ def resetdb(session: Session = NEW_SESSION, skip_init: bool = False): connection = settings.engine.connect() with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): - drop_airflow_models(connection) - drop_airflow_moved_tables(session) + with connection.begin(): + drop_airflow_models(connection) + drop_airflow_moved_tables(connection) if not skip_init: initdb(session=session) @@ -1701,27 +1709,12 @@ def drop_airflow_models(connection): :return: None """ from airflow.models.base import Base - - # Drop connection and chart - those tables have been deleted and in case you - # run resetdb on schema with chart or users table will fail - chart = Table("chart", Base.metadata) - chart.drop(settings.engine, checkfirst=True) - user = Table("user", Base.metadata) - user.drop(settings.engine, checkfirst=True) - users = Table("users", Base.metadata) - users.drop(settings.engine, checkfirst=True) - dag_stats = Table("dag_stats", Base.metadata) - dag_stats.drop(settings.engine, checkfirst=True) - session = Table("session", Base.metadata) - session.drop(settings.engine, checkfirst=True) + from airflow.www.fab_security.sqla.models import Model Base.metadata.drop_all(connection) - # we remove the Tables here so that if resetdb is run metadata does not keep the old tables. - Base.metadata.remove(session) - Base.metadata.remove(dag_stats) - Base.metadata.remove(users) - Base.metadata.remove(user) - Base.metadata.remove(chart) + Model.metadata.drop_all(connection) + db = _get_flask_db(connection.engine.url) + db.drop_all() # alembic adds significant import time, so we import it lazily from alembic.migration import MigrationContext @@ -1731,11 +1724,11 @@ def drop_airflow_models(connection): version.drop(connection) -def drop_airflow_moved_tables(session): +def drop_airflow_moved_tables(connection): from airflow.models.base import Base from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX - tables = set(inspect(session.get_bind()).get_table_names()) + tables = set(inspect(connection).get_table_names()) to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)] for tbl in to_delete: tbl.drop(settings.engine, checkfirst=False) @@ -1749,7 +1742,7 @@ def check(session: Session = NEW_SESSION): :param session: session of the sqlalchemy """ - session.execute("select 1 as is_alive;") + session.execute(text("select 1 as is_alive;")) log.info("Connection successful.") @@ -1780,10 +1773,10 @@ def create_global_lock( dialect = conn.dialect try: if dialect.name == "postgresql": - conn.execute(text("SET LOCK_TIMEOUT to :timeout"), timeout=lock_timeout) - conn.execute(text("SELECT pg_advisory_lock(:id)"), id=lock.value) + conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) + conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): - conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), id=str(lock), timeout=lock_timeout) + conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout}) elif dialect.name == "mssql": # TODO: make locking work for MSSQL pass @@ -1791,12 +1784,12 @@ def create_global_lock( yield finally: if dialect.name == "postgresql": - conn.execute("SET LOCK_TIMEOUT TO DEFAULT") - (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), id=lock.value).fetchone() + conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) + (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone() if not unlocked: raise RuntimeError("Error releasing DB lock!") elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): - conn.execute(text("select RELEASE_LOCK(:id)"), id=str(lock)) + conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)}) elif dialect.name == "mssql": # TODO: make locking work for MSSQL pass diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py index 0ac1d7bb8ef0c..79074a1ddb807 100644 --- a/tests/test_utils/db.py +++ b/tests/test_utils/db.py @@ -82,7 +82,7 @@ def drop_tables_with_prefix(prefix): metadata = reflect_tables(None, session) for table_name, table in metadata.tables.items(): if table_name.startswith(prefix): - table.drop() + table.drop(session.bind) def clear_db_serialized_dags(): diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index dee5118331b78..e17dee6fe6b60 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -230,7 +230,7 @@ def test_resetdb( session_mock = MagicMock() resetdb(session_mock, skip_init=skip_init) mock_drop_airflow.assert_called_once_with(mock_connect.return_value) - mock_drop_moved.assert_called_once_with(session_mock) + mock_drop_moved.assert_called_once_with(mock_connect.return_value) if skip_init: mock_init.assert_not_called() else: