diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index da57feae61de9..cc00e534a2e6d 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -237,17 +237,18 @@ def execute(self, context): tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file)) self.py_file = tmp_gcs_file.name - self.beam_hook.start_python_pipeline( - variables=formatted_pipeline_options, - py_file=self.py_file, - py_options=self.py_options, - py_interpreter=self.py_interpreter, - py_requirements=self.py_requirements, - py_system_site_packages=self.py_system_site_packages, - process_line_callback=process_line_callback, - ) - if is_dataflow: + with self.dataflow_hook.provide_authorized_gcloud(): + self.beam_hook.start_python_pipeline( + variables=formatted_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + process_line_callback=process_line_callback, + ) + self.dataflow_hook.wait_for_done( job_name=dataflow_job_name, location=self.dataflow_config.location, @@ -255,6 +256,17 @@ def execute(self, context): multiple_jobs=False, ) + else: + self.beam_hook.start_python_pipeline( + variables=formatted_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + process_line_callback=process_line_callback, + ) + return {"dataflow_job_id": self.dataflow_job_id} def on_kill(self) -> None: @@ -418,12 +430,13 @@ def execute(self, context): ) if not is_running: pipeline_options["jobName"] = dataflow_job_name - self.beam_hook.start_java_pipeline( - variables=pipeline_options, - jar=self.jar, - job_class=self.job_class, - process_line_callback=process_line_callback, - ) + with self.dataflow_hook.provide_authorized_gcloud(): + self.beam_hook.start_java_pipeline( + variables=pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=process_line_callback, + ) self.dataflow_hook.wait_for_done( job_name=dataflow_job_name, location=self.dataflow_config.location, diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index e9d89c6c02e64..fbe7aaadfff7f 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -448,12 +448,13 @@ def set_current_job_id(job_id): ) if not is_running: pipeline_options["jobName"] = job_name - self.beam_hook.start_java_pipeline( - variables=pipeline_options, - jar=self.jar, - job_class=self.job_class, - process_line_callback=process_line_callback, - ) + with self.dataflow_hook.provide_authorized_gcloud(): + self.beam_hook.start_java_pipeline( + variables=pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=process_line_callback, + ) self.dataflow_hook.wait_for_done( job_name=job_name, location=self.location, @@ -1142,15 +1143,16 @@ def set_current_job_id(job_id): tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file)) self.py_file = tmp_gcs_file.name - self.beam_hook.start_python_pipeline( - variables=formatted_pipeline_options, - py_file=self.py_file, - py_options=self.py_options, - py_interpreter=self.py_interpreter, - py_requirements=self.py_requirements, - py_system_site_packages=self.py_system_site_packages, - process_line_callback=process_line_callback, - ) + with self.dataflow_hook.provide_authorized_gcloud(): + self.beam_hook.start_python_pipeline( + variables=formatted_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + process_line_callback=process_line_callback, + ) self.dataflow_hook.wait_for_done( job_name=job_name, diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index ce8f2b5dc07fd..51b3df92712a3 100644 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -139,6 +139,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock location='us-central1', multiple_jobs=False, ) + dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 5c605d67c8143..67a4bb926995a 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -129,6 +129,7 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_ """ start_python_mock = beam_hook_mock.return_value.start_python_pipeline + provide_gcloud_mock = dataflow_hook_mock.return_value.provide_authorized_gcloud gcs_provide_file = gcs_hook.return_value.provide_file job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value self.dataflow.execute(None) @@ -169,6 +170,7 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_ multiple_jobs=False, ) assert self.dataflow.py_file.startswith('/tmp/dataflow') + provide_gcloud_mock.assert_called_once_with() class TestDataflowJavaOperator(unittest.TestCase): @@ -210,6 +212,7 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_ start_java_mock = beam_hook_mock.return_value.start_java_pipeline gcs_provide_file = gcs_hook.return_value.provide_file job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value + provide_gcloud_mock = dataflow_hook_mock.return_value.provide_authorized_gcloud self.dataflow.check_if_running = CheckJobRunning.IgnoreJob self.dataflow.execute(None) @@ -238,6 +241,8 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_ multiple_jobs=None, ) + provide_gcloud_mock.assert_called_once_with() + @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')