diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index dd5143d7..e336f257 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -171,9 +171,7 @@ def get_dag_params(self) -> Dict[str, Any]: dag_params["default_args"]["sla_miss_callback"] ) - if utils.check_dict_key(dag_params["default_args"], "on_execute_callback") and version.parse( - AIRFLOW_VERSION - ) >= version.parse("2.0.0"): + if utils.check_dict_key(dag_params["default_args"], "on_execute_callback"): if isinstance(dag_params["default_args"]["on_execute_callback"], str): dag_params["default_args"]["on_execute_callback"] = import_string( dag_params["default_args"]["on_execute_callback"] @@ -488,11 +486,10 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr :param dag: DAG instance that task groups to be added. """ task_groups_dict: Dict[str, "TaskGroup"] = {} - if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): - for task_group_name, task_group_conf in task_groups.items(): - DagBuilder.make_nested_task_groups( - task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag - ) + for task_group_name, task_group_conf in task_groups.items(): + DagBuilder.make_nested_task_groups( + task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag + ) return task_groups_dict diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 90cd97b7..6bf361ea 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -687,38 +687,30 @@ def test_get_dag_params_dag_with_task_group(): "dag_id": "test_dag", "dagrun_timeout": datetime.timedelta(seconds=600), } - if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): - error_message = "`task_groups` key can only be used with Airflow 2.x.x" - with pytest.raises(Exception, match=error_message): - td.get_dag_params() - else: - assert td.get_dag_params() == expected + + assert td.get_dag_params() == expected def test_build_task_groups(): td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP, DEFAULT_CONFIG) - if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): - error_message = "`task_groups` key can only be used with Airflow 2.x.x" - with pytest.raises(Exception, match=error_message): - td.build() - else: - actual = td.build() - task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")} - task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")} - assert actual["dag_id"] == "test_dag" - assert isinstance(actual["dag"], DAG) - assert len(actual["dag"].tasks) == 6 - assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"} - assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"} - assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == { - "task_4", - "task_group_2.task_5", - } - assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == { - "task_group_2.task_6", - } - assert {"task_group_1.task_2", "task_group_1.task_3"} == task_group_1 - assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2 + + actual = td.build() + task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")} + task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")} + assert actual["dag_id"] == "test_dag" + assert isinstance(actual["dag"], DAG) + assert len(actual["dag"].tasks) == 6 + assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"} + assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"} + assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == { + "task_4", + "task_group_2.task_5", + } + assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == { + "task_group_2.task_6", + } + assert {"task_group_1.task_2", "task_group_1.task_3"} == task_group_1 + assert {"task_group_2.task_5", "task_group_2.task_6"} == task_group_2 def test_build_task_groups_with_callbacks(): @@ -755,10 +747,8 @@ def test_make_task_groups(): dag = "dag" task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag) expected = MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag) - if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): - assert task_groups == {} - else: - assert task_groups["task_group"].__dict__ == expected.__dict__ + + assert task_groups["task_group"].__dict__ == expected.__dict__ def test_make_task_groups_empty(): @@ -794,8 +784,7 @@ def test_make_task_with_callback(): assert isinstance(actual, PythonOperator) assert callable(actual.on_failure_callback) assert callable(actual.on_success_callback) - if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): - assert callable(actual.on_execute_callback) + assert callable(actual.on_execute_callback) assert callable(actual.on_retry_callback) @@ -842,17 +831,16 @@ def test_dag_with_callback_name_and_file_default_args(): def test_make_timetable(): - if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): - td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG) - timetable = "airflow.timetables.interval.CronDataIntervalTimetable" - timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"} - actual = td.make_timetable(timetable, timetable_params) - assert actual.periodic - try: - assert actual.can_run - except AttributeError: - # can_run attribute was removed and replaced with can_be_scheduled in later versions of Airflow. - assert actual.can_be_scheduled + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG) + timetable = "airflow.timetables.interval.CronDataIntervalTimetable" + timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"} + actual = td.make_timetable(timetable, timetable_params) + assert actual.periodic + try: + assert actual.can_run + except AttributeError: + # can_run attribute was removed and replaced with can_be_scheduled in later versions of Airflow. + assert actual.can_be_scheduled def test_make_dag_with_callback(): @@ -1059,14 +1047,11 @@ def test_make_nested_task_groups(): "sub_task_group": MockTaskGroup(tooltip="this is a sub task group", group_id="sub_task_group", dag=dag), } - if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): - assert task_groups == {} - else: - sub_task_group = task_groups["sub_task_group"].__dict__ - assert sub_task_group["parent_group"] - del sub_task_group["parent_group"] - assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__ - assert sub_task_group == expected["sub_task_group"].__dict__ + sub_task_group = task_groups["sub_task_group"].__dict__ + assert sub_task_group["parent_group"] + del sub_task_group["parent_group"] + assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__ + assert sub_task_group == expected["sub_task_group"].__dict__ class TestTopologicalSortTasks: