Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sqlalchemy queries to 2.0 style (Part 1) #31569

Merged
merged 10 commits into from
May 30, 2023
47 changes: 26 additions & 21 deletions airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import logging

from sqlalchemy import and_, or_
from sqlalchemy import and_, delete, or_, select
from sqlalchemy.orm import Session

from airflow import models
Expand All @@ -47,25 +47,28 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
:return count of deleted dags
"""
log.info("Deleting DAG: %s", dag_id)
running_tis = (
session.query(models.TaskInstance.state)
.filter(models.TaskInstance.dag_id == dag_id)
.filter(models.TaskInstance.state == State.RUNNING)
.first()
running_tis = session.scalar(
select(models.TaskInstance.state)
.where(models.TaskInstance.dag_id == dag_id)
.where(models.TaskInstance.state == State.RUNNING)
.limit(1)
)
if running_tis:
raise AirflowException("TaskInstances still running")
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id).limit(1))
if dag is None:
raise DagNotFound(f"Dag id {dag_id} not found")

# deleting a DAG should also delete all of its subdags
dags_to_delete_query = session.query(DagModel.dag_id).filter(
or_(
DagModel.dag_id == dag_id,
and_(DagModel.dag_id.like(f"{dag_id}.%"), DagModel.is_subdag),
dags_to_delete_query = session.execute(
select(DagModel.dag_id).where(
or_(
DagModel.dag_id == dag_id,
and_(DagModel.dag_id.like(f"{dag_id}.%"), DagModel.is_subdag),
)
)
)

dags_to_delete = [dag_id for dag_id, in dags_to_delete_query]

# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
Expand All @@ -79,22 +82,24 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
if hasattr(model, "dag_id"):
if keep_records_in_log and model.__name__ == "Log":
continue
count += (
session.query(model)
.filter(model.dag_id.in_(dags_to_delete))
.delete(synchronize_session="fetch")
)
count += session.execute(
delete(model)
.where(model.dag_id.in_(dags_to_delete))
.execution_options(synchronize_session="fetch")
).rowcount
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += (
session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
)
count += session.execute(
delete(model).where(model.dag_id == parent_dag_id, model.task_id == task_id)
).rowcount

# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
synchronize_session="fetch"
session.execute(
delete(models.ImportError)
.where(models.ImportError.filename == dag.fileloc)
.execution_options(synchronize_session="fetch")
)

return count
7 changes: 4 additions & 3 deletions airflow/api/common/experimental/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from deprecated import deprecated
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.exceptions import AirflowBadRequest, PoolNotFound
Expand All @@ -33,7 +34,7 @@ def get_pool(name, session: Session = NEW_SESSION):
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")

pool = session.query(Pool).filter_by(pool=name).first()
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
raise PoolNotFound(f"Pool '{name}' doesn't exist")

Expand Down Expand Up @@ -65,7 +66,7 @@ def create_pool(name, slots, description, session: Session = NEW_SESSION):
raise AirflowBadRequest(f"Pool name can't be more than {pool_name_length} characters")

session.expire_on_commit = False
pool = session.query(Pool).filter_by(pool=name).first()
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
pool = Pool(pool=name, slots=slots, description=description)
session.add(pool)
Expand All @@ -88,7 +89,7 @@ def delete_pool(name, session: Session = NEW_SESSION):
if name == Pool.DEFAULT_POOL_NAME:
raise AirflowBadRequest(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")

pool = session.query(Pool).filter_by(pool=name).first()
pool = session.scalar(select(Pool).filter_by(pool=name).limit(1))
if pool is None:
raise PoolNotFound(f"Pool '{name}' doesn't exist")

Expand Down
56 changes: 30 additions & 26 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple

from sqlalchemy import or_
from sqlalchemy import or_, select
from sqlalchemy.orm import Session as SASession, lazyload

from airflow.models.dag import DAG
Expand Down Expand Up @@ -148,10 +148,10 @@ def set_state(
qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, dag_run_ids)

if commit:
tis_altered = qry_dag.with_for_update().all()
tis_altered = session.scalars(qry_dag.with_for_update()).all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.with_for_update().all()
tis_altered += session.scalars(qry_sub_dag.with_for_update()).all()
for task_instance in tis_altered:
# The try_number was decremented when setting to up_for_reschedule and deferred.
# Increment it back when changing the state again
Expand All @@ -160,10 +160,10 @@ def set_state(
task_instance.set_state(state, session=session)
session.flush()
else:
tis_altered = qry_dag.all()
tis_altered = session.scalars(qry_dag).all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.all()
tis_altered += session.scalars(qry_sub_dag).all()
return tis_altered


Expand All @@ -175,9 +175,9 @@ def all_subdag_tasks_query(
):
"""Get *all* tasks of the sub dags."""
qry_sub_dag = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
select(TaskInstance)
.where(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_sub_dag

Expand All @@ -190,13 +190,13 @@ def get_all_dag_task_query(
run_ids: Iterable[str],
):
"""Get all tasks of the main dag that will be affected by a state change."""
qry_dag = session.query(TaskInstance).filter(
qry_dag = select(TaskInstance).where(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id.in_(run_ids),
TaskInstance.ti_selector_condition(task_ids),
)

qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
qry_dag = qry_dag.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
lazyload(TaskInstance.dag_run)
)
return qry_dag
Expand Down Expand Up @@ -324,11 +324,8 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASess
"""Return DAG executions' run_ids."""
last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, session=session)
current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
first_dagrun = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id)
.order_by(DagRun.execution_date.asc())
.first()
first_dagrun = session.scalar(
select(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date.asc()).limit(1)
)

