Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix replace parameter for BigQueryToPostgresOperator #40278

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion airflow/providers/google/cloud/transfers/bigquery_to_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,86 @@

from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
from airflow.providers.google.cloud.utils.bigquery_get_data import bigquery_get_data
from airflow.providers.postgres.hooks.postgres import PostgresHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
"""
Fetch data from a BigQuery table (alternatively fetch selected columns) and insert into PostgreSQL table.

Due to constraints of the PostgreSQL's ON CONFLICT clause both `selected_fields` and `replace_index`
parameters need to be specified when using the operator with parameter `replace=True`.
In effect this means that in order to run this operator with `replace=True` your target table MUST
already have a unique index column / columns, otherwise the INSERT command will fail with an error.
See more at https://www.postgresql.org/docs/current/sql-insert.html.

Please note that currently most of the clauses that can be used with PostgreSQL's INSERT
command, such as ON CONSTRAINT, WHERE, DEFAULT, etc., are not supported by this operator.
If you need the clauses for your queries, `SQLExecuteQueryOperator` will be a more suitable option.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BigQueryToPostgresOperator`

:param target_table_name: target Postgres table (templated)
:param postgres_conn_id: Reference to :ref:`postgres connection id <howto/connection:postgres>`.
:param replace: Whether to replace instead of insert
:param selected_fields: List of fields to return (comma-separated). If
unspecified, all fields are returned. Must be specified if `replace` is True
:param replace_index: the column or list of column names to act as
index for the ON CONFLICT clause. Must be specified if `replace` is True
"""

def __init__(
self,
*,
target_table_name: str,
postgres_conn_id: str = "postgres_default",
replace: bool = False,
selected_fields: list[str] | str | None = None,
replace_index: list[str] | str | None = None,
**kwargs,
) -> None:
super().__init__(target_table_name=target_table_name, **kwargs)
if replace and not (selected_fields and replace_index):
raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names and a unique index.")
super().__init__(
target_table_name=target_table_name, replace=replace, selected_fields=selected_fields, **kwargs
)
self.postgres_conn_id = postgres_conn_id
self.replace_index = replace_index

def get_sql_hook(self) -> PostgresHook:
return PostgresHook(schema=self.database, postgres_conn_id=self.postgres_conn_id)

def execute(self, context: Context) -> None:
big_query_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)
self.persist_links(context)
sql_hook: PostgresHook = self.get_sql_hook()
for rows in bigquery_get_data(
self.log,
self.dataset_id,
self.table_id,
big_query_hook,
self.batch_size,
self.selected_fields,
):
sql_hook.insert_rows(
table=self.target_table_name,
rows=rows,
target_fields=self.selected_fields,
replace=self.replace,
commit_every=self.batch_size,
replace_index=self.replace_index,
)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ to define values dynamically.

You may use the parameter ``selected_fields`` to limit the fields to be copied (all fields by default),
as well as the parameter ``replace`` to overwrite the destination table instead of appending to it.
If the ``replace`` parameter is used, then both ``selected_fields`` and ``replace_index`` parameters will
need to be specified due to constraints of the PostgreSQL's ON CONFLICT clause in the underlying INSERT
command.

For more information, please refer to the links above.

Transferring data
Expand All @@ -57,6 +61,14 @@ The following Operator copies data from a BigQuery table to PostgreSQL.
:start-after: [START howto_operator_bigquery_to_postgres]
:end-before: [END howto_operator_bigquery_to_postgres]

The Operator can also replace data in a PostgreSQL table with matching data from a BigQuery table.

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_to_postgres.py
:language: python
:dedent: 4
:start-after: [START howto_operator_bigquery_to_postgres_upsert]
:end-before: [END howto_operator_bigquery_to_postgres_upsert]


Reference
^^^^^^^^^
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@

from unittest import mock

import pytest

from airflow.providers.google.cloud.transfers.bigquery_to_postgres import BigQueryToPostgresOperator

TASK_ID = "test-bq-create-table-operator"
TEST_DATASET = "test-dataset"
TEST_TABLE_ID = "test-table-id"
TEST_DAG_ID = "test-bigquery-operators"
TEST_DESTINATION_TABLE = "table"


class TestBigQueryToPostgresOperator:
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
def test_execute_good_request_to_bq(self, mock_hook):
destination_table = "table"
operator = BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name=destination_table,
target_table_name=TEST_DESTINATION_TABLE,
replace=False,
)

Expand All @@ -46,3 +48,40 @@ def test_execute_good_request_to_bq(self, mock_hook):
selected_fields=None,
start_index=0,
)

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
def test_execute_good_request_to_bq__with_replace(self, mock_hook):
operator = BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name=TEST_DESTINATION_TABLE,
replace=True,
selected_fields=["col_1", "col_2"],
replace_index=["col_1"],
)

operator.execute(context=mock.MagicMock())
mock_hook.return_value.list_rows.assert_called_once_with(
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
max_results=1000,
selected_fields=["col_1", "col_2"],
start_index=0,
)

@pytest.mark.parametrize(
"selected_fields, replace_index", [(None, None), (["col_1, col_2"], None), (None, ["col_1"])]
)
def test_init_raises_exception_if_replace_is_true_and_missing_params(
self, selected_fields, replace_index
):
error_msg = "PostgreSQL ON CONFLICT upsert syntax requires column names and a unique index."
with pytest.raises(ValueError, match=error_msg):
_ = BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name=TEST_DESTINATION_TABLE,
replace=True,
selected_fields=selected_fields,
replace_index=replace_index,
)
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@
from airflow.settings import Session
from airflow.utils.trigger_rule import TriggerRule

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "example-project")
DAG_ID = "example_bigquery_to_postgres"
DAG_ID = "bigquery_to_postgres"

REGION = "us-central1"
ZONE = REGION + "-a"
NETWORK = "default"
CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_")
CONNECTION_TYPE = "postgres"

BIGQUERY_DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
BIGQUERY_DATASET_NAME = f"ds_{DAG_ID}_{ENV_ID}"
BIGQUERY_TABLE = "test_table"
SOURCE_OBJECT_NAME = "gs://airflow-system-tests-resources/bigquery/salaries_1k.csv"
BATCH_SIZE = 500
Expand Down Expand Up @@ -106,7 +106,9 @@
"initialize_params": {
"disk_size_gb": "10",
"disk_type": f"zones/{ZONE}/diskTypes/pd-balanced",
"source_image": "projects/debian-cloud/global/images/debian-11-bullseye-v20220621",
# The source image can become outdated and stop being supported by apt software packages.
# In that case the image version will need to be updated.
"source_image": "projects/debian-cloud/global/images/debian-12-bookworm-v20240611",
},
}
],
Expand Down Expand Up @@ -162,7 +164,7 @@
schedule="@once", # Override to match your needs
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example", "bigquery"],
tags=["example", "bigquery", "postgres"],
) as dag:
create_bigquery_dataset = BigQueryCreateEmptyDatasetOperator(
task_id="create_bigquery_dataset",
Expand Down Expand Up @@ -244,8 +246,8 @@ def setup_connection(ip_address: str) -> None:

setup_connection_task = setup_connection(get_public_ip_task)

create_sql_table = SQLExecuteQueryOperator(
task_id="create_sql_table",
create_pg_table = SQLExecuteQueryOperator(
task_id="create_pg_table",
conn_id=CONNECTION_ID,
sql=SQL_CREATE_TABLE,
retries=4,
Expand All @@ -264,6 +266,38 @@ def setup_connection(ip_address: str) -> None:
)
# [END howto_operator_bigquery_to_postgres]

update_pg_table_data = SQLExecuteQueryOperator(
task_id="update_pg_table_data",
conn_id=CONNECTION_ID,
sql=f"UPDATE {SQL_TABLE} SET salary = salary + 0.5 WHERE salary < 10000.0",
retries=4,
retry_delay=duration(seconds=20),
retry_exponential_backoff=False,
)

create_unique_index_in_pg_table = SQLExecuteQueryOperator(
task_id="create_unique_index_in_pg_table",
conn_id=CONNECTION_ID,
sql=f"CREATE UNIQUE INDEX emp_salary ON {SQL_TABLE}(emp_name, salary);",
retries=4,
retry_delay=duration(seconds=20),
retry_exponential_backoff=False,
show_return_value_in_logs=True,
)

# [START howto_operator_bigquery_to_postgres_upsert]
bigquery_to_postgres_upsert = BigQueryToPostgresOperator(
task_id="bigquery_to_postgres_upsert",
postgres_conn_id=CONNECTION_ID,
dataset_table=f"{BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE}",
target_table_name=SQL_TABLE,
batch_size=BATCH_SIZE,
replace=True,
selected_fields=["emp_name", "salary"],
replace_index=["emp_name", "salary"],
)
# [END howto_operator_bigquery_to_postgres_upsert]

delete_bigquery_dataset = BigQueryDeleteDatasetOperator(
task_id="delete_bigquery_dataset",
dataset_id=BIGQUERY_DATASET_NAME,
Expand Down Expand Up @@ -301,16 +335,19 @@ def setup_connection(ip_address: str) -> None:
create_bigquery_dataset >> create_bigquery_table >> insert_bigquery_data
create_gce_instance >> setup_postgres
create_gce_instance >> get_public_ip_task >> setup_connection_task
[setup_postgres, setup_connection_task, create_firewall_rule] >> create_sql_table
[setup_postgres, setup_connection_task, create_firewall_rule] >> create_pg_table

(
[insert_bigquery_data, create_sql_table]
[insert_bigquery_data, create_pg_table]
# TEST BODY
>> bigquery_to_postgres
>> update_pg_table_data
>> create_unique_index_in_pg_table
>> bigquery_to_postgres_upsert
)

# TEST TEARDOWN
bigquery_to_postgres >> [
bigquery_to_postgres_upsert >> [
delete_bigquery_dataset,
delete_firewall_rule,
delete_gce_instance,
Expand Down