Skip to content

Commit

Permalink
Fix DagRun.conf when using trigger_dag API (apache#9853)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored and scrambldchannel committed Jul 17, 2020
1 parent 0cb7161 commit d4d3433
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 25 deletions.
2 changes: 1 addition & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self.execution_date = execution_date
self.start_date = start_date
self.external_trigger = external_trigger
self.conf = conf
self.conf = conf or {}
self.state = state
self.run_type = run_type
super().__init__()
Expand Down
41 changes: 17 additions & 24 deletions tests/api/common/experimental/test_trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@
# specific language governing permissions and limitations
# under the License.

import json
import unittest
from unittest import mock

from parameterized import parameterized

from airflow.api.common.experimental.trigger_dag import _trigger_dag
from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun
from airflow.utils import timezone
from tests.test_utils import db


class TestTriggerDag(unittest.TestCase):

def setUp(self) -> None:
db.clear_db_runs()

def tearDown(self) -> None:
db.clear_db_runs()

@mock.patch('airflow.models.DagRun')
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_dag_not_found(self, dag_bag_mock, dag_run_mock):
Expand Down Expand Up @@ -114,25 +122,6 @@ def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, da

self.assertEqual(3, len(triggers))

@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_str_conf(self, dag_bag_mock):
dag_id = "trigger_dag_with_str_conf"
dag = DAG(dag_id)
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
conf = "{\"foo\": \"bar\"}"
dag_run = DagRun()
triggers = _trigger_dag(
dag_id,
dag_bag_mock,
dag_run,
run_id=None,
conf=conf,
execution_date=None,
replace_microseconds=True)

self.assertEqual(triggers[0].conf, json.loads(conf))

@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_too_early_start_date(self, dag_bag_mock):
dag_id = "trigger_dag_with_too_early_start_date"
Expand Down Expand Up @@ -173,13 +162,17 @@ def test_trigger_dag_with_valid_start_date(self, dag_bag_mock):

assert len(triggers) == 1

@parameterized.expand([
(None, {}),
({"foo": "bar"}, {"foo": "bar"}),
('{"foo": "bar"}', {"foo": "bar"}),
])
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_dict_conf(self, dag_bag_mock):
dag_id = "trigger_dag_with_dict_conf"
def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock):
dag_id = "trigger_dag_with_conf"
dag = DAG(dag_id)
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
conf = dict(foo="bar")
dag_run = DagRun()
triggers = _trigger_dag(
dag_id,
Expand All @@ -190,4 +183,4 @@ def test_trigger_dag_with_dict_conf(self, dag_bag_mock):
execution_date=None,
replace_microseconds=True)

self.assertEqual(triggers[0].conf, conf)
self.assertEqual(triggers[0].conf, expected_conf)

0 comments on commit d4d3433

Please sign in to comment.