From 50318f8519d7dd2f8215d7e977347739ecd3e408 Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Fri, 19 Jun 2020 10:24:12 +0200 Subject: [PATCH] Use current_app.dag_bag instead of global variable (#9380) * Use current_app.dag_bag instead of global variable * fixup! Use current_app.dag_bag instead of global variable * fixup! fixup! Use current_app.dag_bag instead of global variable --- airflow/www/app.py | 4 ++ airflow/www/extensions/init_dagbag.py | 32 ++++++++++ airflow/www/views.py | 55 ++++++++--------- tests/www/test_views.py | 86 ++++++++------------------- 4 files changed, 86 insertions(+), 91 deletions(-) create mode 100644 airflow/www/extensions/init_dagbag.py diff --git a/airflow/www/app.py b/airflow/www/app.py index e70ab5c0f99f9..dde28175248c1 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. # + from datetime import timedelta from typing import Optional @@ -30,6 +31,7 @@ from airflow.utils.json import AirflowJsonEncoder from airflow.www.extensions.init_appbuilder import init_appbuilder from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links +from airflow.www.extensions.init_dagbag import init_dagbag from airflow.www.extensions.init_jinja_globals import init_jinja_globals from airflow.www.extensions.init_manifest_files import configure_manifest_files from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection @@ -85,6 +87,8 @@ def create_app(config=None, testing=False, app_name="Airflow"): db.session = settings.Session db.init_app(flask_app) + init_dagbag(flask_app) + init_api_experimental_auth(flask_app) Cache(app=flask_app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'}) diff --git a/airflow/www/extensions/init_dagbag.py b/airflow/www/extensions/init_dagbag.py new file mode 100644 index 0000000000000..89003a78fe4d7 --- /dev/null +++ b/airflow/www/extensions/init_dagbag.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +from airflow.models import DagBag +from airflow.settings import DAGS_FOLDER, STORE_SERIALIZED_DAGS + + +def init_dagbag(app): + """ + Create global DagBag for webserver and API. To access it use + ``flask.current_app.dag_bag``. + """ + if os.environ.get('SKIP_DAGS_PARSING') == 'True': + app.dag_bag = DagBag(os.devnull, include_examples=False) + else: + app.dag_bag = DagBag(DAGS_FOLDER, store_serialized_dags=STORE_SERIALIZED_DAGS) diff --git a/airflow/www/views.py b/airflow/www/views.py index 776aac175a613..7540e9e508814 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -22,7 +22,6 @@ import json import logging import math -import os import socket import traceback from collections import defaultdict @@ -61,7 +60,6 @@ from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, XCom, errors from airflow.models.dagcode import DagCode from airflow.models.dagrun import DagRun, DagRunType -from airflow.settings import STORE_SERIALIZED_DAGS from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS from airflow.utils import timezone @@ -81,11 +79,6 @@ FILTER_TAGS_COOKIE = 'tags_filter' FILTER_STATUS_COOKIE = 'dag_status_filter' -if os.environ.get('SKIP_DAGS_PARSING') != 'True': - dagbag = models.DagBag(settings.DAGS_FOLDER, store_serialized_dags=STORE_SERIALIZED_DAGS) -else: - dagbag = models.DagBag(os.devnull, include_examples=False) - def get_date_time_num_runs_dag_runs_form_data(request, session, dag): dttm = request.args.get('execution_date') @@ -577,7 +570,7 @@ def code(self, session=None): @provide_session def dag_details(self, session=None): dag_id = request.args.get('dag_id') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) title = "DAG details" root = request.args.get('root', '') @@ -612,7 +605,7 @@ def rendered(self): root = request.args.get('root', '') logging.info("Retrieving rendered templates.") - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) task = copy.copy(dag.get_task(task_id)) ti = models.TaskInstance(task=task, execution_date=dttm) @@ -700,7 +693,7 @@ def _get_logs_with_metadata(try_number, metadata): try: if ti is not None: - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if dag: ti.task = dag.get_task(ti.task_id) if response_format == 'json': @@ -803,7 +796,7 @@ def task(self): dttm = timezone.parse(execution_date) form = DateTimeForm(data={'execution_date': dttm}) root = request.args.get('root', '') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: flash( @@ -921,7 +914,7 @@ def run(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') origin = request.form.get('origin') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) task = dag.get_task(task_id) execution_date = request.form.get('execution_date') @@ -1052,7 +1045,7 @@ def trigger(self, session=None): conf=conf ) - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) dag.create_dagrun( run_type=DagRunType.MANUAL, execution_date=execution_date, @@ -1117,7 +1110,7 @@ def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') origin = request.form.get('origin') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) execution_date = request.form.get('execution_date') execution_date = timezone.parse(execution_date) @@ -1150,7 +1143,7 @@ def dagrun_clear(self): execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) execution_date = timezone.parse(execution_date) start_date = execution_date end_date = execution_date @@ -1192,7 +1185,7 @@ def blocked(self, session=None): payload = [] for dag_id, active_dag_runs in dags: max_active_runs = 0 - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if dag: # TODO: Make max_active_runs a column so we can query for it directly max_active_runs = dag.max_active_runs @@ -1209,7 +1202,7 @@ def _mark_dagrun_state_as_failed(self, dag_id, execution_date, confirmed, origin return redirect(origin) execution_date = timezone.parse(execution_date) - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if not dag: flash('Cannot find DAG: {}'.format(dag_id), 'error') @@ -1237,7 +1230,7 @@ def _mark_dagrun_state_as_success(self, dag_id, execution_date, confirmed, origi return redirect(origin) execution_date = timezone.parse(execution_date) - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if not dag: flash('Cannot find DAG: {}'.format(dag_id), 'error') @@ -1287,7 +1280,7 @@ def dagrun_success(self): def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date, confirmed, upstream, downstream, future, past, state): - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) task = dag.get_task(task_id) task.dag = dag @@ -1371,7 +1364,7 @@ def success(self): def tree(self): dag_id = request.args.get('dag_id') blur = conf.getboolean('webserver', 'demo_mode') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if not dag: flash('DAG "{0}" seems to be missing from DagBag.'.format(dag_id), "error") return redirect(url_for('Airflow.index')) @@ -1528,7 +1521,7 @@ def recurse_nodes(task, visited): def graph(self, session=None): dag_id = request.args.get('dag_id') blur = conf.getboolean('webserver', 'demo_mode') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if not dag: flash('DAG "{0}" seems to be missing.'.format(dag_id), "error") return redirect(url_for('Airflow.index')) @@ -1627,7 +1620,7 @@ class GraphForm(DateTimeWithNumRunsWithDagRunsForm): def duration(self, session=None): default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') dag_id = request.args.get('dag_id') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs') num_runs = int(num_runs) if num_runs else default_dag_run @@ -1739,7 +1732,7 @@ def duration(self, session=None): def tries(self, session=None): default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') dag_id = request.args.get('dag_id') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs') num_runs = int(num_runs) if num_runs else default_dag_run @@ -1804,7 +1797,7 @@ def tries(self, session=None): def landing_times(self, session=None): default_dag_run = conf.getint('webserver', 'default_dag_run_display_number') dag_id = request.args.get('dag_id') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) base_date = request.args.get('base_date') num_runs = request.args.get('num_runs') num_runs = int(num_runs) if num_runs else default_dag_run @@ -1902,7 +1895,7 @@ def refresh(self, session=None): session.merge(orm_dag) session.commit() - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) # sync dag permission current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control) @@ -1916,7 +1909,7 @@ def refresh(self, session=None): @provide_session def gantt(self, session=None): dag_id = request.args.get('dag_id') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) demo_mode = conf.getboolean('webserver', 'demo_mode') root = request.args.get('root') @@ -2029,7 +2022,7 @@ def extra_links(self): execution_date = request.args.get('execution_date') link_name = request.args.get('link_name') dttm = timezone.parse(execution_date) - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) if not dag or task_id not in dag.task_ids: response = jsonify( @@ -2066,7 +2059,7 @@ def extra_links(self): @action_logging def task_instances(self): dag_id = request.args.get('dag_id') - dag = dagbag.get_dag(dag_id) + dag = current_app.dag_bag.get_dag(dag_id) dttm = request.args.get('execution_date') if dttm: @@ -2544,7 +2537,7 @@ def action_set_failed(self, drs, session=None): dirty_ids.append(dr.dag_id) count += 1 altered_tis += \ - set_dag_run_state_to_failed(dagbag.get_dag(dr.dag_id), + set_dag_run_state_to_failed(current_app.dag_bag.get_dag(dr.dag_id), dr.execution_date, commit=True, session=session) @@ -2571,7 +2564,7 @@ def action_set_success(self, drs, session=None): dirty_ids.append(dr.dag_id) count += 1 altered_tis += \ - set_dag_run_state_to_success(dagbag.get_dag(dr.dag_id), + set_dag_run_state_to_success(current_app.dag_bag.get_dag(dr.dag_id), dr.execution_date, commit=True, session=session) @@ -2665,7 +2658,7 @@ def action_clear(self, tis, session=None): dag_to_tis = {} for ti in tis: - dag = dagbag.get_dag(ti.dag_id) + dag = current_app.dag_bag.get_dag(ti.dag_id) tis = dag_to_tis.setdefault(dag, []) tis.append(ti) diff --git a/tests/www/test_views.py b/tests/www/test_views.py index dcb277c79f1f9..d6fb16f38877e 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -395,6 +395,7 @@ class TestAirflowBaseViews(TestBase): def setUpClass(cls): super().setUpClass() cls.dagbag = models.DagBag(include_examples=True) + cls.app.dag_bag = cls.dagbag DAG.bulk_sync_to_db(cls.dagbag.dags.values()) def setUp(self): @@ -614,16 +615,13 @@ def test_dag_details(self): self.check_content_in_response('DAG details', resp) @parameterized.expand(["graph", "tree", "dag_details"]) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_view_uses_existing_dagbag(self, endpoint, mock_get_dag): + def test_view_uses_existing_dagbag(self, endpoint): """ Test that Graph, Tree & Dag Details View uses the DagBag already created in views.py instead of creating a new one. """ - mock_get_dag.return_value = DAG(dag_id='example_bash_operator') url = f'{endpoint}?dag_id=example_bash_operator' resp = self.client.get(url, follow_redirects=True) - mock_get_dag.assert_called_once_with('example_bash_operator') self.check_content_in_response('example_bash_operator', resp) @parameterized.expand([ @@ -1052,7 +1050,7 @@ def setUp(self): settings.configure_orm() self.login() - from airflow.www.views import dagbag + dagbag = self.app.dag_bag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dag.sync_to_db() dag_removed = DAG(self.DAG_ID_REMOVED, start_date=self.DEFAULT_DATE) @@ -1252,7 +1250,7 @@ def __init__(self, test, endpoint): self.runs = [] def setup(self): - from airflow.www.views import dagbag + dagbag = self.test.app.dag_bag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) for run_data in self.RUNS_DATA: @@ -2165,7 +2163,7 @@ def test_start_date_filter(self): class TestRenderedView(TestBase): def setUp(self): - super().setUp() + self.default_date = datetime(2020, 3, 1) self.dag = DAG( "testdag", @@ -2187,20 +2185,18 @@ def setUp(self): with create_session() as session: session.query(RTIF).delete() + self.app.dag_bag = mock.MagicMock(**{'get_dag.return_value': self.dag}) + super().setUp() + def tearDown(self) -> None: super().tearDown() with create_session() as session: session.query(RTIF).delete() - @mock.patch('airflow.www.views.STORE_SERIALIZED_DAGS', True) - @mock.patch('airflow.models.taskinstance.STORE_SERIALIZED_DAGS', True) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_rendered_view(self, get_dag_function): + def test_rendered_view(self): """ Test that the Rendered View contains the values from RenderedTaskInstanceFields """ - get_dag_function.return_value = SerializedDagModel.get(self.dag.dag_id).dag - self.assertEqual(self.task1.bash_command, '{{ task_instance_key_str }}') ti = TaskInstance(self.task1, self.default_date) @@ -2213,16 +2209,11 @@ def test_rendered_view(self, get_dag_function): resp = self.client.get(url, follow_redirects=True) self.check_content_in_response("testdag__task1__20200301", resp) - @mock.patch('airflow.www.views.STORE_SERIALIZED_DAGS', True) - @mock.patch('airflow.models.taskinstance.STORE_SERIALIZED_DAGS', True) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_rendered_view_for_unexecuted_tis(self, get_dag_function): + def test_rendered_view_for_unexecuted_tis(self): """ Test that the Rendered View is able to show rendered values even for TIs that have not yet executed """ - get_dag_function.return_value = SerializedDagModel.get(self.dag.dag_id).dag - self.assertEqual(self.task1.bash_command, '{{ task_instance_key_str }}') url = ('rendered?task_id=task1&dag_id=task1&execution_date={}' @@ -2231,16 +2222,15 @@ def test_rendered_view_for_unexecuted_tis(self, get_dag_function): resp = self.client.get(url, follow_redirects=True) self.check_content_in_response("testdag__task1__20200301", resp) - @mock.patch('airflow.www.views.STORE_SERIALIZED_DAGS', True) @mock.patch('airflow.models.taskinstance.STORE_SERIALIZED_DAGS', True) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_user_defined_filter_and_macros_raise_error(self, get_dag_function): + def test_user_defined_filter_and_macros_raise_error(self): """ Test that the Rendered View is able to show rendered values even for TIs that have not yet executed """ - get_dag_function.return_value = SerializedDagModel.get(self.dag.dag_id).dag - + self.app.dag_bag = mock.MagicMock( + **{'get_dag.return_value': SerializedDagModel.get(self.dag.dag_id).dag} + ) self.assertEqual(self.task2.bash_command, 'echo {{ fullname("Apache", "Airflow") | hello }}') @@ -2326,16 +2316,13 @@ def test_trigger_dag_form(self): self.assertEqual(resp.status_code, 200) self.check_content_in_response('Trigger DAG: {}'.format(test_dag_id), resp) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_trigger_endpoint_uses_existing_dagbag(self, mock_get_dag): + def test_trigger_endpoint_uses_existing_dagbag(self): """ Test that Trigger Endpoint uses the DagBag already created in views.py instead of creating a new one. """ - mock_get_dag.return_value = DAG(dag_id='example_bash_operator') url = 'trigger?dag_id=example_bash_operator' resp = self.client.post(url, data={}, follow_redirects=True) - mock_get_dag.assert_called_once_with('example_bash_operator') self.check_content_in_response('example_bash_operator', resp) @@ -2343,7 +2330,7 @@ class TestExtraLinks(TestBase): def setUp(self): from tests.test_utils.mock_operators import Dummy3TestOperator from tests.test_utils.mock_operators import Dummy2TestOperator - super().setUp() + self.endpoint = "extra_links" self.default_date = datetime(2017, 1, 1) @@ -2386,10 +2373,10 @@ class DummyTestOperator(BaseOperator): self.task_2 = Dummy2TestOperator(task_id="some_dummy_task_2", dag=self.dag) self.task_3 = Dummy3TestOperator(task_id="some_dummy_task_3", dag=self.dag) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_extra_links_works(self, get_dag_function): - get_dag_function.return_value = self.dag + self.app.dag_bag = mock.MagicMock(**{'get_dag.return_value': self.dag}) + super().setUp() + def test_extra_links_works(self): response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=foo-bar" .format(self.endpoint, self.dag.dag_id, self.task.task_id, self.default_date), @@ -2405,10 +2392,7 @@ def test_extra_links_works(self, get_dag_function): 'error': None }) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_global_extra_links_works(self, get_dag_function): - get_dag_function.return_value = self.dag - + def test_global_extra_links_works(self): response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=github" .format(self.endpoint, self.dag.dag_id, self.task.task_id, self.default_date), @@ -2423,10 +2407,7 @@ def test_global_extra_links_works(self, get_dag_function): 'error': None }) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_extra_link_in_gantt_view(self, get_dag_function): - get_dag_function.return_value = self.dag - + def test_extra_link_in_gantt_view(self): exec_date = dates.days_ago(2) start_date = datetime(2020, 4, 10, 2, 0, 0) end_date = exec_date + timedelta(seconds=30) @@ -2448,10 +2429,7 @@ def test_extra_link_in_gantt_view(self, get_dag_function): self.assertIn('airflow', extra_links) self.assertIn('github', extra_links) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_operator_extra_link_override_global_extra_link(self, get_dag_function): - get_dag_function.return_value = self.dag - + def test_operator_extra_link_override_global_extra_link(self): response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=airflow".format( self.endpoint, self.dag.dag_id, self.task.task_id, self.default_date), @@ -2466,10 +2444,7 @@ def test_operator_extra_link_override_global_extra_link(self, get_dag_function): 'error': None }) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_extra_links_error_raised(self, get_dag_function): - get_dag_function.return_value = self.dag - + def test_extra_links_error_raised(self): response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=raise_error" .format(self.endpoint, self.dag.dag_id, self.task.task_id, self.default_date), @@ -2483,10 +2458,7 @@ def test_extra_links_error_raised(self, get_dag_function): 'url': None, 'error': 'This is an error'}) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_extra_links_no_response(self, get_dag_function): - get_dag_function.return_value = self.dag - + def test_extra_links_no_response(self): response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=no_response" .format(self.endpoint, self.dag.dag_id, self.task.task_id, self.default_date), @@ -2500,8 +2472,7 @@ def test_extra_links_no_response(self, get_dag_function): 'url': None, 'error': 'No URL found for no_response'}) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_operator_extra_link_override_plugin(self, get_dag_function): + def test_operator_extra_link_override_plugin(self): """ This tests checks if Operator Link (AirflowLink) defined in the Dummy2TestOperator is overriden by Airflow Plugin (AirflowLink2). @@ -2509,8 +2480,6 @@ def test_operator_extra_link_override_plugin(self, get_dag_function): AirflowLink returns 'https://airflow.apache.org/' link AirflowLink2 returns 'https://airflow.apache.org/1.10.5/' link """ - get_dag_function.return_value = self.dag - response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=airflow".format( self.endpoint, self.dag.dag_id, self.task_2.task_id, self.default_date), @@ -2525,8 +2494,7 @@ def test_operator_extra_link_override_plugin(self, get_dag_function): 'error': None }) - @mock.patch('airflow.www.views.dagbag.get_dag') - def test_operator_extra_link_multiple_operators(self, get_dag_function): + def test_operator_extra_link_multiple_operators(self): """ This tests checks if Operator Link (AirflowLink2) defined in Airflow Plugin (AirflowLink2) is attached to all the list of @@ -2535,8 +2503,6 @@ def test_operator_extra_link_multiple_operators(self, get_dag_function): AirflowLink2 returns 'https://airflow.apache.org/1.10.5/' link GoogleLink returns 'https://www.google.com' """ - get_dag_function.return_value = self.dag - response = self.client.get( "{0}?dag_id={1}&task_id={2}&execution_date={3}&link_name=airflow".format( self.endpoint, self.dag.dag_id, self.task_2.task_id, self.default_date),