From 295efd36eac074578e4b54a69d71c2924984326d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Wyszomirski?= Date: Thu, 17 Feb 2022 15:35:11 +0100 Subject: [PATCH] Dataflow Assets (#21639) --- .../providers/apache/beam/operators/beam.py | 28 +++++++- .../providers/google/cloud/links/__init__.py | 16 +++++ .../providers/google/cloud/links/dataflow.py | 64 +++++++++++++++++++ .../google/cloud/operators/dataflow.py | 5 ++ airflow/providers/google/provider.yaml | 1 + .../apache/beam/operators/test_beam.py | 40 ++++++++++-- 6 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 airflow/providers/google/cloud/links/__init__.py create mode 100644 airflow/providers/google/cloud/links/dataflow.py diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index ca976cba72dee..c5e85140d79d5 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -30,6 +30,7 @@ process_line_and_extract_dataflow_job_id_callback, ) from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.links.dataflow import DataflowJobLink from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration from airflow.utils.helpers import convert_camel_to_snake from airflow.version import version @@ -236,6 +237,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator): "dataflow_config", ) template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'} + operator_extra_links = (DataflowJobLink(),) def __init__( self, @@ -301,7 +303,13 @@ def execute(self, context: 'Context'): py_system_site_packages=self.py_system_site_packages, process_line_callback=process_line_callback, ) - + DataflowJobLink.persist( + self, + context, + self.dataflow_config.project_id, + self.dataflow_config.location, + self.dataflow_job_id, + ) if dataflow_job_name and self.dataflow_config.location: self.dataflow_hook.wait_for_done( job_name=dataflow_job_name, @@ -369,6 +377,8 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator): template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'} ui_color = "#0273d4" + operator_extra_links = (DataflowJobLink(),) + def __init__( self, *, @@ -452,6 +462,13 @@ def execute(self, context: 'Context'): if self.dataflow_config.multiple_jobs else False ) + DataflowJobLink.persist( + self, + context, + self.dataflow_config.project_id, + self.dataflow_config.location, + self.dataflow_job_id, + ) self.dataflow_hook.wait_for_done( job_name=dataflow_job_name, location=self.dataflow_config.location, @@ -505,6 +522,7 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator): "dataflow_config", ] template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'} + operator_extra_links = (DataflowJobLink(),) def __init__( self, @@ -565,6 +583,14 @@ def execute(self, context: 'Context'): process_line_callback=process_line_callback, should_init_module=self.should_init_go_module, ) + + DataflowJobLink.persist( + self, + context, + self.dataflow_config.project_id, + self.dataflow_config.location, + self.dataflow_job_id, + ) if dataflow_job_name and self.dataflow_config.location: self.dataflow_hook.wait_for_done( job_name=dataflow_job_name, diff --git a/airflow/providers/google/cloud/links/__init__.py b/airflow/providers/google/cloud/links/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/google/cloud/links/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/google/cloud/links/dataflow.py b/airflow/providers/google/cloud/links/dataflow.py new file mode 100644 index 0000000000000..d8728ac1c8a03 --- /dev/null +++ b/airflow/providers/google/cloud/links/dataflow.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Dataflow links.""" +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from airflow.models import BaseOperator, BaseOperatorLink, XCom + +if TYPE_CHECKING: + from airflow.utils.context import Context + +DATAFLOW_BASE_LINK = "https://pantheon.corp.google.com/dataflow/jobs" +DATAFLOW_JOB_LINK = DATAFLOW_BASE_LINK + "/{region}/{job_id}?project={project_id}" + + +class DataflowJobLink(BaseOperatorLink): + """Helper class for constructing Dataflow Job Link""" + + name = "Dataflow Job" + key = "dataflow_job_config" + + @staticmethod + def persist( + operator_instance: BaseOperator, + context: "Context", + project_id: Optional[str], + region: Optional[str], + job_id: Optional[str], + ): + operator_instance.xcom_push( + context, + key=DataflowJobLink.key, + value={"project_id": project_id, "location": region, "job_id": job_id}, + ) + + def get_link(self, operator: BaseOperator, dttm: datetime) -> str: + conf = XCom.get_one( + key=DataflowJobLink.key, + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + ) + return ( + DATAFLOW_JOB_LINK.format( + project_id=conf["project_id"], region=conf['region'], job_id=conf['job_id'] + ) + if conf + else "" + ) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index ead3eae865e0a..5e50521dee82b 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -31,6 +31,7 @@ process_line_and_extract_dataflow_job_id_callback, ) from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.links.dataflow import DataflowJobLink from airflow.version import version if TYPE_CHECKING: @@ -588,6 +589,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator): "environment", ) ui_color = "#0273d4" + operator_extra_links = (DataflowJobLink(),) def __init__( self, @@ -638,6 +640,7 @@ def execute(self, context: 'Context') -> dict: def set_current_job(current_job): self.job = current_job + DataflowJobLink.persist(self, context, self.project_id, self.location, self.job.get("id")) options = self.dataflow_default_options options.update(self.options) @@ -723,6 +726,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator): """ template_fields: Sequence[str] = ("body", "location", "project_id", "gcp_conn_id") + operator_extra_links = (DataflowJobLink(),) def __init__( self, @@ -760,6 +764,7 @@ def execute(self, context: 'Context'): def set_current_job(current_job): self.job = current_job + DataflowJobLink.persist(self, context, self.project_id, self.location, self.job.get("id")) job = self.hook.start_flex_template( body=self.body, diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index d7dae9f69d9da..873170f345145 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -845,6 +845,7 @@ extra-links: - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetListLink - airflow.providers.google.cloud.operators.cloud_composer.CloudComposerEnvironmentLink - airflow.providers.google.cloud.operators.cloud_composer.CloudComposerEnvironmentsLink + - airflow.providers.google.cloud.links.dataflow.DataflowJobLink - airflow.providers.google.common.links.storage.StorageLink additional-extras: diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index 48246f3fca89c..74c4d9cf5e64c 100644 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -96,10 +96,11 @@ def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): process_line_callback=None, ) + @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist') @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') - def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock): + def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): """Test DataflowHook is created and the right args are passed to start_python_dataflow. """ @@ -127,6 +128,13 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock 'region': 'us-central1', } gcs_provide_file.assert_called_once_with(object_url=PY_FILE) + persist_link_mock.assert_called_once_with( + self.operator, + None, + expected_options['project'], + expected_options['region'], + self.operator.dataflow_job_id, + ) beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with( variables=expected_options, py_file=gcs_provide_file.return_value.__enter__.return_value.name, @@ -144,10 +152,11 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock ) dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() + @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist') @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') - def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __): + def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job self.operator.execute(None) @@ -205,10 +214,11 @@ def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): process_line_callback=None, ) + @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist') @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') - def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock): + def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): """Test DataflowHook is created and the right args are passed to start_java_dataflow. """ @@ -238,7 +248,13 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, 'output': 'gs://test/output', } - + persist_link_mock.assert_called_once_with( + self.operator, + None, + expected_options['project'], + expected_options['region'], + self.operator.dataflow_job_id, + ) beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with( variables=expected_options, jar=gcs_provide_file.return_value.__enter__.return_value.name, @@ -253,10 +269,11 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock project_id=dataflow_hook_mock.return_value.project_id, ) + @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist') @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') - def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __): + def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job @@ -344,6 +361,7 @@ def test_exec_source_on_local_path(self, init_module, beam_hook_mock): should_init_module=False, ) + @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist') @mock.patch( "tempfile.TemporaryDirectory", return_value=MagicMock(__enter__=MagicMock(return_value='/tmp/apache-beam-go')), @@ -351,7 +369,7 @@ def test_exec_source_on_local_path(self, init_module, beam_hook_mock): @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') - def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, _): + def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, _, persist_link_mock): """Test DataflowHook is created and the right args are passed to start_go_dataflow. """ @@ -378,6 +396,13 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, 'region': 'us-central1', } + persist_link_mock.assert_called_once_with( + self.operator, + None, + expected_options['project'], + expected_options['region'], + self.operator.dataflow_job_id, + ) gcs_provide_file.assert_called_once_with(object_url=GO_FILE, dir='/tmp/apache-beam-go') beam_hook_mock.return_value.start_go_pipeline.assert_called_once_with( variables=expected_options, @@ -393,10 +418,11 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock ) dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() + @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist') @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') - def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __): + def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): self.operator.runner = "DataflowRunner" dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job self.operator.execute(None)