if last_dagrun is None:
Expand Down Expand Up @@ -361,7 +358,9 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA
:param state: target state
:param session: database session
"""
dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
dag_run = session.execute(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
Comment on lines +361 to +363
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be equivalent to

Suggested change
dag_run = session.execute(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
dag_run = session.scalars(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).one()

but I don’t think this is better. So just for reference in case anyone wonders.

dag_run.state = state
if state == State.RUNNING:
dag_run.start_date = timezone.utcnow()
Expand Down Expand Up @@ -464,12 +463,15 @@ def set_dag_run_state_to_failed(

# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
tis = session.scalars(
select(TaskInstance).where(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
)
)

task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

tasks = []
Expand All @@ -480,11 +482,13 @@ def set_dag_run_state_to_failed(
tasks.append(task)

# Mark non-finished tasks as SKIPPED.
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
tis = session.scalars(
select(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
)
)

tis = [ti for ti in tis]
Expand Down
7 changes: 5 additions & 2 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flask import g
from flask_login import current_user
from marshmallow import ValidationError
from sqlalchemy import or_
from sqlalchemy import delete, or_
from sqlalchemy.orm import Query, Session

from airflow.api.common.mark_tasks import (
Expand Down Expand Up @@ -74,7 +74,10 @@
@provide_session
def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a DAG Run."""
if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0:
deleted_count = session.execute(
delete(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)
).rowcount
if deleted_count == 0:
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, HTTPStatus.NO_CONTENT

Expand Down
5 changes: 3 additions & 2 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from flask import Response
from marshmallow import ValidationError
from sqlalchemy import func
from sqlalchemy import delete, func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

Expand All @@ -41,7 +41,8 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons
"""Delete a pool."""
if pool_name == "default_pool":
raise BadRequest(detail="Default Pool can't be deleted")
affected_count = session.query(Pool).filter(Pool.pool == pool_name).delete()
affected_count = session.execute(delete(Pool).where(Pool.pool == pool_name)).rowcount

if affected_count == 0:
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
return Response(status=HTTPStatus.NO_CONTENT)
Expand Down
3 changes: 2 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings

from graphviz.dot import Dot
from sqlalchemy import delete
from sqlalchemy.orm import Session

from airflow import settings
Expand Down Expand Up @@ -507,7 +508,7 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
@cli_utils.action_cli
def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
"""Serialize a DAG instance."""
session.query(SerializedDagModel).delete(synchronize_session=False)
session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False))

if not args.clear_only:
dagbag = DagBag(process_subdir(args.subdir))
Expand Down
10 changes: 6 additions & 4 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing import TYPE_CHECKING, Iterable, Iterator

from setproctitle import setproctitle
from sqlalchemy import exc, func, or_
from sqlalchemy import delete, exc, func, or_
from sqlalchemy.orm.session import Session

from airflow import settings
Expand Down Expand Up @@ -610,9 +610,11 @@ def update_import_errors(
# Clear the errors of the processed files
# that no longer have errors
for dagbag_file in files_without_error:
session.query(errors.ImportError).filter(
errors.ImportError.filename.startswith(dagbag_file)
).delete(synchronize_session="fetch")
session.execute(
delete(errors.ImportError)
.where(errors.ImportError.filename.startswith(dagbag_file))
.execution_options(synchronize_session="fetch")
)

# files that still have errors
existing_import_error_files = [x.filename for x in session.query(errors.ImportError.filename).all()]
Expand Down
Loading