Skip to content

Commit

Permalink
Remove tuple_in_condition helpers (#45201)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Dec 26, 2024
1 parent 2bb13c7 commit 3873230
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 115 deletions.
35 changes: 13 additions & 22 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing import TYPE_CHECKING, Any, Callable

from deprecated import deprecated
from sqlalchemy import and_, delete, exists, func, not_, select, text, update
from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -77,12 +77,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import (
is_lock_not_available_error,
prohibit_commit,
tuple_in_condition,
with_row_locks,
)
from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, with_row_locks
from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -357,28 +352,25 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
.join(TI.dag_run)
.where(DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.where(not_(DM.is_paused))
.where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.options(selectinload(TI.dag_model))
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
)

if starved_pools:
query = query.where(not_(TI.pool.in_(starved_pools)))
query = query.where(TI.pool.not_in(starved_pools))

if starved_dags:
query = query.where(not_(TI.dag_id.in_(starved_dags)))
query = query.where(TI.dag_id.not_in(starved_dags))

if starved_tasks:
task_filter = tuple_in_condition((TI.dag_id, TI.task_id), starved_tasks)
query = query.where(not_(task_filter))
query = query.where(tuple_(TI.dag_id, TI.task_id).not_in(starved_tasks))

if starved_tasks_task_dagrun_concurrency:
task_filter = tuple_in_condition(
(TI.dag_id, TI.run_id, TI.task_id),
starved_tasks_task_dagrun_concurrency,
query = query.where(
tuple_(TI.dag_id, TI.run_id, TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
)
query = query.where(not_(task_filter))

query = query.limit(max_tis)

Expand Down Expand Up @@ -1314,9 +1306,8 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
existing_dagruns = (
session.execute(
select(DagRun.dag_id, DagRun.logical_date).where(
tuple_in_condition(
(DagRun.dag_id, DagRun.logical_date),
((dm.dag_id, dm.next_dagrun) for dm in dag_models),
tuple_(DagRun.dag_id, DagRun.logical_date).in_(
(dm.dag_id, dm.next_dagrun) for dm in dag_models
),
)
)
Expand Down Expand Up @@ -1402,7 +1393,7 @@ def _create_dag_runs_asset_triggered(
existing_dagruns: set[tuple[str, timezone.DateTime]] = set(
session.execute(
select(DagRun.dag_id, DagRun.logical_date).where(
tuple_in_condition((DagRun.dag_id, DagRun.logical_date), logical_dates.items())
tuple_(DagRun.dag_id, DagRun.logical_date).in_(logical_dates.items())
)
)
)
Expand Down Expand Up @@ -2188,7 +2179,7 @@ def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, session: Sess
if assets:
session.execute(
delete(AssetActive).where(
tuple_in_condition((AssetActive.name, AssetActive.uri), ((a.name, a.uri) for a in assets))
tuple_(AssetActive.name, AssetActive.uri).in_((a.name, a.uri) for a in assets)
)
)
Stats.gauge("asset.orphaned", len(assets))
Expand All @@ -2201,7 +2192,7 @@ def _activate_referenced_assets(assets: Collection[AssetModel], *, session: Sess
active_assets = set(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_in_condition((AssetActive.name, AssetActive.uri), ((a.name, a.uri) for a in assets))
tuple_(AssetActive.name, AssetActive.uri).in_((a.name, a.uri) for a in assets)
)
)
)
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@
and_,
case,
func,
not_,
or_,
select,
tuple_,
update,
)
from sqlalchemy.ext.associationproxy import association_proxy
Expand Down Expand Up @@ -108,7 +108,7 @@
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, tuple_in_condition, with_row_locks
from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -1081,7 +1081,7 @@ def _get_task_instances(
tis = tis.where(TaskInstance.state.in_(state))

if exclude_run_ids:
tis = tis.where(not_(TaskInstance.run_id.in_(exclude_run_ids)))
tis = tis.where(TaskInstance.run_id.not_in(exclude_run_ids))

if include_dependent_dags:
# Recursively find external tasks indicated by ExternalTaskMarker
Expand Down Expand Up @@ -1192,7 +1192,7 @@ def _get_task_instances(
elif isinstance(next(iter(exclude_task_ids), None), str):
tis = tis.where(TI.task_id.notin_(exclude_task_ids))
else:
tis = tis.where(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
tis = tis.where(tuple_(TI.task_id, TI.map_index).not_in(exclude_task_ids))

return tis

Expand Down
7 changes: 4 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
not_,
or_,
text,
tuple_,
update,
)
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -74,7 +75,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -1644,7 +1645,7 @@ def schedule_tis(
.where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk),
tuple_(TI.task_id, TI.map_index).in_(schedulable_ti_ids_chunk),
)
.values(
state=TaskInstanceState.SCHEDULED,
Expand All @@ -1668,7 +1669,7 @@ def schedule_tis(
.where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
tuple_in_condition((TI.task_id, TI.map_index), dummy_ti_ids_chunk),
tuple_(TI.task_id, TI.map_index).in_(dummy_ti_ids_chunk),
)
.values(
state=TaskInstanceState.SUCCESS,
Expand Down
7 changes: 3 additions & 4 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@
from types import GeneratorType
from typing import TYPE_CHECKING

from sqlalchemy import update
from sqlalchemy import tuple_, update

from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy import Session
from sqlalchemy.orm import Session

from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
Expand Down Expand Up @@ -74,7 +73,7 @@ def _set_state_to_skipped(
.where(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), tasks),
tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(tasks),
)
.values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
.execution_options(synchronize_session=False)
Expand Down
16 changes: 6 additions & 10 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@
Text,
UniqueConstraint,
and_,
case,
delete,
extract,
false,
func,
inspect,
or_,
select,
text,
tuple_,
update,
Expand All @@ -71,7 +73,6 @@
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import lazyload, reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.sql.expression import case, select
from sqlalchemy_utils import UUIDType

from airflow import settings
Expand Down Expand Up @@ -131,12 +132,7 @@
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import (
ExecutorConfigType,
ExtendedJSON,
UtcDateTime,
tuple_in_condition,
)
from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
Expand Down Expand Up @@ -3497,7 +3493,7 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum
if task_id_only:
filters.append(cls.task_id.in_(task_id_only))
if with_map_index:
filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index))
filters.append(tuple_(cls.task_id, cls.map_index).in_(with_map_index))

if not filters:
return false()
Expand Down Expand Up @@ -3675,8 +3671,8 @@ def _get_inactive_asset_unique_keys(
AssetUniqueKey(name, uri)
for name, uri in session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_in_condition(
(AssetActive.name, AssetActive.uri), [attrs.astuple(key) for key in asset_unique_keys]
tuple_(AssetActive.name, AssetActive.uri).in_(
attrs.astuple(key) for key in asset_unique_keys
)
)
)
Expand Down
54 changes: 3 additions & 51 deletions airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import logging
from collections.abc import Generator, Iterable
from importlib import metadata
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any

from packaging import version
from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
Expand Down Expand Up @@ -438,22 +438,6 @@ def is_lock_not_available_error(error: OperationalError):
return False


@overload
def tuple_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any],
) -> ColumnOperators: ...


@overload
def tuple_in_condition(
columns: tuple[ColumnElement, ...],
collection: Select,
*,
session: Session,
) -> ColumnOperators: ...


def tuple_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any] | Select,
Expand All @@ -463,46 +447,14 @@ def tuple_in_condition(
"""
Generate a tuple-in-collection operator to use in ``.where()``.
For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
clause.
Kept for backward compatibility. Remove when providers drop support for
apache-airflow<3.0.
:meta private:
"""
return tuple_(*columns).in_(collection)


@overload
def tuple_not_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any],
) -> ColumnOperators: ...


