Skip to content

Commit

Permalink
Use current_app.dag_bag instead of global variable (#9380)
Browse files Browse the repository at this point in the history
* Use current_app.dag_bag instead of global variable

* fixup! Use current_app.dag_bag instead of global variable

* fixup! fixup! Use current_app.dag_bag instead of global variable
  • Loading branch information
turbaszek authored Jun 19, 2020
1 parent c7e5bce commit 50318f8
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 91 deletions.
4 changes: 4 additions & 0 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#

from datetime import timedelta
from typing import Optional

Expand All @@ -30,6 +31,7 @@
from airflow.utils.json import AirflowJsonEncoder
from airflow.www.extensions.init_appbuilder import init_appbuilder
from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links
from airflow.www.extensions.init_dagbag import init_dagbag
from airflow.www.extensions.init_jinja_globals import init_jinja_globals
from airflow.www.extensions.init_manifest_files import configure_manifest_files
from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection
Expand Down Expand Up @@ -85,6 +87,8 @@ def create_app(config=None, testing=False, app_name="Airflow"):
db.session = settings.Session
db.init_app(flask_app)

init_dagbag(flask_app)

init_api_experimental_auth(flask_app)

Cache(app=flask_app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'})
Expand Down
32 changes: 32 additions & 0 deletions airflow/www/extensions/init_dagbag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os

from airflow.models import DagBag
from airflow.settings import DAGS_FOLDER, STORE_SERIALIZED_DAGS


def init_dagbag(app):
"""
Create global DagBag for webserver and API. To access it use
``flask.current_app.dag_bag``.
"""
if os.environ.get('SKIP_DAGS_PARSING') == 'True':
app.dag_bag = DagBag(os.devnull, include_examples=False)
else:
app.dag_bag = DagBag(DAGS_FOLDER, store_serialized_dags=STORE_SERIALIZED_DAGS)
55 changes: 24 additions & 31 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import json
import logging
import math
import os
import socket
import traceback
from collections import defaultdict
Expand Down Expand Up @@ -61,7 +60,6 @@
from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, XCom, errors
from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun, DagRunType
from airflow.settings import STORE_SERIALIZED_DAGS
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS
from airflow.utils import timezone
Expand All @@ -81,11 +79,6 @@
FILTER_TAGS_COOKIE = 'tags_filter'
FILTER_STATUS_COOKIE = 'dag_status_filter'

if os.environ.get('SKIP_DAGS_PARSING') != 'True':
dagbag = models.DagBag(settings.DAGS_FOLDER, store_serialized_dags=STORE_SERIALIZED_DAGS)
else:
dagbag = models.DagBag(os.devnull, include_examples=False)


def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
Expand Down Expand Up @@ -577,7 +570,7 @@ def code(self, session=None):
@provide_session
def dag_details(self, session=None):
dag_id = request.args.get('dag_id')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
title = "DAG details"
root = request.args.get('root', '')

Expand Down Expand Up @@ -612,7 +605,7 @@ def rendered(self):
root = request.args.get('root', '')

logging.info("Retrieving rendered templates.")
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

task = copy.copy(dag.get_task(task_id))
ti = models.TaskInstance(task=task, execution_date=dttm)
Expand Down Expand Up @@ -700,7 +693,7 @@ def _get_logs_with_metadata(try_number, metadata):

try:
if ti is not None:
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
if dag:
ti.task = dag.get_task(ti.task_id)
if response_format == 'json':
Expand Down Expand Up @@ -803,7 +796,7 @@ def task(self):
dttm = timezone.parse(execution_date)
form = DateTimeForm(data={'execution_date': dttm})
root = request.args.get('root', '')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

if not dag or task_id not in dag.task_ids:
flash(
Expand Down Expand Up @@ -921,7 +914,7 @@ def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
task = dag.get_task(task_id)

execution_date = request.form.get('execution_date')
Expand Down Expand Up @@ -1052,7 +1045,7 @@ def trigger(self, session=None):
conf=conf
)

dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=execution_date,
Expand Down Expand Up @@ -1117,7 +1110,7 @@ def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

execution_date = request.form.get('execution_date')
execution_date = timezone.parse(execution_date)
Expand Down Expand Up @@ -1150,7 +1143,7 @@ def dagrun_clear(self):
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"

dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
execution_date = timezone.parse(execution_date)
start_date = execution_date
end_date = execution_date
Expand Down Expand Up @@ -1192,7 +1185,7 @@ def blocked(self, session=None):
payload = []
for dag_id, active_dag_runs in dags:
max_active_runs = 0
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
if dag:
# TODO: Make max_active_runs a column so we can query for it directly
max_active_runs = dag.max_active_runs
Expand All @@ -1209,7 +1202,7 @@ def _mark_dagrun_state_as_failed(self, dag_id, execution_date, confirmed, origin
return redirect(origin)

execution_date = timezone.parse(execution_date)
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

if not dag:
flash('Cannot find DAG: {}'.format(dag_id), 'error')
Expand Down Expand Up @@ -1237,7 +1230,7 @@ def _mark_dagrun_state_as_success(self, dag_id, execution_date, confirmed, origi
return redirect(origin)

execution_date = timezone.parse(execution_date)
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

if not dag:
flash('Cannot find DAG: {}'.format(dag_id), 'error')
Expand Down Expand Up @@ -1287,7 +1280,7 @@ def dagrun_success(self):
def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date,
confirmed, upstream, downstream,
future, past, state):
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
task = dag.get_task(task_id)
task.dag = dag

Expand Down Expand Up @@ -1371,7 +1364,7 @@ def success(self):
def tree(self):
dag_id = request.args.get('dag_id')
blur = conf.getboolean('webserver', 'demo_mode')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
if not dag:
flash('DAG "{0}" seems to be missing from DagBag.'.format(dag_id), "error")
return redirect(url_for('Airflow.index'))
Expand Down Expand Up @@ -1528,7 +1521,7 @@ def recurse_nodes(task, visited):
def graph(self, session=None):
dag_id = request.args.get('dag_id')
blur = conf.getboolean('webserver', 'demo_mode')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
if not dag:
flash('DAG "{0}" seems to be missing.'.format(dag_id), "error")
return redirect(url_for('Airflow.index'))
Expand Down Expand Up @@ -1627,7 +1620,7 @@ class GraphForm(DateTimeWithNumRunsWithDagRunsForm):
def duration(self, session=None):
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
dag_id = request.args.get('dag_id')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
base_date = request.args.get('base_date')
num_runs = request.args.get('num_runs')
num_runs = int(num_runs) if num_runs else default_dag_run
Expand Down Expand Up @@ -1739,7 +1732,7 @@ def duration(self, session=None):
def tries(self, session=None):
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
dag_id = request.args.get('dag_id')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
base_date = request.args.get('base_date')
num_runs = request.args.get('num_runs')
num_runs = int(num_runs) if num_runs else default_dag_run
Expand Down Expand Up @@ -1804,7 +1797,7 @@ def tries(self, session=None):
def landing_times(self, session=None):
default_dag_run = conf.getint('webserver', 'default_dag_run_display_number')
dag_id = request.args.get('dag_id')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
base_date = request.args.get('base_date')
num_runs = request.args.get('num_runs')
num_runs = int(num_runs) if num_runs else default_dag_run
Expand Down Expand Up @@ -1902,7 +1895,7 @@ def refresh(self, session=None):
session.merge(orm_dag)
session.commit()

dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
# sync dag permission
current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)

Expand All @@ -1916,7 +1909,7 @@ def refresh(self, session=None):
@provide_session
def gantt(self, session=None):
dag_id = request.args.get('dag_id')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)
demo_mode = conf.getboolean('webserver', 'demo_mode')

root = request.args.get('root')
Expand Down Expand Up @@ -2029,7 +2022,7 @@ def extra_links(self):
execution_date = request.args.get('execution_date')
link_name = request.args.get('link_name')
dttm = timezone.parse(execution_date)
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

if not dag or task_id not in dag.task_ids:
response = jsonify(
Expand Down Expand Up @@ -2066,7 +2059,7 @@ def extra_links(self):
@action_logging
def task_instances(self):
dag_id = request.args.get('dag_id')
dag = dagbag.get_dag(dag_id)
dag = current_app.dag_bag.get_dag(dag_id)

dttm = request.args.get('execution_date')
if dttm:
Expand Down Expand Up @@ -2544,7 +2537,7 @@ def action_set_failed(self, drs, session=None):
dirty_ids.append(dr.dag_id)
count += 1
altered_tis += \
set_dag_run_state_to_failed(dagbag.get_dag(dr.dag_id),
set_dag_run_state_to_failed(current_app.dag_bag.get_dag(dr.dag_id),
dr.execution_date,
commit=True,
session=session)
Expand All @@ -2571,7 +2564,7 @@ def action_set_success(self, drs, session=None):
dirty_ids.append(dr.dag_id)
count += 1
altered_tis += \
set_dag_run_state_to_success(dagbag.get_dag(dr.dag_id),
set_dag_run_state_to_success(current_app.dag_bag.get_dag(dr.dag_id),
dr.execution_date,
commit=True,
session=session)
Expand Down Expand Up @@ -2665,7 +2658,7 @@ def action_clear(self, tis, session=None):
dag_to_tis = {}

for ti in tis:
dag = dagbag.get_dag(ti.dag_id)
dag = current_app.dag_bag.get_dag(ti.dag_id)
tis = dag_to_tis.setdefault(dag, [])
tis.append(ti)

Expand Down
Loading

0 comments on commit 50318f8

Please sign in to comment.