Skip to content

Commit

Permalink
Refactor datapipeline operators and fix system tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova authored and potiuk committed Jun 14, 2024
1 parent bffb7b0 commit 7aafd3f
Show file tree
Hide file tree
Showing 21 changed files with 1,224 additions and 4,711 deletions.
8 changes: 8 additions & 0 deletions airflow/contrib/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@
"airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator"
),
},
"datapipeline_operator": {
"CreateDataPipelineOperator": (
"airflow.providers.google.cloud.operators.datapipeline.CreateDataPipelineOperator"
),
"RunDataPipelineOperator": (
"airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator"
),
},
"dataproc_operator": {
"DataprocCreateClusterOperator": (
"airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator"
Expand Down
131 changes: 131 additions & 0 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,11 @@ def get_conn(self) -> Resource:
http_authorized = self._authorize()
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)

def get_pipelines_conn(self) -> build:
"""Return a Google Cloud Data Pipelines service object."""
http_authorized = self._authorize()
return build("datapipelines", "v1", http=http_authorized, cache_discovery=False)

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -1351,6 +1356,132 @@ def is_job_done(self, location: str, project_id: str, job_id: str) -> bool:

return job_controller._check_dataflow_job_state(job)

@GoogleBaseHook.fallback_to_default_project_id
def create_data_pipeline(
self,
body: dict,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
):
"""
Create a new Dataflow Data Pipelines instance.
:param body: The request body (contains instance of Pipeline). See:
https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Data Pipelines instance in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_pipelines_conn()
request = (
service.projects()
.locations()
.pipelines()
.create(
parent=parent,
body=body,
)
)
response = request.execute(num_retries=self.num_retries)
return response

@GoogleBaseHook.fallback_to_default_project_id
def get_data_pipeline(
self,
pipeline_name: str,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> dict:
"""
Retrieve a new Dataflow Data Pipelines instance.
:param pipeline_name: The display name of the pipeline. In example
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Data Pipelines instance in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_pipelines_conn()
request = (
service.projects()
.locations()
.pipelines()
.get(
name=f"{parent}/pipelines/{pipeline_name}",
)
)
response = request.execute(num_retries=self.num_retries)
return response

@GoogleBaseHook.fallback_to_default_project_id
def run_data_pipeline(
self,
pipeline_name: str,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> dict:
"""
Run a Dataflow Data Pipeline Instance.
:param pipeline_name: The display name of the pipeline. In example
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Job in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_pipelines_conn()
request = (
service.projects()
.locations()
.pipelines()
.run(
name=f"{parent}/pipelines/{pipeline_name}",
body={},
)
)
response = request.execute(num_retries=self.num_retries)
return response

@GoogleBaseHook.fallback_to_default_project_id
def delete_data_pipeline(
self,
pipeline_name: str,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> dict | None:
"""
Delete a Dataflow Data Pipelines Instance.
:param pipeline_name: The display name of the pipeline. In example
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Job in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_pipelines_conn()
request = (
service.projects()
.locations()
.pipelines()
.delete(
name=f"{parent}/pipelines/{pipeline_name}",
)
)
response = request.execute(num_retries=self.num_retries)
return response

@staticmethod
def build_parent_name(project_id: str, location: str):
return f"projects/{project_id}/locations/{location}"


class AsyncDataflowHook(GoogleBaseAsyncHook):
"""Async hook class for dataflow service."""
Expand Down
95 changes: 22 additions & 73 deletions airflow/providers/google/cloud/hooks/datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,103 +19,52 @@

from __future__ import annotations

from typing import Sequence
from typing import TYPE_CHECKING

from googleapiclient.discovery import build
from deprecated import deprecated

from airflow.providers.google.common.hooks.base_google import (
GoogleBaseHook,
)

DEFAULT_DATAPIPELINE_LOCATION = "us-central1"
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.dataflow import DataflowHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook

if TYPE_CHECKING:
from googleapiclient.discovery import build

class DataPipelineHook(GoogleBaseHook):
"""
Hook for Google Data Pipelines.
DEFAULT_DATAPIPELINE_LOCATION = "us-central1"

All the methods in the hook where project_id is used must be called with
keyword arguments rather than positional.
"""

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
@deprecated(
reason="This hook is deprecated and will be removed after 01.12.2024. Please use `DataflowHook`.",
category=AirflowProviderDeprecationWarning,
)
class DataPipelineHook(DataflowHook):
"""Hook for Google Data Pipelines."""

def get_conn(self) -> build:
"""Return a Google Cloud Data Pipelines service object."""
http_authorized = self._authorize()
return build("datapipelines", "v1", http=http_authorized, cache_discovery=False)
return super().get_pipelines_conn()

@GoogleBaseHook.fallback_to_default_project_id
def create_data_pipeline(
self,
body: dict,
project_id: str,
location: str = DEFAULT_DATAPIPELINE_LOCATION,
) -> None:
"""
Create a new Data Pipelines instance from the Data Pipelines API.
:param body: The request body (contains instance of Pipeline). See:
https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Data Pipelines instance in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_conn()
self.log.info(dir(service.projects().locations()))
request = (
service.projects()
.locations()
.pipelines()
.create(
parent=parent,
body=body,
)
)
response = request.execute(num_retries=self.num_retries)
return response
) -> dict:
"""Create a new Data Pipelines instance from the Data Pipelines API."""
return super().create_data_pipeline(body=body, project_id=project_id, location=location)

@GoogleBaseHook.fallback_to_default_project_id
def run_data_pipeline(
self,
data_pipeline_name: str,
project_id: str,
location: str = DEFAULT_DATAPIPELINE_LOCATION,
) -> None:
"""
Run a Data Pipelines Instance using the Data Pipelines API.
:param data_pipeline_name: The display name of the pipeline. In example
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
:param project_id: The ID of the GCP project that owns the job.
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
Returns the created Job in JSON representation.
"""
parent = self.build_parent_name(project_id, location)
service = self.get_conn()
request = (
service.projects()
.locations()
.pipelines()
.run(
name=f"{parent}/pipelines/{data_pipeline_name}",
body={},
)
) -> dict:
"""Run a Data Pipelines Instance using the Data Pipelines API."""
return super().run_data_pipeline(
pipeline_name=data_pipeline_name, project_id=project_id, location=location
)
response = request.execute(num_retries=self.num_retries)
return response

@staticmethod
def build_parent_name(project_id: str, location: str):
Expand Down
25 changes: 25 additions & 0 deletions airflow/providers/google/cloud/links/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
DATAFLOW_BASE_LINK = "/dataflow/jobs"
DATAFLOW_JOB_LINK = DATAFLOW_BASE_LINK + "/{region}/{job_id}?project={project_id}"

DATAFLOW_PIPELINE_BASE_LINK = "/dataflow/pipelines"
DATAFLOW_PIPELINE_LINK = DATAFLOW_PIPELINE_BASE_LINK + "/{location}/{pipeline_name}?project={project_id}"


class DataflowJobLink(BaseGoogleLink):
"""Helper class for constructing Dataflow Job Link."""
Expand All @@ -51,3 +54,25 @@ def persist(
key=DataflowJobLink.key,
value={"project_id": project_id, "region": region, "job_id": job_id},
)


class DataflowPipelineLink(BaseGoogleLink):
"""Helper class for constructing Dataflow Pipeline Link."""

name = "Dataflow Pipeline"
key = "dataflow_pipeline_config"
format_str = DATAFLOW_PIPELINE_LINK

@staticmethod
def persist(
operator_instance: BaseOperator,
context: Context,
project_id: str | None,
location: str | None,
pipeline_name: str | None,
):
operator_instance.xcom_push(
context,
key=DataflowPipelineLink.key,
value={"project_id": project_id, "location": location, "pipeline_name": pipeline_name},
)
Loading

0 comments on commit 7aafd3f

Please sign in to comment.