From 80a957f142f260daed262b8e93a4d02c12cfeabc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobiasz=20K=C4=99dzierski?= Date: Tue, 17 Nov 2020 11:43:13 +0100 Subject: [PATCH] Add Dataflow sensors - job metrics (#12039) --- .../cloud/example_dags/example_dataflow.py | 27 ++++++- .../providers/google/cloud/hooks/dataflow.py | 49 ++++++++++++ .../google/cloud/operators/dataflow.py | 2 +- .../google/cloud/sensors/dataflow.py | 74 ++++++++++++++++++- .../google/cloud/hooks/test_dataflow.py | 31 ++++++++ .../google/cloud/sensors/test_dataflow.py | 34 ++++++++- 6 files changed, 213 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index d36c48e752bac..398967486fdf3 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -20,9 +20,11 @@ Example Airflow DAG for Google Cloud Dataflow service """ import os +from typing import Callable, Dict, List from urllib.parse import urlparse from airflow import models +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, @@ -30,7 +32,7 @@ DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator, ) -from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor +from airflow.providers.google.cloud.sensors.dataflow import DataflowJobMetricsSensor, DataflowJobStatusSensor from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator from airflow.utils.dates import days_ago @@ -159,7 +161,30 @@ location='europe-west3', ) + def check_metric_scalar_gte(metric_name: str, value: int) -> Callable: + """Check is metric greater than equals to given value.""" + + def callback(metrics: List[Dict]) -> bool: + dag_native_python_async.log.info("Looking for '%s' >= %d", metric_name, value) + for metric in metrics: + context = metric.get("name", {}).get("context", {}) + original_name = context.get("original_name", "") + tentative = context.get("tentative", "") + if original_name == "Service-cpu_num_seconds" and not tentative: + return metric["scalar"] >= value + raise AirflowException(f"Metric '{metric_name}' not found in metrics") + + return callback + + wait_for_python_job_async_metric = DataflowJobMetricsSensor( + task_id="wait-for-python-job-async-metric", + job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}", + location='europe-west3', + callback=check_metric_scalar_gte(metric_name="Service-cpu_num_seconds", value=100), + ) + start_python_job_async >> wait_for_python_job_async_done + start_python_job_async >> wait_for_python_job_async_metric with models.DAG( diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 3d51db00b43d9..24179a264f625 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -243,6 +243,27 @@ def fetch_job_by_id(self, job_id: str) -> dict: .execute(num_retries=self._num_retries) ) + def fetch_job_metrics_by_id(self, job_id: str) -> dict: + """ + Helper method to fetch the job metrics with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :return: the JobMetrics. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics + :rtype: dict + """ + result = ( + self._dataflow.projects() + .locations() + .jobs() + .getMetrics(projectId=self._project_number, location=self._job_location, jobId=job_id) + .execute(num_retries=self._num_retries) + ) + + self.log.debug("fetch_job_metrics_by_id %s:\n%s", job_id, result) + return result + def _fetch_all_jobs(self) -> List[dict]: request = ( self._dataflow.projects() @@ -1101,3 +1122,31 @@ def get_job( location=location, ) return jobs_controller.fetch_job_by_id(job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def fetch_job_metrics_by_id( + self, + job_id: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> dict: + """ + Gets the job metrics with the specified Job ID. + + :param job_id: Job ID to get. + :type job_id: str + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: + :param location: The location of the Dataflow job (for example europe-west1). See: + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints + :return: the JobMetrics. See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics + :rtype: dict + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + return jobs_controller.fetch_job_metrics_by_id(job_id) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 710401da836c4..49863dcf4bce5 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -909,7 +909,7 @@ def __init__( # pylint: disable=too-many-arguments self.cancel_timeout = cancel_timeout self.wait_until_finished = wait_until_finished self.job_id = None - self.hook = None + self.hook: Optional[DataflowHook] = None def execute(self, context): """Execute the python dataflow job.""" diff --git a/airflow/providers/google/cloud/sensors/dataflow.py b/airflow/providers/google/cloud/sensors/dataflow.py index df0f96b11504d..d2f77d46056c6 100644 --- a/airflow/providers/google/cloud/sensors/dataflow.py +++ b/airflow/providers/google/cloud/sensors/dataflow.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Cloud Dataflow sensor.""" -from typing import Optional, Sequence, Set, Union +from typing import Callable, Optional, Sequence, Set, Union from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( @@ -116,3 +116,75 @@ def poke(self, context: dict) -> bool: raise AirflowException(f"Job with id '{self.job_id}' is already in terminal state: {job_status}") return False + + +class DataflowJobMetricsSensor(BaseSensorOperator): + """ + Checks the metrics of a job in Google Cloud Dataflow. + + :param job_id: ID of the job to be checked. + :type job_id: str + :param callback: callback which is called with list of read job metrics + See: + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate + :type callback: callable + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :type project_id: str + :param location: The location of the Dataflow job (for example europe-west1). See: + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :type gcp_conn_id: str + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ['job_id'] + + @apply_defaults + def __init__( + self, + *, + job_id: str, + callback: Callable[[dict], bool], + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_id = job_id + self.project_id = project_id + self.callback = callback + self.location = location + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + self.hook: Optional[DataflowHook] = None + + def poke(self, context: dict) -> bool: + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + result = self.hook.fetch_job_metrics_by_id( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + + return self.callback(result["metrics"]) diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index a5a0089aeeb80..70ccdd189f3ff 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -648,6 +648,37 @@ def test_get_job(self, mock_conn, mock_dataflowjob): ) method_fetch_job_by_id.assert_called_once_with(TEST_JOB_ID) + @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_fetch_job_metrics_by_id(self, mock_conn, mock_dataflowjob): + method_fetch_job_metrics_by_id = mock_dataflowjob.return_value.fetch_job_metrics_by_id + + self.dataflow_hook.fetch_job_metrics_by_id( + job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION + ) + mock_conn.assert_called_once() + mock_dataflowjob.assert_called_once_with( + dataflow=mock_conn.return_value, + project_number=TEST_PROJECT_ID, + location=TEST_LOCATION, + ) + method_fetch_job_metrics_by_id.assert_called_once_with(TEST_JOB_ID) + + @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) + def test_fetch_job_metrics_by_id_controller(self, mock_conn): + method_get_metrics = ( + mock_conn.return_value.projects.return_value.locations.return_value.jobs.return_value.getMetrics + ) + self.dataflow_hook.fetch_job_metrics_by_id( + job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION + ) + + mock_conn.assert_called_once() + method_get_metrics.return_value.execute.assert_called_once_with(num_retries=0) + method_get_metrics.assert_called_once_with( + jobId=TEST_JOB_ID, projectId=TEST_PROJECT_ID, location=TEST_LOCATION + ) + class TestDataflowTemplateHook(unittest.TestCase): def setUp(self): diff --git a/tests/providers/google/cloud/sensors/test_dataflow.py b/tests/providers/google/cloud/sensors/test_dataflow.py index 9dd6706878342..0d930b9ccb28b 100644 --- a/tests/providers/google/cloud/sensors/test_dataflow.py +++ b/tests/providers/google/cloud/sensors/test_dataflow.py @@ -23,7 +23,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus -from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor +from airflow.providers.google.cloud.sensors.dataflow import DataflowJobMetricsSensor, DataflowJobStatusSensor TEST_TASK_ID = "tesk-id" TEST_JOB_ID = "test_job_id" @@ -98,3 +98,35 @@ def test_poke_raise_exception(self, mock_hook): mock_get_job.assert_called_once_with( job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION ) + + +class TestDataflowJobMetricsSensor(unittest.TestCase): + @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") + def test_poke(self, mock_hook): + mock_fetch_job_metrics_by_id = mock_hook.return_value.fetch_job_metrics_by_id + callback = mock.MagicMock() + + task = DataflowJobMetricsSensor( + task_id=TEST_TASK_ID, + job_id=TEST_JOB_ID, + callback=callback, + location=TEST_LOCATION, + project_id=TEST_PROJECT_ID, + gcp_conn_id=TEST_GCP_CONN_ID, + delegate_to=TEST_DELEGATE_TO, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + results = task.poke(mock.MagicMock()) + + self.assertEqual(callback.return_value, results) + + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, + delegate_to=TEST_DELEGATE_TO, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + mock_fetch_job_metrics_by_id.assert_called_once_with( + job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION + ) + mock_fetch_job_metrics_by_id.return_value.__getitem__.assert_called_once_with("metrics") + callback.assert_called_once_with(mock_fetch_job_metrics_by_id.return_value.__getitem__.return_value)