Skip to content

Commit

Permalink
Update code style for airflow db commands to SQLAlchemy 2.0 style (#…
Browse files Browse the repository at this point in the history
…31486)

* Update code style for `airflow db` commands to SQLAlchemy 2.0 style

This commit introduces changes to the code styles of `airflow db` commands to remove 'RemovedIn20Warning'
and ensure compatibility with SQLAlchemy 2.0.

To see these warnings, you need to set SQLALCHEMY_WARN_20=True when using the db commands

* fixup! Update code style for `airflow db` commands to SQLAlchemy 2.0 style

* fixup! fixup! Update code style for `airflow db` commands to SQLAlchemy 2.0 style

* Use connection instead of session.get_bind()

* remove metadata.bind=bind
  • Loading branch information
ephraimbuddy authored May 26, 2023
1 parent e86f688 commit afa9ead
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 46 deletions.
5 changes: 3 additions & 2 deletions airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any

from sqlalchemy import MetaData, String
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry

from airflow.configuration import conf

Expand All @@ -45,8 +45,9 @@ def _get_schema():


metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
mapper_registry = registry(metadata=metadata)

Base: Any = declarative_base(metadata=metadata)
Base: Any = mapper_registry.generate_base()

ID_LEN = 250

Expand Down
77 changes: 35 additions & 42 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@

from airflow.models.base import Base


log = logging.getLogger(__name__)

REVISION_HEADS_MAP = {
Expand Down Expand Up @@ -686,21 +685,28 @@ def create_default_connections(session: Session = NEW_SESSION):
)


def _create_db_from_orm(session):
from alembic import command
def _get_flask_db(sql_database_uri):
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

from airflow.www.session import AirflowDatabaseSessionInterface

flask_app = Flask(__name__)
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
return db


def _create_db_from_orm(session):
from alembic import command

from airflow.models.base import Base
from airflow.www.fab_security.sqla.models import Model
from airflow.www.session import AirflowDatabaseSessionInterface

def _create_flask_session_tbl(sql_database_uri):
flask_app = Flask(__name__)
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
db = _get_flask_db(sql_database_uri)
db.create_all()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
Expand Down Expand Up @@ -1004,15 +1010,16 @@ def reflect_tables(tables: list[Base | str] | None, session):
"""
import sqlalchemy.schema

metadata = sqlalchemy.schema.MetaData(session.bind)
bind = session.bind
metadata = sqlalchemy.schema.MetaData()

if tables is None:
metadata.reflect(resolve_fks=False)
metadata.reflect(bind=bind, resolve_fks=False)
else:
for tbl in tables:
try:
table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False)
metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False)
except exc.InvalidRequestError:
continue
return metadata
Expand Down Expand Up @@ -1633,8 +1640,9 @@ def resetdb(session: Session = NEW_SESSION, skip_init: bool = False):
connection = settings.engine.connect()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
drop_airflow_models(connection)
drop_airflow_moved_tables(session)
with connection.begin():
drop_airflow_models(connection)
drop_airflow_moved_tables(connection)

if not skip_init:
initdb(session=session)
Expand Down Expand Up @@ -1701,27 +1709,12 @@ def drop_airflow_models(connection):
:return: None
"""
from airflow.models.base import Base

# Drop connection and chart - those tables have been deleted and in case you
# run resetdb on schema with chart or users table will fail
chart = Table("chart", Base.metadata)
chart.drop(settings.engine, checkfirst=True)
user = Table("user", Base.metadata)
user.drop(settings.engine, checkfirst=True)
users = Table("users", Base.metadata)
users.drop(settings.engine, checkfirst=True)
dag_stats = Table("dag_stats", Base.metadata)
dag_stats.drop(settings.engine, checkfirst=True)
session = Table("session", Base.metadata)
session.drop(settings.engine, checkfirst=True)
from airflow.www.fab_security.sqla.models import Model

Base.metadata.drop_all(connection)
# we remove the Tables here so that if resetdb is run metadata does not keep the old tables.
Base.metadata.remove(session)
Base.metadata.remove(dag_stats)
Base.metadata.remove(users)
Base.metadata.remove(user)
Base.metadata.remove(chart)
Model.metadata.drop_all(connection)
db = _get_flask_db(connection.engine.url)
db.drop_all()
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext

Expand All @@ -1731,11 +1724,11 @@ def drop_airflow_models(connection):
version.drop(connection)


def drop_airflow_moved_tables(session):
def drop_airflow_moved_tables(connection):
from airflow.models.base import Base
from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX

tables = set(inspect(session.get_bind()).get_table_names())
tables = set(inspect(connection).get_table_names())
to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)]
for tbl in to_delete:
tbl.drop(settings.engine, checkfirst=False)
Expand All @@ -1749,7 +1742,7 @@ def check(session: Session = NEW_SESSION):
:param session: session of the sqlalchemy
"""
session.execute("select 1 as is_alive;")
session.execute(text("select 1 as is_alive;"))
log.info("Connection successful.")


Expand Down Expand Up @@ -1780,23 +1773,23 @@ def create_global_lock(
dialect = conn.dialect
try:
if dialect.name == "postgresql":
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), timeout=lock_timeout)
conn.execute(text("SELECT pg_advisory_lock(:id)"), id=lock.value)
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), id=str(lock), timeout=lock_timeout)
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})
elif dialect.name == "mssql":
# TODO: make locking work for MSSQL
pass

yield
finally:
if dialect.name == "postgresql":
conn.execute("SET LOCK_TIMEOUT TO DEFAULT")
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), id=lock.value).fetchone()
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
if not unlocked:
raise RuntimeError("Error releasing DB lock!")
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
conn.execute(text("select RELEASE_LOCK(:id)"), id=str(lock))
conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
elif dialect.name == "mssql":
# TODO: make locking work for MSSQL
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def drop_tables_with_prefix(prefix):
metadata = reflect_tables(None, session)
for table_name, table in metadata.tables.items():
if table_name.startswith(prefix):
table.drop()
table.drop(session.bind)


def clear_db_serialized_dags():
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_resetdb(
session_mock = MagicMock()
resetdb(session_mock, skip_init=skip_init)
mock_drop_airflow.assert_called_once_with(mock_connect.return_value)
mock_drop_moved.assert_called_once_with(session_mock)
mock_drop_moved.assert_called_once_with(mock_connect.return_value)
if skip_init:
mock_init.assert_not_called()
else:
Expand Down

0 comments on commit afa9ead

Please sign in to comment.