diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 3ec149b577a53..d9852218da5e3 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -29,8 +29,10 @@ import warnings from copy import deepcopy from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union, cast +from aiohttp import ClientSession as ClientSession +from gcloud.aio.bigquery import Job, Table as Table_async from google.api_core.retry import Retry from google.cloud.bigquery import ( DEFAULT_RETRY, @@ -49,12 +51,13 @@ from pandas import DataFrame from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa +from requests import Session from sqlalchemy import create_engine from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.log.logging_mixin import LoggingMixin @@ -2305,7 +2308,6 @@ def __init__( num_retries: int = 5, labels: Optional[Dict] = None, ) -> None: - super().__init__() self.service = service self.project_id = project_id @@ -2870,7 +2872,6 @@ def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, s def split_tablename( table_input: str, default_project_id: str, var_name: Optional[str] = None ) -> Tuple[str, str, str]: - if '.' not in table_input: raise ValueError(f'Expected table name in the format of .. Got: {table_input}') @@ -3010,3 +3011,253 @@ def _format_schema_for_description(schema: Dict) -> List: ) description.append(field_description) return description + + +class BigQueryAsyncHook(GoogleBaseAsyncHook): + """Uses gcloud-aio library to retrieve Job details""" + + sync_hook_class = BigQueryHook + + async def get_job_instance( + self, project_id: Optional[str], job_id: Optional[str], session: ClientSession + ) -> Job: + """Get the specified job resource by job ID and project ID.""" + with await self.service_file_as_context() as f: + return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session)) + + async def get_job_status( + self, + job_id: Optional[str], + project_id: Optional[str] = None, + ) -> Optional[str]: + """ + Polls for job status asynchronously using gcloud-aio. + + Note that an OSError is raised when Job results are still pending. + Exception means that Job finished with errors + """ + async with ClientSession() as s: + try: + self.log.info("Executing get_job_status...") + job_client = await self.get_job_instance(project_id, job_id, s) + job_status_response = await job_client.result(cast(Session, s)) + if job_status_response: + job_status = "success" + except OSError: + job_status = "pending" + except Exception as e: + self.log.info("Query execution finished with errors...") + job_status = str(e) + return job_status + + async def get_job_output( + self, + job_id: Optional[str], + project_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Get the big query job output for the given job id asynchronously using gcloud-aio.""" + async with ClientSession() as session: + self.log.info("Executing get_job_output..") + job_client = await self.get_job_instance(project_id, job_id, session) + job_query_response = await job_client.get_query_results(cast(Session, session)) + return job_query_response + + def get_records(self, query_results: Dict[str, Any]) -> List[Any]: + """ + Given the output query response from gcloud-aio bigquery, convert the response to records. + + :param query_results: the results from a SQL query + """ + buffer = [] + if "rows" in query_results and query_results["rows"]: + rows = query_results["rows"] + for dict_row in rows: + typed_row = [vs["v"] for vs in dict_row["f"]] + buffer.append(typed_row) + return buffer + + def value_check( + self, + sql: str, + pass_value: Any, + records: List[Any], + tolerance: Optional[float] = None, + ) -> None: + """ + Match a single query resulting row and tolerance with pass_value + + :return: If Match fail, we throw an AirflowException. + """ + if not records: + raise AirflowException("The query returned None") + pass_value_conv = self._convert_to_float_if_possible(pass_value) + is_numeric_value_check = isinstance(pass_value_conv, float) + tolerance_pct_str = str(tolerance * 100) + "%" if tolerance else None + + error_msg = ( + "Test failed.\nPass value:{pass_value_conv}\n" + "Tolerance:{tolerance_pct_str}\n" + "Query:\n{sql}\nResults:\n{records!s}" + ).format( + pass_value_conv=pass_value_conv, + tolerance_pct_str=tolerance_pct_str, + sql=sql, + records=records, + ) + + if not is_numeric_value_check: + tests = [str(record) == pass_value_conv for record in records] + else: + try: + numeric_records = [float(record) for record in records] + except (ValueError, TypeError): + raise AirflowException(f"Converting a result to float failed.\n{error_msg}") + tests = self._get_numeric_matches(numeric_records, pass_value_conv, tolerance) + + if not all(tests): + raise AirflowException(error_msg) + + @staticmethod + def _get_numeric_matches( + records: List[float], pass_value: Any, tolerance: Optional[float] = None + ) -> List[bool]: + """ + A helper function to match numeric pass_value, tolerance with records value + + :param records: List of value to match against + :param pass_value: Expected value + :param tolerance: Allowed tolerance for match to succeed + """ + if tolerance: + return [ + pass_value * (1 - tolerance) <= record <= pass_value * (1 + tolerance) for record in records + ] + + return [record == pass_value for record in records] + + @staticmethod + def _convert_to_float_if_possible(s: Any) -> Any: + """ + A small helper function to convert a string to a numeric value if appropriate + + :param s: the string to be converted + """ + try: + return float(s) + except (ValueError, TypeError): + return s + + def interval_check( + self, + row1: Optional[str], + row2: Optional[str], + metrics_thresholds: Dict[str, Any], + ignore_zero: bool, + ratio_formula: str, + ) -> None: + """ + Checks that the values of metrics given as SQL expressions are within a certain tolerance + + :param row1: first resulting row of a query execution job for first SQL query + :param row2: first resulting row of a query execution job for second SQL query + :param metrics_thresholds: a dictionary of ratios indexed by metrics, for + example 'COUNT(*)': 1.5 would require a 50 percent or less difference + between the current day, and the prior days_back. + :param ignore_zero: whether we should ignore zero metrics + :param ratio_formula: which formula to use to compute the ratio between + the two metrics. Assuming cur is the metric of today and ref is + the metric to today - days_back. + max_over_min: computes max(cur, ref) / min(cur, ref) + relative_diff: computes abs(cur-ref) / ref + """ + if not row2: + raise AirflowException("The second SQL query returned None") + if not row1: + raise AirflowException("The first SQL query returned None") + + ratio_formulas = { + "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref), + "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref, + } + + metrics_sorted = sorted(metrics_thresholds.keys()) + + current = dict(zip(metrics_sorted, row1)) + reference = dict(zip(metrics_sorted, row2)) + ratios: Dict[str, Any] = {} + test_results: Dict[str, Any] = {} + + for metric in metrics_sorted: + cur = float(current[metric]) + ref = float(reference[metric]) + threshold = float(metrics_thresholds[metric]) + if cur == 0 or ref == 0: + ratios[metric] = None + test_results[metric] = ignore_zero + else: + ratios[metric] = ratio_formulas[ratio_formula]( + float(current[metric]), float(reference[metric]) + ) + test_results[metric] = float(ratios[metric]) < threshold + + self.log.info( + ( + "Current metric for %s: %s\n" + "Past metric for %s: %s\n" + "Ratio for %s: %s\n" + "Threshold: %s\n" + ), + metric, + cur, + metric, + ref, + metric, + ratios[metric], + threshold, + ) + + if not all(test_results.values()): + failed_tests = [metric for metric, value in test_results.items() if not value] + self.log.warning( + "The following %s tests out of %s failed:", + len(failed_tests), + len(metrics_sorted), + ) + for k in failed_tests: + self.log.warning( + "'%s' check failed. %s is above %s", + k, + ratios[k], + metrics_thresholds[k], + ) + raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") + + self.log.info("All tests have passed") + + +class BigQueryTableAsyncHook(GoogleBaseAsyncHook): + """Class to get async hook for Bigquery Table Async""" + + sync_hook_class = BigQueryHook + + async def get_table_client( + self, dataset: str, table_id: str, project_id: str, session: ClientSession + ) -> Table_async: + """ + Returns a Google Big Query Table object. + + :param dataset: The name of the dataset in which to look for the table storage bucket. + :param table_id: The name of the table to check the existence of. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :param session: aiohttp ClientSession + """ + with await self.service_file_as_context() as file: + return Table_async( + dataset_name=dataset, + table_name=table_id, + project=project_id, + service_file=file, + session=cast(Session, session), + ) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 74b9bd0a2eb00..b361378c8c724 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -40,6 +40,13 @@ from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url from airflow.providers.google.cloud.links.bigquery import BigQueryDatasetLink, BigQueryTableLink +from airflow.providers.google.cloud.triggers.bigquery import ( + BigQueryCheckTrigger, + BigQueryGetDataTrigger, + BigQueryInsertJobTrigger, + BigQueryIntervalCheckTrigger, + BigQueryValueCheckTrigger, +) if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey @@ -2241,3 +2248,460 @@ def on_kill(self) -> None: ) else: self.log.info('Skipping to cancel job: %s:%s.%s', self.project_id, self.location, self.job_id) + + +class BigQueryInsertJobAsyncOperator(BigQueryInsertJobOperator, BaseOperator): + """ + Starts a BigQuery job asynchronously, and returns job id. + This operator works in the following way: + + - it calculates a unique hash of the job using job's configuration or uuid if ``force_rerun`` is True + - creates ``job_id`` in form of + ``[provided_job_id | airflow_{dag_id}_{task_id}_{exec_date}]_{uniqueness_suffix}`` + - submits a BigQuery job using the ``job_id`` + - if job with given id already exists then it tries to reattach to the job if its not done and its + state is in ``reattach_states``. If the job is done the operator will raise ``AirflowException``. + + Using ``force_rerun`` will submit a new job every time without attaching to already existing ones. + + For job definition see here: + + https://cloud.google.com/bigquery/docs/reference/v2/jobs + + :param configuration: The configuration parameter maps directly to BigQuery's + configuration field in the job object. For more details see + https://cloud.google.com/bigquery/docs/reference/v2/jobs + :param job_id: The ID of the job. It will be suffixed with hash of job configuration + unless ``force_rerun`` is True. + The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or + dashes (-). The maximum length is 1,024 characters. If not provided then uuid will + be generated. + :param force_rerun: If True then operator will use hash of uuid as job id suffix + :param reattach_states: Set of BigQuery job's states in case of which we should reattach + to the job. Should be other than final states. + :param project_id: Google Cloud Project where the job is running + :param location: location the job is running + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called + """ + + def _submit_job(self, hook: BigQueryHook, job_id: str) -> BigQueryJob: # type: ignore[override] + """Submit a new job and get the job id for polling the status using Triggerer.""" + return hook.insert_job( + configuration=self.configuration, + project_id=self.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Any) -> None: + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + + self.hook = hook + job_id = self.hook.generate_job_id( + job_id=self.job_id, + dag_id=self.dag_id, + task_id=self.task_id, + logical_date=context["logical_date"], + configuration=self.configuration, + force_rerun=self.force_rerun, + ) + + try: + job = self._submit_job(hook, job_id) + self._handle_job_error(job) + except Conflict: + # If the job already exists retrieve it + job = hook.get_job( + project_id=self.project_id, + location=self.location, + job_id=job_id, + ) + if job.state in self.reattach_states: + # We are reattaching to a job + job._begin() + self._handle_job_error(job) + else: + # Same job configuration so we need force_rerun + raise AirflowException( + f"Job with id: {job_id} already exists and is in {job.state} state. If you " + f"want to force rerun it consider setting `force_rerun=True`." + f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" + ) + + self.job_id = job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryInsertJobTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + project_id=self.project_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Any, event: Dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) + + +class BigQueryCheckAsyncOperator(BigQueryCheckOperator): + """ + BigQueryCheckAsyncOperator is asynchronous operator, submit the job and check + for the status in async mode by using the job id + """ + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Trigger.""" + configuration = {"query": {"query": self.sql}} + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Any) -> None: + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + ) + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Any, event: Dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + records = event["records"] + if not records: + raise AirflowException("The query returned None") + elif not all(bool(r) for r in records): + raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") + self.log.info("Record: %s", event["records"]) + self.log.info("Success.") + + +class BigQueryGetDataAsyncOperator(BigQueryGetDataOperator): + """ + Fetches the data from a BigQuery table (alternatively fetch data for selected columns) + and returns data in a python list. The number of elements in the returned list will + be equal to the number of rows fetched. Each element in the list will again be a list + where element would represent the columns values for that row. + + **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]`` + + .. note:: + If you pass fields to ``selected_fields`` which are in different order than the + order of columns already in + BQ table, the data will still be in the order of BQ table. + For example if the BQ table has 3 columns as + ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields`` + the data would still be of the form ``'A,B'``. + + **Example**: :: + + get_data = BigQueryGetDataOperator( + task_id='get_data_from_bq', + dataset_id='test_dataset', + table_id='Transaction_partitions', + max_results=100, + selected_fields='DATE', + gcp_conn_id='airflow-conn-id' + ) + + :param dataset_id: The dataset ID of the requested table. (templated) + :param table_id: The table ID of the requested table. (templated) + :param max_results: The maximum number of records (rows) to be fetched from the table. (templated) + :param selected_fields: List of fields to return (comma-separated). If + unspecified, all fields are returned. + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param location: The location used for the operation. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + configuration: Dict[str, Any], + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + return hook.insert_job( + configuration=configuration, + location=self.location, + project_id=hook.project_id, + job_id=job_id, + nowait=True, + ) + + def generate_query(self) -> str: + """ + Generate a select query if selected fields are given or with * + for the given dataset and table id + """ + selected_fields = self.selected_fields if self.selected_fields else "*" + return f"select {selected_fields} from {self.dataset_id}.{self.table_id} limit {self.max_results}" + + def execute(self, context: Any) -> None: # type: ignore[override] + get_query = self.generate_query() + configuration = {"query": {"query": get_query}} + + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + self.hook = hook + job = self._submit_job(hook, job_id="", configuration=configuration) + self.job_id = job.job_id + context["ti"].xcom_push(key="job_id", value=self.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryGetDataTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=hook.project_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Any, event: Dict[str, Any]) -> Any: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + self.log.info("Total extracted rows: %s", len(event["records"])) + return event["records"] + + +class BigQueryIntervalCheckAsyncOperator(BigQueryIntervalCheckOperator): + """ + Checks asynchronously that the values of metrics given as SQL expressions are within + a certain tolerance of the ones from days_back before. + + This method constructs a query like so :: + SELECT {metrics_threshold_dict_key} FROM {table} + WHERE {date_filter_column}= + + :param table: the table name + :param days_back: number of days between ds and the ds we want to check + against. Defaults to 7 days + :param metrics_thresholds: a dictionary of ratios indexed by metrics, for + example 'COUNT(*)': 1.5 would require a 50 percent or less difference + between the current day, and the prior days_back. + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param labels: a dictionary containing labels for the table, passed to BigQuery + """ + + def _submit_job( + self, + hook: BigQueryHook, + sql: str, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + configuration = {"query": {"query": sql}} + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Any) -> None: + """Execute the job in sync mode and defers the trigger with job id to poll for the status""" + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + self.log.info("Using ratio formula: %s", self.ratio_formula) + + self.log.info("Executing SQL check: %s", self.sql1) + job_1 = self._submit_job(hook, sql=self.sql1, job_id="") + context["ti"].xcom_push(key="job_id", value=job_1.job_id) + + self.log.info("Executing SQL check: %s", self.sql2) + job_2 = self._submit_job(hook, sql=self.sql2, job_id="") + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryIntervalCheckTrigger( + conn_id=self.gcp_conn_id, + first_job_id=job_1.job_id, + second_job_id=job_2.job_id, + project_id=hook.project_id, + table=self.table, + metrics_thresholds=self.metrics_thresholds, + date_filter_column=self.date_filter_column, + days_back=self.days_back, + ratio_formula=self.ratio_formula, + ignore_zero=self.ignore_zero, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Any, event: Dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + self.log.info( + "%s completed with response %s ", + self.task_id, + event["status"], + ) + + +class BigQueryValueCheckAsyncOperator(BigQueryValueCheckOperator): + """ + Performs a simple value check using sql code. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryValueCheckOperator` + + :param sql: the sql to be executed + :param use_legacy_sql: Whether to use legacy SQL (true) + or standard SQL (false). + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param location: The geographic location of the job. See details at: + https://cloud.google.com/bigquery/docs/locations#specifying_your_location + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param labels: a dictionary containing labels for the table, passed to BigQuery + """ + + def _submit_job( + self, + hook: BigQueryHook, + job_id: str, + ) -> BigQueryJob: + """Submit a new job and get the job id for polling the status using Triggerer.""" + configuration = { + "query": { + "query": self.sql, + "useLegacySql": False, + } + } + if self.use_legacy_sql: + configuration["query"]["useLegacySql"] = self.use_legacy_sql + + return hook.insert_job( + configuration=configuration, + project_id=hook.project_id, + location=self.location, + job_id=job_id, + nowait=True, + ) + + def execute(self, context: Any) -> None: + hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) + + job = self._submit_job(hook, job_id="") + context["ti"].xcom_push(key="job_id", value=job.job_id) + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryValueCheckTrigger( + conn_id=self.gcp_conn_id, + job_id=job.job_id, + project_id=hook.project_id, + sql=self.sql, + pass_value=self.pass_value, + tolerance=self.tol, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Any, event: Dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info( + "%s completed with response %s ", + self.task_id, + event["message"], + ) diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index f0a9d67f58e5e..ddcd203eb8628 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -16,9 +16,12 @@ # specific language governing permissions and limitations # under the License. """This module contains a Google Bigquery sensor.""" -from typing import TYPE_CHECKING, Optional, Sequence, Union +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -68,7 +71,6 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: - super().__init__(**kwargs) self.project_id = project_id @@ -137,7 +139,6 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: - super().__init__(**kwargs) self.project_id = project_id @@ -162,3 +163,73 @@ def poke(self, context: 'Context') -> bool: table_id=self.table_id, partition_id=self.partition_id, ) + + +class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor): + """ + Checks for the existence of a table in Google Big Query. + + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :param dataset_id: The name of the dataset in which to look for the table. + storage bucket. + :param table_id: The name of the table to check the existence of. + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param polling_interval: The interval in seconds to wait between checks table existence. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + polling_interval: float = 5.0, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.polling_interval = polling_interval + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: 'Context') -> None: + """Airflow runs this method on the worker and defers using the trigger.""" + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=BigQueryTableExistenceTrigger( + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + poll_interval=self.polling_interval, + gcp_conn_id=self.gcp_conn_id, + hook_params={ + "delegate_to": self.delegate_to, + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" + self.log.info("Sensor checks existence of table: %s", table_uri) + if event: + if event["status"] == "success": + return event["message"] + raise AirflowException(event["message"]) + raise AirflowException("No event received in trigger callback") diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py new file mode 100644 index 0000000000000..56e985ea1f3d7 --- /dev/null +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -0,0 +1,528 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +from typing import Any, AsyncIterator, Dict, Optional, SupportsAbs, Tuple, Union + +from aiohttp import ClientSession +from aiohttp.client_exceptions import ClientResponseError + +from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class BigQueryInsertJobTrigger(BaseTrigger): + """ + BigQueryInsertJobTrigger run on the trigger worker to perform insert operation + + :param conn_id: Reference to google cloud connection id + :param job_id: The ID of the job. It will be suffixed with hash of job configuration + :param project_id: Google Cloud Project where the job is running + :param dataset_id: The dataset ID of the requested table. (templated) + :param table_id: The table ID of the requested table. (templated) + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + conn_id: str, + job_id: Optional[str], + project_id: Optional[str], + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + poll_interval: float = 4.0, + ): + super().__init__() + self.log.info("Using the connection %s .", conn_id) + self.conn_id = conn_id + self.job_id = job_id + self._job_conn = None + self.dataset_id = dataset_id + self.project_id = project_id + self.table_id = table_id + self.poll_interval = poll_interval + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryInsertJobTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + self.log.debug("Response from hook: %s", response_from_hook) + + if response_from_hook == "success": + yield TriggerEvent( + { + "job_id": self.job_id, + "status": "success", + "message": "Job completed", + } + ) + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook}) + + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + + def _get_async_hook(self) -> BigQueryAsyncHook: + return BigQueryAsyncHook(gcp_conn_id=self.conn_id) + + +class BigQueryCheckTrigger(BigQueryInsertJobTrigger): + """BigQueryCheckTrigger run on the trigger worker""" + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryCheckTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + if response_from_hook == "success": + query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) + + records = hook.get_records(query_results) + + # If empty list, then no records are available + if not records: + yield TriggerEvent( + { + "status": "success", + "records": None, + } + ) + else: + # Extract only first record from the query results + first_record = records.pop(0) + yield TriggerEvent( + { + "status": "success", + "records": first_record, + } + ) + return + + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook}) + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + + +class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): + """BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class""" + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryInsertJobTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger", + { + "conn_id": self.conn_id, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent with response data""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + if response_from_hook == "success": + query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) + records = hook.get_records(query_results) + self.log.debug("Response from hook: %s", response_from_hook) + yield TriggerEvent( + { + "status": "success", + "message": response_from_hook, + "records": records, + } + ) + return + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook}) + return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + +class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): + """ + BigQueryIntervalCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class + + :param conn_id: Reference to google cloud connection id + :param first_job_id: The ID of the job 1 performed + :param second_job_id: The ID of the job 2 performed + :param project_id: Google Cloud Project where the job is running + :param dataset_id: The dataset ID of the requested table. (templated) + :param table: table name + :param metrics_thresholds: dictionary of ratios indexed by metrics + :param date_filter_column: column name + :param days_back: number of days between ds and the ds we want to check + against + :param ratio_formula: ration formula + :param ignore_zero: boolean value to consider zero or not + :param table_id: The table ID of the requested table. (templated) + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + conn_id: str, + first_job_id: str, + second_job_id: str, + project_id: Optional[str], + table: str, + metrics_thresholds: Dict[str, int], + date_filter_column: Optional[str] = "ds", + days_back: SupportsAbs[int] = -7, + ratio_formula: str = "max_over_min", + ignore_zero: bool = True, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + poll_interval: float = 4.0, + ): + super().__init__( + conn_id=conn_id, + job_id=first_job_id, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + poll_interval=poll_interval, + ) + self.conn_id = conn_id + self.first_job_id = first_job_id + self.second_job_id = second_job_id + self.project_id = project_id + self.table = table + self.metrics_thresholds = metrics_thresholds + self.date_filter_column = date_filter_column + self.days_back = days_back + self.ratio_formula = ratio_formula + self.ignore_zero = ignore_zero + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryCheckTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger", + { + "conn_id": self.conn_id, + "first_job_id": self.first_job_id, + "second_job_id": self.second_job_id, + "project_id": self.project_id, + "table": self.table, + "metrics_thresholds": self.metrics_thresholds, + "date_filter_column": self.date_filter_column, + "days_back": self.days_back, + "ratio_formula": self.ratio_formula, + "ignore_zero": self.ignore_zero, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + first_job_response_from_hook = await hook.get_job_status( + job_id=self.first_job_id, project_id=self.project_id + ) + second_job_response_from_hook = await hook.get_job_status( + job_id=self.second_job_id, project_id=self.project_id + ) + + if first_job_response_from_hook == "success" and second_job_response_from_hook == "success": + first_query_results = await hook.get_job_output( + job_id=self.first_job_id, project_id=self.project_id + ) + + second_query_results = await hook.get_job_output( + job_id=self.second_job_id, project_id=self.project_id + ) + + first_records = hook.get_records(first_query_results) + + second_records = hook.get_records(second_query_results) + + # If empty list, then no records are available + if not first_records: + first_job_row: Optional[str] = None + else: + # Extract only first record from the query results + first_job_row = first_records.pop(0) + + # If empty list, then no records are available + if not second_records: + second_job_row: Optional[str] = None + else: + # Extract only first record from the query results + second_job_row = second_records.pop(0) + + hook.interval_check( + first_job_row, + second_job_row, + self.metrics_thresholds, + self.ignore_zero, + self.ratio_formula, + ) + + yield TriggerEvent( + { + "status": "success", + "message": "Job completed", + "first_row_data": first_job_row, + "second_row_data": second_job_row, + } + ) + return + elif first_job_response_from_hook == "pending" or second_job_response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent( + {"status": "error", "message": second_job_response_from_hook, "data": None} + ) + return + + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + +class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): + """ + BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class + + :param conn_id: Reference to google cloud connection id + :param sql: the sql to be executed + :param pass_value: pass value + :param job_id: The ID of the job + :param project_id: Google Cloud Project where the job is running + :param tolerance: certain metrics for tolerance + :param dataset_id: The dataset ID of the requested table. (templated) + :param table_id: The table ID of the requested table. (templated) + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + conn_id: str, + sql: str, + pass_value: Union[int, float, str], + job_id: Optional[str], + project_id: Optional[str], + tolerance: Any = None, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + poll_interval: float = 4.0, + ): + super().__init__( + conn_id=conn_id, + job_id=job_id, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + poll_interval=poll_interval, + ) + self.sql = sql + self.pass_value = pass_value + self.tolerance = tolerance + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryValueCheckTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger", + { + "conn_id": self.conn_id, + "pass_value": self.pass_value, + "job_id": self.job_id, + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "sql": self.sql, + "table_id": self.table_id, + "tolerance": self.tolerance, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Gets current job execution status and yields a TriggerEvent""" + hook = self._get_async_hook() + while True: + try: + # Poll for job execution status + response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) + if response_from_hook == "success": + query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) + records = hook.get_records(query_results) + records = records.pop(0) if records else None + hook.value_check(self.sql, self.pass_value, records, self.tolerance) + yield TriggerEvent({"status": "success", "message": "Job completed", "records": records}) + return + elif response_from_hook == "pending": + self.log.info("Query is still running...") + self.log.info("Sleeping for %s seconds.", self.poll_interval) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent({"status": "error", "message": response_from_hook, "records": None}) + return + + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + +class BigQueryTableExistenceTrigger(BaseTrigger): + """ + Initialize the BigQuery Table Existence Trigger with needed parameters + + :param project_id: Google Cloud Project where the job is running + :param dataset_id: The dataset ID of the requested table. + :param table_id: The table ID of the requested table. + :param gcp_conn_id: Reference to google cloud connection id + :param hook_params: params for hook + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + project_id: str, + dataset_id: str, + table_id: str, + gcp_conn_id: str, + hook_params: Dict[str, Any], + poll_interval: float = 4.0, + ): + self.dataset_id = dataset_id + self.project_id = project_id + self.table_id = table_id + self.gcp_conn_id: str = gcp_conn_id + self.poll_interval = poll_interval + self.hook_params = hook_params + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryTableExistenceTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger", + { + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "gcp_conn_id": self.gcp_conn_id, + "poll_interval": self.poll_interval, + "hook_params": self.hook_params, + }, + ) + + def _get_async_hook(self) -> BigQueryTableAsyncHook: + return BigQueryTableAsyncHook(gcp_conn_id=self.gcp_conn_id) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Will run until the table exists in the Google Big Query.""" + while True: + try: + hook = self._get_async_hook() + response = await self._table_exists( + hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id + ) + if response: + yield TriggerEvent({"status": "success", "message": "success"}) + return + await asyncio.sleep(self.poll_interval) + except Exception as e: + self.log.exception("Exception occurred while checking for Table existence") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + async def _table_exists( + self, hook: BigQueryTableAsyncHook, dataset: str, table_id: str, project_id: str + ) -> bool: + """ + Create client session and make call to BigQueryTableAsyncHook and check for the table in + Google Big Query. + + :param hook: BigQueryTableAsyncHook Hook class + :param dataset: The name of the dataset in which to look for the table storage bucket. + :param table_id: The name of the table to check the existence of. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + """ + async with ClientSession() as session: + try: + client = await hook.get_table_client( + dataset=dataset, table_id=table_id, project_id=project_id, session=session + ) + response = await client.get() + return True if response else False + except ClientResponseError as err: + if err.status == 404: + return False + raise err diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 79ed286a3db78..e2179bedfa4e1 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -33,6 +33,7 @@ import google_auth_httplib2 import requests import tenacity +from asgiref.sync import sync_to_async from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests from google.api_core.gapic_v1.client_info import ClientInfo from google.auth import _cloud_sdk, compute_engine @@ -56,7 +57,6 @@ log = logging.getLogger(__name__) - # Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota. INVALID_KEYS = [ 'DefaultRequestsPerMinutePerProject', @@ -602,3 +602,26 @@ def test_connection(self): message = str(e) return status, message + + +class GoogleBaseAsyncHook(BaseHook): + """GoogleBaseAsyncHook inherits from BaseHook class, run on the trigger worker""" + + sync_hook_class: Any = None + + def __init__(self, **kwargs: Any): + self._hook_kwargs = kwargs + self._sync_hook = None + + async def get_sync_hook(self) -> Any: + """ + Sync version of the Google Cloud Hooks makes blocking calls in ``__init__`` so we don't inherit + from it. + """ + if not self._sync_hook: + self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs) + return self._sync_hook + + async def service_file_as_context(self) -> Any: + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)() diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 5923805302da7..a3edbf5867758 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -64,6 +64,9 @@ dependencies: # Introduced breaking changes across the board. Those libraries should be upgraded soon # TODO: Upgrade all Google libraries that are limited to <2.0.0 - PyOpenSSL + - asgiref + - gcloud-aio-bigquery + - gcloud-aio-storage - google-ads>=15.1.1 - google-api-core>=2.7.0,<3.0.0 - google-api-python-client>=1.6.0,<2.0.0 diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst index 57e4d87ff8c03..6672d2ce58573 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst @@ -202,7 +202,8 @@ Fetch data from table """"""""""""""""""""" To fetch data from a BigQuery table you can use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator`. +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator` or +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataAsyncOperator` . Alternatively you can fetch data for selected columns if you pass fields to ``selected_fields``. @@ -217,6 +218,17 @@ that row. :start-after: [START howto_operator_bigquery_get_data] :end-before: [END howto_operator_bigquery_get_data] +The below example shows how to use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataAsyncOperator`. +Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow +deployment. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_get_data_async] + :end-before: [END howto_operator_bigquery_get_data_async] + .. _howto/operator:BigQueryUpsertTableOperator: Upsert table @@ -294,9 +306,10 @@ Let's say you would like to execute the following query. :start-after: [START howto_operator_bigquery_query] :end-before: [END howto_operator_bigquery_query] -To execute the SQL query in a specific BigQuery database you can use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobOperator` with -proper query job configuration that can be Jinja templated. +To execute the SQL query in a specific BigQuery database you can use either +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobOperator` or +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobAsyncOperator` +with proper query job configuration that can be Jinja templated. .. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py :language: python @@ -304,6 +317,17 @@ proper query job configuration that can be Jinja templated. :start-after: [START howto_operator_bigquery_insert_job] :end-before: [END howto_operator_bigquery_insert_job] +The below example shows how to use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobAsyncOperator`. +Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow +deployment. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_insert_job_async] + :end-before: [END howto_operator_bigquery_insert_job_async] + For more information on types of BigQuery job please check `documentation `__. @@ -332,8 +356,9 @@ Validate data Check if query result has data """""""""""""""""""""""""""""" -To perform checks against BigQuery you can use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator`. +To perform checks against BigQuery you can use either +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator` or +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckAsyncOperator` This operator expects a sql query that will return a single row. Each value on that first row is evaluated using python ``bool`` casting. If any of the values @@ -345,15 +370,25 @@ return ``False`` the check is failed and errors out. :start-after: [START howto_operator_bigquery_check] :end-before: [END howto_operator_bigquery_check] +Below example shows the usage of :class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckAsyncOperator`, +which is the deferrable version of the operator + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_check_async] + :end-before: [END howto_operator_bigquery_check_async] + .. _howto/operator:BigQueryValueCheckOperator: Compare query result to pass value """""""""""""""""""""""""""""""""" To perform a simple value check using sql code you can use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator`. +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator` or +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckAsyncOperator` -This operator expects a sql query that will return a single row. Each value on +These operators expects a sql query that will return a single row. Each value on that first row is evaluated against ``pass_value`` which can be either a string or numeric value. If numeric, you can also specify ``tolerance``. @@ -363,14 +398,26 @@ or numeric value. If numeric, you can also specify ``tolerance``. :start-after: [START howto_operator_bigquery_value_check] :end-before: [END howto_operator_bigquery_value_check] +The below example shows how to use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckAsyncOperator`. +Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow +deployment. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_value_check_async] + :end-before: [END howto_operator_bigquery_value_check_async] + .. _howto/operator:BigQueryIntervalCheckOperator: Compare metrics over time """"""""""""""""""""""""" To check that the values of metrics given as SQL expressions are within a certain -tolerance of the ones from ``days_back`` before you can use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator`. +tolerance of the ones from ``days_back`` before you can either use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator` or +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckAsyncOperator` .. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py :language: python @@ -378,6 +425,17 @@ tolerance of the ones from ``days_back`` before you can use :start-after: [START howto_operator_bigquery_interval_check] :end-before: [END howto_operator_bigquery_interval_check] +The below example shows how to use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckAsyncOperator`. +Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow +deployment. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_interval_check_async] + :end-before: [END howto_operator_bigquery_interval_check_async] + Sensors ^^^^^^^ @@ -396,6 +454,17 @@ use the ``{{ ds_nodash }}`` macro as the table name suffix. :start-after: [START howto_sensor_bigquery_table] :end-before: [END howto_sensor_bigquery_table] +Use the :class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor` +(deferrable version) if you would like to free up the worker slots while the sensor is running. + +:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor`. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_async_bigquery_table] + :end-before: [END howto_sensor_async_bigquery_table] + Check that a Table Partition exists """"""""""""""""""""""""""""""""""" diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 4d8bcd6eec823..f7539f34ffc02 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -14,6 +14,7 @@ adhoc adls afterall AgentKey +aio Airbnb airbnb Airbyte @@ -71,6 +72,7 @@ asc ascii asciiart asctime +asend asia assertEqualIgnoreMultipleSpaces assigment diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index a4cda45f8995c..bc266ec27a71a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -301,6 +301,9 @@ "PyOpenSSL", "apache-airflow-providers-common-sql>=1.1.0", "apache-airflow>=2.2.0", + "asgiref", + "gcloud-aio-bigquery", + "gcloud-aio-storage", "google-ads>=15.1.1", "google-api-core>=2.7.0,<3.0.0", "google-api-python-client>=1.6.0,<2.0.0", diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 6d69431f14d51..c2342dd8fdf6a 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -18,20 +18,23 @@ import re +import sys import unittest from datetime import datetime -from unittest import mock import pytest +from gcloud.aio.bigquery import Job, Table as Table_async from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, TableReference from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem from google.cloud.exceptions import NotFound from parameterized import parameterized -from airflow import AirflowException +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.bigquery import ( + BigQueryAsyncHook, BigQueryCursor, BigQueryHook, + BigQueryTableAsyncHook, _api_resource_configs_duplication_check, _cleanse_time_partitioning, _format_schema_for_description, @@ -40,6 +43,13 @@ split_tablename, ) +if sys.version_info < (3, 8): + from asynctest import mock + from asynctest.mock import CoroutineMock as AsyncMock +else: + from unittest import mock + from unittest.mock import AsyncMock + PROJECT_ID = "bq-project" CREDENTIALS = "bq-credentials" DATASET_ID = "bq_dataset" @@ -2011,7 +2021,6 @@ def test_deprecation_warning(self, func_name, mock_bq_hook): class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_load_labels(self, mock_insert): - labels = {'label1': 'test1', 'label2': 'test2'} self.hook.run_load( destination_project_dataset_table='my_dataset.my_table', @@ -2025,7 +2034,6 @@ def test_run_load_labels(self, mock_insert): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_load_description(self, mock_insert): - description = "Test Description" self.hook.run_load( destination_project_dataset_table='my_dataset.my_table', @@ -2039,7 +2047,6 @@ def test_run_load_description(self, mock_insert): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table") def test_create_external_table_labels(self, mock_create): - labels = {'label1': 'test1', 'label2': 'test2'} self.hook.create_external_table( external_project_dataset_table='my_dataset.my_table', @@ -2053,7 +2060,6 @@ def test_create_external_table_labels(self, mock_create): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table") def test_create_external_table_description(self, mock_create): - description = "Test Description" self.hook.create_external_table( external_project_dataset_table='my_dataset.my_table', @@ -2064,3 +2070,204 @@ def test_create_external_table_description(self, mock_create): _, kwargs = mock_create.call_args assert kwargs['table_resource']['description'] is description + + +class _BigQueryBaseAsyncTestClass: + def setup_method(self) -> None: + class MockedBigQueryAsyncHook(BigQueryAsyncHook): + def get_credentials_and_project_id(self): + return CREDENTIALS, PROJECT_ID + + self.hook = MockedBigQueryAsyncHook() + + +class TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass): + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession") + async def test_get_job_instance(self, mock_session): + hook = BigQueryAsyncHook() + result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session) + assert isinstance(result, Job) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_get_job_status_success(self, mock_job_instance): + hook = BigQueryAsyncHook() + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + response = "success" + mock_job_instance.return_value.result.return_value = response + resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID) + assert resp == response + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_get_job_status_oserror(self, mock_job_instance): + """Assets that the BigQueryAsyncHook returns a pending response when OSError is raised""" + mock_job_instance.return_value.result.side_effect = OSError() + hook = BigQueryAsyncHook() + job_status = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID) + assert job_status == "pending" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_get_job_status_exception(self, mock_job_instance, caplog): + """Assets that the logging is done correctly when BigQueryAsyncHook raises Exception""" + mock_job_instance.return_value.result.side_effect = Exception() + hook = BigQueryAsyncHook() + await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID) + assert "Query execution finished with errors..." in caplog.text + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_get_job_output_assert_once_with(self, mock_job_instance): + hook = BigQueryAsyncHook() + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + response = "success" + mock_job_instance.return_value.get_query_results.return_value = response + resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID) + assert resp == response + + def test_interval_check_for_airflow_exception(self): + """ + Assert that check return AirflowException + """ + hook = BigQueryAsyncHook() + + row1, row2, metrics_thresholds, ignore_zero, ratio_formula = ( + None, + "0", + {"COUNT(*)": 1.5}, + True, + "max_over_min", + ) + with pytest.raises(AirflowException): + hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula) + + row1, row2, metrics_thresholds, ignore_zero, ratio_formula = ( + "0", + None, + {"COUNT(*)": 1.5}, + True, + "max_over_min", + ) + with pytest.raises(AirflowException): + hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula) + + row1, row2, metrics_thresholds, ignore_zero, ratio_formula = ( + "1", + "1", + {"COUNT(*)": 0}, + True, + "max_over_min", + ) + with pytest.raises(AirflowException): + hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula) + + def test_interval_check_for_success(self): + """ + Assert that check return None + """ + hook = BigQueryAsyncHook() + + row1, row2, metrics_thresholds, ignore_zero, ratio_formula = ( + "0", + "0", + {"COUNT(*)": 1.5}, + True, + "max_over_min", + ) + response = hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula) + assert response is None + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") + async def test_get_job_output(self, mock_job_instance): + """ + Tests to check if a particular object in Google Cloud Storage + is found or not + """ + response = { + "kind": "bigquery#tableDataList", + "etag": "test_etag", + "schema": {"fields": [{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}]}, + "jobReference": { + "projectId": "test_astronomer-airflow-providers", + "jobId": "test_jobid", + "location": "US", + }, + "totalRows": "10", + "rows": [{"f": [{"v": "42"}, {"v": "monthy python"}]}, {"f": [{"v": "42"}, {"v": "fishy fish"}]}], + "totalBytesProcessed": "0", + "jobComplete": True, + "cacheHit": False, + } + hook = BigQueryAsyncHook() + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + mock_job_client.get_query_results.return_value = response + resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID) + assert resp == response + + @pytest.mark.parametrize( + "records,pass_value,tolerance", [(["str"], "str", None), ([2], 2, None), ([0], 2, 1), ([4], 2, 1)] + ) + def test_value_check_success(self, records, pass_value, tolerance): + """ + Assert that value_check method execution succeed + """ + hook = BigQueryAsyncHook() + query = "SELECT COUNT(*) from Any" + response = hook.value_check(query, pass_value, records, tolerance) + assert response is None + + @pytest.mark.parametrize( + "records,pass_value,tolerance", + [([], "", None), (["str"], "str1", None), ([2], 21, None), ([5], 2, 1), (["str"], 2, None)], + ) + def test_value_check_fail(self, records, pass_value, tolerance): + """Assert that check raise AirflowException""" + hook = BigQueryAsyncHook() + query = "SELECT COUNT(*) from Any" + + with pytest.raises(AirflowException) as ex: + hook.value_check(query, pass_value, records, tolerance) + assert isinstance(ex.value, AirflowException) + + @pytest.mark.parametrize( + "records,pass_value,tolerance, expected", + [ + ([2.0], 2.0, None, [True]), + ([2.0], 2.1, None, [False]), + ([2.0], 2.0, 0.5, [True]), + ([1.0], 2.0, 0.5, [True]), + ([3.0], 2.0, 0.5, [True]), + ([0.9], 2.0, 0.5, [False]), + ([3.1], 2.0, 0.5, [False]), + ], + ) + def test_get_numeric_matches(self, records, pass_value, tolerance, expected): + """Assert the if response list have all element match with pass_value with tolerance""" + + assert BigQueryAsyncHook._get_numeric_matches(records, pass_value, tolerance) == expected + + @pytest.mark.parametrize("test_input,expected", [(5.0, 5.0), (5, 5.0), ("5", 5), ("str", "str")]) + def test_convert_to_float_if_possible(self, test_input, expected): + """ + Assert that type casting succeed for the possible value + Otherwise return the same value + """ + + assert BigQueryAsyncHook._convert_to_float_if_possible(test_input) == expected + + @pytest.mark.asyncio + @mock.patch("aiohttp.client.ClientSession") + async def test_get_table_client(self, mock_session): + """Test get_table_client async function and check whether the return value is a + Table instance object""" + hook = BigQueryTableAsyncHook() + result = await hook.get_table_client( + dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session + ) + assert isinstance(result, Table_async) diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 7d53a017c12c6..4aebe77cd4460 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -23,8 +23,12 @@ from google.cloud.bigquery import DEFAULT_RETRY from google.cloud.exceptions import Conflict -from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowException, AirflowTaskTimeout, TaskDeferred +from airflow.models import DAG +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCheckAsyncOperator, BigQueryCheckOperator, BigQueryConsoleIndexableLink, BigQueryConsoleLink, @@ -34,20 +38,32 @@ BigQueryDeleteDatasetOperator, BigQueryDeleteTableOperator, BigQueryExecuteQueryOperator, + BigQueryGetDataAsyncOperator, BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, + BigQueryInsertJobAsyncOperator, BigQueryInsertJobOperator, + BigQueryIntervalCheckAsyncOperator, BigQueryIntervalCheckOperator, BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, BigQueryUpdateTableOperator, BigQueryUpdateTableSchemaOperator, BigQueryUpsertTableOperator, + BigQueryValueCheckAsyncOperator, BigQueryValueCheckOperator, ) +from airflow.providers.google.cloud.triggers.bigquery import ( + BigQueryCheckTrigger, + BigQueryGetDataTrigger, + BigQueryInsertJobTrigger, + BigQueryIntervalCheckTrigger, + BigQueryValueCheckTrigger, +) from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags, clear_db_xcom TASK_ID = 'test-bq-generic-operator' @@ -71,6 +87,7 @@ 'enableRefresh': True, 'refreshIntervalMs': 2000000, } +TEST_TABLE = "test-table" class TestBigQueryCreateEmptyTableOperator(unittest.TestCase): @@ -1103,3 +1120,536 @@ def test_execute_no_force_rerun(self, mock_hook): # No force rerun with pytest.raises(AirflowException): op.execute(context=MagicMock()) + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_insert_job_operator_async(mock_hook): + """ + Asserts that a task is deferred and a BigQueryInsertJobTrigger will be fired + when the BigQueryInsertJobAsyncOperator is executed. + """ + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + + op = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(create_context(op)) + + assert isinstance( + exc.value.trigger, BigQueryInsertJobTrigger + ), "Trigger is not a BigQueryInsertJobTrigger" + + +def test_bigquery_insert_job_operator_execute_failure(): + """Tests that an AirflowException is raised in case of error event""" + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + job_id = "123456" + + operator = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + + with pytest.raises(AirflowException): + operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + +def create_context(task): + dag = DAG(dag_id="dag") + logical_date = datetime(2022, 1, 1, 0, 0, 0) + dag_run = DagRun( + dag_id=dag.dag_id, + execution_date=logical_date, + run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date), + ) + task_instance = TaskInstance(task=task) + task_instance.dag_run = dag_run + task_instance.dag_id = dag.dag_id + task_instance.xcom_push = mock.Mock() + return { + "dag": dag, + "run_id": dag_run.run_id, + "task": task, + "ti": task_instance, + "task_instance": task_instance, + "logical_date": logical_date, + } + + +def test_bigquery_insert_job_operator_execute_complete(): + """Asserts that logging occurs as expected""" + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + job_id = "123456" + + operator = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + with mock.patch.object(operator.log, "info") as mock_log_info: + operator.execute_complete( + context=create_context(operator), + event={"status": "success", "message": "Job completed", "job_id": job_id}, + ) + mock_log_info.assert_called_with("%s completed with response %s ", "insert_query_job", "Job completed") + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_insert_job_operator_with_job_id_generate(mock_hook): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + + mock_hook.return_value.insert_job.side_effect = Conflict("any") + job = MagicMock( + job_id=real_job_id, + error_result=False, + state="PENDING", + done=lambda: False, + ) + mock_hook.return_value.get_job.return_value = job + + op = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + reattach_states={"PENDING"}, + ) + + with pytest.raises(TaskDeferred): + op.execute(create_context(op)) + + mock_hook.return_value.generate_job_id.assert_called_once_with( + job_id=job_id, + dag_id="adhoc_airflow", + task_id="insert_query_job", + logical_date=datetime(2022, 1, 1, 0, 0), + configuration=configuration, + force_rerun=True, + ) + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_execute_reattach(mock_hook): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + mock_hook.return_value.generate_job_id.return_value = f"{job_id}_{hash_}" + + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + + mock_hook.return_value.insert_job.side_effect = Conflict("any") + job = MagicMock( + job_id=real_job_id, + error_result=False, + state="PENDING", + done=lambda: False, + ) + mock_hook.return_value.get_job.return_value = job + + op = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + reattach_states={"PENDING"}, + ) + + with pytest.raises(TaskDeferred): + op.execute(create_context(op)) + + mock_hook.return_value.get_job.assert_called_once_with( + location=TEST_DATASET_LOCATION, + job_id=real_job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + + job._begin.assert_called_once_with() + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_execute_force_rerun(mock_hook): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + mock_hook.return_value.generate_job_id.return_value = f"{job_id}_{hash_}" + + configuration = { + "query": { + "query": "SELECT * FROM any", + "useLegacySql": False, + } + } + + mock_hook.return_value.insert_job.side_effect = Conflict("any") + job = MagicMock( + job_id=real_job_id, + error_result=False, + state="DONE", + done=lambda: False, + ) + mock_hook.return_value.get_job.return_value = job + + op = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + reattach_states={"PENDING"}, + ) + + with pytest.raises(AirflowException) as exc: + op.execute(create_context(op)) + + expected_exception_msg = ( + f"Job with id: {real_job_id} already exists and is in {job.state} state. " + f"If you want to force rerun it consider setting `force_rerun=True`." + f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" + ) + + assert str(exc.value) == expected_exception_msg + + mock_hook.return_value.get_job.assert_called_once_with( + location=TEST_DATASET_LOCATION, + job_id=real_job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_check_operator_async(mock_hook): + """ + Asserts that a task is deferred and a BigQueryCheckTrigger will be fired + when the BigQueryCheckAsyncOperator is executed. + """ + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + + op = BigQueryCheckAsyncOperator( + task_id="bq_check_operator_job", + sql="SELECT * FROM any", + location=TEST_DATASET_LOCATION, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(create_context(op)) + + assert isinstance(exc.value.trigger, BigQueryCheckTrigger), "Trigger is not a BigQueryCheckTrigger" + + +def test_bigquery_check_operator_execute_failure(): + """Tests that an AirflowException is raised in case of error event""" + + operator = BigQueryCheckAsyncOperator( + task_id="bq_check_operator_execute_failure", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION + ) + + with pytest.raises(AirflowException): + operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + +def test_bigquery_check_op_execute_complete_with_no_records(): + """Asserts that exception is raised with correct expected exception message""" + + operator = BigQueryCheckAsyncOperator( + task_id="bq_check_operator_execute_complete", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION + ) + + with pytest.raises(AirflowException) as exc: + operator.execute_complete(context=None, event={"status": "success", "records": None}) + + expected_exception_msg = "The query returned None" + + assert str(exc.value) == expected_exception_msg + + +def test_bigquery_check_op_execute_complete_with_non_boolean_records(): + """Executing a sql which returns a non-boolean value should raise exception""" + + test_sql = "SELECT * FROM any" + + operator = BigQueryCheckAsyncOperator( + task_id="bq_check_operator_execute_complete", sql=test_sql, location=TEST_DATASET_LOCATION + ) + + expected_exception_msg = f"Test failed.\nQuery:\n{test_sql}\nResults:\n{[20, False]!s}" + + with pytest.raises(AirflowException) as exc: + operator.execute_complete(context=None, event={"status": "success", "records": [20, False]}) + + assert str(exc.value) == expected_exception_msg + + +def test_bigquery_check_operator_execute_complete(): + """Asserts that logging occurs as expected""" + + operator = BigQueryCheckAsyncOperator( + task_id="bq_check_operator_execute_complete", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION + ) + + with mock.patch.object(operator.log, "info") as mock_log_info: + operator.execute_complete(context=None, event={"status": "success", "records": [20]}) + mock_log_info.assert_called_with("Success.") + + +def test_bigquery_interval_check_operator_execute_complete(): + """Asserts that logging occurs as expected""" + + operator = BigQueryIntervalCheckAsyncOperator( + task_id="bq_interval_check_operator_execute_complete", + table="test_table", + metrics_thresholds={"COUNT(*)": 1.5}, + location=TEST_DATASET_LOCATION, + ) + + with mock.patch.object(operator.log, "info") as mock_log_info: + operator.execute_complete(context=None, event={"status": "success", "message": "Job completed"}) + mock_log_info.assert_called_with( + "%s completed with response %s ", "bq_interval_check_operator_execute_complete", "success" + ) + + +def test_bigquery_interval_check_operator_execute_failure(): + """Tests that an AirflowException is raised in case of error event""" + + operator = BigQueryIntervalCheckAsyncOperator( + task_id="bq_interval_check_operator_execute_complete", + table="test_table", + metrics_thresholds={"COUNT(*)": 1.5}, + location=TEST_DATASET_LOCATION, + ) + + with pytest.raises(AirflowException): + operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_interval_check_operator_async(mock_hook): + """ + Asserts that a task is deferred and a BigQueryIntervalCheckTrigger will be fired + when the BigQueryIntervalCheckAsyncOperator is executed. + """ + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + + op = BigQueryIntervalCheckAsyncOperator( + task_id="bq_interval_check_operator_execute_complete", + table="test_table", + metrics_thresholds={"COUNT(*)": 1.5}, + location=TEST_DATASET_LOCATION, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(create_context(op)) + + assert isinstance( + exc.value.trigger, BigQueryIntervalCheckTrigger + ), "Trigger is not a BigQueryIntervalCheckTrigger" + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_get_data_operator_async_with_selected_fields(mock_hook): + """ + Asserts that a task is deferred and a BigQuerygetDataTrigger will be fired + when the BigQueryGetDataAsyncOperator is executed. + """ + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + + op = BigQueryGetDataAsyncOperator( + task_id="get_data_from_bq", + dataset_id=TEST_DATASET, + table_id=TEST_TABLE, + max_results=100, + selected_fields="value,name", + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(create_context(op)) + + assert isinstance(exc.value.trigger, BigQueryGetDataTrigger), "Trigger is not a BigQueryGetDataTrigger" + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_get_data_operator_async_without_selected_fields(mock_hook): + """ + Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired + when the BigQueryGetDataAsyncOperator is executed. + """ + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + + op = BigQueryGetDataAsyncOperator( + task_id="get_data_from_bq", + dataset_id=TEST_DATASET, + table_id=TEST_TABLE, + max_results=100, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(create_context(op)) + + assert isinstance(exc.value.trigger, BigQueryGetDataTrigger), "Trigger is not a BigQueryGetDataTrigger" + + +def test_bigquery_get_data_operator_execute_failure(): + """Tests that an AirflowException is raised in case of error event""" + + operator = BigQueryGetDataAsyncOperator( + task_id="get_data_from_bq", + dataset_id=TEST_DATASET, + table_id="any", + max_results=100, + ) + + with pytest.raises(AirflowException): + operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + +def test_bigquery_get_data_op_execute_complete_with_records(): + """Asserts that exception is raised with correct expected exception message""" + + operator = BigQueryGetDataAsyncOperator( + task_id="get_data_from_bq", + dataset_id=TEST_DATASET, + table_id="any", + max_results=100, + ) + + with mock.patch.object(operator.log, "info") as mock_log_info: + operator.execute_complete(context=None, event={"status": "success", "records": [20]}) + mock_log_info.assert_called_with("Total extracted rows: %s", 1) + + +def _get_value_check_async_operator(use_legacy_sql: bool = False): + """Helper function to initialise BigQueryValueCheckOperatorAsync operator""" + query = "SELECT COUNT(*) FROM Any" + pass_val = 2 + + return BigQueryValueCheckAsyncOperator( + task_id="check_value", + sql=query, + pass_value=pass_val, + use_legacy_sql=use_legacy_sql, + ) + + +@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") +def test_bigquery_value_check_async(mock_hook): + """ + Asserts that a task is deferred and a BigQueryValueCheckTrigger will be fired + when the BigQueryValueCheckOperatorAsync is executed. + """ + operator = _get_value_check_async_operator(True) + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + with pytest.raises(TaskDeferred) as exc: + operator.execute(create_context(operator)) + + assert isinstance( + exc.value.trigger, BigQueryValueCheckTrigger + ), "Trigger is not a BigQueryValueCheckTrigger" + + +def test_bigquery_value_check_operator_execute_complete_success(): + """Tests response message in case of success event""" + operator = _get_value_check_async_operator() + + assert ( + operator.execute_complete(context=None, event={"status": "success", "message": "Job completed!"}) + is None + ) + + +def test_bigquery_value_check_operator_execute_complete_failure(): + """Tests that an AirflowException is raised in case of error event""" + operator = _get_value_check_async_operator() + + with pytest.raises(AirflowException): + operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + +@pytest.mark.parametrize( + "kwargs, expected", + [ + ({"sql": "SELECT COUNT(*) from Any"}, "missing keyword argument 'pass_value'"), + ({"pass_value": "Any"}, "missing keyword argument 'sql'"), + ], +) +def test_bigquery_value_check_missing_param(kwargs, expected): + """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator""" + with pytest.raises(AirflowException) as missing_param: + BigQueryValueCheckAsyncOperator(**kwargs) + assert missing_param.value.args[0] == expected + + +def test_bigquery_value_check_empty(): + """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator""" + expected, expected1 = ( + "missing keyword arguments 'sql', 'pass_value'", + "missing keyword arguments 'pass_value', 'sql'", + ) + with pytest.raises(AirflowException) as missing_param: + BigQueryValueCheckAsyncOperator(kwargs={}) + assert (missing_param.value.args[0] == expected) or (missing_param.value.args[0] == expected1) diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py index 87ec3dbacb1ba..5ea3b9b67a289 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery.py +++ b/tests/providers/google/cloud/sensors/test_bigquery.py @@ -17,10 +17,15 @@ from unittest import TestCase, mock +import pytest + +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.sensors.bigquery import ( + BigQueryTableExistenceAsyncSensor, BigQueryTableExistenceSensor, BigQueryTablePartitionExistenceSensor, ) +from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger TEST_PROJECT_ID = "test_project" TEST_DATASET_ID = 'test_dataset' @@ -87,3 +92,66 @@ def test_passing_arguments_to_hook(self, mock_hook): table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, ) + + +@pytest.fixture() +def context(): + """ + Creates an empty context. + """ + context = {} + yield context + + +class TestBigQueryTableExistenceAsyncSensor(TestCase): + def test_big_query_table_existence_sensor_async(self): + """ + Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired + when the BigQueryTableExistenceAsyncSensor is executed. + """ + task = BigQueryTableExistenceAsyncSensor( + task_id="check_table_exists", + project_id=TEST_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + ) + with pytest.raises(TaskDeferred) as exc: + task.execute(context={}) + assert isinstance( + exc.value.trigger, BigQueryTableExistenceTrigger + ), "Trigger is not a BigQueryTableExistenceTrigger" + + def test_big_query_table_existence_sensor_async_execute_failure(self): + """Tests that an AirflowException is raised in case of error event""" + task = BigQueryTableExistenceAsyncSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + ) + with pytest.raises(AirflowException): + task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) + + def test_big_query_table_existence_sensor_async_execute_complete(self): + """Asserts that logging occurs as expected""" + task = BigQueryTableExistenceAsyncSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + ) + table_uri = f"{TEST_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}" + with mock.patch.object(task.log, "info") as mock_log_info: + task.execute_complete(context={}, event={"status": "success", "message": "Job completed"}) + mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) + + def test_big_query_sensor_async_execute_complete_event_none(self): + """Asserts that logging occurs as expected""" + task = BigQueryTableExistenceAsyncSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + ) + with pytest.raises(AirflowException): + task.execute_complete(context={}, event=None) diff --git a/tests/providers/google/cloud/triggers/__init__.py b/tests/providers/google/cloud/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/google/cloud/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py new file mode 100644 index 0000000000000..7a771a124c374 --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -0,0 +1,1040 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import logging +import sys +from typing import Any, Dict + +import pytest +from aiohttp import ClientResponseError, RequestInfo +from gcloud.aio.bigquery import Table +from multidict import CIMultiDict +from yarl import URL + +from airflow.providers.google.cloud.hooks.bigquery import BigQueryTableAsyncHook +from airflow.providers.google.cloud.triggers.bigquery import ( + BigQueryCheckTrigger, + BigQueryGetDataTrigger, + BigQueryInsertJobTrigger, + BigQueryIntervalCheckTrigger, + BigQueryTableExistenceTrigger, + BigQueryValueCheckTrigger, +) +from airflow.triggers.base import TriggerEvent + +if sys.version_info < (3, 8): + from asynctest import mock + from asynctest.mock import CoroutineMock as AsyncMock +else: + from unittest import mock + from unittest.mock import AsyncMock + +TEST_CONN_ID = "bq_default" +TEST_JOB_ID = "1234" +RUN_ID = "1" +RETRY_LIMIT = 2 +RETRY_DELAY = 1.0 +TEST_GCP_PROJECT_ID = "test-project" +TEST_DATASET_ID = "bq_dataset" +TEST_TABLE_ID = "bq_table" +POLLING_PERIOD_SECONDS = 4.0 +TEST_SQL_QUERY = "SELECT count(*) from Any" +TEST_PASS_VALUE = 2 +TEST_TOLERANCE = 1 +TEST_FIRST_JOB_ID = "5678" +TEST_SECOND_JOB_ID = "6789" +TEST_METRIC_THRESHOLDS: Dict[str, int] = {} +TEST_DATE_FILTER_COLUMN = "ds" +TEST_DAYS_BACK = -7 +TEST_RATIO_FORMULA = "max_over_min" +TEST_IGNORE_ZERO = True +TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID" +TEST_HOOK_PARAMS: Dict[str, Any] = {} + + +def test_bigquery_insert_job_op_trigger_serialization(): + """ + Asserts that the BigQueryInsertJobTrigger correctly serializes its arguments + and classpath. + """ + trigger = BigQueryInsertJobTrigger( + TEST_CONN_ID, + TEST_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger" + assert kwargs == { + "conn_id": TEST_CONN_ID, + "job_id": TEST_JOB_ID, + "project_id": TEST_GCP_PROJECT_ID, + "dataset_id": TEST_DATASET_ID, + "table_id": TEST_TABLE_ID, + "poll_interval": POLLING_PERIOD_SECONDS, + } + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_insert_job_op_trigger_success(mock_job_status): + """ + Tests the BigQueryInsertJobTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "success" + + trigger = BigQueryInsertJobTrigger( + TEST_CONN_ID, + TEST_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success", "message": "Job completed", "job_id": TEST_JOB_ID}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") +async def test_bigquery_insert_job_trigger_running(mock_job_instance, caplog): + """ + Test that BigQuery Triggers do not fire while a query is still running. + """ + + from gcloud.aio.bigquery import Job + + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + mock_job_instance.return_value.result.side_effect = OSError + caplog.set_level(logging.INFO) + + trigger = BigQueryInsertJobTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + + assert f"Using the connection {TEST_CONN_ID} ." in caplog.text + + assert "Query is still running..." in caplog.text + assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") +async def test_bigquery_get_data_trigger_running(mock_job_instance, caplog): + """ + Test that BigQuery Triggers do not fire while a query is still running. + """ + + from gcloud.aio.bigquery import Job + + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + mock_job_instance.return_value.result.side_effect = OSError + caplog.set_level(logging.INFO) + + trigger = BigQueryGetDataTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + + assert f"Using the connection {TEST_CONN_ID} ." in caplog.text + + assert "Query is still running..." in caplog.text + assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") +async def test_bigquery_check_trigger_running(mock_job_instance, caplog): + """ + Test that BigQuery Triggers do not fire while a query is still running. + """ + + from gcloud.aio.bigquery import Job + + mock_job_client = AsyncMock(Job) + mock_job_instance.return_value = mock_job_client + mock_job_instance.return_value.result.side_effect = OSError + caplog.set_level(logging.INFO) + + trigger = BigQueryCheckTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + + assert f"Using the connection {TEST_CONN_ID} ." in caplog.text + + assert "Query is still running..." in caplog.text + assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_op_trigger_terminated(mock_job_status, caplog): + """ + Test that BigQuery Triggers fire the correct event in case of an error. + """ + # Set the status to a value other than success or pending + + mock_job_status.return_value = "error" + + trigger = BigQueryInsertJobTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "error"}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_check_trigger_terminated(mock_job_status, caplog): + """ + Test that BigQuery Triggers fire the correct event in case of an error. + """ + # Set the status to a value other than success or pending + + mock_job_status.return_value = "error" + + trigger = BigQueryCheckTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "error"}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_get_data_trigger_terminated(mock_job_status, caplog): + """ + Test that BigQuery Triggers fire the correct event in case of an error. + """ + # Set the status to a value other than success or pending + + mock_job_status.return_value = "error" + + trigger = BigQueryGetDataTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "error"}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_op_trigger_exception(mock_job_status, caplog): + """ + Test that BigQuery Triggers fire the correct event in case of an error. + """ + mock_job_status.side_effect = Exception("Test exception") + + trigger = BigQueryInsertJobTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_check_trigger_exception(mock_job_status, caplog): + """ + Test that BigQuery Triggers fire the correct event in case of an error. + """ + mock_job_status.side_effect = Exception("Test exception") + + trigger = BigQueryCheckTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_get_data_trigger_exception(mock_job_status, caplog): + """ + Test that BigQuery Triggers fire the correct event in case of an error. + """ + mock_job_status.side_effect = Exception("Test exception") + + trigger = BigQueryGetDataTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + + +def test_bigquery_check_op_trigger_serialization(): + """ + Asserts that the BigQueryCheckTrigger correctly serializes its arguments + and classpath. + """ + trigger = BigQueryCheckTrigger( + TEST_CONN_ID, + TEST_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger" + assert kwargs == { + "conn_id": TEST_CONN_ID, + "job_id": TEST_JOB_ID, + "dataset_id": TEST_DATASET_ID, + "project_id": TEST_GCP_PROJECT_ID, + "table_id": TEST_TABLE_ID, + "poll_interval": POLLING_PERIOD_SECONDS, + } + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output") +async def test_bigquery_check_op_trigger_success_with_data(mock_job_output, mock_job_status): + """ + Test the BigQueryCheckTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "success" + mock_job_output.return_value = { + "kind": "bigquery#getQueryResultsResponse", + "etag": "test_etag", + "schema": {"fields": [{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}]}, + "jobReference": { + "projectId": "test_airflow-providers", + "jobId": "test_jobid", + "location": "US", + }, + "totalRows": "1", + "rows": [{"f": [{"v": "22"}]}], + "totalBytesProcessed": "0", + "jobComplete": True, + "cacheHit": False, + } + + trigger = BigQueryCheckTrigger( + TEST_CONN_ID, + TEST_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + + assert TriggerEvent({"status": "success", "records": ["22"]}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output") +async def test_bigquery_check_op_trigger_success_without_data(mock_job_output, mock_job_status): + """ + Tests that BigQueryCheckTrigger sends TriggerEvent as { "status": "success", "records": None} + when no rows are available in the query result. + """ + mock_job_status.return_value = "success" + mock_job_output.return_value = { + "kind": "bigquery#getQueryResultsResponse", + "etag": "test_etag", + "schema": { + "fields": [ + {"name": "value", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "ds", "type": "DATE", "mode": "NULLABLE"}, + ] + }, + "jobReference": { + "projectId": "test_airflow-airflow-providers", + "jobId": "test_jobid", + "location": "US", + }, + "totalRows": "0", + "totalBytesProcessed": "0", + "jobComplete": True, + "cacheHit": False, + } + + trigger = BigQueryCheckTrigger( + TEST_CONN_ID, + TEST_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success", "records": None}) == actual + + +def test_bigquery_get_data_trigger_serialization(): + """ + Asserts that the BigQueryGetDataTrigger correctly serializes its arguments + and classpath. + """ + trigger = BigQueryGetDataTrigger( + conn_id=TEST_CONN_ID, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger" + assert kwargs == { + "conn_id": TEST_CONN_ID, + "job_id": TEST_JOB_ID, + "dataset_id": TEST_DATASET_ID, + "project_id": TEST_GCP_PROJECT_ID, + "table_id": TEST_TABLE_ID, + "poll_interval": POLLING_PERIOD_SECONDS, + } + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output") +async def test_bigquery_get_data_trigger_success_with_data(mock_job_output, mock_job_status): + """ + Tests that BigQueryGetDataTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "success" + mock_job_output.return_value = { + "kind": "bigquery#tableDataList", + "etag": "test_etag", + "schema": {"fields": [{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}]}, + "jobReference": { + "projectId": "test-airflow-providers", + "jobId": "test_jobid", + "location": "US", + }, + "totalRows": "10", + "rows": [{"f": [{"v": "42"}, {"v": "monthy python"}]}, {"f": [{"v": "42"}, {"v": "fishy fish"}]}], + "totalBytesProcessed": "0", + "jobComplete": True, + "cacheHit": False, + } + + trigger = BigQueryGetDataTrigger( + TEST_CONN_ID, + TEST_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + # # The extracted row will be parsed and formatted to retrieve the value from the + # # structure - 'rows":[{"f":[{"v":"42"},{"v":"monthy python"}]},{"f":[{"v":"42"},{"v":"fishy fish"}]}] + + assert ( + TriggerEvent( + { + "status": "success", + "message": "success", + "records": [["42", "monthy python"], ["42", "fishy fish"]], + } + ) + == actual + ) + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + +def test_bigquery_interval_check_trigger_serialization(): + """ + Asserts that the BigQueryIntervalCheckTrigger correctly serializes its arguments + and classpath. + """ + trigger = BigQueryIntervalCheckTrigger( + TEST_CONN_ID, + TEST_FIRST_JOB_ID, + TEST_SECOND_JOB_ID, + TEST_GCP_PROJECT_ID, + TEST_TABLE_ID, + TEST_METRIC_THRESHOLDS, + TEST_DATE_FILTER_COLUMN, + TEST_DAYS_BACK, + TEST_RATIO_FORMULA, + TEST_IGNORE_ZERO, + TEST_DATASET_ID, + TEST_TABLE_ID, + POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger" + assert kwargs == { + "conn_id": TEST_CONN_ID, + "first_job_id": TEST_FIRST_JOB_ID, + "second_job_id": TEST_SECOND_JOB_ID, + "project_id": TEST_GCP_PROJECT_ID, + "table": TEST_TABLE_ID, + "metrics_thresholds": TEST_METRIC_THRESHOLDS, + "date_filter_column": TEST_DATE_FILTER_COLUMN, + "days_back": TEST_DAYS_BACK, + "ratio_formula": TEST_RATIO_FORMULA, + "ignore_zero": TEST_IGNORE_ZERO, + } + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output") +async def test_bigquery_interval_check_trigger_success(mock_get_job_output, mock_job_status): + """ + Tests the BigQueryIntervalCheckTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "success" + mock_get_job_output.return_value = ["0"] + + trigger = BigQueryIntervalCheckTrigger( + conn_id=TEST_CONN_ID, + first_job_id=TEST_FIRST_JOB_ID, + second_job_id=TEST_SECOND_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + table=TEST_TABLE_ID, + metrics_thresholds=TEST_METRIC_THRESHOLDS, + date_filter_column=TEST_DATE_FILTER_COLUMN, + days_back=TEST_DAYS_BACK, + ratio_formula=TEST_RATIO_FORMULA, + ignore_zero=TEST_IGNORE_ZERO, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "error", "message": "The second SQL query returned None"}) + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_interval_check_trigger_pending(mock_job_status, caplog): + """ + Tests that the BigQueryIntervalCheckTrigger do not fire while a query is still running. + """ + mock_job_status.return_value = "pending" + caplog.set_level(logging.INFO) + + trigger = BigQueryIntervalCheckTrigger( + conn_id=TEST_CONN_ID, + first_job_id=TEST_FIRST_JOB_ID, + second_job_id=TEST_SECOND_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + table=TEST_TABLE_ID, + metrics_thresholds=TEST_METRIC_THRESHOLDS, + date_filter_column=TEST_DATE_FILTER_COLUMN, + days_back=TEST_DAYS_BACK, + ratio_formula=TEST_RATIO_FORMULA, + ignore_zero=TEST_IGNORE_ZERO, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + + assert f"Using the connection {TEST_CONN_ID} ." in caplog.text + + assert "Query is still running..." in caplog.text + assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_interval_check_trigger_terminated(mock_job_status): + """ + Tests the BigQueryIntervalCheckTrigger fires the correct event in case of an error. + """ + # Set the status to a value other than success or pending + mock_job_status.return_value = "error" + trigger = BigQueryIntervalCheckTrigger( + conn_id=TEST_CONN_ID, + first_job_id=TEST_FIRST_JOB_ID, + second_job_id=TEST_SECOND_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + table=TEST_TABLE_ID, + metrics_thresholds=TEST_METRIC_THRESHOLDS, + date_filter_column=TEST_DATE_FILTER_COLUMN, + days_back=TEST_DAYS_BACK, + ratio_formula=TEST_RATIO_FORMULA, + ignore_zero=TEST_IGNORE_ZERO, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + + assert TriggerEvent({"status": "error", "message": "error", "data": None}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_interval_check_trigger_exception(mock_job_status, caplog): + """ + Tests that the BigQueryIntervalCheckTrigger fires the correct event in case of an error. + """ + mock_job_status.side_effect = Exception("Test exception") + caplog.set_level(logging.DEBUG) + + trigger = BigQueryIntervalCheckTrigger( + conn_id=TEST_CONN_ID, + first_job_id=TEST_FIRST_JOB_ID, + second_job_id=TEST_SECOND_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + table=TEST_TABLE_ID, + metrics_thresholds=TEST_METRIC_THRESHOLDS, + date_filter_column=TEST_DATE_FILTER_COLUMN, + days_back=TEST_DAYS_BACK, + ratio_formula=TEST_RATIO_FORMULA, + ignore_zero=TEST_IGNORE_ZERO, + dataset_id=TEST_DATASET_ID, + table_id=TEST_TABLE_ID, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + # trigger event is yielded so it creates a generator object + # so i have used async for to get all the values and added it to task + task = [i async for i in trigger.run()] + # since we use return as soon as we yield the trigger event + # at any given point there should be one trigger event returned to the task + # so we validate for length of task to be 1 + + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + +def test_bigquery_value_check_op_trigger_serialization(): + """ + Asserts that the BigQueryValueCheckTrigger correctly serializes its arguments + and classpath. + """ + + trigger = BigQueryValueCheckTrigger( + conn_id=TEST_CONN_ID, + pass_value=TEST_PASS_VALUE, + job_id=TEST_JOB_ID, + dataset_id=TEST_DATASET_ID, + project_id=TEST_GCP_PROJECT_ID, + sql=TEST_SQL_QUERY, + table_id=TEST_TABLE_ID, + tolerance=TEST_TOLERANCE, + poll_interval=POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + + assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger" + assert kwargs == { + "conn_id": TEST_CONN_ID, + "pass_value": TEST_PASS_VALUE, + "job_id": TEST_JOB_ID, + "dataset_id": TEST_DATASET_ID, + "project_id": TEST_GCP_PROJECT_ID, + "sql": TEST_SQL_QUERY, + "table_id": TEST_TABLE_ID, + "tolerance": TEST_TOLERANCE, + "poll_interval": POLLING_PERIOD_SECONDS, + } + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_records") +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output") +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_value_check_op_trigger_success(mock_job_status, get_job_output, get_records): + """ + Tests that the BigQueryValueCheckTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "success" + get_job_output.return_value = {} + get_records.return_value = [[2], [4]] + + trigger = BigQueryValueCheckTrigger( + conn_id=TEST_CONN_ID, + pass_value=TEST_PASS_VALUE, + job_id=TEST_JOB_ID, + dataset_id=TEST_DATASET_ID, + project_id=TEST_GCP_PROJECT_ID, + sql=TEST_SQL_QUERY, + table_id=TEST_TABLE_ID, + tolerance=TEST_TOLERANCE, + poll_interval=POLLING_PERIOD_SECONDS, + ) + + asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + generator = trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent({"status": "success", "message": "Job completed", "records": [4]}) + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_value_check_op_trigger_pending(mock_job_status, caplog): + """ + Tests that the BigQueryValueCheckTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "pending" + caplog.set_level(logging.INFO) + + trigger = BigQueryValueCheckTrigger( + TEST_CONN_ID, + TEST_PASS_VALUE, + TEST_JOB_ID, + TEST_DATASET_ID, + TEST_GCP_PROJECT_ID, + TEST_SQL_QUERY, + TEST_TABLE_ID, + TEST_TOLERANCE, + POLLING_PERIOD_SECONDS, + ) + + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was returned + assert task.done() is False + + assert "Query is still running..." in caplog.text + + assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_value_check_op_trigger_fail(mock_job_status): + """ + Tests that the BigQueryValueCheckTrigger only fires once the query execution reaches a successful state. + """ + mock_job_status.return_value = "dummy" + + trigger = BigQueryValueCheckTrigger( + TEST_CONN_ID, + TEST_PASS_VALUE, + TEST_JOB_ID, + TEST_DATASET_ID, + TEST_GCP_PROJECT_ID, + TEST_SQL_QUERY, + TEST_TABLE_ID, + TEST_TOLERANCE, + POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "error", "message": "dummy", "records": None}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") +async def test_bigquery_value_check_trigger_exception(mock_job_status): + """ + Tests the BigQueryValueCheckTrigger does not fire if there is an exception. + """ + mock_job_status.side_effect = Exception("Test exception") + + trigger = BigQueryValueCheckTrigger( + conn_id=TEST_CONN_ID, + sql=TEST_SQL_QUERY, + pass_value=TEST_PASS_VALUE, + tolerance=1, + job_id=TEST_JOB_ID, + project_id=TEST_GCP_PROJECT_ID, + ) + + # trigger event is yielded so it creates a generator object + # so i have used async for to get all the values and added it to task + task = [i async for i in trigger.run()] + # since we use return as soon as we yield the trigger event + # at any given point there should be one trigger event returned to the task + # so we validate for length of task to be 1 + + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + +def test_big_query_table_existence_trigger_serialization(): + """ + Asserts that the BigQueryTableExistenceTrigger correctly serializes its arguments + and classpath. + """ + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger" + assert kwargs == { + "dataset_id": TEST_DATASET_ID, + "project_id": TEST_GCP_PROJECT_ID, + "table_id": TEST_TABLE_ID, + "gcp_conn_id": TEST_GCP_CONN_ID, + "poll_interval": POLLING_PERIOD_SECONDS, + "hook_params": TEST_HOOK_PARAMS, + } + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists") +async def test_big_query_table_existence_trigger_success(mock_table_exists): + """ + Tests success case BigQueryTableExistenceTrigger + """ + mock_table_exists.return_value = True + + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success", "message": "success"}) == actual + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists") +async def test_big_query_table_existence_trigger_pending(mock_table_exists): + """ + Test that BigQueryTableExistenceTrigger is in loop till the table exist. + """ + mock_table_exists.return_value = False + + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists") +async def test_big_query_table_existence_trigger_exception(mock_table_exists): + """ + Test BigQueryTableExistenceTrigger throws exception if any error. + """ + mock_table_exists.side_effect = AsyncMock(side_effect=Exception("Test exception")) + + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + task = [i async for i in trigger.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client") +async def test_table_exists(mock_get_table_client): + """Test BigQueryTableExistenceTrigger._table_exists async function with mocked value + and mocked return value""" + hook = BigQueryTableAsyncHook() + mock_get_table_client.return_value = AsyncMock(Table) + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + res = await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID) + assert res is True + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client") +async def test_table_exists_exception(mock_get_table_client): + """Test BigQueryTableExistenceTrigger._table_exists async function with exception and return False""" + hook = BigQueryTableAsyncHook() + mock_get_table_client.side_effect = ClientResponseError( + history=(), + request_info=RequestInfo( + headers=CIMultiDict(), + real_url=URL("https://example.com"), + method="GET", + url=URL("https://example.com"), + ), + status=404, + message="Not Found", + ) + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + res = await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID) + expected_response = False + assert res == expected_response + + +@pytest.mark.asyncio +@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client") +async def test_table_exists_raise_exception(mock_get_table_client): + """Test BigQueryTableExistenceTrigger._table_exists async function with raise exception""" + hook = BigQueryTableAsyncHook() + mock_get_table_client.side_effect = ClientResponseError( + history=(), + request_info=RequestInfo( + headers=CIMultiDict(), + real_url=URL("https://example.com"), + method="GET", + url=URL("https://example.com"), + ), + status=400, + message="Not Found", + ) + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + with pytest.raises(ClientResponseError): + await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID) diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py new file mode 100644 index 0000000000000..36e4844807ab0 --- /dev/null +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google BigQuery service. +Uses Async version of the Big Query Operators + +""" +import os +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.empty import EmptyOperator +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCheckAsyncOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryGetDataAsyncOperator, + BigQueryInsertJobAsyncOperator, + BigQueryIntervalCheckAsyncOperator, + BigQueryValueCheckAsyncOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +PROJECT_ID = os.getenv("SYSTEM_TESTS_GCP_PROJECT") +DAG_ID = "bigquery_queries_async" +DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" +LOCATION = "us" +EXECUTION_TIMEOUT = 6 + +TABLE_1 = "table1" +TABLE_2 = "table2" + +SCHEMA = [ + {"name": "value", "type": "INTEGER", "mode": "REQUIRED"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + {"name": "ds", "type": "STRING", "mode": "NULLABLE"}, +] + +DATASET = DATASET_NAME +INSERT_DATE = datetime.now().strftime("%Y-%m-%d") +INSERT_ROWS_QUERY = ( + f"INSERT {DATASET}.{TABLE_1} VALUES " + f"(42, 'monthy python', '{INSERT_DATE}'), " + f"(42, 'fishy fish', '{INSERT_DATE}');" +) + +default_args = { + "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), + "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), + "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), +} + +with DAG( + dag_id="example_async_bigquery_queries_async", + schedule=None, + start_date=datetime(2022, 1, 1), + catchup=False, + default_args=default_args, + tags=["example", "async", "bigquery"], + user_defined_macros={"DATASET": DATASET, "TABLE": TABLE_1}, +) as dag: + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_dataset", + dataset_id=DATASET, + location=LOCATION, + ) + + create_table_1 = BigQueryCreateEmptyTableOperator( + task_id="create_table_1", + dataset_id=DATASET, + table_id=TABLE_1, + schema_fields=SCHEMA, + location=LOCATION, + ) + + create_dataset >> create_table_1 + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET, delete_contents=True, trigger_rule=TriggerRule.ALL_DONE + ) + + # [START howto_operator_bigquery_insert_job_async] + insert_query_job = BigQueryInsertJobAsyncOperator( + task_id="insert_query_job", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + } + }, + location=LOCATION, + ) + # [END howto_operator_bigquery_insert_job_async] + + # [START howto_operator_bigquery_select_job_async] + select_query_job = BigQueryInsertJobAsyncOperator( + task_id="select_query_job", + configuration={ + "query": { + "query": "{% include 'example_bigquery_query.sql' %}", + "useLegacySql": False, + } + }, + location=LOCATION, + ) + # [END howto_operator_bigquery_select_job_async] + + # [START howto_operator_bigquery_value_check_async] + check_value = BigQueryValueCheckAsyncOperator( + task_id="check_value", + sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}", + pass_value=2, + use_legacy_sql=False, + location=LOCATION, + ) + # [END howto_operator_bigquery_value_check_async] + + # [START howto_operator_bigquery_interval_check_async] + check_interval = BigQueryIntervalCheckAsyncOperator( + task_id="check_interval", + table=f"{DATASET}.{TABLE_1}", + days_back=1, + metrics_thresholds={"COUNT(*)": 1.5}, + use_legacy_sql=False, + location=LOCATION, + ) + # [END howto_operator_bigquery_interval_check_async] + + # [START howto_operator_bigquery_multi_query_async] + bigquery_execute_multi_query = BigQueryInsertJobAsyncOperator( + task_id="execute_multi_query", + configuration={ + "query": { + "query": [ + f"SELECT * FROM {DATASET}.{TABLE_2}", + f"SELECT COUNT(*) FROM {DATASET}.{TABLE_2}", + ], + "useLegacySql": False, + } + }, + location=LOCATION, + ) + # [END howto_operator_bigquery_multi_query_async] + + # [START howto_operator_bigquery_get_data_async] + get_data = BigQueryGetDataAsyncOperator( + task_id="get_data", + dataset_id=DATASET, + table_id=TABLE_1, + max_results=10, + selected_fields="value,name", + location=LOCATION, + ) + # [END howto_operator_bigquery_get_data_async] + + get_data_result = BashOperator( + task_id="get_data_result", + bash_command=f"echo {get_data.output}", + trigger_rule=TriggerRule.ALL_DONE, + ) + + # [START howto_operator_bigquery_check_async] + check_count = BigQueryCheckAsyncOperator( + task_id="check_count", + sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}", + use_legacy_sql=False, + location=LOCATION, + ) + # [END howto_operator_bigquery_check_async] + + # [START howto_operator_bigquery_execute_query_save_async] + execute_query_save = BigQueryInsertJobAsyncOperator( + task_id="execute_query_save", + configuration={ + "query": { + "query": f"SELECT * FROM {DATASET}.{TABLE_1}", + "useLegacySql": False, + "destinationTable": { + "projectId": PROJECT_ID, + "datasetId": DATASET, + "tableId": TABLE_2, + }, + } + }, + location=LOCATION, + ) + # [END howto_operator_bigquery_execute_query_save_async] + + execute_long_running_query = BigQueryInsertJobAsyncOperator( + task_id="execute_long_running_query", + configuration={ + "query": { + "query": f"""DECLARE success BOOL; + DECLARE size_bytes INT64; + DECLARE row_count INT64; + DECLARE DELAY_TIME DATETIME; + DECLARE WAIT STRING; + SET success = FALSE; + + SELECT row_count = (SELECT row_count FROM {DATASET}.__TABLES__ WHERE table_id='NON_EXISTING_TABLE'); + IF row_count > 0 THEN + SELECT 'Table Exists!' as message, retry_count as retries; + SET success = TRUE; + ELSE + SELECT 'Table does not exist' as message, row_count; + SET WAIT = 'TRUE'; + SET DELAY_TIME = DATETIME_ADD(CURRENT_DATETIME,INTERVAL 1 MINUTE); + WHILE WAIT = 'TRUE' DO + IF (DELAY_TIME < CURRENT_DATETIME) THEN + SET WAIT = 'FALSE'; + END IF; + END WHILE; + END IF;""", + "useLegacySql": False, + } + }, + location=LOCATION, + ) + + end = EmptyOperator(task_id="end") + + create_table_1 >> insert_query_job >> select_query_job >> check_count + insert_query_job >> get_data >> get_data_result + insert_query_job >> execute_query_save >> bigquery_execute_multi_query + insert_query_job >> execute_long_running_query >> check_value >> check_interval + [check_count, check_interval, bigquery_execute_multi_query, get_data_result] >> delete_dataset + [check_count, check_interval, bigquery_execute_multi_query, get_data_result, delete_dataset] >> end + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py index 45e44343a9279..6faea7cbe9ddb 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py @@ -31,6 +31,7 @@ BigQueryInsertJobOperator, ) from airflow.providers.google.cloud.sensors.bigquery import ( + BigQueryTableExistenceAsyncSensor, BigQueryTableExistenceSensor, BigQueryTablePartitionExistenceSensor, ) @@ -86,6 +87,15 @@ ) # [END howto_sensor_bigquery_table] + # [START howto_sensor_async_bigquery_table] + check_table_exists_async = BigQueryTableExistenceAsyncSensor( + task_id="check_table_exists_async", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + # [END howto_sensor_async_bigquery_table] + execute_insert_query: BaseOperator = BigQueryInsertJobOperator( task_id="execute_insert_query", configuration={ @@ -116,7 +126,7 @@ create_dataset >> create_table create_table >> [check_table_exists, execute_insert_query] execute_insert_query >> check_table_partition_exists - [check_table_exists, check_table_partition_exists] >> delete_dataset + [check_table_exists, check_table_exists_async, check_table_partition_exists] >> delete_dataset from tests.system.utils.watcher import watcher