diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index d61634ae4c524..fa823af439d44 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -103,8 +103,8 @@ def __init__( ) self.use_legacy_sql = use_legacy_sql self.location = location - self.running_job_id = None # type: Optional[str] - self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict + self.running_job_id: str | None = None + self.api_resource_configs: dict = api_resource_configs if api_resource_configs else {} self.labels = labels self.credentials_path = "bigquery_hook_credentials.json" @@ -2313,14 +2313,14 @@ def __init__( self.use_legacy_sql = use_legacy_sql if api_resource_configs: _validate_value("api_resource_configs", api_resource_configs, dict) - self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict + self.api_resource_configs: dict = api_resource_configs if api_resource_configs else {} self.running_job_id = None # type: Optional[str] self.location = location self.num_retries = num_retries self.labels = labels self.hook = hook - def create_empty_table(self, *args, **kwargs) -> None: + def create_empty_table(self, *args, **kwargs): """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table` @@ -2372,7 +2372,7 @@ def delete_dataset(self, *args, **kwargs) -> None: ) return self.hook.delete_dataset(*args, **kwargs) - def create_external_table(self, *args, **kwargs) -> None: + def create_external_table(self, *args, **kwargs): """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_external_table` diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 1819909f0b012..084205f4ed728 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -179,6 +179,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( @@ -199,6 +200,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(sql=sql, **kwargs) @@ -208,6 +210,59 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels + self.deferrable = deferrable + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Context): + if not self.deferrable: + super().execute(context=context) + else: + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + ) + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + records = event["records"] + if not records: + raise AirflowException("The query returned empty results") + elif not all(bool(r) for r in records): + raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") + self.log.info("Record: %s", event["records"]) + self.log.info("Success.") class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): @@ -233,6 +288,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( @@ -256,6 +312,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs) @@ -264,6 +321,65 @@ def __init__( self.use_legacy_sql = use_legacy_sql self.impersonation_chain = impersonation_chain self.labels = labels + self.deferrable = deferrable + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + configuration = { + "query": { + "query": self.sql, + "useLegacySql": False, + } + } + if self.use_legacy_sql: + configuration["query"]["useLegacySql"] = self.use_legacy_sql + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Context) -> None: # type: ignore[override] + if not self.deferrable: + super().execute(context=context) + else: + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryValueCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + sql=self.sql, + pass_value=self.pass_value, + tolerance=self.tol, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperator): @@ -300,6 +416,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( @@ -324,6 +441,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__( @@ -339,6 +457,67 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels + self.deferrable = deferrable + + def _submit_job( + self, + hook: BigQueryHook, + sql: str, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + configuration = {"query": {"query": sql}} + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Context): + if not self.deferrable: + super().execute(context) + else: + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + self.log.info("Using ratio formula: %s", self.ratio_formula) + + self.log.info("Executing SQL check: %s", self.sql1) + job_1 = self._submit_job(hook, sql=self.sql1, job_id="") + context["ti"].xcom_push(key="job_id", value=job_1.job_id) + + self.log.info("Executing SQL check: %s", self.sql2) + job_2 = self._submit_job(hook, sql=self.sql2, job_id="") + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryIntervalCheckTrigger( + conn_id=self.gcp_conn_id, + first_job_id=job_1.job_id, + second_job_id=job_2.job_id, + project_id=hook.project_id, + table=self.table, + metrics_thresholds=self.metrics_thresholds, + date_filter_column=self.date_filter_column, + days_back=self.days_back, + ratio_formula=self.ratio_formula, + ignore_zero=self.ignore_zero, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) class BigQueryGetDataOperator(BaseOperator): @@ -395,6 +574,7 @@ class BigQueryGetDataOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( @@ -419,6 +599,7 @@ def __init__( delegate_to: str | None = None, location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -432,39 +613,97 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.project_id = project_id + self.deferrable = deferrable - def execute(self, context: Context) -> list: - self.log.info( - 'Fetching Data from %s.%s max results: %s', self.dataset_id, self.table_id, self.max_results + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + get_query = self.generate_query() + configuration = {"query": {"query": get_query}} + """Submit a new job and get the job id for polling the status using Triggerer.""" + return hook.insert_job( + configuration=configuration, + location=self.location, + project_id=hook.project_id, + job_id=job_id, + nowait=True, ) + def generate_query(self) -> str: + """ + Generate a select query if selected fields are given or with * + for the given dataset and table id + """ + query = "select " + if self.selected_fields: + query += self.selected_fields + else: + query += "*" + query += f" from {self.dataset_id}.{self.table_id} limit {self.max_results}" + return query + + def execute(self, context: Context): hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + self.hook = hook + + if not self.deferrable: + self.log.info( + 'Fetching Data from %s.%s max results: %s', self.dataset_id, self.table_id, self.max_results + ) + if not self.selected_fields: + schema: dict[str, list] = hook.get_schema( + dataset_id=self.dataset_id, + table_id=self.table_id, + ) + if "fields" in schema: + self.selected_fields = ','.join([field["name"] for field in schema["fields"]]) - if not self.selected_fields: - schema: dict[str, list] = hook.get_schema( + rows = hook.list_rows( dataset_id=self.dataset_id, table_id=self.table_id, + max_results=self.max_results, + selected_fields=self.selected_fields, + location=self.location, + project_id=self.project_id, ) - if "fields" in schema: - self.selected_fields = ','.join([field["name"] for field in schema["fields"]]) - rows = hook.list_rows( - dataset_id=self.dataset_id, - table_id=self.table_id, - max_results=self.max_results, - selected_fields=self.selected_fields, - location=self.location, - project_id=self.project_id, + self.log.info('Total extracted rows: %s', len(rows)) + + table_data = [row.values() for row in rows] + return table_data + + job = self._submit_job(hook, job_id="") + self.job_id = job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryGetDataTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=hook.project_id, + ), + method_name="execute_complete", ) - self.log.info('Total extracted rows: %s', len(rows)) + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) - table_data = [row.values() for row in rows] - return table_data + self.log.info("Total extracted rows: %s", len(event["records"])) + return event["records"] class BigQueryExecuteQueryOperator(BaseOperator): @@ -2099,6 +2338,7 @@ class BigQueryInsertJobOperator(BaseOperator): :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called :param result_retry: How to retry the `result` call that retrieves rows :param result_timeout: The number of seconds to wait for `result` method before using `result_retry` + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( @@ -2129,6 +2369,7 @@ def __init__( cancel_on_kill: bool = True, result_retry: Retry = DEFAULT_RETRY, result_timeout: float | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -2145,6 +2386,7 @@ def __init__( self.result_retry = result_retry self.result_timeout = result_timeout self.hook: BigQueryHook | None = None + self.deferrable = deferrable def prepare_template(self) -> None: # If .json is passed then we have to read the file @@ -2200,7 +2442,11 @@ def execute(self, context: Any): location=self.location, job_id=job_id, ) - if job.state not in self.reattach_states: + if job.state in self.reattach_states: + # We are reattaching to a job + job._begin() + self._handle_job_error(job) + else: # Same job configuration so we need force_rerun raise AirflowException( f"Job with id: {job_id} already exists and is in {job.state} state. If you " @@ -2235,10 +2481,36 @@ def execute(self, context: Any): BigQueryTableLink.persist(**persist_kwargs) self.job_id = job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) # Wait for the job to complete - job.result(timeout=self.result_timeout, retry=self.result_retry) - self._handle_job_error(job) + if not self.deferrable: + job.result(timeout=self.result_timeout, retry=self.result_retry) + self._handle_job_error(job) + + return self.job_id + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryInsertJobTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + project_id=self.project_id, + ), + method_name="execute_complete", + ) + def execute_complete(self, context: Context, event: dict[str, Any]): + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) return self.job_id def on_kill(self) -> None: @@ -2248,460 +2520,3 @@ def on_kill(self) -> None: ) else: self.log.info('Skipping to cancel job: %s:%s.%s', self.project_id, self.location, self.job_id) - - -class BigQueryInsertJobAsyncOperator(BigQueryInsertJobOperator, BaseOperator): - """ - Starts a BigQuery job asynchronously, and returns job id. - This operator works in the following way: - - - it calculates a unique hash of the job using job's configuration or uuid if ``force_rerun`` is True - - creates ``job_id`` in form of - ``[provided_job_id | airflow_{dag_id}_{task_id}_{exec_date}]_{uniqueness_suffix}`` - - submits a BigQuery job using the ``job_id`` - - if job with given id already exists then it tries to reattach to the job if its not done and its - state is in ``reattach_states``. If the job is done the operator will raise ``AirflowException``. - - Using ``force_rerun`` will submit a new job every time without attaching to already existing ones. - - For job definition see here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - :param configuration: The configuration parameter maps directly to BigQuery's - configuration field in the job object. For more details see - https://cloud.google.com/bigquery/docs/reference/v2/jobs - :param job_id: The ID of the job. It will be suffixed with hash of job configuration - unless ``force_rerun`` is True. - The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or - dashes (-). The maximum length is 1,024 characters. If not provided then uuid will - be generated. - :param force_rerun: If True then operator will use hash of uuid as job id suffix - :param reattach_states: Set of BigQuery job's states in case of which we should reattach - to the job. Should be other than final states. - :param project_id: Google Cloud Project where the job is running - :param location: location the job is running - :param gcp_conn_id: The connection ID used to connect to Google Cloud. - :param delegate_to: The account to impersonate using domain-wide delegation of authority, - if any. For this to work, the service account making the request must have - domain-wide delegation enabled. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called - """ - - def _submit_job(self, hook: BigQueryHook, job_id: str) -> BigQueryJob: # type: ignore[override] - """Submit a new job and get the job id for polling the status using Triggerer.""" - return hook.insert_job( - configuration=self.configuration, - project_id=self.project_id, - location=self.location, - job_id=job_id, - nowait=True, - ) - - def execute(self, context: Any) -> None: - hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) - - self.hook = hook - job_id = self.hook.generate_job_id( - job_id=self.job_id, - dag_id=self.dag_id, - task_id=self.task_id, - logical_date=context["logical_date"], - configuration=self.configuration, - force_rerun=self.force_rerun, - ) - - try: - job = self._submit_job(hook, job_id) - self._handle_job_error(job) - except Conflict: - # If the job already exists retrieve it - job = hook.get_job( - project_id=self.project_id, - location=self.location, - job_id=job_id, - ) - if job.state in self.reattach_states: - # We are reattaching to a job - job._begin() - self._handle_job_error(job) - else: - # Same job configuration so we need force_rerun - raise AirflowException( - f"Job with id: {job_id} already exists and is in {job.state} state. If you " - f"want to force rerun it consider setting `force_rerun=True`." - f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" - ) - - self.job_id = job.job_id - context["ti"].xcom_push(key="job_id", value=self.job_id) - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryInsertJobTrigger( - conn_id=self.gcp_conn_id, - job_id=self.job_id, - project_id=self.project_id, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Any, event: dict[str, Any]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info( - "%s completed with response %s ", - self.task_id, - event["message"], - ) - - -class BigQueryCheckAsyncOperator(BigQueryCheckOperator): - """ - BigQueryCheckAsyncOperator is asynchronous operator, submit the job and check - for the status in async mode by using the job id - """ - - def _submit_job( - self, - hook: BigQueryHook, - job_id: str, - ) -> BigQueryJob: - """Submit a new job and get the job id for polling the status using Trigger.""" - configuration = {"query": {"query": self.sql}} - - return hook.insert_job( - configuration=configuration, - project_id=hook.project_id, - location=self.location, - job_id=job_id, - nowait=True, - ) - - def execute(self, context: Any) -> None: - hook = BigQueryHook( - gcp_conn_id=self.gcp_conn_id, - ) - job = self._submit_job(hook, job_id="") - context["ti"].xcom_push(key="job_id", value=job.job_id) - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryCheckTrigger( - conn_id=self.gcp_conn_id, - job_id=job.job_id, - project_id=hook.project_id, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Any, event: dict[str, Any]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - - records = event["records"] - if not records: - raise AirflowException("The query returned None") - elif not all(bool(r) for r in records): - raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") - self.log.info("Record: %s", event["records"]) - self.log.info("Success.") - - -class BigQueryGetDataAsyncOperator(BigQueryGetDataOperator): - """ - Fetches the data from a BigQuery table (alternatively fetch data for selected columns) - and returns data in a python list. The number of elements in the returned list will - be equal to the number of rows fetched. Each element in the list will again be a list - where element would represent the columns values for that row. - - **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]`` - - .. note:: - If you pass fields to ``selected_fields`` which are in different order than the - order of columns already in - BQ table, the data will still be in the order of BQ table. - For example if the BQ table has 3 columns as - ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` - the data would still be of the form ``'A,B'``. - - **Example**: :: - - get_data = BigQueryGetDataOperator( - task_id='get_data_from_bq', - dataset_id='test_dataset', - table_id='Transaction_partitions', - max_results=100, - selected_fields='DATE', - gcp_conn_id='airflow-conn-id' - ) - - :param dataset_id: The dataset ID of the requested table. (templated) - :param table_id: The table ID of the requested table. (templated) - :param max_results: The maximum number of records (rows) to be fetched from the table. (templated) - :param selected_fields: List of fields to return (comma-separated). If - unspecified, all fields are returned. - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. - :param delegate_to: The account to impersonate using domain-wide delegation of authority, - if any. For this to work, the service account making the request must have - domain-wide delegation enabled. - :param location: The location used for the operation. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - def _submit_job( - self, - hook: BigQueryHook, - job_id: str, - configuration: dict[str, Any], - ) -> BigQueryJob: - """Submit a new job and get the job id for polling the status using Triggerer.""" - return hook.insert_job( - configuration=configuration, - location=self.location, - project_id=hook.project_id, - job_id=job_id, - nowait=True, - ) - - def generate_query(self) -> str: - """ - Generate a select query if selected fields are given or with * - for the given dataset and table id - """ - selected_fields = self.selected_fields if self.selected_fields else "*" - return f"select {selected_fields} from {self.dataset_id}.{self.table_id} limit {self.max_results}" - - def execute(self, context: Any) -> None: # type: ignore[override] - get_query = self.generate_query() - configuration = {"query": {"query": get_query}} - - hook = BigQueryHook( - gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - self.hook = hook - job = self._submit_job(hook, job_id="", configuration=configuration) - self.job_id = job.job_id - context["ti"].xcom_push(key="job_id", value=self.job_id) - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryGetDataTrigger( - conn_id=self.gcp_conn_id, - job_id=self.job_id, - dataset_id=self.dataset_id, - table_id=self.table_id, - project_id=hook.project_id, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Any, event: dict[str, Any]) -> Any: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - - self.log.info("Total extracted rows: %s", len(event["records"])) - return event["records"] - - -class BigQueryIntervalCheckAsyncOperator(BigQueryIntervalCheckOperator): - """ - Checks asynchronously that the values of metrics given as SQL expressions are within - a certain tolerance of the ones from days_back before. - - This method constructs a query like so :: - SELECT {metrics_threshold_dict_key} FROM {table} - WHERE {date_filter_column}= - - :param table: the table name - :param days_back: number of days between ds and the ds we want to check - against. Defaults to 7 days - :param metrics_thresholds: a dictionary of ratios indexed by metrics, for - example 'COUNT(*)': 1.5 would require a 50 percent or less difference - between the current day, and the prior days_back. - :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. - :param location: The geographic location of the job. See details at: - https://cloud.google.com/bigquery/docs/locations#specifying_your_location - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param labels: a dictionary containing labels for the table, passed to BigQuery - """ - - def _submit_job( - self, - hook: BigQueryHook, - sql: str, - job_id: str, - ) -> BigQueryJob: - """Submit a new job and get the job id for polling the status using Triggerer.""" - configuration = {"query": {"query": sql}} - return hook.insert_job( - configuration=configuration, - project_id=hook.project_id, - location=self.location, - job_id=job_id, - nowait=True, - ) - - def execute(self, context: Any) -> None: - """Execute the job in sync mode and defers the trigger with job id to poll for the status""" - hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) - self.log.info("Using ratio formula: %s", self.ratio_formula) - - self.log.info("Executing SQL check: %s", self.sql1) - job_1 = self._submit_job(hook, sql=self.sql1, job_id="") - context["ti"].xcom_push(key="job_id", value=job_1.job_id) - - self.log.info("Executing SQL check: %s", self.sql2) - job_2 = self._submit_job(hook, sql=self.sql2, job_id="") - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryIntervalCheckTrigger( - conn_id=self.gcp_conn_id, - first_job_id=job_1.job_id, - second_job_id=job_2.job_id, - project_id=hook.project_id, - table=self.table, - metrics_thresholds=self.metrics_thresholds, - date_filter_column=self.date_filter_column, - days_back=self.days_back, - ratio_formula=self.ratio_formula, - ignore_zero=self.ignore_zero, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Any, event: dict[str, Any]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - - self.log.info( - "%s completed with response %s ", - self.task_id, - event["status"], - ) - - -class BigQueryValueCheckAsyncOperator(BigQueryValueCheckOperator): - """ - Performs a simple value check using sql code. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:BigQueryValueCheckOperator` - - :param sql: the sql to be executed - :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). - :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. - :param location: The geographic location of the job. See details at: - https://cloud.google.com/bigquery/docs/locations#specifying_your_location - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param labels: a dictionary containing labels for the table, passed to BigQuery - """ - - def _submit_job( - self, - hook: BigQueryHook, - job_id: str, - ) -> BigQueryJob: - """Submit a new job and get the job id for polling the status using Triggerer.""" - configuration = { - "query": { - "query": self.sql, - "useLegacySql": False, - } - } - if self.use_legacy_sql: - configuration["query"]["useLegacySql"] = self.use_legacy_sql - - return hook.insert_job( - configuration=configuration, - project_id=hook.project_id, - location=self.location, - job_id=job_id, - nowait=True, - ) - - def execute(self, context: Any) -> None: - hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) - - job = self._submit_job(hook, job_id="") - context["ti"].xcom_push(key="job_id", value=job.job_id) - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryValueCheckTrigger( - conn_id=self.gcp_conn_id, - job_id=job.job_id, - project_id=hook.project_id, - sql=self.sql, - pass_value=self.pass_value, - tolerance=self.tol, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Any, event: dict[str, Any]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info( - "%s completed with response %s ", - self.task_id, - event["message"], - ) diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index 598b556d7e027..e18145f08c622 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains a Google Bigquery sensor.""" +"""This module contains Google BigQuery sensors.""" from __future__ import annotations from datetime import timedelta diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 626bc7e7f4318..66f31ca6850dc 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -64,8 +64,8 @@ dependencies: # Introduced breaking changes across the board. Those libraries should be upgraded soon # TODO: Upgrade all Google libraries that are limited to <2.0.0 - PyOpenSSL - - asgiref - - gcloud-aio-bigquery + - asgiref>=3.5.2 + - gcloud-aio-bigquery>=6.1.2 - gcloud-aio-storage - google-ads>=15.1.1 - google-api-core>=2.7.0,<3.0.0 diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst index 6672d2ce58573..548c37ace21f1 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst @@ -1,3 +1,4 @@ + .. Licensed to the Apache Software Foundation (ASF) under one .. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information @@ -348,6 +349,14 @@ idempotency. If this parameter is not passed then uuid will be used as ``job_id` operator will try to submit a new job with this ``job_id```. If there's already a job with such ``job_id`` then it will reattach to the existing job. +Also for all this action you can use operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_insert_job_async] + :end-before: [END howto_operator_bigquery_insert_job_async] + Validate data ^^^^^^^^^^^^^ @@ -370,8 +379,7 @@ return ``False`` the check is failed and errors out. :start-after: [START howto_operator_bigquery_check] :end-before: [END howto_operator_bigquery_check] -Below example shows the usage of :class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckAsyncOperator`, -which is the deferrable version of the operator +Also you can use deferrable mode in this operator .. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py :language: python @@ -398,10 +406,7 @@ or numeric value. If numeric, you can also specify ``tolerance``. :start-after: [START howto_operator_bigquery_value_check] :end-before: [END howto_operator_bigquery_value_check] -The below example shows how to use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckAsyncOperator`. -Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow -deployment. +Also you can use deferrable mode in this operator .. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py :language: python @@ -425,10 +430,7 @@ tolerance of the ones from ``days_back`` before you can either use :start-after: [START howto_operator_bigquery_interval_check] :end-before: [END howto_operator_bigquery_interval_check] -The below example shows how to use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckAsyncOperator`. -Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow -deployment. +Also you can use deferrable mode in this operator .. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py :language: python diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index bc266ec27a71a..bfbb0a463dcd1 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -301,8 +301,8 @@ "PyOpenSSL", "apache-airflow-providers-common-sql>=1.1.0", "apache-airflow>=2.2.0", - "asgiref", - "gcloud-aio-bigquery", + "asgiref>=3.5.2", + "gcloud-aio-bigquery>=6.1.2", "gcloud-aio-storage", "google-ads>=15.1.1", "google-api-core>=2.7.0,<3.0.0", diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 25888ee96162d..ae642c3583bc6 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -55,13 +55,13 @@ DATASET_ID = "bq_dataset" TABLE_ID = "bq_table" PARTITION_ID = "20200101" -VIEW_ID = 'bq_view' +VIEW_ID = "bq_view" JOB_ID = "1234" -LOCATION = 'europe-north1' +LOCATION = "europe-north1" TABLE_REFERENCE_REPR = { - 'tableId': TABLE_ID, - 'datasetId': DATASET_ID, - 'projectId': PROJECT_ID, + "tableId": TABLE_ID, + "datasetId": DATASET_ID, + "projectId": PROJECT_ID, } TABLE_REFERENCE = TableReference.from_api_repr(TABLE_REFERENCE_REPR) @@ -890,7 +890,7 @@ def test_run_query_with_arg(self, mock_insert): _, kwargs = mock_insert.call_args assert kwargs["configuration"]['labels'] == {'label1': 'test1', 'label2': 'test2'} - @pytest.mark.parametrize('nowait', [True, False]) + @pytest.mark.parametrize("nowait", [True, False]) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.QueryJob") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") def test_insert_job(self, mock_client, mock_query_job, nowait): @@ -913,8 +913,8 @@ def test_insert_job(self, mock_client, mock_query_job, nowait): mock_query_job.from_api_repr.assert_called_once_with( { - 'configuration': job_conf, - 'jobReference': {'jobId': JOB_ID, 'projectId': PROJECT_ID, 'location': LOCATION}, + "configuration": job_conf, + "jobReference": {"jobId": JOB_ID, "projectId": PROJECT_ID, "location": LOCATION}, }, mock_client.return_value, ) diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 57c79156193b8..26159620cd9aa 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -30,7 +30,6 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCheckAsyncOperator, BigQueryCheckOperator, BigQueryConsoleIndexableLink, BigQueryConsoleLink, @@ -40,20 +39,16 @@ BigQueryDeleteDatasetOperator, BigQueryDeleteTableOperator, BigQueryExecuteQueryOperator, - BigQueryGetDataAsyncOperator, BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, - BigQueryInsertJobAsyncOperator, BigQueryInsertJobOperator, - BigQueryIntervalCheckAsyncOperator, BigQueryIntervalCheckOperator, BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, BigQueryUpdateTableOperator, BigQueryUpdateTableSchemaOperator, BigQueryUpsertTableOperator, - BigQueryValueCheckAsyncOperator, BigQueryValueCheckOperator, ) from airflow.providers.google.cloud.triggers.bigquery import ( @@ -69,9 +64,9 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags, clear_db_xcom TASK_ID = 'test-bq-generic-operator' -TEST_DATASET = 'test-dataset' -TEST_DATASET_LOCATION = 'EU' -TEST_GCP_PROJECT_ID = 'test-project' +TEST_DATASET = "test-dataset" +TEST_DATASET_LOCATION = "EU" +TEST_GCP_PROJECT_ID = "test-project" TEST_DELETE_CONTENTS = True TEST_TABLE_ID = 'test-table-id' TEST_GCS_BUCKET = 'test-bucket' @@ -1142,12 +1137,13 @@ def test_bigquery_insert_job_operator_async(mock_hook): } mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - op = BigQueryInsertJobAsyncOperator( + op = BigQueryInsertJobOperator( task_id="insert_query_job", configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, + deferrable=True, ) with pytest.raises(TaskDeferred) as exc: @@ -1168,12 +1164,13 @@ def test_bigquery_insert_job_operator_execute_failure(): } job_id = "123456" - operator = BigQueryInsertJobAsyncOperator( + operator = BigQueryInsertJobOperator( task_id="insert_query_job", configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, + deferrable=True, ) with pytest.raises(AirflowException): @@ -1212,12 +1209,13 @@ def test_bigquery_insert_job_operator_execute_complete(): } job_id = "123456" - operator = BigQueryInsertJobAsyncOperator( + operator = BigQueryInsertJobOperator( task_id="insert_query_job", configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, + deferrable=True, ) with mock.patch.object(operator.log, "info") as mock_log_info: operator.execute_complete( @@ -1249,13 +1247,14 @@ def test_bigquery_insert_job_operator_with_job_id_generate(mock_hook): ) mock_hook.return_value.get_job.return_value = job - op = BigQueryInsertJobAsyncOperator( + op = BigQueryInsertJobOperator( task_id="insert_query_job", configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, reattach_states={"PENDING"}, + deferrable=True, ) with pytest.raises(TaskDeferred): @@ -1294,13 +1293,14 @@ def test_execute_reattach(mock_hook): ) mock_hook.return_value.get_job.return_value = job - op = BigQueryInsertJobAsyncOperator( + op = BigQueryInsertJobOperator( task_id="insert_query_job", configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, reattach_states={"PENDING"}, + deferrable=True, ) with pytest.raises(TaskDeferred): @@ -1316,7 +1316,7 @@ def test_execute_reattach(mock_hook): @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") -def test_execute_force_rerun(mock_hook): +def test_execute_force_rerun_async(mock_hook): job_id = "123456" hash_ = "hash" real_job_id = f"{job_id}_{hash_}" @@ -1338,13 +1338,14 @@ def test_execute_force_rerun(mock_hook): ) mock_hook.return_value.get_job.return_value = job - op = BigQueryInsertJobAsyncOperator( + op = BigQueryInsertJobOperator( task_id="insert_query_job", configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, reattach_states={"PENDING"}, + deferrable=True, ) with pytest.raises(AirflowException) as exc: @@ -1377,10 +1378,11 @@ def test_bigquery_check_operator_async(mock_hook): mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - op = BigQueryCheckAsyncOperator( + op = BigQueryCheckOperator( task_id="bq_check_operator_job", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION, + deferrable=True, ) with pytest.raises(TaskDeferred) as exc: @@ -1392,8 +1394,11 @@ def test_bigquery_check_operator_async(mock_hook): def test_bigquery_check_operator_execute_failure(): """Tests that an AirflowException is raised in case of error event""" - operator = BigQueryCheckAsyncOperator( - task_id="bq_check_operator_execute_failure", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION + operator = BigQueryCheckOperator( + task_id="bq_check_operator_execute_failure", + sql="SELECT * FROM any", + location=TEST_DATASET_LOCATION, + deferrable=True, ) with pytest.raises(AirflowException): @@ -1403,14 +1408,17 @@ def test_bigquery_check_operator_execute_failure(): def test_bigquery_check_op_execute_complete_with_no_records(): """Asserts that exception is raised with correct expected exception message""" - operator = BigQueryCheckAsyncOperator( - task_id="bq_check_operator_execute_complete", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION + operator = BigQueryCheckOperator( + task_id="bq_check_operator_execute_complete", + sql="SELECT * FROM any", + location=TEST_DATASET_LOCATION, + deferrable=True, ) with pytest.raises(AirflowException) as exc: operator.execute_complete(context=None, event={"status": "success", "records": None}) - expected_exception_msg = "The query returned None" + expected_exception_msg = "The query returned empty results" assert str(exc.value) == expected_exception_msg @@ -1420,8 +1428,11 @@ def test_bigquery_check_op_execute_complete_with_non_boolean_records(): test_sql = "SELECT * FROM any" - operator = BigQueryCheckAsyncOperator( - task_id="bq_check_operator_execute_complete", sql=test_sql, location=TEST_DATASET_LOCATION + operator = BigQueryCheckOperator( + task_id="bq_check_operator_execute_complete", + sql=test_sql, + location=TEST_DATASET_LOCATION, + deferrable=True, ) expected_exception_msg = f"Test failed.\nQuery:\n{test_sql}\nResults:\n{[20, False]!s}" @@ -1435,8 +1446,11 @@ def test_bigquery_check_op_execute_complete_with_non_boolean_records(): def test_bigquery_check_operator_execute_complete(): """Asserts that logging occurs as expected""" - operator = BigQueryCheckAsyncOperator( - task_id="bq_check_operator_execute_complete", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION + operator = BigQueryCheckOperator( + task_id="bq_check_operator_execute_complete", + sql="SELECT * FROM any", + location=TEST_DATASET_LOCATION, + deferrable=True, ) with mock.patch.object(operator.log, "info") as mock_log_info: @@ -1447,28 +1461,30 @@ def test_bigquery_check_operator_execute_complete(): def test_bigquery_interval_check_operator_execute_complete(): """Asserts that logging occurs as expected""" - operator = BigQueryIntervalCheckAsyncOperator( + operator = BigQueryIntervalCheckOperator( task_id="bq_interval_check_operator_execute_complete", table="test_table", metrics_thresholds={"COUNT(*)": 1.5}, location=TEST_DATASET_LOCATION, + deferrable=True, ) with mock.patch.object(operator.log, "info") as mock_log_info: operator.execute_complete(context=None, event={"status": "success", "message": "Job completed"}) mock_log_info.assert_called_with( - "%s completed with response %s ", "bq_interval_check_operator_execute_complete", "success" + "%s completed with response %s ", "bq_interval_check_operator_execute_complete", "Job completed" ) def test_bigquery_interval_check_operator_execute_failure(): """Tests that an AirflowException is raised in case of error event""" - operator = BigQueryIntervalCheckAsyncOperator( + operator = BigQueryIntervalCheckOperator( task_id="bq_interval_check_operator_execute_complete", table="test_table", metrics_thresholds={"COUNT(*)": 1.5}, location=TEST_DATASET_LOCATION, + deferrable=True, ) with pytest.raises(AirflowException): @@ -1487,11 +1503,12 @@ def test_bigquery_interval_check_operator_async(mock_hook): mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - op = BigQueryIntervalCheckAsyncOperator( + op = BigQueryIntervalCheckOperator( task_id="bq_interval_check_operator_execute_complete", table="test_table", metrics_thresholds={"COUNT(*)": 1.5}, location=TEST_DATASET_LOCATION, + deferrable=True, ) with pytest.raises(TaskDeferred) as exc: @@ -1514,12 +1531,13 @@ def test_bigquery_get_data_operator_async_with_selected_fields(mock_hook): mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - op = BigQueryGetDataAsyncOperator( + op = BigQueryGetDataOperator( task_id="get_data_from_bq", dataset_id=TEST_DATASET, - table_id=TEST_TABLE, + table_id=TEST_TABLE_ID, max_results=100, selected_fields="value,name", + deferrable=True, ) with pytest.raises(TaskDeferred) as exc: @@ -1540,11 +1558,12 @@ def test_bigquery_get_data_operator_async_without_selected_fields(mock_hook): mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - op = BigQueryGetDataAsyncOperator( + op = BigQueryGetDataOperator( task_id="get_data_from_bq", dataset_id=TEST_DATASET, - table_id=TEST_TABLE, + table_id=TEST_TABLE_ID, max_results=100, + deferrable=True, ) with pytest.raises(TaskDeferred) as exc: @@ -1556,11 +1575,12 @@ def test_bigquery_get_data_operator_async_without_selected_fields(mock_hook): def test_bigquery_get_data_operator_execute_failure(): """Tests that an AirflowException is raised in case of error event""" - operator = BigQueryGetDataAsyncOperator( + operator = BigQueryGetDataOperator( task_id="get_data_from_bq", dataset_id=TEST_DATASET, table_id="any", max_results=100, + deferrable=True, ) with pytest.raises(AirflowException): @@ -1570,11 +1590,12 @@ def test_bigquery_get_data_operator_execute_failure(): def test_bigquery_get_data_op_execute_complete_with_records(): """Asserts that exception is raised with correct expected exception message""" - operator = BigQueryGetDataAsyncOperator( + operator = BigQueryGetDataOperator( task_id="get_data_from_bq", dataset_id=TEST_DATASET, table_id="any", max_results=100, + deferrable=True, ) with mock.patch.object(operator.log, "info") as mock_log_info: @@ -1583,15 +1604,15 @@ def test_bigquery_get_data_op_execute_complete_with_records(): def _get_value_check_async_operator(use_legacy_sql: bool = False): - """Helper function to initialise BigQueryValueCheckOperatorAsync operator""" query = "SELECT COUNT(*) FROM Any" pass_val = 2 - return BigQueryValueCheckAsyncOperator( + return BigQueryValueCheckOperator( task_id="check_value", sql=query, pass_value=pass_val, use_legacy_sql=use_legacy_sql, + deferrable=True, ) @@ -1642,7 +1663,7 @@ def test_bigquery_value_check_operator_execute_complete_failure(): def test_bigquery_value_check_missing_param(kwargs, expected): """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator""" with pytest.raises(AirflowException) as missing_param: - BigQueryValueCheckAsyncOperator(**kwargs) + BigQueryValueCheckOperator(deferrable=True, **kwargs) assert missing_param.value.args[0] == expected @@ -1653,5 +1674,5 @@ def test_bigquery_value_check_empty(): "missing keyword arguments 'pass_value', 'sql'", ) with pytest.raises(AirflowException) as missing_param: - BigQueryValueCheckAsyncOperator(kwargs={}) + BigQueryValueCheckOperator(deferrable=True, kwargs={}) assert (missing_param.value.args[0] == expected) or (missing_param.value.args[0] == expected1) diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py index cfe41d0eb74b3..ca9d13871f02a 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py @@ -25,25 +25,25 @@ from airflow import DAG from airflow.operators.bash import BashOperator -from airflow.operators.empty import EmptyOperator from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCheckAsyncOperator, + BigQueryCheckOperator, BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, - BigQueryGetDataAsyncOperator, - BigQueryInsertJobAsyncOperator, - BigQueryIntervalCheckAsyncOperator, - BigQueryValueCheckAsyncOperator, + BigQueryGetDataOperator, + BigQueryInsertJobOperator, + BigQueryIntervalCheckOperator, + BigQueryValueCheckOperator, ) from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -PROJECT_ID = os.getenv("SYSTEM_TESTS_GCP_PROJECT") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") + DAG_ID = "bigquery_queries_async" + DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" LOCATION = "us" -EXECUTION_TIMEOUT = 6 TABLE_1 = "table1" TABLE_2 = "table2" @@ -62,15 +62,44 @@ f"(42, 'fishy fish', '{INSERT_DATE}');" ) + +CONFIGURATION = { + "query": { + "query": f"""DECLARE success BOOL; + DECLARE size_bytes INT64; + DECLARE row_count INT64; + DECLARE DELAY_TIME DATETIME; + DECLARE WAIT STRING; + SET success = FALSE; + + SELECT row_count = (SELECT row_count FROM {DATASET}.__TABLES__ WHERE table_id='NON_EXISTING_TABLE'); + IF row_count > 0 THEN + SELECT 'Table Exists!' as message, retry_count as retries; + SET success = TRUE; + ELSE + SELECT 'Table does not exist' as message, row_count; + SET WAIT = 'TRUE'; + SET DELAY_TIME = DATETIME_ADD(CURRENT_DATETIME,INTERVAL 1 MINUTE); + WHILE WAIT = 'TRUE' DO + IF (DELAY_TIME < CURRENT_DATETIME) THEN + SET WAIT = 'FALSE'; + END IF; + END WHILE; + END IF;""", + "useLegacySql": False, + } +} + + default_args = { - "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), - "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), - "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), + "execution_timeout": timedelta(hours=6), + "retries": 2, + "retry_delay": timedelta(seconds=60), } with DAG( - dag_id="example_async_bigquery_queries_async", - schedule=None, + dag_id=DAG_ID, + schedule='@once', start_date=datetime(2022, 1, 1), catchup=False, default_args=default_args, @@ -91,14 +120,15 @@ location=LOCATION, ) - create_dataset >> create_table_1 - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET, delete_contents=True, trigger_rule=TriggerRule.ALL_DONE + task_id="delete_dataset", + dataset_id=DATASET, + delete_contents=True, + trigger_rule=TriggerRule.ALL_DONE, ) # [START howto_operator_bigquery_insert_job_async] - insert_query_job = BigQueryInsertJobAsyncOperator( + insert_query_job = BigQueryInsertJobOperator( task_id="insert_query_job", configuration={ "query": { @@ -107,11 +137,12 @@ } }, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_insert_job_async] # [START howto_operator_bigquery_select_job_async] - select_query_job = BigQueryInsertJobAsyncOperator( + select_query_job = BigQueryInsertJobOperator( task_id="select_query_job", configuration={ "query": { @@ -120,32 +151,35 @@ } }, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_select_job_async] # [START howto_operator_bigquery_value_check_async] - check_value = BigQueryValueCheckAsyncOperator( + check_value = BigQueryValueCheckOperator( task_id="check_value", sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}", pass_value=2, use_legacy_sql=False, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_value_check_async] # [START howto_operator_bigquery_interval_check_async] - check_interval = BigQueryIntervalCheckAsyncOperator( + check_interval = BigQueryIntervalCheckOperator( task_id="check_interval", table=f"{DATASET}.{TABLE_1}", days_back=1, metrics_thresholds={"COUNT(*)": 1.5}, use_legacy_sql=False, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_interval_check_async] # [START howto_operator_bigquery_multi_query_async] - bigquery_execute_multi_query = BigQueryInsertJobAsyncOperator( + bigquery_execute_multi_query = BigQueryInsertJobOperator( task_id="execute_multi_query", configuration={ "query": { @@ -157,17 +191,19 @@ } }, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_multi_query_async] # [START howto_operator_bigquery_get_data_async] - get_data = BigQueryGetDataAsyncOperator( + get_data = BigQueryGetDataOperator( task_id="get_data", dataset_id=DATASET, table_id=TABLE_1, max_results=10, selected_fields="value,name", location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_get_data_async] @@ -178,16 +214,17 @@ ) # [START howto_operator_bigquery_check_async] - check_count = BigQueryCheckAsyncOperator( + check_count = BigQueryCheckOperator( task_id="check_count", sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}", use_legacy_sql=False, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_check_async] # [START howto_operator_bigquery_execute_query_save_async] - execute_query_save = BigQueryInsertJobAsyncOperator( + execute_query_save = BigQueryInsertJobOperator( task_id="execute_query_save", configuration={ "query": { @@ -201,48 +238,31 @@ } }, location=LOCATION, + deferrable=True, ) # [END howto_operator_bigquery_execute_query_save_async] - execute_long_running_query = BigQueryInsertJobAsyncOperator( + execute_long_running_query = BigQueryInsertJobOperator( task_id="execute_long_running_query", - configuration={ - "query": { - "query": f"""DECLARE success BOOL; - DECLARE size_bytes INT64; - DECLARE row_count INT64; - DECLARE DELAY_TIME DATETIME; - DECLARE WAIT STRING; - SET success = FALSE; - - SELECT row_count = (SELECT row_count FROM {DATASET}.__TABLES__ WHERE table_id='NON_EXISTING_TABLE'); - IF row_count > 0 THEN - SELECT 'Table Exists!' as message, retry_count as retries; - SET success = TRUE; - ELSE - SELECT 'Table does not exist' as message, row_count; - SET WAIT = 'TRUE'; - SET DELAY_TIME = DATETIME_ADD(CURRENT_DATETIME,INTERVAL 1 MINUTE); - WHILE WAIT = 'TRUE' DO - IF (DELAY_TIME < CURRENT_DATETIME) THEN - SET WAIT = 'FALSE'; - END IF; - END WHILE; - END IF;""", - "useLegacySql": False, - } - }, + configuration=CONFIGURATION, location=LOCATION, + deferrable=True, ) - end = EmptyOperator(task_id="end") - - create_table_1 >> insert_query_job >> select_query_job >> check_count + create_dataset >> create_table_1 >> insert_query_job + insert_query_job >> select_query_job >> check_count insert_query_job >> get_data >> get_data_result insert_query_job >> execute_query_save >> bigquery_execute_multi_query insert_query_job >> execute_long_running_query >> check_value >> check_interval [check_count, check_interval, bigquery_execute_multi_query, get_data_result] >> delete_dataset - [check_count, check_interval, bigquery_execute_multi_query, get_data_result, delete_dataset] >> end + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() from tests.system.utils import get_test_run # noqa: E402