Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BigQuery transfer operators to respect project_id arguments #32232

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _handle_job_error(job: BigQueryJob | UnknownJob) -> None:
def _prepare_configuration(self):
source_project, source_dataset, source_table = self.hook.split_tablename(
table_input=self.source_project_dataset_table,
default_project_id=self.project_id or self.hook.project_id,
default_project_id=self.hook.project_id,
var_name="source_project_dataset_table",
)

Expand Down Expand Up @@ -184,7 +184,7 @@ def _submit_job(

return hook.insert_job(
configuration=configuration,
project_id=configuration["extract"]["sourceTable"]["projectId"],
project_id=self.project_id or hook.project_id,
location=self.location,
job_id=job_id,
timeout=self.result_timeout,
Expand Down Expand Up @@ -255,7 +255,7 @@ def execute(self, context: Context):
trigger=BigQueryInsertJobTrigger(
conn_id=self.gcp_conn_id,
job_id=job_id,
project_id=self.hook.project_id,
project_id=self.project_id or self.hook.project_id,
),
method_name="execute_complete",
)
Expand Down
31 changes: 14 additions & 17 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _submit_job(
# Submit a new job without waiting for it to complete.
return hook.insert_job(
configuration=self.configuration,
project_id=self.project_id,
project_id=self.project_id or hook.project_id,
location=self.location,
job_id=job_id,
timeout=self.result_timeout,
Expand Down Expand Up @@ -359,7 +359,7 @@ def execute(self, context: Context):

if self.external_table:
self.log.info("Creating a new BigQuery table for storing data...")
table_obj_api_repr = self._create_empty_table()
table_obj_api_repr = self._create_external_table()

BigQueryTableLink.persist(
context=context,
Expand All @@ -381,7 +381,7 @@ def execute(self, context: Context):
except Conflict:
# If the job already exists retrieve it
job = self.hook.get_job(
project_id=self.hook.project_id,
project_id=self.project_id or self.hook.project_id,
location=self.location,
job_id=job_id,
)
Expand Down Expand Up @@ -414,12 +414,12 @@ def execute(self, context: Context):
persist_kwargs = {
"context": context,
"task_instance": self,
"project_id": self.hook.project_id,
"table_id": table,
}
if not isinstance(table, str):
persist_kwargs["table_id"] = table["tableId"]
persist_kwargs["dataset_id"] = table["datasetId"]
persist_kwargs["project_id"] = table["projectId"]
BigQueryTableLink.persist(**persist_kwargs)

self.job_id = job.job_id
Expand All @@ -430,7 +430,7 @@ def execute(self, context: Context):
trigger=BigQueryInsertJobTrigger(
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.hook.project_id,
project_id=self.project_id or self.hook.project_id,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -475,7 +475,9 @@ def _find_max_value_in_column(self):
}
}
try:
job_id = hook.insert_job(configuration=self.configuration, project_id=hook.project_id)
job_id = hook.insert_job(
configuration=self.configuration, project_id=self.project_id or hook.project_id
)
rows = list(hook.get_job(job_id=job_id, location=self.location).result())
except BadRequest as e:
if "Unrecognized name:" in e.message:
Expand All @@ -498,12 +500,7 @@ def _find_max_value_in_column(self):
else:
raise RuntimeError(f"The {select_command} returned no rows!")

def _create_empty_table(self):
self.project_id, dataset_id, table_id = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
default_project_id=self.project_id or self.hook.project_id,
)

def _create_external_table(self):
external_config_api_repr = {
"autodetect": self.autodetect,
"sourceFormat": self.source_format,
Expand Down Expand Up @@ -549,7 +546,7 @@ def _create_empty_table(self):

# build table definition
table = Table(
table_ref=TableReference.from_string(self.destination_project_dataset_table, self.project_id)
table_ref=TableReference.from_string(self.destination_project_dataset_table, self.hook.project_id)
)
table.external_data_configuration = external_config
if self.labels:
Expand All @@ -567,17 +564,17 @@ def _create_empty_table(self):
self.log.info("Creating external table: %s", self.destination_project_dataset_table)
self.hook.create_empty_table(
table_resource=table_obj_api_repr,
project_id=self.project_id,
project_id=self.project_id or self.hook.project_id,
location=self.location,
exists_ok=True,
)
self.log.info("External table created successfully: %s", self.destination_project_dataset_table)
return table_obj_api_repr

def _use_existing_table(self):
self.project_id, destination_dataset, destination_table = self.hook.split_tablename(
destination_project_id, destination_dataset, destination_table = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
default_project_id=self.project_id or self.hook.project_id,
default_project_id=self.hook.project_id,
var_name="destination_project_dataset_table",
)

Expand All @@ -597,7 +594,7 @@ def _use_existing_table(self):
"autodetect": self.autodetect,
"createDisposition": self.create_disposition,
"destinationTable": {
"projectId": self.project_id,
"projectId": destination_project_id,
"datasetId": destination_dataset,
"tableId": destination_table,
},
Expand Down
12 changes: 7 additions & 5 deletions tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TEST_DATASET = "test-dataset"
TEST_TABLE_ID = "test-table-id"
PROJECT_ID = "test-project-id"
JOB_PROJECT_ID = "job-project-id"


class TestBigQueryToGCSOperator:
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID)
mock_hook.return_value.generate_job_id.return_value = real_job_id
mock_hook.return_value.insert_job.return_value = MagicMock(job_id="real_job_id", error_result=False)
mock_hook.return_value.project_id = PROJECT_ID
mock_hook.return_value.project_id = JOB_PROJECT_ID

operator = BigQueryToGCSOperator(
task_id=TASK_ID,
Expand All @@ -77,13 +78,14 @@ def test_execute(self, mock_hook):
field_delimiter=field_delimiter,
print_header=print_header,
labels=labels,
project_id=JOB_PROJECT_ID,
)
operator.execute(context=mock.MagicMock())

mock_hook.return_value.insert_job.assert_called_once_with(
job_id="123456_hash",
configuration=expected_configuration,
project_id=PROJECT_ID,
project_id=JOB_PROJECT_ID,
location=None,
timeout=None,
retry=DEFAULT_RETRY,
Expand Down Expand Up @@ -122,10 +124,10 @@ def test_execute_deferrable_mode(self, mock_hook):
mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID)
mock_hook.return_value.generate_job_id.return_value = real_job_id
mock_hook.return_value.insert_job.return_value = MagicMock(job_id="real_job_id", error_result=False)
mock_hook.return_value.project_id = PROJECT_ID
mock_hook.return_value.project_id = JOB_PROJECT_ID

operator = BigQueryToGCSOperator(
project_id=PROJECT_ID,
project_id=JOB_PROJECT_ID,
task_id=TASK_ID,
source_project_dataset_table=source_project_dataset_table,
destination_cloud_storage_uris=destination_cloud_storage_uris,
Expand All @@ -146,7 +148,7 @@ def test_execute_deferrable_mode(self, mock_hook):
mock_hook.return_value.insert_job.assert_called_once_with(
configuration=expected_configuration,
job_id="123456_hash",
project_id=PROJECT_ID,
project_id=JOB_PROJECT_ID,
location=None,
timeout=None,
retry=DEFAULT_RETRY,
Expand Down
Loading