Skip to content

Commit

Permalink
Use google cloud credentials when executing beam command in subprocess (
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWheating authored Oct 17, 2021
1 parent 80b5e65 commit a418fd9
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
45 changes: 29 additions & 16 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,24 +237,36 @@ def execute(self, context):
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file))
self.py_file = tmp_gcs_file.name

self.beam_hook.start_python_pipeline(
variables=formatted_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=process_line_callback,
)

if is_dataflow:
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_python_pipeline(
variables=formatted_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=process_line_callback,
)

self.dataflow_hook.wait_for_done(
job_name=dataflow_job_name,
location=self.dataflow_config.location,
job_id=self.dataflow_job_id,
multiple_jobs=False,
)

else:
self.beam_hook.start_python_pipeline(
variables=formatted_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=process_line_callback,
)

return {"dataflow_job_id": self.dataflow_job_id}

def on_kill(self) -> None:
Expand Down Expand Up @@ -418,12 +430,13 @@ def execute(self, context):
)
if not is_running:
pipeline_options["jobName"] = dataflow_job_name
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
self.dataflow_hook.wait_for_done(
job_name=dataflow_job_name,
location=self.dataflow_config.location,
Expand Down
32 changes: 17 additions & 15 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,13 @@ def set_current_job_id(job_id):
)
if not is_running:
pipeline_options["jobName"] = job_name
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
self.dataflow_hook.wait_for_done(
job_name=job_name,
location=self.location,
Expand Down Expand Up @@ -1142,15 +1143,16 @@ def set_current_job_id(job_id):
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file))
self.py_file = tmp_gcs_file.name

self.beam_hook.start_python_pipeline(
variables=formatted_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=process_line_callback,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_python_pipeline(
variables=formatted_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=process_line_callback,
)

self.dataflow_hook.wait_for_done(
job_name=job_name,
Expand Down
1 change: 1 addition & 0 deletions tests/providers/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
location='us-central1',
multiple_jobs=False,
)
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()

@mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
@mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
Expand Down
5 changes: 5 additions & 0 deletions tests/providers/google/cloud/operators/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_
"""
start_python_mock = beam_hook_mock.return_value.start_python_pipeline
provide_gcloud_mock = dataflow_hook_mock.return_value.provide_authorized_gcloud
gcs_provide_file = gcs_hook.return_value.provide_file
job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
self.dataflow.execute(None)
Expand Down Expand Up @@ -169,6 +170,7 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_
multiple_jobs=False,
)
assert self.dataflow.py_file.startswith('/tmp/dataflow')
provide_gcloud_mock.assert_called_once_with()


class TestDataflowJavaOperator(unittest.TestCase):
Expand Down Expand Up @@ -210,6 +212,7 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_
start_java_mock = beam_hook_mock.return_value.start_java_pipeline
gcs_provide_file = gcs_hook.return_value.provide_file
job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
provide_gcloud_mock = dataflow_hook_mock.return_value.provide_authorized_gcloud
self.dataflow.check_if_running = CheckJobRunning.IgnoreJob

self.dataflow.execute(None)
Expand Down Expand Up @@ -238,6 +241,8 @@ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_
multiple_jobs=None,
)

provide_gcloud_mock.assert_called_once_with()

@mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
Expand Down

0 comments on commit a418fd9

Please sign in to comment.