Skip to content

Commit

Permalink
Remove Airflow <> 2.0.0 check (#334)
Browse files Browse the repository at this point in the history
This PR removes remaining Airflow version 2.0.0 check
  • Loading branch information
pankajastro authored Jan 4, 2025
1 parent 5588811 commit c53c093
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 61 deletions.
13 changes: 5 additions & 8 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ def get_dag_params(self) -> Dict[str, Any]:
dag_params["default_args"]["sla_miss_callback"]
)

if utils.check_dict_key(dag_params["default_args"], "on_execute_callback") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.0.0"):
if utils.check_dict_key(dag_params["default_args"], "on_execute_callback"):
if isinstance(dag_params["default_args"]["on_execute_callback"], str):
dag_params["default_args"]["on_execute_callback"] = import_string(
dag_params["default_args"]["on_execute_callback"]
Expand Down Expand Up @@ -488,11 +486,10 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr
:param dag: DAG instance that task groups to be added.
"""
task_groups_dict: Dict[str, "TaskGroup"] = {}
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
for task_group_name, task_group_conf in task_groups.items():
DagBuilder.make_nested_task_groups(
task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag
)
for task_group_name, task_group_conf in task_groups.items():
DagBuilder.make_nested_task_groups(
task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag
)

return task_groups_dict

Expand Down
91 changes: 38 additions & 53 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,38 +687,30 @@ def test_get_dag_params_dag_with_task_group():
"dag_id": "test_dag",
"dagrun_timeout": datetime.timedelta(seconds=600),
}
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
error_message = "`task_groups` key can only be used with Airflow 2.x.x"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()
else:
assert td.get_dag_params() == expected

assert td.get_dag_params() == expected


def test_build_task_groups():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
error_message = "`task_groups` key can only be used with Airflow 2.x.x"
with pytest.raises(Exception, match=error_message):
td.build()
else:
actual = td.build()
task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")}
task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")}
assert actual["dag_id"] == "test_dag"
assert isinstance(actual["dag"], DAG)
assert len(actual["dag"].tasks) == 6
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"}
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"}
assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == {
"task_4",
"task_group_2.task_5",
}
assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == {
"task_group_2.task_6",
}
assert {"task_group_1.task_2", "task_group_1.task_3"} == task_group_1
assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2

actual = td.build()
task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")}
task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")}
assert actual["dag_id"] == "test_dag"
assert isinstance(actual["dag"], DAG)
assert len(actual["dag"].tasks) == 6
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"}
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"}
assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == {
"task_4",
"task_group_2.task_5",
}
assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == {
"task_group_2.task_6",
}
assert {"task_group_1.task_2", "task_group_1.task_3"} == task_group_1
assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2


def test_build_task_groups_with_callbacks():
Expand Down Expand Up @@ -755,10 +747,8 @@ def test_make_task_groups():
dag = "dag"
task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag)
expected = MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag)
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
assert task_groups == {}
else:
assert task_groups["task_group"].__dict__ == expected.__dict__

assert task_groups["task_group"].__dict__ == expected.__dict__


def test_make_task_groups_empty():
Expand Down Expand Up @@ -794,8 +784,7 @@ def test_make_task_with_callback():
assert isinstance(actual, PythonOperator)
assert callable(actual.on_failure_callback)
assert callable(actual.on_success_callback)
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
assert callable(actual.on_execute_callback)
assert callable(actual.on_execute_callback)
assert callable(actual.on_retry_callback)


Expand Down Expand Up @@ -842,17 +831,16 @@ def test_dag_with_callback_name_and_file_default_args():


def test_make_timetable():
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
timetable = "airflow.timetables.interval.CronDataIntervalTimetable"
timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"}
actual = td.make_timetable(timetable, timetable_params)
assert actual.periodic
try:
assert actual.can_run
except AttributeError:
# can_run attribute was removed and replaced with can_be_scheduled in later versions of Airflow.
assert actual.can_be_scheduled
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
timetable = "airflow.timetables.interval.CronDataIntervalTimetable"
timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"}
actual = td.make_timetable(timetable, timetable_params)
assert actual.periodic
try:
assert actual.can_run
except AttributeError:
# can_run attribute was removed and replaced with can_be_scheduled in later versions of Airflow.
assert actual.can_be_scheduled


def test_make_dag_with_callback():
Expand Down Expand Up @@ -1059,14 +1047,11 @@ def test_make_nested_task_groups():
"sub_task_group": MockTaskGroup(tooltip="this is a sub task group", group_id="sub_task_group", dag=dag),
}

if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
assert task_groups == {}
else:
sub_task_group = task_groups["sub_task_group"].__dict__
assert sub_task_group["parent_group"]
del sub_task_group["parent_group"]
assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__
assert sub_task_group == expected["sub_task_group"].__dict__
sub_task_group = task_groups["sub_task_group"].__dict__
assert sub_task_group["parent_group"]
del sub_task_group["parent_group"]
assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__
assert sub_task_group == expected["sub_task_group"].__dict__


class TestTopologicalSortTasks:
Expand Down

0 comments on commit c53c093

Please sign in to comment.