diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py b/airflow/providers/google/cloud/sensors/bigquery_dts.py index 9e14845197edd..ef92601c0acf2 100644 --- a/airflow/providers/google/cloud/sensors/bigquery_dts.py +++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py @@ -84,6 +84,7 @@ def __init__( retry: Union[Retry, _MethodDefault] = DEFAULT, request_timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), + location: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: @@ -97,6 +98,7 @@ def __init__( self.project_id = project_id self.gcp_cloud_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.location = location def _normalize_state_list(self, states) -> Set[TransferState]: states = {states} if isinstance(states, (str, TransferState, int)) else states @@ -122,6 +124,7 @@ def poke(self, context: 'Context') -> bool: hook = BiqQueryDataTransferServiceHook( gcp_conn_id=self.gcp_cloud_conn_id, impersonation_chain=self.impersonation_chain, + location=self.location, ) run = hook.get_transfer_run( run_id=self.run_id, diff --git a/tests/providers/google/cloud/sensors/test_bigquery_dts.py b/tests/providers/google/cloud/sensors/test_bigquery_dts.py index f75d47714dc54..f55d338e2f76b 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery_dts.py +++ b/tests/providers/google/cloud/sensors/test_bigquery_dts.py @@ -30,6 +30,8 @@ TRANSFER_CONFIG_ID = "config_id" RUN_ID = "run_id" PROJECT_ID = "project_id" +LOCATION = "europe" +GCP_CONN_ID = "google_cloud_default" class TestBigQueryDataTransferServiceTransferRunSensor(unittest.TestCase): @@ -48,6 +50,8 @@ def test_poke_returns_false(self, mock_hook): with pytest.raises(AirflowException, match="Transfer"): op.poke({}) + + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, location=None) mock_hook.return_value.get_transfer_run.assert_called_once_with( transfer_config_id=TRANSFER_CONFIG_ID, run_id=RUN_ID, @@ -68,10 +72,15 @@ def test_poke_returns_true(self, mock_hook): task_id="id", project_id=PROJECT_ID, expected_statuses={"SUCCEEDED"}, + location=LOCATION, ) result = op.poke({}) assert result is True + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, location=LOCATION + ) mock_hook.return_value.get_transfer_run.assert_called_once_with( transfer_config_id=TRANSFER_CONFIG_ID, run_id=RUN_ID,