@overload
def tuple_not_in_condition(
columns: tuple[ColumnElement, ...],
collection: Select,
*,
session: Session,
) -> ColumnOperators: ...


def tuple_not_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any] | Select,
*,
session: Session | None = None,
) -> ColumnOperators:
"""
Generate a tuple-not-in-collection operator to use in ``.where()``.
This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
:meta private:
"""
return tuple_(*columns).not_in(collection)


def get_orm_mapper():
"""Get the correct ORM mapper for the installed SQLAlchemy version."""
import sqlalchemy.orm.mapper
Expand Down
10 changes: 2 additions & 8 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from markupsafe import Markup
from pygments import highlight, lexers
from pygments.formatters import HtmlFormatter
from sqlalchemy import delete, func, select, types
from sqlalchemy import delete, func, select, tuple_, types
from sqlalchemy.ext.associationproxy import AssociationProxy

from airflow.api_fastapi.app import get_auth_manager
Expand All @@ -49,7 +49,6 @@
from airflow.utils.code_utils import get_python_source
from airflow.utils.helpers import alchemy_to_dict
from airflow.utils.json import WebEncoder
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State, TaskInstanceState
from airflow.www.forms import DateTimeWithTimezoneField
from airflow.www.widgets import AirflowDateTimePickerWidget
Expand Down Expand Up @@ -867,12 +866,7 @@ def delete(self, item: Model, raise_exception: bool = False) -> bool:

def delete_all(self, items: list[Model]) -> bool:
self.session.execute(
delete(TI).where(
tuple_in_condition(
(TI.dag_id, TI.run_id),
((x.dag_id, x.run_id) for x in items),
)
)
delete(TI).where(tuple_(TI.dag_id, TI.run_id).in_((x.dag_id, x.run_id) for x in items))
)
return super().delete_all(items)

Expand Down
Loading

0 comments on commit 3873230

Please sign in to comment.