Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Airflow <> 2.0.0 check #334

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ def get_dag_params(self) -> Dict[str, Any]:
dag_params["default_args"]["sla_miss_callback"]
)

if utils.check_dict_key(dag_params["default_args"], "on_execute_callback") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.0.0"):
if utils.check_dict_key(dag_params["default_args"], "on_execute_callback"):
if isinstance(dag_params["default_args"]["on_execute_callback"], str):
dag_params["default_args"]["on_execute_callback"] = import_string(
dag_params["default_args"]["on_execute_callback"]
Expand Down Expand Up @@ -488,11 +486,10 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr
:param dag: DAG instance that task groups to be added.
"""
task_groups_dict: Dict[str, "TaskGroup"] = {}
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
for task_group_name, task_group_conf in task_groups.items():
DagBuilder.make_nested_task_groups(
task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag
)
for task_group_name, task_group_conf in task_groups.items():
DagBuilder.make_nested_task_groups(
task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag
)

return task_groups_dict

Expand Down
91 changes: 38 additions & 53 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,38 +687,30 @@ def test_get_dag_params_dag_with_task_group():
"dag_id": "test_dag",
"dagrun_timeout": datetime.timedelta(seconds=600),
}
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
error_message = "`task_groups` key can only be used with Airflow 2.x.x"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()
else:
assert td.get_dag_params() == expected

assert td.get_dag_params() == expected


def test_build_task_groups():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
error_message = "`task_groups` key can only be used with Airflow 2.x.x"
with pytest.raises(Exception, match=error_message):
td.build()
else:
actual = td.build()
task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")}
task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")}
assert actual["dag_id"] == "test_dag"
assert isinstance(actual["dag"], DAG)
assert len(actual["dag"].tasks) == 6
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"}
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"}
assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == {
"task_4",
"task_group_2.task_5",
}
assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == {
"task_group_2.task_6",
}
assert {"task_group_1.task_2", "task_group_1.task_3"} == task_group_1
assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2

actual = td.build()
task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")}
task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")}
assert actual["dag_id"] == "test_dag"
assert isinstance(actual["dag"], DAG)
assert len(actual["dag"].tasks) == 6
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"}
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"}
assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == {
"task_4",
"task_group_2.task_5",
}
assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == {
"task_group_2.task_6",
}
assert {"task_group_1.task_2", "task_group_1.task_3"} == task_group_1
assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2


def test_build_task_groups_with_callbacks():
Expand Down Expand Up @@ -755,10 +747,8 @@ def test_make_task_groups():
dag = "dag"
task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag)
expected = MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag)
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
assert task_groups == {}
else:
assert task_groups["task_group"].__dict__ == expected.__dict__

assert task_groups["task_group"].__dict__ == expected.__dict__


def test_make_task_groups_empty():
Expand Down Expand Up @@ -794,8 +784,7 @@ def test_make_task_with_callback():
assert isinstance(actual, PythonOperator)
assert callable(actual.on_failure_callback)
assert callable(actual.on_success_callback)
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
assert callable(actual.on_execute_callback)
assert callable(actual.on_execute_callback)
assert callable(actual.on_retry_callback)


Expand Down Expand Up @@ -842,17 +831,16 @@ def test_dag_with_callback_name_and_file_default_args():


def test_make_timetable():
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
timetable = "airflow.timetables.interval.CronDataIntervalTimetable"
timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"}
actual = td.make_timetable(timetable, timetable_params)
assert actual.periodic
try:
assert actual.can_run
except AttributeError:
# can_run attribute was removed and replaced with can_be_scheduled in later versions of Airflow.
assert actual.can_be_scheduled
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
timetable = "airflow.timetables.interval.CronDataIntervalTimetable"
timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"}
actual = td.make_timetable(timetable, timetable_params)
assert actual.periodic
try:
assert actual.can_run
except AttributeError:
# can_run attribute was removed and replaced with can_be_scheduled in later versions of Airflow.
assert actual.can_be_scheduled


def test_make_dag_with_callback():
Expand Down Expand Up @@ -1059,14 +1047,11 @@ def test_make_nested_task_groups():
"sub_task_group": MockTaskGroup(tooltip="this is a sub task group", group_id="sub_task_group", dag=dag),
}

if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
assert task_groups == {}
else:
sub_task_group = task_groups["sub_task_group"].__dict__
assert sub_task_group["parent_group"]
del sub_task_group["parent_group"]
assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__
assert sub_task_group == expected["sub_task_group"].__dict__
sub_task_group = task_groups["sub_task_group"].__dict__
assert sub_task_group["parent_group"]
del sub_task_group["parent_group"]
assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__
assert sub_task_group == expected["sub_task_group"].__dict__


class TestTopologicalSortTasks:
Expand Down
Loading