Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use namedtuple for TaskInstanceKeyType #9712

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,7 @@ def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKeyType, EventBuffe
self.event_buffer = {}
else:
for key in list(self.event_buffer.keys()):
dag_id, _, _, _ = key
if dag_id in dag_ids:
if key.dag_id in dag_ids:
cleared_events[key] = self.event_buffer.pop(key)

return cleared_events
Expand Down
4 changes: 2 additions & 2 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType
'Found matching task %s-%s (%s) with current state of %s',
task.dag_id, task.task_id, task.execution_date, task.state
)
return (dag_id, task_id, ex_time, try_num)
return TaskInstanceKeyType(dag_id, task_id, ex_time, try_num)
else:
self.log.warning(
'task_id/dag_id are not safe to use as Kubernetes labels. This can cause '
Expand Down Expand Up @@ -649,7 +649,7 @@ def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType
)
dag_id = task.dag_id
task_id = task.task_id
return dag_id, task_id, ex_time, try_num
return TaskInstanceKeyType(dag_id, task_id, ex_time, try_num)
self.log.warning(
'Failed to find and match task details to a pod; labels: %s',
labels
Expand Down
4 changes: 3 additions & 1 deletion airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,9 @@ def _collect_errors(self, ti_status, session=None):
def tabulate_ti_keys_set(set_ti_keys: Set[TaskInstanceKeyType]) -> str:
# Sorting by execution date first
sorted_ti_keys = sorted(
set_ti_keys, key=lambda ti_key: (ti_key[2], ti_key[0], ti_key[1], ti_key[3]))
set_ti_keys, key=lambda ti_key:
(ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number)
)
return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Execution date", "Try number"])

def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
Expand Down
21 changes: 10 additions & 11 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,8 @@ def _schedule_task_instances(

for ti in refreshed_tis:
# Add task to task instance
dag = dagbag.dags[ti.key[0]]
ti.task = dag.get_task(ti.key[1])
dag = dagbag.dags[ti.dag_id]
ti.task = dag.get_task(ti.task_id)

# We check only deps needed to set TI to SCHEDULED state here.
# Deps needed to set TI to QUEUED state will be batch checked later
Expand Down Expand Up @@ -1493,13 +1493,13 @@ def _process_executor_events(self, simple_dag_bag, session=None):
# Report execution
for key, value in event_buffer.items():
state, info = value
dag_id, task_id, execution_date, try_number = key
# We create map (dag_id, task_id, execution_date) -> in-memory try_number
ti_primary_key_to_try_number_map[key[:-1]] = try_number
ti_primary_key_to_try_number_map[key.short] = key.try_number

self.log.info(
"Executor reports execution of %s.%s execution_date=%s "
"exited with status %s for try_number %s",
dag_id, task_id, execution_date, state, try_number
key.dag_id, key.task_id, key.execution_date, state, key.try_number
)
if state in (State.FAILED, State.SUCCESS):
tis_with_right_state.append(key)
Expand All @@ -1512,21 +1512,20 @@ def _process_executor_events(self, simple_dag_bag, session=None):
filter_for_tis = TI.filter_for_tis(tis_with_right_state)
tis = session.query(TI).filter(filter_for_tis).all()
for ti in tis:
# Recreate ti_key (dag_id, task_id, execution_date, try_number) using in-memory try_number
dag_id, task_id, execution_date, _ = ti.key
try_number = ti_primary_key_to_try_number_map[(dag_id, task_id, execution_date)]
buffer_key = (dag_id, task_id, execution_date, try_number)
key = ti.key
try_number = ti_primary_key_to_try_number_map[key.short]
buffer_key = TaskInstanceKeyType(key.dag_id, key.task_id, key.execution_date, try_number)
state, info = event_buffer.pop(buffer_key)

# TODO: should we fail RUNNING as well, as we do in Backfills?
if ti.try_number == try_number and ti.state == State.QUEUED:
if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED:
Stats.incr('scheduler.tasks.killed_externally')
self.log.error(
"Executor reports task instance %s finished (%s) although the task says its %s. "
"(Info: %s) Was the task killed externally?",
ti, state, ti.state, info
)
simple_dag = simple_dag_bag.get_dag(dag_id)
simple_dag = simple_dag_bag.get_dag(ti.dag_id)
self.processor_agent.send_callback_to_execute(
full_filepath=simple_dag.full_filepath,
task_instance=ti,
Expand Down
31 changes: 20 additions & 11 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
import warnings
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from urllib.parse import quote

import dill
Expand Down Expand Up @@ -126,9 +126,18 @@ def clear_task_instances(tis,
dr.start_date = timezone.utcnow()


# Key used to identify task instance
# Tuple of: dag_id, task_id, execution_date, try_number
TaskInstanceKeyType = Tuple[str, str, datetime, int]
class TaskInstanceKeyType(NamedTuple):
"""
Key used to identify task instance.
"""
dag_id: str
task_id: str
execution_date: datetime
try_number: int

@property
def short(self) -> Tuple[str, str, datetime]:
turbaszek marked this conversation as resolved.
Show resolved Hide resolved
return self.dag_id, self.task_id, self.execution_date


class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
Expand Down Expand Up @@ -541,7 +550,7 @@ def key(self) -> TaskInstanceKeyType:
"""
Returns a tuple that identifies the task instance uniquely
"""
return self.dag_id, self.task_id, self.execution_date, self.try_number
return TaskInstanceKeyType(self.dag_id, self.task_id, self.execution_date, self.try_number)

@provide_session
def set_state(self, state, session=None, commit=True):
Expand Down Expand Up @@ -1711,11 +1720,11 @@ def filter_for_tis(
TI = TaskInstance
if not tis:
return None
if all(isinstance(t, tuple) for t in tis):
filter_for_tis = ([and_(TI.dag_id == dag_id,
TI.task_id == task_id,
TI.execution_date == execution_date)
for dag_id, task_id, execution_date, _ in tis])
if all(isinstance(t, TaskInstanceKeyType) for t in tis):
filter_for_tis = ([and_(TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.execution_date == ti.execution_date)
for ti in tis])
return or_(*filter_for_tis)
if all(isinstance(t, TaskInstance) for t in tis):
filter_for_tis = ([and_(TI.dag_id == ti.dag_id, # type: ignore
Expand All @@ -1724,7 +1733,7 @@ def filter_for_tis(
for ti in tis])
return or_(*filter_for_tis)

raise TypeError("All elements must have the same type: `TaskInstance` or `TaskInstanceKey`.")
raise TypeError("All elements must have the same type: `TaskInstance` or `TaskInstanceKeyType`.")


# State of the task instance.
Expand Down