diff --git a/.flake8 b/.flake8 index cffaf32f138d9..14de564a3291e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 110 -ignore = E231,E731,W504,I001,W503 +ignore = E203,E231,E731,W504,I001,W503 exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,.eggs,*.egg,node_modules format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s per-file-ignores = diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31972006dcbcc..461f36d9f16e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -153,7 +153,8 @@ repos: rev: stable hooks: - id: black - files: api_connexion/.*\.py + files: api_connexion/.*\.py|.*providers.*\.py + exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$ args: [--config=./pyproject.toml] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 @@ -190,7 +191,7 @@ repos: name: Run isort to sort imports types: [python] # To keep consistent with the global isort skip config defined in setup.cfg - exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py + exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py|.*providers.*\.py - repo: https://github.com/pycqa/pydocstyle rev: 5.0.2 hooks: diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync_1.py b/airflow/providers/amazon/aws/example_dags/example_datasync_1.py index 5e1127aa84684..8b3e2785dcd14 100644 --- a/airflow/providers/amazon/aws/example_dags/example_datasync_1.py +++ b/airflow/providers/amazon/aws/example_dags/example_datasync_1.py @@ -33,16 +33,13 @@ from airflow.utils.dates import days_ago # [START howto_operator_datasync_1_args_1] -TASK_ARN = getenv( - "TASK_ARN", "my_aws_datasync_task_arn") +TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn") # [END howto_operator_datasync_1_args_1] # [START howto_operator_datasync_1_args_2] -SOURCE_LOCATION_URI = getenv( - "SOURCE_LOCATION_URI", "smb://hostname/directory/") +SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") -DESTINATION_LOCATION_URI = getenv( - "DESTINATION_LOCATION_URI", "s3://mybucket/prefix") +DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") # [END howto_operator_datasync_1_args_2] @@ -55,16 +52,12 @@ # [START howto_operator_datasync_1_1] datasync_task_1 = AWSDataSyncOperator( - aws_conn_id="aws_default", - task_id="datasync_task_1", - task_arn=TASK_ARN + aws_conn_id="aws_default", task_id="datasync_task_1", task_arn=TASK_ARN ) # [END howto_operator_datasync_1_1] with models.DAG( - "example_datasync_1_2", - start_date=days_ago(1), - schedule_interval=None, # Override to match your needs + "example_datasync_1_2", start_date=days_ago(1), schedule_interval=None, # Override to match your needs ) as dag: # [START howto_operator_datasync_1_2] datasync_task_2 = AWSDataSyncOperator( diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync_2.py b/airflow/providers/amazon/aws/example_dags/example_datasync_2.py index c6b8e0eb51917..d4c7091db4fff 100644 --- a/airflow/providers/amazon/aws/example_dags/example_datasync_2.py +++ b/airflow/providers/amazon/aws/example_dags/example_datasync_2.py @@ -42,40 +42,30 @@ from airflow.utils.dates import days_ago # [START howto_operator_datasync_2_args] -SOURCE_LOCATION_URI = getenv( - "SOURCE_LOCATION_URI", "smb://hostname/directory/") +SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/") -DESTINATION_LOCATION_URI = getenv( - "DESTINATION_LOCATION_URI", "s3://mybucket/prefix") +DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix") default_create_task_kwargs = '{"Name": "Created by Airflow"}' -CREATE_TASK_KWARGS = json.loads( - getenv("CREATE_TASK_KWARGS", default_create_task_kwargs) -) +CREATE_TASK_KWARGS = json.loads(getenv("CREATE_TASK_KWARGS", default_create_task_kwargs)) default_create_source_location_kwargs = "{}" CREATE_SOURCE_LOCATION_KWARGS = json.loads( - getenv("CREATE_SOURCE_LOCATION_KWARGS", - default_create_source_location_kwargs) + getenv("CREATE_SOURCE_LOCATION_KWARGS", default_create_source_location_kwargs) ) -bucket_access_role_arn = ( - "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role" -) +bucket_access_role_arn = "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role" default_destination_location_kwargs = """\ {"S3BucketArn": "arn:aws:s3:::mybucket", "S3Config": {"BucketAccessRoleArn": "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"} }""" CREATE_DESTINATION_LOCATION_KWARGS = json.loads( - getenv("CREATE_DESTINATION_LOCATION_KWARGS", - re.sub(r"[\s+]", '', default_destination_location_kwargs)) + getenv("CREATE_DESTINATION_LOCATION_KWARGS", re.sub(r"[\s+]", '', default_destination_location_kwargs)) ) default_update_task_kwargs = '{"Name": "Updated by Airflow"}' -UPDATE_TASK_KWARGS = json.loads( - getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs) -) +UPDATE_TASK_KWARGS = json.loads(getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs)) # [END howto_operator_datasync_2_args] @@ -92,13 +82,10 @@ task_id="datasync_task", source_location_uri=SOURCE_LOCATION_URI, destination_location_uri=DESTINATION_LOCATION_URI, - create_task_kwargs=CREATE_TASK_KWARGS, create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS, create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS, - update_task_kwargs=UPDATE_TASK_KWARGS, - - delete_task_after_execution=True + delete_task_after_execution=True, ) # [END howto_operator_datasync_2] diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py index 94cecba7f9269..cef35606c00c9 100644 --- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py +++ b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py @@ -56,12 +56,7 @@ task_definition="hello-world", launch_type="FARGATE", overrides={ - "containerOverrides": [ - { - "name": "hello-world-container", - "command": ["echo", "hello", "world"], - }, - ], + "containerOverrides": [{"name": "hello-world-container", "command": ["echo", "hello", "world"],},], }, network_configuration={ "awsvpcConfiguration": { diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py index 3c52ffca58f07..3077944a0b7dc 100644 --- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py +++ b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py @@ -30,7 +30,7 @@ 'depends_on_past': False, 'email': ['airflow@example.com'], 'email_on_failure': False, - 'email_on_retry': False + 'email_on_retry': False, } # [START howto_operator_emr_automatic_steps_config] @@ -40,12 +40,8 @@ 'ActionOnFailure': 'CONTINUE', 'HadoopJarStep': { 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ] - } + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], + }, } ] @@ -85,13 +81,13 @@ task_id='create_job_flow', job_flow_overrides=JOB_FLOW_OVERRIDES, aws_conn_id='aws_default', - emr_conn_id='emr_default' + emr_conn_id='emr_default', ) job_sensor = EmrJobFlowSensor( task_id='check_job_flow', job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", - aws_conn_id='aws_default' + aws_conn_id='aws_default', ) job_flow_creator >> job_sensor diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py index 0b73bd3366110..1eb857a6a2d29 100644 --- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py +++ b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py @@ -35,7 +35,7 @@ 'depends_on_past': False, 'email': ['airflow@example.com'], 'email_on_failure': False, - 'email_on_retry': False + 'email_on_retry': False, } SPARK_STEPS = [ @@ -44,12 +44,8 @@ 'ActionOnFailure': 'CONTINUE', 'HadoopJarStep': { 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ] - } + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], + }, } ] @@ -87,27 +83,27 @@ task_id='create_job_flow', job_flow_overrides=JOB_FLOW_OVERRIDES, aws_conn_id='aws_default', - emr_conn_id='emr_default' + emr_conn_id='emr_default', ) step_adder = EmrAddStepsOperator( task_id='add_steps', job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", aws_conn_id='aws_default', - steps=SPARK_STEPS + steps=SPARK_STEPS, ) step_checker = EmrStepSensor( task_id='watch_step', job_flow_id="{{ task_instance.xcom_pull('create_job_flow', key='return_value') }}", step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}", - aws_conn_id='aws_default' + aws_conn_id='aws_default', ) cluster_remover = EmrTerminateJobFlowOperator( task_id='remove_cluster', job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}", - aws_conn_id='aws_default' + aws_conn_id='aws_default', ) cluster_creator >> step_adder >> step_checker >> cluster_remover diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py index f05c5ae225cde..fc14199bc300d 100644 --- a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py +++ b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py @@ -74,7 +74,7 @@ def _check_and_transform_video_ids(xcom_key, task_ids, task_instance, **kwargs): dag_id="example_google_api_to_s3_transfer_advanced", schedule_interval=None, start_date=days_ago(1), - tags=['example'] + tags=['example'], ) as dag: # [START howto_operator_google_api_to_s3_transfer_advanced_task_1] task_video_ids_to_s3 = GoogleApiToS3Operator( @@ -89,21 +89,18 @@ def _check_and_transform_video_ids(xcom_key, task_ids, task_instance, **kwargs): 'publishedAfter': YOUTUBE_VIDEO_PUBLISHED_AFTER, 'publishedBefore': YOUTUBE_VIDEO_PUBLISHED_BEFORE, 'type': 'video', - 'fields': 'items/id/videoId' + 'fields': 'items/id/videoId', }, google_api_response_via_xcom='video_ids_response', s3_destination_key=f'{s3_directory}/youtube_search_{s3_file_name}.json', - task_id='video_ids_to_s3' + task_id='video_ids_to_s3', ) # [END howto_operator_google_api_to_s3_transfer_advanced_task_1] # [START howto_operator_google_api_to_s3_transfer_advanced_task_1_1] task_check_and_transform_video_ids = BranchPythonOperator( python_callable=_check_and_transform_video_ids, - op_args=[ - task_video_ids_to_s3.google_api_response_via_xcom, - task_video_ids_to_s3.task_id - ], - task_id='check_and_transform_video_ids' + op_args=[task_video_ids_to_s3.google_api_response_via_xcom, task_video_ids_to_s3.task_id], + task_id='check_and_transform_video_ids', ) # [END howto_operator_google_api_to_s3_transfer_advanced_task_1_1] # [START howto_operator_google_api_to_s3_transfer_advanced_task_2] @@ -115,16 +112,14 @@ def _check_and_transform_video_ids(xcom_key, task_ids, task_instance, **kwargs): google_api_endpoint_params={ 'part': YOUTUBE_VIDEO_PARTS, 'maxResults': 50, - 'fields': YOUTUBE_VIDEO_FIELDS + 'fields': YOUTUBE_VIDEO_FIELDS, }, google_api_endpoint_params_via_xcom='video_ids', s3_destination_key=f'{s3_directory}/youtube_videos_{s3_file_name}.json', - task_id='video_data_to_s3' + task_id='video_data_to_s3', ) # [END howto_operator_google_api_to_s3_transfer_advanced_task_2] # [START howto_operator_google_api_to_s3_transfer_advanced_task_2_1] - task_no_video_ids = DummyOperator( - task_id='no_video_ids' - ) + task_no_video_ids = DummyOperator(task_id='no_video_ids') # [END howto_operator_google_api_to_s3_transfer_advanced_task_2_1] task_video_ids_to_s3 >> task_check_and_transform_video_ids >> [task_video_data_to_s3, task_no_video_ids] diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py index 515a9661bb224..f5c1ec177fca2 100644 --- a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py +++ b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py @@ -37,19 +37,16 @@ dag_id="example_google_api_to_s3_transfer_basic", schedule_interval=None, start_date=days_ago(1), - tags=['example'] + tags=['example'], ) as dag: # [START howto_operator_google_api_to_s3_transfer_basic_task_1] task_google_sheets_values_to_s3 = GoogleApiToS3Operator( google_api_service_name='sheets', google_api_service_version='v4', google_api_endpoint_path='sheets.spreadsheets.values.get', - google_api_endpoint_params={ - 'spreadsheetId': GOOGLE_SHEET_ID, - 'range': GOOGLE_SHEET_RANGE - }, + google_api_endpoint_params={'spreadsheetId': GOOGLE_SHEET_ID, 'range': GOOGLE_SHEET_RANGE}, s3_destination_key=S3_DESTINATION_KEY, task_id='google_sheets_values_to_s3', - dag=dag + dag=dag, ) # [END howto_operator_google_api_to_s3_transfer_basic_task_1] diff --git a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py index 0c308ba1360da..cfb8b955b8ff9 100644 --- a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py @@ -34,10 +34,7 @@ # [END howto_operator_imap_attachment_to_s3_env_variables] with DAG( - dag_id="example_imap_attachment_to_s3", - start_date=days_ago(1), - schedule_interval=None, - tags=['example'] + dag_id="example_imap_attachment_to_s3", start_date=days_ago(1), schedule_interval=None, tags=['example'] ) as dag: # [START howto_operator_imap_attachment_to_s3_task_1] task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator( @@ -46,6 +43,6 @@ imap_mail_folder=IMAP_MAIL_FOLDER, imap_mail_filter=IMAP_MAIL_FILTER, task_id='transfer_imap_attachment_to_s3', - dag=dag + dag=dag, ) # [END howto_operator_imap_attachment_to_s3_task_1] diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py b/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py index 0321cfa53ddbc..591ba0e0a2875 100644 --- a/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py +++ b/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py @@ -31,9 +31,7 @@ def upload_keys(): s3_hook = S3Hook() for i in range(0, 3): s3_hook.load_string( - string_data="input", - key=f"path/data{i}", - bucket_name=BUCKET_NAME, + string_data="input", key=f"path/data{i}", bucket_name=BUCKET_NAME, ) @@ -46,20 +44,15 @@ def upload_keys(): ) as dag: create_bucket = S3CreateBucketOperator( - task_id='s3_bucket_dag_create', - bucket_name=BUCKET_NAME, - region_name='us-east-1', + task_id='s3_bucket_dag_create', bucket_name=BUCKET_NAME, region_name='us-east-1', ) add_keys_to_bucket = PythonOperator( - task_id="s3_bucket_dag_add_keys_to_bucket", - python_callable=upload_keys + task_id="s3_bucket_dag_add_keys_to_bucket", python_callable=upload_keys ) delete_bucket = S3DeleteBucketOperator( - task_id='s3_bucket_dag_delete', - bucket_name=BUCKET_NAME, - force_delete=True, + task_id='s3_bucket_dag_delete', bucket_name=BUCKET_NAME, force_delete=True, ) create_bucket >> add_keys_to_bucket >> delete_bucket diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py index 2ffccbc593736..76c79e521df17 100644 --- a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py +++ b/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py @@ -47,19 +47,15 @@ def _remove_sample_data_from_s3(): with DAG( - dag_id="example_s3_to_redshift", - start_date=days_ago(1), - schedule_interval=None, - tags=['example'] + dag_id="example_s3_to_redshift", start_date=days_ago(1), schedule_interval=None, tags=['example'] ) as dag: setup__task_add_sample_data_to_s3 = PythonOperator( - python_callable=_add_sample_data_to_s3, - task_id='setup__add_sample_data_to_s3' + python_callable=_add_sample_data_to_s3, task_id='setup__add_sample_data_to_s3' ) setup__task_create_table = PostgresOperator( sql=f'CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)', postgres_conn_id='redshift_default', - task_id='setup__create_table' + task_id='setup__create_table', ) # [START howto_operator_s3_to_redshift_task_1] task_transfer_s3_to_redshift = S3ToRedshiftOperator( @@ -68,22 +64,18 @@ def _remove_sample_data_from_s3(): schema="PUBLIC", table=REDSHIFT_TABLE, copy_options=['csv'], - task_id='transfer_s3_to_redshift' + task_id='transfer_s3_to_redshift', ) # [END howto_operator_s3_to_redshift_task_1] teardown__task_drop_table = PostgresOperator( sql=f'DROP TABLE IF EXISTS {REDSHIFT_TABLE}', postgres_conn_id='redshift_default', - task_id='teardown__drop_table' + task_id='teardown__drop_table', ) teardown__task_remove_sample_data_from_s3 = PythonOperator( - python_callable=_remove_sample_data_from_s3, - task_id='teardown__remove_sample_data_from_s3' + python_callable=_remove_sample_data_from_s3, task_id='teardown__remove_sample_data_from_s3' ) - [ - setup__task_add_sample_data_to_s3, - setup__task_create_table - ] >> task_transfer_s3_to_redshift >> [ + [setup__task_add_sample_data_to_s3, setup__task_create_table] >> task_transfer_s3_to_redshift >> [ teardown__task_drop_table, - teardown__task_remove_sample_data_from_s3 + teardown__task_remove_sample_data_from_s3, ] diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index 830bb8f3e577c..a7fb9472906bc 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -41,22 +41,28 @@ class AWSAthenaHook(AwsBaseHook): :type sleep_time: int """ - INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',) - FAILURE_STATES = ('FAILED', 'CANCELLED',) + INTERMEDIATE_STATES = ( + 'QUEUED', + 'RUNNING', + ) + FAILURE_STATES = ( + 'FAILED', + 'CANCELLED', + ) SUCCESS_STATES = ('SUCCEEDED',) - def __init__(self, - *args: Any, - sleep_time: int = 30, - **kwargs: Any) -> None: + def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None: super().__init__(client_type='athena', *args, **kwargs) # type: ignore self.sleep_time = sleep_time - def run_query(self, query: str, - query_context: Dict[str, str], - result_configuration: Dict[str, Any], - client_request_token: Optional[str] = None, - workgroup: str = 'primary') -> str: + def run_query( + self, + query: str, + query_context: Dict[str, str], + result_configuration: Dict[str, Any], + client_request_token: Optional[str] = None, + workgroup: str = 'primary', + ) -> str: """ Run Presto query on athena with provided config and return submitted query_execution_id @@ -76,7 +82,7 @@ def run_query(self, query: str, 'QueryString': query, 'QueryExecutionContext': query_context, 'ResultConfiguration': result_configuration, - 'WorkGroup': workgroup + 'WorkGroup': workgroup, } if client_request_token: params['ClientRequestToken'] = client_request_token @@ -122,9 +128,9 @@ def get_state_change_reason(self, query_execution_id: str) -> Optional[str]: # The error is being absorbed to implement retries. return reason # pylint: disable=lost-exception - def get_query_results(self, query_execution_id: str, - next_token_id: Optional[str] = None, - max_results: int = 1000) -> Optional[dict]: + def get_query_results( + self, query_execution_id: str, next_token_id: Optional[str] = None, max_results: int = 1000 + ) -> Optional[dict]: """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else dict of query output @@ -144,19 +150,18 @@ def get_query_results(self, query_execution_id: str, elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error('Query is in "%s" state. Cannot fetch results', query_state) return None - result_params = { - 'QueryExecutionId': query_execution_id, - 'MaxResults': max_results - } + result_params = {'QueryExecutionId': query_execution_id, 'MaxResults': max_results} if next_token_id: result_params['NextToken'] = next_token_id return self.get_conn().get_query_results(**result_params) - def get_query_results_paginator(self, query_execution_id: str, - max_items: Optional[int] = None, - page_size: Optional[int] = None, - starting_token: Optional[str] = None - ) -> Optional[PageIterator]: + def get_query_results_paginator( + self, + query_execution_id: str, + max_items: Optional[int] = None, + page_size: Optional[int] = None, + starting_token: Optional[str] = None, + ) -> Optional[PageIterator]: """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else a paginator to iterate through pages of results. If you @@ -184,15 +189,13 @@ def get_query_results_paginator(self, query_execution_id: str, 'PaginationConfig': { 'MaxItems': max_items, 'PageSize': page_size, - 'StartingToken': starting_token - - } + 'StartingToken': starting_token, + }, } paginator = self.get_conn().get_paginator('get_query_results') return paginator.paginate(**result_params) - def poll_query_status(self, query_execution_id: str, - max_tries: Optional[int] = None) -> Optional[str]: + def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = None) -> Optional[str]: """ Poll the status of submitted athena query until query state reaches final state. Returns one of the final states diff --git a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py b/airflow/providers/amazon/aws/hooks/aws_dynamodb.py index fd5dde36530cb..f197aa7aece7f 100644 --- a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py +++ b/airflow/providers/amazon/aws/hooks/aws_dynamodb.py @@ -58,7 +58,5 @@ def write_batch_data(self, items): return True except Exception as general_error: raise AirflowException( - 'Failed to insert items in dynamodb, error: {error}'.format( - error=str(general_error) - ) + 'Failed to insert items in dynamodb, error: {error}'.format(error=str(general_error)) ) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 350ff7effb6e4..c2348d3c93b80 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -189,6 +189,7 @@ def _assume_role_with_saml( def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]): import requests + # requests_gssapi will need paramiko > 2.6 since you'll need # 'gssapi' not 'python-gssapi' from PyPi. # https://github.com/paramiko/paramiko/pull/1311 @@ -269,7 +270,7 @@ def __init__( region_name: Optional[str] = None, client_type: Optional[str] = None, resource_type: Optional[str] = None, - config: Optional[Config] = None + config: Optional[Config] = None, ) -> None: super().__init__() self.aws_conn_id = aws_conn_id @@ -280,9 +281,7 @@ def __init__( self.config = config if not (self.client_type or self.resource_type): - raise AirflowException( - 'Either client_type or resource_type' - ' must be provided.') + raise AirflowException('Either client_type or resource_type' ' must be provided.') def _get_credentials(self, region_name): @@ -302,7 +301,7 @@ def _get_credentials(self, region_name): if "config_kwargs" in extra_config: self.log.info( "Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s", - extra_config["config_kwargs"] + extra_config["config_kwargs"], ) self.config = Config(**extra_config["config_kwargs"]) @@ -318,8 +317,7 @@ def _get_credentials(self, region_name): # http://boto3.readthedocs.io/en/latest/guide/configuration.html self.log.info( - "Creating session using boto3 credential strategy region_name=%s", - region_name, + "Creating session using boto3 credential strategy region_name=%s", region_name, ) session = boto3.session.Session(region_name=region_name) return session, None @@ -333,9 +331,7 @@ def get_client_type(self, client_type, region_name=None, config=None): if config is None: config = self.config - return session.client( - client_type, endpoint_url=endpoint_url, config=config, verify=self.verify - ) + return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify) def get_resource_type(self, resource_type, region_name=None, config=None): """Get the underlying boto3 resource using boto3 session""" @@ -346,9 +342,7 @@ def get_resource_type(self, resource_type, region_name=None, config=None): if config is None: config = self.config - return session.resource( - resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify - ) + return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify) @cached_property def conn(self): diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index 37a8ce0ccf3c9..d53c6f535d24c 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -199,11 +199,7 @@ class AwsBatchClientHook(AwsBaseHook): DEFAULT_DELAY_MAX = 10 def __init__( - self, - *args, - max_retries: Optional[int] = None, - status_retries: Optional[int] = None, - **kwargs + self, *args, max_retries: Optional[int] = None, status_retries: Optional[int] = None, **kwargs ): # https://github.com/python/mypy/issues/6799 hence type: ignore super().__init__(client_type='batch', *args, **kwargs) # type: ignore @@ -211,7 +207,7 @@ def __init__( self.status_retries = status_retries or self.STATUS_RETRIES @property - def client(self) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402 + def client(self) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402 """ An AWS API client for batch services, like ``boto3.client('batch')`` @@ -353,9 +349,7 @@ def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: return True if retries >= self.max_retries: - raise AirflowException( - "AWS Batch job ({}) status checks exceed max_retries".format(job_id) - ) + raise AirflowException("AWS Batch job ({}) status checks exceed max_retries".format(job_id)) retries += 1 pause = self.exponential_delay(retries) @@ -391,9 +385,7 @@ def get_job_description(self, job_id: str) -> Dict: if error.get("Code") == "TooManyRequestsException": pass # allow it to retry, if possible else: - raise AirflowException( - "AWS Batch job ({}) description error: {}".format(job_id, err) - ) + raise AirflowException("AWS Batch job ({}) description error: {}".format(job_id, err)) retries += 1 if retries >= self.status_retries: diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py b/airflow/providers/amazon/aws/hooks/batch_waiters.py index 75bfb58c1a9dc..d4e91d93c169e 100644 --- a/airflow/providers/amazon/aws/hooks/batch_waiters.py +++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py @@ -102,12 +102,7 @@ class AwsBatchWaitersHook(AwsBatchClientHook): :type region_name: Optional[str] """ - def __init__( - self, - *args, - waiter_config: Optional[Dict] = None, - **kwargs - ): + def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs): super().__init__(*args, **kwargs) @@ -183,9 +178,7 @@ def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter: :return: a waiter object for the named AWS batch service :rtype: botocore.waiter.Waiter """ - return botocore.waiter.create_waiter_with_client( - waiter_name, self.waiter_model, self.client - ) + return botocore.waiter.create_waiter_with_client(waiter_name, self.waiter_model, self.client) def list_waiters(self) -> List[str]: """ diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py index 153a75fa25d2a..b6ef08ed46bb0 100644 --- a/airflow/providers/amazon/aws/hooks/datasync.py +++ b/airflow/providers/amazon/aws/hooks/datasync.py @@ -58,8 +58,7 @@ def __init__(self, wait_interval_seconds=5, *args, **kwargs): self.tasks = [] # wait_interval_seconds = 0 is used during unit tests if wait_interval_seconds < 0 or wait_interval_seconds > 15 * 60: - raise ValueError("Invalid wait_interval_seconds %s" % - wait_interval_seconds) + raise ValueError("Invalid wait_interval_seconds %s" % wait_interval_seconds) self.wait_interval_seconds = wait_interval_seconds def create_location(self, location_uri, **create_location_kwargs): @@ -85,9 +84,7 @@ def create_location(self, location_uri, **create_location_kwargs): self._refresh_locations() return location["LocationArn"] - def get_location_arns( - self, location_uri, case_sensitive=False, ignore_trailing_slash=True - ): + def get_location_arns(self, location_uri, case_sensitive=False, ignore_trailing_slash=True): """ Return all LocationArns which match a LocationUri. @@ -133,9 +130,7 @@ def _refresh_locations(self): break next_token = locations["NextToken"] - def create_task( - self, source_location_arn, destination_location_arn, **create_task_kwargs - ): + def create_task(self, source_location_arn, destination_location_arn, **create_task_kwargs): r"""Create a Task between the specified source and destination LocationArns. :param str source_location_arn: Source LocationArn. Must exist already. @@ -147,7 +142,7 @@ def create_task( task = self.get_conn().create_task( SourceLocationArn=source_location_arn, DestinationLocationArn=destination_location_arn, - **create_task_kwargs + **create_task_kwargs, ) self._refresh_tasks() return task["TaskArn"] @@ -181,9 +176,7 @@ def _refresh_tasks(self): break next_token = tasks["NextToken"] - def get_task_arns_for_location_arns( - self, source_location_arns, destination_location_arns - ): + def get_task_arns_for_location_arns(self, source_location_arns, destination_location_arns): """ Return list of TaskArns for which use any one of the specified source LocationArns and any one of the specified destination LocationArns. @@ -224,9 +217,7 @@ def start_task_execution(self, task_arn, **kwargs): """ if not task_arn: raise AirflowBadRequest("task_arn not specified") - task_execution = self.get_conn().start_task_execution( - TaskArn=task_arn, **kwargs - ) + task_execution = self.get_conn().start_task_execution(TaskArn=task_arn, **kwargs) return task_execution["TaskExecutionArn"] def cancel_task_execution(self, task_execution_arn): @@ -298,9 +289,7 @@ def wait_for_task_execution(self, task_execution_arn, max_iterations=2 * 180): status = None iterations = max_iterations while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES: - task_execution = self.get_conn().describe_task_execution( - TaskExecutionArn=task_execution_arn - ) + task_execution = self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn) status = task_execution["Status"] self.log.info("status=%s", status) iterations -= 1 @@ -318,5 +307,4 @@ def wait_for_task_execution(self, task_execution_arn, max_iterations=2 * 180): return False if iterations <= 0: raise AirflowTaskTimeout("Max iterations exceeded!") - raise AirflowException("Unknown status: %s" % - status) # Should never happen + raise AirflowException("Unknown status: %s" % status) # Should never happen diff --git a/airflow/providers/amazon/aws/hooks/ec2.py b/airflow/providers/amazon/aws/hooks/ec2.py index f8120c39ad1f7..34517d752e699 100644 --- a/airflow/providers/amazon/aws/hooks/ec2.py +++ b/airflow/providers/amazon/aws/hooks/ec2.py @@ -33,9 +33,7 @@ class EC2Hook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - def __init__(self, - *args, - **kwargs): + def __init__(self, *args, **kwargs): super().__init__(resource_type="ec2", *args, **kwargs) def get_instance(self, instance_id: str): @@ -60,10 +58,7 @@ def get_instance_state(self, instance_id: str) -> str: """ return self.get_instance(instance_id=instance_id).state["Name"] - def wait_for_state(self, - instance_id: str, - target_state: str, - check_interval: float) -> None: + def wait_for_state(self, instance_id: str, target_state: str, check_interval: float) -> None: """ Wait EC2 instance until its state is equal to the target_state. @@ -77,12 +72,8 @@ def wait_for_state(self, :return: None :rtype: None """ - instance_state = self.get_instance_state( - instance_id=instance_id - ) + instance_state = self.get_instance_state(instance_id=instance_id) while instance_state != target_state: self.log.info("instance state: %s", instance_state) time.sleep(check_interval) - instance_state = self.get_instance_state( - instance_id=instance_id - ) + instance_state = self.get_instance_state(instance_id=instance_id) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 001374e86410e..6bad910f472c9 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -47,9 +47,7 @@ def get_cluster_id_by_name(self, emr_cluster_name, cluster_states): :return: id of the EMR cluster """ - response = self.get_conn().list_clusters( - ClusterStates=cluster_states - ) + response = self.get_conn().list_clusters(ClusterStates=cluster_states) matching_clusters = list( filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters']) diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 9db925dd8b57d..dde6362f8b145 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -46,19 +46,23 @@ class AwsGlueJobHook(AwsBaseHook): :param iam_role_name: AWS IAM Role for Glue Job :type iam_role_name: Optional[str] """ + JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds - def __init__(self, - s3_bucket: Optional[str] = None, - job_name: Optional[str] = None, - desc: Optional[str] = None, - concurrent_run_limit: int = 1, - script_location: Optional[str] = None, - retry_limit: int = 0, - num_of_dpus: int = 10, - region_name: Optional[str] = None, - iam_role_name: Optional[str] = None, - *args, **kwargs): + def __init__( + self, + s3_bucket: Optional[str] = None, + job_name: Optional[str] = None, + desc: Optional[str] = None, + concurrent_run_limit: int = 1, + script_location: Optional[str] = None, + retry_limit: int = 0, + num_of_dpus: int = 10, + region_name: Optional[str] = None, + iam_role_name: Optional[str] = None, + *args, + **kwargs, + ): self.job_name = job_name self.desc = desc self.concurrent_run_limit = concurrent_run_limit @@ -104,10 +108,7 @@ def initialize_job(self, script_arguments: Optional[List] = None) -> Dict[str, s try: job_name = self.get_or_create_glue_job() - job_run = glue_client.start_job_run( - JobName=job_name, - Arguments=script_arguments - ) + job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments) return job_run except Exception as general_error: self.log.error("Failed to run aws glue job, error: %s", general_error) @@ -124,11 +125,7 @@ def get_job_state(self, job_name: str, run_id: str) -> str: :return: State of the Glue job """ glue_client = self.get_conn() - job_run = glue_client.get_job_run( - JobName=job_name, - RunId=run_id, - PredecessorsIncluded=True - ) + job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) job_run_state = job_run['JobRun']['JobRunState'] return job_run_state @@ -157,8 +154,8 @@ def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]: raise AirflowException(job_error_message) else: self.log.info( - "Polling for AWS Glue Job %s current run state with status %s", - job_name, job_run_state) + "Polling for AWS Glue Job %s current run state with status %s", job_name, job_run_state + ) time.sleep(self.JOB_POLL_INTERVAL) def get_or_create_glue_job(self) -> str: @@ -176,8 +173,7 @@ def get_or_create_glue_job(self) -> str: self.log.info("Job doesnt exist. Now creating and running AWS Glue Job") if self.s3_bucket is None: raise AirflowException( - 'Could not initialize glue job, ' - 'error: Specify Parameter `s3_bucket`' + 'Could not initialize glue job, ' 'error: Specify Parameter `s3_bucket`' ) s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}' execution_role = self.get_iam_execution_role() @@ -190,7 +186,7 @@ def get_or_create_glue_job(self) -> str: ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, Command={"Name": "glueetl", "ScriptLocation": self.script_location}, MaxRetries=self.retry_limit, - AllocatedCapacity=self.num_of_dpus + AllocatedCapacity=self.num_of_dpus, ) return create_job_response['Name'] except Exception as general_error: diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py index 5a53328a83ca3..27fc7c121b942 100644 --- a/airflow/providers/amazon/aws/hooks/glue_catalog.py +++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py @@ -36,12 +36,7 @@ class AwsGlueCatalogHook(AwsBaseHook): def __init__(self, *args, **kwargs): super().__init__(client_type='glue', *args, **kwargs) - def get_partitions(self, - database_name, - table_name, - expression='', - page_size=None, - max_items=None): + def get_partitions(self, database_name, table_name, expression='', page_size=None, max_items=None): """ Retrieves the partition values for a table. @@ -68,10 +63,7 @@ def get_partitions(self, paginator = self.get_conn().get_paginator('get_partitions') response = paginator.paginate( - DatabaseName=database_name, - TableName=table_name, - Expression=expression, - PaginationConfig=config + DatabaseName=database_name, TableName=table_name, Expression=expression, PaginationConfig=config ) partitions = set() diff --git a/airflow/providers/amazon/aws/hooks/kinesis.py b/airflow/providers/amazon/aws/hooks/kinesis.py index 04a50f71de2b4..1c8480a8e2f79 100644 --- a/airflow/providers/amazon/aws/hooks/kinesis.py +++ b/airflow/providers/amazon/aws/hooks/kinesis.py @@ -45,9 +45,6 @@ def put_records(self, records): Write batch records to Kinesis Firehose """ - response = self.get_conn().put_record_batch( - DeliveryStreamName=self.delivery_stream, - Records=records - ) + response = self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records) return response diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py index 2656b7e0a6710..a1d9b6142ffdf 100644 --- a/airflow/providers/amazon/aws/hooks/lambda_function.py +++ b/airflow/providers/amazon/aws/hooks/lambda_function.py @@ -42,9 +42,15 @@ class AwsLambdaHook(AwsBaseHook): :type invocation_type: str """ - def __init__(self, function_name, - log_type='None', qualifier='$LATEST', - invocation_type='RequestResponse', *args, **kwargs): + def __init__( + self, + function_name, + log_type='None', + qualifier='$LATEST', + invocation_type='RequestResponse', + *args, + **kwargs, + ): self.function_name = function_name self.log_type = log_type self.invocation_type = invocation_type @@ -61,7 +67,7 @@ def invoke_lambda(self, payload): InvocationType=self.invocation_type, LogType=self.log_type, Payload=payload, - Qualifier=self.qualifier + Qualifier=self.qualifier, ) return response diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py index f8c536a8a3cee..1abb83d1a82dc 100644 --- a/airflow/providers/amazon/aws/hooks/logs.py +++ b/airflow/providers/amazon/aws/hooks/logs.py @@ -71,11 +71,13 @@ def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start else: token_arg = {} - response = self.get_conn().get_log_events(logGroupName=log_group, - logStreamName=log_stream_name, - startTime=start_time, - startFromHead=start_from_head, - **token_arg) + response = self.get_conn().get_log_events( + logGroupName=log_group, + logStreamName=log_stream_name, + startTime=start_time, + startFromHead=start_from_head, + **token_arg, + ) events = response['events'] event_count = len(events) diff --git a/airflow/providers/amazon/aws/hooks/redshift.py b/airflow/providers/amazon/aws/hooks/redshift.py index 57f59c86ee85d..065e97573d041 100644 --- a/airflow/providers/amazon/aws/hooks/redshift.py +++ b/airflow/providers/amazon/aws/hooks/redshift.py @@ -47,17 +47,17 @@ def cluster_status(self, cluster_identifier: str) -> str: :type cluster_identifier: str """ try: - response = self.get_conn().describe_clusters( - ClusterIdentifier=cluster_identifier)['Clusters'] + response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)['Clusters'] return response[0]['ClusterStatus'] if response else None except self.get_conn().exceptions.ClusterNotFoundFault: return 'cluster_not_found' def delete_cluster( # pylint: disable=invalid-name - self, - cluster_identifier: str, - skip_final_cluster_snapshot: bool = True, - final_cluster_snapshot_identifier: Optional[str] = None): + self, + cluster_identifier: str, + skip_final_cluster_snapshot: bool = True, + final_cluster_snapshot_identifier: Optional[str] = None, + ): """ Delete a cluster and optionally create a snapshot @@ -73,7 +73,7 @@ def delete_cluster( # pylint: disable=invalid-name response = self.get_conn().delete_cluster( ClusterIdentifier=cluster_identifier, SkipFinalClusterSnapshot=skip_final_cluster_snapshot, - FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier + FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier, ) return response['Cluster'] if response['Cluster'] else None @@ -84,9 +84,7 @@ def describe_cluster_snapshots(self, cluster_identifier: str) -> Optional[List[s :param cluster_identifier: unique identifier of a cluster :type cluster_identifier: str """ - response = self.get_conn().describe_cluster_snapshots( - ClusterIdentifier=cluster_identifier - ) + response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier) if 'Snapshots' not in response: return None snapshots = response['Snapshots'] @@ -94,10 +92,7 @@ def describe_cluster_snapshots(self, cluster_identifier: str) -> Optional[List[s snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True) return snapshots - def restore_from_cluster_snapshot( - self, - cluster_identifier: str, - snapshot_identifier: str) -> str: + def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identifier: str) -> str: """ Restores a cluster from its snapshot @@ -107,15 +102,11 @@ def restore_from_cluster_snapshot( :type snapshot_identifier: str """ response = self.get_conn().restore_from_cluster_snapshot( - ClusterIdentifier=cluster_identifier, - SnapshotIdentifier=snapshot_identifier + ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier ) return response['Cluster'] if response['Cluster'] else None - def create_cluster_snapshot( - self, - snapshot_identifier: str, - cluster_identifier: str) -> str: + def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: str) -> str: """ Creates a snapshot of a cluster @@ -125,7 +116,6 @@ def create_cluster_snapshot( :type cluster_identifier: str """ response = self.get_conn().create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier, + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, ) return response['Snapshot'] if response['Snapshot'] else None diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 976e7ebd8575b..319c5ea4a345b 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -87,8 +87,9 @@ def get_key_name() -> Optional[str]: key_name = get_key_name() if key_name and 'bucket_name' not in bound_args.arguments: - bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = \ - S3Hook.parse_s3_url(bound_args.arguments[key_name]) + bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = S3Hook.parse_s3_url( + bound_args.arguments[key_name] + ) return func(*bound_args.args, **bound_args.kwargs) @@ -161,9 +162,7 @@ def get_bucket(self, bucket_name: Optional[str] = None) -> str: return s3_resource.Bucket(bucket_name) @provide_bucket_name - def create_bucket(self, - bucket_name: Optional[str] = None, - region_name: Optional[str] = None) -> None: + def create_bucket(self, bucket_name: Optional[str] = None, region_name: Optional[str] = None) -> None: """ Creates an Amazon S3 bucket. @@ -177,16 +176,12 @@ def create_bucket(self, if region_name == 'us-east-1': self.get_conn().create_bucket(Bucket=bucket_name) else: - self.get_conn().create_bucket(Bucket=bucket_name, - CreateBucketConfiguration={ - 'LocationConstraint': region_name - }) + self.get_conn().create_bucket( + Bucket=bucket_name, CreateBucketConfiguration={'LocationConstraint': region_name} + ) @provide_bucket_name - def check_for_prefix(self, - prefix: str, - delimiter: str, - bucket_name: Optional[str] = None) -> bool: + def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[str] = None) -> bool: """ Checks that a prefix exists in a bucket @@ -206,12 +201,14 @@ def check_for_prefix(self, return False if plist is None else prefix in plist @provide_bucket_name - def list_prefixes(self, - bucket_name: Optional[str] = None, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - page_size: Optional[int] = None, - max_items: Optional[int] = None) -> Optional[list]: + def list_prefixes( + self, + bucket_name: Optional[str] = None, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + page_size: Optional[int] = None, + max_items: Optional[int] = None, + ) -> Optional[list]: """ Lists prefixes in a bucket under prefix @@ -236,10 +233,9 @@ def list_prefixes(self, } paginator = self.get_conn().get_paginator('list_objects_v2') - response = paginator.paginate(Bucket=bucket_name, - Prefix=prefix, - Delimiter=delimiter, - PaginationConfig=config) + response = paginator.paginate( + Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config + ) has_results = False prefixes = [] @@ -254,12 +250,14 @@ def list_prefixes(self, return None @provide_bucket_name - def list_keys(self, - bucket_name: Optional[str] = None, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - page_size: Optional[int] = None, - max_items: Optional[int] = None) -> Optional[list]: + def list_keys( + self, + bucket_name: Optional[str] = None, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + page_size: Optional[int] = None, + max_items: Optional[int] = None, + ) -> Optional[list]: """ Lists keys in a bucket under prefix and not containing delimiter @@ -284,10 +282,9 @@ def list_keys(self, } paginator = self.get_conn().get_paginator('list_objects_v2') - response = paginator.paginate(Bucket=bucket_name, - Prefix=prefix, - Delimiter=delimiter, - PaginationConfig=config) + response = paginator.paginate( + Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config + ) has_results = False keys = [] @@ -359,13 +356,15 @@ def read_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer: @provide_bucket_name @unify_bucket_name_and_key - def select_key(self, - key: str, - bucket_name: Optional[str] = None, - expression: Optional[str] = None, - expression_type: Optional[str] = None, - input_serialization: Optional[Dict[str, Any]] = None, - output_serialization: Optional[Dict[str, Any]] = None) -> str: + def select_key( + self, + key: str, + bucket_name: Optional[str] = None, + expression: Optional[str] = None, + expression_type: Optional[str] = None, + input_serialization: Optional[Dict[str, Any]] = None, + output_serialization: Optional[Dict[str, Any]] = None, + ) -> str: """ Reads a key with S3 Select. @@ -402,18 +401,18 @@ def select_key(self, Expression=expression, ExpressionType=expression_type, InputSerialization=input_serialization, - OutputSerialization=output_serialization) + OutputSerialization=output_serialization, + ) - return ''.join(event['Records']['Payload'].decode('utf-8') - for event in response['Payload'] - if 'Records' in event) + return ''.join( + event['Records']['Payload'].decode('utf-8') for event in response['Payload'] if 'Records' in event + ) @provide_bucket_name @unify_bucket_name_and_key - def check_for_wildcard_key(self, - wildcard_key: str, - bucket_name: Optional[str] = None, - delimiter: str = '') -> bool: + def check_for_wildcard_key( + self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = '' + ) -> bool: """ Checks that a key matching a wildcard expression exists in a bucket @@ -426,16 +425,16 @@ def check_for_wildcard_key(self, :return: True if a key exists and False if not. :rtype: bool """ - return self.get_wildcard_key(wildcard_key=wildcard_key, - bucket_name=bucket_name, - delimiter=delimiter) is not None + return ( + self.get_wildcard_key(wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter) + is not None + ) @provide_bucket_name @unify_bucket_name_and_key - def get_wildcard_key(self, - wildcard_key: str, - bucket_name: Optional[str] = None, - delimiter: str = '') -> S3Transfer: + def get_wildcard_key( + self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = '' + ) -> S3Transfer: """ Returns a boto3.s3.Object object matching the wildcard expression @@ -459,14 +458,16 @@ def get_wildcard_key(self, @provide_bucket_name @unify_bucket_name_and_key - def load_file(self, - filename: str, - key: str, - bucket_name: Optional[str] = None, - replace: bool = False, - encrypt: bool = False, - gzip: bool = False, - acl_policy: Optional[str] = None) -> None: + def load_file( + self, + filename: str, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + gzip: bool = False, + acl_policy: Optional[str] = None, + ) -> None: """ Loads a local file to S3 @@ -511,14 +512,16 @@ def load_file(self, @provide_bucket_name @unify_bucket_name_and_key - def load_string(self, - string_data: str, - key: str, - bucket_name: Optional[str] = None, - replace: bool = False, - encrypt: bool = False, - encoding: Optional[str] = None, - acl_policy: Optional[str] = None) -> None: + def load_string( + self, + string_data: str, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + encoding: Optional[str] = None, + acl_policy: Optional[str] = None, + ) -> None: """ Loads a string to S3 @@ -552,13 +555,15 @@ def load_string(self, @provide_bucket_name @unify_bucket_name_and_key - def load_bytes(self, - bytes_data: bytes, - key: str, - bucket_name: Optional[str] = None, - replace: bool = False, - encrypt: bool = False, - acl_policy: Optional[str] = None) -> None: + def load_bytes( + self, + bytes_data: bytes, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + acl_policy: Optional[str] = None, + ) -> None: """ Loads bytes to S3 @@ -587,13 +592,15 @@ def load_bytes(self, @provide_bucket_name @unify_bucket_name_and_key - def load_file_obj(self, - file_obj: BytesIO, - key: str, - bucket_name: Optional[str] = None, - replace: bool = False, - encrypt: bool = False, - acl_policy: Optional[str] = None) -> None: + def load_file_obj( + self, + file_obj: BytesIO, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + acl_policy: Optional[str] = None, + ) -> None: """ Loads a file object to S3 @@ -615,13 +622,15 @@ def load_file_obj(self, """ self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy) - def _upload_file_obj(self, - file_obj: BytesIO, - key: str, - bucket_name: Optional[str] = None, - replace: bool = False, - encrypt: bool = False, - acl_policy: Optional[str] = None) -> None: + def _upload_file_obj( + self, + file_obj: BytesIO, + key: str, + bucket_name: Optional[str] = None, + replace: bool = False, + encrypt: bool = False, + acl_policy: Optional[str] = None, + ) -> None: if not replace and self.check_for_key(key, bucket_name): raise ValueError("The key {key} already exists.".format(key=key)) @@ -634,13 +643,15 @@ def _upload_file_obj(self, client = self.get_conn() client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args) - def copy_object(self, - source_bucket_key: str, - dest_bucket_key: str, - source_bucket_name: Optional[str] = None, - dest_bucket_name: Optional[str] = None, - source_version_id: Optional[str] = None, - acl_policy: Optional[str] = None) -> None: + def copy_object( + self, + source_bucket_key: str, + dest_bucket_key: str, + source_bucket_name: Optional[str] = None, + dest_bucket_name: Optional[str] = None, + source_version_id: Optional[str] = None, + acl_policy: Optional[str] = None, + ) -> None: """ Creates a copy of an object that is already stored in S3. @@ -679,26 +690,27 @@ def copy_object(self, else: parsed_url = urlparse(dest_bucket_key) if parsed_url.scheme != '' or parsed_url.netloc != '': - raise AirflowException('If dest_bucket_name is provided, ' + - 'dest_bucket_key should be relative path ' + - 'from root level, rather than a full s3:// url') + raise AirflowException( + 'If dest_bucket_name is provided, ' + + 'dest_bucket_key should be relative path ' + + 'from root level, rather than a full s3:// url' + ) if source_bucket_name is None: source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key) else: parsed_url = urlparse(source_bucket_key) if parsed_url.scheme != '' or parsed_url.netloc != '': - raise AirflowException('If source_bucket_name is provided, ' + - 'source_bucket_key should be relative path ' + - 'from root level, rather than a full s3:// url') - - copy_source = {'Bucket': source_bucket_name, - 'Key': source_bucket_key, - 'VersionId': source_version_id} - response = self.get_conn().copy_object(Bucket=dest_bucket_name, - Key=dest_bucket_key, - CopySource=copy_source, - ACL=acl_policy) + raise AirflowException( + 'If source_bucket_name is provided, ' + + 'source_bucket_key should be relative path ' + + 'from root level, rather than a full s3:// url' + ) + + copy_source = {'Bucket': source_bucket_name, 'Key': source_bucket_key, 'VersionId': source_version_id} + response = self.get_conn().copy_object( + Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy + ) return response @provide_bucket_name @@ -717,9 +729,7 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None: bucket_keys = self.list_keys(bucket_name=bucket_name) if bucket_keys: self.delete_objects(bucket=bucket_name, keys=bucket_keys) - self.conn.delete_bucket( - Bucket=bucket_name - ) + self.conn.delete_bucket(Bucket=bucket_name) def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: """ @@ -745,10 +755,7 @@ def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: # For details see: # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects for chunk in chunks(keys, chunk_size=1000): - response = s3.delete_objects( - Bucket=bucket, - Delete={"Objects": [{"Key": k} for k in chunk]} - ) + response = s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}) deleted_keys = [x['Key'] for x in response.get("Deleted", [])] self.log.info("Deleted: %s", deleted_keys) if "Errors" in response: @@ -757,10 +764,9 @@ def delete_objects(self, bucket: str, keys: Union[str, list]) -> None: @provide_bucket_name @unify_bucket_name_and_key - def download_file(self, - key: str, - bucket_name: Optional[str] = None, - local_path: Optional[str] = None) -> str: + def download_file( + self, key: str, bucket_name: Optional[str] = None, local_path: Optional[str] = None + ) -> str: """ Downloads a file from the S3 location to the local file system. @@ -786,11 +792,13 @@ def download_file(self, return local_tmp_file.name - def generate_presigned_url(self, - client_method: str, - params: Optional[dict] = None, - expires_in: int = 3600, - http_method: Optional[str] = None) -> Optional[str]: + def generate_presigned_url( + self, + client_method: str, + params: Optional[dict] = None, + expires_in: int = 3600, + http_method: Optional[str] = None, + ) -> Optional[str]: """ Generate a presigned url given a client, its method, and arguments @@ -810,10 +818,9 @@ def generate_presigned_url(self, s3_client = self.get_conn() try: - return s3_client.generate_presigned_url(ClientMethod=client_method, - Params=params, - ExpiresIn=expires_in, - HttpMethod=http_method) + return s3_client.generate_presigned_url( + ClientMethod=client_method, Params=params, ExpiresIn=expires_in, HttpMethod=http_method + ) except ClientError as e: self.log.error(e.response["Error"]["Message"]) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index bb65a55f58c7e..fb5aed608e7c8 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -38,6 +38,7 @@ class LogState: Enum-style class holding all possible states of CloudWatch log streams. https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState """ + STARTING = 1 WAIT_IN_PROGRESS = 2 TAILING = 3 @@ -77,12 +78,16 @@ def secondary_training_status_changed(current_job_description, prev_job_descript if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0: return False - prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \ - if prev_job_description is not None else None + prev_job_secondary_status_transitions = ( + prev_job_description.get('SecondaryStatusTransitions') if prev_job_description is not None else None + ) - last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \ - if prev_job_secondary_status_transitions is not None \ - and len(prev_job_secondary_status_transitions) > 0 else '' + last_message = ( + prev_job_secondary_status_transitions[-1]['StatusMessage'] + if prev_job_secondary_status_transitions is not None + and len(prev_job_secondary_status_transitions) > 0 + else '' + ) message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage'] @@ -101,18 +106,28 @@ def secondary_training_status_message(job_description, prev_description): :return: Job status string to be printed. """ - if job_description is None or job_description.get('SecondaryStatusTransitions') is None\ - or len(job_description.get('SecondaryStatusTransitions')) == 0: + if ( + job_description is None + or job_description.get('SecondaryStatusTransitions') is None + or len(job_description.get('SecondaryStatusTransitions')) == 0 + ): return '' - prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\ - if prev_description is not None else None - prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\ - if prev_description_secondary_transitions is not None else 0 + prev_description_secondary_transitions = ( + prev_description.get('SecondaryStatusTransitions') if prev_description is not None else None + ) + prev_transitions_num = ( + len(prev_description['SecondaryStatusTransitions']) + if prev_description_secondary_transitions is not None + else 0 + ) current_transitions = job_description['SecondaryStatusTransitions'] - transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \ - current_transitions[prev_transitions_num - len(current_transitions):] + transitions_to_print = ( + current_transitions[-1:] + if len(current_transitions) == prev_transitions_num + else current_transitions[prev_transitions_num - len(current_transitions) :] + ) status_strs = [] for transition in transitions_to_print: @@ -123,7 +138,7 @@ def secondary_training_status_message(job_description, prev_description): return '\n'.join(status_strs) -class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods +class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods """ Interact with Amazon SageMaker. @@ -133,9 +148,9 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ + non_terminal_states = {'InProgress', 'Stopping'} - endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating', - 'RollingBack', 'Deleting'} + endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating', 'RollingBack', 'Deleting'} failed_states = {'Failed'} def __init__(self, *args, **kwargs): @@ -183,11 +198,9 @@ def configure_s3_resources(self, config): self.s3_hook.create_bucket(bucket_name=op['Bucket']) for op in upload_ops: if op['Tar']: - self.tar_and_s3_upload(op['Path'], op['Key'], - op['Bucket']) + self.tar_and_s3_upload(op['Path'], op['Key'], op['Bucket']) else: - self.s3_hook.load_file(op['Path'], op['Key'], - op['Bucket']) + self.s3_hook.load_file(op['Path'], op['Key'], op['Bucket']) def check_s3_url(self, s3url): """ @@ -199,17 +212,18 @@ def check_s3_url(self, s3url): """ bucket, key = S3Hook.parse_s3_url(s3url) if not self.s3_hook.check_for_bucket(bucket_name=bucket): - raise AirflowException( - "The input S3 Bucket {} does not exist ".format(bucket)) - if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\ - and not self.s3_hook.check_for_prefix( - prefix=key, bucket_name=bucket, delimiter='/'): + raise AirflowException("The input S3 Bucket {} does not exist ".format(bucket)) + if ( + key + and not self.s3_hook.check_for_key(key=key, bucket_name=bucket) + and not self.s3_hook.check_for_prefix(prefix=key, bucket_name=bucket, delimiter='/') + ): # check if s3 key exists in the case user provides a single file # or if s3 prefix exists in the case user provides multiple files in # a prefix - raise AirflowException("The input S3 Key " - "or Prefix {} does not exist in the Bucket {}" - .format(s3url, bucket)) + raise AirflowException( + "The input S3 Key " "or Prefix {} does not exist in the Bucket {}".format(s3url, bucket) + ) return True def check_training_config(self, training_config): @@ -240,10 +254,12 @@ def get_log_conn(self): This method is deprecated. Please use :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead. """ - warnings.warn("Method `get_log_conn` has been deprecated. " - "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.", - category=DeprecationWarning, - stacklevel=2) + warnings.warn( + "Method `get_log_conn` has been deprecated. " + "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.", + category=DeprecationWarning, + stacklevel=2, + ) return self.logs_hook.get_conn() @@ -253,11 +269,13 @@ def log_stream(self, log_group, stream_name, start_time=0, skip=0): Please use :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead. """ - warnings.warn("Method `log_stream` has been deprecated. " - "Please use " - "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.", - category=DeprecationWarning, - stacklevel=2) + warnings.warn( + "Method `log_stream` has been deprecated. " + "Please use " + "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.", + category=DeprecationWarning, + stacklevel=2, + ) return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip) @@ -277,8 +295,10 @@ def multi_stream_iter(self, log_group, streams, positions=None): :return: A tuple of (stream number, cloudwatch log event). """ positions = positions or {s: Position(timestamp=0, skip=0) for s in streams} - event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip) - for s in streams] + event_iters = [ + self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip) + for s in streams + ] events = [] for event_stream in event_iters: if not event_stream: @@ -297,8 +317,9 @@ def multi_stream_iter(self, log_group, streams, positions=None): except StopIteration: events[i] = None - def create_training_job(self, config, wait_for_completion=True, print_log=True, - check_interval=30, max_ingestion_time=None): + def create_training_job( + self, config, wait_for_completion=True, print_log=True, check_interval=30, max_ingestion_time=None + ): """ Create a training job @@ -320,28 +341,31 @@ def create_training_job(self, config, wait_for_completion=True, print_log=True, response = self.get_conn().create_training_job(**config) if print_log: - self.check_training_status_with_log(config['TrainingJobName'], - self.non_terminal_states, - self.failed_states, - wait_for_completion, - check_interval, max_ingestion_time - ) + self.check_training_status_with_log( + config['TrainingJobName'], + self.non_terminal_states, + self.failed_states, + wait_for_completion, + check_interval, + max_ingestion_time, + ) elif wait_for_completion: - describe_response = self.check_status(config['TrainingJobName'], - 'TrainingJobStatus', - self.describe_training_job, - check_interval, max_ingestion_time - ) - - billable_time = \ - (describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \ - describe_response['ResourceConfig']['InstanceCount'] + describe_response = self.check_status( + config['TrainingJobName'], + 'TrainingJobStatus', + self.describe_training_job, + check_interval, + max_ingestion_time, + ) + + billable_time = ( + describe_response['TrainingEndTime'] - describe_response['TrainingStartTime'] + ) * describe_response['ResourceConfig']['InstanceCount'] self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1) return response - def create_tuning_job(self, config, wait_for_completion=True, - check_interval=30, max_ingestion_time=None): + def create_tuning_job(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None): """ Create a tuning job @@ -363,15 +387,18 @@ def create_tuning_job(self, config, wait_for_completion=True, response = self.get_conn().create_hyper_parameter_tuning_job(**config) if wait_for_completion: - self.check_status(config['HyperParameterTuningJobName'], - 'HyperParameterTuningJobStatus', - self.describe_tuning_job, - check_interval, max_ingestion_time - ) + self.check_status( + config['HyperParameterTuningJobName'], + 'HyperParameterTuningJobStatus', + self.describe_tuning_job, + check_interval, + max_ingestion_time, + ) return response - def create_transform_job(self, config, wait_for_completion=True, - check_interval=30, max_ingestion_time=None): + def create_transform_job( + self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None + ): """ Create a transform job @@ -393,15 +420,18 @@ def create_transform_job(self, config, wait_for_completion=True, response = self.get_conn().create_transform_job(**config) if wait_for_completion: - self.check_status(config['TransformJobName'], - 'TransformJobStatus', - self.describe_transform_job, - check_interval, max_ingestion_time - ) + self.check_status( + config['TransformJobName'], + 'TransformJobStatus', + self.describe_transform_job, + check_interval, + max_ingestion_time, + ) return response - def create_processing_job(self, config, wait_for_completion=True, - check_interval=30, max_ingestion_time=None): + def create_processing_job( + self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None + ): """ Create a processing job @@ -421,11 +451,13 @@ def create_processing_job(self, config, wait_for_completion=True, response = self.get_conn().create_processing_job(**config) if wait_for_completion: - self.check_status(config['ProcessingJobName'], - 'ProcessingJobStatus', - self.describe_processing_job, - check_interval, max_ingestion_time - ) + self.check_status( + config['ProcessingJobName'], + 'ProcessingJobStatus', + self.describe_processing_job, + check_interval, + max_ingestion_time, + ) return response def create_model(self, config): @@ -450,8 +482,7 @@ def create_endpoint_config(self, config): return self.get_conn().create_endpoint_config(**config) - def create_endpoint(self, config, wait_for_completion=True, - check_interval=30, max_ingestion_time=None): + def create_endpoint(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None): """ Create an endpoint @@ -471,16 +502,17 @@ def create_endpoint(self, config, wait_for_completion=True, response = self.get_conn().create_endpoint(**config) if wait_for_completion: - self.check_status(config['EndpointName'], - 'EndpointStatus', - self.describe_endpoint, - check_interval, max_ingestion_time, - non_terminal_states=self.endpoint_non_terminal_states - ) + self.check_status( + config['EndpointName'], + 'EndpointStatus', + self.describe_endpoint, + check_interval, + max_ingestion_time, + non_terminal_states=self.endpoint_non_terminal_states, + ) return response - def update_endpoint(self, config, wait_for_completion=True, - check_interval=30, max_ingestion_time=None): + def update_endpoint(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None): """ Update an endpoint @@ -500,12 +532,14 @@ def update_endpoint(self, config, wait_for_completion=True, response = self.get_conn().update_endpoint(**config) if wait_for_completion: - self.check_status(config['EndpointName'], - 'EndpointStatus', - self.describe_endpoint, - check_interval, max_ingestion_time, - non_terminal_states=self.endpoint_non_terminal_states - ) + self.check_status( + config['EndpointName'], + 'EndpointStatus', + self.describe_endpoint, + check_interval, + max_ingestion_time, + non_terminal_states=self.endpoint_non_terminal_states, + ) return response def describe_training_job(self, name): @@ -519,9 +553,16 @@ def describe_training_job(self, name): return self.get_conn().describe_training_job(TrainingJobName=name) - def describe_training_job_with_log(self, job_name, positions, stream_names, - instance_count, state, last_description, - last_describe_job_call): + def describe_training_job_with_log( + self, + job_name, + positions, + stream_names, + instance_count, + state, + last_description, + last_describe_job_call, + ): """ Return the training job info associated with job_name and print CloudWatch logs """ @@ -536,11 +577,12 @@ def describe_training_job_with_log(self, job_name, positions, stream_names, logGroupName=log_group, logStreamNamePrefix=job_name + '/', orderBy='LogStreamName', - limit=instance_count + limit=instance_count, ) stream_names = [s['logStreamName'] for s in streams['logStreams']] - positions.update([(s, Position(timestamp=0, skip=0)) - for s in stream_names if s not in positions]) + positions.update( + [(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions] + ) except logs_conn.exceptions.ResourceNotFoundException: # On the very first training job run on an account, there's no log group until # the container starts logging, so ignore any errors thrown about that @@ -638,10 +680,9 @@ def describe_endpoint(self, name): return self.get_conn().describe_endpoint(EndpointName=name) - def check_status(self, job_name, key, - describe_function, check_interval, - max_ingestion_time, - non_terminal_states=None): + def check_status( + self, job_name, key, describe_function, check_interval, max_ingestion_time, non_terminal_states=None + ): """ Check status of a SageMaker job @@ -677,8 +718,7 @@ def check_status(self, job_name, key, try: response = describe_function(job_name) status = response[key] - self.log.info('Job still running for %s seconds... ' - 'current status is %s', sec, status) + self.log.info('Job still running for %s seconds... ' 'current status is %s', sec, status) except KeyError: raise AirflowException('Could not get status of the SageMaker job') except ClientError: @@ -699,8 +739,15 @@ def check_status(self, job_name, key, response = describe_function(job_name) return response - def check_training_status_with_log(self, job_name, non_terminal_states, failed_states, - wait_for_completion, check_interval, max_ingestion_time): + def check_training_status_with_log( + self, + job_name, + non_terminal_states, + failed_states, + wait_for_completion, + check_interval, + max_ingestion_time, + ): """ Display the logs for a given training job, optionally tailing them until the job is complete. @@ -730,7 +777,7 @@ def check_training_status_with_log(self, job_name, non_terminal_states, failed_s status = description['TrainingJobStatus'] stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position + positions = {} # The current position in each stream, map of stream name -> position job_already_completed = status not in non_terminal_states @@ -763,10 +810,15 @@ def check_training_status_with_log(self, job_name, non_terminal_states, failed_s time.sleep(check_interval) sec += check_interval - state, last_description, last_describe_job_call = \ - self.describe_training_job_with_log(job_name, positions, stream_names, - instance_count, state, last_description, - last_describe_job_call) + state, last_description, last_describe_job_call = self.describe_training_job_with_log( + job_name, + positions, + stream_names, + instance_count, + state, + last_description, + last_describe_job_call, + ) if state == LogState.COMPLETE: break @@ -779,13 +831,14 @@ def check_training_status_with_log(self, job_name, non_terminal_states, failed_s if status in failed_states: reason = last_description.get('FailureReason', '(No reason provided)') raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason)) - billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \ - * instance_count + billable_time = ( + last_description['TrainingEndTime'] - last_description['TrainingStartTime'] + ) * instance_count self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1) def list_training_jobs( self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs - ) -> List[Dict]: # noqa: D402 + ) -> List[Dict]: # noqa: D402 """ This method wraps boto3's list_training_jobs(). The training job name and max results are configurable via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in diff --git a/airflow/providers/amazon/aws/hooks/ses.py b/airflow/providers/amazon/aws/hooks/ses.py index 2ee81719e091b..3844b711bfcbe 100644 --- a/airflow/providers/amazon/aws/hooks/ses.py +++ b/airflow/providers/amazon/aws/hooks/ses.py @@ -52,7 +52,7 @@ def send_email( # pylint: disable=too-many-arguments mime_charset: str = 'utf-8', reply_to: Optional[str] = None, return_path: Optional[str] = None, - custom_headers: Optional[Dict[str, Any]] = None + custom_headers: Optional[Dict[str, Any]] = None, ) -> dict: """ Send email using Amazon Simple Email Service diff --git a/airflow/providers/amazon/aws/hooks/sns.py b/airflow/providers/amazon/aws/hooks/sns.py index f0b0d5b95e11d..e5045bbfc48e9 100644 --- a/airflow/providers/amazon/aws/hooks/sns.py +++ b/airflow/providers/amazon/aws/hooks/sns.py @@ -33,8 +33,9 @@ def _get_message_attribute(o): return {'DataType': 'Number', 'StringValue': str(o)} if hasattr(o, '__iter__'): return {'DataType': 'String.Array', 'StringValue': json.dumps(o)} - raise TypeError('Values in MessageAttributes must be one of bytes, str, int, float, or iterable; ' - f'got {type(o)}') + raise TypeError( + 'Values in MessageAttributes must be one of bytes, str, int, float, or iterable; ' f'got {type(o)}' + ) class AwsSnsHook(AwsBaseHook): @@ -74,9 +75,7 @@ def publish_to_target(self, target_arn, message, subject=None, message_attribute publish_kwargs = { 'TargetArn': target_arn, 'MessageStructure': 'json', - 'Message': json.dumps({ - 'default': message - }), + 'Message': json.dumps({'default': message}), } # Construct args this way because boto3 distinguishes from missing args and those set to None diff --git a/airflow/providers/amazon/aws/hooks/sqs.py b/airflow/providers/amazon/aws/hooks/sqs.py index 849979b2b7f46..6c43f7f70b8ba 100644 --- a/airflow/providers/amazon/aws/hooks/sqs.py +++ b/airflow/providers/amazon/aws/hooks/sqs.py @@ -70,7 +70,9 @@ def send_message(self, queue_url, message_body, delay_seconds=0, message_attribu For details of the returned value see :py:meth:`botocore.client.SQS.send_message` :rtype: dict """ - return self.get_conn().send_message(QueueUrl=queue_url, - MessageBody=message_body, - DelaySeconds=delay_seconds, - MessageAttributes=message_attributes or {}) + return self.get_conn().send_message( + QueueUrl=queue_url, + MessageBody=message_body, + DelaySeconds=delay_seconds, + MessageAttributes=message_attributes or {}, + ) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index f0e10400d95ee..d83d1afa6257a 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -35,8 +35,12 @@ class StepFunctionHook(AwsBaseHook): def __init__(self, region_name=None, *args, **kwargs): super().__init__(client_type='stepfunctions', *args, **kwargs) - def start_execution(self, state_machine_arn: str, name: Optional[str] = None, - state_machine_input: Union[dict, str, None] = None) -> str: + def start_execution( + self, + state_machine_arn: str, + name: Optional[str] = None, + state_machine_input: Union[dict, str, None] = None, + ) -> str: """ Start Execution of the State Machine. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution @@ -50,9 +54,7 @@ def start_execution(self, state_machine_arn: str, name: Optional[str] = None, :return: Execution ARN :rtype: str """ - execution_args = { - 'stateMachineArn': state_machine_arn - } + execution_args = {'stateMachineArn': state_machine_arn} if name is not None: execution_args['name'] = name if state_machine_input is not None: diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 2d224525aa3dc..7d4e3a0703bde 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -38,6 +38,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin): :param filename_template: template for file name (local storage) or log stream name (remote) :type filename_template: str """ + def __init__(self, base_log_folder, log_group_arn, filename_template): super().__init__(base_log_folder, filename_template) split_arn = log_group_arn.split(':') @@ -55,12 +56,14 @@ def hook(self): remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') try: from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook + return AwsLogsHook(aws_conn_id=remote_conn_id, region_name=self.region_name) except Exception: # pylint: disable=broad-except self.log.error( 'Could not create an AwsLogsHook with connection id "%s". ' 'Please make sure that airflow[aws] is installed and ' - 'the Cloudwatch logs connection exists.', remote_conn_id + 'the Cloudwatch logs connection exists.', + remote_conn_id, ) def _render_filename(self, ti, try_number): @@ -72,7 +75,7 @@ def set_context(self, ti): self.handler = watchtower.CloudWatchLogHandler( log_group=self.log_group, stream_name=self._render_filename(ti, ti.try_number), - boto3_session=self.hook.get_session(self.region_name) + boto3_session=self.hook.get_session(self.region_name), ) def close(self): @@ -93,9 +96,12 @@ def close(self): def _read(self, task_instance, try_number, metadata=None): stream_name = self._render_filename(task_instance, try_number) - return '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n'.format( - self.log_group, stream_name, self.get_cloudwatch_logs(stream_name=stream_name) - ), {'end_of_log': True} + return ( + '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n'.format( + self.log_group, stream_name, self.get_cloudwatch_logs(stream_name=stream_name) + ), + {'end_of_log': True}, + ) def get_cloudwatch_logs(self, stream_name): """ diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py index b13b7cdaf40eb..00f52d1b28666 100644 --- a/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -30,6 +30,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin): task instance logs. It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage. """ + def __init__(self, base_log_folder, s3_log_folder, filename_template): super().__init__(base_log_folder, filename_template) self.remote_base = s3_log_folder @@ -46,12 +47,14 @@ def hook(self): remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') try: from airflow.providers.amazon.aws.hooks.s3 import S3Hook + return S3Hook(remote_conn_id) except Exception: # pylint: disable=broad-except self.log.exception( 'Could not create an S3Hook with connection id "%s". ' 'Please make sure that airflow[aws] is installed and ' - 'the S3 connection exists.', remote_conn_id + 'the S3 connection exists.', + remote_conn_id, ) def set_context(self, ti): @@ -115,8 +118,7 @@ def _read(self, ti, try_number, metadata=None): # local machine even if there are errors reading remote logs, as # returned remote_log will contain error messages. remote_log = self.s3_read(remote_loc, return_error=True) - log = '*** Reading remote log from {}.\n{}\n'.format( - remote_loc, remote_log) + log = '*** Reading remote log from {}.\n{}\n'.format(remote_loc, remote_log) return log, {'end_of_log': True} else: return super()._read(ti, try_number) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 4d734d0097ed5..2039fe3015585 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -54,11 +54,12 @@ class AWSAthenaOperator(BaseOperator): ui_color = '#44b5e2' template_fields = ('query', 'database', 'output_location') - template_ext = ('.sql', ) + template_ext = ('.sql',) @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, query: str, database: str, output_location: str, @@ -69,7 +70,7 @@ def __init__( # pylint: disable=too-many-arguments result_configuration: Optional[Dict[str, Any]] = None, sleep_time: int = 30, max_tries: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.query = query @@ -95,21 +96,29 @@ def execute(self, context: dict) -> Optional[str]: """ self.query_execution_context['Database'] = self.database self.result_configuration['OutputLocation'] = self.output_location - self.query_execution_id = self.hook.run_query(self.query, self.query_execution_context, - self.result_configuration, self.client_request_token, - self.workgroup) + self.query_execution_id = self.hook.run_query( + self.query, + self.query_execution_context, + self.result_configuration, + self.client_request_token, + self.workgroup, + ) query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries) if query_status in AWSAthenaHook.FAILURE_STATES: error_message = self.hook.get_state_change_reason(self.query_execution_id) raise Exception( - 'Final state of Athena job is {}, query_execution_id is {}. Error: {}' - .format(query_status, self.query_execution_id, error_message)) + 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format( + query_status, self.query_execution_id, error_message + ) + ) elif not query_status or query_status in AWSAthenaHook.INTERMEDIATE_STATES: raise Exception( 'Final state of Athena job is {}. ' - 'Max tries of poll status exceeded, query_execution_id is {}.' - .format(query_status, self.query_execution_id)) + 'Max tries of poll status exceeded, query_execution_id is {}.'.format( + query_status, self.query_execution_id + ) + ) return self.query_execution_id @@ -119,9 +128,7 @@ def on_kill(self) -> None: """ if self.query_execution_id: self.log.info('⚰️⚰️⚰️ Received a kill Signal. Time to Die') - self.log.info( - 'Stopping Query with executionId - %s', self.query_execution_id - ) + self.log.info('Stopping Query with executionId - %s', self.query_execution_id) response = self.hook.stop_query(self.query_execution_id) http_status_code = None try: diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index c865ade217170..aabe3072815d5 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -99,7 +99,8 @@ class AwsBatchOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, job_name, job_definition, job_queue, @@ -141,9 +142,7 @@ def execute(self, context: Dict): self.monitor_job(context) def on_kill(self): - response = self.hook.client.terminate_job( - jobId=self.job_id, reason="Task killed by the user" - ) + response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user") self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response) def submit_job(self, context: Dict): # pylint: disable=unused-argument @@ -153,9 +152,7 @@ def submit_job(self, context: Dict): # pylint: disable=unused-argument :raises: AirflowException """ self.log.info( - "Running AWS Batch job - job definition: %s - on queue %s", - self.job_definition, - self.job_queue, + "Running AWS Batch job - job definition: %s - on queue %s", self.job_definition, self.job_queue, ) self.log.info("AWS Batch job - container overrides: %s", self.overrides) diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py index f0dc0c4b17cf0..d6c9bb01dc57a 100644 --- a/airflow/providers/amazon/aws/operators/cloud_formation.py +++ b/airflow/providers/amazon/aws/operators/cloud_formation.py @@ -39,17 +39,13 @@ class CloudFormationCreateStackOperator(BaseOperator): :param aws_conn_id: aws connection to uses :type aws_conn_id: str """ + template_fields: List[str] = ['stack_name'] template_ext = () ui_color = '#6b9659' @apply_defaults - def __init__( - self, *, - stack_name, - params, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, stack_name, params, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.stack_name = stack_name self.params = params @@ -76,18 +72,14 @@ class CloudFormationDeleteStackOperator(BaseOperator): :param aws_conn_id: aws connection to uses :type aws_conn_id: str """ + template_fields: List[str] = ['stack_name'] template_ext = () ui_color = '#1d472b' ui_fgcolor = '#FFF' @apply_defaults - def __init__( - self, *, - stack_name, - params=None, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, stack_name, params=None, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.params = params or {} self.stack_name = stack_name diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 944b6e57e6470..681eb9c4821ae 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -101,13 +101,14 @@ class AWSDataSyncOperator(BaseOperator): "create_source_location_kwargs", "create_destination_location_kwargs", "update_task_kwargs", - "task_execution_kwargs" + "task_execution_kwargs", ) ui_color = "#44b5e2" @apply_defaults def __init__( - self, *, + self, + *, aws_conn_id="aws_default", wait_interval_seconds=5, task_arn=None, @@ -121,7 +122,7 @@ def __init__( update_task_kwargs=None, task_execution_kwargs=None, delete_task_after_execution=False, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -181,8 +182,7 @@ def get_hook(self): """ if not self.hook: self.hook = AWSDataSyncHook( - aws_conn_id=self.aws_conn_id, - wait_interval_seconds=self.wait_interval_seconds, + aws_conn_id=self.aws_conn_id, wait_interval_seconds=self.wait_interval_seconds, ) return self.hook @@ -194,16 +194,14 @@ def execute(self, context): # If some were found, identify which one to run if self.candidate_task_arns: - self.task_arn = self.choose_task( - self.candidate_task_arns) + self.task_arn = self.choose_task(self.candidate_task_arns) # If we couldnt find one then try create one if not self.task_arn and self.create_task_kwargs: self._create_datasync_task() if not self.task_arn: - raise AirflowException( - "DataSync TaskArn could not be identified or created.") + raise AirflowException("DataSync TaskArn could not be identified or created.") self.log.info("Using DataSync TaskArn %s", self.task_arn) @@ -227,13 +225,9 @@ def _get_tasks_and_locations(self): """Find existing DataSync Task based on source and dest Locations.""" hook = self.get_hook() - self.candidate_source_location_arns = self._get_location_arns( - self.source_location_uri - ) + self.candidate_source_location_arns = self._get_location_arns(self.source_location_uri) - self.candidate_destination_location_arns = self._get_location_arns( - self.destination_location_uri - ) + self.candidate_destination_location_arns = self._get_location_arns(self.destination_location_uri) if not self.candidate_source_location_arns: self.log.info("No matching source Locations") @@ -245,11 +239,9 @@ def _get_tasks_and_locations(self): self.log.info("Finding DataSync TaskArns that have these LocationArns") self.candidate_task_arns = hook.get_task_arns_for_location_arns( - self.candidate_source_location_arns, - self.candidate_destination_location_arns, + self.candidate_source_location_arns, self.candidate_destination_location_arns, ) - self.log.info("Found candidate DataSync TaskArns %s", - self.candidate_task_arns) + self.log.info("Found candidate DataSync TaskArns %s", self.candidate_task_arns) def choose_task(self, task_arn_list): """Select 1 DataSync TaskArn from a list""" @@ -263,8 +255,7 @@ def choose_task(self, task_arn_list): # from AWS and might lead to confusion. Rather explicitly # choose a random one return random.choice(task_arn_list) - raise AirflowException( - "Unable to choose a Task from {}".format(task_arn_list)) + raise AirflowException("Unable to choose a Task from {}".format(task_arn_list)) def choose_location(self, location_arn_list): """Select 1 DataSync LocationArn from a list""" @@ -278,16 +269,13 @@ def choose_location(self, location_arn_list): # from AWS and might lead to confusion. Rather explicitly # choose a random one return random.choice(location_arn_list) - raise AirflowException( - "Unable to choose a Location from {}".format(location_arn_list)) + raise AirflowException("Unable to choose a Location from {}".format(location_arn_list)) def _create_datasync_task(self): """Create a AWS DataSyncTask.""" hook = self.get_hook() - self.source_location_arn = self.choose_location( - self.candidate_source_location_arns - ) + self.source_location_arn = self.choose_location(self.candidate_source_location_arns) if not self.source_location_arn and self.create_source_location_kwargs: self.log.info('Attempting to create source Location') self.source_location_arn = hook.create_location( @@ -295,12 +283,10 @@ def _create_datasync_task(self): ) if not self.source_location_arn: raise AirflowException( - "Unable to determine source LocationArn." - " Does a suitable DataSync Location exist?") + "Unable to determine source LocationArn." " Does a suitable DataSync Location exist?" + ) - self.destination_location_arn = self.choose_location( - self.candidate_destination_location_arns - ) + self.destination_location_arn = self.choose_location(self.candidate_destination_location_arns) if not self.destination_location_arn and self.create_destination_location_kwargs: self.log.info('Attempting to create destination Location') self.destination_location_arn = hook.create_location( @@ -308,14 +294,12 @@ def _create_datasync_task(self): ) if not self.destination_location_arn: raise AirflowException( - "Unable to determine destination LocationArn." - " Does a suitable DataSync Location exist?") + "Unable to determine destination LocationArn." " Does a suitable DataSync Location exist?" + ) self.log.info("Creating a Task.") self.task_arn = hook.create_task( - self.source_location_arn, - self.destination_location_arn, - **self.create_task_kwargs + self.source_location_arn, self.destination_location_arn, **self.create_task_kwargs ) if not self.task_arn: raise AirflowException("Task could not be created") @@ -336,20 +320,15 @@ def _execute_datasync_task(self): # Create a task execution: self.log.info("Starting execution for TaskArn %s", self.task_arn) - self.task_execution_arn = hook.start_task_execution( - self.task_arn, **self.task_execution_kwargs) + self.task_execution_arn = hook.start_task_execution(self.task_arn, **self.task_execution_kwargs) self.log.info("Started TaskExecutionArn %s", self.task_execution_arn) # Wait for task execution to complete - self.log.info("Waiting for TaskExecutionArn %s", - self.task_execution_arn) + self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn) result = hook.wait_for_task_execution(self.task_execution_arn) self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn) - task_execution_description = hook.describe_task_execution( - task_execution_arn=self.task_execution_arn - ) - self.log.info("task_execution_description=%s", - task_execution_description) + task_execution_description = hook.describe_task_execution(task_execution_arn=self.task_execution_arn) + self.log.info("task_execution_description=%s", task_execution_description) # Log some meaningful statuses level = logging.ERROR if not result else logging.INFO @@ -359,21 +338,16 @@ def _execute_datasync_task(self): self.log.log(level, '%s=%s', k, v) if not result: - raise AirflowException( - "Failed TaskExecutionArn %s" % self.task_execution_arn - ) + raise AirflowException("Failed TaskExecutionArn %s" % self.task_execution_arn) return self.task_execution_arn def on_kill(self): """Cancel the submitted DataSync task.""" hook = self.get_hook() if self.task_execution_arn: - self.log.info("Cancelling TaskExecutionArn %s", - self.task_execution_arn) - hook.cancel_task_execution( - task_execution_arn=self.task_execution_arn) - self.log.info("Cancelled TaskExecutionArn %s", - self.task_execution_arn) + self.log.info("Cancelling TaskExecutionArn %s", self.task_execution_arn) + hook.cancel_task_execution(task_execution_arn=self.task_execution_arn) + self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn) def _delete_datasync_task(self): """Deletes an AWS DataSync Task.""" @@ -385,10 +359,6 @@ def _delete_datasync_task(self): return self.task_arn def _get_location_arns(self, location_uri): - location_arns = self.get_hook().get_location_arns( - location_uri - ) - self.log.info( - "Found LocationArns %s for LocationUri %s", location_arns, location_uri - ) + location_arns = self.get_hook().get_location_arns(location_uri) + self.log.info("Found LocationArns %s for LocationUri %s", location_arns, location_uri) return location_arns diff --git a/airflow/providers/amazon/aws/operators/ec2_start_instance.py b/airflow/providers/amazon/aws/operators/ec2_start_instance.py index dc657bf3c9ddf..e623de9a9f5db 100644 --- a/airflow/providers/amazon/aws/operators/ec2_start_instance.py +++ b/airflow/providers/amazon/aws/operators/ec2_start_instance.py @@ -44,12 +44,15 @@ class EC2StartInstanceOperator(BaseOperator): ui_fgcolor = "#ffffff" @apply_defaults - def __init__(self, *, - instance_id: str, - aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, - check_interval: float = 15, - **kwargs): + def __init__( + self, + *, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + check_interval: float = 15, + **kwargs, + ): super().__init__(**kwargs) self.instance_id = instance_id self.aws_conn_id = aws_conn_id @@ -57,15 +60,10 @@ def __init__(self, *, self.check_interval = check_interval def execute(self, context): - ec2_hook = EC2Hook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name - ) + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Starting EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) instance.start() ec2_hook.wait_for_state( - instance_id=self.instance_id, - target_state="running", - check_interval=self.check_interval, + instance_id=self.instance_id, target_state="running", check_interval=self.check_interval, ) diff --git a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py index 808284497a92c..0369bdd0b6c36 100644 --- a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py +++ b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py @@ -44,12 +44,15 @@ class EC2StopInstanceOperator(BaseOperator): ui_fgcolor = "#ffffff" @apply_defaults - def __init__(self, *, - instance_id: str, - aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, - check_interval: float = 15, - **kwargs): + def __init__( + self, + *, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + check_interval: float = 15, + **kwargs, + ): super().__init__(**kwargs) self.instance_id = instance_id self.aws_conn_id = aws_conn_id @@ -57,15 +60,10 @@ def __init__(self, *, self.check_interval = check_interval def execute(self, context): - ec2_hook = EC2Hook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name - ) + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) self.log.info("Stopping EC2 instance %s", self.instance_id) instance = ec2_hook.get_instance(instance_id=self.instance_id) instance.stop() ec2_hook.wait_for_state( - instance_id=self.instance_id, - target_state="stopped", - check_interval=self.check_interval, + instance_id=self.instance_id, target_state="stopped", check_interval=self.check_interval, ) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 573b10d17d158..44b72e53ea2fd 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -113,11 +113,26 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes template_fields = ('overrides',) @apply_defaults - def __init__(self, *, task_definition, cluster, overrides, # pylint: disable=too-many-arguments - aws_conn_id=None, region_name=None, launch_type='EC2', - group=None, placement_constraints=None, platform_version='LATEST', - network_configuration=None, tags=None, awslogs_group=None, - awslogs_region=None, awslogs_stream_prefix=None, propagate_tags=None, **kwargs): + def __init__( + self, + *, + task_definition, + cluster, + overrides, # pylint: disable=too-many-arguments + aws_conn_id=None, + region_name=None, + launch_type='EC2', + group=None, + placement_constraints=None, + platform_version='LATEST', + network_configuration=None, + tags=None, + awslogs_group=None, + awslogs_region=None, + awslogs_stream_prefix=None, + propagate_tags=None, + **kwargs, + ): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id @@ -144,8 +159,7 @@ def __init__(self, *, task_definition, cluster, overrides, # pylint: disable=to def execute(self, context): self.log.info( - 'Running ECS Task - Task definition: %s - on cluster %s', - self.task_definition, self.cluster + 'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster ) self.log.info('ECSOperator overrides: %s', self.overrides) @@ -189,16 +203,10 @@ def execute(self, context): def _wait_for_task_ended(self): waiter = self.client.get_waiter('tasks_stopped') waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow - waiter.wait( - cluster=self.cluster, - tasks=[self.arn] - ) + waiter.wait(cluster=self.cluster, tasks=[self.arn]) def _check_success_task(self): - response = self.client.describe_tasks( - cluster=self.cluster, - tasks=[self.arn] - ) + response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn]) self.log.info('ECS Task stopped, check status: %s', response) # Get logs from CloudWatch if the awslogs log driver was used @@ -218,44 +226,39 @@ def _check_success_task(self): # successfully finished, but there is no other indication of failure # in the response. # https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html - if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.', - task.get('stoppedReason', '')): + if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.', task.get('stoppedReason', '')): raise AirflowException( - 'The task was stopped because the host instance terminated: {}'. - format(task.get('stoppedReason', ''))) + 'The task was stopped because the host instance terminated: {}'.format( + task.get('stoppedReason', '') + ) + ) containers = task['containers'] for container in containers: - if container.get('lastStatus') == 'STOPPED' and \ - container['exitCode'] != 0: - raise AirflowException( - 'This task is not in success state {}'.format(task)) + if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0: + raise AirflowException('This task is not in success state {}'.format(task)) elif container.get('lastStatus') == 'PENDING': raise AirflowException('This task is still pending {}'.format(task)) elif 'error' in container.get('reason', '').lower(): raise AirflowException( - 'This containers encounter an error during launching : {}'. - format(container.get('reason', '').lower())) + 'This containers encounter an error during launching : {}'.format( + container.get('reason', '').lower() + ) + ) def get_hook(self): """Create and return an AwsHook.""" if not self.hook: self.hook = AwsBaseHook( - aws_conn_id=self.aws_conn_id, - client_type='ecs', - region_name=self.region_name + aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name ) return self.hook def get_logs_hook(self): """Create and return an AwsLogsHook.""" - return AwsLogsHook( - aws_conn_id=self.aws_conn_id, - region_name=self.awslogs_region - ) + return AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region) def on_kill(self): response = self.client.stop_task( - cluster=self.cluster, - task=self.arn, - reason='Task killed by the user') + cluster=self.cluster, task=self.arn, reason='Task killed by the user' + ) self.log.info(response) diff --git a/airflow/providers/amazon/aws/operators/emr_add_steps.py b/airflow/providers/amazon/aws/operators/emr_add_steps.py index 3c1078e6ac438..d046d2e60fb5a 100644 --- a/airflow/providers/amazon/aws/operators/emr_add_steps.py +++ b/airflow/providers/amazon/aws/operators/emr_add_steps.py @@ -44,19 +44,22 @@ class EmrAddStepsOperator(BaseOperator): :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. :type do_xcom_push: bool """ + template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps'] template_ext = ('.json',) ui_color = '#f9c915' @apply_defaults def __init__( - self, *, - job_flow_id=None, - job_flow_name=None, - cluster_states=None, - aws_conn_id='aws_default', - steps=None, - **kwargs): + self, + *, + job_flow_id=None, + job_flow_name=None, + cluster_states=None, + aws_conn_id='aws_default', + steps=None, + **kwargs, + ): if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") if not (job_flow_id is None) ^ (job_flow_name is None): @@ -74,8 +77,9 @@ def execute(self, context): emr = emr_hook.get_conn() - job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name, - self.cluster_states) + job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name( + self.job_flow_name, self.cluster_states + ) if not job_flow_id: raise AirflowException(f'No cluster found for name: {self.job_flow_name}') diff --git a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py index b02abf901336b..71e5e09209a40 100644 --- a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py +++ b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py @@ -37,18 +37,21 @@ class EmrCreateJobFlowOperator(BaseOperator): (must be '.json') to override emr_connection extra. (templated) :type job_flow_overrides: dict|str """ + template_fields = ['job_flow_overrides'] template_ext = ('.json',) ui_color = '#f9c915' @apply_defaults def __init__( - self, *, - aws_conn_id='aws_default', - emr_conn_id='emr_default', - job_flow_overrides=None, - region_name=None, - **kwargs): + self, + *, + aws_conn_id='aws_default', + emr_conn_id='emr_default', + job_flow_overrides=None, + region_name=None, + **kwargs, + ): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.emr_conn_id = emr_conn_id @@ -58,13 +61,12 @@ def __init__( self.region_name = region_name def execute(self, context): - emr = EmrHook(aws_conn_id=self.aws_conn_id, - emr_conn_id=self.emr_conn_id, - region_name=self.region_name) + emr = EmrHook( + aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name + ) self.log.info( - 'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', - self.aws_conn_id, self.emr_conn_id + 'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', self.aws_conn_id, self.emr_conn_id ) if isinstance(self.job_flow_overrides, str): diff --git a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py b/airflow/providers/amazon/aws/operators/emr_modify_cluster.py index 87c2296041c15..48692e34a9ddc 100644 --- a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py +++ b/airflow/providers/amazon/aws/operators/emr_modify_cluster.py @@ -33,17 +33,15 @@ class EmrModifyClusterOperator(BaseOperator): :param do_xcom_push: if True, cluster_id is pushed to XCom with key cluster_id. :type do_xcom_push: bool """ + template_fields = ['cluster_id', 'step_concurrency_level'] template_ext = () ui_color = '#f9c915' @apply_defaults def __init__( - self, *, - cluster_id: str, - step_concurrency_level: int, - aws_conn_id: str = 'aws_default', - **kwargs): + self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = 'aws_default', **kwargs + ): if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") super().__init__(**kwargs) @@ -60,8 +58,9 @@ def execute(self, context): context['ti'].xcom_push(key='cluster_id', value=self.cluster_id) self.log.info('Modifying cluster %s', self.cluster_id) - response = emr.modify_cluster(ClusterId=self.cluster_id, - StepConcurrencyLevel=self.step_concurrency_level) + response = emr.modify_cluster( + ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level + ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException('Modify cluster failed: %s' % response) diff --git a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py b/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py index c22920e48e9dd..19cbddb4df4ab 100644 --- a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py +++ b/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py @@ -30,16 +30,13 @@ class EmrTerminateJobFlowOperator(BaseOperator): :param aws_conn_id: aws connection to uses :type aws_conn_id: str """ + template_fields = ['job_flow_id'] template_ext = () ui_color = '#f9c915' @apply_defaults - def __init__( - self, *, - job_flow_id, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, job_flow_id, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.aws_conn_id = aws_conn_id diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index a945f4ecf4710..991135f5e91ef 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -52,25 +52,28 @@ class AwsGlueJobOperator(BaseOperator): :param iam_role_name: AWS IAM Role for Glue Job Execution :type iam_role_name: Optional[str] """ + template_fields = () template_ext = () ui_color = '#ededed' @apply_defaults - def __init__(self, *, - job_name='aws_glue_default_job', - job_desc='AWS Glue Job with Airflow', - script_location=None, - concurrent_run_limit=None, - script_args=None, - retry_limit=None, - num_of_dpus=6, - aws_conn_id='aws_default', - region_name=None, - s3_bucket=None, - iam_role_name=None, - **kwargs - ): # pylint: disable=too-many-arguments + def __init__( + self, + *, + job_name='aws_glue_default_job', + job_desc='AWS Glue Job with Airflow', + script_location=None, + concurrent_run_limit=None, + script_args=None, + retry_limit=None, + num_of_dpus=6, + aws_conn_id='aws_default', + region_name=None, + s3_bucket=None, + iam_role_name=None, + **kwargs, + ): # pylint: disable=too-many-arguments super(AwsGlueJobOperator, self).__init__(**kwargs) self.job_name = job_name self.job_desc = job_desc @@ -96,20 +99,25 @@ def execute(self, context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) script_name = os.path.basename(self.script_location) s3_hook.load_file(self.script_location, self.s3_bucket, self.s3_artifcats_prefix + script_name) - glue_job = AwsGlueJobHook(job_name=self.job_name, - desc=self.job_desc, - concurrent_run_limit=self.concurrent_run_limit, - script_location=self.script_location, - retry_limit=self.retry_limit, - num_of_dpus=self.num_of_dpus, - aws_conn_id=self.aws_conn_id, - region_name=self.region_name, - s3_bucket=self.s3_bucket, - iam_role_name=self.iam_role_name) + glue_job = AwsGlueJobHook( + job_name=self.job_name, + desc=self.job_desc, + concurrent_run_limit=self.concurrent_run_limit, + script_location=self.script_location, + retry_limit=self.retry_limit, + num_of_dpus=self.num_of_dpus, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + s3_bucket=self.s3_bucket, + iam_role_name=self.iam_role_name, + ) self.log.info("Initializing AWS Glue Job: %s", self.job_name) glue_job_run = glue_job.initialize_job(self.script_args) glue_job_run = glue_job.job_completion(self.job_name, glue_job_run['JobRunId']) self.log.info( "AWS Glue Job: %s status: %s. Run Id: %s", - self.job_name, glue_job_run['JobRunState'], glue_job_run['JobRunId']) + self.job_name, + glue_job_run['JobRunState'], + glue_job_run['JobRunId'], + ) return glue_job_run['JobRunId'] diff --git a/airflow/providers/amazon/aws/operators/s3_bucket.py b/airflow/providers/amazon/aws/operators/s3_bucket.py index f7d9822cf17a6..a2aa06bd3e573 100644 --- a/airflow/providers/amazon/aws/operators/s3_bucket.py +++ b/airflow/providers/amazon/aws/operators/s3_bucket.py @@ -40,12 +40,16 @@ class S3CreateBucketOperator(BaseOperator): :param region_name: AWS region_name. If not specified fetched from connection. :type region_name: Optional[str] """ + @apply_defaults - def __init__(self, *, - bucket_name, - aws_conn_id: Optional[str] = "aws_default", - region_name: Optional[str] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket_name, + aws_conn_id: Optional[str] = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name self.region_name = region_name @@ -76,11 +80,14 @@ class S3DeleteBucketOperator(BaseOperator): maintained on each worker node). :type aws_conn_id: Optional[str] """ - def __init__(self, - bucket_name, - force_delete: Optional[bool] = False, - aws_conn_id: Optional[str] = "aws_default", - **kwargs) -> None: + + def __init__( + self, + bucket_name, + force_delete: Optional[bool] = False, + aws_conn_id: Optional[str] = "aws_default", + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name self.force_delete = force_delete diff --git a/airflow/providers/amazon/aws/operators/s3_copy_object.py b/airflow/providers/amazon/aws/operators/s3_copy_object.py index 8d1dd9ca8c92a..4b2d2901d2c62 100644 --- a/airflow/providers/amazon/aws/operators/s3_copy_object.py +++ b/airflow/providers/amazon/aws/operators/s3_copy_object.py @@ -64,20 +64,21 @@ class S3CopyObjectOperator(BaseOperator): :type verify: bool or str """ - template_fields = ('source_bucket_key', 'dest_bucket_key', - 'source_bucket_name', 'dest_bucket_name') + template_fields = ('source_bucket_key', 'dest_bucket_key', 'source_bucket_name', 'dest_bucket_name') @apply_defaults def __init__( - self, *, - source_bucket_key, - dest_bucket_key, - source_bucket_name=None, - dest_bucket_name=None, - source_version_id=None, - aws_conn_id='aws_default', - verify=None, - **kwargs): + self, + *, + source_bucket_key, + dest_bucket_key, + source_bucket_name=None, + dest_bucket_name=None, + source_version_id=None, + aws_conn_id='aws_default', + verify=None, + **kwargs, + ): super().__init__(**kwargs) self.source_bucket_key = source_bucket_key @@ -90,6 +91,10 @@ def __init__( def execute(self, context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key, - self.source_bucket_name, self.dest_bucket_name, - self.source_version_id) + s3_hook.copy_object( + self.source_bucket_key, + self.dest_bucket_key, + self.source_bucket_name, + self.dest_bucket_name, + self.source_version_id, + ) diff --git a/airflow/providers/amazon/aws/operators/s3_delete_objects.py b/airflow/providers/amazon/aws/operators/s3_delete_objects.py index d8c4683873acf..b6d267ba767ca 100644 --- a/airflow/providers/amazon/aws/operators/s3_delete_objects.py +++ b/airflow/providers/amazon/aws/operators/s3_delete_objects.py @@ -62,14 +62,7 @@ class S3DeleteObjectsOperator(BaseOperator): template_fields = ('keys', 'bucket', 'prefix') @apply_defaults - def __init__( - self, *, - bucket, - keys=None, - prefix=None, - aws_conn_id='aws_default', - verify=None, - **kwargs): + def __init__(self, *, bucket, keys=None, prefix=None, aws_conn_id='aws_default', verify=None, **kwargs): if not bool(keys) ^ bool(prefix): raise ValueError("Either keys or prefix should be set.") diff --git a/airflow/providers/amazon/aws/operators/s3_file_transform.py b/airflow/providers/amazon/aws/operators/s3_file_transform.py index 4324d204d31c2..e2aa822f282a9 100644 --- a/airflow/providers/amazon/aws/operators/s3_file_transform.py +++ b/airflow/providers/amazon/aws/operators/s3_file_transform.py @@ -84,18 +84,20 @@ class S3FileTransformOperator(BaseOperator): @apply_defaults def __init__( - self, *, - source_s3_key: str, - dest_s3_key: str, - transform_script: Optional[str] = None, - select_expression=None, - script_args: Optional[Sequence[str]] = None, - source_aws_conn_id: str = 'aws_default', - source_verify: Optional[Union[bool, str]] = None, - dest_aws_conn_id: str = 'aws_default', - dest_verify: Optional[Union[bool, str]] = None, - replace: bool = False, - **kwargs) -> None: + self, + *, + source_s3_key: str, + dest_s3_key: str, + transform_script: Optional[str] = None, + select_expression=None, + script_args: Optional[Sequence[str]] = None, + source_aws_conn_id: str = 'aws_default', + source_verify: Optional[Union[bool, str]] = None, + dest_aws_conn_id: str = 'aws_default', + dest_verify: Optional[Union[bool, str]] = None, + replace: bool = False, + **kwargs, + ) -> None: # pylint: disable=too-many-arguments super().__init__(**kwargs) self.source_s3_key = source_s3_key @@ -112,29 +114,21 @@ def __init__( def execute(self, context): if self.transform_script is None and self.select_expression is None: - raise AirflowException( - "Either transform_script or select_expression must be specified") + raise AirflowException("Either transform_script or select_expression must be specified") source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify) dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify) self.log.info("Downloading source S3 file %s", self.source_s3_key) if not source_s3.check_for_key(self.source_s3_key): - raise AirflowException( - "The source key {0} does not exist".format(self.source_s3_key)) + raise AirflowException("The source key {0} does not exist".format(self.source_s3_key)) source_s3_key_object = source_s3.get_key(self.source_s3_key) with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest: - self.log.info( - "Dumping S3 file %s contents to local file %s", - self.source_s3_key, f_source.name - ) + self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name) if self.select_expression is not None: - content = source_s3.select_key( - key=self.source_s3_key, - expression=self.select_expression - ) + content = source_s3.select_key(key=self.source_s3_key, expression=self.select_expression) f_source.write(content.encode("utf-8")) else: source_s3_key_object.download_fileobj(Fileobj=f_source) @@ -145,7 +139,7 @@ def execute(self, context): [self.transform_script, f_source.name, f_dest.name, *self.script_args], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - close_fds=True + close_fds=True, ) self.log.info("Output:") @@ -155,13 +149,10 @@ def execute(self, context): process.wait() if process.returncode: - raise AirflowException( - "Transform script failed: {0}".format(process.returncode) - ) + raise AirflowException("Transform script failed: {0}".format(process.returncode)) else: self.log.info( - "Transform script successful. Output temporarily located at %s", - f_dest.name + "Transform script successful. Output temporarily located at %s", f_dest.name ) self.log.info("Uploading transformed file to S3") @@ -169,6 +160,6 @@ def execute(self, context): dest_s3.load_file( filename=f_dest.name if self.transform_script else f_source.name, key=self.dest_s3_key, - replace=self.replace + replace=self.replace, ) self.log.info("Upload successful") diff --git a/airflow/providers/amazon/aws/operators/s3_list.py b/airflow/providers/amazon/aws/operators/s3_list.py index 427ff3fcaa929..4c25e99cfe3ce 100644 --- a/airflow/providers/amazon/aws/operators/s3_list.py +++ b/airflow/providers/amazon/aws/operators/s3_list.py @@ -65,17 +65,12 @@ class S3ListOperator(BaseOperator): aws_conn_id='aws_customers_conn' ) """ + template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter') ui_color = '#ffd700' @apply_defaults - def __init__(self, *, - bucket, - prefix='', - delimiter='', - aws_conn_id='aws_default', - verify=None, - **kwargs): + def __init__(self, *, bucket, prefix='', delimiter='', aws_conn_id='aws_default', verify=None, **kwargs): super().__init__(**kwargs) self.bucket = bucket self.prefix = prefix @@ -88,10 +83,9 @@ def execute(self, context): self.log.info( 'Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)', - self.bucket, self.prefix, self.delimiter + self.bucket, + self.prefix, + self.delimiter, ) - return hook.list_keys( - bucket_name=self.bucket, - prefix=self.prefix, - delimiter=self.delimiter) + return hook.list_keys(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py b/airflow/providers/amazon/aws/operators/sagemaker_base.py index e5c42ac947782..19fb92187ce76 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_base.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_base.py @@ -41,10 +41,7 @@ class SageMakerBaseOperator(BaseOperator): integer_fields = [] # type: Iterable[Iterable[str]] @apply_defaults - def __init__(self, *, - config, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, config, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id @@ -81,14 +78,12 @@ def parse_config_integers(self): for field in self.integer_fields: self.parse_integer(self.config, field) - def expand_role(self): # noqa: D402 + def expand_role(self): # noqa: D402 """Placeholder for calling boto3's expand_role(), which expands an IAM role name into an ARN.""" def preprocess_config(self): """Process the config into a usable form.""" - self.log.info( - 'Preprocessing the config and doing required s3_operations' - ) + self.log.info('Preprocessing the config and doing required s3_operations') self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id) self.hook.configure_s3_resources(self.config) diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py index aa444fa5d8e54..c7a89f2d2d4a2 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py @@ -71,15 +71,17 @@ class SageMakerEndpointOperator(SageMakerBaseOperator): """ @apply_defaults - def __init__(self, *, - config, - wait_for_completion=True, - check_interval=30, - max_ingestion_time=None, - operation='create', - **kwargs): - super().__init__(config=config, - **kwargs) + def __init__( + self, + *, + config, + wait_for_completion=True, + check_interval=30, + max_ingestion_time=None, + operation='create', + **kwargs, + ): + super().__init__(config=config, **kwargs) self.config = config self.wait_for_completion = wait_for_completion @@ -93,9 +95,7 @@ def __init__(self, *, def create_integer_fields(self): """Set fields which should be casted to integers.""" if 'EndpointConfig' in self.config: - self.integer_fields = [ - ['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount'] - ] + self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']] def expand_role(self): if 'Model' not in self.config: @@ -135,7 +135,7 @@ def execute(self, context): endpoint_info, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time + max_ingestion_time=self.max_ingestion_time, ) except ClientError: # Botocore throws a ClientError if the endpoint is already created self.operation = 'update' @@ -145,18 +145,13 @@ def execute(self, context): endpoint_info, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time + max_ingestion_time=self.max_ingestion_time, ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException( - 'Sagemaker endpoint creation failed: %s' % response) + raise AirflowException('Sagemaker endpoint creation failed: %s' % response) else: return { - 'EndpointConfig': self.hook.describe_endpoint_config( - endpoint_info['EndpointConfigName'] - ), - 'Endpoint': self.hook.describe_endpoint( - endpoint_info['EndpointName'] - ) + 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']), + 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']), } diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py index f1d38bf185325..9bde4514ea320 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py @@ -35,16 +35,11 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator): :type aws_conn_id: str """ - integer_fields = [ - ['ProductionVariants', 'InitialInstanceCount'] - ] + integer_fields = [['ProductionVariants', 'InitialInstanceCount']] @apply_defaults - def __init__(self, *, - config, - **kwargs): - super().__init__(config=config, - **kwargs) + def __init__(self, *, config, **kwargs): + super().__init__(config=config, **kwargs) self.config = config @@ -54,11 +49,6 @@ def execute(self, context): self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName']) response = self.hook.create_endpoint_config(self.config) if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException( - 'Sagemaker endpoint config creation failed: %s' % response) + raise AirflowException('Sagemaker endpoint config creation failed: %s' % response) else: - return { - 'EndpointConfig': self.hook.describe_endpoint_config( - self.config['EndpointConfigName'] - ) - } + return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])} diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py b/airflow/providers/amazon/aws/operators/sagemaker_model.py index 31e2fbd6ae496..122ceeee90854 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_model.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_model.py @@ -37,11 +37,8 @@ class SageMakerModelOperator(SageMakerBaseOperator): """ @apply_defaults - def __init__(self, *, - config, - **kwargs): - super().__init__(config=config, - **kwargs) + def __init__(self, *, config, **kwargs): + super().__init__(config=config, **kwargs) self.config = config @@ -58,8 +55,4 @@ def execute(self, context): if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException('Sagemaker model creation failed: %s' % response) else: - return { - 'Model': self.hook.describe_model( - self.config['ModelName'] - ) - } + return {'Model': self.hook.describe_model(self.config['ModelName'])} diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py b/airflow/providers/amazon/aws/operators/sagemaker_processing.py index ef2fd6989a8a8..c1bcac7af4776 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_processing.py @@ -52,15 +52,18 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): """ @apply_defaults - def __init__(self, *, - config, - aws_conn_id, - wait_for_completion=True, - print_log=True, - check_interval=30, - max_ingestion_time=None, - action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 - **kwargs): + def __init__( + self, + *, + config, + aws_conn_id, + wait_for_completion=True, + print_log=True, + check_interval=30, + max_ingestion_time=None, + action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 + **kwargs, + ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) if action_if_job_exists not in ("increment", "fail"): @@ -79,12 +82,10 @@ def _create_integer_fields(self): """Set fields which should be casted to integers.""" self.integer_fields = [ ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], - ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'] + ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], ] if 'StoppingCondition' in self.config: - self.integer_fields += [ - ['StoppingCondition', 'MaxRuntimeInSeconds'] - ] + self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']] def expand_role(self): if 'RoleArn' in self.config: @@ -114,12 +115,8 @@ def execute(self, context): self.config, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time + max_ingestion_time=self.max_ingestion_time, ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException('Sagemaker Processing Job creation failed: %s' % response) - return { - 'Processing': self.hook.describe_processing_job( - self.config['ProcessingJobName'] - ) - } + return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])} diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py index 9bdbe56e38efc..6175a615f441a 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_training.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_training.py @@ -54,18 +54,21 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): integer_fields = [ ['ResourceConfig', 'InstanceCount'], ['ResourceConfig', 'VolumeSizeInGB'], - ['StoppingCondition', 'MaxRuntimeInSeconds'] + ['StoppingCondition', 'MaxRuntimeInSeconds'], ] @apply_defaults - def __init__(self, *, - config, - wait_for_completion=True, - print_log=True, - check_interval=30, - max_ingestion_time=None, - action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 - **kwargs): + def __init__( + self, + *, + config, + wait_for_completion=True, + print_log=True, + check_interval=30, + max_ingestion_time=None, + action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 + **kwargs, + ): super().__init__(config=config, **kwargs) self.wait_for_completion = wait_for_completion @@ -110,13 +113,9 @@ def execute(self, context): wait_for_completion=self.wait_for_completion, print_log=self.print_log, check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time + max_ingestion_time=self.max_ingestion_time, ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException('Sagemaker Training Job creation failed: %s' % response) else: - return { - 'Training': self.hook.describe_training_job( - self.config['TrainingJobName'] - ) - } + return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py index 221bf82aee0ab..7ae8f3ae1787b 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py @@ -62,14 +62,10 @@ class SageMakerTransformOperator(SageMakerBaseOperator): """ @apply_defaults - def __init__(self, *, - config, - wait_for_completion=True, - check_interval=30, - max_ingestion_time=None, - **kwargs): - super().__init__(config=config, - **kwargs) + def __init__( + self, *, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None, **kwargs + ): + super().__init__(config=config, **kwargs) self.config = config self.wait_for_completion = wait_for_completion self.check_interval = check_interval @@ -81,7 +77,7 @@ def create_integer_fields(self): self.integer_fields = [ ['Transform', 'TransformResources', 'InstanceCount'], ['Transform', 'MaxConcurrentTransforms'], - ['Transform', 'MaxPayloadInMB'] + ['Transform', 'MaxPayloadInMB'], ] if 'Transform' not in self.config: for field in self.integer_fields: @@ -110,15 +106,12 @@ def execute(self, context): transform_config, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time) + max_ingestion_time=self.max_ingestion_time, + ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException('Sagemaker transform Job creation failed: %s' % response) else: return { - 'Model': self.hook.describe_model( - transform_config['ModelName'] - ), - 'Transform': self.hook.describe_transform_job( - transform_config['TransformJobName'] - ) + 'Model': self.hook.describe_model(transform_config['ModelName']), + 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']), } diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py index 16268865f79f8..483e5416ef3de 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py @@ -51,18 +51,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator): ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'], ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'], ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'], - ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'] + ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'], ] @apply_defaults - def __init__(self, *, - config, - wait_for_completion=True, - check_interval=30, - max_ingestion_time=None, - **kwargs): - super().__init__(config=config, - **kwargs) + def __init__( + self, *, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None, **kwargs + ): + super().__init__(config=config, **kwargs) self.config = config self.wait_for_completion = wait_for_completion self.check_interval = check_interval @@ -86,13 +82,9 @@ def execute(self, context): self.config, wait_for_completion=self.wait_for_completion, check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time + max_ingestion_time=self.max_ingestion_time, ) if response['ResponseMetadata']['HTTPStatusCode'] != 200: raise AirflowException('Sagemaker Tuning Job creation failed: %s' % response) else: - return { - 'Tuning': self.hook.describe_tuning_job( - self.config['HyperParameterTuningJobName'] - ) - } + return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])} diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py index 3f24813766600..8917dfe0d38b1 100644 --- a/airflow/providers/amazon/aws/operators/sns.py +++ b/airflow/providers/amazon/aws/operators/sns.py @@ -39,18 +39,21 @@ class SnsPublishOperator(BaseOperator): determined automatically) :type message_attributes: dict """ + template_fields = ['message', 'subject', 'message_attributes'] template_ext = () @apply_defaults def __init__( - self, *, - target_arn, - message, - aws_conn_id='aws_default', - subject=None, - message_attributes=None, - **kwargs): + self, + *, + target_arn, + message, + aws_conn_id='aws_default', + subject=None, + message_attributes=None, + **kwargs, + ): super().__init__(**kwargs) self.target_arn = target_arn self.message = message diff --git a/airflow/providers/amazon/aws/operators/sqs.py b/airflow/providers/amazon/aws/operators/sqs.py index e0edc3fb857ab..00b29dbee0ef3 100644 --- a/airflow/providers/amazon/aws/operators/sqs.py +++ b/airflow/providers/amazon/aws/operators/sqs.py @@ -38,17 +38,21 @@ class SQSPublishOperator(BaseOperator): :param aws_conn_id: AWS connection id (default: aws_default) :type aws_conn_id: str """ + template_fields = ('sqs_queue', 'message_content', 'delay_seconds') ui_color = '#6ad3fa' @apply_defaults - def __init__(self, *, - sqs_queue, - message_content, - message_attributes=None, - delay_seconds=0, - aws_conn_id='aws_default', - **kwargs): + def __init__( + self, + *, + sqs_queue, + message_content, + message_attributes=None, + delay_seconds=0, + aws_conn_id='aws_default', + **kwargs, + ): super().__init__(**kwargs) self.sqs_queue = sqs_queue self.aws_conn_id = aws_conn_id @@ -69,10 +73,12 @@ def execute(self, context): hook = SQSHook(aws_conn_id=self.aws_conn_id) - result = hook.send_message(queue_url=self.sqs_queue, - message_body=self.message_content, - delay_seconds=self.delay_seconds, - message_attributes=self.message_attributes) + result = hook.send_message( + queue_url=self.sqs_queue, + message_body=self.message_content, + delay_seconds=self.delay_seconds, + message_attributes=self.message_attributes, + ) self.log.info('result is send_message is %s', result) diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py index 404ce2416f3de..2eaa2c49c7327 100644 --- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py +++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py @@ -36,6 +36,7 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator): :param aws_conn_id: aws connection to use, defaults to 'aws_default' :type aws_conn_id: str """ + template_fields = ['execution_arn'] template_ext = () ui_color = '#f9c915' diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py index 0b22c88afef90..0d8f446cd944e 100644 --- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py +++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py @@ -43,15 +43,22 @@ class StepFunctionStartExecutionOperator(BaseOperator): :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. :type do_xcom_push: bool """ + template_fields = ['state_machine_arn', 'name', 'input'] template_ext = () ui_color = '#f9c915' @apply_defaults - def __init__(self, *, state_machine_arn: str, name: Optional[str] = None, - state_machine_input: Union[dict, str, None] = None, - aws_conn_id='aws_default', region_name=None, - **kwargs): + def __init__( + self, + *, + state_machine_arn: str, + name: Optional[str] = None, + state_machine_input: Union[dict, str, None] = None, + aws_conn_id='aws_default', + region_name=None, + **kwargs, + ): super().__init__(**kwargs) self.state_machine_arn = state_machine_arn self.name = name diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py index 39dd8a70e3c2e..47a07a985127d 100644 --- a/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -70,7 +70,7 @@ def __init__( config_prefix: str = 'airflow/config', profile_name: Optional[str] = None, sep: str = "/", - **kwargs + **kwargs, ): super().__init__() self.connections_prefix = connections_prefix.rstrip("/") @@ -85,9 +85,7 @@ def client(self): """ Create a Secrets Manager client """ - session = boto3.session.Session( - profile_name=self.profile_name, - ) + session = boto3.session.Session(profile_name=self.profile_name,) return session.client(service_name="secretsmanager", **self.kwargs) def get_conn_uri(self, conn_id: str) -> Optional[str]: @@ -128,14 +126,13 @@ def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: """ secrets_path = self.build_path(path_prefix, secret_id, self.sep) try: - response = self.client.get_secret_value( - SecretId=secrets_path, - ) + response = self.client.get_secret_value(SecretId=secrets_path,) return response.get('SecretString') except self.client.exceptions.ResourceNotFoundException: self.log.debug( "An error occurred (ResourceNotFoundException) when calling the " "get_secret_value operation: " - "Secret %s not found.", secrets_path + "Secret %s not found.", + secrets_path, ) return None diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py index 203be353dd9cc..5e67362d98f9d 100644 --- a/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -57,7 +57,7 @@ def __init__( connections_prefix: str = '/airflow/connections', variables_prefix: str = '/airflow/variables', profile_name: Optional[str] = None, - **kwargs + **kwargs, ): super().__init__() self.connections_prefix = connections_prefix.rstrip("/") @@ -102,14 +102,13 @@ def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: """ ssm_path = self.build_path(path_prefix, secret_id) try: - response = self.client.get_parameter( - Name=ssm_path, WithDecryption=True - ) + response = self.client.get_parameter(Name=ssm_path, WithDecryption=True) value = response["Parameter"]["Value"] return value except self.client.exceptions.ParameterNotFound: self.log.info( "An error occurred (ParameterNotFound) when calling the GetParameter operation: " - "Parameter %s not found.", ssm_path + "Parameter %s not found.", + ssm_path, ) return None diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index 50edc8dd4e702..40c028a41c4c4 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -42,8 +42,14 @@ class AthenaSensor(BaseSensorOperator): :type sleep_time: int """ - INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',) - FAILURE_STATES = ('FAILED', 'CANCELLED',) + INTERMEDIATE_STATES = ( + 'QUEUED', + 'RUNNING', + ) + FAILURE_STATES = ( + 'FAILED', + 'CANCELLED', + ) SUCCESS_STATES = ('SUCCEEDED',) template_fields = ['query_execution_id'] @@ -51,12 +57,15 @@ class AthenaSensor(BaseSensorOperator): ui_color = '#66c3ff' @apply_defaults - def __init__(self, *, - query_execution_id: str, - max_retries: Optional[int] = None, - aws_conn_id: str = 'aws_default', - sleep_time: int = 10, - **kwargs: Any) -> None: + def __init__( + self, + *, + query_execution_id: str, + max_retries: Optional[int] = None, + aws_conn_id: str = 'aws_default', + sleep_time: int = 10, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.query_execution_id = query_execution_id diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py index 05f15a394e7d1..739a13331c570 100644 --- a/airflow/providers/amazon/aws/sensors/cloud_formation.py +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -40,11 +40,7 @@ class CloudFormationCreateStackSensor(BaseSensorOperator): ui_color = '#C5CAE9' @apply_defaults - def __init__(self, *, - stack_name, - aws_conn_id='aws_default', - region_name=None, - **kwargs): + def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs): super().__init__(**kwargs) self.stack_name = stack_name self.hook = AWSCloudFormationHook(aws_conn_id=aws_conn_id, region_name=region_name) @@ -75,11 +71,7 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator): ui_color = '#C5CAE9' @apply_defaults - def __init__(self, *, - stack_name, - aws_conn_id='aws_default', - region_name=None, - **kwargs): + def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.region_name = region_name @@ -97,7 +89,5 @@ def poke(self, context): def get_hook(self): """Create and return an AWSCloudFormationHook""" if not self.hook: - self.hook = AWSCloudFormationHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name) + self.hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py index c2a53c891e4b8..7e55d7d9044b3 100644 --- a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py +++ b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py @@ -43,12 +43,15 @@ class EC2InstanceStateSensor(BaseSensorOperator): valid_states = ["running", "stopped", "terminated"] @apply_defaults - def __init__(self, *, - target_state: str, - instance_id: str, - aws_conn_id: str = "aws_default", - region_name: Optional[str] = None, - **kwargs): + def __init__( + self, + *, + target_state: str, + instance_id: str, + aws_conn_id: str = "aws_default", + region_name: Optional[str] = None, + **kwargs, + ): if target_state not in self.valid_states: raise ValueError(f"Invalid target_state: {target_state}") super().__init__(**kwargs) @@ -58,12 +61,7 @@ def __init__(self, *, self.region_name = region_name def poke(self, context): - ec2_hook = EC2Hook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name - ) - instance_state = ec2_hook.get_instance_state( - instance_id=self.instance_id - ) + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id) self.log.info("instance state: %s", instance_state) return instance_state == self.target_state diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py index d487af2146e24..f05197bc45fe9 100644 --- a/airflow/providers/amazon/aws/sensors/emr_base.py +++ b/airflow/providers/amazon/aws/sensors/emr_base.py @@ -38,13 +38,11 @@ class EmrBaseSensor(BaseSensorOperator): :param aws_conn_id: aws connection to uses :type aws_conn_id: str """ + ui_color = '#66c3ff' @apply_defaults - def __init__( - self, *, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.target_states = None # will be set in subclasses @@ -86,8 +84,7 @@ def get_emr_response(self) -> Dict[str, Any]: :return: response :rtype: dict[str, Any] """ - raise NotImplementedError( - 'Please implement get_emr_response() in subclass') + raise NotImplementedError('Please implement get_emr_response() in subclass') @staticmethod def state_from_response(response: Dict[str, Any]) -> str: @@ -99,8 +96,7 @@ def state_from_response(response: Dict[str, Any]) -> str: :return: state :rtype: str """ - raise NotImplementedError( - 'Please implement state_from_response() in subclass') + raise NotImplementedError('Please implement state_from_response() in subclass') @staticmethod def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: @@ -112,5 +108,4 @@ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: :return: failure message :rtype: Optional[str] """ - raise NotImplementedError( - 'Please implement failure_message_from_response() in subclass') + raise NotImplementedError('Please implement failure_message_from_response() in subclass') diff --git a/airflow/providers/amazon/aws/sensors/emr_job_flow.py b/airflow/providers/amazon/aws/sensors/emr_job_flow.py index 004b8b8218df1..c08e9db92675a 100644 --- a/airflow/providers/amazon/aws/sensors/emr_job_flow.py +++ b/airflow/providers/amazon/aws/sensors/emr_job_flow.py @@ -46,11 +46,14 @@ class EmrJobFlowSensor(EmrBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, - job_flow_id: str, - target_states: Optional[Iterable[str]] = None, - failed_states: Optional[Iterable[str]] = None, - **kwargs): + def __init__( + self, + *, + job_flow_id: str, + target_states: Optional[Iterable[str]] = None, + failed_states: Optional[Iterable[str]] = None, + **kwargs, + ): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.target_states = target_states or ['TERMINATED'] @@ -97,6 +100,6 @@ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: state_change_reason = cluster_status.get('StateChangeReason') if state_change_reason: return 'for code: {} with message {}'.format( - state_change_reason.get('Code', 'No code'), - state_change_reason.get('Message', 'Unknown')) + state_change_reason.get('Code', 'No code'), state_change_reason.get('Message', 'Unknown') + ) return None diff --git a/airflow/providers/amazon/aws/sensors/emr_step.py b/airflow/providers/amazon/aws/sensors/emr_step.py index 65394c8927ea4..f3c3d593fb73c 100644 --- a/airflow/providers/amazon/aws/sensors/emr_step.py +++ b/airflow/providers/amazon/aws/sensors/emr_step.py @@ -41,23 +41,24 @@ class EmrStepSensor(EmrBaseSensor): :type failed_states: list[str] """ - template_fields = ['job_flow_id', 'step_id', - 'target_states', 'failed_states'] + template_fields = ['job_flow_id', 'step_id', 'target_states', 'failed_states'] template_ext = () @apply_defaults - def __init__(self, *, - job_flow_id: str, - step_id: str, - target_states: Optional[Iterable[str]] = None, - failed_states: Optional[Iterable[str]] = None, - **kwargs): + def __init__( + self, + *, + job_flow_id: str, + step_id: str, + target_states: Optional[Iterable[str]] = None, + failed_states: Optional[Iterable[str]] = None, + **kwargs, + ): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.step_id = step_id self.target_states = target_states or ['COMPLETED'] - self.failed_states = failed_states or ['CANCELLED', 'FAILED', - 'INTERRUPTED'] + self.failed_states = failed_states or ['CANCELLED', 'FAILED', 'INTERRUPTED'] def get_emr_response(self) -> Dict[str, Any]: """ @@ -71,12 +72,8 @@ def get_emr_response(self) -> Dict[str, Any]: """ emr_client = self.get_hook().get_conn() - self.log.info('Poking step %s on cluster %s', - self.step_id, - self.job_flow_id) - return emr_client.describe_step( - ClusterId=self.job_flow_id, - StepId=self.step_id) + self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id) + return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id) @staticmethod def state_from_response(response: Dict[str, Any]) -> str: @@ -103,7 +100,6 @@ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: fail_details = response['Step']['Status'].get('FailureDetails') if fail_details: return 'for reason {} with message {} and log file {}'.format( - fail_details.get('Reason'), - fail_details.get('Message'), - fail_details.get('LogFile')) + fail_details.get('Reason'), fail_details.get('Message'), fail_details.get('LogFile') + ) return None diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py index 9539761617def..7b2ce30206d7d 100644 --- a/airflow/providers/amazon/aws/sensors/glue.py +++ b/airflow/providers/amazon/aws/sensors/glue.py @@ -32,14 +32,11 @@ class AwsGlueJobSensor(BaseSensorOperator): :param run_id: The AWS Glue current running job identifier :type run_id: str """ + template_fields = ('job_name', 'run_id') @apply_defaults - def __init__(self, *, - job_name, - run_id, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, job_name, run_id, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.job_name = job_name self.run_id = run_id @@ -49,9 +46,7 @@ def __init__(self, *, def poke(self, context): hook = AwsGlueJobHook(aws_conn_id=self.aws_conn_id) - self.log.info( - "Poking for job run status :" - "for Glue Job %s and ID %s", self.job_name, self.run_id) + self.log.info("Poking for job run status :" "for Glue Job %s and ID %s", self.job_name, self.run_id) job_state = hook.get_job_state(job_name=self.job_name, run_id=self.run_id) if job_state in self.success_states: self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py index 5d900abf436ba..f1df94d3c0663 100644 --- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -47,19 +47,27 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator): between each tries :type poke_interval: int """ - template_fields = ('database_name', 'table_name', 'expression',) + + template_fields = ( + 'database_name', + 'table_name', + 'expression', + ) ui_color = '#C5CAE9' @apply_defaults - def __init__(self, *, - table_name, expression="ds='{{ ds }}'", - aws_conn_id='aws_default', - region_name=None, - database_name='default', - poke_interval=60 * 3, - **kwargs): - super().__init__( - poke_interval=poke_interval, **kwargs) + def __init__( + self, + *, + table_name, + expression="ds='{{ ds }}'", + aws_conn_id='aws_default', + region_name=None, + database_name='default', + poke_interval=60 * 3, + **kwargs, + ): + super().__init__(poke_interval=poke_interval, **kwargs) self.aws_conn_id = aws_conn_id self.region_name = region_name self.table_name = table_name @@ -77,15 +85,12 @@ def poke(self, context): 'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression ) - return self.get_hook().check_for_partition( - self.database_name, self.table_name, self.expression) + return self.get_hook().check_for_partition(self.database_name, self.table_name, self.expression) def get_hook(self): """ Gets the AwsGlueCatalogHook """ if not self.hook: - self.hook = AwsGlueCatalogHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region_name) + self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) return self.hook diff --git a/airflow/providers/amazon/aws/sensors/redshift.py b/airflow/providers/amazon/aws/sensors/redshift.py index 0c893cadb3d63..37f35216dfd60 100644 --- a/airflow/providers/amazon/aws/sensors/redshift.py +++ b/airflow/providers/amazon/aws/sensors/redshift.py @@ -30,14 +30,11 @@ class AwsRedshiftClusterSensor(BaseSensorOperator): :param target_status: The cluster status desired. :type target_status: str """ + template_fields = ('cluster_identifier', 'target_status') @apply_defaults - def __init__(self, *, - cluster_identifier, - target_status='available', - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, cluster_identifier, target_status='available', aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.target_status = target_status @@ -45,8 +42,7 @@ def __init__(self, *, self.hook = None def poke(self, context): - self.log.info('Poking for status : %s\nfor cluster %s', - self.target_status, self.cluster_identifier) + self.log.info('Poking for status : %s\nfor cluster %s', self.target_status, self.cluster_identifier) return self.get_hook().cluster_status(self.cluster_identifier) == self.target_status def get_hook(self): diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py b/airflow/providers/amazon/aws/sensors/s3_key.py index 2661daa0dc85d..0c0f6e3e4d014 100644 --- a/airflow/providers/amazon/aws/sensors/s3_key.py +++ b/airflow/providers/amazon/aws/sensors/s3_key.py @@ -55,16 +55,20 @@ class S3KeySensor(BaseSensorOperator): CA cert bundle than the one used by botocore. :type verify: bool or str """ + template_fields = ('bucket_key', 'bucket_name') @apply_defaults - def __init__(self, *, - bucket_key, - bucket_name=None, - wildcard_match=False, - aws_conn_id='aws_default', - verify=None, - **kwargs): + def __init__( + self, + *, + bucket_key, + bucket_name=None, + wildcard_match=False, + aws_conn_id='aws_default', + verify=None, + **kwargs, + ): super().__init__(**kwargs) # Parse if bucket_name is None: @@ -77,9 +81,11 @@ def __init__(self, *, else: parsed_url = urlparse(bucket_key) if parsed_url.scheme != '' or parsed_url.netloc != '': - raise AirflowException('If bucket_name is provided, bucket_key' + - ' should be relative path from root' + - ' level, rather than a full s3:// url') + raise AirflowException( + 'If bucket_name is provided, bucket_key' + + ' should be relative path from root' + + ' level, rather than a full s3:// url' + ) self.bucket_name = bucket_name self.bucket_key = bucket_key self.wildcard_match = wildcard_match @@ -90,9 +96,7 @@ def __init__(self, *, def poke(self, context): self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key) if self.wildcard_match: - return self.get_hook().check_for_wildcard_key( - self.bucket_key, - self.bucket_name) + return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name) return self.get_hook().check_for_key(self.bucket_key, self.bucket_name) def get_hook(self): diff --git a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py b/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py index 95a2148a755fa..f1f3d4e00bf6c 100644 --- a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py +++ b/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py @@ -72,16 +72,19 @@ class S3KeysUnchangedSensor(BaseSensorOperator): template_fields = ('bucket_name', 'prefix') @apply_defaults - def __init__(self, *, - bucket_name: str, - prefix: str, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - inactivity_period: float = 60 * 60, - min_objects: int = 1, - previous_objects: Optional[Set[str]] = None, - allow_delete: bool = True, - **kwargs) -> None: + def __init__( + self, + *, + bucket_name: str, + prefix: str, + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + previous_objects: Optional[Set[str]] = None, + allow_delete: bool = True, + **kwargs, + ) -> None: super().__init__(**kwargs) @@ -117,8 +120,10 @@ def is_keys_unchanged(self, current_objects: Set[str]) -> bool: if current_objects > self.previous_objects: # When new objects arrived, reset the inactivity_seconds # and update previous_objects for the next poke. - self.log.info("New objects found at %s, resetting last_activity_time.", - os.path.join(self.bucket, self.prefix)) + self.log.info( + "New objects found at %s, resetting last_activity_time.", + os.path.join(self.bucket, self.prefix), + ) self.log.debug("New objects: %s", current_objects - self.previous_objects) self.last_activity_time = datetime.now() self.inactivity_seconds = 0 @@ -131,12 +136,17 @@ def is_keys_unchanged(self, current_objects: Set[str]) -> bool: deleted_objects = self.previous_objects - current_objects self.previous_objects = current_objects self.last_activity_time = datetime.now() - self.log.info("Objects were deleted during the last poke interval. Updating the " - "file counter and resetting last_activity_time:\n%s", deleted_objects) + self.log.info( + "Objects were deleted during the last poke interval. Updating the " + "file counter and resetting last_activity_time:\n%s", + deleted_objects, + ) return False - raise AirflowException("Illegal behavior: objects were deleted in %s between pokes." - % os.path.join(self.bucket, self.prefix)) + raise AirflowException( + "Illegal behavior: objects were deleted in %s between pokes." + % os.path.join(self.bucket, self.prefix) + ) if self.last_activity_time: self.inactivity_seconds = int((datetime.now() - self.last_activity_time).total_seconds()) @@ -149,9 +159,13 @@ def is_keys_unchanged(self, current_objects: Set[str]) -> bool: path = os.path.join(self.bucket, self.prefix) if current_num_objects >= self.min_objects: - self.log.info("SUCCESS: \nSensor found %s objects at %s.\n" - "Waited at least %s seconds, with no new objects uploaded.", - current_num_objects, path, self.inactivity_period) + self.log.info( + "SUCCESS: \nSensor found %s objects at %s.\n" + "Waited at least %s seconds, with no new objects uploaded.", + current_num_objects, + path, + self.inactivity_period, + ) return True self.log.error("FAILURE: Inactivity Period passed, not enough objects found in %s", path) diff --git a/airflow/providers/amazon/aws/sensors/s3_prefix.py b/airflow/providers/amazon/aws/sensors/s3_prefix.py index acaf961830c59..4dc49000ce56e 100644 --- a/airflow/providers/amazon/aws/sensors/s3_prefix.py +++ b/airflow/providers/amazon/aws/sensors/s3_prefix.py @@ -51,16 +51,13 @@ class S3PrefixSensor(BaseSensorOperator): CA cert bundle than the one used by botocore. :type verify: bool or str """ + template_fields = ('prefix', 'bucket_name') @apply_defaults - def __init__(self, *, - bucket_name, - prefix, - delimiter='/', - aws_conn_id='aws_default', - verify=None, - **kwargs): + def __init__( + self, *, bucket_name, prefix, delimiter='/', aws_conn_id='aws_default', verify=None, **kwargs + ): super().__init__(**kwargs) # Parse self.bucket_name = bucket_name @@ -74,9 +71,8 @@ def __init__(self, *, def poke(self, context): self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name) return self.get_hook().check_for_prefix( - prefix=self.prefix, - delimiter=self.delimiter, - bucket_name=self.bucket_name) + prefix=self.prefix, delimiter=self.delimiter, bucket_name=self.bucket_name + ) def get_hook(self): """Create and return an S3Hook""" diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py b/airflow/providers/amazon/aws/sensors/sagemaker_base.py index b3468df9a1b61..6704b1a35cffc 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_base.py @@ -28,13 +28,11 @@ class SageMakerBaseSensor(BaseSensorOperator): and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """ + ui_color = '#ededed' @apply_defaults - def __init__( - self, *, - aws_conn_id='aws_default', - **kwargs): + def __init__(self, *, aws_conn_id='aws_default', **kwargs): super().__init__(**kwargs) self.aws_conn_id = aws_conn_id self.hook = None @@ -61,8 +59,7 @@ def poke(self, context): if state in self.failed_states(): failed_reason = self.get_failed_reason_from_response(response) - raise AirflowException('Sagemaker job failed for the following reason: %s' - % failed_reason) + raise AirflowException('Sagemaker job failed for the following reason: %s' % failed_reason) return True def non_terminal_states(self): diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py index b8df5bf644e68..1a1b6f73ca97c 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py @@ -34,9 +34,7 @@ class SageMakerEndpointSensor(SageMakerBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, - endpoint_name, - **kwargs): + def __init__(self, *, endpoint_name, **kwargs): super().__init__(**kwargs) self.endpoint_name = endpoint_name diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py index 1695d95a59bea..36403b877ef87 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py @@ -38,10 +38,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, - job_name, - print_log=True, - **kwargs): + def __init__(self, *, job_name, print_log=True, **kwargs): super().__init__(**kwargs) self.job_name = job_name self.print_log = print_log @@ -75,20 +72,27 @@ def get_sagemaker_response(self): if self.print_log: if not self.log_resource_inited: self.init_log_resource(self.get_hook()) - self.state, self.last_description, self.last_describe_job_call = \ - self.get_hook().describe_training_job_with_log(self.job_name, - self.positions, self.stream_names, - self.instance_count, self.state, - self.last_description, - self.last_describe_job_call) + ( + self.state, + self.last_description, + self.last_describe_job_call, + ) = self.get_hook().describe_training_job_with_log( + self.job_name, + self.positions, + self.stream_names, + self.instance_count, + self.state, + self.last_description, + self.last_describe_job_call, + ) else: self.last_description = self.get_hook().describe_training_job(self.job_name) status = self.state_from_response(self.last_description) if status not in self.non_terminal_states() and status not in self.failed_states(): - billable_time = \ - (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \ - self.last_description['ResourceConfig']['InstanceCount'] + billable_time = ( + self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime'] + ) * self.last_description['ResourceConfig']['InstanceCount'] self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1) return self.last_description diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py index 5a9ffdc9b0350..4108c98d1dc35 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py @@ -35,9 +35,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, - job_name, - **kwargs): + def __init__(self, *, job_name, **kwargs): super().__init__(**kwargs) self.job_name = job_name diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py index 6b97807c7ce47..794695b5c4d63 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py @@ -35,9 +35,7 @@ class SageMakerTuningSensor(SageMakerBaseSensor): template_ext = () @apply_defaults - def __init__(self, *, - job_name, - **kwargs): + def __init__(self, *, job_name, **kwargs): super().__init__(**kwargs) self.job_name = job_name diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py index 573981bdc2a0d..2d1ab54ad1d3d 100644 --- a/airflow/providers/amazon/aws/sensors/sqs.py +++ b/airflow/providers/amazon/aws/sensors/sqs.py @@ -44,12 +44,9 @@ class SQSSensor(BaseSensorOperator): template_fields = ('sqs_queue', 'max_messages') @apply_defaults - def __init__(self, *, - sqs_queue, - aws_conn_id='aws_default', - max_messages=5, - wait_time_seconds=1, - **kwargs): + def __init__( + self, *, sqs_queue, aws_conn_id='aws_default', max_messages=5, wait_time_seconds=1, **kwargs + ): super().__init__(**kwargs) self.sqs_queue = sqs_queue self.aws_conn_id = aws_conn_id @@ -69,25 +66,29 @@ def poke(self, context): self.log.info('SQSSensor checking for message on queue: %s', self.sqs_queue) - messages = sqs_conn.receive_message(QueueUrl=self.sqs_queue, - MaxNumberOfMessages=self.max_messages, - WaitTimeSeconds=self.wait_time_seconds) + messages = sqs_conn.receive_message( + QueueUrl=self.sqs_queue, + MaxNumberOfMessages=self.max_messages, + WaitTimeSeconds=self.wait_time_seconds, + ) self.log.info("received message %s", str(messages)) if 'Messages' in messages and messages['Messages']: - entries = [{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} - for message in messages['Messages']] + entries = [ + {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} + for message in messages['Messages'] + ] - result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, - Entries=entries) + result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries) if 'Successful' in result: context['ti'].xcom_push(key='messages', value=messages) return True else: raise AirflowException( - 'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages)) + 'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages) + ) return False diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py index a0e640e89b17b..6126670955888 100644 --- a/airflow/providers/amazon/aws/sensors/step_function_execution.py +++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py @@ -39,7 +39,11 @@ class StepFunctionExecutionSensor(BaseSensorOperator): """ INTERMEDIATE_STATES = ('RUNNING',) - FAILURE_STATES = ('FAILED', 'TIMED_OUT', 'ABORTED',) + FAILURE_STATES = ( + 'FAILED', + 'TIMED_OUT', + 'ABORTED', + ) SUCCESS_STATES = ('SUCCEEDED',) template_fields = ['execution_arn'] @@ -47,8 +51,7 @@ class StepFunctionExecutionSensor(BaseSensorOperator): ui_color = '#66c3ff' @apply_defaults - def __init__(self, *, execution_arn: str, aws_conn_id='aws_default', region_name=None, - **kwargs): + def __init__(self, *, execution_arn: str, aws_conn_id='aws_default', region_name=None, **kwargs): super().__init__(**kwargs) self.execution_arn = execution_arn self.aws_conn_id = aws_conn_id diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index 7f71a5445115d..40bb0264afea4 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -42,9 +42,7 @@ def _upload_file_to_s3(file_obj, bucket_name, s3_key_prefix): s3_client = S3Hook().get_conn() file_obj.seek(0) s3_client.upload_file( - Filename=file_obj.name, - Bucket=bucket_name, - Key=s3_key_prefix + str(uuid4()), + Filename=file_obj.name, Bucket=bucket_name, Key=s3_key_prefix + str(uuid4()), ) @@ -92,14 +90,17 @@ class DynamoDBToS3Operator(BaseOperator): """ @apply_defaults - def __init__(self, *, - dynamodb_table_name: str, - s3_bucket_name: str, - file_size: int, - dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None, - s3_key_prefix: str = '', - process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes, - **kwargs): + def __init__( + self, + *, + dynamodb_table_name: str, + s3_bucket_name: str, + file_size: int, + dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None, + s3_key_prefix: str = '', + process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes, + **kwargs, + ): super().__init__(**kwargs) self.file_size = file_size self.process_func = process_func @@ -139,8 +140,7 @@ def _scan_dynamodb_and_upload_to_s3(self, temp_file, scan_kwargs, table): # Upload the file to S3 if reach file size limit if getsize(temp_file.name) >= self.file_size: - _upload_file_to_s3(temp_file, self.s3_bucket_name, - self.s3_key_prefix) + _upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix) temp_file.close() temp_file = NamedTemporaryFile() return temp_file diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index 29b04dc168dcd..212c9787195d3 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -83,30 +83,42 @@ class GCSToS3Operator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ - template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'dest_s3_key', - 'google_impersonation_chain',) + + template_fields: Iterable[str] = ( + 'bucket', + 'prefix', + 'delimiter', + 'dest_s3_key', + 'google_impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - bucket, - prefix=None, - delimiter=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - dest_aws_conn_id=None, - dest_s3_key=None, - dest_verify=None, - replace=False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs): + def __init__( + self, + *, # pylint: disable=too-many-arguments + bucket, + prefix=None, + delimiter=None, + gcp_conn_id='google_cloud_default', + google_cloud_storage_conn_id=None, + delegate_to=None, + dest_aws_conn_id=None, + dest_s3_key=None, + dest_verify=None, + replace=False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.bucket = bucket @@ -128,12 +140,14 @@ def execute(self, context): impersonation_chain=self.google_impersonation_chain, ) - self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', - self.bucket, self.delimiter, self.prefix) + self.log.info( + 'Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', + self.bucket, + self.delimiter, + self.prefix, + ) - files = hook.list(bucket_name=self.bucket, - prefix=self.prefix, - delimiter=self.delimiter) + files = hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify) @@ -159,9 +173,7 @@ def execute(self, context): dest_key = self.dest_s3_key + file self.log.info("Saving file to %s", dest_key) - s3_hook.load_bytes(file_bytes, - key=dest_key, - replace=self.replace) + s3_hook.load_bytes(file_bytes, key=dest_key, replace=self.replace) self.log.info("All done, uploaded %d files to S3", len(files)) else: diff --git a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index 85695deb43361..ca17bed221247 100644 --- a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -96,7 +96,8 @@ class GoogleApiToS3Operator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, google_api_service_name, google_api_service_version, google_api_endpoint_path, @@ -112,7 +113,7 @@ def __init__( delegate_to=None, aws_conn_id='aws_default', google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.google_api_service_name = google_api_service_name @@ -162,22 +163,20 @@ def _retrieve_data_from_google_api(self): endpoint=self.google_api_endpoint_path, data=self.google_api_endpoint_params, paginate=self.google_api_pagination, - num_retries=self.google_api_num_retries + num_retries=self.google_api_num_retries, ) return google_api_response def _load_data_to_s3(self, data): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) s3_hook.load_string( - string_data=json.dumps(data), - key=self.s3_destination_key, - replace=self.s3_overwrite + string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite ) def _update_google_api_endpoint_params_via_xcom(self, task_instance): google_api_endpoint_params = task_instance.xcom_pull( task_ids=self.google_api_endpoint_params_via_xcom_task_ids, - key=self.google_api_endpoint_params_via_xcom + key=self.google_api_endpoint_params_via_xcom, ) self.google_api_endpoint_params.update(google_api_endpoint_params) diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py index 3eecacaaa3c77..0ca7218cd1b00 100644 --- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py @@ -62,18 +62,20 @@ class HiveToDynamoDBOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, - sql, - table_name, - table_keys, - pre_process=None, - pre_process_args=None, - pre_process_kwargs=None, - region_name=None, - schema='default', - hiveserver2_conn_id='hiveserver2_default', - aws_conn_id='aws_default', - **kwargs): + self, + *, + sql, + table_name, + table_keys, + pre_process=None, + pre_process_args=None, + pre_process_kwargs=None, + region_name=None, + schema='default', + hiveserver2_conn_id='hiveserver2_default', + aws_conn_id='aws_default', + **kwargs, + ): super().__init__(**kwargs) self.sql = sql self.table_name = table_name @@ -93,20 +95,20 @@ def execute(self, context): self.log.info(self.sql) data = hive.get_pandas_df(self.sql, schema=self.schema) - dynamodb = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id, - table_name=self.table_name, - table_keys=self.table_keys, - region_name=self.region_name) + dynamodb = AwsDynamoDBHook( + aws_conn_id=self.aws_conn_id, + table_name=self.table_name, + table_keys=self.table_keys, + region_name=self.region_name, + ) self.log.info('Inserting rows into dynamodb') if self.pre_process is None: - dynamodb.write_batch_data( - json.loads(data.to_json(orient='records'))) + dynamodb.write_batch_data(json.loads(data.to_json(orient='records'))) else: dynamodb.write_batch_data( - self.pre_process(data=data, - args=self.pre_process_args, - kwargs=self.pre_process_kwargs)) + self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs) + ) self.log.info('Done.') diff --git a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py index 79505f123cc80..bf65b8f56db28 100644 --- a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py @@ -50,19 +50,23 @@ class ImapAttachmentToS3Operator(BaseOperator): :param s3_conn_id: The reference to the s3 connection details. :type s3_conn_id: str """ + template_fields = ('imap_attachment_name', 's3_key', 'imap_mail_filter') @apply_defaults - def __init__(self, *, - imap_attachment_name, - s3_key, - imap_check_regex=False, - imap_mail_folder='INBOX', - imap_mail_filter='All', - s3_overwrite=False, - imap_conn_id='imap_default', - s3_conn_id='aws_default', - **kwargs): + def __init__( + self, + *, + imap_attachment_name, + s3_key, + imap_check_regex=False, + imap_mail_folder='INBOX', + imap_mail_filter='All', + s3_overwrite=False, + imap_conn_id='imap_default', + s3_conn_id='aws_default', + **kwargs, + ): super().__init__(**kwargs) self.imap_attachment_name = imap_attachment_name self.s3_key = s3_key @@ -82,7 +86,8 @@ def execute(self, context): """ self.log.info( 'Transferring mail attachment %s from mail server via imap to s3 key %s...', - self.imap_attachment_name, self.s3_key + self.imap_attachment_name, + self.s3_key, ) with ImapHook(imap_conn_id=self.imap_conn_id) as imap_hook: @@ -95,6 +100,4 @@ def execute(self, context): ) s3_hook = S3Hook(aws_conn_id=self.s3_conn_id) - s3_hook.load_bytes(bytes_data=imap_mail_attachments[0][1], - key=self.s3_key, - replace=self.s3_overwrite) + s3_hook.load_bytes(bytes_data=imap_mail_attachments[0][1], key=self.s3_key, replace=self.s3_overwrite) diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py index 214689c5efac7..b996e10ec76b7 100644 --- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py @@ -41,16 +41,19 @@ class MongoToS3Operator(BaseOperator): # pylint: disable=too-many-instance-attributes @apply_defaults - def __init__(self, *, - mongo_conn_id, - s3_conn_id, - mongo_collection, - mongo_query, - s3_bucket, - s3_key, - mongo_db=None, - replace=False, - **kwargs): + def __init__( + self, + *, + mongo_conn_id, + s3_conn_id, + mongo_collection, + mongo_query, + s3_bucket, + s3_key, + mongo_db=None, + replace=False, + **kwargs, + ): super().__init__(**kwargs) # Conn Ids self.mongo_conn_id = mongo_conn_id @@ -78,14 +81,12 @@ def execute(self, context): results = MongoHook(self.mongo_conn_id).aggregate( mongo_collection=self.mongo_collection, aggregate_query=self.mongo_query, - mongo_db=self.mongo_db + mongo_db=self.mongo_db, ) else: results = MongoHook(self.mongo_conn_id).find( - mongo_collection=self.mongo_collection, - query=self.mongo_query, - mongo_db=self.mongo_db + mongo_collection=self.mongo_collection, query=self.mongo_query, mongo_db=self.mongo_db ) # Performs transform then stringifies the docs results into json format @@ -93,10 +94,7 @@ def execute(self, context): # Load Into S3 s3_conn.load_string( - string_data=docs_str, - key=self.s3_key, - bucket_name=self.s3_bucket, - replace=self.replace + string_data=docs_str, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace ) return True @@ -107,9 +105,7 @@ def _stringify(iterable, joinable='\n'): Takes an iterable (pymongo Cursor or Array) containing dictionaries and returns a stringified version using python join """ - return joinable.join( - [json.dumps(doc, default=json_util.default) for doc in iterable] - ) + return joinable.join([json.dumps(doc, default=json_util.default) for doc in iterable]) @staticmethod def transform(docs): diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py index 249e4b24b208c..7a376f1fa149c 100644 --- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py @@ -63,22 +63,27 @@ class MySQLToS3Operator(BaseOperator): :type header: bool """ - template_fields = ('s3_key', 'query',) + template_fields = ( + 's3_key', + 'query', + ) template_ext = ('.sql',) @apply_defaults def __init__( - self, *, - query: str, - s3_bucket: str, - s3_key: str, - mysql_conn_id: str = 'mysql_default', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - pd_csv_kwargs: Optional[dict] = None, - index: Optional[bool] = False, - header: Optional[bool] = False, - **kwargs) -> None: + self, + *, + query: str, + s3_bucket: str, + s3_key: str, + mysql_conn_id: str = 'mysql_default', + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + pd_csv_kwargs: Optional[dict] = None, + index: Optional[bool] = False, + header: Optional[bool] = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.query = query self.s3_bucket = s3_bucket @@ -116,9 +121,7 @@ def execute(self, context): self._fix_int_dtypes(data_df) with NamedTemporaryFile(mode='r+', suffix='.csv') as tmp_csv: data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs) - s3_conn.load_file(filename=tmp_csv.name, - key=self.s3_key, - bucket_name=self.s3_bucket) + s3_conn.load_file(filename=tmp_csv.name, key=self.s3_key, bucket_name=self.s3_bucket) if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): file_location = os.path.join(self.s3_bucket, self.s3_key) diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 9f1b113caf94f..3a3e6c24643a3 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -71,19 +71,21 @@ class RedshiftToS3Operator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, - schema: str, - table: str, - s3_bucket: str, - s3_key: str, - redshift_conn_id: str = 'redshift_default', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - unload_options: Optional[List] = None, - autocommit: bool = False, - include_header: bool = False, - table_as_file_name: bool = True, # Set to True by default for not breaking current workflows - **kwargs) -> None: + self, + *, + schema: str, + table: str, + s3_bucket: str, + s3_key: str, + redshift_conn_id: str = 'redshift_default', + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + unload_options: Optional[List] = None, + autocommit: bool = False, + include_header: bool = False, + table_as_file_name: bool = True, # Set to True by default for not breaking current workflows + **kwargs, + ) -> None: super().__init__(**kwargs) self.schema = schema self.table = table @@ -98,7 +100,9 @@ def __init__( # pylint: disable=too-many-arguments self.table_as_file_name = table_as_file_name if self.include_header and 'HEADER' not in [uo.upper().strip() for uo in self.unload_options]: - self.unload_options = list(self.unload_options) + ['HEADER', ] + self.unload_options = list(self.unload_options) + [ + 'HEADER', + ] def execute(self, context): postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) @@ -114,12 +118,14 @@ def execute(self, context): with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {unload_options}; - """.format(select_query=select_query, - s3_bucket=self.s3_bucket, - s3_key=s3_key, - access_key=credentials.access_key, - secret_key=credentials.secret_key, - unload_options=unload_options) + """.format( + select_query=select_query, + s3_bucket=self.s3_bucket, + s3_key=s3_key, + access_key=credentials.access_key, + secret_key=credentials.secret_key, + unload_options=unload_options, + ) self.log.info('Executing UNLOAD command...') postgres_hook.run(unload_query, self.autocommit) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 1ddbeaeef5681..3b2afd72f652e 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -64,17 +64,19 @@ class S3ToRedshiftOperator(BaseOperator): @apply_defaults def __init__( - self, *, - schema: str, - table: str, - s3_bucket: str, - s3_key: str, - redshift_conn_id: str = 'redshift_default', - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - copy_options: Optional[List] = None, - autocommit: bool = False, - **kwargs) -> None: + self, + *, + schema: str, + table: str, + s3_bucket: str, + s3_key: str, + redshift_conn_id: str = 'redshift_default', + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + copy_options: Optional[List] = None, + autocommit: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.schema = schema self.table = table @@ -100,13 +102,15 @@ def execute(self, context): with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {copy_options}; - """.format(schema=self.schema, - table=self.table, - s3_bucket=self.s3_bucket, - s3_key=self.s3_key, - access_key=credentials.access_key, - secret_key=credentials.secret_key, - copy_options=copy_options) + """.format( + schema=self.schema, + table=self.table, + s3_bucket=self.s3_bucket, + s3_key=self.s3_key, + access_key=credentials.access_key, + secret_key=credentials.secret_key, + copy_options=copy_options, + ) self.log.info('Executing COPY command...') self._postgres_hook.run(copy_query, self.autocommit) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py index fd9246d416b3b..fe87c69afe0d6 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py @@ -49,13 +49,9 @@ class S3ToSFTPOperator(BaseOperator): template_fields = ('s3_key', 'sftp_path') @apply_defaults - def __init__(self, *, - s3_bucket, - s3_key, - sftp_path, - sftp_conn_id='ssh_default', - s3_conn_id='aws_default', - **kwargs): + def __init__( + self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs + ): super().__init__(**kwargs) self.sftp_conn_id = sftp_conn_id self.sftp_path = sftp_path diff --git a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py index c1b6e65e5b82d..087eb74e8154c 100644 --- a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py @@ -49,13 +49,9 @@ class SFTPToS3Operator(BaseOperator): template_fields = ('s3_key', 'sftp_path') @apply_defaults - def __init__(self, *, - s3_bucket, - s3_key, - sftp_path, - sftp_conn_id='ssh_default', - s3_conn_id='aws_default', - **kwargs): + def __init__( + self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs + ): super().__init__(**kwargs) self.sftp_conn_id = sftp_conn_id self.sftp_path = sftp_path @@ -80,9 +76,4 @@ def execute(self, context): with NamedTemporaryFile("w") as f: sftp_client.get(self.sftp_path, f.name) - s3_hook.load_file( - filename=f.name, - key=self.s3_key, - bucket_name=self.s3_bucket, - replace=True - ) + s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True) diff --git a/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py b/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py index ce736344b9c25..b28fa845d52d5 100644 --- a/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py +++ b/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py @@ -34,7 +34,7 @@ default_args=args, schedule_interval=None, start_date=days_ago(2), - tags=['example'] + tags=['example'], ) as dag: # [START howto_operator_cassandra_table_sensor] table_sensor = CassandraTableSensor( diff --git a/airflow/providers/apache/cassandra/hooks/cassandra.py b/airflow/providers/apache/cassandra/hooks/cassandra.py index b3608850121dc..71aea789a12a1 100644 --- a/airflow/providers/apache/cassandra/hooks/cassandra.py +++ b/airflow/providers/apache/cassandra/hooks/cassandra.py @@ -25,7 +25,10 @@ from cassandra.auth import PlainTextAuthProvider from cassandra.cluster import Cluster, Session from cassandra.policies import ( - DCAwareRoundRobinPolicy, RoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, + DCAwareRoundRobinPolicy, + RoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, ) from airflow.hooks.base_hook import BaseHook @@ -81,6 +84,7 @@ class CassandraHook(BaseHook, LoggingMixin): For details of the Cluster config, see cassandra.cluster. """ + def __init__(self, cassandra_conn_id: str = 'cassandra_default'): super().__init__() conn = self.get_connection(cassandra_conn_id) @@ -93,8 +97,7 @@ def __init__(self, cassandra_conn_id: str = 'cassandra_default'): conn_config['port'] = int(conn.port) if conn.login: - conn_config['auth_provider'] = PlainTextAuthProvider( - username=conn.login, password=conn.password) + conn_config['auth_provider'] = PlainTextAuthProvider(username=conn.login, password=conn.password) policy_name = conn.extra_dejson.get('load_balancing_policy', None) policy_args = conn.extra_dejson.get('load_balancing_policy_args', {}) @@ -158,17 +161,17 @@ def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy: return WhiteListRoundRobinPolicy(hosts) if policy_name == 'TokenAwarePolicy': - allowed_child_policies = ('RoundRobinPolicy', - 'DCAwareRoundRobinPolicy', - 'WhiteListRoundRobinPolicy',) - child_policy_name = policy_args.get('child_load_balancing_policy', - 'RoundRobinPolicy') + allowed_child_policies = ( + 'RoundRobinPolicy', + 'DCAwareRoundRobinPolicy', + 'WhiteListRoundRobinPolicy', + ) + child_policy_name = policy_args.get('child_load_balancing_policy', 'RoundRobinPolicy') child_policy_args = policy_args.get('child_load_balancing_policy_args', {}) if child_policy_name not in allowed_child_policies: return TokenAwarePolicy(RoundRobinPolicy()) else: - child_policy = CassandraHook.get_lb_policy(child_policy_name, - child_policy_args) + child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args) return TokenAwarePolicy(child_policy) # Fallback to default RoundRobinPolicy @@ -186,8 +189,7 @@ def table_exists(self, table: str) -> bool: if '.' in table: keyspace, table = table.split('.', 1) cluster_metadata = self.get_conn().cluster.metadata - return (keyspace in cluster_metadata.keyspaces and - table in cluster_metadata.keyspaces[keyspace].tables) + return keyspace in cluster_metadata.keyspaces and table in cluster_metadata.keyspaces[keyspace].tables def record_exists(self, table: str, keys: Dict[str, str]) -> bool: """ diff --git a/airflow/providers/apache/cassandra/sensors/record.py b/airflow/providers/apache/cassandra/sensors/record.py index ea67ac5841c57..bc61b2953b7a9 100644 --- a/airflow/providers/apache/cassandra/sensors/record.py +++ b/airflow/providers/apache/cassandra/sensors/record.py @@ -53,6 +53,7 @@ class CassandraRecordSensor(BaseSensorOperator): when connecting to Cassandra cluster :type cassandra_conn_id: str """ + template_fields = ('table', 'keys') @apply_defaults diff --git a/airflow/providers/apache/cassandra/sensors/table.py b/airflow/providers/apache/cassandra/sensors/table.py index 82cd411c0ada4..64129d780a98d 100644 --- a/airflow/providers/apache/cassandra/sensors/table.py +++ b/airflow/providers/apache/cassandra/sensors/table.py @@ -51,6 +51,7 @@ class CassandraTableSensor(BaseSensorOperator): when connecting to Cassandra cluster :type cassandra_conn_id: str """ + template_fields = ('table',) @apply_defaults diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 3dbc5b9ae070b..b609c4aafe661 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -49,7 +49,7 @@ def __init__( self, druid_ingest_conn_id: str = 'druid_ingest_default', timeout: int = 1, - max_ingestion_time: Optional[int] = None + max_ingestion_time: Optional[int] = None, ) -> None: super().__init__() @@ -71,7 +71,8 @@ def get_conn_url(self) -> str: conn_type = 'http' if not conn.conn_type else conn.conn_type endpoint = conn.extra_dejson.get('endpoint', '') return "{conn_type}://{host}:{port}/{endpoint}".format( - conn_type=conn_type, host=host, port=port, endpoint=endpoint) + conn_type=conn_type, host=host, port=port, endpoint=endpoint + ) def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]: """ @@ -96,8 +97,7 @@ def submit_indexing_job(self, json_index_spec: Dict[str, Any]) -> None: self.log.info("Druid ingestion spec: %s", json_index_spec) req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth()) if req_index.status_code != 200: - raise AirflowException('Did not get 200 when ' - 'submitting the Druid job to {}'.format(url)) + raise AirflowException('Did not get 200 when ' 'submitting the Druid job to {}'.format(url)) req_json = req_index.json() # Wait until the job is completed @@ -115,8 +115,7 @@ def submit_indexing_job(self, json_index_spec: Dict[str, Any]) -> None: if self.max_ingestion_time and sec > self.max_ingestion_time: # ensure that the job gets killed if the max ingestion time is exceeded requests.post("{0}/{1}/shutdown".format(url, druid_task_id), auth=self.get_auth()) - raise AirflowException('Druid ingestion took more than ' - f'{self.max_ingestion_time} seconds') + raise AirflowException('Druid ingestion took more than ' f'{self.max_ingestion_time} seconds') time.sleep(self.timeout) @@ -128,8 +127,7 @@ def submit_indexing_job(self, json_index_spec: Dict[str, Any]) -> None: elif status == 'SUCCESS': running = False # Great success! elif status == 'FAILED': - raise AirflowException('Druid indexing job failed, ' - 'check console for more info') + raise AirflowException('Druid indexing job failed, ' 'check console for more info') else: raise AirflowException(f'Could not get status of the job, got {status}') @@ -143,6 +141,7 @@ class DruidDbApiHook(DbApiHook): This hook is purely for users to query druid broker. For ingestion, please use druidHook. """ + conn_name_attr = 'druid_broker_conn_id' default_conn_name = 'druid_broker_default' supports_autocommit = False @@ -158,7 +157,7 @@ def get_conn(self) -> connect: path=conn.extra_dejson.get('endpoint', '/druid/v2/sql'), scheme=conn.extra_dejson.get('schema', 'http'), user=conn.login, - password=conn.password + password=conn.password, ) self.log.info('Get the connection to druid broker on %s using user %s', conn.host, conn.login) return druid_broker_conn @@ -175,14 +174,18 @@ def get_uri(self) -> str: host += ':{port}'.format(port=conn.port) conn_type = 'druid' if not conn.conn_type else conn.conn_type endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql') - return '{conn_type}://{host}/{endpoint}'.format( - conn_type=conn_type, host=host, endpoint=endpoint) + return '{conn_type}://{host}/{endpoint}'.format(conn_type=conn_type, host=host, endpoint=endpoint) def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplemented: raise NotImplementedError() - def insert_rows(self, table: str, rows: Iterable[Tuple[str]], - target_fields: Optional[Iterable[str]] = None, - commit_every: int = 1000, replace: bool = False, - **kwargs: Any) -> NotImplemented: + def insert_rows( + self, + table: str, + rows: Iterable[Tuple[str]], + target_fields: Optional[Iterable[str]] = None, + commit_every: int = 1000, + replace: bool = False, + **kwargs: Any, + ) -> NotImplemented: raise NotImplementedError() diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py index f046ff1552a93..1ad665d8f7df7 100644 --- a/airflow/providers/apache/druid/operators/druid.py +++ b/airflow/providers/apache/druid/operators/druid.py @@ -34,23 +34,25 @@ class DruidOperator(BaseOperator): accepts index jobs :type druid_ingest_conn_id: str """ + template_fields = ('json_index_file',) template_ext = ('.json',) @apply_defaults - def __init__(self, *, json_index_file: str, - druid_ingest_conn_id: str = 'druid_ingest_default', - max_ingestion_time: Optional[int] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + json_index_file: str, + druid_ingest_conn_id: str = 'druid_ingest_default', + max_ingestion_time: Optional[int] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.json_index_file = json_index_file self.conn_id = druid_ingest_conn_id self.max_ingestion_time = max_ingestion_time def execute(self, context: Dict[Any, Any]) -> None: - hook = DruidHook( - druid_ingest_conn_id=self.conn_id, - max_ingestion_time=self.max_ingestion_time - ) + hook = DruidHook(druid_ingest_conn_id=self.conn_id, max_ingestion_time=self.max_ingestion_time) self.log.info("Submitting %s", self.json_index_file) hook.submit_indexing_job(json.loads(self.json_index_file)) diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py index 2f6114dfec0f7..12637880ac7cc 100644 --- a/airflow/providers/apache/druid/operators/druid_check.py +++ b/airflow/providers/apache/druid/operators/druid_check.py @@ -58,11 +58,7 @@ class DruidCheckOperator(CheckOperator): @apply_defaults def __init__( - self, - *, - sql: str, - druid_broker_conn_id: str = 'druid_broker_default', - **kwargs: Any + self, *, sql: str, druid_broker_conn_id: str = 'druid_broker_default', **kwargs: Any ) -> None: super().__init__(sql=sql, **kwargs) self.druid_broker_conn_id = druid_broker_conn_id diff --git a/airflow/providers/apache/druid/transfers/hive_to_druid.py b/airflow/providers/apache/druid/transfers/hive_to_druid.py index 595db0f5106bd..36d5ff40ced2b 100644 --- a/airflow/providers/apache/druid/transfers/hive_to_druid.py +++ b/airflow/providers/apache/druid/transfers/hive_to_druid.py @@ -84,7 +84,8 @@ class HiveToDruidOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, sql: str, druid_datasource: str, ts_dim: str, @@ -100,7 +101,7 @@ def __init__( # pylint: disable=too-many-arguments segment_granularity: str = "DAY", hive_tblproperties: Optional[Dict[Any, Any]] = None, job_properties: Optional[Dict[Any, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.sql = sql @@ -111,9 +112,7 @@ def __init__( # pylint: disable=too-many-arguments self.target_partition_size = target_partition_size self.query_granularity = query_granularity self.segment_granularity = segment_granularity - self.metric_spec = metric_spec or [{ - "name": "count", - "type": "count"}] + self.metric_spec = metric_spec or [{"name": "count", "type": "count"}] self.hive_cli_conn_id = hive_cli_conn_id self.hadoop_dependency_coordinates = hadoop_dependency_coordinates self.druid_ingest_conn_id = druid_ingest_conn_id @@ -126,9 +125,7 @@ def execute(self, context: Dict[str, Any]) -> None: self.log.info("Extracting data from Hive") hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_') sql = self.sql.strip().strip(';') - tblproperties = ''.join([", '{}' = '{}'" - .format(k, v) - for k, v in self.hive_tblproperties.items()]) + tblproperties = ''.join([", '{}' = '{}'".format(k, v) for k, v in self.hive_tblproperties.items()]) hql = f"""\ SET mapred.output.compress=false; SET hive.exec.compress.output=false; @@ -155,10 +152,7 @@ def execute(self, context: Dict[str, Any]) -> None: druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id) try: - index_spec = self.construct_ingest_query( - static_path=static_path, - columns=columns, - ) + index_spec = self.construct_ingest_query(static_path=static_path, columns=columns,) self.log.info("Inserting rows into Druid, hdfs path: %s", static_path) @@ -166,15 +160,11 @@ def execute(self, context: Dict[str, Any]) -> None: self.log.info("Load seems to have succeeded!") finally: - self.log.info( - "Cleaning up by dropping the temp Hive table %s", - hive_table - ) + self.log.info("Cleaning up by dropping the temp Hive table %s", hive_table) hql = "DROP TABLE IF EXISTS {}".format(hive_table) hive.run_cli(hql) - def construct_ingest_query(self, static_path: str, - columns: List[str]) -> Dict[str, Any]: + def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[str, Any]: """ Builds an ingest query for an HDFS TSV load. @@ -219,16 +209,13 @@ def construct_ingest_query(self, static_path: str, "dimensionsSpec": { "dimensionExclusions": [], "dimensions": dimensions, # list of names - "spatialDimensions": [] + "spatialDimensions": [], }, - "timestampSpec": { - "column": self.ts_dim, - "format": "auto" - }, - "format": "tsv" - } + "timestampSpec": {"column": self.ts_dim, "format": "auto"}, + "format": "tsv", + }, }, - "dataSource": self.druid_datasource + "dataSource": self.druid_datasource, }, "tuningConfig": { "type": "hadoop", @@ -243,22 +230,14 @@ def construct_ingest_query(self, static_path: str, "numShards": num_shards, }, }, - "ioConfig": { - "inputSpec": { - "paths": static_path, - "type": "static" - }, - "type": "hadoop" - } - } + "ioConfig": {"inputSpec": {"paths": static_path, "type": "static"}, "type": "hadoop"}, + }, } if self.job_properties: - ingest_query_dict['spec']['tuningConfig']['jobProperties'] \ - .update(self.job_properties) + ingest_query_dict['spec']['tuningConfig']['jobProperties'].update(self.job_properties) if self.hadoop_dependency_coordinates: - ingest_query_dict['hadoopDependencyCoordinates'] \ - = self.hadoop_dependency_coordinates + ingest_query_dict['hadoopDependencyCoordinates'] = self.hadoop_dependency_coordinates return ingest_query_dict diff --git a/airflow/providers/apache/hdfs/hooks/hdfs.py b/airflow/providers/apache/hdfs/hooks/hdfs.py index 61b37722a4dd6..e13a5c7cb5d85 100644 --- a/airflow/providers/apache/hdfs/hooks/hdfs.py +++ b/airflow/providers/apache/hdfs/hooks/hdfs.py @@ -46,18 +46,17 @@ class HDFSHook(BaseHook): :type autoconfig: bool """ - def __init__(self, - hdfs_conn_id: str = 'hdfs_default', - proxy_user: Optional[str] = None, - autoconfig: bool = False - ): + def __init__( + self, hdfs_conn_id: str = 'hdfs_default', proxy_user: Optional[str] = None, autoconfig: bool = False + ): super().__init__() if not snakebite_loaded: raise ImportError( 'This HDFSHook implementation requires snakebite, but ' 'snakebite is not compatible with Python 3 ' '(as of August 2015). Please use Python 2 if you require ' - 'this hook -- or help by submitting a PR!') + 'this hook -- or help by submitting a PR!' + ) self.hdfs_conn_id = hdfs_conn_id self.proxy_user = proxy_user self.autoconfig = autoconfig @@ -78,29 +77,34 @@ def get_conn(self) -> Any: if not effective_user: effective_user = connections[0].login if not autoconfig: - autoconfig = connections[0].extra_dejson.get('autoconfig', - False) - hdfs_namenode_principal = connections[0].extra_dejson.get( - 'hdfs_namenode_principal') + autoconfig = connections[0].extra_dejson.get('autoconfig', False) + hdfs_namenode_principal = connections[0].extra_dejson.get('hdfs_namenode_principal') except AirflowException: if not autoconfig: raise if autoconfig: # will read config info from $HADOOP_HOME conf files - client = AutoConfigClient(effective_user=effective_user, - use_sasl=use_sasl) + client = AutoConfigClient(effective_user=effective_user, use_sasl=use_sasl) elif len(connections) == 1: - client = Client(connections[0].host, connections[0].port, - effective_user=effective_user, use_sasl=use_sasl, - hdfs_namenode_principal=hdfs_namenode_principal) + client = Client( + connections[0].host, + connections[0].port, + effective_user=effective_user, + use_sasl=use_sasl, + hdfs_namenode_principal=hdfs_namenode_principal, + ) elif len(connections) > 1: name_node = [Namenode(conn.host, conn.port) for conn in connections] - client = HAClient(name_node, effective_user=effective_user, - use_sasl=use_sasl, - hdfs_namenode_principal=hdfs_namenode_principal) + client = HAClient( + name_node, + effective_user=effective_user, + use_sasl=use_sasl, + hdfs_namenode_principal=hdfs_namenode_principal, + ) else: - raise HDFSHookException("conn_id doesn't exist in the repository " - "and autoconfig is not specified") + raise HDFSHookException( + "conn_id doesn't exist in the repository " "and autoconfig is not specified" + ) return client diff --git a/airflow/providers/apache/hdfs/hooks/webhdfs.py b/airflow/providers/apache/hdfs/hooks/webhdfs.py index a72c7b0823a5c..bc24601cc1abb 100644 --- a/airflow/providers/apache/hdfs/hooks/webhdfs.py +++ b/airflow/providers/apache/hdfs/hooks/webhdfs.py @@ -52,9 +52,7 @@ class WebHDFSHook(BaseHook): :type proxy_user: str """ - def __init__(self, webhdfs_conn_id: str = 'webhdfs_default', - proxy_user: Optional[str] = None - ): + def __init__(self, webhdfs_conn_id: str = 'webhdfs_default', proxy_user: Optional[str] = None): super().__init__() self.webhdfs_conn_id = webhdfs_conn_id self.proxy_user = proxy_user @@ -88,8 +86,9 @@ def _find_valid_server(self) -> Any: self.log.error("Could not connect to %s:%s", connection.host, connection.port) host_socket.close() except HdfsError as hdfs_error: - self.log.error('Read operation on namenode %s failed with error: %s', - connection.host, hdfs_error) + self.log.error( + 'Read operation on namenode %s failed with error: %s', connection.host, hdfs_error + ) return None def _get_client(self, connection: Connection) -> Any: @@ -117,9 +116,9 @@ def check_for_path(self, hdfs_path: str) -> bool: status = conn.status(hdfs_path, strict=False) return bool(status) - def load_file(self, source: str, destination: str, - overwrite: bool = True, parallelism: int = 1, - **kwargs: Any) -> None: + def load_file( + self, source: str, destination: str, overwrite: bool = True, parallelism: int = 1, **kwargs: Any + ) -> None: r""" Uploads a file to HDFS. @@ -140,9 +139,7 @@ def load_file(self, source: str, destination: str, """ conn = self.get_conn() - conn.upload(hdfs_path=destination, - local_path=source, - overwrite=overwrite, - n_threads=parallelism, - **kwargs) + conn.upload( + hdfs_path=destination, local_path=source, overwrite=overwrite, n_threads=parallelism, **kwargs + ) self.log.debug("Uploaded file %s to %s", source, destination) diff --git a/airflow/providers/apache/hdfs/sensors/hdfs.py b/airflow/providers/apache/hdfs/sensors/hdfs.py index 85a8eb1e895c4..d7235dcabd1d2 100644 --- a/airflow/providers/apache/hdfs/sensors/hdfs.py +++ b/airflow/providers/apache/hdfs/sensors/hdfs.py @@ -32,19 +32,22 @@ class HdfsSensor(BaseSensorOperator): """ Waits for a file or folder to land in HDFS """ + template_fields = ('filepath',) ui_color = settings.WEB_COLORS['LIGHTBLUE'] @apply_defaults - def __init__(self, - *, - filepath: str, - hdfs_conn_id: str = 'hdfs_default', - ignored_ext: Optional[List[str]] = None, - ignore_copying: bool = True, - file_size: Optional[int] = None, - hook: Type[HDFSHook] = HDFSHook, - **kwargs: Any) -> None: + def __init__( + self, + *, + filepath: str, + hdfs_conn_id: str = 'hdfs_default', + ignored_ext: Optional[List[str]] = None, + ignore_copying: bool = True, + file_size: Optional[int] = None, + hook: Type[HDFSHook] = HDFSHook, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) if ignored_ext is None: ignored_ext = ['_COPYING_'] @@ -56,10 +59,7 @@ def __init__(self, self.hook = hook @staticmethod - def filter_for_filesize( - result: List[Dict[Any, Any]], - size: Optional[int] = None - ) -> List[Dict[Any, Any]]: + def filter_for_filesize(result: List[Dict[Any, Any]], size: Optional[int] = None) -> List[Dict[Any, Any]]: """ Will test the filepath result and test if its size is at least self.filesize @@ -68,10 +68,7 @@ def filter_for_filesize( :return: (bool) depending on the matching criteria """ if size: - log.debug( - 'Filtering for file size >= %s in files: %s', - size, map(lambda x: x['path'], result) - ) + log.debug('Filtering for file size >= %s in files: %s', size, map(lambda x: x['path'], result)) size *= settings.MEGABYTE result = [x for x in result if x['length'] >= size] log.debug('HdfsSensor.poke: after size filter result is %s', result) @@ -79,9 +76,7 @@ def filter_for_filesize( @staticmethod def filter_for_ignored_ext( - result: List[Dict[Any, Any]], - ignored_ext: List[str], - ignore_copying: bool + result: List[Dict[Any, Any]], ignored_ext: List[str], ignore_copying: bool ) -> List[Dict[Any, Any]]: """ Will filter if instructed to do so the result to remove matching criteria @@ -100,7 +95,8 @@ def filter_for_ignored_ext( ignored_extensions_regex = re.compile(regex_builder) log.debug( 'Filtering result for ignored extensions: %s in files %s', - ignored_extensions_regex.pattern, map(lambda x: x['path'], result) + ignored_extensions_regex.pattern, + map(lambda x: x['path'], result), ) result = [x for x in result if not ignored_extensions_regex.match(x['path'])] log.debug('HdfsSensor.poke: after ext filter result is %s', result) @@ -118,9 +114,7 @@ def poke(self, context: Dict[Any, Any]) -> bool: # here is a quick fix result = sb_client.ls([self.filepath], include_toplevel=False) self.log.debug('HdfsSensor.poke: result is %s', result) - result = self.filter_for_ignored_ext( - result, self.ignored_ext, self.ignore_copying - ) + result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying) result = self.filter_for_filesize(result, self.file_size) return bool(result) except Exception: # pylint: disable=broad-except @@ -134,10 +128,7 @@ class HdfsRegexSensor(HdfsSensor): Waits for matching files by matching on regex """ - def __init__(self, - regex: Pattern[str], - *args: Any, - **kwargs: Any) -> None: + def __init__(self, regex: Pattern[str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.regex = regex @@ -151,11 +142,12 @@ def poke(self, context: Dict[Any, Any]) -> bool: self.log.info( 'Poking for %s to be a directory with files matching %s', self.filepath, self.regex.pattern ) - result = [f for f in sb_client.ls([self.filepath], include_toplevel=False) if - f['file_type'] == 'f' and - self.regex.match(f['path'].replace('%s/' % self.filepath, ''))] - result = self.filter_for_ignored_ext(result, self.ignored_ext, - self.ignore_copying) + result = [ + f + for f in sb_client.ls([self.filepath], include_toplevel=False) + if f['file_type'] == 'f' and self.regex.match(f['path'].replace('%s/' % self.filepath, '')) + ] + result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying) result = self.filter_for_filesize(result, self.file_size) return bool(result) @@ -165,10 +157,7 @@ class HdfsFolderSensor(HdfsSensor): Waits for a non-empty directory """ - def __init__(self, - be_empty: bool = False, - *args: Any, - **kwargs: Any): + def __init__(self, be_empty: bool = False, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.be_empty = be_empty @@ -180,8 +169,7 @@ def poke(self, context: Dict[str, Any]) -> bool: """ sb_client = self.hook(self.hdfs_conn_id).get_conn() result = sb_client.ls([self.filepath], include_toplevel=True) - result = self.filter_for_ignored_ext(result, self.ignored_ext, - self.ignore_copying) + result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying) result = self.filter_for_filesize(result, self.file_size) if self.be_empty: self.log.info('Poking for filepath %s to a empty directory', self.filepath) diff --git a/airflow/providers/apache/hdfs/sensors/web_hdfs.py b/airflow/providers/apache/hdfs/sensors/web_hdfs.py index 8d21b3e84dfd1..edc3c8b560553 100644 --- a/airflow/providers/apache/hdfs/sensors/web_hdfs.py +++ b/airflow/providers/apache/hdfs/sensors/web_hdfs.py @@ -25,20 +25,18 @@ class WebHdfsSensor(BaseSensorOperator): """ Waits for a file or folder to land in HDFS """ + template_fields = ('filepath',) @apply_defaults - def __init__(self, - *, - filepath: str, - webhdfs_conn_id: str = 'webhdfs_default', - **kwargs: Any) -> None: + def __init__(self, *, filepath: str, webhdfs_conn_id: str = 'webhdfs_default', **kwargs: Any) -> None: super().__init__(**kwargs) self.filepath = filepath self.webhdfs_conn_id = webhdfs_conn_id def poke(self, context: Dict[Any, Any]) -> bool: from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook + hook = WebHDFSHook(self.webhdfs_conn_id) self.log.info('Poking for file %s', self.filepath) return hook.check_for_path(hdfs_path=self.filepath) diff --git a/airflow/providers/apache/hive/example_dags/example_twitter_dag.py b/airflow/providers/apache/hive/example_dags/example_twitter_dag.py index 7dc03df7854cf..8c9d1f390f1c8 100644 --- a/airflow/providers/apache/hive/example_dags/example_twitter_dag.py +++ b/airflow/providers/apache/hive/example_dags/example_twitter_dag.py @@ -99,20 +99,14 @@ def transfertodb(): # is direction(from or to)_twitterHandle_date.csv # -------------------------------------------------------------------------------- - fetch_tweets = PythonOperator( - task_id='fetch_tweets', - python_callable=fetchtweets - ) + fetch_tweets = PythonOperator(task_id='fetch_tweets', python_callable=fetchtweets) # -------------------------------------------------------------------------------- # Clean the eight files. In this step you can get rid of or cherry pick columns # and different parts of the text # -------------------------------------------------------------------------------- - clean_tweets = PythonOperator( - task_id='clean_tweets', - python_callable=cleantweets - ) + clean_tweets = PythonOperator(task_id='clean_tweets', python_callable=cleantweets) clean_tweets << fetch_tweets @@ -122,10 +116,7 @@ def transfertodb(): # complicated. You can also take a look at Web Services to do such tasks # -------------------------------------------------------------------------------- - analyze_tweets = PythonOperator( - task_id='analyze_tweets', - python_callable=analyzetweets - ) + analyze_tweets = PythonOperator(task_id='analyze_tweets', python_callable=analyzetweets) analyze_tweets << clean_tweets @@ -135,10 +126,7 @@ def transfertodb(): # it to MySQL # -------------------------------------------------------------------------------- - hive_to_mysql = PythonOperator( - task_id='hive_to_mysql', - python_callable=transfertodb - ) + hive_to_mysql = PythonOperator(task_id='hive_to_mysql', python_callable=transfertodb) # -------------------------------------------------------------------------------- # The following tasks are generated using for loop. The first task puts the eight @@ -163,19 +151,21 @@ def transfertodb(): load_to_hdfs = BashOperator( task_id="put_" + channel + "_to_hdfs", - bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " + - local_dir + file_name + - hdfs_dir + channel + "/" + bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " + + local_dir + + file_name + + hdfs_dir + + channel + + "/", ) load_to_hdfs << analyze_tweets load_to_hive = HiveOperator( task_id="load_" + channel + "_to_hive", - hql="LOAD DATA INPATH '" + - hdfs_dir + channel + "/" + file_name + "' " - "INTO TABLE " + channel + " " - "PARTITION(dt='" + dt + "')" + hql="LOAD DATA INPATH '" + hdfs_dir + channel + "/" + file_name + "' " + "INTO TABLE " + channel + " " + "PARTITION(dt='" + dt + "')", ) load_to_hive << load_to_hdfs load_to_hive >> hive_to_mysql @@ -184,19 +174,21 @@ def transfertodb(): file_name = "from_" + channel + "_" + yesterday.strftime("%Y-%m-%d") + ".csv" load_to_hdfs = BashOperator( task_id="put_" + channel + "_to_hdfs", - bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " + - local_dir + file_name + - hdfs_dir + channel + "/" + bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " + + local_dir + + file_name + + hdfs_dir + + channel + + "/", ) load_to_hdfs << analyze_tweets load_to_hive = HiveOperator( task_id="load_" + channel + "_to_hive", - hql="LOAD DATA INPATH '" + - hdfs_dir + channel + "/" + file_name + "' " - "INTO TABLE " + channel + " " - "PARTITION(dt='" + dt + "')" + hql="LOAD DATA INPATH '" + hdfs_dir + channel + "/" + file_name + "' " + "INTO TABLE " + channel + " " + "PARTITION(dt='" + dt + "')", ) load_to_hive << load_to_hdfs diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py index 677ba3fe5e65e..6b7042a905bc1 100644 --- a/airflow/providers/apache/hive/hooks/hive.py +++ b/airflow/providers/apache/hive/hooks/hive.py @@ -46,8 +46,10 @@ def get_context_from_env_var() -> Dict[Any, Any]: :return: The context of interest. """ - return {format_map['default']: os.environ.get(format_map['env_var_format'], '') - for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()} + return { + format_map['default']: os.environ.get(format_map['env_var_format'], '') + for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + } class HiveCliHook(BaseHook): @@ -82,7 +84,7 @@ def __init__( run_as: Optional[str] = None, mapred_queue: Optional[str] = None, mapred_queue_priority: Optional[str] = None, - mapred_job_name: Optional[str] = None + mapred_job_name: Optional[str] = None, ) -> None: super().__init__() conn = self.get_connection(hive_cli_conn_id) @@ -98,10 +100,10 @@ def __init__( if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES: raise AirflowException( "Invalid Mapred Queue Priority. Valid values are: " - "{}".format(', '.join(HIVE_QUEUE_PRIORITIES))) + "{}".format(', '.join(HIVE_QUEUE_PRIORITIES)) + ) - self.mapred_queue = mapred_queue or conf.get('hive', - 'default_hive_mapred_queue') + self.mapred_queue = mapred_queue or conf.get('hive', 'default_hive_mapred_queue') self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name @@ -131,18 +133,18 @@ def _prepare_cli_cmd(self) -> List[Any]: if self.use_beeline: hive_bin = 'beeline' jdbc_url = "jdbc:hive2://{host}:{port}/{schema}".format( - host=conn.host, port=conn.port, schema=conn.schema) + host=conn.host, port=conn.port, schema=conn.schema + ) if conf.get('core', 'security') == 'kerberos': - template = conn.extra_dejson.get( - 'principal', "hive/_HOST@EXAMPLE.COM") + template = conn.extra_dejson.get('principal', "hive/_HOST@EXAMPLE.COM") if "_HOST" in template: - template = utils.replace_hostname_pattern( - utils.get_components(template)) + template = utils.replace_hostname_pattern(utils.get_components(template)) proxy_user = self._get_proxy_user() jdbc_url += ";principal={template};{proxy_user}".format( - template=template, proxy_user=proxy_user) + template=template, proxy_user=proxy_user + ) elif self.auth: jdbc_url += ";auth=" + self.auth @@ -176,17 +178,15 @@ def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]: """ if not d: return [] - return as_flattened_list( - zip(["-hiveconf"] * len(d), - ["{}={}".format(k, v) for k, v in d.items()]) - ) + return as_flattened_list(zip(["-hiveconf"] * len(d), ["{}={}".format(k, v) for k, v in d.items()])) - def run_cli(self, - hql: Union[str, Text], - schema: Optional[str] = None, - verbose: Optional[bool] = True, - hive_conf: Optional[Dict[Any, Any]] = None - ) -> Any: + def run_cli( + self, + hql: Union[str, Text], + schema: Optional[str] = None, + verbose: Optional[bool] = True, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Any: """ Run an hql statement using the hive cli. If hive_conf is specified it should be a dict and the entries will be set as key/value pairs @@ -222,28 +222,23 @@ def run_cli(self, hive_conf_params = self._prepare_hiveconf(env_context) if self.mapred_queue: hive_conf_params.extend( - ['-hiveconf', - 'mapreduce.job.queuename={}' - .format(self.mapred_queue), - '-hiveconf', - 'mapred.job.queue.name={}' - .format(self.mapred_queue), - '-hiveconf', - 'tez.queue.name={}' - .format(self.mapred_queue) - ]) + [ + '-hiveconf', + 'mapreduce.job.queuename={}'.format(self.mapred_queue), + '-hiveconf', + 'mapred.job.queue.name={}'.format(self.mapred_queue), + '-hiveconf', + 'tez.queue.name={}'.format(self.mapred_queue), + ] + ) if self.mapred_queue_priority: hive_conf_params.extend( - ['-hiveconf', - 'mapreduce.job.priority={}' - .format(self.mapred_queue_priority)]) + ['-hiveconf', 'mapreduce.job.priority={}'.format(self.mapred_queue_priority)] + ) if self.mapred_job_name: - hive_conf_params.extend( - ['-hiveconf', - 'mapred.job.name={}' - .format(self.mapred_job_name)]) + hive_conf_params.extend(['-hiveconf', 'mapred.job.name={}'.format(self.mapred_job_name)]) hive_cmd.extend(hive_conf_params) hive_cmd.extend(['-f', f.name]) @@ -251,11 +246,8 @@ def run_cli(self, if verbose: self.log.info("%s", " ".join(hive_cmd)) sub_process: Any = subprocess.Popen( - hive_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=tmp_dir, - close_fds=True) + hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True + ) self.sub_process = sub_process stdout = '' while True: @@ -284,9 +276,7 @@ def test_hql(self, hql: Union[str, Text]) -> None: if query.startswith('create table'): create.append(query_original) - elif query.startswith(('set ', - 'add jar ', - 'create temporary function')): + elif query.startswith(('set ', 'add jar ', 'create temporary function')): other.append(query_original) elif query.startswith('insert'): insert.append(query_original) @@ -323,7 +313,7 @@ def load_df( delimiter: str = ',', encoding: str = 'utf8', pandas_kwargs: Any = None, - **kwargs: Any + **kwargs: Any, ) -> None: """ Loads a pandas DataFrame into hive. @@ -348,9 +338,7 @@ def load_df( :param kwargs: passed to self.load_file """ - def _infer_field_types_from_df( - df: pandas.DataFrame - ) -> Dict[Any, Any]: + def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]: dtype_kind_hive_type = { 'b': 'BOOLEAN', # boolean 'i': 'BIGINT', # signed integer @@ -361,7 +349,7 @@ def _infer_field_types_from_df( 'O': 'STRING', # object 'S': 'STRING', # (byte-)string 'U': 'STRING', # Unicode - 'V': 'STRING' # void + 'V': 'STRING', # void } order_type = OrderedDict() @@ -377,20 +365,20 @@ def _infer_field_types_from_df( if field_dict is None: field_dict = _infer_field_types_from_df(df) - df.to_csv(path_or_buf=f, - sep=delimiter, - header=False, - index=False, - encoding=encoding, - date_format="%Y-%m-%d %H:%M:%S", - **pandas_kwargs) + df.to_csv( + path_or_buf=f, + sep=delimiter, + header=False, + index=False, + encoding=encoding, + date_format="%Y-%m-%d %H:%M:%S", + **pandas_kwargs, + ) f.flush() - return self.load_file(filepath=f.name, - table=table, - delimiter=delimiter, - field_dict=field_dict, - **kwargs) + return self.load_file( + filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs + ) def load_file( self, @@ -402,7 +390,7 @@ def load_file( overwrite: bool = True, partition: Optional[Dict[str, Any]] = None, recreate: bool = False, - tblproperties: Optional[Dict[str, Any]] = None + tblproperties: Optional[Dict[str, Any]] = None, ) -> None: """ Loads a local file into Hive @@ -444,20 +432,16 @@ def load_file( if create or recreate: if field_dict is None: raise ValueError("Must provide a field dict when creating a table") - fields = ",\n ".join( - ['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()]) - hql += "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n".format( - table=table, fields=fields) + fields = ",\n ".join(['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()]) + hql += "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n".format(table=table, fields=fields) if partition: - pfields = ",\n ".join( - [p + " STRING" for p in partition]) + pfields = ",\n ".join([p + " STRING" for p in partition]) hql += "PARTITIONED BY ({pfields})\n".format(pfields=pfields) hql += "ROW FORMAT DELIMITED\n" hql += "FIELDS TERMINATED BY '{delimiter}'\n".format(delimiter=delimiter) hql += "STORED AS textfile\n" if tblproperties is not None: - tprops = ", ".join( - ["'{0}'='{1}'".format(k, v) for k, v in tblproperties.items()]) + tprops = ", ".join(["'{0}'='{1}'".format(k, v) for k, v in tblproperties.items()]) hql += "TBLPROPERTIES({tprops})\n".format(tprops=tprops) hql += ";" self.log.info(hql) @@ -467,8 +451,7 @@ def load_file( hql += "OVERWRITE " hql += "INTO TABLE {table} ".format(table=table) if partition: - pvals = ", ".join( - ["{0}='{1}'".format(k, v) for k, v in partition.items()]) + pvals = ", ".join(["{0}='{1}'".format(k, v) for k, v in partition.items()]) hql += "PARTITION ({pvals})".format(pvals=pvals) # As a workaround for HIVE-10541, add a newline character @@ -547,6 +530,7 @@ def sasl_factory() -> sasl.Client: return sasl_client from thrift_sasl import TSaslClientTransport + transport = TSaslClientTransport(sasl_factory, "GSSAPI", conn_socket) else: transport = TTransport.TBufferedTransport(conn_socket) @@ -590,16 +574,11 @@ def check_for_partition(self, schema: str, table: str, partition: str) -> bool: True """ with self.metastore as client: - partitions = client.get_partitions_by_filter( - schema, table, partition, 1) + partitions = client.get_partitions_by_filter(schema, table, partition, 1) return bool(partitions) - def check_for_named_partition(self, - schema: str, - table: str, - partition_name: str - ) -> Any: + def check_for_named_partition(self, schema: str, table: str, partition_name: str) -> Any: """ Checks whether a partition with a given name exists @@ -651,9 +630,9 @@ def get_databases(self, pattern: str = '*') -> Any: with self.metastore as client: return client.get_databases(pattern) - def get_partitions(self, schema: str, table_name: str, - partition_filter: Optional[str] = None - ) -> List[Any]: + def get_partitions( + self, schema: str, table_name: str, partition_filter: Optional[str] = None + ) -> List[Any]: """ Returns a list of all partitions in a table. Works only for tables with less than 32767 (java short max val). @@ -674,21 +653,23 @@ def get_partitions(self, schema: str, table_name: str, else: if partition_filter: parts = client.get_partitions_by_filter( - db_name=schema, tbl_name=table_name, - filter=partition_filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT) + db_name=schema, + tbl_name=table_name, + filter=partition_filter, + max_parts=HiveMetastoreHook.MAX_PART_COUNT, + ) else: parts = client.get_partitions( - db_name=schema, tbl_name=table_name, - max_parts=HiveMetastoreHook.MAX_PART_COUNT) + db_name=schema, tbl_name=table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT + ) pnames = [p.name for p in table.partitionKeys] return [dict(zip(pnames, p.values)) for p in parts] @staticmethod - def _get_max_partition_from_part_specs(part_specs: List[Any], - partition_key: Optional[str], - filter_map: Optional[Dict[str, Any]] - ) -> Any: + def _get_max_partition_from_part_specs( + part_specs: List[Any], partition_key: Optional[str], filter_map: Optional[Dict[str, Any]] + ) -> Any: """ Helper method to get max partition of partitions with partition_key from part specs. key:value pair in filter_map will be used to @@ -711,30 +692,36 @@ def _get_max_partition_from_part_specs(part_specs: List[Any], # Assuming all specs have the same keys. if partition_key not in part_specs[0].keys(): - raise AirflowException("Provided partition_key {} " - "is not in part_specs.".format(partition_key)) + raise AirflowException("Provided partition_key {} " "is not in part_specs.".format(partition_key)) is_subset = None if filter_map: is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys())) if filter_map and not is_subset: - raise AirflowException("Keys in provided filter_map {} " - "are not subset of part_spec keys: {}" - .format(', '.join(filter_map.keys()), - ', '.join(part_specs[0].keys()))) + raise AirflowException( + "Keys in provided filter_map {} " + "are not subset of part_spec keys: {}".format( + ', '.join(filter_map.keys()), ', '.join(part_specs[0].keys()) + ) + ) - candidates = [p_dict[partition_key] for p_dict in part_specs - if filter_map is None or - all(item in p_dict.items() for item in filter_map.items())] + candidates = [ + p_dict[partition_key] + for p_dict in part_specs + if filter_map is None or all(item in p_dict.items() for item in filter_map.items()) + ] if not candidates: return None else: return max(candidates) - def max_partition(self, schema: str, table_name: str, - field: Optional[str] = None, - filter_map: Optional[Dict[Any, Any]] = None - ) -> Any: + def max_partition( + self, + schema: str, + table_name: str, + field: Optional[str] = None, + filter_map: Optional[Dict[Any, Any]] = None, + ) -> Any: """ Returns the maximum value for all partitions with given field in a table. If only one partition key exist in the table, the key will be used as field. @@ -763,25 +750,19 @@ def max_partition(self, schema: str, table_name: str, if len(table.partitionKeys) == 1: field = table.partitionKeys[0].name elif not field: - raise AirflowException("Please specify the field you want the max " - "value for.") + raise AirflowException("Please specify the field you want the max " "value for.") elif field not in key_name_set: raise AirflowException("Provided field is not a partition key.") if filter_map and not set(filter_map.keys()).issubset(key_name_set): - raise AirflowException("Provided filter_map contains keys " - "that are not partition key.") + raise AirflowException("Provided filter_map contains keys " "that are not partition key.") - part_names = \ - client.get_partition_names(schema, - table_name, - max_parts=HiveMetastoreHook.MAX_PART_COUNT) - part_specs = [client.partition_name_to_spec(part_name) - for part_name in part_names] + part_names = client.get_partition_names( + schema, table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT + ) + part_specs = [client.partition_name_to_spec(part_name) for part_name in part_names] - return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs, - field, - filter_map) + return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs, field, filter_map) def table_exists(self, table_name: str, db: str = 'default') -> bool: """ @@ -820,8 +801,9 @@ def drop_partitions(self, table_name, part_vals, delete_data=False, db='default' """ if self.table_exists(table_name, db): with self.metastore as client: - self.log.info("Dropping partition of table %s.%s matching the spec: %s", - db, table_name, part_vals) + self.log.info( + "Dropping partition of table %s.%s matching the spec: %s", db, table_name, part_vals + ) return client.drop_partition(db, table_name, part_vals, delete_data) else: self.log.info("Table %s.%s does not exist!", db, table_name) @@ -839,12 +821,12 @@ class HiveServer2Hook(DbApiHook): are using impala you may need to set it to false in the ``extra`` of your connection in the UI """ + conn_name_attr = 'hiveserver2_conn_id' default_conn_name = 'hiveserver2_default' supports_autocommit = False - def get_conn(self, schema: Optional[str] = None - ) -> Any: + def get_conn(self, schema: Optional[str] = None) -> Any: """ Returns a Hive connection object. """ @@ -864,13 +846,13 @@ def get_conn(self, schema: Optional[str] = None # pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier if auth_mechanism == 'GSSAPI': self.log.warning( - "Detected deprecated 'GSSAPI' for authMechanism " - "for %s. Please use 'KERBEROS' instead", - self.hiveserver2_conn_id # type: ignore + "Detected deprecated 'GSSAPI' for authMechanism " "for %s. Please use 'KERBEROS' instead", + self.hiveserver2_conn_id, # type: ignore ) auth_mechanism = 'KERBEROS' from pyhive.hive import connect + return connect( host=db.host, port=db.port, @@ -878,14 +860,20 @@ def get_conn(self, schema: Optional[str] = None kerberos_service_name=kerberos_service_name, username=db.login or username, password=db.password, - database=schema or db.schema or 'default') + database=schema or db.schema or 'default', + ) # pylint: enable=no-member - def _get_results(self, hql: Union[str, Text, List[str]], schema: str = 'default', - fetch_size: Optional[int] = None, - hive_conf: Optional[Dict[Any, Any]] = None) -> Any: + def _get_results( + self, + hql: Union[str, Text, List[str]], + schema: str = 'default', + fetch_size: Optional[int] = None, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Any: from pyhive.exc import ProgrammingError + if isinstance(hql, str): hql = [hql] previous_description = None @@ -908,17 +896,19 @@ def _get_results(self, hql: Union[str, Text, List[str]], schema: str = 'default' cur.execute(statement) # we only get results of statements that returns lowered_statement = statement.lower().strip() - if (lowered_statement.startswith('select') or - lowered_statement.startswith('with') or - lowered_statement.startswith('show') or - (lowered_statement.startswith('set') and - '=' not in lowered_statement)): + if ( + lowered_statement.startswith('select') + or lowered_statement.startswith('with') + or lowered_statement.startswith('show') + or (lowered_statement.startswith('set') and '=' not in lowered_statement) + ): description = cur.description if previous_description and previous_description != description: message = '''The statements are producing different descriptions: Current: {} - Previous: {}'''.format(repr(description), - repr(previous_description)) + Previous: {}'''.format( + repr(description), repr(previous_description) + ) raise ValueError(message) elif not previous_description: previous_description = description @@ -931,10 +921,13 @@ def _get_results(self, hql: Union[str, Text, List[str]], schema: str = 'default' except ProgrammingError: self.log.debug("get_results returned no records") - def get_results(self, hql: Union[str, Text], schema: str = 'default', - fetch_size: Optional[int] = None, - hive_conf: Optional[Dict[Any, Any]] = None - ) -> Dict[str, Any]: + def get_results( + self, + hql: Union[str, Text], + schema: str = 'default', + fetch_size: Optional[int] = None, + hive_conf: Optional[Dict[Any, Any]] = None, + ) -> Dict[str, Any]: """ Get results of the provided hql in target schema. @@ -949,13 +942,9 @@ def get_results(self, hql: Union[str, Text], schema: str = 'default', :return: results of hql execution, dict with data (list of results) and header :rtype: dict """ - results_iter = self._get_results(hql, schema, - fetch_size=fetch_size, hive_conf=hive_conf) + results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) - results = { - 'data': list(results_iter), - 'header': header - } + results = {'data': list(results_iter), 'header': header} return results def to_csv( @@ -967,7 +956,7 @@ def to_csv( lineterminator: str = '\r\n', output_header: bool = True, fetch_size: int = 1000, - hive_conf: Optional[Dict[Any, Any]] = None + hive_conf: Optional[Dict[Any, Any]] = None, ) -> None: """ Execute hql in target schema and write results to a csv file. @@ -991,17 +980,13 @@ def to_csv( """ - results_iter = self._get_results(hql, schema, - fetch_size=fetch_size, hive_conf=hive_conf) + results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) message = None i = 0 with open(csv_filepath, 'wb') as file: - writer = csv.writer(file, - delimiter=delimiter, - lineterminator=lineterminator, - encoding='utf-8') + writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding='utf-8') try: if output_header: self.log.debug('Cursor description is %s', header) @@ -1021,10 +1006,9 @@ def to_csv( self.log.info("Done. Loaded a total of %s rows.", i) - def get_records(self, hql: Union[str, Text], - schema: str = 'default', - hive_conf: Optional[Dict[Any, Any]] = None - ) -> Any: + def get_records( + self, hql: Union[str, Text], schema: str = 'default', hive_conf: Optional[Dict[Any, Any]] = None + ) -> Any: """ Get a set of records from a Hive query. @@ -1044,11 +1028,13 @@ def get_records(self, hql: Union[str, Text], """ return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data'] - def get_pandas_df(self, hql: Union[str, Text], # type: ignore - schema: str = 'default', - hive_conf: Optional[Dict[Any, Any]] = None, - **kwargs - ) -> pandas.DataFrame: + def get_pandas_df( # type: ignore + self, + hql: Union[str, Text], + schema: str = 'default', + hive_conf: Optional[Dict[Any, Any]] = None, + **kwargs, + ) -> pandas.DataFrame: """ Get a pandas dataframe from a Hive query diff --git a/airflow/providers/apache/hive/operators/hive.py b/airflow/providers/apache/hive/operators/hive.py index 48d75742de5ed..1db5e9908a63c 100644 --- a/airflow/providers/apache/hive/operators/hive.py +++ b/airflow/providers/apache/hive/operators/hive.py @@ -62,26 +62,37 @@ class HiveOperator(BaseOperator): :type mapred_job_name: str """ - template_fields = ('hql', 'schema', 'hive_cli_conn_id', 'mapred_queue', - 'hiveconfs', 'mapred_job_name', 'mapred_queue_priority') - template_ext = ('.hql', '.sql',) + template_fields = ( + 'hql', + 'schema', + 'hive_cli_conn_id', + 'mapred_queue', + 'hiveconfs', + 'mapred_job_name', + 'mapred_queue_priority', + ) + template_ext = ( + '.hql', + '.sql', + ) ui_color = '#f0e4ec' # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, - hql: str, - hive_cli_conn_id: str = 'hive_cli_default', - schema: str = 'default', - hiveconfs: Optional[Dict[Any, Any]] = None, - hiveconf_jinja_translate: bool = False, - script_begin_tag: Optional[str] = None, - run_as_owner: bool = False, - mapred_queue: Optional[str] = None, - mapred_queue_priority: Optional[str] = None, - mapred_job_name: Optional[str] = None, - **kwargs: Any + self, + *, + hql: str, + hive_cli_conn_id: str = 'hive_cli_default', + schema: str = 'default', + hiveconfs: Optional[Dict[Any, Any]] = None, + hiveconf_jinja_translate: bool = False, + script_begin_tag: Optional[str] = None, + run_as_owner: bool = False, + mapred_queue: Optional[str] = None, + mapred_queue_priority: Optional[str] = None, + mapred_job_name: Optional[str] = None, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.hql = hql @@ -97,8 +108,10 @@ def __init__( self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name self.mapred_job_name_template = conf.get( - 'hive', 'mapred_job_name_template', - fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}") + 'hive', + 'mapred_job_name_template', + fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}", + ) # assigned lazily - just for consistency we can create the attribute with a # `None` initial value, later it will be populated by the execute method. @@ -115,12 +128,12 @@ def get_hook(self) -> HiveCliHook: run_as=self.run_as, mapred_queue=self.mapred_queue, mapred_queue_priority=self.mapred_queue_priority, - mapred_job_name=self.mapred_job_name) + mapred_job_name=self.mapred_job_name, + ) def prepare_template(self) -> None: if self.hiveconf_jinja_translate: - self.hql = re.sub( - r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql) + self.hql = re.sub(r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql) if self.script_begin_tag and self.script_begin_tag in self.hql: self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:]) @@ -131,10 +144,12 @@ def execute(self, context: Dict[str, Any]) -> None: # set the mapred_job_name if it's not set with dag, task, execution time info if not self.mapred_job_name: ti = context['ti'] - self.hook.mapred_job_name = self.mapred_job_name_template\ - .format(dag_id=ti.dag_id, task_id=ti.task_id, - execution_date=ti.execution_date.isoformat(), - hostname=ti.hostname.split('.')[0]) + self.hook.mapred_job_name = self.mapred_job_name_template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + execution_date=ti.execution_date.isoformat(), + hostname=ti.hostname.split('.')[0], + ) if self.hiveconf_jinja_translate: self.hiveconfs = context_to_airflow_vars(context) @@ -160,6 +175,7 @@ def clear_airflow_vars(self) -> None: """ Reset airflow environment variables to prevent existing ones from impacting behavior. """ - blank_env_vars = {value['env_var_format']: '' for value in - operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()} + blank_env_vars = { + value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + } os.environ.update(blank_env_vars) diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py index 6fc689e82fc52..4dfef2cebeed8 100644 --- a/airflow/providers/apache/hive/operators/hive_stats.py +++ b/airflow/providers/apache/hive/operators/hive_stats.py @@ -63,24 +63,25 @@ class HiveStatsCollectionOperator(BaseOperator): ui_color = '#aff7a6' @apply_defaults - def __init__(self, *, - table: str, - partition: Any, - extra_exprs: Optional[Dict[str, Any]] = None, - excluded_columns: Optional[List[str]] = None, - assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None, - metastore_conn_id: str = 'metastore_default', - presto_conn_id: str = 'presto_default', - mysql_conn_id: str = 'airflow_db', - **kwargs: Any - ) -> None: + def __init__( + self, + *, + table: str, + partition: Any, + extra_exprs: Optional[Dict[str, Any]] = None, + excluded_columns: Optional[List[str]] = None, + assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None, + metastore_conn_id: str = 'metastore_default', + presto_conn_id: str = 'presto_default', + mysql_conn_id: str = 'airflow_db', + **kwargs: Any, + ) -> None: if 'col_blacklist' in kwargs: warnings.warn( 'col_blacklist kwarg passed to {c} (task_id: {t}) is deprecated, please rename it to ' - 'excluded_columns instead'.format( - c=self.__class__.__name__, t=kwargs.get('task_id')), + 'excluded_columns instead'.format(c=self.__class__.__name__, t=kwargs.get('task_id')), category=FutureWarning, - stacklevel=2 + stacklevel=2, ) excluded_columns = kwargs.pop('col_blacklist') super().__init__(**kwargs) @@ -121,9 +122,7 @@ def execute(self, context: Optional[Dict[str, Any]] = None) -> None: table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} - exprs: Any = { - ('', 'count'): 'COUNT(*)' - } + exprs: Any = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): if self.assignment_func: assign_exprs = self.assignment_func(col, col_type) @@ -134,14 +133,13 @@ def execute(self, context: Optional[Dict[str, Any]] = None) -> None: exprs.update(assign_exprs) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) - exprs_str = ",\n ".join([ - v + " AS " + k[0] + '__' + k[1] - for k, v in exprs.items()]) + exprs_str = ",\n ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause_) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( - exprs_str=exprs_str, table=self.table, where_clause=where_clause) + exprs_str=exprs_str, table=self.table, where_clause=where_clause + ) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) @@ -161,7 +159,9 @@ def execute(self, context: Optional[Dict[str, Any]] = None) -> None: partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; - """.format(table=self.table, part_json=part_json, dttm=self.dttm) + """.format( + table=self.table, part_json=part_json, dttm=self.dttm + ) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats @@ -169,22 +169,17 @@ def execute(self, context: Optional[Dict[str, Any]] = None) -> None: table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; - """.format(table=self.table, part_json=part_json, dttm=self.dttm) + """.format( + table=self.table, part_json=part_json, dttm=self.dttm + ) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") - rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) - for r in zip(exprs, row)] + rows = [ + (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row) + ] mysql.insert_rows( table='hive_stats', rows=rows, - target_fields=[ - 'ds', - 'dttm', - 'table_name', - 'partition_repr', - 'col', - 'metric', - 'value', - ] + target_fields=['ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',], ) diff --git a/airflow/providers/apache/hive/sensors/hive_partition.py b/airflow/providers/apache/hive/sensors/hive_partition.py index 8e1b8279da048..15843b8d4addc 100644 --- a/airflow/providers/apache/hive/sensors/hive_partition.py +++ b/airflow/providers/apache/hive/sensors/hive_partition.py @@ -42,19 +42,26 @@ class HivePartitionSensor(BaseSensorOperator): connection id :type metastore_conn_id: str """ - template_fields = ('schema', 'table', 'partition',) + + template_fields = ( + 'schema', + 'table', + 'partition', + ) ui_color = '#C5CAE9' @apply_defaults - def __init__(self, *, - table: str, - partition: Optional[str] = "ds='{{ ds }}'", - metastore_conn_id: str = 'metastore_default', - schema: str = 'default', - poke_interval: int = 60 * 3, - **kwargs: Any): - super().__init__( - poke_interval=poke_interval, **kwargs) + def __init__( + self, + *, + table: str, + partition: Optional[str] = "ds='{{ ds }}'", + metastore_conn_id: str = 'metastore_default', + schema: str = 'default', + poke_interval: int = 60 * 3, + **kwargs: Any, + ): + super().__init__(poke_interval=poke_interval, **kwargs) if not partition: partition = "ds='{{ ds }}'" self.metastore_conn_id = metastore_conn_id @@ -65,11 +72,7 @@ def __init__(self, *, def poke(self, context: Dict[str, Any]) -> bool: if '.' in self.table: self.schema, self.table = self.table.split('.') - self.log.info( - 'Poking for table %s.%s, partition %s', self.schema, self.table, self.partition - ) + self.log.info('Poking for table %s.%s, partition %s', self.schema, self.table, self.partition) if not hasattr(self, 'hook'): - hook = HiveMetastoreHook( - metastore_conn_id=self.metastore_conn_id) - return hook.check_for_partition( - self.schema, self.table, self.partition) + hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) + return hook.check_for_partition(self.schema, self.table, self.partition) diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py b/airflow/providers/apache/hive/sensors/metastore_partition.py index 1e54440f3ada7..31376ad1267a3 100644 --- a/airflow/providers/apache/hive/sensors/metastore_partition.py +++ b/airflow/providers/apache/hive/sensors/metastore_partition.py @@ -41,16 +41,20 @@ class MetastorePartitionSensor(SqlSensor): :param mysql_conn_id: a reference to the MySQL conn_id for the metastore :type mysql_conn_id: str """ + template_fields = ('partition_name', 'table', 'schema') ui_color = '#8da7be' @apply_defaults - def __init__(self, *, - table: str, - partition_name: str, - schema: str = "default", - mysql_conn_id: str = "metastore_mysql", - **kwargs: Any): + def __init__( + self, + *, + table: str, + partition_name: str, + schema: str = "default", + mysql_conn_id: str = "metastore_mysql", + **kwargs: Any, + ): self.partition_name = partition_name self.table = table @@ -78,5 +82,7 @@ def poke(self, context: Dict[str, Any]) -> Any: B0.TBL_NAME = '{self.table}' AND C0.NAME = '{self.schema}' AND A0.PART_NAME = '{self.partition_name}'; - """.format(self=self) + """.format( + self=self + ) return super().poke(context) diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py index f69e2b2d4ea51..23d9466f78968 100644 --- a/airflow/providers/apache/hive/sensors/named_hive_partition.py +++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py @@ -42,14 +42,16 @@ class NamedHivePartitionSensor(BaseSensorOperator): ui_color = '#8d99ae' @apply_defaults - def __init__(self, *, - partition_names: List[str], - metastore_conn_id: str = 'metastore_default', - poke_interval: int = 60 * 3, - hook: Any = None, - **kwargs: Any): - super().__init__( - poke_interval=poke_interval, **kwargs) + def __init__( + self, + *, + partition_names: List[str], + metastore_conn_id: str = 'metastore_default', + poke_interval: int = 60 * 3, + hook: Any = None, + **kwargs: Any, + ): + super().__init__(poke_interval=poke_interval, **kwargs) self.next_index_to_poke = 0 if isinstance(partition_names, str): @@ -74,8 +76,7 @@ def parse_partition_name(partition: str) -> Tuple[Any, ...]: schema, table_partition = first_split second_split = table_partition.split('/', 1) if len(second_split) == 1: - raise ValueError('Could not parse ' + partition + - 'into table, partition') + raise ValueError('Could not parse ' + partition + 'into table, partition') else: table, partition = second_split return schema, table, partition @@ -84,14 +85,13 @@ def poke_partition(self, partition: str) -> Any: """Check for a named partition.""" if not self.hook: from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook - self.hook = HiveMetastoreHook( - metastore_conn_id=self.metastore_conn_id) + + self.hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) schema, table, partition = self.parse_partition_name(partition) self.log.info('Poking for %s.%s/%s', schema, table, partition) - return self.hook.check_for_named_partition( - schema, table, partition) + return self.hook.check_for_named_partition(schema, table, partition) def poke(self, context: Dict[str, Any]) -> bool: diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py index 724c7919f8e80..11d81ec837f7f 100644 --- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py +++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py @@ -67,16 +67,19 @@ class HiveToMySqlOperator(BaseOperator): ui_color = '#a0e08c' @apply_defaults - def __init__(self, *, - sql: str, - mysql_table: str, - hiveserver2_conn_id: str = 'hiveserver2_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, - mysql_postoperator: Optional[str] = None, - bulk_load: bool = False, - hive_conf: Optional[Dict] = None, - **kwargs) -> None: + def __init__( + self, + *, + sql: str, + mysql_table: str, + hiveserver2_conn_id: str = 'hiveserver2_default', + mysql_conn_id: str = 'mysql_default', + mysql_preoperator: Optional[str] = None, + mysql_postoperator: Optional[str] = None, + bulk_load: bool = False, + hive_conf: Optional[Dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.mysql_table = mysql_table @@ -96,12 +99,14 @@ def execute(self, context): hive_conf.update(self.hive_conf) if self.bulk_load: tmp_file = NamedTemporaryFile() - hive.to_csv(self.sql, - tmp_file.name, - delimiter='\t', - lineterminator='\n', - output_header=False, - hive_conf=hive_conf) + hive.to_csv( + self.sql, + tmp_file.name, + delimiter='\t', + lineterminator='\n', + output_header=False, + hive_conf=hive_conf, + ) else: hive_results = hive.get_records(self.sql, hive_conf=hive_conf) diff --git a/airflow/providers/apache/hive/transfers/hive_to_samba.py b/airflow/providers/apache/hive/transfers/hive_to_samba.py index 5f08b83d09487..dc93297475f6c 100644 --- a/airflow/providers/apache/hive/transfers/hive_to_samba.py +++ b/airflow/providers/apache/hive/transfers/hive_to_samba.py @@ -45,15 +45,21 @@ class HiveToSambaOperator(BaseOperator): """ template_fields = ('hql', 'destination_filepath') - template_ext = ('.hql', '.sql',) + template_ext = ( + '.hql', + '.sql', + ) @apply_defaults - def __init__(self, *, - hql: str, - destination_filepath: str, - samba_conn_id: str = 'samba_default', - hiveserver2_conn_id: str = 'hiveserver2_default', - **kwargs) -> None: + def __init__( + self, + *, + hql: str, + destination_filepath: str, + samba_conn_id: str = 'samba_default', + hiveserver2_conn_id: str = 'hiveserver2_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.hiveserver2_conn_id = hiveserver2_conn_id self.samba_conn_id = samba_conn_id diff --git a/airflow/providers/apache/hive/transfers/mssql_to_hive.py b/airflow/providers/apache/hive/transfers/mssql_to_hive.py index 01a932724f9cb..8f32ca28139f7 100644 --- a/airflow/providers/apache/hive/transfers/mssql_to_hive.py +++ b/airflow/providers/apache/hive/transfers/mssql_to_hive.py @@ -76,17 +76,20 @@ class MsSqlToHiveOperator(BaseOperator): ui_color = '#a0e08c' @apply_defaults - def __init__(self, *, - sql: str, - hive_table: str, - create: bool = True, - recreate: bool = False, - partition: Optional[Dict] = None, - delimiter: str = chr(1), - mssql_conn_id: str = 'mssql_default', - hive_cli_conn_id: str = 'hive_cli_default', - tblproperties: Optional[Dict] = None, - **kwargs) -> None: + def __init__( + self, + *, + sql: str, + hive_table: str, + create: bool = True, + recreate: bool = False, + partition: Optional[Dict] = None, + delimiter: str = chr(1), + mssql_conn_id: str = 'mssql_default', + hive_cli_conn_id: str = 'hive_cli_default', + tblproperties: Optional[Dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.hive_table = hive_table @@ -138,4 +141,5 @@ def execute(self, context: Dict[str, str]): partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, - tblproperties=self.tblproperties) + tblproperties=self.tblproperties, + ) diff --git a/airflow/providers/apache/hive/transfers/mysql_to_hive.py b/airflow/providers/apache/hive/transfers/mysql_to_hive.py index 99650ec59f024..25aa802ddb080 100644 --- a/airflow/providers/apache/hive/transfers/mysql_to_hive.py +++ b/airflow/providers/apache/hive/transfers/mysql_to_hive.py @@ -85,21 +85,22 @@ class MySqlToHiveOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, - *, - sql: str, - hive_table: str, - create: bool = True, - recreate: bool = False, - partition: Optional[Dict] = None, - delimiter: str = chr(1), - quoting: Optional[str] = None, - quotechar: str = '"', - escapechar: Optional[str] = None, - mysql_conn_id: str = 'mysql_default', - hive_cli_conn_id: str = 'hive_cli_default', - tblproperties: Optional[Dict] = None, - **kwargs) -> None: + self, + *, + sql: str, + hive_table: str, + create: bool = True, + recreate: bool = False, + partition: Optional[Dict] = None, + delimiter: str = chr(1), + quoting: Optional[str] = None, + quotechar: str = '"', + escapechar: Optional[str] = None, + mysql_conn_id: str = 'mysql_default', + hive_cli_conn_id: str = 'hive_cli_default', + tblproperties: Optional[Dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.hive_table = hive_table @@ -146,11 +147,14 @@ def execute(self, context: Dict[str, str]): cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: - csv_writer = csv.writer(f, delimiter=self.delimiter, - quoting=self.quoting, - quotechar=self.quotechar, - escapechar=self.escapechar, - encoding="utf-8") + csv_writer = csv.writer( + f, + delimiter=self.delimiter, + quoting=self.quoting, + quotechar=self.quotechar, + escapechar=self.escapechar, + encoding="utf-8", + ) field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) @@ -167,4 +171,5 @@ def execute(self, context: Dict[str, str]): partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, - tblproperties=self.tblproperties) + tblproperties=self.tblproperties, + ) diff --git a/airflow/providers/apache/hive/transfers/s3_to_hive.py b/airflow/providers/apache/hive/transfers/s3_to_hive.py index 6c730a0a782bd..844777e40d869 100644 --- a/airflow/providers/apache/hive/transfers/s3_to_hive.py +++ b/airflow/providers/apache/hive/transfers/s3_to_hive.py @@ -107,25 +107,26 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, - *, - s3_key: str, - field_dict: Dict, - hive_table: str, - delimiter: str = ',', - create: bool = True, - recreate: bool = False, - partition: Optional[Dict] = None, - headers: bool = False, - check_headers: bool = False, - wildcard_match: bool = False, - aws_conn_id: str = 'aws_default', - verify: Optional[Union[bool, str]] = None, - hive_cli_conn_id: str = 'hive_cli_default', - input_compressed: bool = False, - tblproperties: Optional[Dict] = None, - select_expression: Optional[str] = None, - **kwargs) -> None: + self, + *, + s3_key: str, + field_dict: Dict, + hive_table: str, + delimiter: str = ',', + create: bool = True, + recreate: bool = False, + partition: Optional[Dict] = None, + headers: bool = False, + check_headers: bool = False, + wildcard_match: bool = False, + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + hive_cli_conn_id: str = 'hive_cli_default', + input_compressed: bool = False, + tblproperties: Optional[Dict] = None, + select_expression: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.s3_key = s3_key self.field_dict = field_dict @@ -144,10 +145,8 @@ def __init__( # pylint: disable=too-many-arguments self.tblproperties = tblproperties self.select_expression = select_expression - if (self.check_headers and - not (self.field_dict is not None and self.headers)): - raise AirflowException("To check_headers provide " + - "field_dict and headers") + if self.check_headers and not (self.field_dict is not None and self.headers): + raise AirflowException("To check_headers provide " + "field_dict and headers") def execute(self, context): # Downloading file from S3 @@ -165,18 +164,13 @@ def execute(self, context): s3_key_object = s3_hook.get_key(self.s3_key) _, file_ext = os.path.splitext(s3_key_object.key) - if (self.select_expression and self.input_compressed and - file_ext.lower() != '.gz'): - raise AirflowException("GZIP is the only compression " + - "format Amazon S3 Select supports") + if self.select_expression and self.input_compressed and file_ext.lower() != '.gz': + raise AirflowException("GZIP is the only compression " + "format Amazon S3 Select supports") - with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ - NamedTemporaryFile(mode="wb", - dir=tmp_dir, - suffix=file_ext) as f: - self.log.info( - "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name - ) + with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir, NamedTemporaryFile( + mode="wb", dir=tmp_dir, suffix=file_ext + ) as f: + self.log.info("Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name) if self.select_expression: option = {} if self.headers: @@ -192,7 +186,7 @@ def execute(self, context): bucket_name=s3_key_object.bucket_name, key=s3_key_object.key, expression=self.select_expression, - input_serialization=input_serialization + input_serialization=input_serialization, ) f.write(content.encode("utf-8")) else: @@ -209,14 +203,13 @@ def execute(self, context): partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, - tblproperties=self.tblproperties) + tblproperties=self.tblproperties, + ) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) - fn_uncompressed = uncompress_file(f.name, - file_ext, - tmp_dir) + fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space @@ -233,20 +226,19 @@ def execute(self, context): # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) - headless_file = ( - self._delete_top_row_and_compress(fn_uncompressed, - file_ext, - tmp_dir)) + headless_file = self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) - hive_hook.load_file(headless_file, - self.hive_table, - field_dict=self.field_dict, - create=self.create, - partition=self.partition, - delimiter=self.delimiter, - recreate=self.recreate, - tblproperties=self.tblproperties) + hive_hook.load_file( + headless_file, + self.hive_table, + field_dict=self.field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate, + tblproperties=self.tblproperties, + ) def _get_top_row_as_list(self, file_name): with open(file_name, 'rt') as file: @@ -263,22 +255,19 @@ def _match_headers(self, header_list): "Headers count mismatch File headers:\n %s\nField names: \n %s\n", header_list, field_names ) return False - test_field_match = [h1.lower() == h2.lower() - for h1, h2 in zip(header_list, field_names)] + test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): self.log.warning( "Headers do not match field names File headers:\n %s\nField names: \n %s\n", - header_list, field_names + header_list, + field_names, ) return False else: return True @staticmethod - def _delete_top_row_and_compress( - input_file_name, - output_file_ext, - dest_dir): + def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir): # When output_file_ext is not defined, file is not compressed open_fn = open if output_file_ext.lower() == '.gz': diff --git a/airflow/providers/apache/hive/transfers/vertica_to_hive.py b/airflow/providers/apache/hive/transfers/vertica_to_hive.py index 02a4f80ab2c66..66c9790b0588c 100644 --- a/airflow/providers/apache/hive/transfers/vertica_to_hive.py +++ b/airflow/providers/apache/hive/transfers/vertica_to_hive.py @@ -73,17 +73,18 @@ class VerticaToHiveOperator(BaseOperator): @apply_defaults def __init__( - self, - *, - sql, - hive_table, - create=True, - recreate=False, - partition=None, - delimiter=chr(1), - vertica_conn_id='vertica_default', - hive_cli_conn_id='hive_cli_default', - **kwargs): + self, + *, + sql, + hive_table, + create=True, + recreate=False, + partition=None, + delimiter=chr(1), + vertica_conn_id='vertica_default', + hive_cli_conn_id='hive_cli_default', + **kwargs, + ): super().__init__(**kwargs) self.sql = sql self.hive_table = hive_table @@ -127,8 +128,7 @@ def execute(self, context): for field in cursor.description: col_count += 1 col_position = "Column{position}".format(position=col_count) - field_dict[col_position if field[0] == '' else field[0]] = \ - self.type_map(field[1]) + field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor.iterate()) f.flush() cursor.close() @@ -141,4 +141,5 @@ def execute(self, context): create=self.create, partition=self.partition, delimiter=self.delimiter, - recreate=self.recreate) + recreate=self.recreate, + ) diff --git a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py b/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py index 8ff685abd4961..f7901d4de132e 100644 --- a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py +++ b/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py @@ -34,7 +34,7 @@ default_args=args, schedule_interval=None, start_date=days_ago(1), - tags=['example'] + tags=['example'], ) @@ -49,11 +49,7 @@ def gen_build_time(**kwargs): ti.xcom_push(key='date_end', value='1325433600000') -gen_build_time_task = PythonOperator( - python_callable=gen_build_time, - task_id='gen_build_time', - dag=dag -) +gen_build_time_task = PythonOperator(python_callable=gen_build_time, task_id='gen_build_time', dag=dag) build_task1 = KylinCubeOperator( task_id="kylin_build_1", diff --git a/airflow/providers/apache/kylin/hooks/kylin.py b/airflow/providers/apache/kylin/hooks/kylin.py index 8a880e3e276be..991815d6fc7bc 100644 --- a/airflow/providers/apache/kylin/hooks/kylin.py +++ b/airflow/providers/apache/kylin/hooks/kylin.py @@ -33,11 +33,13 @@ class KylinHook(BaseHook): :param dsn: dsn :type dsn: Optional[str] """ - def __init__(self, - kylin_conn_id: Optional[str] = 'kylin_default', - project: Optional[str] = None, - dsn: Optional[str] = None - ): + + def __init__( + self, + kylin_conn_id: Optional[str] = 'kylin_default', + project: Optional[str] = None, + dsn: Optional[str] = None, + ): super().__init__() self.kylin_conn_id = kylin_conn_id self.project = project @@ -49,9 +51,14 @@ def get_conn(self): return kylinpy.create_kylin(self.dsn) else: self.project = self.project if self.project else conn.schema - return kylinpy.Kylin(conn.host, username=conn.login, - password=conn.password, port=conn.port, - project=self.project, **conn.extra_dejson) + return kylinpy.Kylin( + conn.host, + username=conn.login, + password=conn.password, + port=conn.port, + project=self.project, + **conn.extra_dejson, + ) def cube_run(self, datasource_name, op, **op_args): """ diff --git a/airflow/providers/apache/kylin/operators/kylin_cube.py b/airflow/providers/apache/kylin/operators/kylin_cube.py index 5a8cdbc1199a1..a7326897313e4 100644 --- a/airflow/providers/apache/kylin/operators/kylin_cube.py +++ b/airflow/providers/apache/kylin/operators/kylin_cube.py @@ -87,31 +87,50 @@ class KylinCubeOperator(BaseOperator): :type eager_error_status: tuple """ - template_fields = ('project', 'cube', 'dsn', 'command', 'start_time', 'end_time', - 'segment_name', 'offset_start', 'offset_end') + template_fields = ( + 'project', + 'cube', + 'dsn', + 'command', + 'start_time', + 'end_time', + 'segment_name', + 'offset_start', + 'offset_end', + ) ui_color = '#E79C46' - build_command = {'fullbuild', 'build', 'merge', 'refresh', 'build_streaming', - 'merge_streaming', 'refresh_streaming'} + build_command = { + 'fullbuild', + 'build', + 'merge', + 'refresh', + 'build_streaming', + 'merge_streaming', + 'refresh_streaming', + } jobs_end_status = {"FINISHED", "ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"} # pylint: disable=too-many-arguments,inconsistent-return-statements @apply_defaults - def __init__(self, *, - kylin_conn_id: Optional[str] = 'kylin_default', - project: Optional[str] = None, - cube: Optional[str] = None, - dsn: Optional[str] = None, - command: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - offset_start: Optional[str] = None, - offset_end: Optional[str] = None, - segment_name: Optional[str] = None, - is_track_job: Optional[bool] = False, - interval: int = 60, - timeout: int = 60 * 60 * 24, - eager_error_status=("ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"), - **kwargs): + def __init__( + self, + *, + kylin_conn_id: Optional[str] = 'kylin_default', + project: Optional[str] = None, + cube: Optional[str] = None, + dsn: Optional[str] = None, + command: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + offset_start: Optional[str] = None, + offset_end: Optional[str] = None, + segment_name: Optional[str] = None, + is_track_job: Optional[bool] = False, + interval: int = 60, + timeout: int = 60 * 60 * 24, + eager_error_status=("ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"), + **kwargs, + ): super().__init__(**kwargs) self.kylin_conn_id = kylin_conn_id self.project = project @@ -135,15 +154,18 @@ def execute(self, context): _support_invoke_command = kylinpy.CubeSource.support_invoke_command if self.command.lower() not in _support_invoke_command: - raise AirflowException('Kylin:Command {} can not match kylin command list {}'.format( - self.command, _support_invoke_command)) + raise AirflowException( + 'Kylin:Command {} can not match kylin command list {}'.format( + self.command, _support_invoke_command + ) + ) kylinpy_params = { 'start': datetime.fromtimestamp(int(self.start_time) / 1000) if self.start_time else None, 'end': datetime.fromtimestamp(int(self.end_time) / 1000) if self.end_time else None, 'name': self.segment_name, 'offset_start': int(self.offset_start) if self.offset_start else None, - 'offset_end': int(self.offset_end) if self.offset_end else None + 'offset_end': int(self.offset_end) if self.offset_end else None, } rsp_data = _hook.cube_run(self.cube, self.command.lower(), **kylinpy_params) if self.is_track_job and self.command.lower() in self.build_command: @@ -162,8 +184,7 @@ def execute(self, context): job_status = _hook.get_job_status(job_id) self.log.info('Kylin job status is %s ', job_status) if job_status in self.jobs_error_status: - raise AirflowException( - 'Kylin job {} status {} is error '.format(job_id, job_status)) + raise AirflowException('Kylin job {} status {} is error '.format(job_id, job_status)) if self.do_xcom_push: return rsp_data diff --git a/airflow/providers/apache/livy/example_dags/example_livy.py b/airflow/providers/apache/livy/example_dags/example_livy.py index 9e561c9b26606..e8245e2e330ee 100644 --- a/airflow/providers/apache/livy/example_dags/example_livy.py +++ b/airflow/providers/apache/livy/example_dags/example_livy.py @@ -25,17 +25,10 @@ from airflow.providers.apache.livy.operators.livy import LivyOperator from airflow.utils.dates import days_ago -args = { - 'owner': 'airflow', - 'email': ['airflow@example.com'], - 'depends_on_past': False -} +args = {'owner': 'airflow', 'email': ['airflow@example.com'], 'depends_on_past': False} with DAG( - dag_id='example_livy_operator', - default_args=args, - schedule_interval='@daily', - start_date=days_ago(5), + dag_id='example_livy_operator', default_args=args, schedule_interval='@daily', start_date=days_ago(5), ) as dag: livy_java_task = LivyOperator( @@ -45,9 +38,7 @@ file='/spark-examples.jar', args=[10], num_executors=1, - conf={ - 'spark.shuffle.compress': 'false', - }, + conf={'spark.shuffle.compress': 'false',}, class_name='org.apache.spark.examples.SparkPi', ) diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index f0030872bddfd..1d614c07ec247 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -34,6 +34,7 @@ class BatchState(Enum): """ Batch session states """ + NOT_STARTED = 'not_started' STARTING = 'starting' RUNNING = 'running' @@ -65,10 +66,7 @@ class LivyHook(HttpHook, LoggingMixin): BatchState.ERROR, } - _def_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } + _def_headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} def __init__(self, livy_conn_id: str = 'livy_default') -> None: super(LivyHook, self).__init__(http_conn_id=livy_conn_id) @@ -93,7 +91,7 @@ def run_method( method: str = 'GET', data: Optional[Any] = None, headers: Optional[Dict[str, Any]] = None, - extra_options: Optional[Dict[Any, Any]] = None + extra_options: Optional[Dict[Any, Any]] = None, ) -> Any: """ Wrapper for HttpHook, allows to change method on the same HttpHook @@ -138,20 +136,17 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any: self.get_conn() self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url) - response = self.run_method( - method='POST', - endpoint='/batches', - data=batch_submit_body - ) + response = self.run_method(method='POST', endpoint='/batches', data=batch_submit_body) self.log.debug("Got response: %s", response.text) try: response.raise_for_status() except requests.exceptions.HTTPError as err: - raise AirflowException("Could not submit batch. Status code: {}. Message: '{}'".format( - err.response.status_code, - err.response.text - )) + raise AirflowException( + "Could not submit batch. Status code: {}. Message: '{}'".format( + err.response.status_code, err.response.text + ) + ) batch_id = self._parse_post_response(response.json()) if batch_id is None: @@ -178,10 +173,9 @@ def get_batch(self, session_id: Union[int, str]) -> Any: response.raise_for_status() except requests.exceptions.HTTPError as err: self.log.warning("Got status code %d for session %d", err.response.status_code, session_id) - raise AirflowException("Unable to fetch batch with id: {}. Message: {}".format( - session_id, - err.response.text - )) + raise AirflowException( + "Unable to fetch batch with id: {}. Message: {}".format(session_id, err.response.text) + ) return response.json() @@ -203,10 +197,9 @@ def get_batch_state(self, session_id: Union[int, str]) -> BatchState: response.raise_for_status() except requests.exceptions.HTTPError as err: self.log.warning("Got status code %d for session %d", err.response.status_code, session_id) - raise AirflowException("Unable to fetch batch with id: {}. Message: {}".format( - session_id, - err.response.text - )) + raise AirflowException( + "Unable to fetch batch with id: {}. Message: {}".format(session_id, err.response.text) + ) jresp = response.json() if 'state' not in jresp: @@ -225,19 +218,17 @@ def delete_batch(self, session_id: Union[int, str]) -> Any: self._validate_session_id(session_id) self.log.info("Deleting batch session %d", session_id) - response = self.run_method( - method='DELETE', - endpoint='/batches/{}'.format(session_id) - ) + response = self.run_method(method='DELETE', endpoint='/batches/{}'.format(session_id)) try: response.raise_for_status() except requests.exceptions.HTTPError as err: self.log.warning("Got status code %d for session %d", err.response.status_code, session_id) - raise AirflowException("Could not kill the batch with session id: {}. Message: {}".format( - session_id, - err.response.text - )) + raise AirflowException( + "Could not kill the batch with session id: {}. Message: {}".format( + session_id, err.response.text + ) + ) return response.json() @@ -283,7 +274,7 @@ def build_post_batch_body( num_executors: Optional[Union[int, str]] = None, queue: Optional[str] = None, proxy_user: Optional[str] = None, - conf: Optional[Dict[Any, Any]] = None + conf: Optional[Dict[Any, Any]] = None, ) -> Any: """ Build the post batch request body. @@ -386,9 +377,11 @@ def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> boo :return: true if valid :rtype: bool """ - if vals is None or \ - not isinstance(vals, (tuple, list)) or \ - any(1 for val in vals if not isinstance(val, (str, int, float))): + if ( + vals is None + or not isinstance(vals, (tuple, list)) + or any(1 for val in vals if not isinstance(val, (str, int, float))) + ): raise ValueError("List of strings expected") return True diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index 16be339f3f007..cbaaec2cffbd9 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -74,7 +74,8 @@ class LivyOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, file: str, class_name: Optional[str] = None, args: Optional[Sequence[Union[str, int, float]]] = None, @@ -93,7 +94,7 @@ def __init__( proxy_user: Optional[str] = None, livy_conn_id: str = 'livy_default', polling_interval: int = 0, - **kwargs: Any + **kwargs: Any, ) -> None: # pylint: disable-msg=too-many-arguments @@ -115,7 +116,7 @@ def __init__( 'queue': queue, 'name': name, 'conf': conf, - 'proxy_user': proxy_user + 'proxy_user': proxy_user, } self._livy_conn_id = livy_conn_id diff --git a/airflow/providers/apache/livy/sensors/livy.py b/airflow/providers/apache/livy/sensors/livy.py index b9d0bc429de0b..ba29b7f940344 100644 --- a/airflow/providers/apache/livy/sensors/livy.py +++ b/airflow/providers/apache/livy/sensors/livy.py @@ -39,10 +39,7 @@ class LivySensor(BaseSensorOperator): @apply_defaults def __init__( - self, *, - batch_id: Union[int, str], - livy_conn_id: str = 'livy_default', - **kwargs: Any + self, *, batch_id: Union[int, str], livy_conn_id: str = 'livy_default', **kwargs: Any ) -> None: super().__init__(**kwargs) self._livy_conn_id = livy_conn_id diff --git a/airflow/providers/apache/pig/example_dags/example_pig.py b/airflow/providers/apache/pig/example_dags/example_pig.py index 8917f86cfe01f..368135a3fdd5e 100644 --- a/airflow/providers/apache/pig/example_dags/example_pig.py +++ b/airflow/providers/apache/pig/example_dags/example_pig.py @@ -31,12 +31,7 @@ default_args=args, schedule_interval=None, start_date=days_ago(2), - tags=['example'] + tags=['example'], ) -run_this = PigOperator( - task_id="run_example_pig_script", - pig="ls /;", - pig_opts="-x local", - dag=dag, -) +run_this = PigOperator(task_id="run_example_pig_script", pig="ls /;", pig_opts="-x local", dag=dag,) diff --git a/airflow/providers/apache/pig/hooks/pig.py b/airflow/providers/apache/pig/hooks/pig.py index 8baee6c737174..4152dd2c99f99 100644 --- a/airflow/providers/apache/pig/hooks/pig.py +++ b/airflow/providers/apache/pig/hooks/pig.py @@ -33,17 +33,14 @@ class PigCliHook(BaseHook): """ - def __init__( - self, - pig_cli_conn_id: str = "pig_cli_default") -> None: + def __init__(self, pig_cli_conn_id: str = "pig_cli_default") -> None: super().__init__() conn = self.get_connection(pig_cli_conn_id) self.pig_properties = conn.extra_dejson.get('pig_properties', '') self.conn = conn self.sub_process = None - def run_cli(self, pig: str, pig_opts: Optional[str] = None, - verbose: bool = True) -> Any: + def run_cli(self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True) -> Any: """ Run an pig script using the pig cli @@ -75,11 +72,8 @@ def run_cli(self, pig: str, pig_opts: Optional[str] = None, if verbose: self.log.info("%s", " ".join(pig_cmd)) sub_process: Any = subprocess.Popen( - pig_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=tmp_dir, - close_fds=True) + pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True + ) self.sub_process = sub_process stdout = '' for line in iter(sub_process.stdout.readline, b''): diff --git a/airflow/providers/apache/pig/operators/pig.py b/airflow/providers/apache/pig/operators/pig.py index 3f3c57897a18b..6d0f74e4fcfef 100644 --- a/airflow/providers/apache/pig/operators/pig.py +++ b/airflow/providers/apache/pig/operators/pig.py @@ -42,17 +42,22 @@ class PigOperator(BaseOperator): """ template_fields = ('pig',) - template_ext = ('.pig', '.piglatin',) + template_ext = ( + '.pig', + '.piglatin', + ) ui_color = '#f0e4ec' @apply_defaults def __init__( - self, *, - pig: str, - pig_cli_conn_id: str = 'pig_cli_default', - pigparams_jinja_translate: bool = False, - pig_opts: Optional[str] = None, - **kwargs: Any) -> None: + self, + *, + pig: str, + pig_cli_conn_id: str = 'pig_cli_default', + pigparams_jinja_translate: bool = False, + pig_opts: Optional[str] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.pigparams_jinja_translate = pigparams_jinja_translate @@ -63,8 +68,7 @@ def __init__( def prepare_template(self): if self.pigparams_jinja_translate: - self.pig = re.sub( - r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig) + self.pig = re.sub(r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig) def execute(self, context): self.log.info('Executing: %s', self.pig) diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index 24369455f53a5..248c05895aca7 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -54,26 +54,26 @@ class PinotAdminHook(BaseHook): :type pinot_admin_system_exit: bool """ - def __init__(self, - conn_id: str = "pinot_admin_default", - cmd_path: str = "pinot-admin.sh", - pinot_admin_system_exit: bool = False - ) -> None: + def __init__( + self, + conn_id: str = "pinot_admin_default", + cmd_path: str = "pinot-admin.sh", + pinot_admin_system_exit: bool = False, + ) -> None: super().__init__() conn = self.get_connection(conn_id) self.host = conn.host self.port = str(conn.port) self.cmd_path = conn.extra_dejson.get("cmd_path", cmd_path) - self.pinot_admin_system_exit = conn.extra_dejson.get("pinot_admin_system_exit", - pinot_admin_system_exit) + self.pinot_admin_system_exit = conn.extra_dejson.get( + "pinot_admin_system_exit", pinot_admin_system_exit + ) self.conn = conn def get_conn(self) -> Any: return self.conn - def add_schema(self, schema_file: str, - with_exec: Optional[bool] = True - ) -> Any: + def add_schema(self, schema_file: str, with_exec: Optional[bool] = True) -> Any: """ Add Pinot schema by run AddSchema command @@ -90,9 +90,7 @@ def add_schema(self, schema_file: str, cmd += ["-exec"] self.run_cli(cmd) - def add_table(self, file_path: str, - with_exec: Optional[bool] = True - ) -> Any: + def add_table(self, file_path: str, with_exec: Optional[bool] = True) -> Any: """ Add Pinot table with run AddTable command @@ -110,26 +108,27 @@ def add_table(self, file_path: str, self.run_cli(cmd) # pylint: disable=too-many-arguments - def create_segment(self, - generator_config_file: Optional[str] = None, - data_dir: Optional[str] = None, - segment_format: Optional[str] = None, - out_dir: Optional[str] = None, - overwrite: Optional[str] = None, - table_name: Optional[str] = None, - segment_name: Optional[str] = None, - time_column_name: Optional[str] = None, - schema_file: Optional[str] = None, - reader_config_file: Optional[str] = None, - enable_star_tree_index: Optional[str] = None, - star_tree_index_spec_file: Optional[str] = None, - hll_size: Optional[str] = None, - hll_columns: Optional[str] = None, - hll_suffix: Optional[str] = None, - num_threads: Optional[str] = None, - post_creation_verification: Optional[str] = None, - retry: Optional[str] = None - ) -> Any: + def create_segment( + self, + generator_config_file: Optional[str] = None, + data_dir: Optional[str] = None, + segment_format: Optional[str] = None, + out_dir: Optional[str] = None, + overwrite: Optional[str] = None, + table_name: Optional[str] = None, + segment_name: Optional[str] = None, + time_column_name: Optional[str] = None, + schema_file: Optional[str] = None, + reader_config_file: Optional[str] = None, + enable_star_tree_index: Optional[str] = None, + star_tree_index_spec_file: Optional[str] = None, + hll_size: Optional[str] = None, + hll_columns: Optional[str] = None, + hll_suffix: Optional[str] = None, + num_threads: Optional[str] = None, + post_creation_verification: Optional[str] = None, + retry: Optional[str] = None, + ) -> Any: """ Create Pinot segment by run CreateSegment command """ @@ -191,8 +190,7 @@ def create_segment(self, self.run_cli(cmd) - def upload_segment(self, segment_dir: str, table_name: Optional[str] = None - ) -> Any: + def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) -> Any: """ Upload Segment with run UploadSegment command @@ -230,11 +228,8 @@ def run_cli(self, cmd: List[str], verbose: Optional[bool] = True) -> str: self.log.info(" ".join(command)) sub_process = subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - close_fds=True, - env=env) + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, env=env + ) stdout = "" if sub_process.stdout: @@ -248,8 +243,9 @@ def run_cli(self, cmd: List[str], verbose: Optional[bool] = True) -> str: # As of Pinot v0.1.0, either of "Error: ..." or "Exception caught: ..." # is expected to be in the output messages. See: # https://github.com/apache/incubator-pinot/blob/release-0.1.0/pinot-tools/src/main/java/org/apache/pinot/tools/admin/PinotAdministrator.java#L98-L101 - if ((self.pinot_admin_system_exit and sub_process.returncode) or - ("Error" in stdout or "Exception" in stdout)): + if (self.pinot_admin_system_exit and sub_process.returncode) or ( + "Error" in stdout or "Exception" in stdout + ): raise AirflowException(stdout) return stdout @@ -259,6 +255,7 @@ class PinotDbApiHook(DbApiHook): """ Connect to pinot db (https://github.com/apache/incubator-pinot) to issue pql """ + conn_name_attr = 'pinot_broker_conn_id' default_conn_name = 'pinot_broker_default' supports_autocommit = False @@ -274,10 +271,9 @@ def get_conn(self) -> Any: host=conn.host, port=conn.port, path=conn.extra_dejson.get('endpoint', '/pql'), - scheme=conn.extra_dejson.get('schema', 'http') + scheme=conn.extra_dejson.get('schema', 'http'), ) - self.log.info('Get the connection to pinot ' - 'broker on %s', conn.host) + self.log.info('Get the connection to pinot ' 'broker on %s', conn.host) return pinot_broker_conn def get_uri(self) -> str: @@ -292,12 +288,9 @@ def get_uri(self) -> str: host += ':{port}'.format(port=conn.port) conn_type = 'http' if not conn.conn_type else conn.conn_type endpoint = conn.extra_dejson.get('endpoint', 'pql') - return '{conn_type}://{host}/{endpoint}'.format( - conn_type=conn_type, host=host, endpoint=endpoint) + return '{conn_type}://{host}/{endpoint}'.format(conn_type=conn_type, host=host, endpoint=endpoint) - def get_records(self, sql: str, - parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None - ) -> Any: + def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: """ Executes the sql and returns a set of records. @@ -311,9 +304,7 @@ def get_records(self, sql: str, cur.execute(sql) return cur.fetchall() - def get_first(self, sql: str, - parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None - ) -> Any: + def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: """ Executes the sql and returns the first resulting row. @@ -330,9 +321,13 @@ def get_first(self, sql: str, def set_autocommit(self, conn: Connection, autocommit: Any) -> Any: raise NotImplementedError() - def insert_rows(self, table: str, rows: str, - target_fields: Optional[str] = None, - commit_every: int = 1000, - replace: bool = False, - **kwargs: Any) -> Any: + def insert_rows( + self, + table: str, + rows: str, + target_fields: Optional[str] = None, + commit_every: int = 1000, + replace: bool = False, + **kwargs: Any, + ) -> Any: raise NotImplementedError() diff --git a/airflow/providers/apache/spark/example_dags/example_spark_dag.py b/airflow/providers/apache/spark/example_dags/example_spark_dag.py index 5d279e04e2652..982a7732aaf96 100644 --- a/airflow/providers/apache/spark/example_dags/example_spark_dag.py +++ b/airflow/providers/apache/spark/example_dags/example_spark_dag.py @@ -35,12 +35,11 @@ default_args=args, schedule_interval=None, start_date=days_ago(2), - tags=['example'] + tags=['example'], ) as dag: # [START howto_operator_spark_submit] submit_job = SparkSubmitOperator( - application="${SPARK_HOME}/examples/src/main/python/pi.py", - task_id="submit_job" + application="${SPARK_HOME}/examples/src/main/python/pi.py", task_id="submit_job" ) # [END howto_operator_spark_submit] @@ -53,7 +52,7 @@ metastore_table="bar", save_mode="overwrite", save_format="JSON", - task_id="jdbc_to_spark_job" + task_id="jdbc_to_spark_job", ) spark_to_jdbc_job = SparkJDBCOperator( @@ -63,14 +62,10 @@ jdbc_driver="org.postgresql.Driver", metastore_table="bar", save_mode="append", - task_id="spark_to_jdbc_job" + task_id="spark_to_jdbc_job", ) # [END howto_operator_spark_jdbc] # [START howto_operator_spark_sql] - sql_job = SparkSqlOperator( - sql="SELECT * FROM bar", - master="local", - task_id="sql_job" - ) + sql_job = SparkSqlOperator(sql="SELECT * FROM bar", master="local", task_id="sql_job") # [END howto_operator_spark_sql] diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py index 8ec3f4996c169..2100b6432136c 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py @@ -113,38 +113,39 @@ class SparkJDBCHook(SparkSubmitHook): """ # pylint: disable=too-many-arguments,too-many-locals - def __init__(self, - spark_app_name: str = 'airflow-spark-jdbc', - spark_conn_id: str = 'spark-default', - spark_conf: Optional[Dict[str, Any]] = None, - spark_py_files: Optional[str] = None, - spark_files: Optional[str] = None, - spark_jars: Optional[str] = None, - num_executors: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, - verbose: bool = False, - principal: Optional[str] = None, - keytab: Optional[str] = None, - cmd_type: str = 'spark_to_jdbc', - jdbc_table: Optional[str] = None, - jdbc_conn_id: str = 'jdbc-default', - jdbc_driver: Optional[str] = None, - metastore_table: Optional[str] = None, - jdbc_truncate: bool = False, - save_mode: Optional[str] = None, - save_format: Optional[str] = None, - batch_size: Optional[int] = None, - fetch_size: Optional[int] = None, - num_partitions: Optional[int] = None, - partition_column: Optional[str] = None, - lower_bound: Optional[str] = None, - upper_bound: Optional[str] = None, - create_table_column_types: Optional[str] = None, - *args: Any, - **kwargs: Any - ): + def __init__( + self, + spark_app_name: str = 'airflow-spark-jdbc', + spark_conn_id: str = 'spark-default', + spark_conf: Optional[Dict[str, Any]] = None, + spark_py_files: Optional[str] = None, + spark_files: Optional[str] = None, + spark_jars: Optional[str] = None, + num_executors: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + verbose: bool = False, + principal: Optional[str] = None, + keytab: Optional[str] = None, + cmd_type: str = 'spark_to_jdbc', + jdbc_table: Optional[str] = None, + jdbc_conn_id: str = 'jdbc-default', + jdbc_driver: Optional[str] = None, + metastore_table: Optional[str] = None, + jdbc_truncate: bool = False, + save_mode: Optional[str] = None, + save_format: Optional[str] = None, + batch_size: Optional[int] = None, + fetch_size: Optional[int] = None, + num_partitions: Optional[int] = None, + partition_column: Optional[str] = None, + lower_bound: Optional[str] = None, + upper_bound: Optional[str] = None, + create_table_column_types: Optional[str] = None, + *args: Any, + **kwargs: Any, + ): super().__init__(*args, **kwargs) self._name = spark_app_name self._conn_id = spark_conn_id @@ -177,12 +178,7 @@ def __init__(self, self._jdbc_connection = self._resolve_jdbc_connection() def _resolve_jdbc_connection(self) -> Dict[str, Any]: - conn_data = {'url': '', - 'schema': '', - 'conn_prefix': '', - 'user': '', - 'password': '' - } + conn_data = {'url': '', 'schema': '', 'conn_prefix': '', 'user': '', 'password': ''} try: conn = self.get_connection(self._jdbc_conn_id) if conn.port: @@ -196,8 +192,7 @@ def _resolve_jdbc_connection(self) -> Dict[str, Any]: conn_data['conn_prefix'] = extra.get('conn_prefix', '') except AirflowException: self.log.debug( - "Could not load jdbc connection string %s, defaulting to %s", - self._jdbc_conn_id, "" + "Could not load jdbc connection string %s, defaulting to %s", self._jdbc_conn_id, "" ) return conn_data @@ -205,9 +200,10 @@ def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any: arguments = [] arguments += ["-cmdType", self._cmd_type] if self._jdbc_connection['url']: - arguments += ['-url', "{0}{1}/{2}".format( - jdbc_conn['conn_prefix'], jdbc_conn['url'], jdbc_conn['schema'] - )] + arguments += [ + '-url', + "{0}{1}/{2}".format(jdbc_conn['conn_prefix'], jdbc_conn['url'], jdbc_conn['schema']), + ] if self._jdbc_connection['user']: arguments += ['-user', self._jdbc_connection['user']] if self._jdbc_connection['password']: @@ -226,12 +222,16 @@ def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any: arguments += ['-fetchsize', str(self._fetch_size)] if self._num_partitions: arguments += ['-numPartitions', str(self._num_partitions)] - if (self._partition_column and self._lower_bound and - self._upper_bound and self._num_partitions): + if self._partition_column and self._lower_bound and self._upper_bound and self._num_partitions: # these 3 parameters need to be used all together to take effect. - arguments += ['-partitionColumn', self._partition_column, - '-lowerBound', self._lower_bound, - '-upperBound', self._upper_bound] + arguments += [ + '-partitionColumn', + self._partition_column, + '-lowerBound', + self._lower_bound, + '-upperBound', + self._upper_bound, + ] if self._save_mode: arguments += ['-saveMode', self._save_mode] if self._save_format: @@ -244,10 +244,8 @@ def submit_jdbc_job(self) -> None: """ Submit Spark JDBC job """ - self._application_args = \ - self._build_jdbc_application_arguments(self._jdbc_connection) - self.submit(application=os.path.dirname(os.path.abspath(__file__)) + - "/spark_jdbc_script.py") + self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection) + self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py") def get_conn(self) -> Any: pass diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py index 3a9f56a24a0fa..ffc9a3e2e0811 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py @@ -25,12 +25,14 @@ SPARK_READ_FROM_JDBC: str = "jdbc_to_spark" -def set_common_options(spark_source: Any, - url: str = 'localhost:5432', - jdbc_table: str = 'default.default', - user: str = 'root', - password: str = 'root', - driver: str = 'driver') -> Any: +def set_common_options( + spark_source: Any, + url: str = 'localhost:5432', + jdbc_table: str = 'default.default', + user: str = 'root', + password: str = 'root', + driver: str = 'driver', +) -> Any: """ Get Spark source from JDBC connection @@ -42,36 +44,36 @@ def set_common_options(spark_source: Any, :param driver: JDBC resource driver """ - spark_source = spark_source \ - .format('jdbc') \ - .option('url', url) \ - .option('dbtable', jdbc_table) \ - .option('user', user) \ - .option('password', password) \ + spark_source = ( + spark_source.format('jdbc') + .option('url', url) + .option('dbtable', jdbc_table) + .option('user', user) + .option('password', password) .option('driver', driver) + ) return spark_source # pylint: disable=too-many-arguments -def spark_write_to_jdbc(spark_session: SparkSession, - url: str, - user: str, - password: str, - metastore_table: str, - jdbc_table: str, - driver: Any, - truncate: bool, - save_mode: str, - batch_size: int, - num_partitions: int, - create_table_column_types: str) -> None: +def spark_write_to_jdbc( + spark_session: SparkSession, + url: str, + user: str, + password: str, + metastore_table: str, + jdbc_table: str, + driver: Any, + truncate: bool, + save_mode: str, + batch_size: int, + num_partitions: int, + create_table_column_types: str, +) -> None: """ Transfer data from Spark to JDBC source """ - writer = spark_session \ - .table(metastore_table) \ - .write \ - + writer = spark_session.table(metastore_table).write # first set common options writer = set_common_options(writer, url, jdbc_table, user, password, driver) @@ -85,26 +87,26 @@ def spark_write_to_jdbc(spark_session: SparkSession, if create_table_column_types: writer = writer.option("createTableColumnTypes", create_table_column_types) - writer \ - .save(mode=save_mode) + writer.save(mode=save_mode) # pylint: disable=too-many-arguments -def spark_read_from_jdbc(spark_session: SparkSession, - url: str, - user: str, - password: str, - metastore_table: str, - jdbc_table: str, - driver: Any, - save_mode: str, - save_format: str, - fetch_size: int, - num_partitions: int, - partition_column: str, - lower_bound: str, - upper_bound: str - ) -> None: +def spark_read_from_jdbc( + spark_session: SparkSession, + url: str, + user: str, + password: str, + metastore_table: str, + jdbc_table: str, + driver: Any, + save_mode: str, + save_format: str, + fetch_size: int, + num_partitions: int, + partition_column: str, + lower_bound: str, + upper_bound: str, +) -> None: """ Transfer data from JDBC source to Spark """ @@ -118,15 +120,13 @@ def spark_read_from_jdbc(spark_session: SparkSession, if num_partitions: reader = reader.option('numPartitions', num_partitions) if partition_column and lower_bound and upper_bound: - reader = reader \ - .option('partitionColumn', partition_column) \ - .option('lowerBound', lower_bound) \ + reader = ( + reader.option('partitionColumn', partition_column) + .option('lowerBound', lower_bound) .option('upperBound', upper_bound) + ) - reader \ - .load() \ - .write \ - .saveAsTable(metastore_table, format=save_format, mode=save_mode) + reader.load().write.saveAsTable(metastore_table, format=save_format, mode=save_mode) def _parse_arguments(args: Optional[List[str]] = None) -> Any: @@ -148,16 +148,12 @@ def _parse_arguments(args: Optional[List[str]] = None) -> Any: parser.add_argument('-partitionColumn', dest='partition_column', action='store') parser.add_argument('-lowerBound', dest='lower_bound', action='store') parser.add_argument('-upperBound', dest='upper_bound', action='store') - parser.add_argument('-createTableColumnTypes', - dest='create_table_column_types', action='store') + parser.add_argument('-createTableColumnTypes', dest='create_table_column_types', action='store') return parser.parse_args(args=args) def _create_spark_session(arguments: Any) -> SparkSession: - return SparkSession.builder \ - .appName(arguments.name) \ - .enableHiveSupport() \ - .getOrCreate() + return SparkSession.builder.appName(arguments.name).enableHiveSupport().getOrCreate() def _run_spark(arguments: Any) -> None: @@ -165,33 +161,37 @@ def _run_spark(arguments: Any) -> None: spark = _create_spark_session(arguments) if arguments.cmd_type == SPARK_WRITE_TO_JDBC: - spark_write_to_jdbc(spark, - arguments.url, - arguments.user, - arguments.password, - arguments.metastore_table, - arguments.jdbc_table, - arguments.jdbc_driver, - arguments.truncate, - arguments.save_mode, - arguments.batch_size, - arguments.num_partitions, - arguments.create_table_column_types) + spark_write_to_jdbc( + spark, + arguments.url, + arguments.user, + arguments.password, + arguments.metastore_table, + arguments.jdbc_table, + arguments.jdbc_driver, + arguments.truncate, + arguments.save_mode, + arguments.batch_size, + arguments.num_partitions, + arguments.create_table_column_types, + ) elif arguments.cmd_type == SPARK_READ_FROM_JDBC: - spark_read_from_jdbc(spark, - arguments.url, - arguments.user, - arguments.password, - arguments.metastore_table, - arguments.jdbc_table, - arguments.jdbc_driver, - arguments.save_mode, - arguments.save_format, - arguments.fetch_size, - arguments.num_partitions, - arguments.partition_column, - arguments.lower_bound, - arguments.upper_bound) + spark_read_from_jdbc( + spark, + arguments.url, + arguments.user, + arguments.password, + arguments.metastore_table, + arguments.jdbc_table, + arguments.jdbc_driver, + arguments.save_mode, + arguments.save_format, + arguments.fetch_size, + arguments.num_partitions, + arguments.partition_column, + arguments.lower_bound, + arguments.upper_bound, + ) if __name__ == "__main__": # pragma: no cover diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index c0491dde1d60b..cceb2bc1aeb62 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -57,21 +57,22 @@ class SparkSqlHook(BaseHook): """ # pylint: disable=too-many-arguments - def __init__(self, - sql: str, - conf: Optional[str] = None, - conn_id: str = 'spark_sql_default', - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - master: str = 'yarn', - name: str = 'default-name', - num_executors: Optional[int] = None, - verbose: bool = True, - yarn_queue: str = 'default' - ) -> None: + def __init__( + self, + sql: str, + conf: Optional[str] = None, + conn_id: str = 'spark_sql_default', + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + master: str = 'yarn', + name: str = 'default-name', + num_executors: Optional[int] = None, + verbose: bool = True, + yarn_queue: str = 'default', + ) -> None: super().__init__() self._sql = sql self._conf = conf @@ -152,10 +153,7 @@ def run_query(self, cmd: str = "", **kwargs: Any) -> None: :type kwargs: dict """ spark_sql_cmd = self._prepare_command(cmd) - self._sp = subprocess.Popen(spark_sql_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - **kwargs) + self._sp = subprocess.Popen(spark_sql_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) for line in iter(self._sp.stdout): # type: ignore self.log.info(line) diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index 0319f6d57a205..13983e905ff4c 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -105,33 +105,34 @@ class SparkSubmitHook(BaseHook, LoggingMixin): """ # pylint: disable=too-many-arguments,too-many-locals,too-many-branches - def __init__(self, - conf: Optional[Dict[str, Any]] = None, - conn_id: str = 'spark_default', - files: Optional[str] = None, - py_files: Optional[str] = None, - archives: Optional[str] = None, - driver_class_path: Optional[str] = None, - jars: Optional[str] = None, - java_class: Optional[str] = None, - packages: Optional[str] = None, - exclude_packages: Optional[str] = None, - repositories: Optional[str] = None, - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - proxy_user: Optional[str] = None, - name: str = 'default-name', - num_executors: Optional[int] = None, - status_poll_interval: int = 1, - application_args: Optional[List[Any]] = None, - env_vars: Optional[Dict[str, Any]] = None, - verbose: bool = False, - spark_binary: Optional[str] = None - ) -> None: + def __init__( + self, + conf: Optional[Dict[str, Any]] = None, + conn_id: str = 'spark_default', + files: Optional[str] = None, + py_files: Optional[str] = None, + archives: Optional[str] = None, + driver_class_path: Optional[str] = None, + jars: Optional[str] = None, + java_class: Optional[str] = None, + packages: Optional[str] = None, + exclude_packages: Optional[str] = None, + repositories: Optional[str] = None, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + proxy_user: Optional[str] = None, + name: str = 'default-name', + num_executors: Optional[int] = None, + status_poll_interval: int = 1, + application_args: Optional[List[Any]] = None, + env_vars: Optional[Dict[str, Any]] = None, + verbose: bool = False, + spark_binary: Optional[str] = None, + ) -> None: super().__init__() self._conf = conf or {} self._conn_id = conn_id @@ -168,7 +169,9 @@ def __init__(self, if self._is_kubernetes and kube_client is None: raise RuntimeError( "{} specified by kubernetes dependencies are not installed!".format( - self._connection['master'])) + self._connection['master'] + ) + ) self._should_track_driver_status = self._resolve_should_track_driver_status() self._driver_id: Optional[str] = None @@ -182,17 +185,18 @@ def _resolve_should_track_driver_status(self) -> bool: subsequent spark-submit status requests after the initial spark-submit request :return: if the driver status should be tracked """ - return ('spark://' in self._connection['master'] and - self._connection['deploy_mode'] == 'cluster') + return 'spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster' def _resolve_connection(self) -> Dict[str, Any]: # Build from connection master or default to yarn if not available - conn_data = {'master': 'yarn', - 'queue': None, - 'deploy_mode': None, - 'spark_home': None, - 'spark_binary': self._spark_binary or "spark-submit", - 'namespace': None} + conn_data = { + 'master': 'yarn', + 'queue': None, + 'deploy_mode': None, + 'spark_home': None, + 'spark_binary': self._spark_binary or "spark-submit", + 'namespace': None, + } try: # Master can be local, yarn, spark://HOST:PORT, mesos://HOST:PORT and @@ -208,13 +212,11 @@ def _resolve_connection(self) -> Dict[str, Any]: conn_data['queue'] = extra.get('queue', None) conn_data['deploy_mode'] = extra.get('deploy-mode', None) conn_data['spark_home'] = extra.get('spark-home', None) - conn_data['spark_binary'] = self._spark_binary or \ - extra.get('spark-binary', "spark-submit") + conn_data['spark_binary'] = self._spark_binary or extra.get('spark-binary', "spark-submit") conn_data['namespace'] = extra.get('namespace') except AirflowException: self.log.info( - "Could not load connection string %s, defaulting to %s", - self._conn_id, conn_data['master'] + "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data['master'] ) if 'spark.kubernetes.namespace' in self._conf: @@ -230,8 +232,9 @@ def _get_spark_binary_path(self) -> List[str]: # the spark_home; otherwise assume that spark-submit is present in the path to # the executing user if self._connection['spark_home']: - connection_cmd = [os.path.join(self._connection['spark_home'], 'bin', - self._connection['spark_binary'])] + connection_cmd = [ + os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary']) + ] else: connection_cmd = [self._connection['spark_binary']] @@ -242,18 +245,18 @@ def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str: # where key contains password (case insensitive), e.g. HivePassword='abc' connection_cmd_masked = re.sub( r"(" - r"\S*?" # Match all non-whitespace characters before... + r"\S*?" # Match all non-whitespace characters before... r"(?:secret|password)" # ...literally a "secret" or "password" - # word (not capturing them). - r"\S*?" # All non-whitespace characters before either... - r"(?:=|\s+)" # ...an equal sign or whitespace characters - # (not capturing them). - r"(['\"]?)" # An optional single or double quote. - r")" # This is the end of the first capturing group. - r"(?:(?!\2\s).)*" # All characters between optional quotes - # (matched above); if the value is quoted, - # it may contain whitespace. - r"(\2)", # Optional matching quote. + # word (not capturing them). + r"\S*?" # All non-whitespace characters before either... + r"(?:=|\s+)" # ...an equal sign or whitespace characters + # (not capturing them). + r"(['\"]?)" # An optional single or double quote. + r")" # This is the end of the first capturing group. + r"(?:(?!\2\s).)*" # All characters between optional quotes + # (matched above); if the value is quoted, + # it may contain whitespace. + r"(\2)", # Optional matching quote. r'\1******\3', ' '.join(connection_cmd), flags=re.I, @@ -284,17 +287,16 @@ def _build_spark_submit_command(self, application: str) -> List[str]: else: tmpl = "spark.kubernetes.driverEnv.{}={}" for key in self._env_vars: - connection_cmd += [ - "--conf", - tmpl.format(key, str(self._env_vars[key]))] + connection_cmd += ["--conf", tmpl.format(key, str(self._env_vars[key]))] elif self._env_vars and self._connection['deploy_mode'] != "cluster": self._env = self._env_vars # Do it on Popen of the process elif self._env_vars and self._connection['deploy_mode'] == "cluster": - raise AirflowException( - "SparkSubmitHook env_vars is not supported in standalone-cluster mode.") + raise AirflowException("SparkSubmitHook env_vars is not supported in standalone-cluster mode.") if self._is_kubernetes and self._connection['namespace']: - connection_cmd += ["--conf", "spark.kubernetes.namespace={}".format( - self._connection['namespace'])] + connection_cmd += [ + "--conf", + "spark.kubernetes.namespace={}".format(self._connection['namespace']), + ] if self._files: connection_cmd += ["--files", self._files] if self._py_files: @@ -364,8 +366,9 @@ def _build_track_driver_status_command(self) -> List[str]: "--max-time", str(curl_max_wait_time), "{host}/v1/submissions/status/{submission_id}".format( - host=spark_host, - submission_id=self._driver_id)] + host=spark_host, submission_id=self._driver_id + ), + ] self.log.info(connection_cmd) # The driver id so we can poll for its status @@ -373,8 +376,9 @@ def _build_track_driver_status_command(self) -> List[str]: pass else: raise AirflowException( - "Invalid status: attempted to poll driver " + - "status but no driver id is known. Giving up.") + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up." + ) else: @@ -388,8 +392,9 @@ def _build_track_driver_status_command(self) -> List[str]: connection_cmd += ["--status", self._driver_id] else: raise AirflowException( - "Invalid status: attempted to poll driver " + - "status but no driver id is known. Giving up.") + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up." + ) self.log.debug("Poll driver status cmd: %s", connection_cmd) @@ -410,12 +415,14 @@ def submit(self, application: str = "", **kwargs: Any) -> None: env.update(self._env) kwargs["env"] = env - self._submit_sp = subprocess.Popen(spark_submit_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=-1, - universal_newlines=True, - **kwargs) + self._submit_sp = subprocess.Popen( + spark_submit_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True, + **kwargs, + ) self._process_spark_submit_log(iter(self._submit_sp.stdout)) # type: ignore returncode = self._submit_sp.wait() @@ -442,8 +449,7 @@ def submit(self, application: str = "", **kwargs: Any) -> None: if self._should_track_driver_status: if self._driver_id is None: raise AirflowException( - "No driver id is known: something went wrong when executing " + - "the spark submit command" + "No driver id is known: something went wrong when executing " + "the spark submit command" ) # We start with the SUBMITTED status as initial status @@ -454,8 +460,9 @@ def submit(self, application: str = "", **kwargs: Any) -> None: if self._driver_status != "FINISHED": raise AirflowException( - "ERROR : Driver {} badly exited with status {}" - .format(self._driver_id, self._driver_status) + "ERROR : Driver {} badly exited with status {}".format( + self._driver_id, self._driver_status + ) ) def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: @@ -479,8 +486,7 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: match = re.search('(application[0-9_]+)', line) if match: self._yarn_application_id = match.groups()[0] - self.log.info("Identified spark driver id: %s", - self._yarn_application_id) + self.log.info("Identified spark driver id: %s", self._yarn_application_id) # If we run Kubernetes cluster mode, we want to extract the driver pod id # from the logs so we can kill the application when we stop it unexpectedly @@ -488,8 +494,7 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: match = re.search(r'\s*pod name: ((.+?)-([a-z0-9]+)-driver)', line) if match: self._kubernetes_driver_pod = match.groups()[0] - self.log.info("Identified spark driver pod: %s", - self._kubernetes_driver_pod) + self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) # Store the Spark Exit code match_exit_code = re.search(r'\s*[eE]xit code: (\d+)', line) @@ -520,8 +525,7 @@ def _process_spark_status_log(self, itr: Iterator[Any]) -> None: # Check if the log line is about the driver status and extract the status. if "driverState" in line: - self._driver_status = line.split(' : ')[1] \ - .replace(',', '').replace('\"', '').strip() + self._driver_status = line.split(' : ')[1].replace(',', '').replace('\"', '').strip() driver_found = True self.log.debug("spark driver status log: %s", line) @@ -566,8 +570,7 @@ def _start_driver_status_tracking(self) -> None: max_missed_job_status_reports = 10 # Keep polling as long as the driver is processing - while self._driver_status not in ["FINISHED", "UNKNOWN", - "KILLED", "FAILED", "ERROR"]: + while self._driver_status not in ["FINISHED", "UNKNOWN", "KILLED", "FAILED", "ERROR"]: # Sleep for n seconds as we do not want to spam the cluster time.sleep(self._status_poll_interval) @@ -575,12 +578,13 @@ def _start_driver_status_tracking(self) -> None: self.log.debug("polling status of spark driver with id %s", self._driver_id) poll_drive_status_cmd = self._build_track_driver_status_command() - status_process: Any = subprocess.Popen(poll_drive_status_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=-1, - universal_newlines=True - ) + status_process: Any = subprocess.Popen( + poll_drive_status_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True, + ) self._process_spark_status_log(iter(status_process.stdout)) returncode = status_process.wait() @@ -590,8 +594,9 @@ def _start_driver_status_tracking(self) -> None: missed_job_status_reports += 1 else: raise AirflowException( - "Failed to poll for the driver status {} times: returncode = {}" - .format(max_missed_job_status_reports, returncode) + "Failed to poll for the driver status {} times: returncode = {}".format( + max_missed_job_status_reports, returncode + ) ) def _build_spark_driver_kill_command(self) -> List[str]: @@ -604,9 +609,9 @@ def _build_spark_driver_kill_command(self) -> List[str]: # the spark_home; otherwise assume that spark-submit is present in the path to # the executing user if self._connection['spark_home']: - connection_cmd = [os.path.join(self._connection['spark_home'], - 'bin', - self._connection['spark_binary'])] + connection_cmd = [ + os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary']) + ] else: connection_cmd = [self._connection['spark_binary']] @@ -633,20 +638,18 @@ def on_kill(self) -> None: self.log.info('Killing driver %s on cluster', self._driver_id) kill_cmd = self._build_spark_driver_kill_command() - driver_kill = subprocess.Popen(kill_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + driver_kill = subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - self.log.info("Spark driver %s killed with return code: %s", - self._driver_id, driver_kill.wait()) + self.log.info( + "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() + ) if self._submit_sp and self._submit_sp.poll() is None: self.log.info('Sending kill signal to %s', self._connection['spark_binary']) self._submit_sp.kill() if self._yarn_application_id: - kill_cmd = "yarn application -kill {}" \ - .format(self._yarn_application_id).split() + kill_cmd = "yarn application -kill {}".format(self._yarn_application_id).split() env = None if self._keytab is not None and self._principal is not None: # we are ignoring renewal failures from renew_from_kt @@ -656,10 +659,9 @@ def on_kill(self) -> None: env = os.environ.copy() env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache') - yarn_kill = subprocess.Popen(kill_cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + yarn_kill = subprocess.Popen( + kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) @@ -669,12 +671,14 @@ def on_kill(self) -> None: # Currently only instantiate Kubernetes client for killing a spark pod. try: import kubernetes + client = kube_client.get_kube_client() api_response = client.delete_namespaced_pod( self._kubernetes_driver_pod, self._connection['namespace'], body=kubernetes.client.V1DeleteOptions(), - pretty=True) + pretty=True, + ) self.log.info("Spark on K8s killed with response: %s", api_response) diff --git a/airflow/providers/apache/spark/operators/spark_jdbc.py b/airflow/providers/apache/spark/operators/spark_jdbc.py index 8c1c7be494559..a5f701a7542e2 100644 --- a/airflow/providers/apache/spark/operators/spark_jdbc.py +++ b/airflow/providers/apache/spark/operators/spark_jdbc.py @@ -120,36 +120,39 @@ class SparkJDBCOperator(SparkSubmitOperator): # pylint: disable=too-many-arguments,too-many-locals @apply_defaults - def __init__(self, *, - spark_app_name: str = 'airflow-spark-jdbc', - spark_conn_id: str = 'spark-default', - spark_conf: Optional[Dict[str, Any]] = None, - spark_py_files: Optional[str] = None, - spark_files: Optional[str] = None, - spark_jars: Optional[str] = None, - num_executors: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, - verbose: bool = False, - principal: Optional[str] = None, - keytab: Optional[str] = None, - cmd_type: str = 'spark_to_jdbc', - jdbc_table: Optional[str] = None, - jdbc_conn_id: str = 'jdbc-default', - jdbc_driver: Optional[str] = None, - metastore_table: Optional[str] = None, - jdbc_truncate: bool = False, - save_mode: Optional[str] = None, - save_format: Optional[str] = None, - batch_size: Optional[int] = None, - fetch_size: Optional[int] = None, - num_partitions: Optional[int] = None, - partition_column: Optional[str] = None, - lower_bound: Optional[str] = None, - upper_bound: Optional[str] = None, - create_table_column_types: Optional[str] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + spark_app_name: str = 'airflow-spark-jdbc', + spark_conn_id: str = 'spark-default', + spark_conf: Optional[Dict[str, Any]] = None, + spark_py_files: Optional[str] = None, + spark_files: Optional[str] = None, + spark_jars: Optional[str] = None, + num_executors: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + verbose: bool = False, + principal: Optional[str] = None, + keytab: Optional[str] = None, + cmd_type: str = 'spark_to_jdbc', + jdbc_table: Optional[str] = None, + jdbc_conn_id: str = 'jdbc-default', + jdbc_driver: Optional[str] = None, + metastore_table: Optional[str] = None, + jdbc_truncate: bool = False, + save_mode: Optional[str] = None, + save_format: Optional[str] = None, + batch_size: Optional[int] = None, + fetch_size: Optional[int] = None, + num_partitions: Optional[int] = None, + partition_column: Optional[str] = None, + lower_bound: Optional[str] = None, + upper_bound: Optional[str] = None, + create_table_column_types: Optional[str] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self._spark_app_name = spark_app_name self._spark_conn_id = spark_conn_id @@ -223,5 +226,5 @@ def _get_hook(self) -> SparkJDBCHook: partition_column=self._partition_column, lower_bound=self._lower_bound, upper_bound=self._upper_bound, - create_table_column_types=self._create_table_column_types + create_table_column_types=self._create_table_column_types, ) diff --git a/airflow/providers/apache/spark/operators/spark_sql.py b/airflow/providers/apache/spark/operators/spark_sql.py index fec7f6f15d3f5..6d6c69d635528 100644 --- a/airflow/providers/apache/spark/operators/spark_sql.py +++ b/airflow/providers/apache/spark/operators/spark_sql.py @@ -64,21 +64,24 @@ class SparkSqlOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - sql: str, - conf: Optional[str] = None, - conn_id: str = 'spark_sql_default', - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - master: str = 'yarn', - name: str = 'default-name', - num_executors: Optional[int] = None, - verbose: bool = True, - yarn_queue: str = 'default', - **kwargs: Any) -> None: + def __init__( + self, + *, + sql: str, + conf: Optional[str] = None, + conn_id: str = 'spark_sql_default', + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + master: str = 'yarn', + name: str = 'default-name', + num_executors: Optional[int] = None, + verbose: bool = True, + yarn_queue: str = 'default', + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self._sql = sql self._conf = conf @@ -110,17 +113,18 @@ def on_kill(self) -> None: def _get_hook(self) -> SparkSqlHook: """Get SparkSqlHook""" - return SparkSqlHook(sql=self._sql, - conf=self._conf, - conn_id=self._conn_id, - total_executor_cores=self._total_executor_cores, - executor_cores=self._executor_cores, - executor_memory=self._executor_memory, - keytab=self._keytab, - principal=self._principal, - name=self._name, - num_executors=self._num_executors, - master=self._master, - verbose=self._verbose, - yarn_queue=self._yarn_queue - ) + return SparkSqlHook( + sql=self._sql, + conf=self._conf, + conn_id=self._conn_id, + total_executor_cores=self._total_executor_cores, + executor_cores=self._executor_cores, + executor_memory=self._executor_memory, + keytab=self._keytab, + principal=self._principal, + name=self._name, + num_executors=self._num_executors, + master=self._master, + verbose=self._verbose, + yarn_queue=self._yarn_queue, + ) diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py index 24d684ae140ca..c3c6b65a1abb7 100644 --- a/airflow/providers/apache/spark/operators/spark_submit.py +++ b/airflow/providers/apache/spark/operators/spark_submit.py @@ -95,41 +95,58 @@ class SparkSubmitOperator(BaseOperator): Some distros may use spark2-submit. :type spark_binary: str """ - template_fields = ('_application', '_conf', '_files', '_py_files', '_jars', '_driver_class_path', - '_packages', '_exclude_packages', '_keytab', '_principal', '_proxy_user', '_name', - '_application_args', '_env_vars') + + template_fields = ( + '_application', + '_conf', + '_files', + '_py_files', + '_jars', + '_driver_class_path', + '_packages', + '_exclude_packages', + '_keytab', + '_principal', + '_proxy_user', + '_name', + '_application_args', + '_env_vars', + ) ui_color = WEB_COLORS['LIGHTORANGE'] # pylint: disable=too-many-arguments,too-many-locals @apply_defaults - def __init__(self, *, - application: str = '', - conf: Optional[Dict[str, Any]] = None, - conn_id: str = 'spark_default', - files: Optional[str] = None, - py_files: Optional[str] = None, - archives: Optional[str] = None, - driver_class_path: Optional[str] = None, - jars: Optional[str] = None, - java_class: Optional[str] = None, - packages: Optional[str] = None, - exclude_packages: Optional[str] = None, - repositories: Optional[str] = None, - total_executor_cores: Optional[int] = None, - executor_cores: Optional[int] = None, - executor_memory: Optional[str] = None, - driver_memory: Optional[str] = None, - keytab: Optional[str] = None, - principal: Optional[str] = None, - proxy_user: Optional[str] = None, - name: str = 'arrow-spark', - num_executors: Optional[int] = None, - status_poll_interval: int = 1, - application_args: Optional[List[Any]] = None, - env_vars: Optional[Dict[str, Any]] = None, - verbose: bool = False, - spark_binary: Optional[str] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + application: str = '', + conf: Optional[Dict[str, Any]] = None, + conn_id: str = 'spark_default', + files: Optional[str] = None, + py_files: Optional[str] = None, + archives: Optional[str] = None, + driver_class_path: Optional[str] = None, + jars: Optional[str] = None, + java_class: Optional[str] = None, + packages: Optional[str] = None, + exclude_packages: Optional[str] = None, + repositories: Optional[str] = None, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + proxy_user: Optional[str] = None, + name: str = 'arrow-spark', + num_executors: Optional[int] = None, + status_poll_interval: int = 1, + application_args: Optional[List[Any]] = None, + env_vars: Optional[Dict[str, Any]] = None, + verbose: bool = False, + spark_binary: Optional[str] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self._application = application self._conf = conf @@ -198,5 +215,5 @@ def _get_hook(self) -> SparkSubmitHook: application_args=self._application_args, env_vars=self._env_vars, verbose=self._verbose, - spark_binary=self._spark_binary + spark_binary=self._spark_binary, ) diff --git a/airflow/providers/apache/sqoop/hooks/sqoop.py b/airflow/providers/apache/sqoop/hooks/sqoop.py index 6b849356282bf..7dfa665bb9b56 100644 --- a/airflow/providers/apache/sqoop/hooks/sqoop.py +++ b/airflow/providers/apache/sqoop/hooks/sqoop.py @@ -54,14 +54,15 @@ class SqoopHook(BaseHook): :type properties: dict """ - def __init__(self, - conn_id: str = 'sqoop_default', - verbose: bool = False, - num_mappers: Optional[int] = None, - hcatalog_database: Optional[str] = None, - hcatalog_table: Optional[str] = None, - properties: Optional[Dict[str, Any]] = None - ) -> None: + def __init__( + self, + conn_id: str = 'sqoop_default', + verbose: bool = False, + num_mappers: Optional[int] = None, + hcatalog_database: Optional[str] = None, + hcatalog_table: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None, + ) -> None: # No mutable types in the default parameters super().__init__() self.conn = self.get_connection(conn_id) @@ -77,8 +78,7 @@ def __init__(self, self.verbose = verbose self.num_mappers = num_mappers self.properties = properties or {} - self.log.info("Using connection to: %s:%s/%s", - self.conn.host, self.conn.port, self.conn.schema) + self.log.info("Using connection to: %s:%s/%s", self.conn.host, self.conn.port, self.conn.schema) self.sub_process: Any = None def get_conn(self) -> Any: @@ -106,11 +106,7 @@ def popen(self, cmd: List[str], **kwargs: Any) -> None: """ masked_cmd = ' '.join(self.cmd_mask_password(cmd)) self.log.info("Executing command: %s", masked_cmd) - self.sub_process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - **kwargs) + self.sub_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) for line in iter(self.sub_process.stdout): # type: ignore self.log.info(line.strip()) @@ -174,12 +170,18 @@ def _get_export_format_argument(file_type: str = 'text') -> List[str]: elif file_type == "text": return ["--as-textfile"] else: - raise AirflowException("Argument file_type should be 'avro', " - "'sequence', 'parquet' or 'text'.") - - def _import_cmd(self, target_dir: Optional[str], append: bool, file_type: str, - split_by: Optional[str], direct: Optional[bool], - driver: Any, extra_import_options: Any) -> List[str]: + raise AirflowException("Argument file_type should be 'avro', " "'sequence', 'parquet' or 'text'.") + + def _import_cmd( + self, + target_dir: Optional[str], + append: bool, + file_type: str, + split_by: Optional[str], + direct: Optional[bool], + driver: Any, + extra_import_options: Any, + ) -> List[str]: cmd = self._prepare_command(export=False) @@ -209,18 +211,19 @@ def _import_cmd(self, target_dir: Optional[str], append: bool, file_type: str, return cmd # pylint: disable=too-many-arguments - def import_table(self, - table: str, - target_dir: Optional[str] = None, - append: bool = False, - file_type: str = "text", - columns: Optional[str] = None, - split_by: Optional[str] = None, - where: Optional[str] = None, - direct: bool = False, - driver: Any = None, - extra_import_options: Optional[Dict[str, Any]] = None - ) -> Any: + def import_table( + self, + table: str, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + columns: Optional[str] = None, + split_by: Optional[str] = None, + where: Optional[str] = None, + direct: bool = False, + driver: Any = None, + extra_import_options: Optional[Dict[str, Any]] = None, + ) -> Any: """ Imports table from remote location to target dir. Arguments are copies of direct sqoop command line arguments @@ -239,8 +242,7 @@ def import_table(self, If a key doesn't have a value, just pass an empty string to it. Don't include prefix of -- for sqoop options. """ - cmd = self._import_cmd(target_dir, append, file_type, split_by, direct, - driver, extra_import_options) + cmd = self._import_cmd(target_dir, append, file_type, split_by, direct, driver, extra_import_options) cmd += ["--table", table] @@ -251,15 +253,17 @@ def import_table(self, self.popen(cmd) - def import_query(self, query: str, - target_dir: Optional[str] = None, - append: bool = False, - file_type: str = "text", - split_by: Optional[str] = None, - direct: Optional[bool] = None, - driver: Optional[Any] = None, - extra_import_options: Optional[Dict[str, Any]] = None - ) -> Any: + def import_query( + self, + query: str, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + split_by: Optional[str] = None, + direct: Optional[bool] = None, + driver: Optional[Any] = None, + extra_import_options: Optional[Dict[str, Any]] = None, + ) -> Any: """ Imports a specific query from the rdbms to hdfs @@ -275,28 +279,29 @@ def import_query(self, query: str, If a key doesn't have a value, just pass an empty string to it. Don't include prefix of -- for sqoop options. """ - cmd = self._import_cmd(target_dir, append, file_type, split_by, direct, - driver, extra_import_options) + cmd = self._import_cmd(target_dir, append, file_type, split_by, direct, driver, extra_import_options) cmd += ["--query", query] self.popen(cmd) # pylint: disable=too-many-arguments - def _export_cmd(self, table: str, - export_dir: Optional[str] = None, - input_null_string: Optional[str] = None, - input_null_non_string: Optional[str] = None, - staging_table: Optional[str] = None, - clear_staging_table: bool = False, - enclosed_by: Optional[str] = None, - escaped_by: Optional[str] = None, - input_fields_terminated_by: Optional[str] = None, - input_lines_terminated_by: Optional[str] = None, - input_optionally_enclosed_by: Optional[str] = None, - batch: bool = False, - relaxed_isolation: bool = False, - extra_export_options: Optional[Dict[str, Any]] = None - ) -> List[str]: + def _export_cmd( + self, + table: str, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + relaxed_isolation: bool = False, + extra_export_options: Optional[Dict[str, Any]] = None, + ) -> List[str]: cmd = self._prepare_command(export=True) @@ -325,8 +330,7 @@ def _export_cmd(self, table: str, cmd += ["--input-lines-terminated-by", input_lines_terminated_by] if input_optionally_enclosed_by: - cmd += ["--input-optionally-enclosed-by", - input_optionally_enclosed_by] + cmd += ["--input-optionally-enclosed-by", input_optionally_enclosed_by] if batch: cmd += ["--batch"] @@ -349,22 +353,23 @@ def _export_cmd(self, table: str, return cmd # pylint: disable=too-many-arguments - def export_table(self, - table: str, - export_dir: Optional[str] = None, - input_null_string: Optional[str] = None, - input_null_non_string: Optional[str] = None, - staging_table: Optional[str] = None, - clear_staging_table: bool = False, - enclosed_by: Optional[str] = None, - escaped_by: Optional[str] = None, - input_fields_terminated_by: Optional[str] = None, - input_lines_terminated_by: Optional[str] = None, - input_optionally_enclosed_by: Optional[str] = None, - batch: bool = False, - relaxed_isolation: bool = False, - extra_export_options: Optional[Dict[str, Any]] = None - ) -> None: + def export_table( + self, + table: str, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + relaxed_isolation: bool = False, + extra_export_options: Optional[Dict[str, Any]] = None, + ) -> None: """ Exports Hive table to remote location. Arguments are copies of direct sqoop command line Arguments @@ -391,12 +396,21 @@ def export_table(self, If a key doesn't have a value, just pass an empty string to it. Don't include prefix of -- for sqoop options. """ - cmd = self._export_cmd(table, export_dir, input_null_string, - input_null_non_string, staging_table, - clear_staging_table, enclosed_by, escaped_by, - input_fields_terminated_by, - input_lines_terminated_by, - input_optionally_enclosed_by, batch, - relaxed_isolation, extra_export_options) + cmd = self._export_cmd( + table, + export_dir, + input_null_string, + input_null_non_string, + staging_table, + clear_staging_table, + enclosed_by, + escaped_by, + input_fields_terminated_by, + input_lines_terminated_by, + input_optionally_enclosed_by, + batch, + relaxed_isolation, + extra_export_options, + ) self.popen(cmd) diff --git a/airflow/providers/apache/sqoop/operators/sqoop.py b/airflow/providers/apache/sqoop/operators/sqoop.py index 3400360fc5d40..514d83fb7fa07 100644 --- a/airflow/providers/apache/sqoop/operators/sqoop.py +++ b/airflow/providers/apache/sqoop/operators/sqoop.py @@ -83,53 +83,74 @@ class SqoopOperator(BaseOperator): If a key doesn't have a value, just pass an empty string to it. Don't include prefix of -- for sqoop options. """ - template_fields = ('conn_id', 'cmd_type', 'table', 'query', 'target_dir', - 'file_type', 'columns', 'split_by', - 'where', 'export_dir', 'input_null_string', - 'input_null_non_string', 'staging_table', - 'enclosed_by', 'escaped_by', 'input_fields_terminated_by', - 'input_lines_terminated_by', 'input_optionally_enclosed_by', - 'properties', 'extra_import_options', 'driver', - 'extra_export_options', 'hcatalog_database', 'hcatalog_table',) + + template_fields = ( + 'conn_id', + 'cmd_type', + 'table', + 'query', + 'target_dir', + 'file_type', + 'columns', + 'split_by', + 'where', + 'export_dir', + 'input_null_string', + 'input_null_non_string', + 'staging_table', + 'enclosed_by', + 'escaped_by', + 'input_fields_terminated_by', + 'input_lines_terminated_by', + 'input_optionally_enclosed_by', + 'properties', + 'extra_import_options', + 'driver', + 'extra_export_options', + 'hcatalog_database', + 'hcatalog_table', + ) ui_color = '#7D8CA4' # pylint: disable=too-many-arguments,too-many-locals @apply_defaults - def __init__(self, *, - conn_id: str = 'sqoop_default', - cmd_type: str = 'import', - table: Optional[str] = None, - query: Optional[str] = None, - target_dir: Optional[str] = None, - append: bool = False, - file_type: str = 'text', - columns: Optional[str] = None, - num_mappers: Optional[int] = None, - split_by: Optional[str] = None, - where: Optional[str] = None, - export_dir: Optional[str] = None, - input_null_string: Optional[str] = None, - input_null_non_string: Optional[str] = None, - staging_table: Optional[str] = None, - clear_staging_table: bool = False, - enclosed_by: Optional[str] = None, - escaped_by: Optional[str] = None, - input_fields_terminated_by: Optional[str] = None, - input_lines_terminated_by: Optional[str] = None, - input_optionally_enclosed_by: Optional[str] = None, - batch: bool = False, - direct: bool = False, - driver: Optional[Any] = None, - verbose: bool = False, - relaxed_isolation: bool = False, - properties: Optional[Dict[str, Any]] = None, - hcatalog_database: Optional[str] = None, - hcatalog_table: Optional[str] = None, - create_hcatalog_table: bool = False, - extra_import_options: Optional[Dict[str, Any]] = None, - extra_export_options: Optional[Dict[str, Any]] = None, - **kwargs: Any - ) -> None: + def __init__( + self, + *, + conn_id: str = 'sqoop_default', + cmd_type: str = 'import', + table: Optional[str] = None, + query: Optional[str] = None, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = 'text', + columns: Optional[str] = None, + num_mappers: Optional[int] = None, + split_by: Optional[str] = None, + where: Optional[str] = None, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + direct: bool = False, + driver: Optional[Any] = None, + verbose: bool = False, + relaxed_isolation: bool = False, + properties: Optional[Dict[str, Any]] = None, + hcatalog_database: Optional[str] = None, + hcatalog_table: Optional[str] = None, + create_hcatalog_table: bool = False, + extra_import_options: Optional[Dict[str, Any]] = None, + extra_export_options: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.conn_id = conn_id self.cmd_type = cmd_type @@ -187,7 +208,8 @@ def execute(self, context: Dict[str, Any]) -> None: input_optionally_enclosed_by=self.input_optionally_enclosed_by, batch=self.batch, relaxed_isolation=self.relaxed_isolation, - extra_export_options=self.extra_export_options) + extra_export_options=self.extra_export_options, + ) elif self.cmd_type == 'import': # add create hcatalog table to extra import options if option passed # if new params are added to constructor can pass them in here @@ -196,9 +218,7 @@ def execute(self, context: Dict[str, Any]) -> None: self.extra_import_options['create-hcatalog-table'] = '' if self.table and self.query: - raise AirflowException( - 'Cannot specify query and table together. Need to specify either or.' - ) + raise AirflowException('Cannot specify query and table together. Need to specify either or.') if self.table: self.hook.import_table( @@ -211,7 +231,8 @@ def execute(self, context: Dict[str, Any]) -> None: where=self.where, direct=self.direct, driver=self.driver, - extra_import_options=self.extra_import_options) + extra_import_options=self.extra_import_options, + ) elif self.query: self.hook.import_query( query=self.query, @@ -221,11 +242,10 @@ def execute(self, context: Dict[str, Any]) -> None: split_by=self.split_by, direct=self.direct, driver=self.driver, - extra_import_options=self.extra_import_options) - else: - raise AirflowException( - "Provide query or table parameter to import using Sqoop" + extra_import_options=self.extra_import_options, ) + else: + raise AirflowException("Provide query or table parameter to import using Sqoop") else: raise AirflowException("cmd_type should be 'import' or 'export'") @@ -242,5 +262,5 @@ def _get_hook(self) -> SqoopHook: num_mappers=self.num_mappers, hcatalog_database=self.hcatalog_database, hcatalog_table=self.hcatalog_table, - properties=self.properties + properties=self.properties, ) diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py index ff0b466c145f9..d426562f32d24 100644 --- a/airflow/providers/celery/sensors/celery_queue.py +++ b/airflow/providers/celery/sensors/celery_queue.py @@ -35,12 +35,9 @@ class CeleryQueueSensor(BaseSensorOperator): :param target_task_id: Task id for checking :type target_task_id: str """ + @apply_defaults - def __init__( - self, *, - celery_queue: str, - target_task_id: Optional[str] = None, - **kwargs) -> None: + def __init__(self, *, celery_queue: str, target_task_id: Optional[str] = None, **kwargs) -> None: super().__init__(**kwargs) self.celery_queue = celery_queue @@ -76,14 +73,8 @@ def poke(self, context: Dict[str, Any]) -> bool: scheduled = len(scheduled[self.celery_queue]) active = len(active[self.celery_queue]) - self.log.info( - 'Checking if celery queue %s is empty.', self.celery_queue - ) + self.log.info('Checking if celery queue %s is empty.', self.celery_queue) return reserved == 0 and scheduled == 0 and active == 0 except KeyError: - raise KeyError( - 'Could not locate Celery queue {0}'.format( - self.celery_queue - ) - ) + raise KeyError('Could not locate Celery queue {0}'.format(self.celery_queue)) diff --git a/airflow/providers/cloudant/hooks/cloudant.py b/airflow/providers/cloudant/hooks/cloudant.py index 5dcdf25a4f99b..57e126af434b8 100644 --- a/airflow/providers/cloudant/hooks/cloudant.py +++ b/airflow/providers/cloudant/hooks/cloudant.py @@ -60,5 +60,6 @@ def get_conn(self) -> cloudant: def _validate_connection(self, conn: cloudant) -> None: for conn_param in ['login', 'password']: if not getattr(conn, conn_param): - raise AirflowException('missing connection parameter {conn_param}'.format( - conn_param=conn_param)) + raise AirflowException( + 'missing connection parameter {conn_param}'.format(conn_param=conn_param) + ) diff --git a/airflow/providers/cncf/kubernetes/example_dags/example_kubernetes.py b/airflow/providers/cncf/kubernetes/example_dags/example_kubernetes.py index 17f8e56970d42..708dd06990419 100644 --- a/airflow/providers/cncf/kubernetes/example_dags/example_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/example_dags/example_kubernetes.py @@ -34,35 +34,19 @@ secret_file = Secret('volume', '/etc/sql_conn', 'airflow-secrets', 'sql_alchemy_conn') secret_env = Secret('env', 'SQL_CONN', 'airflow-secrets', 'sql_alchemy_conn') secret_all_keys = Secret('env', None, 'airflow-secrets-2') -volume_mount = VolumeMount('test-volume', - mount_path='/root/mount_file', - sub_path=None, - read_only=True) +volume_mount = VolumeMount('test-volume', mount_path='/root/mount_file', sub_path=None, read_only=True) configmaps = ['test-configmap-1', 'test-configmap-2'] -volume_config = { - 'persistentVolumeClaim': { - 'claimName': 'test-volume' - } -} +volume_config = {'persistentVolumeClaim': {'claimName': 'test-volume'}} volume = Volume(name='test-volume', configs=volume_config) # [END howto_operator_k8s_cluster_resources] port = Port('http', 80) -init_container_volume_mounts = [k8s.V1VolumeMount( - mount_path='/etc/foo', - name='test-volume', - sub_path=None, - read_only=True -)] +init_container_volume_mounts = [ + k8s.V1VolumeMount(mount_path='/etc/foo', name='test-volume', sub_path=None, read_only=True) +] -init_environments = [k8s.V1EnvVar( - name='key1', - value='value1' -), k8s.V1EnvVar( - name='key2', - value='value2' -)] +init_environments = [k8s.V1EnvVar(name='key1', value='value1'), k8s.V1EnvVar(name='key2', value='value2')] init_container = k8s.V1Container( name="init-container", @@ -70,53 +54,41 @@ env=init_environments, volume_mounts=init_container_volume_mounts, command=["bash", "-cx"], - args=["echo 10"] + args=["echo 10"], ) affinity = { 'nodeAffinity': { - 'preferredDuringSchedulingIgnoredDuringExecution': [{ - "weight": 1, - "preference": { - "matchExpressions": { - "key": "disktype", - "operator": "In", - "values": ["ssd"] - } + 'preferredDuringSchedulingIgnoredDuringExecution': [ + { + "weight": 1, + "preference": {"matchExpressions": {"key": "disktype", "operator": "In", "values": ["ssd"]}}, } - }] + ] }, "podAffinity": { - "requiredDuringSchedulingIgnoredDuringExecution": [{ - "labelSelector": { - "matchExpressions": [{ - "key": "security", - "operator": "In", - "values": ["S1"] - }] - }, - "topologyKey": "failure-domain.beta.kubernetes.io/zone" - }] + "requiredDuringSchedulingIgnoredDuringExecution": [ + { + "labelSelector": { + "matchExpressions": [{"key": "security", "operator": "In", "values": ["S1"]}] + }, + "topologyKey": "failure-domain.beta.kubernetes.io/zone", + } + ] }, "podAntiAffinity": { - "requiredDuringSchedulingIgnoredDuringExecution": [{ - "labelSelector": { - "matchExpressions": [{ - "key": "security", - "operator": "In", - "values": ["S2"] - }] - }, - "topologyKey": "kubernetes.io/hostname" - }] - } + "requiredDuringSchedulingIgnoredDuringExecution": [ + { + "labelSelector": { + "matchExpressions": [{"key": "security", "operator": "In", "values": ["S2"]}] + }, + "topologyKey": "kubernetes.io/hostname", + } + ] + }, } -tolerations = [{ - 'key': "key", - 'operator': 'Equal', - 'value': 'value' -}] +tolerations = [{'key': "key", 'operator': 'Equal', 'value': 'value'}] default_args = { @@ -163,7 +135,7 @@ is_delete_operator_pod=True, in_cluster=True, task_id="task-two", - get_logs=True + get_logs=True, ) # [END howto_operator_k8s_private_image] @@ -177,7 +149,7 @@ is_delete_operator_pod=True, in_cluster=True, task_id="write-xcom", - get_logs=True + get_logs=True, ) pod_task_xcom_result = BashOperator( diff --git a/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py b/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py index c88b73bf78216..811afdb95c2b0 100644 --- a/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py @@ -30,6 +30,7 @@ # [START import_module] # The DAG object; we'll need this to instantiate a DAG from airflow import DAG + # Operators; we need this to operate! from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import SparkKubernetesSensor @@ -46,7 +47,7 @@ 'email': ['airflow@example.com'], 'email_on_failure': False, 'email_on_retry': False, - 'max_active_runs': 1 + 'max_active_runs': 1, } # [END default_args] @@ -74,6 +75,6 @@ namespace="default", application_name="{{ task_instance.xcom_pull(task_ids='spark_pi_submit')['metadata']['name'] }}", kubernetes_conn_id="kubernetes_default", - dag=dag + dag=dag, ) t1 >> t2 diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index f95f99ebee16c..9854ef8a4f0ed 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -40,10 +40,7 @@ class KubernetesHook(BaseHook): :type conn_id: str """ - def __init__( - self, - conn_id: str = "kubernetes_default" - ): + def __init__(self, conn_id: str = "kubernetes_default"): super().__init__() self.conn_id = conn_id @@ -67,13 +64,9 @@ def get_conn(self): config.load_kube_config(temp_config.name) return client.ApiClient() - def create_custom_resource_definition(self, - group: str, - version: str, - plural: str, - body: Union[str, dict], - namespace: Optional[str] = None - ): + def create_custom_resource_definition( + self, group: str, version: str, plural: str, body: Union[str, dict], namespace: Optional[str] = None + ): """ Creates custom resource definition object in Kubernetes @@ -95,23 +88,16 @@ def create_custom_resource_definition(self, body = _load_body_to_dict(body) try: response = api.create_namespaced_custom_object( - group=group, - version=version, - namespace=namespace, - plural=plural, - body=body + group=group, version=version, namespace=namespace, plural=plural, body=body ) self.log.debug("Response: %s", response) return response except client.rest.ApiException as e: raise AirflowException("Exception when calling -> create_custom_resource_definition: %s\n" % e) - def get_custom_resource_definition(self, - group: str, - version: str, - plural: str, - name: str, - namespace: Optional[str] = None): + def get_custom_resource_definition( + self, group: str, version: str, plural: str, name: str, namespace: Optional[str] = None + ): """ Get custom resource definition object from Kubernetes @@ -131,11 +117,7 @@ def get_custom_resource_definition(self, namespace = self.get_namespace() try: response = custom_resource_definition_api.get_namespaced_custom_object( - group=group, - version=version, - namespace=namespace, - plural=plural, - name=name + group=group, version=version, namespace=namespace, plural=plural, name=name ) return response except client.rest.ApiException as e: diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 80f980d6459ee..b971e8f2b5b09 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -43,11 +43,14 @@ class SparkKubernetesOperator(BaseOperator): ui_color = '#f4a460' @apply_defaults - def __init__(self, *, - application_file: str, - namespace: Optional[str] = None, - kubernetes_conn_id: str = 'kubernetes_default', - **kwargs) -> None: + def __init__( + self, + *, + application_file: str, + namespace: Optional[str] = None, + kubernetes_conn_id: str = 'kubernetes_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.application_file = application_file self.namespace = namespace @@ -61,5 +64,6 @@ def execute(self, context): version="v1beta2", plural="sparkapplications", body=self.application_file, - namespace=self.namespace) + namespace=self.namespace, + ) return response diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 934be58361cfc..2a62835f5a02f 100644 --- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -44,11 +44,14 @@ class SparkKubernetesSensor(BaseSensorOperator): SUCCESS_STATES = ('COMPLETED',) @apply_defaults - def __init__(self, *, - application_name: str, - namespace: Optional[str] = None, - kubernetes_conn_id: str = 'kubernetes_default', - **kwargs): + def __init__( + self, + *, + application_name: str, + namespace: Optional[str] = None, + kubernetes_conn_id: str = 'kubernetes_default', + **kwargs, + ): super().__init__(**kwargs) self.application_name = application_name self.namespace = namespace @@ -62,7 +65,8 @@ def poke(self, context: Dict): version="v1beta2", plural="sparkapplications", name=self.application_name, - namespace=self.namespace) + namespace=self.namespace, + ) try: application_state = response['status']['applicationState']['state'] except KeyError: diff --git a/airflow/providers/databricks/example_dags/example_databricks.py b/airflow/providers/databricks/example_dags/example_databricks.py index 55e5c53e915b3..4bc16013b645a 100644 --- a/airflow/providers/databricks/example_dags/example_databricks.py +++ b/airflow/providers/databricks/example_dags/example_databricks.py @@ -51,37 +51,24 @@ new_cluster = { 'spark_version': '2.1.0-db3-scala2.11', 'node_type_id': 'r3.xlarge', - 'aws_attributes': { - 'availability': 'ON_DEMAND' - }, - 'num_workers': 8 + 'aws_attributes': {'availability': 'ON_DEMAND'}, + 'num_workers': 8, } notebook_task_params = { 'new_cluster': new_cluster, - 'notebook_task': { - 'notebook_path': '/Users/airflow@example.com/PrepareData', - }, + 'notebook_task': {'notebook_path': '/Users/airflow@example.com/PrepareData',}, } # Example of using the JSON parameter to initialize the operator. - notebook_task = DatabricksSubmitRunOperator( - task_id='notebook_task', - json=notebook_task_params - ) + notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params) # Example of using the named parameters of DatabricksSubmitRunOperator # to initialize the operator. spark_jar_task = DatabricksSubmitRunOperator( task_id='spark_jar_task', new_cluster=new_cluster, - spark_jar_task={ - 'main_class_name': 'com.example.ProcessData' - }, - libraries=[ - { - 'jar': 'dbfs:/lib/etl-0.1.jar' - } - ] + spark_jar_task={'main_class_name': 'com.example.ProcessData'}, + libraries=[{'jar': 'dbfs:/lib/etl-0.1.jar'}], ) notebook_task >> spark_jar_task diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index cba85942cf5aa..608d63ad0fb5c 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -48,6 +48,7 @@ class RunState: """ Utility class for the run state concept of Databricks runs. """ + def __init__(self, life_cycle_state, result_state, state_message): self.life_cycle_state = life_cycle_state self.result_state = result_state @@ -58,10 +59,12 @@ def is_terminal(self) -> bool: """True if the current state is a terminal state.""" if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES: raise AirflowException( - ('Unexpected life cycle state: {}: If the state has ' - 'been introduced recently, please check the Databricks user ' - 'guide for troubleshooting information').format( - self.life_cycle_state)) + ( + 'Unexpected life cycle state: {}: If the state has ' + 'been introduced recently, please check the Databricks user ' + 'guide for troubleshooting information' + ).format(self.life_cycle_state) + ) return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR') @property @@ -70,9 +73,11 @@ def is_successful(self) -> bool: return self.result_state == 'SUCCESS' def __eq__(self, other): - return self.life_cycle_state == other.life_cycle_state and \ - self.result_state == other.result_state and \ - self.state_message == other.state_message + return ( + self.life_cycle_state == other.life_cycle_state + and self.result_state == other.result_state + and self.state_message == other.state_message + ) def __repr__(self): return str(self.__dict__) @@ -94,8 +99,10 @@ class DatabricksHook(BaseHook): # noqa might be a floating point number). :type retry_delay: float """ - def __init__(self, databricks_conn_id='databricks_default', timeout_seconds=180, retry_limit=3, - retry_delay=1.0): + + def __init__( + self, databricks_conn_id='databricks_default', timeout_seconds=180, retry_limit=3, retry_delay=1.0 + ): super().__init__() self.databricks_conn_id = databricks_conn_id self.databricks_conn = self.get_connection(databricks_conn_id) @@ -156,9 +163,7 @@ def _do_api_call(self, endpoint_info, json): auth = (self.databricks_conn.login, self.databricks_conn.password) host = self.databricks_conn.host - url = 'https://{host}/{endpoint}'.format( - host=self._parse_host(host), - endpoint=endpoint) + url = 'https://{host}/{endpoint}'.format(host=self._parse_host(host), endpoint=endpoint) if method == 'GET': request_func = requests.get @@ -178,30 +183,30 @@ def _do_api_call(self, endpoint_info, json): params=json if method == 'GET' else None, auth=auth, headers=USER_AGENT_HEADER, - timeout=self.timeout_seconds) + timeout=self.timeout_seconds, + ) response.raise_for_status() return response.json() except requests_exceptions.RequestException as e: if not _retryable_error(e): # In this case, the user probably made a mistake. # Don't retry. - raise AirflowException('Response: {0}, Status Code: {1}'.format( - e.response.content, e.response.status_code)) + raise AirflowException( + 'Response: {0}, Status Code: {1}'.format(e.response.content, e.response.status_code) + ) self._log_request_error(attempt_num, e) if attempt_num == self.retry_limit: - raise AirflowException(('API requests to Databricks failed {} times. ' + - 'Giving up.').format(self.retry_limit)) + raise AirflowException( + ('API requests to Databricks failed {} times. ' + 'Giving up.').format(self.retry_limit) + ) attempt_num += 1 sleep(self.retry_delay) def _log_request_error(self, attempt_num, error): - self.log.error( - 'Attempt %s API Request to Databricks failed with reason: %s', - attempt_num, error - ) + self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) def run_now(self, json): """ @@ -301,18 +306,14 @@ def terminate_cluster(self, json: dict) -> None: def _retryable_error(exception): - return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) \ - or exception.response is not None and exception.response.status_code >= 500 + return ( + isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) + or exception.response is not None + and exception.response.status_code >= 500 + ) -RUN_LIFE_CYCLE_STATES = [ - 'PENDING', - 'RUNNING', - 'TERMINATING', - 'TERMINATED', - 'SKIPPED', - 'INTERNAL_ERROR' -] +RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] class _TokenAuth(AuthBase): @@ -320,6 +321,7 @@ class _TokenAuth(AuthBase): Helper class for requests Auth field. AuthBase requires you to implement the __call__ magic function. """ + def __init__(self, token): self.token = token diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 3bc8f5a67d2f9..3e2a2b8422ae1 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -49,12 +49,10 @@ def _deep_string_coerce(content, json_path='json'): elif isinstance(content, (list, tuple)): return [coerce(e, '{0}[{1}]'.format(json_path, i)) for i, e in enumerate(content)] elif isinstance(content, dict): - return {k: coerce(v, '{0}[{1}]'.format(json_path, k)) - for k, v in list(content.items())} + return {k: coerce(v, '{0}[{1}]'.format(json_path, k)) for k, v in list(content.items())} else: param_type = type(content) - msg = 'Type {0} used for parameter {1} is not a number or a string' \ - .format(param_type, json_path) + msg = 'Type {0} used for parameter {1} is not a number or a string'.format(param_type, json_path) raise AirflowException(msg) @@ -81,9 +79,7 @@ def _handle_databricks_operator_execution(operator, hook, log, context): log.info('View run status, Spark UI, and logs at %s', run_page_url) return else: - error_message = '{t} failed with terminal state: {s}'.format( - t=operator.task_id, - s=run_state) + error_message = '{t} failed with terminal state: {s}'.format(t=operator.task_id, s=run_state) raise AirflowException(error_message) else: log.info('%s in run state: %s', operator.task_id, run_state) @@ -236,6 +232,7 @@ class DatabricksSubmitRunOperator(BaseOperator): :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :type do_xcom_push: bool """ + # Used in airflow.models.BaseOperator template_fields = ('json',) # Databricks brand color (blue) under white text @@ -245,23 +242,25 @@ class DatabricksSubmitRunOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, - json=None, - spark_jar_task=None, - notebook_task=None, - spark_python_task=None, - spark_submit_task=None, - new_cluster=None, - existing_cluster_id=None, - libraries=None, - run_name=None, - timeout_seconds=None, - databricks_conn_id='databricks_default', - polling_period_seconds=30, - databricks_retry_limit=3, - databricks_retry_delay=1, - do_xcom_push=False, - **kwargs): + self, + *, + json=None, + spark_jar_task=None, + notebook_task=None, + spark_python_task=None, + spark_submit_task=None, + new_cluster=None, + existing_cluster_id=None, + libraries=None, + run_name=None, + timeout_seconds=None, + databricks_conn_id='databricks_default', + polling_period_seconds=30, + databricks_retry_limit=3, + databricks_retry_delay=1, + do_xcom_push=False, + **kwargs, + ): """ Creates a new ``DatabricksSubmitRunOperator``. """ @@ -301,7 +300,8 @@ def _get_hook(self): return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay) + retry_delay=self.databricks_retry_delay, + ) def execute(self, context): hook = self._get_hook() @@ -311,10 +311,7 @@ def execute(self, context): def on_kill(self): hook = self._get_hook() hook.cancel_run(self.run_id) - self.log.info( - 'Task: %s with run_id: %s was requested to be cancelled.', - self.task_id, self.run_id - ) + self.log.info('Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id) class DatabricksRunNowOperator(BaseOperator): @@ -448,6 +445,7 @@ class DatabricksRunNowOperator(BaseOperator): :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :type do_xcom_push: bool """ + # Used in airflow.models.BaseOperator template_fields = ('json',) # Databricks brand color (blue) under white text @@ -457,18 +455,20 @@ class DatabricksRunNowOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, - job_id=None, - json=None, - notebook_params=None, - python_params=None, - spark_submit_params=None, - databricks_conn_id='databricks_default', - polling_period_seconds=30, - databricks_retry_limit=3, - databricks_retry_delay=1, - do_xcom_push=False, - **kwargs): + self, + *, + job_id=None, + json=None, + notebook_params=None, + python_params=None, + spark_submit_params=None, + databricks_conn_id='databricks_default', + polling_period_seconds=30, + databricks_retry_limit=3, + databricks_retry_delay=1, + do_xcom_push=False, + **kwargs, + ): """ Creates a new ``DatabricksRunNowOperator``. """ @@ -497,7 +497,8 @@ def _get_hook(self): return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay) + retry_delay=self.databricks_retry_delay, + ) def execute(self, context): hook = self._get_hook() @@ -507,7 +508,4 @@ def execute(self, context): def on_kill(self): hook = self._get_hook() hook.cancel_run(self.run_id) - self.log.info( - 'Task: %s with run_id: %s was requested to be cancelled.', - self.task_id, self.run_id - ) + self.log.info('Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id) diff --git a/airflow/providers/datadog/hooks/datadog.py b/airflow/providers/datadog/hooks/datadog.py index 7f6163328a297..e81118207321e 100644 --- a/airflow/providers/datadog/hooks/datadog.py +++ b/airflow/providers/datadog/hooks/datadog.py @@ -38,6 +38,7 @@ class DatadogHook(BaseHook, LoggingMixin): :param datadog_conn_id: The connection to datadog, containing metadata for api keys. :param datadog_conn_id: str """ + def __init__(self, datadog_conn_id: str = 'datadog_default') -> None: super().__init__() conn = self.get_connection(datadog_conn_id) @@ -50,8 +51,7 @@ def __init__(self, datadog_conn_id: str = 'datadog_default') -> None: self.host = conn.host if self.api_key is None: - raise AirflowException("api_key must be specified in the " - "Datadog connection details") + raise AirflowException("api_key must be specified in the " "Datadog connection details") self.log.info("Setting up api keys for Datadog") initialize(api_key=self.api_key, app_key=self.app_key) @@ -64,11 +64,14 @@ def validate_response(self, response: Dict[str, Any]) -> None: self.log.error("Datadog returned: %s", response) raise AirflowException("Error status received from Datadog") - def send_metric(self, metric_name: str, - datapoint: Union[float, int], - tags: Optional[List[str]] = None, - type_: Optional[str] = None, - interval: Optional[int] = None) -> Dict[str, Any]: + def send_metric( + self, + metric_name: str, + datapoint: Union[float, int], + tags: Optional[List[str]] = None, + type_: Optional[str] = None, + interval: Optional[int] = None, + ) -> Dict[str, Any]: """ Sends a single datapoint metric to DataDog @@ -84,20 +87,13 @@ def send_metric(self, metric_name: str, :type interval: int """ response = api.Metric.send( - metric=metric_name, - points=datapoint, - host=self.host, - tags=tags, - type=type_, - interval=interval) + metric=metric_name, points=datapoint, host=self.host, tags=tags, type=type_, interval=interval + ) self.validate_response(response) return response - def query_metric(self, - query: str, - from_seconds_ago: int, - to_seconds_ago: int) -> Dict[str, Any]: + def query_metric(self, query: str, from_seconds_ago: int, to_seconds_ago: int) -> Dict[str, Any]: """ Queries datadog for a specific metric, potentially with some function applied to it and returns the results. @@ -111,25 +107,25 @@ def query_metric(self, """ now = int(time.time()) - response = api.Metric.query( - start=now - from_seconds_ago, - end=now - to_seconds_ago, - query=query) + response = api.Metric.query(start=now - from_seconds_ago, end=now - to_seconds_ago, query=query) self.validate_response(response) return response # pylint: disable=too-many-arguments - def post_event(self, title: str, - text: str, - aggregation_key: Optional[str] = None, - alert_type: Optional[str] = None, - date_happened: Optional[int] = None, - handle: Optional[str] = None, - priority: Optional[str] = None, - related_event_id: Optional[int] = None, - tags: Optional[List[str]] = None, - device_name: Optional[List[str]] = None) -> Dict[str, Any]: + def post_event( + self, + title: str, + text: str, + aggregation_key: Optional[str] = None, + alert_type: Optional[str] = None, + date_happened: Optional[int] = None, + handle: Optional[str] = None, + priority: Optional[str] = None, + related_event_id: Optional[int] = None, + tags: Optional[List[str]] = None, + device_name: Optional[List[str]] = None, + ) -> Dict[str, Any]: """ Posts an event to datadog (processing finished, potentially alerts, other issues) Think about this as a means to maintain persistence of alerts, rather than @@ -170,7 +166,8 @@ def post_event(self, title: str, tags=tags, host=self.host, device_name=device_name, - source_type_name=self.source_type_name) + source_type_name=self.source_type_name, + ) self.validate_response(response) return response diff --git a/airflow/providers/datadog/sensors/datadog.py b/airflow/providers/datadog/sensors/datadog.py index 29969d977f1cc..ec298950afa3a 100644 --- a/airflow/providers/datadog/sensors/datadog.py +++ b/airflow/providers/datadog/sensors/datadog.py @@ -36,19 +36,22 @@ class DatadogSensor(BaseSensorOperator): :param datadog_conn_id: The connection to datadog, containing metadata for api keys. :param datadog_conn_id: str """ + ui_color = '#66c3dd' @apply_defaults def __init__( - self, *, - datadog_conn_id: str = 'datadog_default', - from_seconds_ago: int = 3600, - up_to_seconds_from_now: int = 0, - priority: Optional[str] = None, - sources: Optional[str] = None, - tags: Optional[List[str]] = None, - response_check: Optional[Callable[[Dict[str, Any]], bool]] = None, - **kwargs) -> None: + self, + *, + datadog_conn_id: str = 'datadog_default', + from_seconds_ago: int = 3600, + up_to_seconds_from_now: int = 0, + priority: Optional[str] = None, + sources: Optional[str] = None, + tags: Optional[List[str]] = None, + response_check: Optional[Callable[[Dict[str, Any]], bool]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.datadog_conn_id = datadog_conn_id self.from_seconds_ago = from_seconds_ago @@ -70,7 +73,8 @@ def poke(self, context: Dict[str, Any]) -> bool: end=self.up_to_seconds_from_now, priority=self.priority, sources=self.sources, - tags=self.tags) + tags=self.tags, + ) if isinstance(response, dict) and response.get('status', 'ok') != 'ok': self.log.error("Unexpected Datadog result: %s", response) diff --git a/airflow/providers/dingding/example_dags/example_dingding.py b/airflow/providers/dingding/example_dags/example_dingding.py index f2b483b7e4626..003d0da76af1c 100644 --- a/airflow/providers/dingding/example_dags/example_dingding.py +++ b/airflow/providers/dingding/example_dags/example_dingding.py @@ -38,13 +38,14 @@ def failure_callback(context): :param context: The context of the executed task. :type context: dict """ - message = 'AIRFLOW TASK FAILURE TIPS:\n' \ - 'DAG: {}\n' \ - 'TASKS: {}\n' \ - 'Reason: {}\n' \ - .format(context['task_instance'].dag_id, - context['task_instance'].task_id, - context['exception']) + message = ( + 'AIRFLOW TASK FAILURE TIPS:\n' + 'DAG: {}\n' + 'TASKS: {}\n' + 'Reason: {}\n'.format( + context['task_instance'].dag_id, context['task_instance'].task_id, context['exception'] + ) + ) return DingdingOperator( task_id='dingding_success_callback', dingding_conn_id='dingding_default', @@ -73,7 +74,7 @@ def failure_callback(context): message_type='text', message='Airflow dingding text message remind none', at_mobiles=None, - at_all=False + at_all=False, ) # [END howto_operator_dingding] @@ -83,7 +84,7 @@ def failure_callback(context): message_type='text', message='Airflow dingding text message remind specific users', at_mobiles=['156XXXXXXXX', '130XXXXXXXX'], - at_all=False + at_all=False, ) text_msg_remind_include_invalid = DingdingOperator( @@ -93,7 +94,7 @@ def failure_callback(context): message='Airflow dingding text message remind users including invalid', # 123 is invalid user or user not in the group at_mobiles=['156XXXXXXXX', '123'], - at_all=False + at_all=False, ) # [START howto_operator_dingding_remind_users] @@ -105,7 +106,7 @@ def failure_callback(context): # list of user phone/email here in the group # when at_all is specific will cover at_mobiles at_mobiles=['156XXXXXXXX', '130XXXXXXXX'], - at_all=True + at_all=True, ) # [END howto_operator_dingding_remind_users] @@ -117,8 +118,8 @@ def failure_callback(context): 'title': 'Airflow dingding link message', 'text': 'Airflow official documentation link', 'messageUrl': 'http://airflow.apache.org', - 'picURL': 'http://airflow.apache.org/_images/pin_large.png' - } + 'picURL': 'http://airflow.apache.org/_images/pin_large.png', + }, ) # [START howto_operator_dingding_rich_text] @@ -129,12 +130,12 @@ def failure_callback(context): message={ 'title': 'Airflow dingding markdown message', 'text': '# Markdown message title\n' - 'content content .. \n' - '### sub-title\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)' + 'content content .. \n' + '### sub-title\n' + '![logo](http://airflow.apache.org/_images/pin_large.png)', }, at_mobiles=['156XXXXXXXX'], - at_all=False + at_all=False, ) # [END howto_operator_dingding_rich_text] @@ -145,13 +146,13 @@ def failure_callback(context): message={ 'title': 'Airflow dingding single actionCard message', 'text': 'Airflow dingding single actionCard message\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)\n' - 'This is a official logo in Airflow website.', + '![logo](http://airflow.apache.org/_images/pin_large.png)\n' + 'This is a official logo in Airflow website.', 'hideAvatar': '0', 'btnOrientation': '0', 'singleTitle': 'read more', - 'singleURL': 'http://airflow.apache.org' - } + 'singleURL': 'http://airflow.apache.org', + }, ) multi_action_card_msg = DingdingOperator( @@ -161,21 +162,15 @@ def failure_callback(context): message={ 'title': 'Airflow dingding multi actionCard message', 'text': 'Airflow dingding multi actionCard message\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)\n' - 'Airflow documentation and github', + '![logo](http://airflow.apache.org/_images/pin_large.png)\n' + 'Airflow documentation and github', 'hideAvatar': '0', 'btnOrientation': '0', 'btns': [ - { - 'title': 'Airflow Documentation', - 'actionURL': 'http://airflow.apache.org' - }, - { - 'title': 'Airflow Github', - 'actionURL': 'https://github.com/apache/airflow' - } - ] - } + {'title': 'Airflow Documentation', 'actionURL': 'http://airflow.apache.org'}, + {'title': 'Airflow Github', 'actionURL': 'https://github.com/apache/airflow'}, + ], + }, ) feed_card_msg = DingdingOperator( @@ -187,35 +182,35 @@ def failure_callback(context): { "title": "Airflow DAG feed card", "messageURL": "https://airflow.readthedocs.io/en/latest/ui.html", - "picURL": "http://airflow.apache.org/_images/dags.png" + "picURL": "http://airflow.apache.org/_images/dags.png", }, { "title": "Airflow tree feed card", "messageURL": "https://airflow.readthedocs.io/en/latest/ui.html", - "picURL": "http://airflow.apache.org/_images/tree.png" + "picURL": "http://airflow.apache.org/_images/tree.png", }, { "title": "Airflow graph feed card", "messageURL": "https://airflow.readthedocs.io/en/latest/ui.html", - "picURL": "http://airflow.apache.org/_images/graph.png" - } + "picURL": "http://airflow.apache.org/_images/graph.png", + }, ] - } + }, ) msg_failure_callback = DingdingOperator( task_id='msg_failure_callback', dingding_conn_id='dingding_default', message_type='not_support_msg_type', - message="" + message="", ) [ text_msg_remind_none, text_msg_remind_specific, text_msg_remind_include_invalid, - text_msg_remind_all + text_msg_remind_all, ] >> link_msg >> markdown_msg >> [ single_action_card_msg, - multi_action_card_msg + multi_action_card_msg, ] >> feed_card_msg >> msg_failure_callback diff --git a/airflow/providers/dingding/hooks/dingding.py b/airflow/providers/dingding/hooks/dingding.py index 00e8b4e6f4968..c002c408e267e 100644 --- a/airflow/providers/dingding/hooks/dingding.py +++ b/airflow/providers/dingding/hooks/dingding.py @@ -46,15 +46,16 @@ class DingdingHook(HttpHook): :type at_all: bool """ - def __init__(self, - dingding_conn_id='dingding_default', - message_type='text', - message=None, - at_mobiles=None, - at_all=False, - *args, - **kwargs - ): + def __init__( + self, + dingding_conn_id='dingding_default', + message_type='text', + message=None, + at_mobiles=None, + at_all=False, + *args, + **kwargs, + ): super().__init__(http_conn_id=dingding_conn_id, *args, **kwargs) self.message_type = message_type self.message = message @@ -68,8 +69,9 @@ def _get_endpoint(self): conn = self.get_connection(self.http_conn_id) token = conn.password if not token: - raise AirflowException('Dingding token is requests but get nothing, ' - 'check you conn_id configuration.') + raise AirflowException( + 'Dingding token is requests but get nothing, ' 'check you conn_id configuration.' + ) return 'robot/send?access_token={}'.format(token) def _build_message(self): @@ -81,19 +83,11 @@ def _build_message(self): if self.message_type in ['text', 'markdown']: data = { 'msgtype': self.message_type, - self.message_type: { - 'content': self.message - } if self.message_type == 'text' else self.message, - 'at': { - 'atMobiles': self.at_mobiles, - 'isAtAll': self.at_all - } + self.message_type: {'content': self.message} if self.message_type == 'text' else self.message, + 'at': {'atMobiles': self.at_mobiles, 'isAtAll': self.at_all}, } else: - data = { - 'msgtype': self.message_type, - self.message_type: self.message - } + data = {'msgtype': self.message_type, self.message_type: self.message} return json.dumps(data) def get_conn(self, headers=None): @@ -117,17 +111,18 @@ def send(self): """ support_type = ['text', 'link', 'markdown', 'actionCard', 'feedCard'] if self.message_type not in support_type: - raise ValueError('DingdingWebhookHook only support {} ' - 'so far, but receive {}'.format(support_type, self.message_type)) + raise ValueError( + 'DingdingWebhookHook only support {} ' + 'so far, but receive {}'.format(support_type, self.message_type) + ) data = self._build_message() self.log.info('Sending Dingding type %s message %s', self.message_type, data) - resp = self.run(endpoint=self._get_endpoint(), - data=data, - headers={'Content-Type': 'application/json'}) + resp = self.run( + endpoint=self._get_endpoint(), data=data, headers={'Content-Type': 'application/json'} + ) # Dingding success send message will with errcode equal to 0 if int(resp.json().get('errcode')) != 0: - raise AirflowException('Send Dingding message failed, receive error ' - f'message {resp.text}') + raise AirflowException('Send Dingding message failed, receive error ' f'message {resp.text}') self.log.info('Success Send Dingding message') diff --git a/airflow/providers/dingding/operators/dingding.py b/airflow/providers/dingding/operators/dingding.py index 0d1ba6687ca94..c25f8b9f11c64 100644 --- a/airflow/providers/dingding/operators/dingding.py +++ b/airflow/providers/dingding/operators/dingding.py @@ -42,17 +42,21 @@ class DingdingOperator(BaseOperator): :param at_all: Remind all people in group or not. If True, will overwrite ``at_mobiles`` :type at_all: bool """ + template_fields = ('message',) ui_color = '#4ea4d4' # Dingding icon color @apply_defaults - def __init__(self, *, - dingding_conn_id='dingding_default', - message_type='text', - message=None, - at_mobiles=None, - at_all=False, - **kwargs): + def __init__( + self, + *, + dingding_conn_id='dingding_default', + message_type='text', + message=None, + at_mobiles=None, + at_all=False, + **kwargs, + ): super().__init__(**kwargs) self.dingding_conn_id = dingding_conn_id self.message_type = message_type @@ -63,10 +67,6 @@ def __init__(self, *, def execute(self, context): self.log.info('Sending Dingding message.') hook = DingdingHook( - self.dingding_conn_id, - self.message_type, - self.message, - self.at_mobiles, - self.at_all + self.dingding_conn_id, self.message_type, self.message, self.at_mobiles, self.at_all ) hook.send() diff --git a/airflow/providers/discord/hooks/discord_webhook.py b/airflow/providers/discord/hooks/discord_webhook.py index 1d5f199663ee1..b88efd73e2de8 100644 --- a/airflow/providers/discord/hooks/discord_webhook.py +++ b/airflow/providers/discord/hooks/discord_webhook.py @@ -54,17 +54,18 @@ class DiscordWebhookHook(HttpHook): :type proxy: str """ - def __init__(self, - http_conn_id: Optional[str] = None, - webhook_endpoint: Optional[str] = None, - message: str = "", - username: Optional[str] = None, - avatar_url: Optional[str] = None, - tts: bool = False, - proxy: Optional[str] = None, - *args: Any, - **kwargs: Any - ) -> None: + def __init__( + self, + http_conn_id: Optional[str] = None, + webhook_endpoint: Optional[str] = None, + message: str = "", + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + proxy: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.http_conn_id: Any = http_conn_id self.webhook_endpoint = self._get_webhook_endpoint(http_conn_id, webhook_endpoint) @@ -90,13 +91,15 @@ def _get_webhook_endpoint(self, http_conn_id: Optional[str], webhook_endpoint: O extra = conn.extra_dejson endpoint = extra.get('webhook_endpoint', '') else: - raise AirflowException('Cannot get webhook endpoint: No valid Discord ' - 'webhook endpoint or http_conn_id supplied.') + raise AirflowException( + 'Cannot get webhook endpoint: No valid Discord ' 'webhook endpoint or http_conn_id supplied.' + ) # make sure endpoint matches the expected Discord webhook format if not re.match('^webhooks/[0-9]+/[a-zA-Z0-9_-]+$', endpoint): - raise AirflowException('Expected Discord webhook endpoint in the form ' - 'of "webhooks/{webhook.id}/{webhook.token}".') + raise AirflowException( + 'Expected Discord webhook endpoint in the form ' 'of "webhooks/{webhook.id}/{webhook.token}".' + ) return endpoint @@ -119,8 +122,7 @@ def _build_discord_payload(self) -> str: if len(self.message) <= 2000: payload['content'] = self.message else: - raise AirflowException('Discord message length must be 2000 or fewer ' - 'characters.') + raise AirflowException('Discord message length must be 2000 or fewer ' 'characters.') return json.dumps(payload) @@ -135,7 +137,9 @@ def execute(self) -> None: discord_payload = self._build_discord_payload() - self.run(endpoint=self.webhook_endpoint, - data=discord_payload, - headers={'Content-type': 'application/json'}, - extra_options={'proxies': proxies}) + self.run( + endpoint=self.webhook_endpoint, + data=discord_payload, + headers={'Content-type': 'application/json'}, + extra_options={'proxies': proxies}, + ) diff --git a/airflow/providers/discord/operators/discord_webhook.py b/airflow/providers/discord/operators/discord_webhook.py index 6b97920e2ea20..ea8758156e6e0 100644 --- a/airflow/providers/discord/operators/discord_webhook.py +++ b/airflow/providers/discord/operators/discord_webhook.py @@ -57,19 +57,19 @@ class DiscordWebhookOperator(SimpleHttpOperator): template_fields = ['username', 'message'] @apply_defaults - def __init__(self, *, - http_conn_id: Optional[str] = None, - webhook_endpoint: Optional[str] = None, - message: str = "", - username: Optional[str] = None, - avatar_url: Optional[str] = None, - tts: bool = False, - proxy: Optional[str] = None, - **kwargs) -> None: - super().__init__( - endpoint=webhook_endpoint, - **kwargs - ) + def __init__( + self, + *, + http_conn_id: Optional[str] = None, + webhook_endpoint: Optional[str] = None, + message: str = "", + username: Optional[str] = None, + avatar_url: Optional[str] = None, + tts: bool = False, + proxy: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(endpoint=webhook_endpoint, **kwargs) if not http_conn_id: raise AirflowException('No valid Discord http_conn_id supplied.') @@ -94,6 +94,6 @@ def execute(self, context: Dict) -> None: self.username, self.avatar_url, self.tts, - self.proxy + self.proxy, ) self.hook.execute() diff --git a/airflow/providers/docker/example_dags/example_docker.py b/airflow/providers/docker/example_dags/example_docker.py index 627f6300a8093..a5826bc805e9a 100644 --- a/airflow/providers/docker/example_dags/example_docker.py +++ b/airflow/providers/docker/example_dags/example_docker.py @@ -29,7 +29,7 @@ 'email_on_failure': False, 'email_on_retry': False, 'retries': 1, - 'retry_delay': timedelta(minutes=5) + 'retry_delay': timedelta(minutes=5), } dag = DAG( @@ -39,16 +39,9 @@ start_date=days_ago(2), ) -t1 = BashOperator( - task_id='print_date', - bash_command='date', - dag=dag) +t1 = BashOperator(task_id='print_date', bash_command='date', dag=dag) -t2 = BashOperator( - task_id='sleep', - bash_command='sleep 5', - retries=3, - dag=dag) +t2 = BashOperator(task_id='sleep', bash_command='sleep 5', retries=3, dag=dag) t3 = DockerOperator( api_version='1.19', @@ -57,14 +50,11 @@ image='centos:latest', network_mode='bridge', task_id='docker_op_tester', - dag=dag + dag=dag, ) -t4 = BashOperator( - task_id='print_hello', - bash_command='echo "hello world!!!"', - dag=dag) +t4 = BashOperator(task_id='print_hello', bash_command='echo "hello world!!!"', dag=dag) t1 >> t2 diff --git a/airflow/providers/docker/example_dags/example_docker_swarm.py b/airflow/providers/docker/example_dags/example_docker_swarm.py index 7dc056edad0bb..d27337f4d48b3 100644 --- a/airflow/providers/docker/example_dags/example_docker_swarm.py +++ b/airflow/providers/docker/example_dags/example_docker_swarm.py @@ -26,7 +26,7 @@ 'depends_on_past': False, 'email': ['airflow@example.com'], 'email_on_failure': False, - 'email_on_retry': False + 'email_on_retry': False, } dag = DAG( @@ -34,7 +34,7 @@ default_args=default_args, schedule_interval=timedelta(minutes=10), start_date=days_ago(1), - catchup=False + catchup=False, ) with dag as dag: diff --git a/airflow/providers/docker/hooks/docker.py b/airflow/providers/docker/hooks/docker.py index bae0e7f5a046f..f57b273a9f14a 100644 --- a/airflow/providers/docker/hooks/docker.py +++ b/airflow/providers/docker/hooks/docker.py @@ -33,12 +33,14 @@ class DockerHook(BaseHook, LoggingMixin): credentials and extra configuration are stored :type docker_conn_id: str """ - def __init__(self, - docker_conn_id='docker_default', - base_url: Optional[str] = None, - version: Optional[str] = None, - tls: Optional[str] = None - ) -> None: + + def __init__( + self, + docker_conn_id='docker_default', + base_url: Optional[str] = None, + version: Optional[str] = None, + tls: Optional[str] = None, + ) -> None: super().__init__() if not base_url: raise AirflowException('No Docker base URL provided') @@ -65,11 +67,7 @@ def __init__(self, self.__reauth = extra_options.get('reauth') != 'no' def get_conn(self) -> APIClient: - client = APIClient( - base_url=self.__base_url, - version=self.__version, - tls=self.__tls - ) + client = APIClient(base_url=self.__base_url, version=self.__version, tls=self.__tls) self.__login(client) return client @@ -81,7 +79,7 @@ def __login(self, client) -> None: password=self.__password, registry=self.__registry, email=self.__email, - reauth=self.__reauth + reauth=self.__reauth, ) self.log.debug('Login successful') except APIError as docker_error: diff --git a/airflow/providers/docker/operators/docker.py b/airflow/providers/docker/operators/docker.py index 81861ffbe7229..6148f5b367804 100644 --- a/airflow/providers/docker/operators/docker.py +++ b/airflow/providers/docker/operators/docker.py @@ -126,43 +126,49 @@ class DockerOperator(BaseOperator): :param cap_add: Include container capabilities :type cap_add: list[str] """ + template_fields = ('command', 'environment', 'container_name') - template_ext = ('.sh', '.bash',) + template_ext = ( + '.sh', + '.bash', + ) # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__( - self, *, - image: str, - api_version: Optional[str] = None, - command: Optional[Union[str, List[str]]] = None, - container_name: Optional[str] = None, - cpus: float = 1.0, - docker_url: str = 'unix://var/run/docker.sock', - environment: Optional[Dict] = None, - private_environment: Optional[Dict] = None, - force_pull: bool = False, - mem_limit: Optional[Union[float, str]] = None, - host_tmp_dir: Optional[str] = None, - network_mode: Optional[str] = None, - tls_ca_cert: Optional[str] = None, - tls_client_cert: Optional[str] = None, - tls_client_key: Optional[str] = None, - tls_hostname: Optional[Union[str, bool]] = None, - tls_ssl_version: Optional[str] = None, - tmp_dir: str = '/tmp/airflow', - user: Optional[Union[str, int]] = None, - volumes: Optional[List[str]] = None, - working_dir: Optional[str] = None, - xcom_all: bool = False, - docker_conn_id: Optional[str] = None, - dns: Optional[List[str]] = None, - dns_search: Optional[List[str]] = None, - auto_remove: bool = False, - shm_size: Optional[int] = None, - tty: Optional[bool] = False, - cap_add: Optional[Iterable[str]] = None, - **kwargs) -> None: + self, + *, + image: str, + api_version: Optional[str] = None, + command: Optional[Union[str, List[str]]] = None, + container_name: Optional[str] = None, + cpus: float = 1.0, + docker_url: str = 'unix://var/run/docker.sock', + environment: Optional[Dict] = None, + private_environment: Optional[Dict] = None, + force_pull: bool = False, + mem_limit: Optional[Union[float, str]] = None, + host_tmp_dir: Optional[str] = None, + network_mode: Optional[str] = None, + tls_ca_cert: Optional[str] = None, + tls_client_cert: Optional[str] = None, + tls_client_key: Optional[str] = None, + tls_hostname: Optional[Union[str, bool]] = None, + tls_ssl_version: Optional[str] = None, + tmp_dir: str = '/tmp/airflow', + user: Optional[Union[str, int]] = None, + volumes: Optional[List[str]] = None, + working_dir: Optional[str] = None, + xcom_all: bool = False, + docker_conn_id: Optional[str] = None, + dns: Optional[List[str]] = None, + dns_search: Optional[List[str]] = None, + auto_remove: bool = False, + shm_size: Optional[int] = None, + tty: Optional[bool] = False, + cap_add: Optional[Iterable[str]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.api_version = api_version @@ -210,7 +216,7 @@ def get_hook(self) -> DockerHook: docker_conn_id=self.docker_conn_id, base_url=self.docker_url, version=self.api_version, - tls=self.__get_tls_config() + tls=self.__get_tls_config(), ) def _run_image(self) -> Optional[str]: @@ -237,17 +243,15 @@ def _run_image(self) -> Optional[str]: dns_search=self.dns_search, cpu_shares=int(round(self.cpus * 1024)), mem_limit=self.mem_limit, - cap_add=self.cap_add), + cap_add=self.cap_add, + ), image=self.image, user=self.user, working_dir=self.working_dir, tty=self.tty, ) - lines = self.cli.attach(container=self.container['Id'], - stdout=True, - stderr=True, - stream=True) + lines = self.cli.attach(container=self.container['Id'], stdout=True, stderr=True, stream=True) self.cli.start(self.container['Id']) @@ -266,8 +270,7 @@ def _run_image(self) -> Optional[str]: # duplicated conditional logic because of expensive operation ret = None if self.do_xcom_push: - ret = self.cli.logs(container=self.container['Id']) \ - if self.xcom_all else line.encode('utf-8') + ret = self.cli.logs(container=self.container['Id']) if self.xcom_all else line.encode('utf-8') if self.auto_remove: self.cli.remove_container(self.container['Id']) @@ -296,11 +299,7 @@ def _get_cli(self) -> APIClient: return self.get_hook().get_conn() else: tls_config = self.__get_tls_config() - return APIClient( - base_url=self.docker_url, - version=self.api_version, - tls=tls_config - ) + return APIClient(base_url=self.docker_url, version=self.api_version, tls=tls_config) def get_command(self) -> Union[List[str], str]: """ @@ -330,7 +329,7 @@ def __get_tls_config(self) -> Optional[tls.TLSConfig]: client_cert=(self.tls_client_cert, self.tls_client_key), verify=True, ssl_version=self.tls_ssl_version, # noqa - assert_hostname=self.tls_hostname + assert_hostname=self.tls_hostname, ) self.docker_url = self.docker_url.replace('tcp://', 'https://') return tls_config diff --git a/airflow/providers/docker/operators/docker_swarm.py b/airflow/providers/docker/operators/docker_swarm.py index 904fee2bb0fcb..92936ad1bbf3f 100644 --- a/airflow/providers/docker/operators/docker_swarm.py +++ b/airflow/providers/docker/operators/docker_swarm.py @@ -97,12 +97,7 @@ class DockerSwarmOperator(DockerOperator): """ @apply_defaults - def __init__( - self, - *, - image: str, - enable_logging: bool = True, - **kwargs) -> None: + def __init__(self, *, image: str, enable_logging: bool = True, **kwargs) -> None: super().__init__(image=image, **kwargs) self.enable_logging = enable_logging @@ -129,10 +124,10 @@ def _run_service(self) -> None: tty=self.tty, ), restart_policy=types.RestartPolicy(condition='none'), - resources=types.Resources(mem_limit=self.mem_limit) + resources=types.Resources(mem_limit=self.mem_limit), ), name='airflow-%s' % get_random_string(), - labels={'name': 'airflow__%s__%s' % (self.dag_id, self.task_id)} + labels={'name': 'airflow__%s__%s' % (self.dag_id, self.task_id)}, ) self.log.info('Service started: %s', str(self.service)) @@ -159,9 +154,7 @@ def _run_service(self) -> None: def _service_status(self) -> Optional[str]: if not self.cli: raise Exception("The 'cli' should be initialized before!") - return self.cli.tasks( - filters={'service': self.service['ID']} - )[0]['Status']['State'] + return self.cli.tasks(filters={'service': self.service['ID']})[0]['Status']['State'] def _has_service_terminated(self) -> bool: status = self._service_status() diff --git a/airflow/providers/elasticsearch/hooks/elasticsearch.py b/airflow/providers/elasticsearch/hooks/elasticsearch.py index 884a0d1414673..b322375664b7b 100644 --- a/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -30,11 +30,7 @@ class ElasticsearchHook(DbApiHook): conn_name_attr = 'elasticsearch_conn_id' default_conn_name = 'elasticsearch_default' - def __init__(self, - schema: str = "http", - connection: Optional[AirflowConnection] = None, - *args, - **kwargs): + def __init__(self, schema: str = "http", connection: Optional[AirflowConnection] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema self.connection = connection @@ -51,7 +47,7 @@ def get_conn(self) -> ESConnection: port=conn.port, user=conn.login or None, password=conn.password or None, - scheme=conn.schema or "http" + scheme=conn.schema or "http", ) if conn.extra_dejson.get('http_compress', False): @@ -74,8 +70,7 @@ def get_uri(self) -> str: host = conn.host if conn.port is not None: host += ':{port}'.format(port=conn.port) - uri = '{conn.conn_type}+{conn.schema}://{login}{host}/'.format( - conn=conn, login=login, host=host) + uri = '{conn.conn_type}+{conn.schema}://{login}{host}/'.format(conn=conn, login=login, host=host) extras_length = len(conn.extra_dejson) if not extras_length: @@ -85,8 +80,7 @@ def get_uri(self) -> str: for arg_key, arg_value in conn.extra_dejson.items(): extras_length -= 1 - uri += "{arg_key}={arg_value}".format( - arg_key=arg_key, arg_value=arg_value) + uri += "{arg_key}={arg_value}".format(arg_key=arg_key, arg_value=arg_value) if extras_length: uri += '&' diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 6ff2ef37e0802..3a65ea6c112db 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -76,12 +76,10 @@ def __init__( # pylint: disable=too-many-arguments :param host: Elasticsearch host name """ es_kwargs = es_kwargs or {} - super().__init__( - base_log_folder, filename_template) + super().__init__(base_log_folder, filename_template) self.closed = False - self.log_id_template, self.log_id_jinja_template = \ - parse_template_string(log_id_template) + self.log_id_template, self.log_id_jinja_template = parse_template_string(log_id_template) self.client = elasticsearch.Elasticsearch([host], **es_kwargs) @@ -104,10 +102,9 @@ def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: execution_date = self._clean_execution_date(ti.execution_date) else: execution_date = ti.execution_date.isoformat() - return self.log_id_template.format(dag_id=ti.dag_id, - task_id=ti.task_id, - execution_date=execution_date, - try_number=try_number) + return self.log_id_template.format( + dag_id=ti.dag_id, task_id=ti.task_id, execution_date=execution_date, try_number=try_number + ) @staticmethod def _clean_execution_date(execution_date: datetime) -> str: @@ -120,9 +117,7 @@ def _clean_execution_date(execution_date: datetime) -> str: """ return execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f") - def _read( - self, ti: TaskInstance, try_number: int, metadata: Optional[dict] = None - ) -> Tuple[str, dict]: + def _read(self, ti: TaskInstance, try_number: int, metadata: Optional[dict] = None) -> Tuple[str, dict]: """ Endpoint for streaming log. @@ -151,8 +146,7 @@ def _read( # end_of_log_mark may contain characters like '\n' which is needed to # have the log uploaded but will not be stored in elasticsearch. - metadata['end_of_log'] = False if not logs \ - else logs[-1].message == self.end_of_log_mark.strip() + metadata['end_of_log'] = False if not logs else logs[-1].message == self.end_of_log_mark.strip() cur_ts = pendulum.now() # Assume end of log after not receiving new log for 5 min, @@ -160,8 +154,11 @@ def _read( # delay before Elasticsearch makes the log available. if 'last_log_timestamp' in metadata: last_log_ts = timezone.parse(metadata['last_log_timestamp']) - if cur_ts.diff(last_log_ts).in_minutes() >= 5 or 'max_offset' in metadata \ - and offset >= metadata['max_offset']: + if ( + cur_ts.diff(last_log_ts).in_minutes() >= 5 + or 'max_offset' in metadata + and offset >= metadata['max_offset'] + ): metadata['end_of_log'] = True if offset != next_offset or 'last_log_timestamp' not in metadata: @@ -188,9 +185,7 @@ def es_read(self, log_id: str, offset: str, metadata: dict) -> list: """ # Offset is the unique key for sorting logs given log_id. - search = Search(using=self.client) \ - .query('match_phrase', log_id=log_id) \ - .sort('offset') + search = Search(using=self.client).query('match_phrase', log_id=log_id).sort('offset') search = search.filter('range', offset={'gt': int(offset)}) max_log_line = search.count() @@ -207,8 +202,7 @@ def es_read(self, log_id: str, offset: str, metadata: dict) -> list: if max_log_line != 0: try: - logs = search[self.MAX_LINE_PER_PAGE * self.PAGE:self.MAX_LINE_PER_PAGE] \ - .execute() + logs = search[self.MAX_LINE_PER_PAGE * self.PAGE : self.MAX_LINE_PER_PAGE].execute() except Exception as e: # pylint: disable=broad-except self.log.exception('Could not read log with log_id: %s, error: %s', log_id, str(e)) @@ -229,8 +223,9 @@ def set_context(self, ti: TaskInstance) -> None: 'dag_id': str(ti.dag_id), 'task_id': str(ti.task_id), 'execution_date': self._clean_execution_date(ti.execution_date), - 'try_number': str(ti.try_number) - }) + 'try_number': str(ti.try_number), + }, + ) if self.write_stdout: if self.context_set: @@ -299,6 +294,7 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) -> dag_id=task_instance.dag_id, task_id=task_instance.task_id, execution_date=task_instance.execution_date, - try_number=try_number) + try_number=try_number, + ) url = 'https://' + self.frontend.format(log_id=quote(log_id)) return url diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 27e64a5806ec3..8e555459b8d22 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -34,6 +34,7 @@ class ExasolHook(DbApiHook): `_ for more details. """ + conn_name_attr = 'exasol_conn_id' default_conn_name = 'exasol_default' supports_autocommit = True @@ -49,7 +50,8 @@ def get_conn(self): dsn='%s:%s' % (conn.host, conn.port), user=conn.login, password=conn.password, - schema=self.schema or conn.schema) + schema=self.schema or conn.schema, + ) # check for parameters in conn.extra for arg_name, arg_val in conn.extra_dejson.items(): if arg_name in ['compression', 'encryption', 'json_lib', 'client_name']: @@ -143,9 +145,9 @@ def set_autocommit(self, conn, autocommit): """ if not self.supports_autocommit and autocommit: self.log.warning( - ("%s connection doesn't support " - "autocommit but autocommit activated."), - getattr(self, self.conn_name_attr)) + ("%s connection doesn't support " "autocommit but autocommit activated."), + getattr(self, self.conn_name_attr), + ) conn.set_autocommit(autocommit) def get_autocommit(self, conn): diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index e4ac6b6d6976e..f0dc336336be0 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -47,13 +47,15 @@ class ExasolOperator(BaseOperator): @apply_defaults def __init__( - self, *, - sql: str, - exasol_conn_id: str = 'exasol_default', - autocommit: bool = False, - parameters: Optional[Mapping] = None, - schema: Optional[str] = None, - **kwargs): + self, + *, + sql: str, + exasol_conn_id: str = 'exasol_default', + autocommit: bool = False, + parameters: Optional[Mapping] = None, + schema: Optional[str] = None, + **kwargs, + ): super(ExasolOperator, self).__init__(**kwargs) self.exasol_conn_id = exasol_conn_id self.sql = sql @@ -63,9 +65,5 @@ def __init__( def execute(self, context): self.log.info('Executing: %s', self.sql) - hook = ExasolHook(exasol_conn_id=self.exasol_conn_id, - schema=self.schema) - hook.run( - self.sql, - autocommit=self.autocommit, - parameters=self.parameters) + hook = ExasolHook(exasol_conn_id=self.exasol_conn_id, schema=self.schema) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/airflow/providers/facebook/ads/hooks/ads.py b/airflow/providers/facebook/ads/hooks/ads.py index 1318cb24a40ba..77192e588686d 100644 --- a/airflow/providers/facebook/ads/hooks/ads.py +++ b/airflow/providers/facebook/ads/hooks/ads.py @@ -36,6 +36,7 @@ class JobStatus(Enum): """ Available options for facebook async task status """ + COMPLETED = 'Job Completed' STARTED = 'Job Started' RUNNING = 'Job Running' @@ -58,27 +59,22 @@ class FacebookAdsReportingHook(BaseHook): """ - def __init__( - self, - facebook_conn_id: str = "facebook_default", - api_version: str = "v6.0", - ) -> None: + def __init__(self, facebook_conn_id: str = "facebook_default", api_version: str = "v6.0",) -> None: super().__init__() self.facebook_conn_id = facebook_conn_id self.api_version = api_version - self.client_required_fields = ["app_id", - "app_secret", - "access_token", - "account_id"] + self.client_required_fields = ["app_id", "app_secret", "access_token", "account_id"] def _get_service(self) -> FacebookAdsApi: """Returns Facebook Ads Client using a service account""" config = self.facebook_ads_config - return FacebookAdsApi.init(app_id=config["app_id"], - app_secret=config["app_secret"], - access_token=config["access_token"], - account_id=config["account_id"], - api_version=self.api_version) + return FacebookAdsApi.init( + app_id=config["app_id"], + app_secret=config["app_secret"], + access_token=config["access_token"], + account_id=config["account_id"], + api_version=self.api_version, + ) @cached_property def facebook_ads_config(self) -> Dict: @@ -96,10 +92,7 @@ def facebook_ads_config(self) -> Dict: return config def bulk_facebook_report( - self, - params: Dict[str, Any], - fields: List[str], - sleep_time: int = 5, + self, params: Dict[str, Any], fields: List[str], sleep_time: int = 5, ) -> List[AdsInsights]: """ Pulls data from the Facebook Ads API diff --git a/airflow/providers/ftp/hooks/ftp.py b/airflow/providers/ftp/hooks/ftp.py index b7c3e06056cdd..466d927c05282 100644 --- a/airflow/providers/ftp/hooks/ftp.py +++ b/airflow/providers/ftp/hooks/ftp.py @@ -113,11 +113,7 @@ def delete_directory(self, path: str) -> None: conn = self.get_conn() conn.rmd(path) - def retrieve_file( - self, - remote_full_path, - local_full_path_or_buffer, - callback=None): + def retrieve_file(self, remote_full_path, local_full_path_or_buffer, callback=None): """ Transfers the remote file to a local location. @@ -267,6 +263,7 @@ class FTPSHook(FTPHook): """ Interact with FTPS. """ + def get_conn(self) -> ftplib.FTP: """ Returns a FTPS connection object. @@ -278,9 +275,7 @@ def get_conn(self) -> ftplib.FTP: if params.port: ftplib.FTP_TLS.port = params.port - self.conn = ftplib.FTP_TLS( - params.host, params.login, params.password - ) + self.conn = ftplib.FTP_TLS(params.host, params.login, params.password) self.conn.set_pasv(pasv) return self.conn diff --git a/airflow/providers/ftp/sensors/ftp.py b/airflow/providers/ftp/sensors/ftp.py index 878b7e33b4841..baf2d351336d9 100644 --- a/airflow/providers/ftp/sensors/ftp.py +++ b/airflow/providers/ftp/sensors/ftp.py @@ -45,11 +45,8 @@ class FTPSensor(BaseSensorOperator): @apply_defaults def __init__( - self, *, - path: str, - ftp_conn_id: str = 'ftp_default', - fail_on_transient_errors: bool = True, - **kwargs) -> None: + self, *, path: str, ftp_conn_id: str = 'ftp_default', fail_on_transient_errors: bool = True, **kwargs + ) -> None: super().__init__(**kwargs) self.path = path @@ -77,9 +74,9 @@ def poke(self, context: dict) -> bool: except ftplib.error_perm as e: self.log.error('Ftp error encountered: %s', str(e)) error_code = self._get_error_code(e) - if ((error_code != 550) and - (self.fail_on_transient_errors or - (error_code not in self.transient_errors))): + if (error_code != 550) and ( + self.fail_on_transient_errors or (error_code not in self.transient_errors) + ): raise e return False @@ -89,6 +86,7 @@ def poke(self, context: dict) -> bool: class FTPSSensor(FTPSensor): """Waits for a file or directory to be present on FTP over SSL.""" + def _create_hook(self) -> FTPHook: """Return connection hook.""" return FTPSHook(ftp_conn_id=self.ftp_conn_id) diff --git a/airflow/providers/google/__init__.py b/airflow/providers/google/__init__.py index 7c5a5d4f585a5..9d7b5d8ffb82f 100644 --- a/airflow/providers/google/__init__.py +++ b/airflow/providers/google/__init__.py @@ -18,15 +18,18 @@ from airflow.configuration import conf -PROVIDERS_GOOGLE_VERBOSE_LOGGING: bool = conf.getboolean('providers_google', - 'VERBOSE_LOGGING', fallback=False) +PROVIDERS_GOOGLE_VERBOSE_LOGGING: bool = conf.getboolean( + 'providers_google', 'VERBOSE_LOGGING', fallback=False +) if PROVIDERS_GOOGLE_VERBOSE_LOGGING: for logger_name in ["google_auth_httplib2", "httplib2", "googleapiclient"]: logger = logging.getLogger(logger_name) - logger.handlers += [handler for handler in - logging.getLogger().handlers if handler.name in ["task", "console"]] + logger.handlers += [ + handler for handler in logging.getLogger().handlers if handler.name in ["task", "console"] + ] logger.level = logging.DEBUG logger.propagate = False import httplib2 + httplib2.debuglevel = 4 diff --git a/airflow/providers/google/ads/example_dags/example_ads.py b/airflow/providers/google/ads/example_dags/example_ads.py index 43bdcd34f6c93..e38fd473aae04 100644 --- a/airflow/providers/google/ads/example_dags/example_ads.py +++ b/airflow/providers/google/ads/example_dags/example_ads.py @@ -82,8 +82,6 @@ # [START howto_ads_list_accounts_operator] list_accounts = GoogleAdsListAccountsOperator( - task_id="list_accounts", - bucket=BUCKET, - object_name=GCS_ACCOUNTS_CSV + task_id="list_accounts", bucket=BUCKET, object_name=GCS_ACCOUNTS_CSV ) # [END howto_ads_list_accounts_operator] diff --git a/airflow/providers/google/ads/hooks/ads.py b/airflow/providers/google/ads/hooks/ads.py index 81f8cf0a397d4..12ad1f7cfe2c0 100644 --- a/airflow/providers/google/ads/hooks/ads.py +++ b/airflow/providers/google/ads/hooks/ads.py @@ -156,16 +156,13 @@ def search( """ service = self._get_service iterators = ( - service.search(client_id, query=query, page_size=page_size, **kwargs) - for client_id in client_ids + service.search(client_id, query=query, page_size=page_size, **kwargs) for client_id in client_ids ) self.log.info("Fetched Google Ads Iterators") return self._extract_rows(iterators) - def _extract_rows( - self, iterators: Generator[GRPCIterator, None, None] - ) -> List[GoogleAdsRow]: + def _extract_rows(self, iterators: Generator[GRPCIterator, None, None]) -> List[GoogleAdsRow]: """ Convert Google Page Iterator (GRPCIterator) objects to Google Ads Rows @@ -188,9 +185,7 @@ def _extract_rows( self.log.error("\tError with message: %s.", error.message) if error.location: for field_path_element in error.location.field_path_elements: - self.log.error( - "\t\tOn field: %s", field_path_element.field_name - ) + self.log.error("\t\tOn field: %s", field_path_element.field_name) raise def list_accessible_customers(self) -> List[str]: diff --git a/airflow/providers/google/ads/operators/ads.py b/airflow/providers/google/ads/operators/ads.py index c2071b7dfd612..e24523f3396f1 100644 --- a/airflow/providers/google/ads/operators/ads.py +++ b/airflow/providers/google/ads/operators/ads.py @@ -68,11 +68,16 @@ class GoogleAdsListAccountsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("bucket", "object_name", "impersonation_chain",) + template_fields = ( + "bucket", + "object_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, bucket: str, object_name: str, gcp_conn_id: str = "google_cloud_default", @@ -92,15 +97,9 @@ def __init__( def execute(self, context: Dict): uri = f"gs://{self.bucket}/{self.object_name}" - ads_hook = GoogleAdsHook( - gcp_conn_id=self.gcp_conn_id, - google_ads_conn_id=self.google_ads_conn_id - ) + ads_hook = GoogleAdsHook(gcp_conn_id=self.gcp_conn_id, google_ads_conn_id=self.google_ads_conn_id) - gcs_hook = GCSHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain - ) + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) with NamedTemporaryFile("w+") as temp_file: # Download accounts accounts = ads_hook.list_accessible_customers() @@ -110,10 +109,7 @@ def execute(self, context: Dict): # Upload to GCS gcs_hook.upload( - bucket_name=self.bucket, - object_name=self.object_name, - gzip=self.gzip, - filename=temp_file.name + bucket_name=self.bucket, object_name=self.object_name, gzip=self.gzip, filename=temp_file.name ) self.log.info("Uploaded %s to %s", len(accounts), uri) diff --git a/airflow/providers/google/ads/transfers/ads_to_gcs.py b/airflow/providers/google/ads/transfers/ads_to_gcs.py index 6b2fb71f37c08..0890ef8e7e1f3 100644 --- a/airflow/providers/google/ads/transfers/ads_to_gcs.py +++ b/airflow/providers/google/ads/transfers/ads_to_gcs.py @@ -69,11 +69,19 @@ class GoogleAdsToGcsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("client_ids", "query", "attributes", "bucket", "obj", "impersonation_chain",) + template_fields = ( + "client_ids", + "query", + "attributes", + "bucket", + "obj", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, client_ids: List[str], query: str, attributes: List[str], @@ -99,13 +107,8 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - service = GoogleAdsHook( - gcp_conn_id=self.gcp_conn_id, - google_ads_conn_id=self.google_ads_conn_id - ) - rows = service.search( - client_ids=self.client_ids, query=self.query, page_size=self.page_size - ) + service = GoogleAdsHook(gcp_conn_id=self.gcp_conn_id, google_ads_conn_id=self.google_ads_conn_id) + rows = service.search(client_ids=self.client_ids, query=self.query, page_size=self.page_size) try: getter = attrgetter(*self.attributes) @@ -119,14 +122,8 @@ def execute(self, context: Dict): writer.writerows(converted_rows) csvfile.flush() - hook = GCSHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain - ) + hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) hook.upload( - bucket_name=self.bucket, - object_name=self.obj, - filename=csvfile.name, - gzip=self.gzip, + bucket_name=self.bucket, object_name=self.obj, filename=csvfile.name, gzip=self.gzip, ) self.log.info("%s uploaded to GCS", self.obj) diff --git a/airflow/providers/google/cloud/_internal_client/secret_manager_client.py b/airflow/providers/google/cloud/_internal_client/secret_manager_client.py index 59714d2195075..0148b1479b32c 100644 --- a/airflow/providers/google/cloud/_internal_client/secret_manager_client.py +++ b/airflow/providers/google/cloud/_internal_client/secret_manager_client.py @@ -40,9 +40,9 @@ class _SecretManagerClient(LoggingMixin): :param credentials: Credentials used to authenticate to GCP :type credentials: google.auth.credentials.Credentials """ + def __init__( - self, - credentials: google.auth.credentials.Credentials, + self, credentials: google.auth.credentials.Credentials, ): super().__init__() self.credentials = credentials @@ -63,15 +63,11 @@ def client(self) -> SecretManagerServiceClient: Create an authenticated KMS client """ _client = SecretManagerServiceClient( - credentials=self.credentials, - client_info=ClientInfo(client_library_version='airflow_v' + version) + credentials=self.credentials, client_info=ClientInfo(client_library_version='airflow_v' + version) ) return _client - def get_secret(self, - secret_id: str, - project_id: str, - secret_version: str = 'latest') -> Optional[str]: + def get_secret(self, secret_id: str, project_id: str, secret_version: str = 'latest') -> Optional[str]: """ Get secret value from the Secret Manager. @@ -88,13 +84,12 @@ def get_secret(self, value = response.payload.data.decode('UTF-8') return value except NotFound: - self.log.error( - "GCP API Call Error (NotFound): Secret ID %s not found.", secret_id - ) + self.log.error("GCP API Call Error (NotFound): Secret ID %s not found.", secret_id) return None except PermissionDenied: self.log.error( """GCP API Call Error (PermissionDenied): No access for Secret ID %s. - Did you add 'secretmanager.versions.access' permission?""", secret_id + Did you add 'secretmanager.versions.access' permission?""", + secret_id, ) return None diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py index 6648b149a902a..1560dda31dbb6 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py @@ -24,16 +24,17 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_TEXT_CLS_BUCKET = os.environ.get( - "GCP_AUTOML_TEXT_CLS_BUCKET", "gs://" -) +GCP_AUTOML_TEXT_CLS_BUCKET = os.environ.get("GCP_AUTOML_TEXT_CLS_BUCKET", "gs://") # Example values DATASET_ID = "" @@ -66,9 +67,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -79,9 +78,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -99,5 +96,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py index 0be33e6ed76c4..ac63c02ccc09d 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_extraction.py @@ -24,16 +24,17 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_TEXT_BUCKET = os.environ.get( - "GCP_AUTOML_TEXT_BUCKET", "gs://cloud-ml-data/NL-entity/dataset.csv" -) +GCP_AUTOML_TEXT_BUCKET = os.environ.get("GCP_AUTOML_TEXT_BUCKET", "gs://cloud-ml-data/NL-entity/dataset.csv") # Example values DATASET_ID = "" @@ -64,9 +65,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -77,9 +76,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -97,5 +94,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py index 98a99b3f576ad..8645f0a0f563b 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py @@ -24,16 +24,17 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_SENTIMENT_BUCKET = os.environ.get( - "GCP_AUTOML_SENTIMENT_BUCKET", "gs://" -) +GCP_AUTOML_SENTIMENT_BUCKET = os.environ.get("GCP_AUTOML_SENTIMENT_BUCKET", "gs://") # Example values DATASET_ID = "" @@ -67,9 +68,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -80,9 +79,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -100,5 +97,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_tables.py b/airflow/providers/google/cloud/example_dags/example_automl_tables.py index 11d4049fffc05..8a60440eadcb9 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_tables.py @@ -26,10 +26,19 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLBatchPredictOperator, AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, - AutoMLDeleteModelOperator, AutoMLDeployModelOperator, AutoMLGetModelOperator, AutoMLImportDataOperator, - AutoMLListDatasetOperator, AutoMLPredictOperator, AutoMLTablesListColumnSpecsOperator, - AutoMLTablesListTableSpecsOperator, AutoMLTablesUpdateDatasetOperator, AutoMLTrainModelOperator, + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLDeployModelOperator, + AutoMLGetModelOperator, + AutoMLImportDataOperator, + AutoMLListDatasetOperator, + AutoMLPredictOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago @@ -92,9 +101,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - dataset_id = ( - "{{ task_instance.xcom_pull('create_dataset_task', key='dataset_id') }}" - ) + dataset_id = "{{ task_instance.xcom_pull('create_dataset_task', key='dataset_id') }}" # [END howto_operator_automl_create_dataset] MODEL["dataset_id"] = dataset_id @@ -135,18 +142,13 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'), target) }}" update_dataset_task = AutoMLTablesUpdateDatasetOperator( - task_id="update_dataset_task", - dataset=update, - location=GCP_AUTOML_LOCATION, + task_id="update_dataset_task", dataset=update, location=GCP_AUTOML_LOCATION, ) # [END howto_operator_automl_update_dataset] # [START howto_operator_automl_create_model] create_model_task = AutoMLTrainModelOperator( - task_id="create_model_task", - model=MODEL, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, + task_id="create_model_task", model=MODEL, location=GCP_AUTOML_LOCATION, project_id=GCP_PROJECT_ID, ) model_id = "{{ task_instance.xcom_pull('create_model_task', key='model_id') }}" @@ -194,9 +196,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -222,9 +222,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: # [START howto_operator_list_dataset] list_datasets_task = AutoMLListDatasetOperator( - task_id="list_datasets_task", - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, + task_id="list_datasets_task", location=GCP_AUTOML_LOCATION, project_id=GCP_PROJECT_ID, ) # [END howto_operator_list_dataset] @@ -254,10 +252,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: ) as get_deploy_dag: # [START howto_operator_get_model] get_model_task = AutoMLGetModelOperator( - task_id="get_model_task", - model_id=MODEL_ID, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, + task_id="get_model_task", model_id=MODEL_ID, location=GCP_AUTOML_LOCATION, project_id=GCP_PROJECT_ID, ) # [END howto_operator_get_model] diff --git a/airflow/providers/google/cloud/example_dags/example_automl_translation.py b/airflow/providers/google/cloud/example_dags/example_automl_translation.py index 87e5265dac8ab..9a6e95de81711 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_translation.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_translation.py @@ -24,16 +24,17 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_TRANSLATION_BUCKET = os.environ.get( - "GCP_AUTOML_TRANSLATION_BUCKET", "gs://project-vcm/file" -) +GCP_AUTOML_TRANSLATION_BUCKET = os.environ.get("GCP_AUTOML_TRANSLATION_BUCKET", "gs://project-vcm/file") # Example values DATASET_ID = "TRL123456789" @@ -48,10 +49,7 @@ # Example dataset DATASET = { "display_name": "test_translation_dataset", - "translation_dataset_metadata": { - "source_language_code": "en", - "target_language_code": "es", - }, + "translation_dataset_metadata": {"source_language_code": "en", "target_language_code": "es",}, } IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TRANSLATION_BUCKET]}} @@ -71,9 +69,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -84,9 +80,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -104,5 +98,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py index 2f91223749d44..1605046f19b06 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py @@ -24,8 +24,11 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago @@ -68,9 +71,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -81,9 +82,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -101,5 +100,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py index 257a450159918..30b2f54c48d15 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py @@ -24,16 +24,18 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") GCP_AUTOML_TRACKING_BUCKET = os.environ.get( - "GCP_AUTOML_TRACKING_BUCKET", - "gs://automl-video-datasets/youtube_8m_videos_animal_tiny.csv", + "GCP_AUTOML_TRACKING_BUCKET", "gs://automl-video-datasets/youtube_8m_videos_animal_tiny.csv", ) # Example values @@ -69,9 +71,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -82,9 +82,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -102,5 +100,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py index 48d14aca03c90..2d16db214852d 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_vision_classification.py @@ -24,16 +24,17 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_VISION_BUCKET = os.environ.get( - "GCP_AUTOML_VISION_BUCKET", "gs://your-bucket" -) +GCP_AUTOML_VISION_BUCKET = os.environ.get("GCP_AUTOML_VISION_BUCKET", "gs://your-bucket") # Example values DATASET_ID = "ICN123455678" @@ -68,9 +69,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -81,9 +80,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -101,5 +98,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py b/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py index 77bac055a6508..3be9eba969749 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py @@ -24,8 +24,11 @@ from airflow import models from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, - AutoMLImportDataOperator, AutoMLTrainModelOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLImportDataOperator, + AutoMLTrainModelOperator, ) from airflow.utils.dates import days_ago @@ -68,9 +71,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = ( - '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' - ) + dataset_id = '{{ task_instance.xcom_pull("create_dataset_task", key="dataset_id") }}' import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -81,9 +82,7 @@ MODEL["dataset_id"] = dataset_id - create_model = AutoMLTrainModelOperator( - task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION - ) + create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) model_id = "{{ task_instance.xcom_pull('create_model', key='model_id') }}" @@ -101,5 +100,4 @@ project_id=GCP_PROJECT_ID, ) - create_dataset_task >> import_dataset_task >> create_model >> \ - delete_model_task >> delete_datasets_task + create_dataset_task >> import_dataset_task >> create_model >> delete_model_task >> delete_datasets_task diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py index 7e39a55f5671d..cacde16aad827 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py @@ -27,16 +27,15 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery_dts import ( - BigQueryCreateDataTransferOperator, BigQueryDataTransferServiceStartTransferRunsOperator, + BigQueryCreateDataTransferOperator, + BigQueryDataTransferServiceStartTransferRunsOperator, BigQueryDeleteDataTransferConfigOperator, ) from airflow.providers.google.cloud.sensors.bigquery_dts import BigQueryDataTransferServiceTransferRunSensor from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") -BUCKET_URI = os.environ.get( - "GCP_DTS_BUCKET_URI", "gs://cloud-ml-tables-data/bank-marketing.csv" -) +BUCKET_URI = os.environ.get("GCP_DTS_BUCKET_URI", "gs://cloud-ml-tables-data/bank-marketing.csv") GCP_DTS_BQ_DATASET = os.environ.get("GCP_DTS_BQ_DATASET", "test_dts") GCP_DTS_BQ_TABLE = os.environ.get("GCP_DTS_BQ_TABLE", "GCS_Test") @@ -77,14 +76,11 @@ ) as dag: # [START howto_bigquery_create_data_transfer] gcp_bigquery_create_transfer = BigQueryCreateDataTransferOperator( - transfer_config=TRANSFER_CONFIG, - project_id=GCP_PROJECT_ID, - task_id="gcp_bigquery_create_transfer", + transfer_config=TRANSFER_CONFIG, project_id=GCP_PROJECT_ID, task_id="gcp_bigquery_create_transfer", ) transfer_config_id = ( - "{{ task_instance.xcom_pull('gcp_bigquery_create_transfer', " - "key='transfer_config_id') }}" + "{{ task_instance.xcom_pull('gcp_bigquery_create_transfer', key='transfer_config_id') }}" ) # [END howto_bigquery_create_data_transfer] @@ -94,9 +90,7 @@ transfer_config_id=transfer_config_id, requested_run_time={"seconds": int(time.time() + 60)}, ) - run_id = ( - "{{ task_instance.xcom_pull('gcp_bigquery_start_transfer', key='run_id') }}" - ) + run_id = "{{ task_instance.xcom_pull('gcp_bigquery_start_transfer', key='run_id') }}" # [END howto_bigquery_start_transfer] # [START howto_bigquery_dts_sensor] diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py b/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py index 8a110e0924ff8..6e9fd7ef20612 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py @@ -26,9 +26,15 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator, - BigQueryDeleteDatasetOperator, BigQueryDeleteTableOperator, BigQueryGetDatasetOperator, - BigQueryGetDatasetTablesOperator, BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryDeleteTableOperator, + BigQueryGetDatasetOperator, + BigQueryGetDatasetTablesOperator, + BigQueryPatchDatasetOperator, + BigQueryUpdateDatasetOperator, BigQueryUpsertTableOperator, ) from airflow.utils.dates import days_ago @@ -39,8 +45,7 @@ DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_operations") LOCATION_DATASET_NAME = "{}_location".format(DATASET_NAME) DATA_SAMPLE_GCS_URL = os.environ.get( - "GCP_BIGQUERY_DATA_GCS_URL", - "gs://cloud-samples-data/bigquery/us-states/us-states.csv", + "GCP_BIGQUERY_DATA_GCS_URL", "gs://cloud-samples-data/bigquery/us-states/us-states.csv", ) DATA_SAMPLE_GCS_URL_PARTS = urlparse(DATA_SAMPLE_GCS_URL) @@ -68,8 +73,7 @@ # [START howto_operator_bigquery_delete_table] delete_table = BigQueryDeleteTableOperator( - task_id="delete_table", - deletion_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.test_table", + task_id="delete_table", deletion_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.test_table", ) # [END howto_operator_bigquery_delete_table] @@ -78,10 +82,7 @@ task_id="create_view", dataset_id=DATASET_NAME, table_id="test_view", - view={ - "query": f"SELECT * FROM `{PROJECT_ID}.{DATASET_NAME}.test_table`", - "useLegacySql": False, - }, + view={"query": f"SELECT * FROM `{PROJECT_ID}.{DATASET_NAME}.test_table`", "useLegacySql": False,}, ) # [END howto_operator_bigquery_create_view] @@ -98,10 +99,7 @@ source_objects=[DATA_SAMPLE_GCS_OBJECT_NAME], destination_project_dataset_table=f"{DATASET_NAME}.external_table", skip_leading_rows=1, - schema_fields=[ - {"name": "name", "type": "STRING"}, - {"name": "post_abbr", "type": "STRING"}, - ], + schema_fields=[{"name": "name", "type": "STRING"}, {"name": "post_abbr", "type": "STRING"},], ) # [END howto_operator_bigquery_create_external_table] @@ -117,9 +115,7 @@ # [END howto_operator_bigquery_upsert_table] # [START howto_operator_bigquery_create_dataset] - create_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create-dataset", dataset_id=DATASET_NAME - ) + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create-dataset", dataset_id=DATASET_NAME) # [END howto_operator_bigquery_create_dataset] # [START howto_operator_bigquery_get_dataset_tables] @@ -129,9 +125,7 @@ # [END howto_operator_bigquery_get_dataset_tables] # [START howto_operator_bigquery_get_dataset] - get_dataset = BigQueryGetDatasetOperator( - task_id="get-dataset", dataset_id=DATASET_NAME - ) + get_dataset = BigQueryGetDatasetOperator(task_id="get-dataset", dataset_id=DATASET_NAME) # [END howto_operator_bigquery_get_dataset] get_dataset_result = BashOperator( @@ -143,10 +137,7 @@ patch_dataset = BigQueryPatchDatasetOperator( task_id="patch_dataset", dataset_id=DATASET_NAME, - dataset_resource={ - "friendlyName": "Patched Dataset", - "description": "Patched dataset", - }, + dataset_resource={"friendlyName": "Patched Dataset", "description": "Patched dataset",}, ) # [END howto_operator_bigquery_patch_dataset] @@ -179,9 +170,7 @@ tags=["example"], ): create_dataset_with_location = BigQueryCreateEmptyDatasetOperator( - task_id="create_dataset_with_location", - dataset_id=LOCATION_DATASET_NAME, - location=BQ_LOCATION, + task_id="create_dataset_with_location", dataset_id=LOCATION_DATASET_NAME, location=BQ_LOCATION, ) create_table_with_location = BigQueryCreateEmptyTableOperator( @@ -195,8 +184,6 @@ ) delete_dataset_with_location = BigQueryDeleteDatasetOperator( - task_id="delete_dataset_with_location", - dataset_id=LOCATION_DATASET_NAME, - delete_contents=True, + task_id="delete_dataset_with_location", dataset_id=LOCATION_DATASET_NAME, delete_contents=True, ) create_dataset_with_location >> create_table_with_location >> delete_dataset_with_location diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py b/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py index e0259cb70bdef..90e9e2b597d99 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py @@ -25,9 +25,15 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCheckOperator, BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, - BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, BigQueryGetDataOperator, - BigQueryInsertJobOperator, BigQueryIntervalCheckOperator, BigQueryValueCheckOperator, + BigQueryCheckOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, + BigQueryGetDataOperator, + BigQueryInsertJobOperator, + BigQueryIntervalCheckOperator, + BigQueryValueCheckOperator, ) from airflow.utils.dates import days_ago @@ -42,10 +48,11 @@ INSERT_DATE = datetime.now().strftime("%Y-%m-%d") # [START howto_operator_bigquery_query] -INSERT_ROWS_QUERY = \ - f"INSERT {DATASET_NAME}.{TABLE_1} VALUES " \ - f"(42, 'monthy python', '{INSERT_DATE}'), " \ +INSERT_ROWS_QUERY = ( + f"INSERT {DATASET_NAME}.{TABLE_1} VALUES " + f"(42, 'monthy python', '{INSERT_DATE}'), " f"(42, 'fishy fish', '{INSERT_DATE}');" +) # [END howto_operator_bigquery_query] SCHEMA = [ @@ -62,7 +69,7 @@ schedule_interval=None, # Override to match your needs start_date=days_ago(1), tags=["example"], - user_defined_macros={"DATASET": DATASET_NAME, "TABLE": TABLE_1} + user_defined_macros={"DATASET": DATASET_NAME, "TABLE": TABLE_1}, ) as dag_with_locations: create_dataset = BigQueryCreateEmptyDatasetOperator( task_id="create-dataset", dataset_id=DATASET_NAME, location=location, @@ -93,12 +100,7 @@ # [START howto_operator_bigquery_insert_job] insert_query_job = BigQueryInsertJobOperator( task_id="insert_query_job", - configuration={ - "query": { - "query": INSERT_ROWS_QUERY, - "useLegacySql": "False", - } - }, + configuration={"query": {"query": INSERT_ROWS_QUERY, "useLegacySql": "False",}}, location=location, ) # [END howto_operator_bigquery_insert_job] @@ -107,10 +109,7 @@ select_query_job = BigQueryInsertJobOperator( task_id="select_query_job", configuration={ - "query": { - "query": "{% include 'example_bigquery_query.sql' %}", - "useLegacySql": False, - } + "query": {"query": "{% include 'example_bigquery_query.sql' %}", "useLegacySql": False,} }, location=location, ) @@ -150,8 +149,7 @@ # [END howto_operator_bigquery_get_data] get_data_result = BashOperator( - task_id="get_data_result", - bash_command="echo \"{{ task_instance.xcom_pull('get_data') }}\"", + task_id="get_data_result", bash_command="echo \"{{ task_instance.xcom_pull('get_data') }}\"", ) # [START howto_operator_bigquery_check] @@ -159,7 +157,7 @@ task_id="check_count", sql=f"SELECT COUNT(*) FROM {DATASET_NAME}.{TABLE_1}", use_legacy_sql=False, - location=location + location=location, ) # [END howto_operator_bigquery_check] diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py b/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py index 54e38592f7c97..af9078cf6889d 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py @@ -24,11 +24,14 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, ) from airflow.providers.google.cloud.sensors.bigquery import ( - BigQueryTableExistenceSensor, BigQueryTablePartitionExistenceSensor, + BigQueryTableExistenceSensor, + BigQueryTablePartitionExistenceSensor, ) from airflow.utils.dates import days_ago @@ -40,9 +43,7 @@ PARTITION_NAME = "{{ ds_nodash }}" -INSERT_ROWS_QUERY = \ - f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES " \ - "(42, '{{ ds }}')" +INSERT_ROWS_QUERY = f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES " "(42, '{{ ds }}')" SCHEMA = [ {"name": "value", "type": "INTEGER", "mode": "REQUIRED"}, @@ -57,7 +58,7 @@ start_date=days_ago(1), tags=["example"], user_defined_macros={"DATASET": DATASET_NAME, "TABLE": TABLE_NAME}, - default_args={"project_id": PROJECT_ID} + default_args={"project_id": PROJECT_ID}, ) as dag_with_locations: create_dataset = BigQueryCreateEmptyDatasetOperator( task_id="create-dataset", dataset_id=DATASET_NAME, project_id=PROJECT_ID @@ -68,10 +69,7 @@ dataset_id=DATASET_NAME, table_id=TABLE_NAME, schema_fields=SCHEMA, - time_partitioning={ - "type": "DAY", - "field": "ds", - } + time_partitioning={"type": "DAY", "field": "ds",}, ) # [START howto_sensor_bigquery_table] check_table_exists = BigQueryTableExistenceSensor( @@ -85,8 +83,11 @@ # [START howto_sensor_bigquery_table_partition] check_table_partition_exists = BigQueryTablePartitionExistenceSensor( - task_id="check_table_partition_exists", project_id=PROJECT_ID, dataset_id=DATASET_NAME, - table_id=TABLE_NAME, partition_id=PARTITION_NAME + task_id="check_table_partition_exists", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + partition_id=PARTITION_NAME, ) # [END howto_sensor_bigquery_table_partition] diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py b/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py index 46115feda3987..6cd4c648e26e6 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_to_bigquery.py @@ -23,7 +23,9 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, ) from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator from airflow.utils.dates import days_ago @@ -45,9 +47,7 @@ destination_project_dataset_table=f"{DATASET_NAME}.{TARGET}", ) - create_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create_dataset", dataset_id=DATASET_NAME - ) + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) for table in [ORIGIN, TARGET]: create_table = BigQueryCreateEmptyTableOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py index b5b15f016895c..70841573b73d9 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_to_gcs.py @@ -23,16 +23,16 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, ) from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator from airflow.utils.dates import days_ago PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") -DATA_EXPORT_BUCKET_NAME = os.environ.get( - "GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-gcs-data" -) +DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-gcs-data") TABLE = "table_42" with models.DAG( @@ -44,14 +44,10 @@ bigquery_to_gcs = BigQueryToGCSOperator( task_id="bigquery_to_gcs", source_project_dataset_table=f"{DATASET_NAME}.{TABLE}", - destination_cloud_storage_uris=[ - f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv" - ], + destination_cloud_storage_uris=[f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv"], ) - create_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create_dataset", dataset_id=DATASET_NAME - ) + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) create_table = BigQueryCreateEmptyTableOperator( task_id="create_table", diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py b/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py index 513195e82a1da..82a06a2723f52 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_transfer.py @@ -23,7 +23,9 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, ) from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator @@ -31,9 +33,7 @@ PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "test_dataset_transfer") -DATA_EXPORT_BUCKET_NAME = os.environ.get( - "GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-sample-data" -) +DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "test-bigquery-sample-data") ORIGIN = "origin" TARGET = "target" @@ -52,14 +52,10 @@ bigquery_to_gcs = BigQueryToGCSOperator( task_id="bigquery_to_gcs", source_project_dataset_table=f"{DATASET_NAME}.{ORIGIN}", - destination_cloud_storage_uris=[ - f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv" - ], + destination_cloud_storage_uris=[f"gs://{DATA_EXPORT_BUCKET_NAME}/export-bigquery.csv"], ) - create_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create_dataset", dataset_id=DATASET_NAME - ) + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) for table in [ORIGIN, TARGET]: create_table = BigQueryCreateEmptyTableOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_bigtable.py b/airflow/providers/google/cloud/example_dags/example_bigtable.py index 388fb47f19e5d..81a0aa17c35c2 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigtable.py +++ b/airflow/providers/google/cloud/example_dags/example_bigtable.py @@ -49,8 +49,12 @@ from airflow import models from airflow.providers.google.cloud.operators.bigtable import ( - BigtableCreateInstanceOperator, BigtableCreateTableOperator, BigtableDeleteInstanceOperator, - BigtableDeleteTableOperator, BigtableUpdateClusterOperator, BigtableUpdateInstanceOperator, + BigtableCreateInstanceOperator, + BigtableCreateTableOperator, + BigtableDeleteInstanceOperator, + BigtableDeleteTableOperator, + BigtableUpdateClusterOperator, + BigtableUpdateInstanceOperator, ) from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor from airflow.utils.dates import days_ago @@ -136,27 +140,19 @@ # [START howto_operator_gcp_bigtable_instance_delete] delete_instance_task = BigtableDeleteInstanceOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - task_id='delete_instance_task', + project_id=GCP_PROJECT_ID, instance_id=CBT_INSTANCE_ID, task_id='delete_instance_task', ) delete_instance_task2 = BigtableDeleteInstanceOperator( - instance_id=CBT_INSTANCE_ID, - task_id='delete_instance_task2', + instance_id=CBT_INSTANCE_ID, task_id='delete_instance_task2', ) # [END howto_operator_gcp_bigtable_instance_delete] # [START howto_operator_gcp_bigtable_table_create] create_table_task = BigtableCreateTableOperator( - project_id=GCP_PROJECT_ID, - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='create_table', + project_id=GCP_PROJECT_ID, instance_id=CBT_INSTANCE_ID, table_id=CBT_TABLE_ID, task_id='create_table', ) create_table_task2 = BigtableCreateTableOperator( - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='create_table_task2', + instance_id=CBT_INSTANCE_ID, table_id=CBT_TABLE_ID, task_id='create_table_task2', ) create_table_task >> create_table_task2 # [END howto_operator_gcp_bigtable_table_create] @@ -187,9 +183,7 @@ task_id='delete_table_task', ) delete_table_task2 = BigtableDeleteTableOperator( - instance_id=CBT_INSTANCE_ID, - table_id=CBT_TABLE_ID, - task_id='delete_table_task2', + instance_id=CBT_INSTANCE_ID, table_id=CBT_TABLE_ID, task_id='delete_table_task2', ) # [END howto_operator_gcp_bigtable_table_delete] @@ -197,16 +191,9 @@ wait_for_table_replication_task2 >> delete_table_task wait_for_table_replication_task >> delete_table_task2 wait_for_table_replication_task2 >> delete_table_task2 - create_instance_task \ - >> create_table_task \ - >> cluster_update_task \ - >> update_instance_task \ - >> delete_table_task - create_instance_task2 \ - >> create_table_task2 \ - >> cluster_update_task2 \ - >> delete_table_task2 + create_instance_task >> create_table_task >> cluster_update_task + cluster_update_task >> update_instance_task >> delete_table_task + create_instance_task2 >> create_table_task2 >> cluster_update_task2 >> delete_table_task2 # Only delete instances after all tables are deleted - [delete_table_task, delete_table_task2] >> \ - delete_instance_task >> delete_instance_task2 + [delete_table_task, delete_table_task2] >> delete_instance_task >> delete_instance_task2 diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_build.py b/airflow/providers/google/cloud/example_dags/example_cloud_build.py index 27591d6478984..445fdb07250e5 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_build.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_build.py @@ -104,9 +104,10 @@ # [START howto_operator_gcp_create_build_from_yaml_body] create_build_from_file = CloudBuildCreateBuildOperator( - task_id="create_build_from_file", project_id=GCP_PROJECT_ID, + task_id="create_build_from_file", + project_id=GCP_PROJECT_ID, body=str(CURRENT_FOLDER.joinpath('example_cloud_build.yaml')), - params={'name': 'Airflow'} + params={'name': 'Airflow'}, ) # [END howto_operator_gcp_create_build_from_yaml_body] create_build_from_storage >> create_build_from_storage_result # pylint: disable=pointless-statement diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py index 5dbfb1805b228..a3d8c2b56f24a 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py @@ -26,11 +26,16 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.cloud_memorystore import ( - CloudMemorystoreCreateInstanceAndImportOperator, CloudMemorystoreCreateInstanceOperator, - CloudMemorystoreDeleteInstanceOperator, CloudMemorystoreExportAndDeleteInstanceOperator, - CloudMemorystoreExportInstanceOperator, CloudMemorystoreFailoverInstanceOperator, - CloudMemorystoreGetInstanceOperator, CloudMemorystoreImportOperator, - CloudMemorystoreListInstancesOperator, CloudMemorystoreScaleInstanceOperator, + CloudMemorystoreCreateInstanceAndImportOperator, + CloudMemorystoreCreateInstanceOperator, + CloudMemorystoreDeleteInstanceOperator, + CloudMemorystoreExportAndDeleteInstanceOperator, + CloudMemorystoreExportInstanceOperator, + CloudMemorystoreFailoverInstanceOperator, + CloudMemorystoreGetInstanceOperator, + CloudMemorystoreImportOperator, + CloudMemorystoreListInstancesOperator, + CloudMemorystoreScaleInstanceOperator, CloudMemorystoreUpdateInstanceOperator, ) from airflow.providers.google.cloud.operators.gcs import GCSBucketCreateAclEntryOperator diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_sql.py b/airflow/providers/google/cloud/example_dags/example_cloud_sql.py index c2f52e5777221..15bc7d0a1c077 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_sql.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_sql.py @@ -32,12 +32,18 @@ from airflow import models from airflow.providers.google.cloud.operators.cloud_sql import ( - CloudSQLCreateInstanceDatabaseOperator, CloudSQLCreateInstanceOperator, - CloudSQLDeleteInstanceDatabaseOperator, CloudSQLDeleteInstanceOperator, CloudSQLExportInstanceOperator, - CloudSQLImportInstanceOperator, CloudSQLInstancePatchOperator, CloudSQLPatchInstanceDatabaseOperator, + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceDatabaseOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExportInstanceOperator, + CloudSQLImportInstanceOperator, + CloudSQLInstancePatchOperator, + CloudSQLPatchInstanceDatabaseOperator, ) from airflow.providers.google.cloud.operators.gcs import ( - GCSBucketCreateAclEntryOperator, GCSObjectCreateAclEntryOperator, + GCSBucketCreateAclEntryOperator, + GCSObjectCreateAclEntryOperator, ) from airflow.utils.dates import days_ago @@ -60,38 +66,21 @@ "name": INSTANCE_NAME, "settings": { "tier": "db-n1-standard-1", - "backupConfiguration": { - "binaryLogEnabled": True, - "enabled": True, - "startTime": "05:00" - }, + "backupConfiguration": {"binaryLogEnabled": True, "enabled": True, "startTime": "05:00"}, "activationPolicy": "ALWAYS", "dataDiskSizeGb": 30, "dataDiskType": "PD_SSD", "databaseFlags": [], - "ipConfiguration": { - "ipv4Enabled": True, - "requireSsl": True, - }, - "locationPreference": { - "zone": "europe-west4-a" - }, - "maintenanceWindow": { - "hour": 5, - "day": 7, - "updateTrack": "canary" - }, + "ipConfiguration": {"ipv4Enabled": True, "requireSsl": True,}, + "locationPreference": {"zone": "europe-west4-a"}, + "maintenanceWindow": {"hour": 5, "day": 7, "updateTrack": "canary"}, "pricingPlan": "PER_USE", "replicationType": "ASYNCHRONOUS", "storageAutoResize": True, "storageAutoResizeLimit": 0, - "userLabels": { - "my-key": "my-value" - } - }, - "failoverReplica": { - "name": FAILOVER_REPLICA_NAME + "userLabels": {"my-key": "my-value"}, }, + "failoverReplica": {"name": FAILOVER_REPLICA_NAME}, "databaseVersion": "MYSQL_5_7", "region": "europe-west4", } @@ -99,9 +88,7 @@ body2 = { "name": INSTANCE_NAME2, - "settings": { - "tier": "db-n1-standard-1", - }, + "settings": {"tier": "db-n1-standard-1",}, "databaseVersion": "MYSQL_5_7", "region": "europe-west4", } @@ -109,9 +96,7 @@ # [START howto_operator_cloudsql_create_replica] read_replica_body = { "name": READ_REPLICA_NAME, - "settings": { - "tier": "db-n1-standard-1", - }, + "settings": {"tier": "db-n1-standard-1",}, "databaseVersion": "MYSQL_5_7", "region": "europe-west4", "masterInstanceName": INSTANCE_NAME, @@ -124,48 +109,24 @@ "name": INSTANCE_NAME, "settings": { "dataDiskSizeGb": 35, - "maintenanceWindow": { - "hour": 3, - "day": 6, - "updateTrack": "canary" - }, - "userLabels": { - "my-key-patch": "my-value-patch" - } - } + "maintenanceWindow": {"hour": 3, "day": 6, "updateTrack": "canary"}, + "userLabels": {"my-key-patch": "my-value-patch"}, + }, } # [END howto_operator_cloudsql_patch_body] # [START howto_operator_cloudsql_export_body] export_body = { - "exportContext": { - "fileType": "sql", - "uri": EXPORT_URI, - "sqlExportOptions": { - "schemaOnly": False - } - } + "exportContext": {"fileType": "sql", "uri": EXPORT_URI, "sqlExportOptions": {"schemaOnly": False}} } # [END howto_operator_cloudsql_export_body] # [START howto_operator_cloudsql_import_body] -import_body = { - "importContext": { - "fileType": "sql", - "uri": IMPORT_URI - } -} +import_body = {"importContext": {"fileType": "sql", "uri": IMPORT_URI}} # [END howto_operator_cloudsql_import_body] # [START howto_operator_cloudsql_db_create_body] -db_create_body = { - "instance": INSTANCE_NAME, - "name": DB_NAME, - "project": GCP_PROJECT_ID -} +db_create_body = {"instance": INSTANCE_NAME, "name": DB_NAME, "project": GCP_PROJECT_ID} # [END howto_operator_cloudsql_db_create_body] # [START howto_operator_cloudsql_db_patch_body] -db_patch_body = { - "charset": "utf16", - "collation": "utf16_general_ci" -} +db_patch_body = {"charset": "utf16", "collation": "utf16_general_ci"} # [END howto_operator_cloudsql_db_patch_body] with models.DAG( @@ -180,18 +141,12 @@ # [START howto_operator_cloudsql_create] sql_instance_create_task = CloudSQLCreateInstanceOperator( - project_id=GCP_PROJECT_ID, - body=body, - instance=INSTANCE_NAME, - task_id='sql_instance_create_task' + project_id=GCP_PROJECT_ID, body=body, instance=INSTANCE_NAME, task_id='sql_instance_create_task' ) # [END howto_operator_cloudsql_create] sql_instance_create_2_task = CloudSQLCreateInstanceOperator( - project_id=GCP_PROJECT_ID, - body=body2, - instance=INSTANCE_NAME2, - task_id='sql_instance_create_task2' + project_id=GCP_PROJECT_ID, body=body2, instance=INSTANCE_NAME2, task_id='sql_instance_create_task2' ) # [END howto_operator_cloudsql_create] @@ -199,7 +154,7 @@ project_id=GCP_PROJECT_ID, body=read_replica_body, instance=READ_REPLICA_NAME, - task_id='sql_instance_read_replica_create' + task_id='sql_instance_read_replica_create', ) # ############################################## # @@ -208,31 +163,20 @@ # [START howto_operator_cloudsql_patch] sql_instance_patch_task = CloudSQLInstancePatchOperator( - project_id=GCP_PROJECT_ID, - body=patch_body, - instance=INSTANCE_NAME, - task_id='sql_instance_patch_task' + project_id=GCP_PROJECT_ID, body=patch_body, instance=INSTANCE_NAME, task_id='sql_instance_patch_task' ) # [END howto_operator_cloudsql_patch] sql_instance_patch_task2 = CloudSQLInstancePatchOperator( - project_id=GCP_PROJECT_ID, - body=patch_body, - instance=INSTANCE_NAME, - task_id='sql_instance_patch_task2' + project_id=GCP_PROJECT_ID, body=patch_body, instance=INSTANCE_NAME, task_id='sql_instance_patch_task2' ) # [START howto_operator_cloudsql_db_create] sql_db_create_task = CloudSQLCreateInstanceDatabaseOperator( - project_id=GCP_PROJECT_ID, - body=db_create_body, - instance=INSTANCE_NAME, - task_id='sql_db_create_task' + project_id=GCP_PROJECT_ID, body=db_create_body, instance=INSTANCE_NAME, task_id='sql_db_create_task' ) sql_db_create_task2 = CloudSQLCreateInstanceDatabaseOperator( - body=db_create_body, - instance=INSTANCE_NAME, - task_id='sql_db_create_task2' + body=db_create_body, instance=INSTANCE_NAME, task_id='sql_db_create_task2' ) # [END howto_operator_cloudsql_db_create] @@ -242,13 +186,10 @@ body=db_patch_body, instance=INSTANCE_NAME, database=DB_NAME, - task_id='sql_db_patch_task' + task_id='sql_db_patch_task', ) sql_db_patch_task2 = CloudSQLPatchInstanceDatabaseOperator( - body=db_patch_body, - instance=INSTANCE_NAME, - database=DB_NAME, - task_id='sql_db_patch_task2' + body=db_patch_body, instance=INSTANCE_NAME, database=DB_NAME, task_id='sql_db_patch_task2' ) # [END howto_operator_cloudsql_db_patch] @@ -262,25 +203,20 @@ # [START howto_operator_cloudsql_export_gcs_permissions] sql_gcp_add_bucket_permission_task = GCSBucketCreateAclEntryOperator( entity="user-{{ task_instance.xcom_pull(" - "'sql_instance_create_task', key='service_account_email') " - "}}", + "'sql_instance_create_task', key='service_account_email') " + "}}", role="WRITER", bucket=export_url_split[1], # netloc (bucket) - task_id='sql_gcp_add_bucket_permission_task' + task_id='sql_gcp_add_bucket_permission_task', ) # [END howto_operator_cloudsql_export_gcs_permissions] # [START howto_operator_cloudsql_export] sql_export_task = CloudSQLExportInstanceOperator( - project_id=GCP_PROJECT_ID, - body=export_body, - instance=INSTANCE_NAME, - task_id='sql_export_task' + project_id=GCP_PROJECT_ID, body=export_body, instance=INSTANCE_NAME, task_id='sql_export_task' ) sql_export_task2 = CloudSQLExportInstanceOperator( - body=export_body, - instance=INSTANCE_NAME, - task_id='sql_export_task2' + body=export_body, instance=INSTANCE_NAME, task_id='sql_export_task2' ) # [END howto_operator_cloudsql_export] @@ -294,8 +230,8 @@ # [START howto_operator_cloudsql_import_gcs_permissions] sql_gcp_add_object_permission_task = GCSObjectCreateAclEntryOperator( entity="user-{{ task_instance.xcom_pull(" - "'sql_instance_create_task2', key='service_account_email')" - " }}", + "'sql_instance_create_task2', key='service_account_email')" + " }}", role="READER", bucket=import_url_split[1], # netloc (bucket) object_name=import_url_split[2][1:], # path (strip first '/') @@ -306,8 +242,8 @@ # write access to the whole bucket!. sql_gcp_add_bucket_permission_2_task = GCSBucketCreateAclEntryOperator( entity="user-{{ task_instance.xcom_pull(" - "'sql_instance_create_task2', key='service_account_email') " - "}}", + "'sql_instance_create_task2', key='service_account_email') " + "}}", role="WRITER", bucket=import_url_split[1], # netloc task_id='sql_gcp_add_bucket_permission_2_task', @@ -316,15 +252,10 @@ # [START howto_operator_cloudsql_import] sql_import_task = CloudSQLImportInstanceOperator( - project_id=GCP_PROJECT_ID, - body=import_body, - instance=INSTANCE_NAME2, - task_id='sql_import_task' + project_id=GCP_PROJECT_ID, body=import_body, instance=INSTANCE_NAME2, task_id='sql_import_task' ) sql_import_task2 = CloudSQLImportInstanceOperator( - body=import_body, - instance=INSTANCE_NAME2, - task_id='sql_import_task2' + body=import_body, instance=INSTANCE_NAME2, task_id='sql_import_task2' ) # [END howto_operator_cloudsql_import] @@ -334,15 +265,10 @@ # [START howto_operator_cloudsql_db_delete] sql_db_delete_task = CloudSQLDeleteInstanceDatabaseOperator( - project_id=GCP_PROJECT_ID, - instance=INSTANCE_NAME, - database=DB_NAME, - task_id='sql_db_delete_task' + project_id=GCP_PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, task_id='sql_db_delete_task' ) sql_db_delete_task2 = CloudSQLDeleteInstanceDatabaseOperator( - instance=INSTANCE_NAME, - database=DB_NAME, - task_id='sql_db_delete_task2' + instance=INSTANCE_NAME, database=DB_NAME, task_id='sql_db_delete_task2' ) # [END howto_operator_cloudsql_db_delete] @@ -354,32 +280,25 @@ sql_instance_failover_replica_delete_task = CloudSQLDeleteInstanceOperator( project_id=GCP_PROJECT_ID, instance=FAILOVER_REPLICA_NAME, - task_id='sql_instance_failover_replica_delete_task' + task_id='sql_instance_failover_replica_delete_task', ) sql_instance_read_replica_delete_task = CloudSQLDeleteInstanceOperator( - project_id=GCP_PROJECT_ID, - instance=READ_REPLICA_NAME, - task_id='sql_instance_read_replica_delete_task' + project_id=GCP_PROJECT_ID, instance=READ_REPLICA_NAME, task_id='sql_instance_read_replica_delete_task' ) # [END howto_operator_cloudsql_replicas_delete] # [START howto_operator_cloudsql_delete] sql_instance_delete_task = CloudSQLDeleteInstanceOperator( - project_id=GCP_PROJECT_ID, - instance=INSTANCE_NAME, - task_id='sql_instance_delete_task' + project_id=GCP_PROJECT_ID, instance=INSTANCE_NAME, task_id='sql_instance_delete_task' ) sql_instance_delete_task2 = CloudSQLDeleteInstanceOperator( - instance=INSTANCE_NAME2, - task_id='sql_instance_delete_task2' + instance=INSTANCE_NAME2, task_id='sql_instance_delete_task2' ) # [END howto_operator_cloudsql_delete] sql_instance_delete_2_task = CloudSQLDeleteInstanceOperator( - project_id=GCP_PROJECT_ID, - instance=INSTANCE_NAME2, - task_id='sql_instance_delete_2_task' + project_id=GCP_PROJECT_ID, instance=INSTANCE_NAME2, task_id='sql_instance_delete_2_task' ) ( diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py b/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py index 94fed1dddec9c..e1611becd17d7 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py @@ -48,35 +48,29 @@ GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') GCP_REGION = os.environ.get('GCP_REGION', 'europe-west-1b') -GCSQL_POSTGRES_INSTANCE_NAME_QUERY = os.environ.get( - 'GCSQL_POSTGRES_INSTANCE_NAME_QUERY', - 'testpostgres') -GCSQL_POSTGRES_DATABASE_NAME = os.environ.get('GCSQL_POSTGRES_DATABASE_NAME', - 'postgresdb') +GCSQL_POSTGRES_INSTANCE_NAME_QUERY = os.environ.get('GCSQL_POSTGRES_INSTANCE_NAME_QUERY', 'testpostgres') +GCSQL_POSTGRES_DATABASE_NAME = os.environ.get('GCSQL_POSTGRES_DATABASE_NAME', 'postgresdb') GCSQL_POSTGRES_USER = os.environ.get('GCSQL_POSTGRES_USER', 'postgres_user') GCSQL_POSTGRES_PASSWORD = os.environ.get('GCSQL_POSTGRES_PASSWORD', 'password') GCSQL_POSTGRES_PUBLIC_IP = os.environ.get('GCSQL_POSTGRES_PUBLIC_IP', '0.0.0.0') GCSQL_POSTGRES_PUBLIC_PORT = os.environ.get('GCSQL_POSTGRES_PUBLIC_PORT', 5432) -GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get('GCSQL_POSTGRES_CLIENT_CERT_FILE', - ".key/postgres-client-cert.pem") -GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get('GCSQL_POSTGRES_CLIENT_KEY_FILE', - ".key/postgres-client-key.pem") -GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get('GCSQL_POSTGRES_SERVER_CA_FILE', - ".key/postgres-server-ca.pem") - -GCSQL_MYSQL_INSTANCE_NAME_QUERY = os.environ.get('GCSQL_MYSQL_INSTANCE_NAME_QUERY', - 'testmysql') +GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get( + 'GCSQL_POSTGRES_CLIENT_CERT_FILE', ".key/postgres-client-cert.pem" +) +GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get( + 'GCSQL_POSTGRES_CLIENT_KEY_FILE', ".key/postgres-client-key.pem" +) +GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get('GCSQL_POSTGRES_SERVER_CA_FILE', ".key/postgres-server-ca.pem") + +GCSQL_MYSQL_INSTANCE_NAME_QUERY = os.environ.get('GCSQL_MYSQL_INSTANCE_NAME_QUERY', 'testmysql') GCSQL_MYSQL_DATABASE_NAME = os.environ.get('GCSQL_MYSQL_DATABASE_NAME', 'mysqldb') GCSQL_MYSQL_USER = os.environ.get('GCSQL_MYSQL_USER', 'mysql_user') GCSQL_MYSQL_PASSWORD = os.environ.get('GCSQL_MYSQL_PASSWORD', 'password') GCSQL_MYSQL_PUBLIC_IP = os.environ.get('GCSQL_MYSQL_PUBLIC_IP', '0.0.0.0') GCSQL_MYSQL_PUBLIC_PORT = os.environ.get('GCSQL_MYSQL_PUBLIC_PORT', 3306) -GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_CERT_FILE', - ".key/mysql-client-cert.pem") -GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_KEY_FILE', - ".key/mysql-client-key.pem") -GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get('GCSQL_MYSQL_SERVER_CA_FILE', - ".key/mysql-server-ca.pem") +GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_CERT_FILE', ".key/mysql-client-cert.pem") +GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_KEY_FILE', ".key/mysql-client-key.pem") +GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get('GCSQL_MYSQL_SERVER_CA_FILE', ".key/mysql-server-ca.pem") SQL = [ 'CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)', @@ -114,7 +108,7 @@ def get_absolute_path(path): database=quote_plus(GCSQL_POSTGRES_DATABASE_NAME), client_cert_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_CERT_FILE)), client_key_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_KEY_FILE)), - server_ca_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_SERVER_CA_FILE)) + server_ca_file=quote_plus(get_absolute_path(GCSQL_POSTGRES_SERVER_CA_FILE)), ) # The connections below are created using one of the standard approaches - via environment @@ -122,49 +116,52 @@ def get_absolute_path(path): # of AIRFLOW (using command line or UI). # Postgres: connect via proxy over TCP -os.environ['AIRFLOW_CONN_PROXY_POSTGRES_TCP'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=postgres&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=True&" \ +os.environ['AIRFLOW_CONN_PROXY_POSTGRES_TCP'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" "sql_proxy_use_tcp=True".format(**postgres_kwargs) +) # Postgres: connect via proxy over UNIX socket (specific proxy version) -os.environ['AIRFLOW_CONN_PROXY_POSTGRES_SOCKET'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=postgres&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=True&" \ - "sql_proxy_version=v1.13&" \ +os.environ['AIRFLOW_CONN_PROXY_POSTGRES_SOCKET'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_version=v1.13&" "sql_proxy_use_tcp=False".format(**postgres_kwargs) +) # Postgres: connect directly via TCP (non-SSL) -os.environ['AIRFLOW_CONN_PUBLIC_POSTGRES_TCP'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=postgres&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=False&" \ +os.environ['AIRFLOW_CONN_PUBLIC_POSTGRES_TCP'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" "use_ssl=False".format(**postgres_kwargs) +) # Postgres: connect directly via TCP (SSL) -os.environ['AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=postgres&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=False&" \ - "use_ssl=True&" \ - "sslcert={client_cert_file}&" \ - "sslkey={client_key_file}&" \ - "sslrootcert={server_ca_file}"\ - .format(**postgres_kwargs) +os.environ['AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}".format(**postgres_kwargs) +) mysql_kwargs = dict( user=quote_plus(GCSQL_MYSQL_USER), @@ -177,74 +174,77 @@ def get_absolute_path(path): database=quote_plus(GCSQL_MYSQL_DATABASE_NAME), client_cert_file=quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_CERT_FILE)), client_key_file=quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_KEY_FILE)), - server_ca_file=quote_plus(get_absolute_path(GCSQL_MYSQL_SERVER_CA_FILE)) + server_ca_file=quote_plus(get_absolute_path(GCSQL_MYSQL_SERVER_CA_FILE)), ) # MySQL: connect via proxy over TCP (specific proxy version) -os.environ['AIRFLOW_CONN_PROXY_MYSQL_TCP'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=mysql&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=True&" \ - "sql_proxy_version=v1.13&" \ +os.environ['AIRFLOW_CONN_PROXY_MYSQL_TCP'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_version=v1.13&" "sql_proxy_use_tcp=True".format(**mysql_kwargs) +) # MySQL: connect via proxy over UNIX socket using pre-downloaded Cloud Sql Proxy binary try: - sql_proxy_binary_path = subprocess.check_output( - ['which', 'cloud_sql_proxy']).decode('utf-8').rstrip() + sql_proxy_binary_path = subprocess.check_output(['which', 'cloud_sql_proxy']).decode('utf-8').rstrip() except subprocess.CalledProcessError: sql_proxy_binary_path = "/tmp/anyhow_download_cloud_sql_proxy" -os.environ['AIRFLOW_CONN_PROXY_MYSQL_SOCKET'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=mysql&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=True&" \ - "sql_proxy_binary_path={sql_proxy_binary_path}&" \ - "sql_proxy_use_tcp=False".format( - sql_proxy_binary_path=quote_plus(sql_proxy_binary_path), **mysql_kwargs) +os.environ['AIRFLOW_CONN_PROXY_MYSQL_SOCKET'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=True&" + "sql_proxy_binary_path={sql_proxy_binary_path}&" + "sql_proxy_use_tcp=False".format(sql_proxy_binary_path=quote_plus(sql_proxy_binary_path), **mysql_kwargs) +) # MySQL: connect directly via TCP (non-SSL) -os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=mysql&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=False&" \ +os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" "use_ssl=False".format(**mysql_kwargs) +) # MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql Proxy binary path -os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=mysql&" \ - "project_id={project_id}&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=False&" \ - "use_ssl=True&" \ - "sslcert={client_cert_file}&" \ - "sslkey={client_key_file}&" \ +os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" "sslrootcert={server_ca_file}".format(**mysql_kwargs) +) # Special case: MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql # Proxy binary path AND with missing project_id -os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID'] = \ - "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" \ - "database_type=mysql&" \ - "location={location}&" \ - "instance={instance}&" \ - "use_proxy=False&" \ - "use_ssl=True&" \ - "sslcert={client_cert_file}&" \ - "sslkey={client_key_file}&" \ +os.environ['AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID'] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" "sslrootcert={server_ca_file}".format(**mysql_kwargs) +) # [END howto_operator_cloudsql_query_connections] @@ -260,25 +260,20 @@ def get_absolute_path(path): "proxy_mysql_socket", "public_mysql_tcp", "public_mysql_tcp_ssl", - "public_mysql_tcp_ssl_no_project_id" + "public_mysql_tcp_ssl_no_project_id", ] tasks = [] with models.DAG( - dag_id='example_gcp_sql_query', - schedule_interval=None, - start_date=days_ago(1), - tags=['example'], + dag_id='example_gcp_sql_query', schedule_interval=None, start_date=days_ago(1), tags=['example'], ) as dag: prev_task = None for connection_name in connection_names: task = CloudSQLExecuteQueryOperator( - gcp_cloudsql_conn_id=connection_name, - task_id="example_gcp_sql_task_" + connection_name, - sql=SQL + gcp_cloudsql_conn_id=connection_name, task_id="example_gcp_sql_task_" + connection_name, sql=SQL ) tasks.append(task) if prev_task: diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py index a994f6b632210..353aa335f3d93 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_aws.py @@ -41,15 +41,32 @@ from airflow import models from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - ALREADY_EXISTING_IN_SINK, AWS_S3_DATA_SOURCE, BUCKET_NAME, DESCRIPTION, FILTER_JOB_NAMES, - FILTER_PROJECT_ID, GCS_DATA_SINK, JOB_NAME, PROJECT_ID, SCHEDULE, SCHEDULE_END_DATE, SCHEDULE_START_DATE, - START_TIME_OF_DAY, STATUS, TRANSFER_OPTIONS, TRANSFER_SPEC, GcpTransferJobsStatus, + ALREADY_EXISTING_IN_SINK, + AWS_S3_DATA_SOURCE, + BUCKET_NAME, + DESCRIPTION, + FILTER_JOB_NAMES, + FILTER_PROJECT_ID, + GCS_DATA_SINK, + JOB_NAME, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + START_TIME_OF_DAY, + STATUS, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + GcpTransferJobsStatus, GcpTransferOperationStatus, ) from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( - CloudDataTransferServiceCancelOperationOperator, CloudDataTransferServiceCreateJobOperator, - CloudDataTransferServiceDeleteJobOperator, CloudDataTransferServiceGetOperationOperator, - CloudDataTransferServiceListOperationsOperator, CloudDataTransferServicePauseOperationOperator, + CloudDataTransferServiceCancelOperationOperator, + CloudDataTransferServiceCreateJobOperator, + CloudDataTransferServiceDeleteJobOperator, + CloudDataTransferServiceGetOperationOperator, + CloudDataTransferServiceListOperationsOperator, + CloudDataTransferServicePauseOperationOperator, CloudDataTransferServiceResumeOperationOperator, ) from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( @@ -67,9 +84,7 @@ 'GCP_TRANSFER_FIRST_TARGET_BUCKET', 'gcp-transfer-first-target' ) -GCP_TRANSFER_JOB_NAME = os.environ.get( - 'GCP_TRANSFER_JOB_NAME', 'transferJobs/sampleJob' -) +GCP_TRANSFER_JOB_NAME = os.environ.get('GCP_TRANSFER_JOB_NAME', 'transferJobs/sampleJob') # [START howto_operator_gcp_transfer_create_job_body_aws] aws_to_gcs_transfer_body = { @@ -172,6 +187,8 @@ ) # [END howto_operator_gcp_transfer_delete_job] - create_transfer_job_from_aws >> wait_for_operation_to_start >> pause_operation >> \ - list_operations >> get_operation >> resume_operation >> wait_for_operation_to_end >> \ - cancel_operation >> delete_transfer_from_aws_job + # fmt: off + create_transfer_job_from_aws >> wait_for_operation_to_start >> pause_operation + pause_operation >> list_operations >> get_operation >> resume_operation + resume_operation >> wait_for_operation_to_end >> cancel_operation >> delete_transfer_from_aws_job + # fmt: on diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py index 1df2541a11e27..afbd04dee43e2 100644 --- a/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py +++ b/airflow/providers/google/cloud/example_dags/example_cloud_storage_transfer_service_gcp.py @@ -33,14 +33,31 @@ from airflow import models from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - ALREADY_EXISTING_IN_SINK, BUCKET_NAME, DESCRIPTION, FILTER_JOB_NAMES, FILTER_PROJECT_ID, GCS_DATA_SINK, - GCS_DATA_SOURCE, PROJECT_ID, SCHEDULE, SCHEDULE_END_DATE, SCHEDULE_START_DATE, START_TIME_OF_DAY, STATUS, - TRANSFER_JOB, TRANSFER_JOB_FIELD_MASK, TRANSFER_OPTIONS, TRANSFER_SPEC, GcpTransferJobsStatus, + ALREADY_EXISTING_IN_SINK, + BUCKET_NAME, + DESCRIPTION, + FILTER_JOB_NAMES, + FILTER_PROJECT_ID, + GCS_DATA_SINK, + GCS_DATA_SOURCE, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + START_TIME_OF_DAY, + STATUS, + TRANSFER_JOB, + TRANSFER_JOB_FIELD_MASK, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + GcpTransferJobsStatus, GcpTransferOperationStatus, ) from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( - CloudDataTransferServiceCreateJobOperator, CloudDataTransferServiceDeleteJobOperator, - CloudDataTransferServiceGetOperationOperator, CloudDataTransferServiceListOperationsOperator, + CloudDataTransferServiceCreateJobOperator, + CloudDataTransferServiceDeleteJobOperator, + CloudDataTransferServiceGetOperationOperator, + CloudDataTransferServiceListOperationsOperator, CloudDataTransferServiceUpdateJobOperator, ) from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( @@ -112,15 +129,12 @@ task_id="list_operations", request_filter={ FILTER_PROJECT_ID: GCP_PROJECT_ID, - FILTER_JOB_NAMES: [ - "{{task_instance.xcom_pull('create_transfer')['name']}}" - ], + FILTER_JOB_NAMES: ["{{task_instance.xcom_pull('create_transfer')['name']}}"], }, ) get_operation = CloudDataTransferServiceGetOperationOperator( - task_id="get_operation", - operation_name="{{task_instance.xcom_pull('list_operations')[0]['name']}}", + task_id="get_operation", operation_name="{{task_instance.xcom_pull('list_operations')[0]['name']}}", ) delete_transfer = CloudDataTransferServiceDeleteJobOperator( @@ -129,5 +143,5 @@ project_id=GCP_PROJECT_ID, ) - create_transfer >> wait_for_transfer >> update_transfer >> \ - list_operations >> get_operation >> delete_transfer + create_transfer >> wait_for_transfer >> update_transfer >> list_operations >> get_operation + get_operation >> delete_transfer diff --git a/airflow/providers/google/cloud/example_dags/example_compute.py b/airflow/providers/google/cloud/example_dags/example_compute.py index 3153715c01e30..93ec29121d233 100644 --- a/airflow/providers/google/cloud/example_dags/example_compute.py +++ b/airflow/providers/google/cloud/example_dags/example_compute.py @@ -33,7 +33,8 @@ from airflow import models from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineSetMachineTypeOperator, ComputeEngineStartInstanceOperator, + ComputeEngineSetMachineTypeOperator, + ComputeEngineStartInstanceOperator, ComputeEngineStopInstanceOperator, ) from airflow.utils.dates import days_ago @@ -56,34 +57,24 @@ ) as dag: # [START howto_operator_gce_start] gce_instance_start = ComputeEngineStartInstanceOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - task_id='gcp_compute_start_task' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_start_task' ) # [END howto_operator_gce_start] # Duplicate start for idempotence testing # [START howto_operator_gce_start_no_project_id] gce_instance_start2 = ComputeEngineStartInstanceOperator( - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - task_id='gcp_compute_start_task2' + zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_start_task2' ) # [END howto_operator_gce_start_no_project_id] # [START howto_operator_gce_stop] gce_instance_stop = ComputeEngineStopInstanceOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - task_id='gcp_compute_stop_task' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_stop_task' ) # [END howto_operator_gce_stop] # Duplicate stop for idempotence testing # [START howto_operator_gce_stop_no_project_id] gce_instance_stop2 = ComputeEngineStopInstanceOperator( - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - task_id='gcp_compute_stop_task2' + zone=GCE_ZONE, resource_id=GCE_INSTANCE, task_id='gcp_compute_stop_task2' ) # [END howto_operator_gce_stop_no_project_id] # [START howto_operator_gce_set_machine_type] @@ -91,10 +82,8 @@ project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE, - body={ - 'machineType': 'zones/{}/machineTypes/{}'.format(GCE_ZONE, GCE_SHORT_MACHINE_TYPE_NAME) - }, - task_id='gcp_compute_set_machine_type' + body={'machineType': 'zones/{}/machineTypes/{}'.format(GCE_ZONE, GCE_SHORT_MACHINE_TYPE_NAME)}, + task_id='gcp_compute_set_machine_type', ) # [END howto_operator_gce_set_machine_type] # Duplicate set machine type for idempotence testing @@ -102,12 +91,10 @@ gce_set_machine_type2 = ComputeEngineSetMachineTypeOperator( zone=GCE_ZONE, resource_id=GCE_INSTANCE, - body={ - 'machineType': 'zones/{}/machineTypes/{}'.format(GCE_ZONE, GCE_SHORT_MACHINE_TYPE_NAME) - }, - task_id='gcp_compute_set_machine_type2' + body={'machineType': 'zones/{}/machineTypes/{}'.format(GCE_ZONE, GCE_SHORT_MACHINE_TYPE_NAME)}, + task_id='gcp_compute_set_machine_type2', ) # [END howto_operator_gce_set_machine_type_no_project_id] - gce_instance_start >> gce_instance_start2 >> gce_instance_stop >> \ - gce_instance_stop2 >> gce_set_machine_type >> gce_set_machine_type2 + gce_instance_start >> gce_instance_start2 >> gce_instance_stop >> gce_instance_stop2 + gce_instance_stop2 >> gce_set_machine_type >> gce_set_machine_type2 diff --git a/airflow/providers/google/cloud/example_dags/example_compute_igm.py b/airflow/providers/google/cloud/example_dags/example_compute_igm.py index f091af1f90765..d58392cf6807f 100644 --- a/airflow/providers/google/cloud/example_dags/example_compute_igm.py +++ b/airflow/providers/google/cloud/example_dags/example_compute_igm.py @@ -42,7 +42,8 @@ from airflow import models from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineCopyInstanceTemplateOperator, ComputeEngineInstanceGroupUpdateManagerTemplateOperator, + ComputeEngineCopyInstanceTemplateOperator, + ComputeEngineInstanceGroupUpdateManagerTemplateOperator, ) from airflow.utils.dates import days_ago @@ -51,39 +52,38 @@ # [START howto_operator_compute_template_copy_args] GCE_TEMPLATE_NAME = os.environ.get('GCE_TEMPLATE_NAME', 'instance-template-test') -GCE_NEW_TEMPLATE_NAME = os.environ.get('GCE_NEW_TEMPLATE_NAME', - 'instance-template-test-new') +GCE_NEW_TEMPLATE_NAME = os.environ.get('GCE_NEW_TEMPLATE_NAME', 'instance-template-test-new') GCE_NEW_DESCRIPTION = os.environ.get('GCE_NEW_DESCRIPTION', 'Test new description') GCE_INSTANCE_TEMPLATE_BODY_UPDATE = { "name": GCE_NEW_TEMPLATE_NAME, "description": GCE_NEW_DESCRIPTION, - "properties": { - "machineType": "n1-standard-2" - } + "properties": {"machineType": "n1-standard-2"}, } # [END howto_operator_compute_template_copy_args] # [START howto_operator_compute_igm_update_template_args] -GCE_INSTANCE_GROUP_MANAGER_NAME = os.environ.get('GCE_INSTANCE_GROUP_MANAGER_NAME', - 'instance-group-test') +GCE_INSTANCE_GROUP_MANAGER_NAME = os.environ.get('GCE_INSTANCE_GROUP_MANAGER_NAME', 'instance-group-test') SOURCE_TEMPLATE_URL = os.environ.get( 'SOURCE_TEMPLATE_URL', - "https://www.googleapis.com/compute/beta/projects/" + GCP_PROJECT_ID + - "/global/instanceTemplates/instance-template-test") + "https://www.googleapis.com/compute/beta/projects/" + + GCP_PROJECT_ID + + "/global/instanceTemplates/instance-template-test", +) DESTINATION_TEMPLATE_URL = os.environ.get( 'DESTINATION_TEMPLATE_URL', - "https://www.googleapis.com/compute/beta/projects/" + GCP_PROJECT_ID + - "/global/instanceTemplates/" + GCE_NEW_TEMPLATE_NAME) + "https://www.googleapis.com/compute/beta/projects/" + + GCP_PROJECT_ID + + "/global/instanceTemplates/" + + GCE_NEW_TEMPLATE_NAME, +) UPDATE_POLICY = { "type": "OPPORTUNISTIC", "minimalAction": "RESTART", - "maxSurge": { - "fixed": 1 - }, - "minReadySec": 1800 + "maxSurge": {"fixed": 1}, + "minReadySec": 1800, } # [END howto_operator_compute_igm_update_template_args] @@ -100,7 +100,7 @@ project_id=GCP_PROJECT_ID, resource_id=GCE_TEMPLATE_NAME, body_patch=GCE_INSTANCE_TEMPLATE_BODY_UPDATE, - task_id='gcp_compute_igm_copy_template_task' + task_id='gcp_compute_igm_copy_template_task', ) # [END howto_operator_gce_igm_copy_template] # Added to check for idempotence @@ -108,33 +108,30 @@ gce_instance_template_copy2 = ComputeEngineCopyInstanceTemplateOperator( resource_id=GCE_TEMPLATE_NAME, body_patch=GCE_INSTANCE_TEMPLATE_BODY_UPDATE, - task_id='gcp_compute_igm_copy_template_task_2' + task_id='gcp_compute_igm_copy_template_task_2', ) # [END howto_operator_gce_igm_copy_template_no_project_id] # [START howto_operator_gce_igm_update_template] - gce_instance_group_manager_update_template = \ - ComputeEngineInstanceGroupUpdateManagerTemplateOperator( - project_id=GCP_PROJECT_ID, - resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, - zone=GCE_ZONE, - source_template=SOURCE_TEMPLATE_URL, - destination_template=DESTINATION_TEMPLATE_URL, - update_policy=UPDATE_POLICY, - task_id='gcp_compute_igm_group_manager_update_template' - ) + gce_instance_group_manager_update_template = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( + project_id=GCP_PROJECT_ID, + resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, + zone=GCE_ZONE, + source_template=SOURCE_TEMPLATE_URL, + destination_template=DESTINATION_TEMPLATE_URL, + update_policy=UPDATE_POLICY, + task_id='gcp_compute_igm_group_manager_update_template', + ) # [END howto_operator_gce_igm_update_template] # Added to check for idempotence (and without UPDATE_POLICY) # [START howto_operator_gce_igm_update_template_no_project_id] - gce_instance_group_manager_update_template2 = \ - ComputeEngineInstanceGroupUpdateManagerTemplateOperator( - resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, - zone=GCE_ZONE, - source_template=SOURCE_TEMPLATE_URL, - destination_template=DESTINATION_TEMPLATE_URL, - task_id='gcp_compute_igm_group_manager_update_template_2' - ) + gce_instance_group_manager_update_template2 = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( + resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, + zone=GCE_ZONE, + source_template=SOURCE_TEMPLATE_URL, + destination_template=DESTINATION_TEMPLATE_URL, + task_id='gcp_compute_igm_group_manager_update_template_2', + ) # [END howto_operator_gce_igm_update_template_no_project_id] - gce_instance_template_copy >> gce_instance_template_copy2 >> \ - gce_instance_group_manager_update_template >> \ - gce_instance_group_manager_update_template2 + gce_instance_template_copy >> gce_instance_template_copy2 >> gce_instance_group_manager_update_template + gce_instance_group_manager_update_template >> gce_instance_group_manager_update_template2 diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py index 4a2d13a42cfb8..08f8da278a90a 100644 --- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py +++ b/airflow/providers/google/cloud/example_dags/example_datacatalog.py @@ -24,16 +24,26 @@ from airflow import models from airflow.operators.bash_operator import BashOperator from airflow.providers.google.cloud.operators.datacatalog import ( - CloudDataCatalogCreateEntryGroupOperator, CloudDataCatalogCreateEntryOperator, - CloudDataCatalogCreateTagOperator, CloudDataCatalogCreateTagTemplateFieldOperator, - CloudDataCatalogCreateTagTemplateOperator, CloudDataCatalogDeleteEntryGroupOperator, - CloudDataCatalogDeleteEntryOperator, CloudDataCatalogDeleteTagOperator, - CloudDataCatalogDeleteTagTemplateFieldOperator, CloudDataCatalogDeleteTagTemplateOperator, - CloudDataCatalogGetEntryGroupOperator, CloudDataCatalogGetEntryOperator, - CloudDataCatalogGetTagTemplateOperator, CloudDataCatalogListTagsOperator, - CloudDataCatalogLookupEntryOperator, CloudDataCatalogRenameTagTemplateFieldOperator, - CloudDataCatalogSearchCatalogOperator, CloudDataCatalogUpdateEntryOperator, - CloudDataCatalogUpdateTagOperator, CloudDataCatalogUpdateTagTemplateFieldOperator, + CloudDataCatalogCreateEntryGroupOperator, + CloudDataCatalogCreateEntryOperator, + CloudDataCatalogCreateTagOperator, + CloudDataCatalogCreateTagTemplateFieldOperator, + CloudDataCatalogCreateTagTemplateOperator, + CloudDataCatalogDeleteEntryGroupOperator, + CloudDataCatalogDeleteEntryOperator, + CloudDataCatalogDeleteTagOperator, + CloudDataCatalogDeleteTagTemplateFieldOperator, + CloudDataCatalogDeleteTagTemplateOperator, + CloudDataCatalogGetEntryGroupOperator, + CloudDataCatalogGetEntryOperator, + CloudDataCatalogGetTagTemplateOperator, + CloudDataCatalogListTagsOperator, + CloudDataCatalogLookupEntryOperator, + CloudDataCatalogRenameTagTemplateFieldOperator, + CloudDataCatalogSearchCatalogOperator, + CloudDataCatalogUpdateEntryOperator, + CloudDataCatalogUpdateTagOperator, + CloudDataCatalogUpdateTagTemplateFieldOperator, CloudDataCatalogUpdateTagTemplateOperator, ) from airflow.utils.dates import days_ago @@ -288,7 +298,7 @@ task_id="lookup_entry", linked_resource=current_entry_template.format( project_id=PROJECT_ID, location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID - ) + ), ) # [END howto_operator_gcp_datacatalog_lookup_entry_linked_resource] diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index d5d82d66a8778..8449ae1012d32 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -24,7 +24,9 @@ from airflow import models from airflow.providers.google.cloud.operators.dataflow import ( - CheckJobRunning, DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, + CheckJobRunning, + DataflowCreateJavaJobOperator, + DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator @@ -40,12 +42,7 @@ GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc GCS_JAR_OBJECT_NAME = GCS_JAR_PARTS.path[1:] -default_args = { - 'dataflow_default_options': { - 'tempLocation': GCS_TMP, - 'stagingLocation': GCS_STAGING, - } -} +default_args = {'dataflow_default_options': {'tempLocation': GCS_TMP, 'stagingLocation': GCS_STAGING,}} with models.DAG( "example_gcp_dataflow_native_java", @@ -59,13 +56,11 @@ task_id="start-java-job", jar=GCS_JAR, job_name='{{task.task_id}}', - options={ - 'output': GCS_OUTPUT, - }, + options={'output': GCS_OUTPUT,}, poll_sleep=10, job_class='org.apache.beam.examples.WordCount', check_if_running=CheckJobRunning.IgnoreJob, - location='europe-west3' + location='europe-west3', ) # [END howto_operator_start_java_job] @@ -80,9 +75,7 @@ task_id="start-java-job-local", jar="/tmp/dataflow-{{ ds_nodash }}.jar", job_name='{{task.task_id}}', - options={ - 'output': GCS_OUTPUT, - }, + options={'output': GCS_OUTPUT,}, poll_sleep=10, job_class='org.apache.beam.examples.WordCount', check_if_running=CheckJobRunning.WaitForRun, @@ -103,15 +96,11 @@ py_file=GCS_PYTHON, py_options=[], job_name='{{task.task_id}}', - options={ - 'output': GCS_OUTPUT, - }, - py_requirements=[ - 'apache-beam[gcp]==2.21.0' - ], + options={'output': GCS_OUTPUT,}, + py_requirements=['apache-beam[gcp]==2.21.0'], py_interpreter='python3', py_system_site_packages=False, - location='europe-west3' + location='europe-west3', ) # [END howto_operator_start_python_job] @@ -120,14 +109,10 @@ py_file='apache_beam.examples.wordcount', py_options=['-m'], job_name='{{task.task_id}}', - options={ - 'output': GCS_OUTPUT, - }, - py_requirements=[ - 'apache-beam[gcp]==2.14.0' - ], + options={'output': GCS_OUTPUT,}, + py_requirements=['apache-beam[gcp]==2.14.0'], py_interpreter='python3', - py_system_site_packages=False + py_system_site_packages=False, ) with models.DAG( @@ -140,9 +125,6 @@ start_template_job = DataflowTemplatedJobStartOperator( task_id="start-template-job", template='gs://dataflow-templates/latest/Word_Count', - parameters={ - 'inputFile': "gs://dataflow-samples/shakespeare/kinglear.txt", - 'output': GCS_OUTPUT - }, - location='europe-west3' + parameters={'inputFile': "gs://dataflow-samples/shakespeare/kinglear.txt", 'output': GCS_OUTPUT}, + location='europe-west3', ) diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py b/airflow/providers/google/cloud/example_dags/example_datafusion.py index 60aa8f17bb717..1ee999d4ea0f6 100644 --- a/airflow/providers/google/cloud/example_dags/example_datafusion.py +++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -22,11 +22,16 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.datafusion import ( - CloudDataFusionCreateInstanceOperator, CloudDataFusionCreatePipelineOperator, - CloudDataFusionDeleteInstanceOperator, CloudDataFusionDeletePipelineOperator, - CloudDataFusionGetInstanceOperator, CloudDataFusionListPipelinesOperator, - CloudDataFusionRestartInstanceOperator, CloudDataFusionStartPipelineOperator, - CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator, + CloudDataFusionCreateInstanceOperator, + CloudDataFusionCreatePipelineOperator, + CloudDataFusionDeleteInstanceOperator, + CloudDataFusionDeletePipelineOperator, + CloudDataFusionGetInstanceOperator, + CloudDataFusionListPipelinesOperator, + CloudDataFusionRestartInstanceOperator, + CloudDataFusionStartPipelineOperator, + CloudDataFusionStopPipelineOperator, + CloudDataFusionUpdateInstanceOperator, ) from airflow.utils import dates from airflow.utils.state import State @@ -59,11 +64,7 @@ "name": "GCSFile", "type": "batchsource", "label": "GCS", - "artifact": { - "name": "google-cloud", - "version": "0.14.2", - "scope": "SYSTEM", - }, + "artifact": {"name": "google-cloud", "version": "0.14.2", "scope": "SYSTEM",}, "properties": { "project": "auto-detect", "format": "text", @@ -73,7 +74,7 @@ "recursive": "false", "encrypted": "false", "schema": '{"type":"record","name":"etlSchemaBody","fields":' - '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', "path": BUCKET1, "referenceName": "foo_bucket", }, @@ -82,7 +83,7 @@ { "name": "etlSchemaBody", "schema": '{"type":"record","name":"etlSchemaBody","fields":' - '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', } ], }, @@ -92,11 +93,7 @@ "name": "GCS", "type": "batchsink", "label": "GCS2", - "artifact": { - "name": "google-cloud", - "version": "0.14.2", - "scope": "SYSTEM", - }, + "artifact": {"name": "google-cloud", "version": "0.14.2", "scope": "SYSTEM",}, "properties": { "project": "auto-detect", "suffix": "yyyy-MM-dd-HH-mm", @@ -104,7 +101,7 @@ "serviceFilePath": "auto-detect", "location": "us", "schema": '{"type":"record","name":"etlSchemaBody","fields":' - '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', "referenceName": "bar", "path": BUCKET2, }, @@ -113,14 +110,14 @@ { "name": "etlSchemaBody", "schema": '{"type":"record","name":"etlSchemaBody","fields":' - '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', } ], "inputSchema": [ { "name": "GCS", "schema": '{"type":"record","name":"etlSchemaBody","fields":' - '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', + '[{"name":"offset","type":"long"},{"name":"body","type":"string"}]}', } ], }, @@ -137,14 +134,11 @@ with models.DAG( "example_data_fusion", schedule_interval=None, # Override to match your needs - start_date=dates.days_ago(1) + start_date=dates.days_ago(1), ) as dag: # [START howto_cloud_data_fusion_create_instance_operator] create_instance = CloudDataFusionCreateInstanceOperator( - location=LOCATION, - instance_name=INSTANCE_NAME, - instance=INSTANCE, - task_id="create_instance", + location=LOCATION, instance_name=INSTANCE_NAME, instance=INSTANCE, task_id="create_instance", ) # [END howto_cloud_data_fusion_create_instance_operator] @@ -188,19 +182,13 @@ # [START howto_cloud_data_fusion_start_pipeline] start_pipeline = CloudDataFusionStartPipelineOperator( - location=LOCATION, - pipeline_name=PIPELINE_NAME, - instance_name=INSTANCE_NAME, - task_id="start_pipeline", + location=LOCATION, pipeline_name=PIPELINE_NAME, instance_name=INSTANCE_NAME, task_id="start_pipeline", ) # [END howto_cloud_data_fusion_start_pipeline] # [START howto_cloud_data_fusion_stop_pipeline] stop_pipeline = CloudDataFusionStopPipelineOperator( - location=LOCATION, - pipeline_name=PIPELINE_NAME, - instance_name=INSTANCE_NAME, - task_id="stop_pipeline", + location=LOCATION, pipeline_name=PIPELINE_NAME, instance_name=INSTANCE_NAME, task_id="stop_pipeline", ) # [END howto_cloud_data_fusion_stop_pipeline] diff --git a/airflow/providers/google/cloud/example_dags/example_dataprep.py b/airflow/providers/google/cloud/example_dags/example_dataprep.py index 684b8e5ff95e2..3732ea54f7866 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataprep.py +++ b/airflow/providers/google/cloud/example_dags/example_dataprep.py @@ -25,9 +25,7 @@ JOB_ID = 6269792 with models.DAG( - "example_dataprep", - schedule_interval=None, # Override to match your needs - start_date=dates.days_ago(1) + "example_dataprep", schedule_interval=None, start_date=dates.days_ago(1) # Override to match your needs ) as dag: # [START how_to_dataprep_get_jobs_for_job_group_operator] diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py b/airflow/providers/google/cloud/example_dags/example_dataproc.py index 7dedeebced625..fd463dc8c230f 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataproc.py +++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py @@ -24,7 +24,9 @@ from airflow import models from airflow.providers.google.cloud.operators.dataproc import ( - DataprocCreateClusterOperator, DataprocDeleteClusterOperator, DataprocSubmitJobOperator, + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocSubmitJobOperator, DataprocUpdateClusterOperator, ) from airflow.utils.dates import days_ago @@ -66,16 +68,10 @@ # Update options # [START how_to_cloud_dataproc_updatemask_cluster_operator] CLUSTER_UPDATE = { - "config": { - "worker_config": {"num_instances": 3}, - "secondary_worker_config": {"num_instances": 3}, - } + "config": {"worker_config": {"num_instances": 3}, "secondary_worker_config": {"num_instances": 3},} } UPDATE_MASK = { - "paths": [ - "config.worker_config.num_instances", - "config.secondary_worker_config.num_instances", - ] + "paths": ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances",] } # [END how_to_cloud_dataproc_updatemask_cluster_operator] @@ -144,11 +140,7 @@ } # [END how_to_cloud_dataproc_hadoop_config] -with models.DAG( - "example_gcp_dataproc", - start_date=days_ago(1), - schedule_interval=None, -) as dag: +with models.DAG("example_gcp_dataproc", start_date=days_ago(1), schedule_interval=None,) as dag: # [START how_to_cloud_dataproc_create_cluster_operator] create_cluster = DataprocCreateClusterOperator( task_id="create_cluster", project_id=PROJECT_ID, cluster=CLUSTER, region=REGION @@ -171,10 +163,7 @@ task_id="pig_task", job=PIG_JOB, location=REGION, project_id=PROJECT_ID ) spark_sql_task = DataprocSubmitJobOperator( - task_id="spark_sql_task", - job=SPARK_SQL_JOB, - location=REGION, - project_id=PROJECT_ID, + task_id="spark_sql_task", job=SPARK_SQL_JOB, location=REGION, project_id=PROJECT_ID, ) spark_task = DataprocSubmitJobOperator( @@ -201,10 +190,7 @@ # [START how_to_cloud_dataproc_delete_cluster_operator] delete_cluster = DataprocDeleteClusterOperator( - task_id="delete_cluster", - project_id=PROJECT_ID, - cluster_name=CLUSTER_NAME, - region=REGION, + task_id="delete_cluster", project_id=PROJECT_ID, cluster_name=CLUSTER_NAME, region=REGION, ) # [END how_to_cloud_dataproc_delete_cluster_operator] diff --git a/airflow/providers/google/cloud/example_dags/example_datastore.py b/airflow/providers/google/cloud/example_dags/example_datastore.py index 4129b53036b10..88af5e5eb52d4 100644 --- a/airflow/providers/google/cloud/example_dags/example_datastore.py +++ b/airflow/providers/google/cloud/example_dags/example_datastore.py @@ -27,9 +27,13 @@ from airflow import models from airflow.providers.google.cloud.operators.datastore import ( - CloudDatastoreAllocateIdsOperator, CloudDatastoreBeginTransactionOperator, CloudDatastoreCommitOperator, - CloudDatastoreExportEntitiesOperator, CloudDatastoreImportEntitiesOperator, - CloudDatastoreRollbackOperator, CloudDatastoreRunQueryOperator, + CloudDatastoreAllocateIdsOperator, + CloudDatastoreBeginTransactionOperator, + CloudDatastoreCommitOperator, + CloudDatastoreExportEntitiesOperator, + CloudDatastoreImportEntitiesOperator, + CloudDatastoreRollbackOperator, + CloudDatastoreRunQueryOperator, ) from airflow.utils import dates @@ -44,10 +48,7 @@ ) as dag: # [START how_to_export_task] export_task = CloudDatastoreExportEntitiesOperator( - task_id="export_task", - bucket=BUCKET, - project_id=GCP_PROJECT_ID, - overwrite_existing=True, + task_id="export_task", bucket=BUCKET, project_id=GCP_PROJECT_ID, overwrite_existing=True, ) # [END how_to_export_task] @@ -63,12 +64,7 @@ export_task >> import_task # [START how_to_keys_def] -KEYS = [ - { - "partitionId": {"projectId": GCP_PROJECT_ID, "namespaceId": ""}, - "path": {"kind": "airflow"}, - } -] +KEYS = [{"partitionId": {"projectId": GCP_PROJECT_ID, "namespaceId": ""}, "path": {"kind": "airflow"},}] # [END how_to_keys_def] # [START how_to_transaction_def] @@ -79,12 +75,7 @@ COMMIT_BODY = { "mode": "TRANSACTIONAL", "mutations": [ - { - "insert": { - "key": KEYS[0], - "properties": {"string": {"stringValue": "airflow is awesome!"}}, - } - } + {"insert": {"key": KEYS[0], "properties": {"string": {"stringValue": "airflow is awesome!"}},}} ], "transaction": "{{ task_instance.xcom_pull('begin_transaction_commit') }}", } @@ -93,9 +84,7 @@ # [START how_to_query_def] QUERY = { "partitionId": {"projectId": GCP_PROJECT_ID, "namespaceId": ""}, - "readOptions": { - "transaction": "{{ task_instance.xcom_pull('begin_transaction_query') }}" - }, + "readOptions": {"transaction": "{{ task_instance.xcom_pull('begin_transaction_query') }}"}, "query": {}, } # [END how_to_query_def] @@ -129,15 +118,11 @@ allocate_ids >> begin_transaction_commit >> commit_task begin_transaction_query = CloudDatastoreBeginTransactionOperator( - task_id="begin_transaction_query", - transaction_options=TRANSACTION_OPTIONS, - project_id=GCP_PROJECT_ID, + task_id="begin_transaction_query", transaction_options=TRANSACTION_OPTIONS, project_id=GCP_PROJECT_ID, ) # [START how_to_run_query] - run_query = CloudDatastoreRunQueryOperator( - task_id="run_query", body=QUERY, project_id=GCP_PROJECT_ID - ) + run_query = CloudDatastoreRunQueryOperator(task_id="run_query", body=QUERY, project_id=GCP_PROJECT_ID) # [END how_to_run_query] allocate_ids >> begin_transaction_query >> run_query diff --git a/airflow/providers/google/cloud/example_dags/example_dlp.py b/airflow/providers/google/cloud/example_dags/example_dlp.py index ddc0ca3cd4b5b..a8dce29290050 100644 --- a/airflow/providers/google/cloud/example_dags/example_dlp.py +++ b/airflow/providers/google/cloud/example_dags/example_dlp.py @@ -30,7 +30,8 @@ from airflow import models from airflow.providers.google.cloud.operators.dlp import ( - CloudDLPCreateInspectTemplateOperator, CloudDLPDeleteInspectTemplateOperator, + CloudDLPCreateInspectTemplateOperator, + CloudDLPDeleteInspectTemplateOperator, CloudDLPInspectContentOperator, ) from airflow.utils.dates import days_ago @@ -43,9 +44,7 @@ "rows": [{"values": [{"string_value": "My phone number is (206) 555-0123"}]}], } ) -INSPECT_CONFIG = InspectConfig( - info_types=[{"name": "PHONE_NUMBER"}, {"name": "US_TOLLFREE_PHONE_NUMBER"}] -) +INSPECT_CONFIG = InspectConfig(info_types=[{"name": "PHONE_NUMBER"}, {"name": "US_TOLLFREE_PHONE_NUMBER"}]) INSPECT_TEMPLATE = InspectTemplate(inspect_config=INSPECT_CONFIG) @@ -73,10 +72,7 @@ ) delete_template = CloudDLPDeleteInspectTemplateOperator( - task_id="delete_template", - template_id=TEMPLATE_ID, - project_id=GCP_PROJECT, - dag=dag, + task_id="delete_template", template_id=TEMPLATE_ID, project_id=GCP_PROJECT, dag=dag, ) create_template >> inspect_content >> delete_template diff --git a/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py index 656d73eaddde5..452db7cde34e7 100644 --- a/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py @@ -24,7 +24,9 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, ) from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator @@ -49,28 +51,20 @@ AdsInsights.Field.clicks, AdsInsights.Field.impressions, ] -PARAMS = { - 'level': 'ad', - 'date_preset': 'yesterday' -} +PARAMS = {'level': 'ad', 'date_preset': 'yesterday'} # [END howto_FB_ADS_variables] with models.DAG( "example_facebook_ads_to_gcs", schedule_interval=None, # Override to match your needs - start_date=days_ago(1) + start_date=days_ago(1), ) as dag: create_bucket = GCSCreateBucketOperator( - task_id="create_bucket", - bucket_name=GCS_BUCKET, - project_id=GCP_PROJECT_ID, + task_id="create_bucket", bucket_name=GCS_BUCKET, project_id=GCP_PROJECT_ID, ) - create_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create_dataset", - dataset_id=DATASET_NAME, - ) + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME,) create_table = BigQueryCreateEmptyTableOperator( task_id="create_table", @@ -103,7 +97,7 @@ bucket=GCS_BUCKET, source_objects=[GCS_OBJ_PATH], destination_project_dataset_table=f"{DATASET_NAME}.{TABLE_NAME}", - write_disposition='WRITE_TRUNCATE' + write_disposition='WRITE_TRUNCATE', ) read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( @@ -112,16 +106,10 @@ use_legacy_sql=False, ) - delete_bucket = GCSDeleteBucketOperator( - task_id="delete_bucket", - bucket_name=GCS_BUCKET, - ) + delete_bucket = GCSDeleteBucketOperator(task_id="delete_bucket", bucket_name=GCS_BUCKET,) delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", - project_id=GCP_PROJECT_ID, - dataset_id=DATASET_NAME, - delete_contents=True, + task_id="delete_dataset", project_id=GCP_PROJECT_ID, dataset_id=DATASET_NAME, delete_contents=True, ) create_bucket >> create_dataset >> create_table >> run_operator >> load_csv diff --git a/airflow/providers/google/cloud/example_dags/example_functions.py b/airflow/providers/google/cloud/example_dags/example_functions.py index ed3dd4f17c4c2..c98a93fd830b2 100644 --- a/airflow/providers/google/cloud/example_dags/example_functions.py +++ b/airflow/providers/google/cloud/example_dags/example_functions.py @@ -44,51 +44,45 @@ from airflow import models from airflow.providers.google.cloud.operators.functions import ( - CloudFunctionDeleteFunctionOperator, CloudFunctionDeployFunctionOperator, + CloudFunctionDeleteFunctionOperator, + CloudFunctionDeployFunctionOperator, CloudFunctionInvokeFunctionOperator, ) from airflow.utils import dates GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') GCP_LOCATION = os.environ.get('GCP_LOCATION', 'europe-west1') -GCF_SHORT_FUNCTION_NAME = os.environ.get('GCF_SHORT_FUNCTION_NAME', 'hello').\ - replace("-", "_") # make sure there are no dashes in function name (!) -FUNCTION_NAME = 'projects/{}/locations/{}/functions/{}'.format(GCP_PROJECT_ID, - GCP_LOCATION, - GCF_SHORT_FUNCTION_NAME) +GCF_SHORT_FUNCTION_NAME = os.environ.get('GCF_SHORT_FUNCTION_NAME', 'hello').replace( + "-", "_" +) # make sure there are no dashes in function name (!) +FUNCTION_NAME = 'projects/{}/locations/{}/functions/{}'.format( + GCP_PROJECT_ID, GCP_LOCATION, GCF_SHORT_FUNCTION_NAME +) GCF_SOURCE_ARCHIVE_URL = os.environ.get('GCF_SOURCE_ARCHIVE_URL', '') GCF_SOURCE_UPLOAD_URL = os.environ.get('GCF_SOURCE_UPLOAD_URL', '') GCF_SOURCE_REPOSITORY = os.environ.get( 'GCF_SOURCE_REPOSITORY', 'https://source.developers.google.com/' - 'projects/{}/repos/hello-world/moveable-aliases/master'.format(GCP_PROJECT_ID)) + 'projects/{}/repos/hello-world/moveable-aliases/master'.format(GCP_PROJECT_ID), +) GCF_ZIP_PATH = os.environ.get('GCF_ZIP_PATH', '') GCF_ENTRYPOINT = os.environ.get('GCF_ENTRYPOINT', 'helloWorld') GCF_RUNTIME = 'nodejs6' GCP_VALIDATE_BODY = os.environ.get('GCP_VALIDATE_BODY', "True") == "True" # [START howto_operator_gcf_deploy_body] -body = { - "name": FUNCTION_NAME, - "entryPoint": GCF_ENTRYPOINT, - "runtime": GCF_RUNTIME, - "httpsTrigger": {} -} +body = {"name": FUNCTION_NAME, "entryPoint": GCF_ENTRYPOINT, "runtime": GCF_RUNTIME, "httpsTrigger": {}} # [END howto_operator_gcf_deploy_body] # [START howto_operator_gcf_default_args] -default_args = { - 'owner': 'airflow' -} +default_args = {'owner': 'airflow'} # [END howto_operator_gcf_default_args] # [START howto_operator_gcf_deploy_variants] if GCF_SOURCE_ARCHIVE_URL: body['sourceArchiveUrl'] = GCF_SOURCE_ARCHIVE_URL elif GCF_SOURCE_REPOSITORY: - body['sourceRepository'] = { - 'url': GCF_SOURCE_REPOSITORY - } + body['sourceRepository'] = {'url': GCF_SOURCE_REPOSITORY} elif GCF_ZIP_PATH: body['sourceUploadUrl'] = '' default_args['zip_path'] = GCF_ZIP_PATH @@ -111,15 +105,12 @@ project_id=GCP_PROJECT_ID, location=GCP_LOCATION, body=body, - validate_body=GCP_VALIDATE_BODY + validate_body=GCP_VALIDATE_BODY, ) # [END howto_operator_gcf_deploy] # [START howto_operator_gcf_deploy_no_project_id] deploy2_task = CloudFunctionDeployFunctionOperator( - task_id="gcf_deploy2_task", - location=GCP_LOCATION, - body=body, - validate_body=GCP_VALIDATE_BODY + task_id="gcf_deploy2_task", location=GCP_LOCATION, body=body, validate_body=GCP_VALIDATE_BODY ) # [END howto_operator_gcf_deploy_no_project_id] # [START howto_operator_gcf_invoke_function] @@ -128,13 +119,10 @@ project_id=GCP_PROJECT_ID, location=GCP_LOCATION, input_data={}, - function_id=GCF_SHORT_FUNCTION_NAME + function_id=GCF_SHORT_FUNCTION_NAME, ) # [END howto_operator_gcf_invoke_function] # [START howto_operator_gcf_delete] - delete_task = CloudFunctionDeleteFunctionOperator( - task_id="gcf_delete_task", - name=FUNCTION_NAME - ) + delete_task = CloudFunctionDeleteFunctionOperator(task_id="gcf_delete_task", name=FUNCTION_NAME) # [END howto_operator_gcf_delete] deploy_task >> deploy2_task >> invoke_task >> delete_task diff --git a/airflow/providers/google/cloud/example_dags/example_gcs.py b/airflow/providers/google/cloud/example_dags/example_gcs.py index d885addcbbcbc..0c0598a3e6561 100644 --- a/airflow/providers/google/cloud/example_dags/example_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_gcs.py @@ -24,8 +24,12 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.gcs import ( - GCSBucketCreateAclEntryOperator, GCSCreateBucketOperator, GCSDeleteBucketOperator, - GCSDeleteObjectsOperator, GCSFileTransformOperator, GCSListObjectsOperator, + GCSBucketCreateAclEntryOperator, + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSDeleteObjectsOperator, + GCSFileTransformOperator, + GCSListObjectsOperator, GCSObjectCreateAclEntryOperator, ) from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator @@ -42,21 +46,13 @@ BUCKET_2 = os.environ.get("GCP_GCS_BUCKET_1", "test-gcs-example-bucket-2") -PATH_TO_TRANSFORM_SCRIPT = os.environ.get( - 'GCP_GCS_PATH_TO_TRANSFORM_SCRIPT', 'test.py' -) -PATH_TO_UPLOAD_FILE = os.environ.get( - "GCP_GCS_PATH_TO_UPLOAD_FILE", "test-gcs-example.txt" -) -PATH_TO_SAVED_FILE = os.environ.get( - "GCP_GCS_PATH_TO_SAVED_FILE", "test-gcs-example-download.txt" -) +PATH_TO_TRANSFORM_SCRIPT = os.environ.get('GCP_GCS_PATH_TO_TRANSFORM_SCRIPT', 'test.py') +PATH_TO_UPLOAD_FILE = os.environ.get("GCP_GCS_PATH_TO_UPLOAD_FILE", "test-gcs-example.txt") +PATH_TO_SAVED_FILE = os.environ.get("GCP_GCS_PATH_TO_SAVED_FILE", "test-gcs-example-download.txt") BUCKET_FILE_LOCATION = PATH_TO_UPLOAD_FILE.rpartition("/")[-1] -with models.DAG( - "example_gcs", start_date=days_ago(1), schedule_interval=None, tags=['example'], -) as dag: +with models.DAG("example_gcs", start_date=days_ago(1), schedule_interval=None, tags=['example'],) as dag: create_bucket1 = GCSCreateBucketOperator( task_id="create_bucket1", bucket_name=BUCKET_1, project_id=PROJECT_ID ) @@ -65,27 +61,21 @@ task_id="create_bucket2", bucket_name=BUCKET_2, project_id=PROJECT_ID ) - list_buckets = GCSListObjectsOperator( - task_id="list_buckets", bucket=BUCKET_1 - ) + list_buckets = GCSListObjectsOperator(task_id="list_buckets", bucket=BUCKET_1) list_buckets_result = BashOperator( - task_id="list_buckets_result", - bash_command="echo \"{{ task_instance.xcom_pull('list_buckets') }}\"", + task_id="list_buckets_result", bash_command="echo \"{{ task_instance.xcom_pull('list_buckets') }}\"", ) upload_file = LocalFilesystemToGCSOperator( - task_id="upload_file", - src=PATH_TO_UPLOAD_FILE, - dst=BUCKET_FILE_LOCATION, - bucket=BUCKET_1, + task_id="upload_file", src=PATH_TO_UPLOAD_FILE, dst=BUCKET_FILE_LOCATION, bucket=BUCKET_1, ) transform_file = GCSFileTransformOperator( task_id="transform_file", source_bucket=BUCKET_1, source_object=BUCKET_FILE_LOCATION, - transform_script=["python", PATH_TO_TRANSFORM_SCRIPT] + transform_script=["python", PATH_TO_TRANSFORM_SCRIPT], ) # [START howto_operator_gcs_bucket_create_acl_entry_task] gcs_bucket_create_acl_entry_task = GCSBucketCreateAclEntryOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_gcs_to_bigquery.py b/airflow/providers/google/cloud/example_dags/example_gcs_to_bigquery.py index 2ef99a92bf175..1727d9fa80c9e 100644 --- a/airflow/providers/google/cloud/example_dags/example_gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/example_dags/example_gcs_to_bigquery.py @@ -24,7 +24,8 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryDeleteDatasetOperator, ) from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator from airflow.utils.dates import days_ago @@ -36,12 +37,11 @@ dag_id='example_gcs_to_bigquery_operator', start_date=days_ago(2), schedule_interval=None, - tags=['example']) + tags=['example'], +) create_test_dataset = BigQueryCreateEmptyDatasetOperator( - task_id='create_airflow_test_dataset', - dataset_id=DATASET_NAME, - dag=dag + task_id='create_airflow_test_dataset', dataset_id=DATASET_NAME, dag=dag ) # [START howto_operator_gcs_to_bigquery] @@ -55,14 +55,12 @@ {'name': 'post_abbr', 'type': 'STRING', 'mode': 'NULLABLE'}, ], write_disposition='WRITE_TRUNCATE', - dag=dag) + dag=dag, +) # [END howto_operator_gcs_to_bigquery] delete_test_dataset = BigQueryDeleteDatasetOperator( - task_id='delete_airflow_test_dataset', - dataset_id=DATASET_NAME, - delete_contents=True, - dag=dag + task_id='delete_airflow_test_dataset', dataset_id=DATASET_NAME, delete_contents=True, dag=dag ) create_test_dataset >> load_csv >> delete_test_dataset diff --git a/airflow/providers/google/cloud/example_dags/example_gcs_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_gcs_to_gcs.py index 4dbbe72647b98..8244f99f5c07f 100644 --- a/airflow/providers/google/cloud/example_dags/example_gcs_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_gcs_to_gcs.py @@ -43,9 +43,7 @@ ) as dag: # [START howto_synch_bucket] sync_bucket = GCSSynchronizeBucketsOperator( - task_id="sync_bucket", - source_bucket=BUCKET_1_SRC, - destination_bucket=BUCKET_1_DST + task_id="sync_bucket", source_bucket=BUCKET_1_SRC, destination_bucket=BUCKET_1_DST ) # [END howto_synch_bucket] @@ -55,7 +53,7 @@ source_bucket=BUCKET_1_SRC, destination_bucket=BUCKET_1_DST, delete_extra_files=True, - allow_overwrite=True + allow_overwrite=True, ) # [END howto_synch_full_bucket] @@ -64,7 +62,7 @@ task_id="sync_to_subdirectory", source_bucket=BUCKET_1_SRC, destination_bucket=BUCKET_1_DST, - destination_object="subdir/" + destination_object="subdir/", ) # [END howto_synch_to_subdir] @@ -73,7 +71,7 @@ task_id="sync_from_subdirectory", source_bucket=BUCKET_1_SRC, source_object="subdir/", - destination_bucket=BUCKET_1_DST + destination_bucket=BUCKET_1_DST, ) # [END howto_sync_from_subdir] @@ -83,7 +81,7 @@ source_bucket=BUCKET_1_SRC, source_object=OBJECT_1, destination_bucket=BUCKET_1_DST, # If not supplied the source_bucket value will be used - destination_object="backup_" + OBJECT_1 # If not supplied the source_object value will be used + destination_object="backup_" + OBJECT_1, # If not supplied the source_object value will be used ) # [END howto_operator_gcs_to_gcs_single_file] @@ -93,7 +91,7 @@ source_bucket=BUCKET_1_SRC, source_object="data/*.txt", destination_bucket=BUCKET_1_DST, - destination_object="backup/" + destination_object="backup/", ) # [END howto_operator_gcs_to_gcs_wildcard] @@ -104,7 +102,7 @@ source_object="data/", destination_bucket=BUCKET_1_DST, destination_object="backup/", - delimiter='.txt' + delimiter='.txt', ) # [END howto_operator_gcs_to_gcs_delimiter] @@ -114,7 +112,7 @@ source_bucket=BUCKET_1_SRC, source_objects=[OBJECT_1, OBJECT_2], # Instead of files each element could be a wildcard expression destination_bucket=BUCKET_1_DST, - destination_object="backup/" + destination_object="backup/", ) # [END howto_operator_gcs_to_gcs_list] @@ -125,7 +123,7 @@ source_object=OBJECT_1, destination_bucket=BUCKET_1_DST, destination_object="backup_" + OBJECT_1, - move_object=True + move_object=True, ) # [END howto_operator_gcs_to_gcs_single_file_move] @@ -135,6 +133,6 @@ source_bucket=BUCKET_1_SRC, source_objects=[OBJECT_1, OBJECT_2], destination_bucket=BUCKET_1_DST, - destination_object="backup/" + destination_object="backup/", ) # [END howto_operator_gcs_to_gcs_list_move] diff --git a/airflow/providers/google/cloud/example_dags/example_kubernetes_engine.py b/airflow/providers/google/cloud/example_dags/example_kubernetes_engine.py index 2d085b30da928..df96a258845cf 100644 --- a/airflow/providers/google/cloud/example_dags/example_kubernetes_engine.py +++ b/airflow/providers/google/cloud/example_dags/example_kubernetes_engine.py @@ -24,7 +24,9 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.kubernetes_engine import ( - GKECreateClusterOperator, GKEDeleteClusterOperator, GKEStartPodOperator, + GKECreateClusterOperator, + GKEDeleteClusterOperator, + GKEStartPodOperator, ) from airflow.utils.dates import days_ago @@ -44,10 +46,7 @@ ) as dag: # [START howto_operator_gke_create_cluster] create_cluster = GKECreateClusterOperator( - task_id="create_cluster", - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - body=CLUSTER, + task_id="create_cluster", project_id=GCP_PROJECT_ID, location=GCP_LOCATION, body=CLUSTER, ) # [END howto_operator_gke_create_cluster] @@ -84,10 +83,7 @@ # [START howto_operator_gke_delete_cluster] delete_cluster = GKEDeleteClusterOperator( - task_id="delete_cluster", - name=CLUSTER_NAME, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, + task_id="delete_cluster", name=CLUSTER_NAME, project_id=GCP_PROJECT_ID, location=GCP_LOCATION, ) # [END howto_operator_gke_delete_cluster] diff --git a/airflow/providers/google/cloud/example_dags/example_life_sciences.py b/airflow/providers/google/cloud/example_dags/example_life_sciences.py index ac3d72f02d52c..3829d4a9d1323 100644 --- a/airflow/providers/google/cloud/example_dags/example_life_sciences.py +++ b/airflow/providers/google/cloud/example_dags/example_life_sciences.py @@ -31,18 +31,11 @@ # [START howto_configure_simple_action_pipeline] SIMPLE_ACTION_PIEPELINE = { "pipeline": { - "actions": [ - { - "imageUri": "bash", - "commands": ["-c", "echo Hello, world"] - }, - ], + "actions": [{"imageUri": "bash", "commands": ["-c", "echo Hello, world"]},], "resources": { "regions": ["{}".format(LOCATION)], - "virtualMachine": { - "machineType": "n1-standard-1", - } - } + "virtualMachine": {"machineType": "n1-standard-1",}, + }, }, } # [END howto_configure_simple_action_pipeline] @@ -53,48 +46,45 @@ "actions": [ { "imageUri": "google/cloud-sdk", - "commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME), "/tmp"] - }, - { - "imageUri": "bash", - "commands": ["-c", "echo Hello, world"] + "commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME), "/tmp"], }, + {"imageUri": "bash", "commands": ["-c", "echo Hello, world"]}, { "imageUri": "google/cloud-sdk", - "commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME), - "gs://{}/output.in".format(BUCKET)] + "commands": [ + "gsutil", + "cp", + "gs://{}/{}".format(BUCKET, FILENAME), + "gs://{}/output.in".format(BUCKET), + ], }, ], "resources": { "regions": ["{}".format(LOCATION)], - "virtualMachine": { - "machineType": "n1-standard-1", - } - } + "virtualMachine": {"machineType": "n1-standard-1",}, + }, } } # [END howto_configure_multiple_action_pipeline] -with models.DAG("example_gcp_life_sciences", - default_args=dict(start_date=dates.days_ago(1)), - schedule_interval=None, - tags=['example'], - ) as dag: +with models.DAG( + "example_gcp_life_sciences", + default_args=dict(start_date=dates.days_ago(1)), + schedule_interval=None, + tags=['example'], +) as dag: # [START howto_run_pipeline] simple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( task_id='simple-action-pipeline', body=SIMPLE_ACTION_PIEPELINE, project_id=PROJECT_ID, - location=LOCATION + location=LOCATION, ) # [END howto_run_pipeline] multiple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( - task_id='multi-action-pipeline', - body=MULTI_ACTION_PIPELINE, - project_id=PROJECT_ID, - location=LOCATION + task_id='multi-action-pipeline', body=MULTI_ACTION_PIPELINE, project_id=PROJECT_ID, location=LOCATION ) simple_life_science_action_pipeline >> multiple_life_science_action_pipeline diff --git a/airflow/providers/google/cloud/example_dags/example_local_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_local_to_gcs.py index eb3a8dfc43600..9de6e09471c89 100644 --- a/airflow/providers/google/cloud/example_dags/example_local_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_local_to_gcs.py @@ -32,13 +32,10 @@ 'example_local_to_gcs', default_args=dict(start_date=dates.days_ago(1)), schedule_interval=None, - tags=['example'] + tags=['example'], ) as dag: # [START howto_operator_local_filesystem_to_gcs] upload_file = LocalFilesystemToGCSOperator( - task_id="upload_file", - src=PATH_TO_UPLOAD_FILE, - dst=DESTINATION_FILE_LOCATION, - bucket=BUCKET_NAME, + task_id="upload_file", src=PATH_TO_UPLOAD_FILE, dst=DESTINATION_FILE_LOCATION, bucket=BUCKET_NAME, ) # [END howto_operator_local_filesystem_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_mlengine.py b/airflow/providers/google/cloud/example_dags/example_mlengine.py index 5db611d989b5f..84871c0a3cfd7 100644 --- a/airflow/providers/google/cloud/example_dags/example_mlengine.py +++ b/airflow/providers/google/cloud/example_dags/example_mlengine.py @@ -25,9 +25,14 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.mlengine import ( - MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, - MLEngineDeleteVersionOperator, MLEngineGetModelOperator, MLEngineListVersionsOperator, - MLEngineSetDefaultVersionOperator, MLEngineStartBatchPredictionJobOperator, + MLEngineCreateModelOperator, + MLEngineCreateVersionOperator, + MLEngineDeleteModelOperator, + MLEngineDeleteVersionOperator, + MLEngineGetModelOperator, + MLEngineListVersionsOperator, + MLEngineSetDefaultVersionOperator, + MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator, ) from airflow.providers.google.cloud.utils import mlengine_operator_utils @@ -39,21 +44,19 @@ SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://test-airflow-mlengine/saved-model/") JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://test-airflow-mlengine/keras-job-dir") -PREDICTION_INPUT = os.environ.get("GCP_MLENGINE_PREDICTION_INPUT", - "gs://test-airflow-mlengine/prediction_input.json") -PREDICTION_OUTPUT = os.environ.get("GCP_MLENGINE_PREDICTION_OUTPUT", - "gs://test-airflow-mlengine/prediction_output") +PREDICTION_INPUT = os.environ.get( + "GCP_MLENGINE_PREDICTION_INPUT", "gs://test-airflow-mlengine/prediction_input.json" +) +PREDICTION_OUTPUT = os.environ.get( + "GCP_MLENGINE_PREDICTION_OUTPUT", "gs://test-airflow-mlengine/prediction_output" +) TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://test-airflow-mlengine/trainer.tar.gz") TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task") SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://test-airflow-mlengine/tmp/") SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://test-airflow-mlengine/staging/") -default_args = { - "params": { - "model_name": MODEL_NAME - } -} +default_args = {"params": {"model_name": MODEL_NAME}} with models.DAG( "example_gcp_mlengine", @@ -79,26 +82,17 @@ # [START howto_operator_gcp_mlengine_create_model] create_model = MLEngineCreateModelOperator( - task_id="create-model", - project_id=PROJECT_ID, - model={ - "name": MODEL_NAME, - }, + task_id="create-model", project_id=PROJECT_ID, model={"name": MODEL_NAME,}, ) # [END howto_operator_gcp_mlengine_create_model] # [START howto_operator_gcp_mlengine_get_model] - get_model = MLEngineGetModelOperator( - task_id="get-model", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - ) + get_model = MLEngineGetModelOperator(task_id="get-model", project_id=PROJECT_ID, model_name=MODEL_NAME,) # [END howto_operator_gcp_mlengine_get_model] # [START howto_operator_gcp_mlengine_print_model] get_model_result = BashOperator( - bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"", - task_id="get-model-result", + bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"", task_id="get-model-result", ) # [END howto_operator_gcp_mlengine_print_model] @@ -114,8 +108,8 @@ "runtime_version": "1.15", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", - "pythonVersion": "3.7" - } + "pythonVersion": "3.7", + }, ) # [END howto_operator_gcp_mlengine_create_version1] @@ -131,32 +125,26 @@ "runtime_version": "1.15", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", - "pythonVersion": "3.7" - } + "pythonVersion": "3.7", + }, ) # [END howto_operator_gcp_mlengine_create_version2] # [START howto_operator_gcp_mlengine_default_version] set_defaults_version = MLEngineSetDefaultVersionOperator( - task_id="set-default-version", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - version_name="v2", + task_id="set-default-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v2", ) # [END howto_operator_gcp_mlengine_default_version] # [START howto_operator_gcp_mlengine_list_versions] list_version = MLEngineListVersionsOperator( - task_id="list-version", - project_id=PROJECT_ID, - model_name=MODEL_NAME, + task_id="list-version", project_id=PROJECT_ID, model_name=MODEL_NAME, ) # [END howto_operator_gcp_mlengine_list_versions] # [START howto_operator_gcp_mlengine_print_versions] list_version_result = BashOperator( - bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"", - task_id="list-version-result", + bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"", task_id="list-version-result", ) # [END howto_operator_gcp_mlengine_print_versions] @@ -176,19 +164,13 @@ # [START howto_operator_gcp_mlengine_delete_version] delete_version = MLEngineDeleteVersionOperator( - task_id="delete-version", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - version_name="v1" + task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1" ) # [END howto_operator_gcp_mlengine_delete_version] # [START howto_operator_gcp_mlengine_delete_model] delete_model = MLEngineDeleteModelOperator( - task_id="delete-model", - project_id=PROJECT_ID, - model_name=MODEL_NAME, - delete_contents=True + task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True ) # [END howto_operator_gcp_mlengine_delete_model] @@ -208,10 +190,13 @@ def get_metric_fn_and_keys(): """ Gets metric function and keys used to generate summary """ + def normalize_value(inst: Dict): val = float(inst['dense_4'][0]) return tuple([val]) # returns a tuple. + return normalize_value, ['val'] # key order must match. + # [END howto_operator_gcp_mlengine_get_metric] # [START howto_operator_gcp_mlengine_validate_error] @@ -226,6 +211,7 @@ def validate_err_and_count(summary: Dict) -> Dict: if summary['count'] != 20: raise ValueError('Invalid value val != 20; summary={}'.format(summary)) return summary + # [END howto_operator_gcp_mlengine_validate_error] # [START howto_operator_gcp_mlengine_evaluate] diff --git a/airflow/providers/google/cloud/example_dags/example_natural_language.py b/airflow/providers/google/cloud/example_dags/example_natural_language.py index e8a26892f3a08..11a9246495ee6 100644 --- a/airflow/providers/google/cloud/example_dags/example_natural_language.py +++ b/airflow/providers/google/cloud/example_dags/example_natural_language.py @@ -25,8 +25,10 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.natural_language import ( - CloudNaturalLanguageAnalyzeEntitiesOperator, CloudNaturalLanguageAnalyzeEntitySentimentOperator, - CloudNaturalLanguageAnalyzeSentimentOperator, CloudNaturalLanguageClassifyTextOperator, + CloudNaturalLanguageAnalyzeEntitiesOperator, + CloudNaturalLanguageAnalyzeEntitySentimentOperator, + CloudNaturalLanguageAnalyzeSentimentOperator, + CloudNaturalLanguageClassifyTextOperator, ) from airflow.utils.dates import days_ago @@ -50,12 +52,13 @@ with models.DAG( "example_gcp_natural_language", schedule_interval=None, # Override to match your needs - start_date=days_ago(1) + start_date=days_ago(1), ) as dag: # [START howto_operator_gcp_natural_language_analyze_entities] - analyze_entities = \ - CloudNaturalLanguageAnalyzeEntitiesOperator(document=document, task_id="analyze_entities") + analyze_entities = CloudNaturalLanguageAnalyzeEntitiesOperator( + document=document, task_id="analyze_entities" + ) # [END howto_operator_gcp_natural_language_analyze_entities] # [START howto_operator_gcp_natural_language_analyze_entities_result] @@ -79,8 +82,9 @@ # [END howto_operator_gcp_natural_language_analyze_entity_sentiment_result] # [START howto_operator_gcp_natural_language_analyze_sentiment] - analyze_sentiment = \ - CloudNaturalLanguageAnalyzeSentimentOperator(document=document, task_id="analyze_sentiment") + analyze_sentiment = CloudNaturalLanguageAnalyzeSentimentOperator( + document=document, task_id="analyze_sentiment" + ) # [END howto_operator_gcp_natural_language_analyze_sentiment] # [START howto_operator_gcp_natural_language_analyze_sentiment_result] diff --git a/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py index 2e61f63ffb129..f215abb8599c7 100644 --- a/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_postgres_to_gcs.py @@ -33,9 +33,5 @@ tags=['example'], ) as dag: upload_data = PostgresToGCSOperator( - task_id="get_data", - sql=SQL_QUERY, - bucket=GCS_BUCKET, - filename=FILENAME, - gzip=False + task_id="get_data", sql=SQL_QUERY, bucket=GCS_BUCKET, filename=FILENAME, gzip=False ) diff --git a/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py index 3256090b04a05..3534b7631964f 100644 --- a/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py @@ -23,7 +23,9 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, ) from airflow.providers.google.cloud.transfers.presto_to_gcs import PrestoToGCSOperator diff --git a/airflow/providers/google/cloud/example_dags/example_pubsub.py b/airflow/providers/google/cloud/example_dags/example_pubsub.py index c631e2b5054cb..a1fc058298e61 100644 --- a/airflow/providers/google/cloud/example_dags/example_pubsub.py +++ b/airflow/providers/google/cloud/example_dags/example_pubsub.py @@ -24,8 +24,12 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.pubsub import ( - PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, PubSubPublishMessageOperator, PubSubPullOperator, + PubSubCreateSubscriptionOperator, + PubSubCreateTopicOperator, + PubSubDeleteSubscriptionOperator, + PubSubDeleteTopicOperator, + PubSubPublishMessageOperator, + PubSubPullOperator, ) from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor from airflow.utils.dates import days_ago @@ -46,7 +50,7 @@ with models.DAG( "example_gcp_pubsub_sensor", schedule_interval=None, # Override to match your needs - start_date=days_ago(1) + start_date=days_ago(1), ) as example_sensor_dag: # [START howto_operator_gcp_pubsub_create_topic] create_topic = PubSubCreateTopicOperator( @@ -64,17 +68,12 @@ subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" pull_messages = PubSubPullSensor( - task_id="pull_messages", - ack_messages=True, - project_id=GCP_PROJECT_ID, - subscription=subscription, + task_id="pull_messages", ack_messages=True, project_id=GCP_PROJECT_ID, subscription=subscription, ) # [END howto_operator_gcp_pubsub_pull_message_with_sensor] # [START howto_operator_gcp_pubsub_pull_messages_result] - pull_messages_result = BashOperator( - task_id="pull_messages_result", bash_command=echo_cmd - ) + pull_messages_result = BashOperator(task_id="pull_messages_result", bash_command=echo_cmd) # [END howto_operator_gcp_pubsub_pull_messages_result] # [START howto_operator_gcp_pubsub_publish] @@ -107,7 +106,7 @@ with models.DAG( "example_gcp_pubsub_operator", schedule_interval=None, # Override to match your needs - start_date=days_ago(1) + start_date=days_ago(1), ) as example_operator_dag: # [START howto_operator_gcp_pubsub_create_topic] create_topic = PubSubCreateTopicOperator( @@ -125,17 +124,12 @@ subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" pull_messages_operaator = PubSubPullOperator( - task_id="pull_messages", - ack_messages=True, - project_id=GCP_PROJECT_ID, - subscription=subscription, + task_id="pull_messages", ack_messages=True, project_id=GCP_PROJECT_ID, subscription=subscription, ) # [END howto_operator_gcp_pubsub_pull_message_with_operator] # [START howto_operator_gcp_pubsub_pull_messages_result] - pull_messages_result = BashOperator( - task_id="pull_messages_result", bash_command=echo_cmd - ) + pull_messages_result = BashOperator(task_id="pull_messages_result", bash_command=echo_cmd) # [END howto_operator_gcp_pubsub_pull_messages_result] # [START howto_operator_gcp_pubsub_publish] @@ -162,6 +156,11 @@ # [END howto_operator_gcp_pubsub_delete_topic] ( - create_topic >> subscribe_task >> publish_task - >> pull_messages_operaator >> pull_messages_result >> unsubscribe_task >> delete_topic + create_topic + >> subscribe_task + >> publish_task + >> pull_messages_operaator + >> pull_messages_result + >> unsubscribe_task + >> delete_topic ) diff --git a/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py index 218fa31fbd74b..ec197b96c1a3c 100644 --- a/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_sftp_to_gcs.py @@ -36,9 +36,7 @@ OBJECT_SRC_3 = "parent-3.txt" -with models.DAG( - "example_sftp_to_gcs", start_date=days_ago(1), schedule_interval=None -) as dag: +with models.DAG("example_sftp_to_gcs", start_date=days_ago(1), schedule_interval=None) as dag: # [START howto_operator_sftp_to_gcs_copy_single_file] copy_file_from_sftp_to_gcs = SFTPToGCSOperator( task_id="file-copy-sftp-to-gcs", diff --git a/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py index c09306c275e68..27be10ea5aae5 100644 --- a/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_sheets_to_gcs.py @@ -33,8 +33,6 @@ ) as dag: # [START upload_sheet_to_gcs] upload_sheet_to_gcs = GoogleSheetsToGCSOperator( - task_id="upload_sheet_to_gcs", - destination_bucket=BUCKET, - spreadsheet_id=SPREADSHEET_ID, + task_id="upload_sheet_to_gcs", destination_bucket=BUCKET, spreadsheet_id=SPREADSHEET_ID, ) # [END upload_sheet_to_gcs] diff --git a/airflow/providers/google/cloud/example_dags/example_spanner.py b/airflow/providers/google/cloud/example_dags/example_spanner.py index 51de839e3e087..eb7baa69eba6b 100644 --- a/airflow/providers/google/cloud/example_dags/example_spanner.py +++ b/airflow/providers/google/cloud/example_dags/example_spanner.py @@ -36,17 +36,21 @@ from airflow import models from airflow.providers.google.cloud.operators.spanner import ( - SpannerDeleteDatabaseInstanceOperator, SpannerDeleteInstanceOperator, - SpannerDeployDatabaseInstanceOperator, SpannerDeployInstanceOperator, - SpannerQueryDatabaseInstanceOperator, SpannerUpdateDatabaseInstanceOperator, + SpannerDeleteDatabaseInstanceOperator, + SpannerDeleteInstanceOperator, + SpannerDeployDatabaseInstanceOperator, + SpannerDeployInstanceOperator, + SpannerQueryDatabaseInstanceOperator, + SpannerUpdateDatabaseInstanceOperator, ) from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') GCP_SPANNER_INSTANCE_ID = os.environ.get('GCP_SPANNER_INSTANCE_ID', 'testinstance') GCP_SPANNER_DATABASE_ID = os.environ.get('GCP_SPANNER_DATABASE_ID', 'testdatabase') -GCP_SPANNER_CONFIG_NAME = os.environ.get('GCP_SPANNER_CONFIG_NAME', - 'projects/example-project/instanceConfigs/eur3') +GCP_SPANNER_CONFIG_NAME = os.environ.get( + 'GCP_SPANNER_CONFIG_NAME', 'projects/example-project/instanceConfigs/eur3' +) GCP_SPANNER_NODE_COUNT = os.environ.get('GCP_SPANNER_NODE_COUNT', '1') GCP_SPANNER_DISPLAY_NAME = os.environ.get('GCP_SPANNER_DISPLAY_NAME', 'Test Instance') # OPERATION_ID should be unique per operation @@ -66,14 +70,14 @@ configuration_name=GCP_SPANNER_CONFIG_NAME, node_count=int(GCP_SPANNER_NODE_COUNT), display_name=GCP_SPANNER_DISPLAY_NAME, - task_id='spanner_instance_create_task' + task_id='spanner_instance_create_task', ) spanner_instance_update_task = SpannerDeployInstanceOperator( instance_id=GCP_SPANNER_INSTANCE_ID, configuration_name=GCP_SPANNER_CONFIG_NAME, node_count=int(GCP_SPANNER_NODE_COUNT) + 1, display_name=GCP_SPANNER_DISPLAY_NAME + '_updated', - task_id='spanner_instance_update_task' + task_id='spanner_instance_update_task', ) # [END howto_operator_spanner_deploy] @@ -86,7 +90,7 @@ "CREATE TABLE my_table1 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", "CREATE TABLE my_table2 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", ], - task_id='spanner_database_deploy_task' + task_id='spanner_database_deploy_task', ) spanner_database_deploy_task2 = SpannerDeployDatabaseInstanceOperator( instance_id=GCP_SPANNER_INSTANCE_ID, @@ -95,7 +99,7 @@ "CREATE TABLE my_table1 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", "CREATE TABLE my_table2 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", ], - task_id='spanner_database_deploy_task2' + task_id='spanner_database_deploy_task2', ) # [END howto_operator_spanner_database_deploy] @@ -104,10 +108,8 @@ project_id=GCP_PROJECT_ID, instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, - ddl_statements=[ - "CREATE TABLE my_table3 (id INT64, name STRING(MAX)) PRIMARY KEY (id)", - ], - task_id='spanner_database_update_task' + ddl_statements=["CREATE TABLE my_table3 (id INT64, name STRING(MAX)) PRIMARY KEY (id)",], + task_id='spanner_database_update_task', ) # [END howto_operator_spanner_database_update] @@ -117,19 +119,15 @@ instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, operation_id=OPERATION_ID, - ddl_statements=[ - "CREATE TABLE my_table_unique (id INT64, name STRING(MAX)) PRIMARY KEY (id)", - ], - task_id='spanner_database_update_idempotent1_task' + ddl_statements=["CREATE TABLE my_table_unique (id INT64, name STRING(MAX)) PRIMARY KEY (id)",], + task_id='spanner_database_update_idempotent1_task', ) spanner_database_update_idempotent2_task = SpannerUpdateDatabaseInstanceOperator( instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, operation_id=OPERATION_ID, - ddl_statements=[ - "CREATE TABLE my_table_unique (id INT64, name STRING(MAX)) PRIMARY KEY (id)", - ], - task_id='spanner_database_update_idempotent2_task' + ddl_statements=["CREATE TABLE my_table_unique (id INT64, name STRING(MAX)) PRIMARY KEY (id)",], + task_id='spanner_database_update_idempotent2_task', ) # [END howto_operator_spanner_database_update_idempotent] @@ -139,13 +137,13 @@ instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, query=["DELETE FROM my_table2 WHERE true"], - task_id='spanner_instance_query_task' + task_id='spanner_instance_query_task', ) spanner_instance_query_task2 = SpannerQueryDatabaseInstanceOperator( instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, query=["DELETE FROM my_table2 WHERE true"], - task_id='spanner_instance_query_task2' + task_id='spanner_instance_query_task2', ) # [END howto_operator_spanner_query] @@ -154,37 +152,34 @@ project_id=GCP_PROJECT_ID, instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, - task_id='spanner_database_delete_task' + task_id='spanner_database_delete_task', ) spanner_database_delete_task2 = SpannerDeleteDatabaseInstanceOperator( instance_id=GCP_SPANNER_INSTANCE_ID, database_id=GCP_SPANNER_DATABASE_ID, - task_id='spanner_database_delete_task2' + task_id='spanner_database_delete_task2', ) # [END howto_operator_spanner_database_delete] # [START howto_operator_spanner_delete] spanner_instance_delete_task = SpannerDeleteInstanceOperator( - project_id=GCP_PROJECT_ID, - instance_id=GCP_SPANNER_INSTANCE_ID, - task_id='spanner_instance_delete_task' + project_id=GCP_PROJECT_ID, instance_id=GCP_SPANNER_INSTANCE_ID, task_id='spanner_instance_delete_task' ) spanner_instance_delete_task2 = SpannerDeleteInstanceOperator( - instance_id=GCP_SPANNER_INSTANCE_ID, - task_id='spanner_instance_delete_task2' + instance_id=GCP_SPANNER_INSTANCE_ID, task_id='spanner_instance_delete_task2' ) # [END howto_operator_spanner_delete] - spanner_instance_create_task \ - >> spanner_instance_update_task \ - >> spanner_database_deploy_task \ - >> spanner_database_deploy_task2 \ - >> spanner_database_update_task \ - >> spanner_database_update_idempotent1_task \ - >> spanner_database_update_idempotent2_task \ - >> spanner_instance_query_task \ - >> spanner_instance_query_task2 \ - >> spanner_database_delete_task \ - >> spanner_database_delete_task2 \ - >> spanner_instance_delete_task \ + ( + spanner_instance_create_task + >> spanner_instance_update_task + >> spanner_database_deploy_task + >> spanner_database_update_idempotent1_task + >> spanner_database_update_idempotent2_task + >> spanner_instance_query_task + >> spanner_instance_query_task2 + >> spanner_database_delete_task + >> spanner_database_delete_task2 + >> spanner_instance_delete_task >> spanner_instance_delete_task2 + ) diff --git a/airflow/providers/google/cloud/example_dags/example_speech_to_text.py b/airflow/providers/google/cloud/example_dags/example_speech_to_text.py index 1340eea7b02a3..58a5b814fe53a 100644 --- a/airflow/providers/google/cloud/example_dags/example_speech_to_text.py +++ b/airflow/providers/google/cloud/example_dags/example_speech_to_text.py @@ -58,9 +58,7 @@ ) # [START howto_operator_speech_to_text_recognize] speech_to_text_recognize_task2 = CloudSpeechToTextRecognizeSpeechOperator( - config=CONFIG, - audio=AUDIO, - task_id="speech_to_text_recognize_task" + config=CONFIG, audio=AUDIO, task_id="speech_to_text_recognize_task" ) # [END howto_operator_speech_to_text_recognize] diff --git a/airflow/providers/google/cloud/example_dags/example_stackdriver.py b/airflow/providers/google/cloud/example_dags/example_stackdriver.py index 658abe7f86468..d3a642b8418a2 100644 --- a/airflow/providers/google/cloud/example_dags/example_stackdriver.py +++ b/airflow/providers/google/cloud/example_dags/example_stackdriver.py @@ -24,93 +24,76 @@ from airflow import models from airflow.providers.google.cloud.operators.stackdriver import ( - StackdriverDeleteAlertOperator, StackdriverDeleteNotificationChannelOperator, - StackdriverDisableAlertPoliciesOperator, StackdriverDisableNotificationChannelsOperator, - StackdriverEnableAlertPoliciesOperator, StackdriverEnableNotificationChannelsOperator, - StackdriverListAlertPoliciesOperator, StackdriverListNotificationChannelsOperator, - StackdriverUpsertAlertOperator, StackdriverUpsertNotificationChannelOperator, + StackdriverDeleteAlertOperator, + StackdriverDeleteNotificationChannelOperator, + StackdriverDisableAlertPoliciesOperator, + StackdriverDisableNotificationChannelsOperator, + StackdriverEnableAlertPoliciesOperator, + StackdriverEnableNotificationChannelsOperator, + StackdriverListAlertPoliciesOperator, + StackdriverListNotificationChannelsOperator, + StackdriverUpsertAlertOperator, + StackdriverUpsertNotificationChannelOperator, ) from airflow.utils.dates import days_ago TEST_ALERT_POLICY_1 = { "combiner": "OR", "name": "projects/sd-project/alertPolicies/12345", - "creationRecord": { - "mutatedBy": "user123", - "mutateTime": "2020-01-01T00:00:00.000000Z" - }, + "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"}, "enabled": True, "displayName": "test alert 1", "conditions": [ { "conditionThreshold": { "comparison": "COMPARISON_GT", - "aggregations": [ - { - "alignmentPeriod": "60s", - "perSeriesAligner": "ALIGN_RATE" - } - ] + "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}], }, "displayName": "Condition display", - "name": "projects/sd-project/alertPolicies/123/conditions/456" + "name": "projects/sd-project/alertPolicies/123/conditions/456", } - ] + ], } TEST_ALERT_POLICY_2 = { "combiner": "OR", "name": "projects/sd-project/alertPolicies/6789", - "creationRecord": { - "mutatedBy": "user123", - "mutateTime": "2020-01-01T00:00:00.000000Z" - }, + "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"}, "enabled": False, "displayName": "test alert 2", "conditions": [ { "conditionThreshold": { "comparison": "COMPARISON_GT", - "aggregations": [ - { - "alignmentPeriod": "60s", - "perSeriesAligner": "ALIGN_RATE" - } - ] + "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}], }, "displayName": "Condition display", - "name": "projects/sd-project/alertPolicies/456/conditions/789" + "name": "projects/sd-project/alertPolicies/456/conditions/789", } - ] + ], } TEST_NOTIFICATION_CHANNEL_1 = { "displayName": "channel1", "enabled": True, - "labels": { - "auth_token": "top-secret", - "channel_name": "#channel" - }, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, "name": "projects/sd-project/notificationChannels/12345", - "type": "slack" + "type": "slack", } TEST_NOTIFICATION_CHANNEL_2 = { "displayName": "channel2", "enabled": False, - "labels": { - "auth_token": "top-secret", - "channel_name": "#channel" - }, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, "name": "projects/sd-project/notificationChannels/6789", - "type": "slack" + "type": "slack", } with models.DAG( 'example_stackdriver', schedule_interval=None, # Override to match your needs start_date=days_ago(1), - tags=['example'] + tags=['example'], ) as dag: # [START howto_operator_gcp_stackdriver_upsert_notification_channel] create_notification_channel = StackdriverUpsertNotificationChannelOperator( @@ -121,22 +104,19 @@ # [START howto_operator_gcp_stackdriver_enable_notification_channel] enable_notification_channel = StackdriverEnableNotificationChannelsOperator( - task_id='enable-notification-channel', - filter_='type="slack"' + task_id='enable-notification-channel', filter_='type="slack"' ) # [END howto_operator_gcp_stackdriver_enable_notification_channel] # [START howto_operator_gcp_stackdriver_disable_notification_channel] disable_notification_channel = StackdriverDisableNotificationChannelsOperator( - task_id='disable-notification-channel', - filter_='displayName="channel1"' + task_id='disable-notification-channel', filter_='displayName="channel1"' ) # [END howto_operator_gcp_stackdriver_disable_notification_channel] # [START howto_operator_gcp_stackdriver_list_notification_channel] list_notification_channel = StackdriverListNotificationChannelsOperator( - task_id='list-notification-channel', - filter_='type="slack"' + task_id='list-notification-channel', filter_='type="slack"' ) # [END howto_operator_gcp_stackdriver_list_notification_channel] @@ -149,38 +129,31 @@ # [START howto_operator_gcp_stackdriver_enable_alert_policy] enable_alert_policy = StackdriverEnableAlertPoliciesOperator( - task_id='enable-alert-policies', - filter_='(displayName="test alert 1" OR displayName="test alert 2")', + task_id='enable-alert-policies', filter_='(displayName="test alert 1" OR displayName="test alert 2")', ) # [END howto_operator_gcp_stackdriver_enable_alert_policy] # [START howto_operator_gcp_stackdriver_disable_alert_policy] disable_alert_policy = StackdriverDisableAlertPoliciesOperator( - task_id='disable-alert-policies', - filter_='displayName="test alert 1"', + task_id='disable-alert-policies', filter_='displayName="test alert 1"', ) # [END howto_operator_gcp_stackdriver_disable_alert_policy] # [START howto_operator_gcp_stackdriver_list_alert_policy] - list_alert_policies = StackdriverListAlertPoliciesOperator( - task_id='list-alert-policies', - ) + list_alert_policies = StackdriverListAlertPoliciesOperator(task_id='list-alert-policies',) # [END howto_operator_gcp_stackdriver_list_alert_policy] # [START howto_operator_gcp_stackdriver_delete_notification_channel] delete_notification_channel = StackdriverDeleteNotificationChannelOperator( - task_id='delete-notification-channel', - name='test-channel', + task_id='delete-notification-channel', name='test-channel', ) # [END howto_operator_gcp_stackdriver_delete_notification_channel] # [START howto_operator_gcp_stackdriver_delete_alert_policy] - delete_alert_policy = StackdriverDeleteAlertOperator( - task_id='delete-alert-polciy', - name='test-alert', - ) + delete_alert_policy = StackdriverDeleteAlertOperator(task_id='delete-alert-polciy', name='test-alert',) # [END howto_operator_gcp_stackdriver_delete_alert_policy] - create_notification_channel >> enable_notification_channel >> disable_notification_channel \ - >> list_notification_channel >> create_alert_policy >> enable_alert_policy >> disable_alert_policy \ - >> list_alert_policies >> delete_notification_channel >> delete_alert_policy + create_notification_channel >> enable_notification_channel >> disable_notification_channel + disable_notification_channel >> list_notification_channel >> create_alert_policy + create_alert_policy >> enable_alert_policy >> disable_alert_policy >> list_alert_policies + list_alert_policies >> delete_notification_channel >> delete_alert_policy diff --git a/airflow/providers/google/cloud/example_dags/example_tasks.py b/airflow/providers/google/cloud/example_dags/example_tasks.py index 6912cf138e7d4..8a97e58bbdf58 100644 --- a/airflow/providers/google/cloud/example_dags/example_tasks.py +++ b/airflow/providers/google/cloud/example_dags/example_tasks.py @@ -31,7 +31,9 @@ from airflow import models from airflow.providers.google.cloud.operators.tasks import ( - CloudTasksQueueCreateOperator, CloudTasksTaskCreateOperator, CloudTasksTaskRunOperator, + CloudTasksQueueCreateOperator, + CloudTasksTaskCreateOperator, + CloudTasksTaskRunOperator, ) from airflow.utils.dates import days_ago diff --git a/airflow/providers/google/cloud/example_dags/example_translate.py b/airflow/providers/google/cloud/example_dags/example_translate.py index 7c954f0ede671..ecae7f540a8fe 100644 --- a/airflow/providers/google/cloud/example_dags/example_translate.py +++ b/airflow/providers/google/cloud/example_dags/example_translate.py @@ -45,8 +45,7 @@ # [END howto_operator_translate_text] # [START howto_operator_translate_access] translation_access = BashOperator( - task_id='access', - bash_command="echo '{{ task_instance.xcom_pull(\"translate\")[0] }}'" + task_id='access', bash_command="echo '{{ task_instance.xcom_pull(\"translate\")[0] }}'" ) product_set_create >> translation_access # [END howto_operator_translate_access] diff --git a/airflow/providers/google/cloud/example_dags/example_translate_speech.py b/airflow/providers/google/cloud/example_dags/example_translate_speech.py index 74ce735e3d3e0..61c7579b0fd7a 100644 --- a/airflow/providers/google/cloud/example_dags/example_translate_speech.py +++ b/airflow/providers/google/cloud/example_dags/example_translate_speech.py @@ -70,7 +70,7 @@ format_=FORMAT, source_language=SOURCE_LANGUAGE, model=MODEL, - task_id='translate_speech_task' + task_id='translate_speech_task', ) translate_speech_task2 = CloudTranslateSpeechOperator( audio=AUDIO, @@ -79,7 +79,7 @@ format_=FORMAT, source_language=SOURCE_LANGUAGE, model=MODEL, - task_id='translate_speech_task2' + task_id='translate_speech_task2', ) # [END howto_operator_translate_speech] text_to_speech_synthesize_task >> translate_speech_task >> translate_speech_task2 diff --git a/airflow/providers/google/cloud/example_dags/example_video_intelligence.py b/airflow/providers/google/cloud/example_dags/example_video_intelligence.py index 6098b812f04d1..eced38fb87029 100644 --- a/airflow/providers/google/cloud/example_dags/example_video_intelligence.py +++ b/airflow/providers/google/cloud/example_dags/example_video_intelligence.py @@ -31,15 +31,14 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.video_intelligence import ( - CloudVideoIntelligenceDetectVideoExplicitContentOperator, CloudVideoIntelligenceDetectVideoLabelsOperator, + CloudVideoIntelligenceDetectVideoExplicitContentOperator, + CloudVideoIntelligenceDetectVideoLabelsOperator, CloudVideoIntelligenceDetectVideoShotsOperator, ) from airflow.utils.dates import days_ago # [START howto_operator_video_intelligence_os_args] -GCP_BUCKET_NAME = os.environ.get( - "GCP_VIDEO_INTELLIGENCE_BUCKET_NAME", "test-bucket-name" -) +GCP_BUCKET_NAME = os.environ.get("GCP_VIDEO_INTELLIGENCE_BUCKET_NAME", "test-bucket-name") # [END howto_operator_video_intelligence_os_args] @@ -57,18 +56,14 @@ # [START howto_operator_video_intelligence_detect_labels] detect_video_label = CloudVideoIntelligenceDetectVideoLabelsOperator( - input_uri=INPUT_URI, - output_uri=None, - video_context=None, - timeout=5, - task_id="detect_video_label", + input_uri=INPUT_URI, output_uri=None, video_context=None, timeout=5, task_id="detect_video_label", ) # [END howto_operator_video_intelligence_detect_labels] # [START howto_operator_video_intelligence_detect_labels_result] detect_video_label_result = BashOperator( bash_command="echo {{ task_instance.xcom_pull('detect_video_label')" - "['annotationResults'][0]['shotLabelAnnotations'][0]['entity']}}", + "['annotationResults'][0]['shotLabelAnnotations'][0]['entity']}}", task_id="detect_video_label_result", ) # [END howto_operator_video_intelligence_detect_labels_result] @@ -87,7 +82,7 @@ # [START howto_operator_video_intelligence_detect_explicit_content_result] detect_video_explicit_content_result = BashOperator( bash_command="echo {{ task_instance.xcom_pull('detect_video_explicit_content')" - "['annotationResults'][0]['explicitAnnotation']['frames'][0]}}", + "['annotationResults'][0]['explicitAnnotation']['frames'][0]}}", task_id="detect_video_explicit_content_result", ) # [END howto_operator_video_intelligence_detect_explicit_content_result] @@ -106,7 +101,7 @@ # [START howto_operator_video_intelligence_detect_video_shots_result] detect_video_shots_result = BashOperator( bash_command="echo {{ task_instance.xcom_pull('detect_video_shots')" - "['annotationResults'][0]['shotAnnotations'][0]}}", + "['annotationResults'][0]['shotAnnotations'][0]}}", task_id="detect_video_shots_result", ) # [END howto_operator_video_intelligence_detect_video_shots_result] diff --git a/airflow/providers/google/cloud/example_dags/example_vision.py b/airflow/providers/google/cloud/example_dags/example_vision.py index 3d81776c34ae7..d029226ef508d 100644 --- a/airflow/providers/google/cloud/example_dags/example_vision.py +++ b/airflow/providers/google/cloud/example_dags/example_vision.py @@ -36,31 +36,45 @@ from airflow import models from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.vision import ( - CloudVisionAddProductToProductSetOperator, CloudVisionCreateProductOperator, - CloudVisionCreateProductSetOperator, CloudVisionCreateReferenceImageOperator, - CloudVisionDeleteProductOperator, CloudVisionDeleteProductSetOperator, - CloudVisionDeleteReferenceImageOperator, CloudVisionDetectImageLabelsOperator, - CloudVisionDetectImageSafeSearchOperator, CloudVisionDetectTextOperator, CloudVisionGetProductOperator, - CloudVisionGetProductSetOperator, CloudVisionImageAnnotateOperator, - CloudVisionRemoveProductFromProductSetOperator, CloudVisionTextDetectOperator, - CloudVisionUpdateProductOperator, CloudVisionUpdateProductSetOperator, + CloudVisionAddProductToProductSetOperator, + CloudVisionCreateProductOperator, + CloudVisionCreateProductSetOperator, + CloudVisionCreateReferenceImageOperator, + CloudVisionDeleteProductOperator, + CloudVisionDeleteProductSetOperator, + CloudVisionDeleteReferenceImageOperator, + CloudVisionDetectImageLabelsOperator, + CloudVisionDetectImageSafeSearchOperator, + CloudVisionDetectTextOperator, + CloudVisionGetProductOperator, + CloudVisionGetProductSetOperator, + CloudVisionImageAnnotateOperator, + CloudVisionRemoveProductFromProductSetOperator, + CloudVisionTextDetectOperator, + CloudVisionUpdateProductOperator, + CloudVisionUpdateProductSetOperator, ) from airflow.utils.dates import days_ago # [START howto_operator_vision_retry_import] from google.api_core.retry import Retry # isort:skip pylint: disable=wrong-import-order + # [END howto_operator_vision_retry_import] # [START howto_operator_vision_product_set_import] from google.cloud.vision_v1.types import ProductSet # isort:skip pylint: disable=wrong-import-order + # [END howto_operator_vision_product_set_import] # [START howto_operator_vision_product_import] from google.cloud.vision_v1.types import Product # isort:skip pylint: disable=wrong-import-order + # [END howto_operator_vision_product_import] # [START howto_operator_vision_reference_image_import] from google.cloud.vision_v1.types import ReferenceImage # isort:skip pylint: disable=wrong-import-order + # [END howto_operator_vision_reference_image_import] # [START howto_operator_vision_enums_import] from google.cloud.vision import enums # isort:skip pylint: disable=wrong-import-order + # [END howto_operator_vision_enums_import] diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py index 7d26875a40602..5280d47e954d6 100644 --- a/airflow/providers/google/cloud/hooks/automl.py +++ b/airflow/providers/google/cloud/hooks/automl.py @@ -25,8 +25,18 @@ from google.api_core.retry import Retry from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient from google.cloud.automl_v1beta1.types import ( - BatchPredictInputConfig, BatchPredictOutputConfig, ColumnSpec, Dataset, ExamplePayload, FieldMask, - ImageObjectDetectionModelDeploymentMetadata, InputConfig, Model, Operation, PredictResponse, TableSpec, + BatchPredictInputConfig, + BatchPredictOutputConfig, + ColumnSpec, + Dataset, + ExamplePayload, + FieldMask, + ImageObjectDetectionModelDeploymentMetadata, + InputConfig, + Model, + Operation, + PredictResponse, + TableSpec, ) from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -47,9 +57,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None # type: Optional[AutoMlClient] @@ -68,9 +76,7 @@ def get_conn(self) -> AutoMlClient: :rtype: google.cloud.automl_v1beta1.AutoMlClient """ if self._client is None: - self._client = AutoMlClient( - credentials=self._get_credentials(), client_info=self.client_info - ) + self._client = AutoMlClient(credentials=self._get_credentials(), client_info=self.client_info) return self._client @cached_property @@ -81,9 +87,7 @@ def prediction_client(self) -> PredictionServiceClient: :return: Google Cloud AutoML PredictionServiceClient client object. :rtype: google.cloud.automl_v1beta1.PredictionServiceClient """ - return PredictionServiceClient( - credentials=self._get_credentials(), client_info=self.client_info - ) + return PredictionServiceClient(credentials=self._get_credentials(), client_info=self.client_info) @GoogleBaseHook.fallback_to_default_project_id def create_model( @@ -229,12 +233,7 @@ def predict( client = self.prediction_client name = client.model_path(project=project_id, location=location, model=model_id) result = client.predict( - name=name, - payload=payload, - params=params, - retry=retry, - timeout=timeout, - metadata=metadata, + name=name, payload=payload, params=params, retry=retry, timeout=timeout, metadata=metadata, ) return result @@ -273,11 +272,7 @@ def create_dataset( client = self.get_conn() parent = client.location_path(project=project_id, location=location) result = client.create_dataset( - parent=parent, - dataset=dataset, - retry=retry, - timeout=timeout, - metadata=metadata, + parent=parent, dataset=dataset, retry=retry, timeout=timeout, metadata=metadata, ) return result @@ -317,15 +312,9 @@ def import_data( :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance """ client = self.get_conn() - name = client.dataset_path( - project=project_id, location=location, dataset=dataset_id - ) + name = client.dataset_path(project=project_id, location=location, dataset=dataset_id) result = client.import_data( - name=name, - input_config=input_config, - retry=retry, - timeout=timeout, - metadata=metadata, + name=name, input_config=input_config, retry=retry, timeout=timeout, metadata=metadata, ) return result @@ -379,10 +368,7 @@ def list_column_specs( # pylint: disable=too-many-arguments """ client = self.get_conn() parent = client.table_spec_path( - project=project_id, - location=location, - dataset=dataset_id, - table_spec=table_spec_id, + project=project_id, location=location, dataset=dataset_id, table_spec=table_spec_id, ) result = client.list_column_specs( parent=parent, @@ -428,9 +414,7 @@ def get_model( """ client = self.get_conn() name = client.model_path(project=project_id, location=location, model=model_id) - result = client.get_model( - name=name, retry=retry, timeout=timeout, metadata=metadata - ) + result = client.get_model(name=name, retry=retry, timeout=timeout, metadata=metadata) return result @GoogleBaseHook.fallback_to_default_project_id @@ -466,9 +450,7 @@ def delete_model( """ client = self.get_conn() name = client.model_path(project=project_id, location=location, model=model_id) - result = client.delete_model( - name=name, retry=retry, timeout=timeout, metadata=metadata - ) + result = client.delete_model(name=name, retry=retry, timeout=timeout, metadata=metadata) return result def update_dataset( @@ -501,11 +483,7 @@ def update_dataset( """ client = self.get_conn() result = client.update_dataset( - dataset=dataset, - update_mask=update_mask, - retry=retry, - timeout=timeout, - metadata=metadata, + dataset=dataset, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata, ) return result @@ -515,9 +493,7 @@ def deploy_model( model_id: str, location: str, project_id: str, - image_detection_metadata: Union[ - ImageObjectDetectionModelDeploymentMetadata, dict - ] = None, + image_detection_metadata: Union[ImageObjectDetectionModelDeploymentMetadata, dict] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, @@ -607,9 +583,7 @@ def list_table_specs( of the response through the `options` parameter. """ client = self.get_conn() - parent = client.dataset_path( - project=project_id, location=location, dataset=dataset_id - ) + parent = client.dataset_path(project=project_id, location=location, dataset=dataset_id) result = client.list_table_specs( parent=parent, filter_=filter_, @@ -653,9 +627,7 @@ def list_datasets( """ client = self.get_conn() parent = client.location_path(project=project_id, location=location) - result = client.list_datasets( - parent=parent, retry=retry, timeout=timeout, metadata=metadata - ) + result = client.list_datasets(parent=parent, retry=retry, timeout=timeout, metadata=metadata) return result @GoogleBaseHook.fallback_to_default_project_id @@ -690,10 +662,6 @@ def delete_dataset( :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance """ client = self.get_conn() - name = client.dataset_path( - project=project_id, location=location, dataset=dataset_id - ) - result = client.delete_dataset( - name=name, retry=retry, timeout=timeout, metadata=metadata - ) + name = client.dataset_path(project=project_id, location=location, dataset=dataset_id) + result = client.delete_dataset(name=name, retry=retry, timeout=timeout, metadata=metadata) return result diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 6133c11ff9732..752dc729e67ed 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -28,7 +28,14 @@ from google.api_core.retry import Retry from google.cloud.bigquery import ( - DEFAULT_RETRY, Client, CopyJob, ExternalConfig, ExtractJob, LoadJob, QueryJob, SchemaField, + DEFAULT_RETRY, + Client, + CopyJob, + ExternalConfig, + ExtractJob, + LoadJob, + QueryJob, + SchemaField, ) from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.table import EncryptionConfiguration, Row, Table, TableReference @@ -37,7 +44,8 @@ from pandas import DataFrame from pandas_gbq import read_gbq from pandas_gbq.gbq import ( - GbqConnector, _check_google_client_version as gbq_check_google_client_version, + GbqConnector, + _check_google_client_version as gbq_check_google_client_version, _test_google_api_imports as gbq_test_google_api_imports, ) @@ -58,33 +66,36 @@ class BigQueryHook(GoogleBaseHook, DbApiHook): Interact with BigQuery. This hook uses the Google Cloud Platform connection. """ + conn_name_attr = 'gcp_conn_id' # type: str - def __init__(self, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - use_legacy_sql: bool = True, - location: Optional[str] = None, - bigquery_conn_id: Optional[str] = None, - api_resource_configs: Optional[Dict] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None,) -> None: + def __init__( + self, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + use_legacy_sql: bool = True, + location: Optional[str] = None, + bigquery_conn_id: Optional[str] = None, + api_resource_configs: Optional[Dict] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + ) -> None: # To preserve backward compatibility # TODO: remove one day if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=2) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=2, + ) gcp_conn_id = bigquery_conn_id super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.use_legacy_sql = use_legacy_sql self.location = location self.running_job_id = None # type: Optional[str] - self.api_resource_configs = api_resource_configs \ - if api_resource_configs else {} # type Dict + self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict def get_conn(self) -> "BigQueryConnection": """ @@ -97,7 +108,7 @@ def get_conn(self) -> "BigQueryConnection": use_legacy_sql=self.use_legacy_sql, location=self.location, num_retries=self.num_retries, - hook=self + hook=self, ) def get_service(self) -> Any: @@ -105,12 +116,10 @@ def get_service(self) -> Any: Returns a BigQuery service object. """ warnings.warn( - "This method will be deprecated. Please use `BigQueryHook.get_client` method", - DeprecationWarning + "This method will be deprecated. Please use `BigQueryHook.get_client` method", DeprecationWarning ) http_authorized = self._authorize() - return build( - 'bigquery', 'v2', http=http_authorized, cache_discovery=False) + return build('bigquery', 'v2', http=http_authorized, cache_discovery=False) def get_client(self, project_id: Optional[str] = None, location: Optional[str] = None) -> Client: """ @@ -126,7 +135,7 @@ def get_client(self, project_id: Optional[str] = None, location: Optional[str] = client_info=self.client_info, project=project_id, location=location, - credentials=self._get_credentials() + credentials=self._get_credentials(), ) @staticmethod @@ -142,11 +151,7 @@ def _resolve_table_reference( except KeyError: # Something is wrong so we try to build the reference table_resource["tableReference"] = table_resource.get("tableReference", {}) - values = [ - ("projectId", project_id), - ("tableId", table_id), - ("datasetId", dataset_id) - ] + values = [("projectId", project_id), ("tableId", table_id), ("datasetId", dataset_id)] for key, value in values: # Check if value is already present if no use the provided one resolved_value = table_resource["tableReference"].get(key, value) @@ -159,8 +164,13 @@ def _resolve_table_reference( return table_resource def insert_rows( - self, table: Any, rows: Any, target_fields: Any = None, commit_every: Any = 1000, - replace: Any = False, **kwargs + self, + table: Any, + rows: Any, + target_fields: Any = None, + commit_every: Any = 1000, + replace: Any = False, + **kwargs, ) -> NoReturn: """ Insertion is currently unsupported. Theoretically, you could use @@ -170,8 +180,11 @@ def insert_rows( raise NotImplementedError() def get_pandas_df( - self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None, dialect: Optional[str] = None, - **kwargs + self, + sql: str, + parameters: Optional[Union[Iterable, Mapping]] = None, + dialect: Optional[str] = None, + **kwargs, ) -> DataFrame: """ Returns a Pandas DataFrame for the results produced by a BigQuery @@ -197,12 +210,9 @@ def get_pandas_df( credentials, project_id = self._get_credentials_and_project_id() - return read_gbq(sql, - project_id=project_id, - dialect=dialect, - verbose=False, - credentials=credentials, - **kwargs) + return read_gbq( + sql, project_id=project_id, dialect=dialect, verbose=False, credentials=credentials, **kwargs + ) @GoogleBaseHook.fallback_to_default_project_id def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool: @@ -228,11 +238,7 @@ def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool: @GoogleBaseHook.fallback_to_default_project_id def table_partition_exists( - self, - dataset_id: str, - table_id: str, - partition_id: str, - project_id: str + self, dataset_id: str, table_id: str, partition_id: str, project_id: str ) -> bool: """ Checks for the existence of a partition in a table in Google BigQuery. @@ -271,7 +277,7 @@ def create_empty_table( # pylint: disable=too-many-arguments retry: Optional[Retry] = DEFAULT_RETRY, num_retries: Optional[int] = None, location: Optional[str] = None, - exists_ok: bool = True + exists_ok: bool = True, ) -> Table: """ Creates a new, empty table in the dataset. @@ -351,9 +357,7 @@ def create_empty_table( # pylint: disable=too-many-arguments _table_resource['timePartitioning'] = time_partitioning if cluster_fields: - _table_resource['clustering'] = { - 'fields': cluster_fields - } + _table_resource['clustering'] = {'fields': cluster_fields} if labels: _table_resource['labels'] = labels @@ -366,16 +370,11 @@ def create_empty_table( # pylint: disable=too-many-arguments table_resource = table_resource or _table_resource table_resource = self._resolve_table_reference( - table_resource=table_resource, - project_id=project_id, - dataset_id=dataset_id, - table_id=table_id, + table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id, ) table = Table.from_api_repr(table_resource) return self.get_client(project_id=project_id, location=location).create_table( - table=table, - exists_ok=exists_ok, - retry=retry + table=table, exists_ok=exists_ok, retry=retry ) @GoogleBaseHook.fallback_to_default_project_id @@ -415,7 +414,8 @@ def create_empty_dataset( self.log.info( "`%s` was provided in both `dataset_reference` and as `%s`. " "Using value from `dataset_reference`", - param, convert_camel_to_snake(param) + param, + convert_camel_to_snake(param), ) continue # use specified value if not value: @@ -425,8 +425,7 @@ def create_empty_dataset( ) # dataset_reference has no param but we can fallback to default value self.log.info( - "%s was not specified in `dataset_reference`. Will use default value %s.", - param, value + "%s was not specified in `dataset_reference`. Will use default value %s.", param, value ) dataset_reference["datasetReference"][param] = value @@ -499,31 +498,32 @@ def delete_dataset( dataset=DatasetReference(project=project_id, dataset_id=dataset_id), delete_contents=delete_contents, retry=retry, - not_found_ok=True + not_found_ok=True, ) @GoogleBaseHook.fallback_to_default_project_id - def create_external_table(self, # pylint: disable=too-many-locals,too-many-arguments - external_project_dataset_table: str, - schema_fields: List, - source_uris: List, - source_format: str = 'CSV', - autodetect: bool = False, - compression: str = 'NONE', - ignore_unknown_values: bool = False, - max_bad_records: int = 0, - skip_leading_rows: int = 0, - field_delimiter: str = ',', - quote_character: Optional[str] = None, - allow_quoted_newlines: bool = False, - allow_jagged_rows: bool = False, - encoding: str = "UTF-8", - src_fmt_configs: Optional[Dict] = None, - labels: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, - location: Optional[str] = None, - project_id: Optional[str] = None, - ) -> None: + def create_external_table( # pylint: disable=too-many-locals,too-many-arguments + self, + external_project_dataset_table: str, + schema_fields: List, + source_uris: List, + source_format: str = 'CSV', + autodetect: bool = False, + compression: str = 'NONE', + ignore_unknown_values: bool = False, + max_bad_records: int = 0, + skip_leading_rows: int = 0, + field_delimiter: str = ',', + quote_character: Optional[str] = None, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + encoding: str = "UTF-8", + src_fmt_configs: Optional[Dict] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: Optional[str] = None, + project_id: Optional[str] = None, + ) -> None: """ Creates a new external table in the dataset with the data from Google Cloud Storage. See here: @@ -614,7 +614,7 @@ def create_external_table(self, # pylint: disable=too-many-locals,too-many-argu 'sourceFormat': source_format, 'sourceUris': source_uris, 'compression': compression, - 'ignoreUnknownValues': ignore_unknown_values + 'ignoreUnknownValues': ignore_unknown_values, } # if following fields are not specified in src_fmt_configs, @@ -625,24 +625,25 @@ def create_external_table(self, # pylint: disable=too-many-locals,too-many-argu 'quote': quote_character, 'allowQuotedNewlines': allow_quoted_newlines, 'allowJaggedRows': allow_jagged_rows, - 'encoding': encoding - } - src_fmt_to_param_mapping = { - 'CSV': 'csvOptions', - 'GOOGLE_SHEETS': 'googleSheetsOptions' + 'encoding': encoding, } + src_fmt_to_param_mapping = {'CSV': 'csvOptions', 'GOOGLE_SHEETS': 'googleSheetsOptions'} src_fmt_to_configs_mapping = { 'csvOptions': [ - 'allowJaggedRows', 'allowQuotedNewlines', - 'fieldDelimiter', 'skipLeadingRows', - 'quote', 'encoding' + 'allowJaggedRows', + 'allowQuotedNewlines', + 'fieldDelimiter', + 'skipLeadingRows', + 'quote', + 'encoding', ], - 'googleSheetsOptions': ['skipLeadingRows'] + 'googleSheetsOptions': ['skipLeadingRows'], } if source_format in src_fmt_to_param_mapping.keys(): valid_configs = src_fmt_to_configs_mapping[src_fmt_to_param_mapping[source_format]] - src_fmt_configs = _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs, - backward_compatibility_configs) + src_fmt_configs = _validate_src_fmt_configs( + source_format, src_fmt_configs, valid_configs, backward_compatibility_configs + ) external_config_api_repr[src_fmt_to_param_mapping[source_format]] = src_fmt_configs # build external config @@ -653,9 +654,7 @@ def create_external_table(self, # pylint: disable=too-many-locals,too-many-argu external_config.max_bad_records = max_bad_records # build table definition - table = Table( - table_ref=TableReference.from_string(external_project_dataset_table, project_id) - ) + table = Table(table_ref=TableReference.from_string(external_project_dataset_table, project_id)) table.external_data_configuration = external_config if labels: table.labels = labels @@ -665,10 +664,7 @@ def create_external_table(self, # pylint: disable=too-many-locals,too-many-argu self.log.info('Creating external table: %s', external_project_dataset_table) self.create_empty_table( - table_resource=table.to_api_repr(), - project_id=project_id, - location=location, - exists_ok=True + table_resource=table.to_api_repr(), project_id=project_id, location=location, exists_ok=True ) self.log.info('External table created successfully: %s', external_project_dataset_table) @@ -711,10 +707,7 @@ def update_table( """ fields = fields or list(table_resource.keys()) table_resource = self._resolve_table_reference( - table_resource=table_resource, - project_id=project_id, - dataset_id=dataset_id, - table_id=table_id + table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id ) table = Table.from_api_repr(table_resource) @@ -724,20 +717,22 @@ def update_table( return table_object.to_api_repr() @GoogleBaseHook.fallback_to_default_project_id - def patch_table(self, # pylint: disable=too-many-arguments - dataset_id: str, - table_id: str, - project_id: Optional[str] = None, - description: Optional[str] = None, - expiration_time: Optional[int] = None, - external_data_configuration: Optional[Dict] = None, - friendly_name: Optional[str] = None, - labels: Optional[Dict] = None, - schema: Optional[List] = None, - time_partitioning: Optional[Dict] = None, - view: Optional[Dict] = None, - require_partition_filter: Optional[bool] = None, - encryption_configuration: Optional[Dict] = None) -> None: + def patch_table( # pylint: disable=too-many-arguments + self, + dataset_id: str, + table_id: str, + project_id: Optional[str] = None, + description: Optional[str] = None, + expiration_time: Optional[int] = None, + external_data_configuration: Optional[Dict] = None, + friendly_name: Optional[str] = None, + labels: Optional[Dict] = None, + schema: Optional[List] = None, + time_partitioning: Optional[Dict] = None, + view: Optional[Dict] = None, + require_partition_filter: Optional[bool] = None, + encryption_configuration: Optional[Dict] = None, + ) -> None: """ Patch information in an existing table. It only updates fields that are provided in the request object. @@ -799,8 +794,7 @@ def patch_table(self, # pylint: disable=too-many-arguments """ warnings.warn( - "This method is deprecated, please use ``BigQueryHook.update_table`` method.", - DeprecationWarning, + "This method is deprecated, please use ``BigQueryHook.update_table`` method.", DeprecationWarning, ) table_resource: Dict[str, Any] = {} @@ -830,7 +824,7 @@ def patch_table(self, # pylint: disable=too-many-arguments fields=list(table_resource.keys()), project_id=project_id, dataset_id=dataset_id, - table_id=table_id + table_id=table_id, ) @GoogleBaseHook.fallback_to_default_project_id @@ -842,7 +836,7 @@ def insert_all( rows: List, ignore_unknown_values: bool = False, skip_invalid_rows: bool = False, - fail_on_error: bool = False + fail_on_error: bool = False, ) -> None: """ Method to stream data into BigQuery one record at a time without needing @@ -877,9 +871,7 @@ def insert_all( even if any insertion errors occur. :type fail_on_error: bool """ - self.log.info( - 'Inserting %s row(s) into table %s:%s.%s', len(rows), project_id, dataset_id, table_id - ) + self.log.info('Inserting %s row(s) into table %s:%s.%s', len(rows), project_id, dataset_id, table_id) table = self._resolve_table_reference( table_resource={}, project_id=project_id, dataset_id=dataset_id, table_id=table_id @@ -888,7 +880,7 @@ def insert_all( table=Table.from_api_repr(table), rows=rows, ignore_unknown_values=ignore_unknown_values, - skip_invalid_rows=skip_invalid_rows + skip_invalid_rows=skip_invalid_rows, ) if errors: error_msg = f"{len(errors)} insert error(s) occurred. Details: {errors}" @@ -896,10 +888,7 @@ def insert_all( if fail_on_error: raise AirflowException(f'BigQuery job failed. Error was: {error_msg}') else: - self.log.info( - 'All row(s) inserted successfully: %s:%s.%s', - project_id, dataset_id, table_id - ) + self.log.info('All row(s) inserted successfully: %s:%s.%s', project_id, dataset_id, table_id) @GoogleBaseHook.fallback_to_default_project_id def update_dataset( @@ -946,9 +935,7 @@ def update_dataset( self.log.info('Start updating dataset') dataset = self.get_client(project_id=project_id).update_dataset( - dataset=Dataset.from_api_repr(dataset_resource), - fields=fields, - retry=retry, + dataset=Dataset.from_api_repr(dataset_resource), fields=fields, retry=retry, ) self.log.info("Dataset successfully updated: %s", dataset) return dataset @@ -974,10 +961,7 @@ def patch_dataset( https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ - warnings.warn( - "This method is deprecated. Please use ``update_dataset``.", - DeprecationWarning - ) + warnings.warn("This method is deprecated. Please use ``update_dataset``.", DeprecationWarning) project_id = project_id or self.project_id if not dataset_id or not isinstance(dataset_id, str): raise ValueError( @@ -991,11 +975,7 @@ def patch_dataset( self.log.info('Start patching dataset: %s:%s', dataset_project_id, dataset_id) dataset = ( service.datasets() # pylint: disable=no-member - .patch( - datasetId=dataset_id, - projectId=dataset_project_id, - body=dataset_resource, - ) + .patch(datasetId=dataset_id, projectId=dataset_project_id, body=dataset_resource,) .execute(num_retries=self.num_retries) ) self.log.info("Dataset successfully patched: %s", dataset) @@ -1007,7 +987,7 @@ def get_dataset_tables_list( dataset_id: str, project_id: Optional[str] = None, table_prefix: Optional[str] = None, - max_results: Optional[int] = None + max_results: Optional[int] = None, ) -> List[Dict[str, Any]]: """ Method returns tables list of a BigQuery tables. If table prefix is specified, @@ -1027,14 +1007,10 @@ def get_dataset_tables_list( :type max_results: int :return: List of tables associated with the dataset """ - warnings.warn( - "This method is deprecated. Please use ``get_dataset_tables``.", - DeprecationWarning - ) + warnings.warn("This method is deprecated. Please use ``get_dataset_tables``.", DeprecationWarning) project_id = project_id or self.project_id tables = self.get_client().list_tables( - dataset=DatasetReference(project=project_id, dataset_id=dataset_id), - max_results=max_results, + dataset=DatasetReference(project=project_id, dataset_id=dataset_id), max_results=max_results, ) if table_prefix: @@ -1143,18 +1119,13 @@ def run_grant_dataset_view_access( if source_project: project_id = source_project warnings.warn( - "Parameter ``source_project`` is deprecated. Use ``project_id``.", - DeprecationWarning, + "Parameter ``source_project`` is deprecated. Use ``project_id``.", DeprecationWarning, ) view_project = view_project or project_id view_access = AccessEntry( role=None, entity_type="view", - entity_id={ - 'projectId': view_project, - 'datasetId': view_dataset, - 'tableId': view_table - } + entity_id={'projectId': view_project, 'datasetId': view_dataset, 'tableId': view_table}, ) dataset = self.get_dataset(project_id=project_id, dataset_id=source_dataset) @@ -1163,27 +1134,30 @@ def run_grant_dataset_view_access( if view_access not in dataset.access_entries: self.log.info( 'Granting table %s:%s.%s authorized view access to %s:%s dataset.', - view_project, view_dataset, view_table, project_id, source_dataset + view_project, + view_dataset, + view_table, + project_id, + source_dataset, ) dataset.access_entries += [view_access] dataset = self.update_dataset( - fields=["access"], - dataset_resource=dataset.to_api_repr(), - project_id=project_id + fields=["access"], dataset_resource=dataset.to_api_repr(), project_id=project_id ) else: self.log.info( 'Table %s:%s.%s already has authorized view access to %s:%s dataset.', - view_project, view_dataset, view_table, project_id, source_dataset + view_project, + view_dataset, + view_table, + project_id, + source_dataset, ) return dataset.to_api_repr() @GoogleBaseHook.fallback_to_default_project_id def run_table_upsert( - self, - dataset_id: str, - table_resource: Dict[str, Any], - project_id: Optional[str] = None + self, dataset_id: str, table_resource: Dict[str, Any], project_id: Optional[str] = None ) -> Dict[str, Any]: """ If the table already exists, update the existing table if not create new. @@ -1201,10 +1175,7 @@ def run_table_upsert( """ table_id = table_resource['tableReference']['tableId'] table_resource = self._resolve_table_reference( - table_resource=table_resource, - project_id=project_id, - dataset_id=dataset_id, - table_id=table_id + table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id ) tables_list_resp = self.get_dataset_tables(dataset_id=dataset_id, project_id=project_id) @@ -1238,10 +1209,7 @@ def run_table_delete(self, deletion_dataset_table: str, ignore_if_missing: bool @GoogleBaseHook.fallback_to_default_project_id def delete_table( - self, - table_id: str, - not_found_ok: bool = True, - project_id: Optional[str] = None, + self, table_id: str, not_found_ok: bool = True, project_id: Optional[str] = None, ) -> None: """ Delete an existing table from the dataset. If the table does not exist, return an error @@ -1257,8 +1225,7 @@ def delete_table( :type project_id: str """ self.get_client(project_id=project_id).delete_table( - table=Table.from_string(table_id), - not_found_ok=not_found_ok, + table=Table.from_string(table_id), not_found_ok=not_found_ok, ) self.log.info('Deleted table %s', table_id) @@ -1269,7 +1236,7 @@ def get_tabledata( max_results: Optional[int] = None, selected_fields: Optional[str] = None, page_token: Optional[str] = None, - start_index: Optional[int] = None + start_index: Optional[int] = None, ) -> List[Dict]: """ Get the data of a given dataset.table and optionally with selected columns. @@ -1286,9 +1253,7 @@ def get_tabledata( :return: list of rows """ warnings.warn("This method is deprecated. Please use `list_rows`.", DeprecationWarning) - rows = self.list_rows( - dataset_id, table_id, max_results, selected_fields, page_token, start_index - ) + rows = self.list_rows(dataset_id, table_id, max_results, selected_fields, page_token, start_index) return [dict(r) for r in rows] @GoogleBaseHook.fallback_to_default_project_id @@ -1383,8 +1348,7 @@ def cancel_query(self) -> None: Cancel all started queries that have not yet completed """ warnings.warn( - "This method is deprecated. Please use `BigQueryHook.cancel_job`.", - DeprecationWarning, + "This method is deprecated. Please use `BigQueryHook.cancel_job`.", DeprecationWarning, ) if self.running_job_id: self.cancel_job(job_id=self.running_job_id) @@ -1393,10 +1357,7 @@ def cancel_query(self) -> None: @GoogleBaseHook.fallback_to_default_project_id def cancel_job( - self, - job_id: str, - project_id: Optional[str] = None, - location: Optional[str] = None, + self, job_id: str, project_id: Optional[str] = None, location: Optional[str] = None, ) -> None: """ Cancels a job an wait for cancellation to complete @@ -1430,17 +1391,16 @@ def cancel_job( elif polling_attempts == max_polling_attempts: self.log.info( "Stopping polling due to timeout. Job with id %s " - "has not completed cancel and may or may not finish.", job_id) + "has not completed cancel and may or may not finish.", + job_id, + ) else: self.log.info('Waiting for canceled job with id %s to finish.', job_id) time.sleep(5) @GoogleBaseHook.fallback_to_default_project_id def get_job( - self, - job_id: Optional[str] = None, - project_id: Optional[str] = None, - location: Optional[str] = None, + self, job_id: Optional[str] = None, project_id: Optional[str] = None, location: Optional[str] = None, ) -> Union[CopyJob, QueryJob, LoadJob, ExtractJob]: """ Retrives a BigQuery job. For more information see: @@ -1456,11 +1416,7 @@ def get_job( :type location: str """ client = self.get_client(project_id=project_id, location=location) - job = client.get_job( - job_id=job_id, - project=project_id, - location=location - ) + job = client.get_job(job_id=job_id, project=project_id, location=location) return job @GoogleBaseHook.fallback_to_default_project_id @@ -1497,11 +1453,7 @@ def insert_job( client = self.get_client(project_id=project_id, location=location) job_data = { "configuration": configuration, - "jobReference": { - "jobId": job_id, - "projectId": project_id, - "location": location - } + "jobReference": {"jobId": job_id, "projectId": project_id, "location": location}, } # pylint: disable=protected-access supported_jobs = { @@ -1537,35 +1489,34 @@ def run_with_configuration(self, configuration: Dict) -> str: https://cloud.google.com/bigquery/docs/reference/v2/jobs for details. """ - warnings.warn( - "This method is deprecated. Please use `BigQueryHook.insert_job`", - DeprecationWarning - ) + warnings.warn("This method is deprecated. Please use `BigQueryHook.insert_job`", DeprecationWarning) job = self.insert_job(configuration=configuration, project_id=self.project_id) self.running_job_id = job.job_id return job.job_id - def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid-name - destination_project_dataset_table: str, - source_uris: List, - schema_fields: Optional[List] = None, - source_format: str = 'CSV', - create_disposition: str = 'CREATE_IF_NEEDED', - skip_leading_rows: int = 0, - write_disposition: str = 'WRITE_EMPTY', - field_delimiter: str = ',', - max_bad_records: int = 0, - quote_character: Optional[str] = None, - ignore_unknown_values: bool = False, - allow_quoted_newlines: bool = False, - allow_jagged_rows: bool = False, - encoding: str = "UTF-8", - schema_update_options: Optional[Iterable] = None, - src_fmt_configs: Optional[Dict] = None, - time_partitioning: Optional[Dict] = None, - cluster_fields: Optional[List] = None, - autodetect: bool = False, - encryption_configuration: Optional[Dict] = None) -> str: + def run_load( # pylint: disable=too-many-locals,too-many-arguments,invalid-name + self, + destination_project_dataset_table: str, + source_uris: List, + schema_fields: Optional[List] = None, + source_format: str = 'CSV', + create_disposition: str = 'CREATE_IF_NEEDED', + skip_leading_rows: int = 0, + write_disposition: str = 'WRITE_EMPTY', + field_delimiter: str = ',', + max_bad_records: int = 0, + quote_character: Optional[str] = None, + ignore_unknown_values: bool = False, + allow_quoted_newlines: bool = False, + allow_jagged_rows: bool = False, + encoding: str = "UTF-8", + schema_update_options: Optional[Iterable] = None, + src_fmt_configs: Optional[Dict] = None, + time_partitioning: Optional[Dict] = None, + cluster_fields: Optional[List] = None, + autodetect: bool = False, + encryption_configuration: Optional[Dict] = None, + ) -> str: """ Executes a BigQuery load command to load data from Google Cloud Storage to BigQuery. See here: @@ -1649,8 +1600,7 @@ def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid :type encryption_configuration: dict """ warnings.warn( - "This method is deprecated. Please use `BigQueryHook.insert_job` method.", - DeprecationWarning + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", DeprecationWarning ) if not self.project_id: @@ -1666,40 +1616,44 @@ def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat # noqa # pylint: disable=line-too-long if schema_fields is None and not autodetect: - raise ValueError( - 'You must either pass a schema or autodetect=True.') + raise ValueError('You must either pass a schema or autodetect=True.') if src_fmt_configs is None: src_fmt_configs = {} source_format = source_format.upper() allowed_formats = [ - "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", - "DATASTORE_BACKUP", "PARQUET" + "CSV", + "NEWLINE_DELIMITED_JSON", + "AVRO", + "GOOGLE_SHEETS", + "DATASTORE_BACKUP", + "PARQUET", ] if source_format not in allowed_formats: - raise ValueError("{0} is not a valid source format. " - "Please use one of the following types: {1}" - .format(source_format, allowed_formats)) + raise ValueError( + "{0} is not a valid source format. " + "Please use one of the following types: {1}".format(source_format, allowed_formats) + ) # bigquery also allows you to define how you want a table's schema to change # as a side effect of a load # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions - allowed_schema_update_options = [ - 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" - ] - if not set(allowed_schema_update_options).issuperset( - set(schema_update_options)): + allowed_schema_update_options = ['ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"] + if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): raise ValueError( "{0} contains invalid schema update options." - "Please only use one or more of the following options: {1}" - .format(schema_update_options, allowed_schema_update_options)) + "Please only use one or more of the following options: {1}".format( + schema_update_options, allowed_schema_update_options + ) + ) - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_project_dataset_table, - default_project_id=self.project_id, - var_name='destination_project_dataset_table') + destination_project, destination_dataset, destination_table = _split_tablename( + table_input=destination_project_dataset_table, + default_project_id=self.project_id, + var_name='destination_project_dataset_table', + ) configuration = { 'load': { @@ -1713,18 +1667,13 @@ def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid 'sourceFormat': source_format, 'sourceUris': source_uris, 'writeDisposition': write_disposition, - 'ignoreUnknownValues': ignore_unknown_values + 'ignoreUnknownValues': ignore_unknown_values, } } - time_partitioning = _cleanse_time_partitioning( - destination_project_dataset_table, - time_partitioning - ) + time_partitioning = _cleanse_time_partitioning(destination_project_dataset_table, time_partitioning) if time_partitioning: - configuration['load'].update({ - 'timePartitioning': time_partitioning - }) + configuration['load'].update({'timePartitioning': time_partitioning}) if cluster_fields: configuration['load'].update({'clustering': {'fields': cluster_fields}}) @@ -1734,30 +1683,32 @@ def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid if schema_update_options: if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: - raise ValueError("schema_update_options is only " - "allowed if write_disposition is " - "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") - else: - self.log.info( - "Adding experimental 'schemaUpdateOptions': %s", - schema_update_options + raise ValueError( + "schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'." ) - configuration['load'][ - 'schemaUpdateOptions'] = schema_update_options + else: + self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) + configuration['load']['schemaUpdateOptions'] = schema_update_options if max_bad_records: configuration['load']['maxBadRecords'] = max_bad_records if encryption_configuration: - configuration["load"][ - "destinationEncryptionConfiguration" - ] = encryption_configuration + configuration["load"]["destinationEncryptionConfiguration"] = encryption_configuration src_fmt_to_configs_mapping = { 'CSV': [ - 'allowJaggedRows', 'allowQuotedNewlines', 'autodetect', - 'fieldDelimiter', 'skipLeadingRows', 'ignoreUnknownValues', - 'nullMarker', 'quote', 'encoding' + 'allowJaggedRows', + 'allowQuotedNewlines', + 'autodetect', + 'fieldDelimiter', + 'skipLeadingRows', + 'ignoreUnknownValues', + 'nullMarker', + 'quote', + 'encoding', ], 'DATASTORE_BACKUP': ['projectionFields'], 'NEWLINE_DELIMITED_JSON': ['autodetect', 'ignoreUnknownValues'], @@ -1769,15 +1720,18 @@ def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid # if following fields are not specified in src_fmt_configs, # honor the top-level params for backward-compatibility - backward_compatibility_configs = {'skipLeadingRows': skip_leading_rows, - 'fieldDelimiter': field_delimiter, - 'ignoreUnknownValues': ignore_unknown_values, - 'quote': quote_character, - 'allowQuotedNewlines': allow_quoted_newlines, - 'encoding': encoding} + backward_compatibility_configs = { + 'skipLeadingRows': skip_leading_rows, + 'fieldDelimiter': field_delimiter, + 'ignoreUnknownValues': ignore_unknown_values, + 'quote': quote_character, + 'allowQuotedNewlines': allow_quoted_newlines, + 'encoding': encoding, + } - src_fmt_configs = _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs, - backward_compatibility_configs) + src_fmt_configs = _validate_src_fmt_configs( + source_format, src_fmt_configs, valid_configs, backward_compatibility_configs + ) configuration['load'].update(src_fmt_configs) @@ -1788,13 +1742,15 @@ def run_load(self, # pylint: disable=too-many-locals,too-many-arguments,invalid self.running_job_id = job.job_id return job.job_id - def run_copy(self, # pylint: disable=invalid-name - source_project_dataset_tables: Union[List, str], - destination_project_dataset_table: str, - write_disposition: str = 'WRITE_EMPTY', - create_disposition: str = 'CREATE_IF_NEEDED', - labels: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None) -> str: + def run_copy( # pylint: disable=invalid-name + self, + source_project_dataset_tables: Union[List, str], + destination_project_dataset_table: str, + write_disposition: str = 'WRITE_EMPTY', + create_disposition: str = 'CREATE_IF_NEEDED', + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + ) -> str: """ Executes a BigQuery copy command to copy data from one BigQuery table to another. See here: @@ -1829,35 +1785,31 @@ def run_copy(self, # pylint: disable=invalid-name :type encryption_configuration: dict """ warnings.warn( - "This method is deprecated. Please use `BigQueryHook.insert_job` method.", - DeprecationWarning + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", DeprecationWarning ) if not self.project_id: raise ValueError("The project_id should be set") - source_project_dataset_tables = ([ - source_project_dataset_tables - ] if not isinstance(source_project_dataset_tables, list) else - source_project_dataset_tables) + source_project_dataset_tables = ( + [source_project_dataset_tables] + if not isinstance(source_project_dataset_tables, list) + else source_project_dataset_tables + ) source_project_dataset_tables_fixup = [] for source_project_dataset_table in source_project_dataset_tables: - source_project, source_dataset, source_table = \ - _split_tablename(table_input=source_project_dataset_table, - default_project_id=self.project_id, - var_name='source_project_dataset_table') - source_project_dataset_tables_fixup.append({ - 'projectId': - source_project, - 'datasetId': - source_dataset, - 'tableId': - source_table - }) + source_project, source_dataset, source_table = _split_tablename( + table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name='source_project_dataset_table', + ) + source_project_dataset_tables_fixup.append( + {'projectId': source_project, 'datasetId': source_dataset, 'tableId': source_table} + ) - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_project_dataset_table, - default_project_id=self.project_id) + destination_project, destination_dataset, destination_table = _split_tablename( + table_input=destination_project_dataset_table, default_project_id=self.project_id + ) configuration = { 'copy': { 'createDisposition': create_disposition, @@ -1866,8 +1818,8 @@ def run_copy(self, # pylint: disable=invalid-name 'destinationTable': { 'projectId': destination_project, 'datasetId': destination_dataset, - 'tableId': destination_table - } + 'tableId': destination_table, + }, } } @@ -1875,23 +1827,22 @@ def run_copy(self, # pylint: disable=invalid-name configuration['labels'] = labels if encryption_configuration: - configuration["copy"][ - "destinationEncryptionConfiguration" - ] = encryption_configuration + configuration["copy"]["destinationEncryptionConfiguration"] = encryption_configuration job = self.insert_job(configuration=configuration, project_id=self.project_id) self.running_job_id = job.job_id return job.job_id def run_extract( - self, - source_project_dataset_table: str, - destination_cloud_storage_uris: str, - compression: str = 'NONE', - export_format: str = 'CSV', - field_delimiter: str = ',', - print_header: bool = True, - labels: Optional[Dict] = None) -> str: + self, + source_project_dataset_table: str, + destination_cloud_storage_uris: str, + compression: str = 'NONE', + export_format: str = 'CSV', + field_delimiter: str = ',', + print_header: bool = True, + labels: Optional[Dict] = None, + ) -> str: """ Executes a BigQuery extract command to copy data from BigQuery to Google Cloud Storage. See here: @@ -1921,16 +1872,16 @@ def run_extract( :type labels: dict """ warnings.warn( - "This method is deprecated. Please use `BigQueryHook.insert_job` method.", - DeprecationWarning + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", DeprecationWarning ) if not self.project_id: raise ValueError("The project_id should be set") - source_project, source_dataset, source_table = \ - _split_tablename(table_input=source_project_dataset_table, - default_project_id=self.project_id, - var_name='source_project_dataset_table') + source_project, source_dataset, source_table = _split_tablename( + table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name='source_project_dataset_table', + ) configuration = { 'extract': { @@ -1960,26 +1911,28 @@ def run_extract( return job.job_id # pylint: disable=too-many-locals,too-many-arguments, too-many-branches - def run_query(self, - sql: str, - destination_dataset_table: Optional[str] = None, - write_disposition: str = 'WRITE_EMPTY', - allow_large_results: bool = False, - flatten_results: Optional[bool] = None, - udf_config: Optional[List] = None, - use_legacy_sql: Optional[bool] = None, - maximum_billing_tier: Optional[int] = None, - maximum_bytes_billed: Optional[float] = None, - create_disposition: str = 'CREATE_IF_NEEDED', - query_params: Optional[List] = None, - labels: Optional[Dict] = None, - schema_update_options: Optional[Iterable] = None, - priority: str = 'INTERACTIVE', - time_partitioning: Optional[Dict] = None, - api_resource_configs: Optional[Dict] = None, - cluster_fields: Optional[List[str]] = None, - location: Optional[str] = None, - encryption_configuration: Optional[Dict] = None) -> str: + def run_query( + self, + sql: str, + destination_dataset_table: Optional[str] = None, + write_disposition: str = 'WRITE_EMPTY', + allow_large_results: bool = False, + flatten_results: Optional[bool] = None, + udf_config: Optional[List] = None, + use_legacy_sql: Optional[bool] = None, + maximum_billing_tier: Optional[int] = None, + maximum_bytes_billed: Optional[float] = None, + create_disposition: str = 'CREATE_IF_NEEDED', + query_params: Optional[List] = None, + labels: Optional[Dict] = None, + schema_update_options: Optional[Iterable] = None, + priority: str = 'INTERACTIVE', + time_partitioning: Optional[Dict] = None, + api_resource_configs: Optional[Dict] = None, + cluster_fields: Optional[List[str]] = None, + location: Optional[str] = None, + encryption_configuration: Optional[Dict] = None, + ) -> str: """ Executes a BigQuery SQL query. Optionally persists results in a BigQuery table. See here: @@ -2060,8 +2013,7 @@ def run_query(self, :type encryption_configuration: dict """ warnings.warn( - "This method is deprecated. Please use `BigQueryHook.insert_job` method.", - DeprecationWarning + "This method is deprecated. Please use `BigQueryHook.insert_job` method.", DeprecationWarning ) if not self.project_id: raise ValueError("The project_id should be set") @@ -2077,47 +2029,43 @@ def run_query(self, if not api_resource_configs: api_resource_configs = self.api_resource_configs else: - _validate_value('api_resource_configs', - api_resource_configs, dict) + _validate_value('api_resource_configs', api_resource_configs, dict) configuration = deepcopy(api_resource_configs) if 'query' not in configuration: configuration['query'] = {} else: - _validate_value("api_resource_configs['query']", - configuration['query'], dict) + _validate_value("api_resource_configs['query']", configuration['query'], dict) if sql is None and not configuration['query'].get('query', None): - raise TypeError('`BigQueryBaseCursor.run_query` ' - 'missing 1 required positional argument: `sql`') + raise TypeError('`BigQueryBaseCursor.run_query` ' 'missing 1 required positional argument: `sql`') # BigQuery also allows you to define how you want a table's schema to change # as a side effect of a query job # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions # noqa # pylint: disable=line-too-long - allowed_schema_update_options = [ - 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" - ] + allowed_schema_update_options = ['ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"] - if not set(allowed_schema_update_options - ).issuperset(set(schema_update_options)): - raise ValueError("{0} contains invalid schema update options. " - "Please only use one or more of the following " - "options: {1}" - .format(schema_update_options, - allowed_schema_update_options)) + if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): + raise ValueError( + "{0} contains invalid schema update options. " + "Please only use one or more of the following " + "options: {1}".format(schema_update_options, allowed_schema_update_options) + ) if schema_update_options: if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: - raise ValueError("schema_update_options is only " - "allowed if write_disposition is " - "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") + raise ValueError( + "schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'." + ) if destination_dataset_table: - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_dataset_table, - default_project_id=self.project_id) + destination_project, destination_dataset, destination_table = _split_tablename( + table_input=destination_dataset_table, default_project_id=self.project_id + ) destination_dataset_table = { # type: ignore 'projectId': destination_project, @@ -2145,15 +2093,13 @@ def run_query(self, for param, param_name, param_default, param_type in query_param_list: if param_name not in configuration['query'] and param in [None, {}, ()]: if param_name == 'timePartitioning': - param_default = _cleanse_time_partitioning( - destination_dataset_table, time_partitioning) + param_default = _cleanse_time_partitioning(destination_dataset_table, time_partitioning) param = param_default if param in [None, {}, ()]: continue - _api_resource_configs_duplication_check( - param_name, param, configuration['query']) + _api_resource_configs_duplication_check(param_name, param, configuration['query']) configuration['query'][param_name] = param @@ -2161,12 +2107,10 @@ def run_query(self, # it last step because we can get param from 2 sources, # and first of all need to find it - _validate_value(param_name, configuration['query'][param_name], - param_type) + _validate_value(param_name, configuration['query'][param_name], param_type) if param_name == 'schemaUpdateOptions' and param: - self.log.info("Adding experimental 'schemaUpdateOptions': " - "%s", schema_update_options) + self.log.info("Adding experimental 'schemaUpdateOptions': " "%s", schema_update_options) if param_name != 'destinationTable': continue @@ -2177,29 +2121,31 @@ def run_query(self, "Not correct 'destinationTable' in " "api_resource_configs. 'destinationTable' " "must be a dict with {'projectId':'', " - "'datasetId':'', 'tableId':''}") + "'datasetId':'', 'tableId':''}" + ) - configuration['query'].update({ - 'allowLargeResults': allow_large_results, - 'flattenResults': flatten_results, - 'writeDisposition': write_disposition, - 'createDisposition': create_disposition, - }) + configuration['query'].update( + { + 'allowLargeResults': allow_large_results, + 'flattenResults': flatten_results, + 'writeDisposition': write_disposition, + 'createDisposition': create_disposition, + } + ) - if 'useLegacySql' in configuration['query'] and configuration['query']['useLegacySql'] and\ - 'queryParameters' in configuration['query']: - raise ValueError("Query parameters are not allowed " - "when using legacy SQL") + if ( + 'useLegacySql' in configuration['query'] + and configuration['query']['useLegacySql'] + and 'queryParameters' in configuration['query'] + ): + raise ValueError("Query parameters are not allowed " "when using legacy SQL") if labels: - _api_resource_configs_duplication_check( - 'labels', labels, configuration) + _api_resource_configs_duplication_check('labels', labels, configuration) configuration['labels'] = labels if encryption_configuration: - configuration["query"][ - "destinationEncryptionConfiguration" - ] = encryption_configuration + configuration["query"]["destinationEncryptionConfiguration"] = encryption_configuration job = self.insert_job(configuration=configuration, project_id=self.project_id) self.running_job_id = job.job_id @@ -2239,20 +2185,19 @@ def __init__(self, *args, **kwargs) -> None: self._args = args self._kwargs = kwargs - def close(self) -> None: # noqa: D403 + def close(self) -> None: # noqa: D403 """BigQueryConnection does not have anything to close""" - def commit(self) -> None: # noqa: D403 + def commit(self) -> None: # noqa: D403 """BigQueryConnection does not support transactions""" - def cursor(self) -> "BigQueryCursor": # noqa: D403 + def cursor(self) -> "BigQueryCursor": # noqa: D403 """Return a new :py:class:`Cursor` object using the connection""" return BigQueryCursor(*self._args, **self._kwargs) - def rollback(self) -> NoReturn: # noqa: D403 + def rollback(self) -> NoReturn: # noqa: D403 """BigQueryConnection does not have transactions""" - raise NotImplementedError( - "BigQueryConnection does not have transactions") + raise NotImplementedError("BigQueryConnection does not have transactions") class BigQueryBaseCursor(LoggingMixin): @@ -2262,14 +2207,16 @@ class BigQueryBaseCursor(LoggingMixin): PEP 249 cursor isn't needed. """ - def __init__(self, - service: Any, - project_id: str, - hook: BigQueryHook, - use_legacy_sql: bool = True, - api_resource_configs: Optional[Dict] = None, - location: Optional[str] = None, - num_retries: int = 5) -> None: + def __init__( + self, + service: Any, + project_id: str, + hook: BigQueryHook, + use_legacy_sql: bool = True, + api_resource_configs: Optional[Dict] = None, + location: Optional[str] = None, + num_retries: int = 5, + ) -> None: super().__init__() self.service = service @@ -2277,8 +2224,7 @@ def __init__(self, self.use_legacy_sql = use_legacy_sql if api_resource_configs: _validate_value("api_resource_configs", api_resource_configs, dict) - self.api_resource_configs = api_resource_configs \ - if api_resource_configs else {} # type Dict + self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict self.running_job_id = None # type: Optional[str] self.location = location self.num_retries = num_retries @@ -2292,7 +2238,9 @@ def create_empty_table(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.create_empty_table(*args, **kwargs) def create_empty_dataset(self, *args, **kwargs) -> None: @@ -2303,7 +2251,9 @@ def create_empty_dataset(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_dataset`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.create_empty_dataset(*args, **kwargs) def get_dataset_tables(self, *args, **kwargs) -> List[Dict[str, Any]]: @@ -2314,7 +2264,9 @@ def get_dataset_tables(self, *args, **kwargs) -> List[Dict[str, Any]]: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.get_dataset_tables(*args, **kwargs) def delete_dataset(self, *args, **kwargs) -> None: @@ -2325,7 +2277,9 @@ def delete_dataset(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.delete_dataset`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.delete_dataset(*args, **kwargs) def create_external_table(self, *args, **kwargs) -> None: @@ -2336,7 +2290,9 @@ def create_external_table(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_external_table`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.create_external_table(*args, **kwargs) def patch_table(self, *args, **kwargs) -> None: @@ -2347,7 +2303,9 @@ def patch_table(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_table`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.patch_table(*args, **kwargs) def insert_all(self, *args, **kwargs) -> None: @@ -2358,7 +2316,9 @@ def insert_all(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_all`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.insert_all(*args, **kwargs) def update_dataset(self, *args, **kwargs) -> Dict: @@ -2369,7 +2329,9 @@ def update_dataset(self, *args, **kwargs) -> Dict: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_dataset`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return Dataset.to_api_repr(self.hook.update_dataset(*args, **kwargs)) def patch_dataset(self, *args, **kwargs) -> Dict: @@ -2380,7 +2342,9 @@ def patch_dataset(self, *args, **kwargs) -> Dict: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.patch_dataset`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.patch_dataset(*args, **kwargs) def get_dataset_tables_list(self, *args, **kwargs) -> List[Dict[str, Any]]: @@ -2391,7 +2355,9 @@ def get_dataset_tables_list(self, *args, **kwargs) -> List[Dict[str, Any]]: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables_list`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.get_dataset_tables_list(*args, **kwargs) def get_datasets_list(self, *args, **kwargs) -> List: @@ -2402,7 +2368,9 @@ def get_datasets_list(self, *args, **kwargs) -> List: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_datasets_list`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.get_datasets_list(*args, **kwargs) def get_dataset(self, *args, **kwargs) -> Dict: @@ -2413,7 +2381,9 @@ def get_dataset(self, *args, **kwargs) -> Dict: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.get_dataset(*args, **kwargs) def run_grant_dataset_view_access(self, *args, **kwargs) -> Dict: @@ -2425,7 +2395,9 @@ def run_grant_dataset_view_access(self, *args, **kwargs) -> Dict: "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks" ".bigquery.BigQueryHook.run_grant_dataset_view_access`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_grant_dataset_view_access(*args, **kwargs) def run_table_upsert(self, *args, **kwargs) -> Dict: @@ -2436,7 +2408,9 @@ def run_table_upsert(self, *args, **kwargs) -> Dict: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_table_upsert`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_table_upsert(*args, **kwargs) def run_table_delete(self, *args, **kwargs) -> None: @@ -2447,7 +2421,9 @@ def run_table_delete(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_table_delete`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_table_delete(*args, **kwargs) def get_tabledata(self, *args, **kwargs) -> List[Dict]: @@ -2458,7 +2434,9 @@ def get_tabledata(self, *args, **kwargs) -> List[Dict]: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_tabledata`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.get_tabledata(*args, **kwargs) def get_schema(self, *args, **kwargs) -> Dict: @@ -2469,7 +2447,9 @@ def get_schema(self, *args, **kwargs) -> Dict: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_schema`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.get_schema(*args, **kwargs) def poll_job_complete(self, *args, **kwargs) -> bool: @@ -2480,7 +2460,9 @@ def poll_job_complete(self, *args, **kwargs) -> bool: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.poll_job_complete(*args, **kwargs) def cancel_query(self, *args, **kwargs) -> None: @@ -2491,7 +2473,9 @@ def cancel_query(self, *args, **kwargs) -> None: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.cancel_query`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.cancel_query(*args, **kwargs) # type: ignore # noqa def run_with_configuration(self, *args, **kwargs) -> str: @@ -2502,7 +2486,9 @@ def run_with_configuration(self, *args, **kwargs) -> str: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_with_configuration(*args, **kwargs) def run_load(self, *args, **kwargs) -> str: @@ -2513,7 +2499,9 @@ def run_load(self, *args, **kwargs) -> str: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_load`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_load(*args, **kwargs) def run_copy(self, *args, **kwargs) -> str: @@ -2524,7 +2512,9 @@ def run_copy(self, *args, **kwargs) -> str: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_copy`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_copy(*args, **kwargs) def run_extract(self, *args, **kwargs) -> str: @@ -2535,7 +2525,9 @@ def run_extract(self, *args, **kwargs) -> str: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_extract`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_extract(*args, **kwargs) def run_query(self, *args, **kwargs) -> str: @@ -2546,7 +2538,9 @@ def run_query(self, *args, **kwargs) -> str: warnings.warn( "This method is deprecated. " "Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_query`", - DeprecationWarning, stacklevel=3) + DeprecationWarning, + stacklevel=3, + ) return self.hook.run_query(*args, **kwargs) @@ -2604,8 +2598,7 @@ def execute(self, operation: str, parameters: Optional[Dict] = None) -> None: :param parameters: Parameters to substitute into the query. :type parameters: dict """ - sql = _bind_parameters(operation, - parameters) if parameters else operation + sql = _bind_parameters(operation, parameters) if parameters else operation self.flush_results() self.job_id = self.hook.run_query(sql) @@ -2647,11 +2640,16 @@ def next(self) -> Union[List, None]: if self.all_pages_loaded: return None - query_results = (self.service.jobs().getQueryResults( - projectId=self.project_id, - jobId=self.job_id, - location=self.location, - pageToken=self.page_token).execute(num_retries=self.num_retries)) + query_results = ( + self.service.jobs() + .getQueryResults( + projectId=self.project_id, + jobId=self.job_id, + location=self.location, + pageToken=self.page_token, + ) + .execute(num_retries=self.num_retries) + ) if 'rows' in query_results and query_results['rows']: self.page_token = query_results.get('pageToken') @@ -2660,10 +2658,7 @@ def next(self) -> Union[List, None]: rows = query_results['rows'] for dict_row in rows: - typed_row = ([ - _bq_cast(vs['v'], col_types[idx]) - for idx, vs in enumerate(dict_row['f']) - ]) + typed_row = [_bq_cast(vs['v'], col_types[idx]) for idx, vs in enumerate(dict_row['f'])] self.buffer.append(typed_row) if not self.page_token: @@ -2766,20 +2761,20 @@ def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, s return float(string_field) elif bq_type == 'BOOLEAN': if string_field not in ['true', 'false']: - raise ValueError("{} must have value 'true' or 'false'".format( - string_field)) + raise ValueError("{} must have value 'true' or 'false'".format(string_field)) return string_field == 'true' else: return string_field -def _split_tablename(table_input: str, default_project_id: str, - var_name: Optional[str] = None) -> Tuple[str, str, str]: +def _split_tablename( + table_input: str, default_project_id: str, var_name: Optional[str] = None +) -> Tuple[str, str, str]: if '.' not in table_input: raise ValueError( - 'Expected target table name in the format of ' - '.. Got: {}'.format(table_input)) + 'Expected target table name in the format of ' '.
. Got: {}'.format(table_input) + ) if not default_project_id: raise ValueError("INTERNAL: No default project is specified") @@ -2791,9 +2786,11 @@ def var_print(var_name): return "Format exception for {var}: ".format(var=var_name) if table_input.count('.') + table_input.count(':') > 3: - raise Exception(('{var}Use either : or . to specify project ' - 'got {input}').format( - var=var_print(var_name), input=table_input)) + raise Exception( + ('{var}Use either : or . to specify project ' 'got {input}').format( + var=var_print(var_name), input=table_input + ) + ) cmpt = table_input.rsplit(':', 1) project_id = None rest = table_input @@ -2805,16 +2802,16 @@ def var_print(var_name): project_id = cmpt[0] rest = cmpt[1] else: - raise Exception(('{var}Expect format of (.
, ' - 'got {input}').format( - var=var_print(var_name), input=table_input)) + raise Exception( + ('{var}Expect format of (.
, ' 'got {input}').format( + var=var_print(var_name), input=table_input + ) + ) cmpt = rest.split('.') if len(cmpt) == 3: if project_id: - raise ValueError( - "{var}Use either : or . to specify project".format( - var=var_print(var_name))) + raise ValueError("{var}Use either : or . to specify project".format(var=var_print(var_name))) project_id = cmpt[0] dataset_id = cmpt[1] table_id = cmpt[2] @@ -2824,14 +2821,18 @@ def var_print(var_name): table_id = cmpt[1] else: raise Exception( - ('{var}Expect format of (.
, ' - 'got {input}').format(var=var_print(var_name), input=table_input)) + ('{var}Expect format of (.
, ' 'got {input}').format( + var=var_print(var_name), input=table_input + ) + ) if project_id is None: if var_name is not None: log.info( 'Project not included in %s: %s; using project "%s"', - var_name, table_input, default_project_id + var_name, + table_input, + default_project_id, ) project_id = default_project_id @@ -2840,7 +2841,7 @@ def var_print(var_name): def _cleanse_time_partitioning( destination_dataset_table: Optional[str], time_partitioning_in: Optional[Dict] -) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option +) -> Dict: # if it is a partitioned table ($ is in the table name) add partition load option if time_partitioning_in is None: time_partitioning_in = {} @@ -2855,24 +2856,29 @@ def _cleanse_time_partitioning( def _validate_value(key: Any, value: Any, expected_type: Type) -> None: """Function to check expected type and raise error if type is not correct""" if not isinstance(value, expected_type): - raise TypeError("{} argument must have a type {} not {}".format( - key, expected_type, type(value))) + raise TypeError("{} argument must have a type {} not {}".format(key, expected_type, type(value))) -def _api_resource_configs_duplication_check(key: Any, value: Any, config_dict: Dict, - config_dict_name='api_resource_configs') -> None: +def _api_resource_configs_duplication_check( + key: Any, value: Any, config_dict: Dict, config_dict_name='api_resource_configs' +) -> None: if key in config_dict and value != config_dict[key]: - raise ValueError("Values of {param_name} param are duplicated. " - "{dict_name} contained {param_name} param " - "in `query` config and {param_name} was also provided " - "with arg to run_query() method. Please remove duplicates." - .format(param_name=key, dict_name=config_dict_name)) + raise ValueError( + "Values of {param_name} param are duplicated. " + "{dict_name} contained {param_name} param " + "in `query` config and {param_name} was also provided " + "with arg to run_query() method. Please remove duplicates.".format( + param_name=key, dict_name=config_dict_name + ) + ) -def _validate_src_fmt_configs(source_format: str, - src_fmt_configs: Dict, - valid_configs: List[str], - backward_compatibility_configs: Optional[Dict] = None) -> Dict: +def _validate_src_fmt_configs( + source_format: str, + src_fmt_configs: Dict, + valid_configs: List[str], + backward_compatibility_configs: Optional[Dict] = None, +) -> Dict: """ Validates the given src_fmt_configs against a valid configuration for the source format. Adds the backward compatiblity config to the src_fmt_configs. @@ -2896,7 +2902,6 @@ def _validate_src_fmt_configs(source_format: str, for k, v in src_fmt_configs.items(): if k not in valid_configs: - raise ValueError("{0} is not a valid src_fmt_configs for type {1}." - .format(k, source_format)) + raise ValueError("{0} is not a valid src_fmt_configs for type {1}.".format(k, source_format)) return src_fmt_configs diff --git a/airflow/providers/google/cloud/hooks/bigquery_dts.py b/airflow/providers/google/cloud/hooks/bigquery_dts.py index efea8598621bc..0f71c792121b2 100644 --- a/airflow/providers/google/cloud/hooks/bigquery_dts.py +++ b/airflow/providers/google/cloud/hooks/bigquery_dts.py @@ -25,7 +25,9 @@ from google.api_core.retry import Retry from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient from google.cloud.bigquery_datatransfer_v1.types import ( - StartManualTransferRunsResponse, TransferConfig, TransferRun, + StartManualTransferRunsResponse, + TransferConfig, + TransferRun, ) from google.protobuf.json_format import MessageToDict, ParseDict from googleapiclient.discovery import Resource @@ -57,9 +59,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) @staticmethod @@ -77,9 +77,7 @@ def _disable_auto_scheduling(config: Union[dict, TransferConfig]) -> TransferCon new_config = copy(config) schedule_options = new_config.get("schedule_options") if schedule_options: - disable_auto_scheduling = schedule_options.get( - "disable_auto_scheduling", None - ) + disable_auto_scheduling = schedule_options.get("disable_auto_scheduling", None) if disable_auto_scheduling is None: schedule_options["disable_auto_scheduling"] = True else: @@ -171,12 +169,8 @@ def delete_transfer_config( :return: None """ client = self.get_conn() - name = client.project_transfer_config_path( - project=project_id, transfer_config=transfer_config_id - ) - return client.delete_transfer_config( - name=name, retry=retry, timeout=timeout, metadata=metadata - ) + name = client.project_transfer_config_path(project=project_id, transfer_config=transfer_config_id) + return client.delete_transfer_config(name=name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def start_manual_transfer_runs( @@ -221,9 +215,7 @@ def start_manual_transfer_runs( :return: An ``google.cloud.bigquery_datatransfer_v1.types.StartManualTransferRunsResponse`` instance. """ client = self.get_conn() - parent = client.project_transfer_config_path( - project=project_id, transfer_config=transfer_config_id - ) + parent = client.project_transfer_config_path(project=project_id, transfer_config=transfer_config_id) return client.start_manual_transfer_runs( parent=parent, requested_time_range=requested_time_range, @@ -265,9 +257,5 @@ def get_transfer_run( :return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance. """ client = self.get_conn() - name = client.project_run_path( - project=project_id, transfer_config=transfer_config_id, run=run_id - ) - return client.get_transfer_run( - name=name, retry=retry, timeout=timeout, metadata=metadata - ) + name = client.project_run_path(project=project_id, transfer_config=transfer_config_id, run=run_id) + return client.get_transfer_run(name=name, retry=retry, timeout=timeout, metadata=metadata) diff --git a/airflow/providers/google/cloud/hooks/bigtable.py b/airflow/providers/google/cloud/hooks/bigtable.py index e8f42fa4a5bce..5d5286e6df023 100644 --- a/airflow/providers/google/cloud/hooks/bigtable.py +++ b/airflow/providers/google/cloud/hooks/bigtable.py @@ -48,9 +48,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -60,7 +58,7 @@ def _get_client(self, project_id: str): project=project_id, credentials=self._get_credentials(), client_info=self.client_info, - admin=True + admin=True, ) return self._client @@ -100,8 +98,9 @@ def delete_instance(self, instance_id: str, project_id: str) -> None: if instance: instance.delete() else: - self.log.warning("The instance '%s' does not exist in project '%s'. Exiting", - instance_id, project_id) + self.log.warning( + "The instance '%s' does not exist in project '%s'. Exiting", instance_id, project_id + ) @GoogleBaseHook.fallback_to_default_project_id def create_instance( @@ -118,7 +117,7 @@ def create_instance( instance_labels: Optional[Dict] = None, cluster_nodes: Optional[int] = None, cluster_storage_type: enums.StorageType = enums.StorageType.STORAGE_TYPE_UNSPECIFIED, - timeout: Optional[float] = None + timeout: Optional[float] = None, ) -> Instance: """ Creates new instance. @@ -170,36 +169,31 @@ def create_instance( instance_labels, ) - clusters = [ - instance.cluster( - main_cluster_id, - main_cluster_zone, - cluster_nodes, - cluster_storage_type - ) - ] + clusters = [instance.cluster(main_cluster_id, main_cluster_zone, cluster_nodes, cluster_storage_type)] if replica_cluster_id and replica_cluster_zone: warnings.warn( "The replica_cluster_id and replica_cluster_zone parameter have been deprecated." - "You should pass the replica_clusters parameter.", DeprecationWarning, stacklevel=2) - clusters.append(instance.cluster( - replica_cluster_id, - replica_cluster_zone, - cluster_nodes, - cluster_storage_type - )) + "You should pass the replica_clusters parameter.", + DeprecationWarning, + stacklevel=2, + ) + clusters.append( + instance.cluster( + replica_cluster_id, replica_cluster_zone, cluster_nodes, cluster_storage_type + ) + ) if replica_clusters: for replica_cluster in replica_clusters: if "id" in replica_cluster and "zone" in replica_cluster: - clusters.append(instance.cluster( - replica_cluster["id"], - replica_cluster["zone"], - cluster_nodes, - cluster_storage_type - )) - operation = instance.create( - clusters=clusters - ) + clusters.append( + instance.cluster( + replica_cluster["id"], + replica_cluster["zone"], + cluster_nodes, + cluster_storage_type, + ) + ) + operation = instance.create(clusters=clusters) operation.result(timeout) return instance @@ -211,7 +205,7 @@ def update_instance( instance_display_name: Optional[str] = None, instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, instance_labels: Optional[Dict] = None, - timeout: Optional[float] = None + timeout: Optional[float] = None, ) -> Instance: """ Update an existing instance. @@ -253,7 +247,7 @@ def create_table( instance: Instance, table_id: str, initial_split_keys: Optional[List] = None, - column_families: Optional[Dict[str, GarbageCollectionRule]] = None + column_families: Optional[Dict[str, GarbageCollectionRule]] = None, ) -> None: """ Creates the specified Cloud Bigtable table. diff --git a/airflow/providers/google/cloud/hooks/cloud_build.py b/airflow/providers/google/cloud/hooks/cloud_build.py index bc979244979e9..821b1ba81be62 100644 --- a/airflow/providers/google/cloud/hooks/cloud_build.py +++ b/airflow/providers/google/cloud/hooks/cloud_build.py @@ -65,9 +65,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -137,7 +135,9 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> None: while True: operation_response = ( # pylint: disable=no-member - service.operations().get(name=operation_name).execute(num_retries=self.num_retries) + service.operations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) ) if operation_response.get("done"): response = operation_response.get("response") diff --git a/airflow/providers/google/cloud/hooks/cloud_memorystore.py b/airflow/providers/google/cloud/hooks/cloud_memorystore.py index bb37e4a8eff06..7c59e86704b38 100644 --- a/airflow/providers/google/cloud/hooks/cloud_memorystore.py +++ b/airflow/providers/google/cloud/hooks/cloud_memorystore.py @@ -63,9 +63,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None # type: Optional[CloudRedisClient] diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index a033de09a6683..c70f0648d54c4 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -44,6 +44,7 @@ from sqlalchemy.orm import Session from airflow.exceptions import AirflowException + # Number of retries - used by googleapiclient method calls to perform retries # For requests that are "retriable" from airflow.hooks.base_hook import BaseHook @@ -64,6 +65,7 @@ class CloudSqlOperationStatus: """ Helper class with operation statuses. """ + PENDING = "PENDING" RUNNING = "RUNNING" DONE = "DONE" @@ -77,6 +79,7 @@ class CloudSQLHook(GoogleBaseHook): All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """ + def __init__( self, api_version: str, @@ -85,9 +88,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version self._conn = None @@ -101,8 +102,7 @@ def get_conn(self): """ if not self._conn: http_authorized = self._authorize() - self._conn = build('sqladmin', self.api_version, - http=http_authorized, cache_discovery=False) + self._conn = build('sqladmin', self.api_version, http=http_authorized, cache_discovery=False) return self._conn @GoogleBaseHook.fallback_to_default_project_id @@ -118,10 +118,12 @@ def get_instance(self, instance: str, project_id: str) -> Dict: :return: A Cloud SQL instance resource. :rtype: dict """ - return self.get_conn().instances().get( # noqa # pylint: disable=no-member - project=project_id, - instance=instance - ).execute(num_retries=self.num_retries) + return ( + self.get_conn() + .instances() + .get(project=project_id, instance=instance) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() @@ -137,13 +139,14 @@ def create_instance(self, body: Dict, project_id: str) -> None: :type project_id: str :return: None """ - response = self.get_conn().instances().insert( # noqa # pylint: disable=no-member - project=project_id, - body=body - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .instances() + .insert(project=project_id, body=body) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() @@ -164,14 +167,14 @@ def patch_instance(self, body: Dict, instance: str, project_id: str) -> None: :type project_id: str :return: None """ - response = self.get_conn().instances().patch( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - body=body - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .instances() + .patch(project=project_id, instance=instance, body=body) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() @@ -186,13 +189,14 @@ def delete_instance(self, instance: str, project_id: str) -> None: :type instance: str :return: None """ - response = self.get_conn().instances().delete( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .instances() + .delete(project=project_id, instance=instance,) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id def get_database(self, instance: str, database: str, project_id: str) -> Dict: @@ -210,11 +214,12 @@ def get_database(self, instance: str, database: str, project_id: str) -> Dict: https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases#resource. :rtype: dict """ - return self.get_conn().databases().get( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - database=database - ).execute(num_retries=self.num_retries) + return ( + self.get_conn() + .databases() + .get(project=project_id, instance=instance, database=database) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() @@ -232,24 +237,18 @@ def create_database(self, instance: str, body: Dict, project_id: str) -> None: :type project_id: str :return: None """ - response = self.get_conn().databases().insert( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - body=body - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .databases() + .insert(project=project_id, instance=instance, body=body) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() - def patch_database( - self, - instance: str, - database: str, - body: Dict, - project_id: str, - ) -> None: + def patch_database(self, instance: str, database: str, body: Dict, project_id: str,) -> None: """ Updates a database resource inside a Cloud SQL instance. @@ -268,15 +267,16 @@ def patch_database( :type project_id: str :return: None """ - response = self.get_conn().databases().patch( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - database=database, - body=body - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .databases() + .patch( # noqa # pylint: disable=no-member + project=project_id, instance=instance, database=database, body=body + ) + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() @@ -293,14 +293,16 @@ def delete_database(self, instance: str, database: str, project_id: str) -> None :type project_id: str :return: None """ - response = self.get_conn().databases().delete( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - database=database - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .databases() + .delete( # noqa # pylint: disable=no-member + project=project_id, instance=instance, database=database + ) + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id @GoogleBaseHook.operation_in_progress_retry() @@ -320,14 +322,14 @@ def export_instance(self, instance: str, body: Dict, project_id: str) -> None: :type project_id: str :return: None """ - response = self.get_conn().instances().export( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - body=body - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .instances() + .export(project=project_id, instance=instance, body=body) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id def import_instance(self, instance: str, body: Dict, project_id: str) -> None: @@ -347,18 +349,16 @@ def import_instance(self, instance: str, body: Dict, project_id: str) -> None: :return: None """ try: - response = self.get_conn().instances().import_( # noqa # pylint: disable=no-member - project=project_id, - instance=instance, - body=body - ).execute(num_retries=self.num_retries) + response = ( + self.get_conn() + .instances() + .import_(project=project_id, instance=instance, body=body) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) except HttpError as ex: - raise AirflowException( - 'Importing instance {} failed: {}'.format(instance, ex.content) - ) + raise AirflowException('Importing instance {} failed: {}'.format(instance, ex.content)) def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) -> None: """ @@ -373,10 +373,11 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) """ service = self.get_conn() while True: - operation_response = service.operations().get( # noqa # pylint: disable=no-member - project=project_id, - operation=operation_name, - ).execute(num_retries=self.num_retries) + operation_response = ( + service.operations() + .get(project=project_id, operation=operation_name,) # noqa # pylint: disable=no-member + .execute(num_retries=self.num_retries) + ) if operation_response.get("status") == CloudSqlOperationStatus.DONE: error = operation_response.get("error") if error: @@ -389,8 +390,9 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str) CLOUD_SQL_PROXY_DOWNLOAD_URL = "https://dl.google.com/cloudsql/cloud_sql_proxy.{}.{}" -CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL = \ +CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL = ( "https://storage.googleapis.com/cloudsql-proxy/{}/cloud_sql_proxy.{}.{}" +) GCP_CREDENTIALS_KEY_PATH = "extra__google_cloud_platform__key_path" GCP_CREDENTIALS_KEYFILE_DICT = "extra__google_cloud_platform__keyfile_dict" @@ -440,7 +442,7 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', project_id: Optional[str] = None, sql_proxy_version: Optional[str] = None, - sql_proxy_binary_path: Optional[str] = None + sql_proxy_binary_path: Optional[str] = None, ) -> None: super().__init__() self.path_prefix = path_prefix @@ -455,16 +457,15 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.command_line_parameters = [] # type: List[str] self.cloud_sql_proxy_socket_directory = self.path_prefix - self.sql_proxy_path = sql_proxy_binary_path if sql_proxy_binary_path \ - else self.path_prefix + "_cloud_sql_proxy" + self.sql_proxy_path = ( + sql_proxy_binary_path if sql_proxy_binary_path else self.path_prefix + "_cloud_sql_proxy" + ) self.credentials_path = self.path_prefix + "_credentials.json" self._build_command_line_parameters() def _build_command_line_parameters(self) -> None: - self.command_line_parameters.extend( - ['-dir', self.cloud_sql_proxy_socket_directory]) - self.command_line_parameters.extend( - ['-instances', self.instance_specification]) + self.command_line_parameters.extend(['-dir', self.cloud_sql_proxy_socket_directory]) + self.command_line_parameters.extend(['-instances', self.instance_specification]) @staticmethod def _is_os_64bit() -> bool: @@ -480,10 +481,10 @@ def _download_sql_proxy_if_needed(self) -> None: download_url = CLOUD_SQL_PROXY_DOWNLOAD_URL.format(system, processor) else: download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format( - self.sql_proxy_version, system, processor) + self.sql_proxy_version, system, processor + ) proxy_path_tmp = self.sql_proxy_path + ".tmp" - self.log.info("Downloading cloud_sql_proxy from %s to %s", - download_url, proxy_path_tmp) + self.log.info("Downloading cloud_sql_proxy from %s to %s", download_url, proxy_path_tmp) response = requests.get(download_url, allow_redirects=True) # Downloading to .tmp file first to avoid case where partially downloaded # binary is used by parallel operator which uses the same fixed binary path @@ -492,51 +493,46 @@ def _download_sql_proxy_if_needed(self) -> None: if response.status_code != 200: raise AirflowException( "The cloud-sql-proxy could not be downloaded. Status code = {}. " - "Reason = {}".format(response.status_code, response.reason)) + "Reason = {}".format(response.status_code, response.reason) + ) - self.log.info("Moving sql_proxy binary from %s to %s", - proxy_path_tmp, self.sql_proxy_path) + self.log.info("Moving sql_proxy binary from %s to %s", proxy_path_tmp, self.sql_proxy_path) shutil.move(proxy_path_tmp, self.sql_proxy_path) os.chmod(self.sql_proxy_path, 0o744) # Set executable bit self.sql_proxy_was_downloaded = True @provide_session def _get_credential_parameters(self, session: Session) -> List[str]: - connection = session.query(Connection). \ - filter(Connection.conn_id == self.gcp_conn_id).first() + connection = session.query(Connection).filter(Connection.conn_id == self.gcp_conn_id).first() session.expunge_all() if connection.extra_dejson.get(GCP_CREDENTIALS_KEY_PATH): - credential_params = [ - '-credential_file', - connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH] - ] + credential_params = ['-credential_file', connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH]] elif connection.extra_dejson.get(GCP_CREDENTIALS_KEYFILE_DICT): - credential_file_content = json.loads( - connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT]) + credential_file_content = json.loads(connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT]) self.log.info("Saving credentials to %s", self.credentials_path) with open(self.credentials_path, "w") as file: json.dump(credential_file_content, file) - credential_params = [ - '-credential_file', - self.credentials_path - ] + credential_params = ['-credential_file', self.credentials_path] else: self.log.info( "The credentials are not supplied by neither key_path nor " "keyfile_dict of the gcp connection %s. Falling back to " - "default activated account", self.gcp_conn_id) + "default activated account", + self.gcp_conn_id, + ) credential_params = [] if not self.instance_specification: - project_id = connection.extra_dejson.get( - 'extra__google_cloud_platform__project') + project_id = connection.extra_dejson.get('extra__google_cloud_platform__project') if self.project_id: project_id = self.project_id if not project_id: - raise AirflowException("For forwarding all instances, the project id " - "for GCP should be provided either " - "by project_id extra in the GCP connection or by " - "project_id provided in the operator.") + raise AirflowException( + "For forwarding all instances, the project id " + "for GCP should be provided either " + "by project_id extra in the GCP connection or by " + "project_id provided in the operator." + ) credential_params.extend(['-projects', project_id]) return credential_params @@ -548,8 +544,7 @@ def start_proxy(self) -> None: """ self._download_sql_proxy_if_needed() if self.sql_proxy_process: - raise AirflowException("The sql proxy is already running: {}".format( - self.sql_proxy_process)) + raise AirflowException("The sql proxy is already running: {}".format(self.sql_proxy_process)) else: command_to_run = [self.sql_proxy_path] command_to_run.extend(self.command_line_parameters) @@ -557,25 +552,25 @@ def start_proxy(self) -> None: Path(self.cloud_sql_proxy_socket_directory).mkdir(parents=True, exist_ok=True) command_to_run.extend(self._get_credential_parameters()) # pylint: disable=no-value-for-parameter self.log.info("Running the command: `%s`", " ".join(command_to_run)) - self.sql_proxy_process = Popen(command_to_run, - stdin=PIPE, stdout=PIPE, stderr=PIPE) + self.sql_proxy_process = Popen(command_to_run, stdin=PIPE, stdout=PIPE, stderr=PIPE) self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid) while True: - line = self.sql_proxy_process.stderr.readline().decode('utf-8') \ - if self.sql_proxy_process.stderr else "" + line = ( + self.sql_proxy_process.stderr.readline().decode('utf-8') + if self.sql_proxy_process.stderr + else "" + ) return_code = self.sql_proxy_process.poll() if line == '' and return_code is not None: self.sql_proxy_process = None raise AirflowException( - "The cloud_sql_proxy finished early with return code {}!".format( - return_code)) + "The cloud_sql_proxy finished early with return code {}!".format(return_code) + ) if line != '': self.log.info(line) if "googleapi: Error" in line or "invalid instance name:" in line: self.stop_proxy() - raise AirflowException( - "Error when starting the cloud_sql_proxy {}!".format( - line)) + raise AirflowException("Error when starting the cloud_sql_proxy {}!".format(line)) if "Ready for new connections" in line: return @@ -588,13 +583,11 @@ def stop_proxy(self) -> None: if not self.sql_proxy_process: raise AirflowException("The sql proxy is not started yet") else: - self.log.info("Stopping the cloud_sql_proxy pid: %s", - self.sql_proxy_process.pid) + self.log.info("Stopping the cloud_sql_proxy pid: %s", self.sql_proxy_process.pid) self.sql_proxy_process.kill() self.sql_proxy_process = None # Cleanup! - self.log.info("Removing the socket directory: %s", - self.cloud_sql_proxy_socket_directory) + self.log.info("Removing the socket directory: %s", self.cloud_sql_proxy_socket_directory) shutil.rmtree(self.cloud_sql_proxy_socket_directory, ignore_errors=True) if self.sql_proxy_was_downloaded: self.log.info("Removing downloaded proxy: %s", self.sql_proxy_path) @@ -605,11 +598,9 @@ def stop_proxy(self) -> None: if e.errno != errno.ENOENT: raise else: - self.log.info("Skipped removing proxy - it was not downloaded: %s", - self.sql_proxy_path) + self.log.info("Skipped removing proxy - it was not downloaded: %s", self.sql_proxy_path) if os.path.isfile(self.credentials_path): - self.log.info("Removing generated credentials file %s", - self.credentials_path) + self.log.info("Removing generated credentials file %s", self.credentials_path) # Here file cannot be delete by concurrent task (each task has its own copy) os.remove(self.credentials_path) @@ -642,38 +633,28 @@ def get_socket_path(self) -> str: CONNECTION_URIS = { "postgres": { "proxy": { - "tcp": - "postgresql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", - "socket": - "postgresql://{user}:{password}@{socket_path}/{database}" + "tcp": "postgresql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", + "socket": "postgresql://{user}:{password}@{socket_path}/{database}", }, "public": { - "ssl": - "postgresql://{user}:{password}@{public_ip}:{public_port}/{database}?" - "sslmode=verify-ca&" - "sslcert={client_cert_file}&" - "sslkey={client_key_file}&" - "sslrootcert={server_ca_file}", - "non-ssl": - "postgresql://{user}:{password}@{public_ip}:{public_port}/{database}" - } + "ssl": "postgresql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "sslmode=verify-ca&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}", + "non-ssl": "postgresql://{user}:{password}@{public_ip}:{public_port}/{database}", + }, }, "mysql": { "proxy": { - "tcp": - "mysql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", - "socket": - "mysql://{user}:{password}@localhost/{database}?" - "unix_socket={socket_path}" + "tcp": "mysql://{user}:{password}@127.0.0.1:{proxy_port}/{database}", + "socket": "mysql://{user}:{password}@localhost/{database}?" "unix_socket={socket_path}", }, "public": { - "ssl": - "mysql://{user}:{password}@{public_ip}:{public_port}/{database}?" - "ssl={ssl_spec}", - "non-ssl": - "mysql://{user}:{password}@{public_ip}:{public_port}/{database}" - } - } + "ssl": "mysql://{user}:{password}@{public_ip}:{public_port}/{database}?" "ssl={ssl_spec}", + "non-ssl": "mysql://{user}:{password}@{public_ip}:{public_port}/{database}", + }, + }, } # type: Dict[str, Dict[str, Dict[str, str]]] CLOUD_SQL_VALID_DATABASE_TYPES = ['postgres', 'mysql'] @@ -736,7 +717,7 @@ def __init__( self, gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', gcp_conn_id: str = 'google_cloud_default', - default_gcp_project_id: Optional[str] = None + default_gcp_project_id: Optional[str] = None, ) -> None: super().__init__() self.gcp_conn_id = gcp_conn_id @@ -779,11 +760,11 @@ def _get_bool(val: Any) -> bool: @staticmethod def _check_ssl_file(file_to_check, name) -> None: if not file_to_check: - raise AirflowException("SSL connections requires {name} to be set". - format(name=name)) + raise AirflowException("SSL connections requires {name} to be set".format(name=name)) if not os.path.isfile(file_to_check): - raise AirflowException("The {file_to_check} must be a readable file". - format(file_to_check=file_to_check)) + raise AirflowException( + "The {file_to_check} must be a readable file".format(file_to_check=file_to_check) + ) def _validate_inputs(self) -> None: if self.project_id == '': @@ -793,13 +774,17 @@ def _validate_inputs(self) -> None: if not self.instance: raise AirflowException("The required extra 'instance' is empty or None") if self.database_type not in CLOUD_SQL_VALID_DATABASE_TYPES: - raise AirflowException("Invalid database type '{}'. Must be one of {}".format( - self.database_type, CLOUD_SQL_VALID_DATABASE_TYPES - )) + raise AirflowException( + "Invalid database type '{}'. Must be one of {}".format( + self.database_type, CLOUD_SQL_VALID_DATABASE_TYPES + ) + ) if self.use_proxy and self.use_ssl: - raise AirflowException("Cloud SQL Proxy does not support SSL connections." - " SSL is not needed as Cloud SQL Proxy " - "provides encryption on its own") + raise AirflowException( + "Cloud SQL Proxy does not support SSL connections." + " SSL is not needed as Cloud SQL Proxy " + "provides encryption on its own" + ) def validate_ssl_certs(self) -> None: """ @@ -824,9 +809,8 @@ def validate_socket_path_length(self) -> None: else: suffix = "" expected_path = "{}/{}:{}:{}{}".format( - self._generate_unique_path(), - self.project_id, self.instance, - self.database, suffix) + self._generate_unique_path(), self.project_id, self.instance, self.database, suffix + ) if len(expected_path) > UNIX_PATH_MAX: self.log.info("Too long (%s) path: %s", len(expected_path), expected_path) raise AirflowException( @@ -834,8 +818,8 @@ def validate_socket_path_length(self) -> None: "on Linux system. Either use shorter instance/database " "name or switch to TCP connection. " "The socket path for Cloud SQL proxy is now:" - "{}".format( - UNIX_PATH_MAX, expected_path)) + "{}".format(UNIX_PATH_MAX, expected_path) + ) @staticmethod def _generate_unique_path() -> str: @@ -849,7 +833,8 @@ def _generate_unique_path() -> str: random.seed() while True: candidate = "/tmp/" + ''.join( - random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) + random.choice(string.ascii_lowercase + string.digits) for _ in range(8) + ) if not os.path.exists(candidate): return candidate @@ -876,20 +861,15 @@ def _generate_connection_uri(self) -> str: format_string = proxy_uris['tcp'] else: format_string = proxy_uris['socket'] - socket_path = \ - "{sql_proxy_socket_path}/{instance_socket_name}".format( - sql_proxy_socket_path=self.sql_proxy_unique_path, - instance_socket_name=self._get_instance_socket_name() - ) + socket_path = "{sql_proxy_socket_path}/{instance_socket_name}".format( + sql_proxy_socket_path=self.sql_proxy_unique_path, + instance_socket_name=self._get_instance_socket_name(), + ) else: public_uris = database_uris['public'] # type: Dict[str, str] if self.use_ssl: format_string = public_uris['ssl'] - ssl_spec = { - 'cert': self.sslcert, - 'key': self.sslkey, - 'ca': self.sslrootcert - } + ssl_spec = {'cert': self.sslcert, 'key': self.sslkey, 'ca': self.sslrootcert} else: format_string = public_uris['non-ssl'] if not self.user: @@ -912,10 +892,14 @@ def _generate_connection_uri(self) -> str: ssl_spec=self._quote(json.dumps(ssl_spec)) if ssl_spec else '', client_cert_file=self._quote(self.sslcert) if self.sslcert else '', client_key_file=self._quote(self.sslkey) if self.sslcert else '', - server_ca_file=self._quote(self.sslrootcert if self.sslcert else '') + server_ca_file=self._quote(self.sslrootcert if self.sslcert else ''), + ) + self.log.info( + "DB connection URI %s", + connection_uri.replace( + quote_plus(self.password) if self.password else 'PASSWORD', 'XXXXXXXXXXXX' + ), ) - self.log.info("DB connection URI %s", connection_uri.replace( - quote_plus(self.password) if self.password else 'PASSWORD', 'XXXXXXXXXXXX')) return connection_uri def _get_instance_socket_name(self) -> str: @@ -955,7 +939,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: project_id=self.project_id, sql_proxy_version=self.sql_proxy_version, sql_proxy_binary_path=self.sql_proxy_binary_path, - gcp_conn_id=self.gcp_conn_id + gcp_conn_id=self.gcp_conn_id, ) def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySqlHook]: diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index 10ef7ab00958b..c97e94b4e37a0 100644 --- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -43,6 +43,7 @@ class GcpTransferJobsStatus: """ Class with GCP Transfer jobs statuses. """ + ENABLED = "ENABLED" DISABLED = "DISABLED" DELETED = "DELETED" @@ -52,6 +53,7 @@ class GcpTransferOperationStatus: """ Class with GCP Transfer operations statuses. """ + IN_PROGRESS = "IN_PROGRESS" PAUSED = "PAUSED" SUCCESS = "SUCCESS" @@ -125,6 +127,7 @@ class CloudDataTransferServiceHook(GoogleBaseHook): All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """ + def __init__( self, api_version: str = 'v1', @@ -133,9 +136,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version self._conn = None @@ -168,32 +169,32 @@ def create_transfer_job(self, body: Dict) -> Dict: """ body = self._inject_project_id(body, BODY, PROJECT_ID) try: - transfer_job = self.get_conn().transferJobs()\ - .create(body=body).execute( # pylint: disable=no-member - num_retries=self.num_retries) + # pylint: disable=no-member + transfer_job = ( + self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries) + ) except HttpError as e: # If status code "Conflict" # https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Code.ENUM_VALUES.ALREADY_EXISTS # we should try to find this job job_name = body.get(JOB_NAME, "") if int(e.resp.status) == ALREADY_EXIST_CODE and job_name: - transfer_job = self.get_transfer_job( - job_name=job_name, project_id=body.get(PROJECT_ID)) + transfer_job = self.get_transfer_job(job_name=job_name, project_id=body.get(PROJECT_ID)) # Generate new job_name, if jobs status is deleted # and try to create this job again if transfer_job.get(STATUS) == GcpTransferJobsStatus.DELETED: body[JOB_NAME] = gen_job_name(job_name) self.log.info( - "Job `%s` has been soft deleted. Creating job with " - "new name `%s`", job_name, {body[JOB_NAME]}) + "Job `%s` has been soft deleted. Creating job with " "new name `%s`", + job_name, + {body[JOB_NAME]}, + ) # pylint: disable=no-member - return self.get_conn()\ - .transferJobs()\ - .create(body=body)\ - .execute(num_retries=self.num_retries) + return ( + self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries) + ) elif transfer_job.get(STATUS) == GcpTransferJobsStatus.DISABLED: - return self.enable_transfer_job( - job_name=job_name, project_id=body.get(PROJECT_ID)) + return self.enable_transfer_job(job_name=job_name, project_id=body.get(PROJECT_ID)) else: raise e return transfer_job @@ -251,8 +252,8 @@ def list_transfer_job(self, request_filter: Optional[Dict] = None, **kwargs) -> response = request.execute(num_retries=self.num_retries) jobs.extend(response[TRANSFER_JOBS]) - request = conn.transferJobs().list_next(previous_request=request, # pylint: disable=no-member - previous_response=response) + # pylint: disable=no-member + request = conn.transferJobs().list_next(previous_request=request, previous_response=response) return jobs @@ -272,8 +273,8 @@ def enable_transfer_job(self, job_name: str, project_id: str) -> Dict: """ return ( self.get_conn() # pylint: disable=no-member - .transferJobs() - .patch( + .transferJobs() + .patch( jobName=job_name, body={ PROJECT_ID: project_id, @@ -343,8 +344,9 @@ def cancel_transfer_operation(self, operation_name: str) -> None: :rtype: None """ - self.get_conn().transferOperations().cancel( # pylint: disable=no-member - name=operation_name).execute(num_retries=self.num_retries) + self.get_conn().transferOperations().cancel(name=operation_name).execute( # pylint: disable=no-member + num_retries=self.num_retries + ) def get_transfer_operation(self, operation_name: str) -> Dict: """ @@ -400,7 +402,8 @@ def list_transfer_operations(self, request_filter: Optional[Dict] = None, **kwar operations = [] # type: List[Dict] request = conn.transferOperations().list( # pylint: disable=no-member - name=TRANSFER_OPERATIONS, filter=json.dumps(request_filter)) + name=TRANSFER_OPERATIONS, filter=json.dumps(request_filter) + ) while request is not None: response = request.execute(num_retries=self.num_retries) @@ -421,8 +424,9 @@ def pause_transfer_operation(self, operation_name: str) -> None: :type operation_name: str :rtype: None """ - self.get_conn().transferOperations().pause( # pylint: disable=no-member - name=operation_name).execute(num_retries=self.num_retries) + self.get_conn().transferOperations().pause(name=operation_name).execute( # pylint: disable=no-member + num_retries=self.num_retries + ) def resume_transfer_operation(self, operation_name: str) -> None: """ @@ -432,14 +436,15 @@ def resume_transfer_operation(self, operation_name: str) -> None: :type operation_name: str :rtype: None """ - self.get_conn().transferOperations().resume( # pylint: disable=no-member - name=operation_name).execute(num_retries=self.num_retries) + self.get_conn().transferOperations().resume(name=operation_name).execute( # pylint: disable=no-member + num_retries=self.num_retries + ) def wait_for_transfer_job( self, job: Dict, expected_statuses: Optional[Set[str]] = None, - timeout: Optional[Union[float, timedelta]] = None + timeout: Optional[Union[float, timedelta]] = None, ) -> None: """ Waits until the job reaches the expected state. @@ -471,8 +476,9 @@ def wait_for_transfer_job( request_filter={FILTER_PROJECT_ID: job[PROJECT_ID], FILTER_JOB_NAMES: [job[NAME]]} ) - if CloudDataTransferServiceHook.\ - operations_contain_expected_statuses(operations, expected_statuses): + if CloudDataTransferServiceHook.operations_contain_expected_statuses( + operations, expected_statuses + ): return time.sleep(TIME_TO_SLEEP_IN_SECONDS) raise AirflowException("Timeout. The operation could not be completed within the allotted time.") @@ -489,8 +495,7 @@ def _inject_project_id(self, body: Dict, param_name: str, target_key: str) -> Di @staticmethod def operations_contain_expected_statuses( - operations: List[Dict], - expected_statuses: Union[Set[str], str] + operations: List[Dict], expected_statuses: Union[Set[str], str] ) -> bool: """ Checks whether the operation list has an operation with the diff --git a/airflow/providers/google/cloud/hooks/compute.py b/airflow/providers/google/cloud/hooks/compute.py index 2c9b06c6560b9..344eb2989f4f4 100644 --- a/airflow/providers/google/cloud/hooks/compute.py +++ b/airflow/providers/google/cloud/hooks/compute.py @@ -35,6 +35,7 @@ class GceOperationStatus: """ Class with GCE operations statuses. """ + PENDING = "PENDING" RUNNING = "RUNNING" DONE = "DONE" @@ -47,6 +48,7 @@ class ComputeEngineHook(GoogleBaseHook): All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """ + _conn = None # type: Optional[Any] def __init__( @@ -57,9 +59,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -72,8 +72,7 @@ def get_conn(self): """ if not self._conn: http_authorized = self._authorize() - self._conn = build('compute', self.api_version, - http=http_authorized, cache_discovery=False) + self._conn = build('compute', self.api_version, http=http_authorized, cache_discovery=False) return self._conn @GoogleBaseHook.fallback_to_default_project_id @@ -92,20 +91,20 @@ def start_instance(self, zone: str, resource_id: str, project_id: str) -> None: :type project_id: str :return: None """ - response = self.get_conn().instances().start( # noqa pylint: disable=no-member - project=project_id, - zone=zone, - instance=resource_id - ).execute(num_retries=self.num_retries) + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instances() + .start(project=project_id, zone=zone, instance=resource_id) + .execute(num_retries=self.num_retries) + ) try: operation_name = response["name"] except KeyError: raise AirflowException( - "Wrong response '{}' returned - it should contain " - "'name' field".format(response)) - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=zone) + "Wrong response '{}' returned - it should contain " "'name' field".format(response) + ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name, zone=zone) @GoogleBaseHook.fallback_to_default_project_id def stop_instance(self, zone: str, resource_id: str, project_id: str) -> None: @@ -123,29 +122,23 @@ def stop_instance(self, zone: str, resource_id: str, project_id: str) -> None: :type project_id: str :return: None """ - response = self.get_conn().instances().stop( # noqa pylint: disable=no-member - project=project_id, - zone=zone, - instance=resource_id - ).execute(num_retries=self.num_retries) + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instances() + .stop(project=project_id, zone=zone, instance=resource_id) + .execute(num_retries=self.num_retries) + ) try: operation_name = response["name"] except KeyError: raise AirflowException( - "Wrong response '{}' returned - it should contain " - "'name' field".format(response)) - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=zone) + "Wrong response '{}' returned - it should contain " "'name' field".format(response) + ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name, zone=zone) @GoogleBaseHook.fallback_to_default_project_id - def set_machine_type( - self, - zone: str, - resource_id: str, - body: Dict, - project_id: str - ) -> None: + def set_machine_type(self, zone: str, resource_id: str, body: Dict, project_id: str) -> None: """ Sets machine type of an instance defined by project_id, zone and resource_id. Must be called with keyword arguments rather than positional. @@ -169,22 +162,18 @@ def set_machine_type( operation_name = response["name"] except KeyError: raise AirflowException( - "Wrong response '{}' returned - it should contain " - "'name' field".format(response)) - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=zone) + "Wrong response '{}' returned - it should contain " "'name' field".format(response) + ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name, zone=zone) - def _execute_set_machine_type( - self, - zone: str, - resource_id: str, - body: Dict, - project_id: str - ) -> Dict: - return self.get_conn().instances().setMachineType( # noqa pylint: disable=no-member - project=project_id, zone=zone, instance=resource_id, body=body)\ + def _execute_set_machine_type(self, zone: str, resource_id: str, body: Dict, project_id: str) -> Dict: + # noqa pylint: disable=no-member + return ( + self.get_conn() + .instances() + .setMachineType(project=project_id, zone=zone, instance=resource_id, body=body) .execute(num_retries=self.num_retries) + ) @GoogleBaseHook.fallback_to_default_project_id def get_instance_template(self, resource_id: str, project_id: str) -> Dict: @@ -202,18 +191,18 @@ def get_instance_template(self, resource_id: str, project_id: str) -> Dict: https://cloud.google.com/compute/docs/reference/rest/v1/instanceTemplates :rtype: dict """ - response = self.get_conn().instanceTemplates().get( # noqa pylint: disable=no-member - project=project_id, - instanceTemplate=resource_id - ).execute(num_retries=self.num_retries) + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceTemplates() + .get(project=project_id, instanceTemplate=resource_id) + .execute(num_retries=self.num_retries) + ) return response @GoogleBaseHook.fallback_to_default_project_id def insert_instance_template( - self, - body: Dict, - project_id: str, - request_id: Optional[str] = None, + self, body: Dict, project_id: str, request_id: Optional[str] = None, ) -> None: """ Inserts instance template using body specified @@ -233,27 +222,23 @@ def insert_instance_template( :type project_id: str :return: None """ - response = self.get_conn().instanceTemplates().insert( # noqa pylint: disable=no-member - project=project_id, - body=body, - requestId=request_id - ).execute(num_retries=self.num_retries) + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceTemplates() + .insert(project=project_id, body=body, requestId=request_id) + .execute(num_retries=self.num_retries) + ) try: operation_name = response["name"] except KeyError: raise AirflowException( - "Wrong response '{}' returned - it should contain " - "'name' field".format(response)) - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name) + "Wrong response '{}' returned - it should contain " "'name' field".format(response) + ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id - def get_instance_group_manager( - self, - zone: str, - resource_id: str, - project_id: str, - ) -> Dict: + def get_instance_group_manager(self, zone: str, resource_id: str, project_id: str,) -> Dict: """ Retrieves Instance Group Manager by project_id, zone and resource_id. Must be called with keyword arguments rather than positional. @@ -270,21 +255,18 @@ def get_instance_group_manager( https://cloud.google.com/compute/docs/reference/rest/beta/instanceGroupManagers :rtype: dict """ - response = self.get_conn().instanceGroupManagers().get( # noqa pylint: disable=no-member - project=project_id, - zone=zone, - instanceGroupManager=resource_id - ).execute(num_retries=self.num_retries) + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceGroupManagers() + .get(project=project_id, zone=zone, instanceGroupManager=resource_id) + .execute(num_retries=self.num_retries) + ) return response @GoogleBaseHook.fallback_to_default_project_id def patch_instance_group_manager( - self, - zone: str, - resource_id: str, - body: Dict, - project_id: str, - request_id: Optional[str] = None, + self, zone: str, resource_id: str, body: Dict, project_id: str, request_id: Optional[str] = None, ) -> None: """ Patches Instance Group Manager with the specified body. @@ -309,28 +291,29 @@ def patch_instance_group_manager( :type project_id: str :return: None """ - response = self.get_conn().instanceGroupManagers().patch( # noqa pylint: disable=no-member - project=project_id, - zone=zone, - instanceGroupManager=resource_id, - body=body, - requestId=request_id - ).execute(num_retries=self.num_retries) + # noqa pylint: disable=no-member + response = ( + self.get_conn() + .instanceGroupManagers() + .patch( + project=project_id, + zone=zone, + instanceGroupManager=resource_id, + body=body, + requestId=request_id, + ) + .execute(num_retries=self.num_retries) + ) try: operation_name = response["name"] except KeyError: raise AirflowException( - "Wrong response '{}' returned - it should contain " - "'name' field".format(response)) - self._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=zone) + "Wrong response '{}' returned - it should contain " "'name' field".format(response) + ) + self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name, zone=zone) def _wait_for_operation_to_complete( - self, - project_id: str, - operation_name: str, - zone: Optional[str] = None + self, project_id: str, operation_name: str, zone: Optional[str] = None ) -> None: """ Waits for the named operation to complete - checks status of the async call. @@ -348,11 +331,12 @@ def _wait_for_operation_to_complete( service=service, operation_name=operation_name, project_id=project_id, - num_retries=self.num_retries + num_retries=self.num_retries, ) else: operation_response = self._check_zone_operation_status( - service, operation_name, project_id, zone, self.num_retries) + service, operation_name, project_id, zone, self.num_retries + ) if operation_response.get("status") == GceOperationStatus.DONE: error = operation_response.get("error") if error: @@ -366,23 +350,20 @@ def _wait_for_operation_to_complete( @staticmethod def _check_zone_operation_status( - service: Any, - operation_name: str, - project_id: str, - zone: str, - num_retries: int + service: Any, operation_name: str, project_id: str, zone: str, num_retries: int ) -> Dict: - return service.zoneOperations().get( - project=project_id, zone=zone, operation=operation_name).execute( - num_retries=num_retries) + return ( + service.zoneOperations() + .get(project=project_id, zone=zone, operation=operation_name) + .execute(num_retries=num_retries) + ) @staticmethod def _check_global_operation_status( - service: Any, - operation_name: str, - project_id: str, - num_retries: int + service: Any, operation_name: str, project_id: str, num_retries: int ) -> Dict: - return service.globalOperations().get( - project=project_id, operation=operation_name).execute( - num_retries=num_retries) + return ( + service.globalOperations() + .get(project=project_id, operation=operation_name) + .execute(num_retries=num_retries) + ) diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py index 4b16aae00cd39..d2e5b31191e03 100644 --- a/airflow/providers/google/cloud/hooks/datacatalog.py +++ b/airflow/providers/google/cloud/hooks/datacatalog.py @@ -20,7 +20,13 @@ from google.api_core.retry import Retry from google.cloud.datacatalog_v1beta1 import DataCatalogClient from google.cloud.datacatalog_v1beta1.types import ( - Entry, EntryGroup, FieldMask, SearchCatalogRequest, Tag, TagTemplate, TagTemplateField, + Entry, + EntryGroup, + FieldMask, + SearchCatalogRequest, + Tag, + TagTemplate, + TagTemplateField, ) from airflow import AirflowException @@ -55,9 +61,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client: Optional[DataCatalogClient] = None @@ -67,8 +71,7 @@ def get_conn(self) -> DataCatalogClient: """ if not self._client: self._client = DataCatalogClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) return self._client @@ -937,7 +940,10 @@ def search_catalog( self.log.info( "Searching catalog: scope=%s, query=%s, page_size=%s, order_by=%s", - scope, query, page_size, order_by + scope, + query, + page_size, + order_by, ) result = client.search_catalog( scope=scope, diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 569de11c1f5a1..be02ee0054a7a 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -52,7 +52,6 @@ def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]: - def _wrapper(func: T) -> T: """ Decorator that provides fallback for location from `region` key in `variables` parameters. @@ -60,11 +59,13 @@ def _wrapper(func: T) -> T: :param func: function to wrap :return: result of the function call """ + @functools.wraps(func) def inner_wrapper(self: "DataflowHook", *args, **kwargs): if args: raise AirflowException( - "You must use keyword arguments in this methods rather than positional") + "You must use keyword arguments in this methods rather than positional" + ) parameter_location = kwargs.get(parameter_name) variables_location = kwargs.get('variables', {}).get(variable_key_name) @@ -82,6 +83,7 @@ def inner_wrapper(self: "DataflowHook", *args, **kwargs): kwargs['variables'] = copy_variables return func(self, *args, **kwargs) + return cast(T, inner_wrapper) return _wrapper @@ -95,6 +97,7 @@ class DataflowJobStatus: """ Helper class with Dataflow job statuses. """ + JOB_STATE_DONE = "JOB_STATE_DONE" JOB_STATE_RUNNING = "JOB_STATE_RUNNING" JOB_STATE_FAILED = "JOB_STATE_FAILED" @@ -109,6 +112,7 @@ class DataflowJobType: """ Helper class with Dataflow job types. """ + JOB_TYPE_UNKNOWN = "JOB_TYPE_UNKNOWN" JOB_TYPE_BATCH = "JOB_TYPE_BATCH" JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING" @@ -130,6 +134,7 @@ class _DataflowJobsController(LoggingMixin): :param multiple_jobs: If set to true this task will be searched by name prefix (``name`` parameter), not by specific job ID, then actions will be performed on all matching jobs. """ + def __init__( self, dataflow: Any, @@ -139,7 +144,7 @@ def __init__( name: Optional[str] = None, job_id: Optional[str] = None, num_retries: int = 0, - multiple_jobs: bool = False + multiple_jobs: bool = False, ) -> None: super().__init__() @@ -188,35 +193,37 @@ def _get_current_jobs(self) -> List[Dict]: raise Exception('Missing both dataflow job ID and name.') def _fetch_job_by_id(self, job_id: str) -> Dict: - return self._dataflow.projects().locations().jobs().get( - projectId=self._project_number, - location=self._job_location, - jobId=job_id - ).execute(num_retries=self._num_retries) + return ( + self._dataflow.projects() + .locations() + .jobs() + .get(projectId=self._project_number, location=self._job_location, jobId=job_id) + .execute(num_retries=self._num_retries) + ) def _fetch_all_jobs(self) -> List[Dict]: - request = self._dataflow.projects().locations().jobs().list( - projectId=self._project_number, - location=self._job_location + request = ( + self._dataflow.projects() + .locations() + .jobs() + .list(projectId=self._project_number, location=self._job_location) ) jobs = [] # type: List[Dict] while request is not None: response = request.execute(num_retries=self._num_retries) jobs.extend(response["jobs"]) - request = self._dataflow.projects().locations().jobs().list_next( - previous_request=request, - previous_response=response + request = ( + self._dataflow.projects() + .locations() + .jobs() + .list_next(previous_request=request, previous_response=response) ) return jobs def _fetch_jobs_by_prefix_name(self, prefix_name: str) -> List[Dict]: jobs = self._fetch_all_jobs() - jobs = [ - job - for job in jobs - if job['name'].startswith(prefix_name) - ] + jobs = [job for job in jobs if job['name'].startswith(prefix_name)] return jobs def _refresh_jobs(self) -> None: @@ -230,13 +237,9 @@ def _refresh_jobs(self) -> None: if self._jobs: for job in self._jobs: - self.log.info( - 'Google Cloud DataFlow job %s is state: %s', job['name'], job['currentState'] - ) + self.log.info('Google Cloud DataFlow job %s is state: %s', job['name'], job['currentState']) else: - self.log.info( - 'Google Cloud DataFlow job not available yet..' - ) + self.log.info('Google Cloud DataFlow job not available yet..') def _check_dataflow_job_state(self, job) -> bool: """ @@ -250,16 +253,18 @@ def _check_dataflow_job_state(self, job) -> bool: if DataflowJobStatus.JOB_STATE_DONE == job['currentState']: return True elif DataflowJobStatus.JOB_STATE_FAILED == job['currentState']: - raise Exception("Google Cloud Dataflow job {} has failed.".format( - job['name'])) + raise Exception("Google Cloud Dataflow job {} has failed.".format(job['name'])) elif DataflowJobStatus.JOB_STATE_CANCELLED == job['currentState']: - raise Exception("Google Cloud Dataflow job {} was cancelled.".format( - job['name'])) - elif DataflowJobStatus.JOB_STATE_RUNNING == job['currentState'] and \ - DataflowJobType.JOB_TYPE_STREAMING == job['type']: + raise Exception("Google Cloud Dataflow job {} was cancelled.".format(job['name'])) + elif ( + DataflowJobStatus.JOB_STATE_RUNNING == job['currentState'] + and DataflowJobType.JOB_TYPE_STREAMING == job['type'] + ): return True - elif job['currentState'] in {DataflowJobStatus.JOB_STATE_RUNNING, - DataflowJobStatus.JOB_STATE_PENDING}: + elif job['currentState'] in { + DataflowJobStatus.JOB_STATE_RUNNING, + DataflowJobStatus.JOB_STATE_PENDING, + }: return False self.log.debug("Current job: %s", str(job)) raise Exception( @@ -301,13 +306,14 @@ def cancel(self) -> None: self.log.info("Canceling jobs: %s", ", ".join(job_ids)) for job_id in job_ids: batch.add( - self._dataflow.projects().locations().jobs().update( + self._dataflow.projects() + .locations() + .jobs() + .update( projectId=self._project_number, location=self._job_location, jobId=job_id, - body={ - "requestedState": DataflowJobStatus.JOB_STATE_CANCELLED - } + body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED}, ) ) batch.execute() @@ -315,20 +321,15 @@ def cancel(self) -> None: class _DataflowRunner(LoggingMixin): def __init__( - self, - cmd: List[str], - on_new_job_id_callback: Optional[Callable[[str], None]] = None + self, cmd: List[str], on_new_job_id_callback: Optional[Callable[[str], None]] = None ) -> None: super().__init__() self.log.info("Running command: %s", ' '.join(shlex.quote(c) for c in cmd)) self.on_new_job_id_callback = on_new_job_id_callback self.job_id: Optional[str] = None self._proc = subprocess.Popen( - cmd, - shell=False, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - close_fds=True) + cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True + ) def _process_fd(self, fd): """ @@ -421,9 +422,7 @@ def __init__( ) -> None: self.poll_sleep = poll_sleep super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) def get_conn(self): @@ -431,8 +430,7 @@ def get_conn(self): Returns a Google Cloud Dataflow service object. """ http_authorized = self._authorize() - return build( - 'dataflow', 'v1b3', http=http_authorized, cache_discovery=False) + return build('dataflow', 'v1b3', http=http_authorized, cache_discovery=False) @GoogleBaseHook.provide_gcp_credential_file def _start_dataflow( @@ -444,13 +442,10 @@ def _start_dataflow( project_id: str, multiple_jobs: bool = False, on_new_job_id_callback: Optional[Callable[[str], None]] = None, - location: str = DEFAULT_DATAFLOW_LOCATION + location: str = DEFAULT_DATAFLOW_LOCATION, ) -> None: cmd = command_prefix + self._build_cmd(variables, label_formatter, project_id) - runner = _DataflowRunner( - cmd=cmd, - on_new_job_id_callback=on_new_job_id_callback - ) + runner = _DataflowRunner(cmd=cmd, on_new_job_id_callback=on_new_job_id_callback) job_id = runner.wait_for_done() job_controller = _DataflowJobsController( dataflow=self.get_conn(), @@ -460,7 +455,7 @@ def _start_dataflow( poll_sleep=self.poll_sleep, job_id=job_id, num_retries=self.num_retries, - multiple_jobs=multiple_jobs + multiple_jobs=multiple_jobs, ) job_controller.wait_for_done() @@ -477,7 +472,7 @@ def start_java_dataflow( append_job_name: bool = True, multiple_jobs: bool = False, on_new_job_id_callback: Optional[Callable[[str], None]] = None, - location: str = DEFAULT_DATAFLOW_LOCATION + location: str = DEFAULT_DATAFLOW_LOCATION, ) -> None: """ Starts Dataflow java job. @@ -506,11 +501,9 @@ def start_java_dataflow( variables['region'] = location def label_formatter(labels_dict): - return ['--labels={}'.format( - json.dumps(labels_dict).replace(' ', ''))] + return ['--labels={}'.format(json.dumps(labels_dict).replace(' ', ''))] - command_prefix = (["java", "-cp", jar, job_class] if job_class - else ["java", "-jar", jar]) + command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar] self._start_dataflow( variables=variables, name=name, @@ -519,7 +512,7 @@ def label_formatter(labels_dict): project_id=project_id, multiple_jobs=multiple_jobs, on_new_job_id_callback=on_new_job_id_callback, - location=location + location=location, ) @_fallback_to_location_from_variables @@ -534,7 +527,7 @@ def start_template_dataflow( project_id: str, append_job_name: bool = True, on_new_job_id_callback: Optional[Callable[[str], None]] = None, - location: str = DEFAULT_DATAFLOW_LOCATION + location: str = DEFAULT_DATAFLOW_LOCATION, ) -> Dict: """ Starts Dataflow template job. @@ -565,15 +558,17 @@ def start_template_dataflow( name = self._build_dataflow_job_name(job_name, append_job_name) service = self.get_conn() - request = service.projects().locations().templates().launch( # pylint: disable=no-member - projectId=project_id, - location=location, - gcsPath=dataflow_template, - body={ - "jobName": name, - "parameters": parameters, - "environment": variables - } + # pylint: disable=no-member + request = ( + service.projects() + .locations() + .templates() + .launch( + projectId=project_id, + location=location, + gcsPath=dataflow_template, + body={"jobName": name, "parameters": parameters, "environment": variables}, + ) ) response = request.execute(num_retries=self.num_retries) @@ -588,7 +583,8 @@ def start_template_dataflow( job_id=job_id, location=location, poll_sleep=self.poll_sleep, - num_retries=self.num_retries) + num_retries=self.num_retries, + ) jobs_controller.wait_for_done() return response["job"] @@ -607,7 +603,7 @@ def start_python_dataflow( # pylint: disable=too-many-arguments py_system_site_packages: bool = False, append_job_name: bool = True, on_new_job_id_callback: Optional[Callable[[str], None]] = None, - location: str = DEFAULT_DATAFLOW_LOCATION + location: str = DEFAULT_DATAFLOW_LOCATION, ): """ Starts Dataflow job. @@ -650,8 +646,7 @@ def start_python_dataflow( # pylint: disable=too-many-arguments variables['region'] = location def label_formatter(labels_dict): - return ['--labels={}={}'.format(key, value) - for key, value in labels_dict.items()] + return ['--labels={}={}'.format(key, value) for key, value in labels_dict.items()] if py_requirements is not None: if not py_requirements and not py_system_site_packages: @@ -683,7 +678,7 @@ def label_formatter(labels_dict): label_formatter=label_formatter, project_id=project_id, on_new_job_id_callback=on_new_job_id_callback, - location=location + location=location, ) else: command_prefix = [py_interpreter] + py_options + [dataflow] @@ -695,7 +690,7 @@ def label_formatter(labels_dict): label_formatter=label_formatter, project_id=project_id, on_new_job_id_callback=on_new_job_id_callback, - location=location + location=location, ) @staticmethod @@ -706,7 +701,8 @@ def _build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str raise ValueError( 'Invalid job_name ({}); the name must consist of' 'only the characters [-a-z0-9], starting with a ' - 'letter and ending with a letter or number '.format(base_job_name)) + 'letter and ending with a letter or number '.format(base_job_name) + ) if append_job_name: safe_job_name = base_job_name + "-" + str(uuid.uuid4())[:8] @@ -748,7 +744,7 @@ def is_job_dataflow_running( name: str, project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, - variables: Optional[Dict] = None + variables: Optional[Dict] = None, ) -> bool: """ Helper method to check if jos is still running in dataflow @@ -766,13 +762,16 @@ def is_job_dataflow_running( if variables: warnings.warn( "The variables parameter has been deprecated. You should pass location using " - "the location parameter.", DeprecationWarning, stacklevel=4) + "the location parameter.", + DeprecationWarning, + stacklevel=4, + ) jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), project_number=project_id, name=name, location=location, - poll_sleep=self.poll_sleep + poll_sleep=self.poll_sleep, ) return jobs_controller.is_job_running() @@ -805,6 +804,6 @@ def cancel_job( name=job_name, job_id=job_id, location=location, - poll_sleep=self.poll_sleep + poll_sleep=self.poll_sleep, ) jobs_controller.cancel() diff --git a/airflow/providers/google/cloud/hooks/datafusion.py b/airflow/providers/google/cloud/hooks/datafusion.py index d153fec629fab..d9de5ef7d3533 100644 --- a/airflow/providers/google/cloud/hooks/datafusion.py +++ b/airflow/providers/google/cloud/hooks/datafusion.py @@ -67,9 +67,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -125,8 +123,7 @@ def wait_for_pipeline_state( return if current_state in failure_states: raise AirflowException( - f"Pipeline {pipeline_name} state {current_state} is not " - f"one of {success_states}" + f"Pipeline {pipeline_name} state {current_state} is not " f"one of {success_states}" ) sleep(30) @@ -146,9 +143,7 @@ def _parent(project_id: str, location: str) -> str: @staticmethod def _base_url(instance_url: str, namespace: str) -> str: - return os.path.join( - instance_url, "v3", "namespaces", quote(namespace), "apps" - ) + return os.path.join(instance_url, "v3", "namespaces", quote(namespace), "apps") def _cdap_request( self, url: str, method: str, body: Optional[Union[List, Dict]] = None @@ -157,9 +152,7 @@ def _cdap_request( request = google.auth.transport.requests.Request() credentials = self._get_credentials() - credentials.before_request( - request=request, method=method, url=url, headers=headers - ) + credentials.before_request(request=request, method=method, url=url, headers=headers) payload = json.dumps(body) if body else None @@ -172,18 +165,11 @@ def get_conn(self) -> Resource: """ if not self._conn: http_authorized = self._authorize() - self._conn = build( - "datafusion", - self.api_version, - http=http_authorized, - cache_discovery=False, - ) + self._conn = build("datafusion", self.api_version, http=http_authorized, cache_discovery=False,) return self._conn @GoogleBaseHook.fallback_to_default_project_id - def restart_instance( - self, instance_name: str, location: str, project_id: str - ) -> Operation: + def restart_instance(self, instance_name: str, location: str, project_id: str) -> Operation: """ Restart a single Data Fusion instance. At the end of an operation instance is fully restarted. @@ -206,9 +192,7 @@ def restart_instance( return operation @GoogleBaseHook.fallback_to_default_project_id - def delete_instance( - self, instance_name: str, location: str, project_id: str - ) -> Operation: + def delete_instance(self, instance_name: str, location: str, project_id: str) -> Operation: """ Deletes a single Date Fusion instance. @@ -231,11 +215,7 @@ def delete_instance( @GoogleBaseHook.fallback_to_default_project_id def create_instance( - self, - instance_name: str, - instance: Dict[str, Any], - location: str, - project_id: str, + self, instance_name: str, instance: Dict[str, Any], location: str, project_id: str, ) -> Operation: """ Creates a new Data Fusion instance in the specified project and location. @@ -255,19 +235,13 @@ def create_instance( .projects() .locations() .instances() - .create( - parent=self._parent(project_id, location), - body=instance, - instanceId=instance_name, - ) + .create(parent=self._parent(project_id, location), body=instance, instanceId=instance_name,) .execute(num_retries=self.num_retries) ) return operation @GoogleBaseHook.fallback_to_default_project_id - def get_instance( - self, instance_name: str, location: str, project_id: str - ) -> Dict[str, Any]: + def get_instance(self, instance_name: str, location: str, project_id: str) -> Dict[str, Any]: """ Gets details of a single Data Fusion instance. @@ -290,12 +264,7 @@ def get_instance( @GoogleBaseHook.fallback_to_default_project_id def patch_instance( - self, - instance_name: str, - instance: Dict[str, Any], - update_mask: str, - location: str, - project_id: str, + self, instance_name: str, instance: Dict[str, Any], update_mask: str, location: str, project_id: str, ) -> Operation: """ Updates a single Data Fusion instance. @@ -323,20 +292,14 @@ def patch_instance( .locations() .instances() .patch( - name=self._name(project_id, location, instance_name), - updateMask=update_mask, - body=instance, + name=self._name(project_id, location, instance_name), updateMask=update_mask, body=instance, ) .execute(num_retries=self.num_retries) ) return operation def create_pipeline( - self, - pipeline_name: str, - pipeline: Dict[str, Any], - instance_url: str, - namespace: str = "default", + self, pipeline_name: str, pipeline: Dict[str, Any], instance_url: str, namespace: str = "default", ) -> None: """ Creates a Cloud Data Fusion pipeline. @@ -356,9 +319,7 @@ def create_pipeline( url = os.path.join(self._base_url(instance_url, namespace), quote(pipeline_name)) response = self._cdap_request(url=url, method="PUT", body=pipeline) if response.status != 200: - raise AirflowException( - f"Creating a pipeline failed with code {response.status}" - ) + raise AirflowException(f"Creating a pipeline failed with code {response.status}") def delete_pipeline( self, @@ -387,9 +348,7 @@ def delete_pipeline( response = self._cdap_request(url=url, method="DELETE", body=None) if response.status != 200: - raise AirflowException( - f"Deleting a pipeline failed with code {response.status}" - ) + raise AirflowException(f"Deleting a pipeline failed with code {response.status}") def list_pipelines( self, @@ -423,17 +382,11 @@ def list_pipelines( response = self._cdap_request(url=url, method="GET", body=None) if response.status != 200: - raise AirflowException( - f"Listing pipelines failed with code {response.status}" - ) + raise AirflowException(f"Listing pipelines failed with code {response.status}") return json.loads(response.data) def _get_workflow_state( - self, - pipeline_name: str, - instance_url: str, - pipeline_id: str, - namespace: str = "default", + self, pipeline_name: str, instance_url: str, pipeline_id: str, namespace: str = "default", ) -> str: url = os.path.join( self._base_url(instance_url, namespace), @@ -445,9 +398,7 @@ def _get_workflow_state( ) response = self._cdap_request(url=url, method="GET") if response.status != 200: - raise AirflowException( - f"Retrieving a pipeline state failed with code {response.status}" - ) + raise AirflowException(f"Retrieving a pipeline state failed with code {response.status}") workflow = json.loads(response.data) return workflow["status"] @@ -475,25 +426,19 @@ def start_pipeline( # TODO: This API endpoint starts multiple pipelines. There will eventually be a fix # return the run Id as part of the API request to run a single pipeline. # https://github.com/apache/airflow/pull/8954#discussion_r438223116 - url = os.path.join( - instance_url, - "v3", - "namespaces", - quote(namespace), - "start", - ) + url = os.path.join(instance_url, "v3", "namespaces", quote(namespace), "start",) runtime_args = runtime_args or {} - body = [{ - "appId": pipeline_name, - "programType": "workflow", - "programId": "DataPipelineWorkflow", - "runtimeargs": runtime_args - }] + body = [ + { + "appId": pipeline_name, + "programType": "workflow", + "programId": "DataPipelineWorkflow", + "runtimeargs": runtime_args, + } + ] response = self._cdap_request(url=url, method="POST", body=body) if response.status != 200: - raise AirflowException( - f"Starting a pipeline failed with code {response.status}" - ) + raise AirflowException(f"Starting a pipeline failed with code {response.status}") response_json = json.loads(response.data) pipeline_id = response_json[0]["runId"] @@ -506,9 +451,7 @@ def start_pipeline( ) return pipeline_id - def stop_pipeline( - self, pipeline_name: str, instance_url: str, namespace: str = "default" - ) -> None: + def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = "default") -> None: """ Stops a Cloud Data Fusion pipeline. Works for both batch and stream pipelines. @@ -530,6 +473,4 @@ def stop_pipeline( ) response = self._cdap_request(url=url, method="POST") if response.status != 200: - raise AirflowException( - f"Stopping a pipeline failed with code {response.status}" - ) + raise AirflowException(f"Stopping a pipeline failed with code {response.status}") diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index 78769cf9bbda0..823db501b72b1 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -28,10 +28,17 @@ from cached_property import cached_property from google.api_core.retry import Retry from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module - ClusterControllerClient, JobControllerClient, WorkflowTemplateServiceClient, + ClusterControllerClient, + JobControllerClient, + WorkflowTemplateServiceClient, ) from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module - Cluster, Duration, FieldMask, Job, JobStatus, WorkflowTemplate, + Cluster, + Duration, + FieldMask, + Job, + JobStatus, + WorkflowTemplate, ) from airflow.exceptions import AirflowException @@ -43,28 +50,23 @@ class DataProcJobBuilder: """ A helper class for building Dataproc job. """ + def __init__( self, project_id: str, task_id: str, cluster_name: str, job_type: str, - properties: Optional[Dict[str, str]] = None + properties: Optional[Dict[str, str]] = None, ) -> None: name = task_id + "_" + str(uuid.uuid4())[:8] self.job_type = job_type self.job = { "job": { - "reference": { - "project_id": project_id, - "job_id": name, - }, - "placement": { - "cluster_name": cluster_name - }, + "reference": {"project_id": project_id, "job_id": name,}, + "placement": {"cluster_name": cluster_name}, "labels": {'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')}, - job_type: { - } + job_type: {}, } } # type: Dict[str, Any] if properties is not None: @@ -215,14 +217,12 @@ def get_cluster_client(self, location: Optional[str] = None) -> ClusterControlle """ Returns ClusterControllerClient. """ - client_options = { - 'api_endpoint': '{}-dataproc.googleapis.com:443'.format(location) - } if location else None + client_options = ( + {'api_endpoint': '{}-dataproc.googleapis.com:443'.format(location)} if location else None + ) return ClusterControllerClient( - credentials=self._get_credentials(), - client_info=self.client_info, - client_options=client_options + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options ) @cached_property @@ -231,22 +231,19 @@ def get_template_client(self) -> WorkflowTemplateServiceClient: Returns WorkflowTemplateServiceClient. """ return WorkflowTemplateServiceClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) def get_job_client(self, location: Optional[str] = None) -> JobControllerClient: """ Returns JobControllerClient. """ - client_options = { - 'api_endpoint': '{}-dataproc.googleapis.com:443'.format(location) - } if location else None + client_options = ( + {'api_endpoint': '{}-dataproc.googleapis.com:443'.format(location)} if location else None + ) return JobControllerClient( - credentials=self._get_credentials(), - client_info=self.client_info, - client_options=client_options + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options ) @GoogleBaseHook.fallback_to_default_project_id @@ -591,11 +588,7 @@ def create_workflow_template( client = self.get_template_client parent = client.region_path(project_id, location) return client.create_workflow_template( - parent=parent, - template=template, - retry=retry, - timeout=timeout, - metadata=metadata + parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata ) @GoogleBaseHook.fallback_to_default_project_id @@ -651,7 +644,7 @@ def instantiate_workflow_template( request_id=request_id, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) return operation @@ -697,18 +690,12 @@ def instantiate_inline_workflow_template( request_id=request_id, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) return operation @GoogleBaseHook.fallback_to_default_project_id - def wait_for_job( - self, - job_id: str, - location: str, - project_id: str, - wait_time: int = 10 - ): + def wait_for_job(self, job_id: str, location: str, project_id: str, wait_time: int = 10): """ Helper method which polls a job to check if it finishes. @@ -724,11 +711,7 @@ def wait_for_job( state = None while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED): time.sleep(wait_time) - job = self.get_job( - location=location, - job_id=job_id, - project_id=project_id - ) + job = self.get_job(location=location, job_id=job_id, project_id=project_id) state = job.status.state if state == JobStatus.ERROR: raise AirflowException('Job failed:\n{}'.format(job)) @@ -770,7 +753,7 @@ def get_job( job_id=job_id, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) return job @@ -816,7 +799,7 @@ def submit_job( request_id=request_id, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) def submit( @@ -824,7 +807,7 @@ def submit( project_id: str, job: Dict, region: str = 'global', - job_error_states: Optional[Iterable[str]] = None # pylint: disable=unused-argument + job_error_states: Optional[Iterable[str]] = None, # pylint: disable=unused-argument ) -> None: """ Submits Google Cloud Dataproc job. @@ -839,22 +822,10 @@ def submit( :type job_error_states: List[str] """ # TODO: Remover one day - warnings.warn( - "This method is deprecated. Please use `submit_job`", - DeprecationWarning, - stacklevel=2 - ) - job_object = self.submit_job( - location=region, - project_id=project_id, - job=job - ) + warnings.warn("This method is deprecated. Please use `submit_job`", DeprecationWarning, stacklevel=2) + job_object = self.submit_job(location=region, project_id=project_id, job=job) job_id = job_object.reference.job_id - self.wait_for_job( - job_id=job_id, - location=region, - project_id=project_id - ) + self.wait_for_job(job_id=job_id, location=region, project_id=project_id) @GoogleBaseHook.fallback_to_default_project_id def cancel_job( diff --git a/airflow/providers/google/cloud/hooks/datastore.py b/airflow/providers/google/cloud/hooks/datastore.py index 27685ec5b20bb..c70621ddc59d5 100644 --- a/airflow/providers/google/cloud/hooks/datastore.py +++ b/airflow/providers/google/cloud/hooks/datastore.py @@ -51,12 +51,13 @@ def __init__( if datastore_conn_id: warnings.warn( "The datastore_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=2) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=2, + ) gcp_conn_id = datastore_conn_id super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.connection = None self.api_version = api_version @@ -70,8 +71,9 @@ def get_conn(self) -> Any: """ if not self.connection: http_authorized = self._authorize() - self.connection = build('datastore', self.api_version, http=http_authorized, - cache_discovery=False) + self.connection = build( + 'datastore', self.api_version, http=http_authorized, cache_discovery=False + ) return self.connection @@ -92,10 +94,11 @@ def allocate_ids(self, partial_keys: List, project_id: str) -> List: """ conn = self.get_conn() # type: Any - resp = (conn # pylint: disable=no-member - .projects() - .allocateIds(projectId=project_id, body={'keys': partial_keys}) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .allocateIds(projectId=project_id, body={'keys': partial_keys}) + .execute(num_retries=self.num_retries) + ) return resp['keys'] @@ -116,10 +119,11 @@ def begin_transaction(self, project_id: str, transaction_options: Dict[str, Any] """ conn = self.get_conn() # type: Any - resp = (conn # pylint: disable=no-member - .projects() - .beginTransaction(projectId=project_id, body={"transactionOptions": transaction_options}) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .beginTransaction(projectId=project_id, body={"transactionOptions": transaction_options}) + .execute(num_retries=self.num_retries) + ) return resp['transaction'] @@ -140,10 +144,11 @@ def commit(self, body: Dict, project_id: str) -> Dict: """ conn = self.get_conn() # type: Any - resp = (conn # pylint: disable=no-member - .projects() - .commit(projectId=project_id, body=body) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .commit(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) return resp @@ -180,10 +185,11 @@ def lookup( body['readConsistency'] = read_consistency if transaction: body['transaction'] = transaction - resp = (conn # pylint: disable=no-member - .projects() - .lookup(projectId=project_id, body=body) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .lookup(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) return resp @@ -223,10 +229,11 @@ def run_query(self, body: Dict, project_id: str) -> Dict: """ conn = self.get_conn() # type: Any - resp = (conn # pylint: disable=no-member - .projects() - .runQuery(projectId=project_id, body=body) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .runQuery(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) return resp['batch'] @@ -244,11 +251,12 @@ def get_operation(self, name: str) -> Dict: """ conn = self.get_conn() # type: Any - resp = (conn # pylint: disable=no-member - .projects() - .operations() - .get(name=name) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .operations() + .get(name=name) + .execute(num_retries=self.num_retries) + ) return resp @@ -266,11 +274,12 @@ def delete_operation(self, name: str) -> Dict: """ conn = self.get_conn() # type: Any - resp = (conn # pylint: disable=no-member - .projects() - .operations() - .delete(name=name) - .execute(num_retries=self.num_retries)) + resp = ( + conn.projects() # pylint: disable=no-member + .operations() + .delete(name=name) + .execute(num_retries=self.num_retries) + ) return resp @@ -299,12 +308,12 @@ def poll_operation_until_done(self, name: str, polling_interval_in_seconds: int) @GoogleBaseHook.fallback_to_default_project_id def export_to_storage_bucket( - self, - bucket: str, - project_id: str, - namespace: Optional[str] = None, - entity_filter: Optional[Dict] = None, - labels: Optional[Dict[str, str]] = None, + self, + bucket: str, + project_id: str, + namespace: Optional[str] = None, + entity_filter: Optional[Dict] = None, + labels: Optional[Dict[str, str]] = None, ) -> Dict: """ Export entities from Cloud Datastore to Cloud Storage for backup. @@ -340,22 +349,23 @@ def export_to_storage_bucket( 'entityFilter': entity_filter, 'labels': labels, } # type: Dict - resp = (admin_conn # pylint: disable=no-member - .projects() - .export(projectId=project_id, body=body) - .execute(num_retries=self.num_retries)) + resp = ( + admin_conn.projects() # pylint: disable=no-member + .export(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) return resp @GoogleBaseHook.fallback_to_default_project_id def import_from_storage_bucket( - self, - bucket: str, - file: str, - project_id: str, - namespace: Optional[str] = None, - entity_filter: Optional[Dict] = None, - labels: Optional[Union[Dict, str]] = None, + self, + bucket: str, + file: str, + project_id: str, + namespace: Optional[str] = None, + entity_filter: Optional[Dict] = None, + labels: Optional[Union[Dict, str]] = None, ) -> Dict: """ Import a backup from Cloud Storage to Cloud Datastore. @@ -393,9 +403,10 @@ def import_from_storage_bucket( 'entityFilter': entity_filter, 'labels': labels, } # type: Dict - resp = (admin_conn # pylint: disable=no-member - .projects() - .import_(projectId=project_id, body=body) - .execute(num_retries=self.num_retries)) + resp = ( + admin_conn.projects() # pylint: disable=no-member + .import_(projectId=project_id, body=body) + .execute(num_retries=self.num_retries) + ) return resp diff --git a/airflow/providers/google/cloud/hooks/dlp.py b/airflow/providers/google/cloud/hooks/dlp.py index 08a242d6cf466..753c43041249e 100644 --- a/airflow/providers/google/cloud/hooks/dlp.py +++ b/airflow/providers/google/cloud/hooks/dlp.py @@ -28,10 +28,25 @@ from google.api_core.retry import Retry from google.cloud.dlp_v2 import DlpServiceClient from google.cloud.dlp_v2.types import ( - ByteContentItem, ContentItem, DeidentifyConfig, DeidentifyContentResponse, DeidentifyTemplate, DlpJob, - FieldMask, InspectConfig, InspectContentResponse, InspectJobConfig, InspectTemplate, JobTrigger, - ListInfoTypesResponse, RedactImageRequest, RedactImageResponse, ReidentifyContentResponse, - RiskAnalysisJobConfig, StoredInfoType, StoredInfoTypeConfig, + ByteContentItem, + ContentItem, + DeidentifyConfig, + DeidentifyContentResponse, + DeidentifyTemplate, + DlpJob, + FieldMask, + InspectConfig, + InspectContentResponse, + InspectJobConfig, + InspectTemplate, + JobTrigger, + ListInfoTypesResponse, + RedactImageRequest, + RedactImageResponse, + ReidentifyContentResponse, + RiskAnalysisJobConfig, + StoredInfoType, + StoredInfoTypeConfig, ) from airflow.exceptions import AirflowException @@ -74,9 +89,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -196,7 +209,7 @@ def create_dlp_job( timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, wait_until_finished: bool = True, - time_to_sleep_in_seconds: int = 60 + time_to_sleep_in_seconds: int = 60, ) -> DlpJob: """ Creates a new job to inspect storage or calculate risk metrics. diff --git a/airflow/providers/google/cloud/hooks/functions.py b/airflow/providers/google/cloud/hooks/functions.py index ab451d3022618..5b2805bdb2934 100644 --- a/airflow/providers/google/cloud/hooks/functions.py +++ b/airflow/providers/google/cloud/hooks/functions.py @@ -38,6 +38,7 @@ class CloudFunctionsHook(GoogleBaseHook): All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """ + _conn = None # type: Optional[Any] def __init__( @@ -48,9 +49,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -77,8 +76,9 @@ def get_conn(self): """ if not self._conn: http_authorized = self._authorize() - self._conn = build('cloudfunctions', self.api_version, - http=http_authorized, cache_discovery=False) + self._conn = build( + 'cloudfunctions', self.api_version, http=http_authorized, cache_discovery=False + ) return self._conn def get_function(self, name: str) -> Dict: @@ -90,8 +90,10 @@ def get_function(self, name: str) -> Dict: :return: A Cloud Functions object representing the function. :rtype: dict """ + # fmt: off return self.get_conn().projects().locations().functions().get( # pylint: disable=no-member name=name).execute(num_retries=self.num_retries) + # fmt: on @GoogleBaseHook.fallback_to_default_project_id def create_new_function(self, location: str, body: Dict, project_id: str) -> None: @@ -107,10 +109,12 @@ def create_new_function(self, location: str, body: Dict, project_id: str) -> Non :type project_id: str :return: None """ + # fmt: off response = self.get_conn().projects().locations().functions().create( # pylint: disable=no-member location=self._full_location(project_id, location), body=body ).execute(num_retries=self.num_retries) + # fmt: on operation_name = response["name"] self._wait_for_operation_to_complete(operation_name=operation_name) @@ -126,11 +130,13 @@ def update_function(self, name: str, body: Dict, update_mask: List[str]) -> None :type update_mask: [str] :return: None """ + # fmt: off response = self.get_conn().projects().locations().functions().patch( # pylint: disable=no-member updateMask=",".join(update_mask), name=name, body=body ).execute(num_retries=self.num_retries) + # fmt: on operation_name = response["name"] self._wait_for_operation_to_complete(operation_name=operation_name) @@ -149,11 +155,13 @@ def upload_function_zip(self, location: str, zip_path: str, project_id: str) -> :return: The upload URL that was returned by generateUploadUrl method. :rtype: str """ + # fmt: off # pylint: disable=no-member # noqa response = \ self.get_conn().projects().locations().functions().generateUploadUrl( parent=self._full_location(project_id, location) ).execute(num_retries=self.num_retries) + # fmt: on upload_url = response.get('uploadUrl') with open(zip_path, 'rb') as file: @@ -163,10 +171,7 @@ def upload_function_zip(self, location: str, zip_path: str, project_id: str) -> # Those two headers needs to be specified according to: # https://cloud.google.com/functions/docs/reference/rest/v1/projects.locations.functions/generateUploadUrl # nopep8 - headers={ - 'Content-type': 'application/zip', - 'x-goog-content-length-range': '0,104857600', - } + headers={'Content-type': 'application/zip', 'x-goog-content-length-range': '0,104857600',}, ) return upload_url @@ -178,19 +183,15 @@ def delete_function(self, name: str) -> None: :type name: str :return: None """ + # fmt: off response = self.get_conn().projects().locations().functions().delete( # pylint: disable=no-member name=name).execute(num_retries=self.num_retries) + # fmt: on operation_name = response["name"] self._wait_for_operation_to_complete(operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id - def call_function( - self, - function_id: str, - input_data: Dict, - location: str, - project_id: str, - ) -> Dict: + def call_function(self, function_id: str, input_data: Dict, location: str, project_id: str,) -> Dict: """ Synchronously invokes a deployed Cloud Function. To be used for testing purposes as very limited traffic is allowed. @@ -207,14 +208,14 @@ def call_function( :return: None """ name = "projects/{project_id}/locations/{location}/functions/{function_id}".format( - project_id=project_id, - location=location, - function_id=function_id + project_id=project_id, location=location, function_id=function_id ) + # fmt: off response = self.get_conn().projects().locations().functions().call( # pylint: disable=no-member name=name, body=input_data ).execute(num_retries=self.num_retries) + # fmt: on if 'error' in response: raise AirflowException(response['error']) return response @@ -232,9 +233,11 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> Dict: """ service = self.get_conn() while True: + # fmt: off operation_response = service.operations().get( # pylint: disable=no-member name=operation_name, ).execute(num_retries=self.num_retries) + # fmt: on if operation_response.get("done"): response = operation_response.get("response") error = operation_response.get("error") diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 470477f1cbecb..d8cbd3405d394 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -58,13 +58,14 @@ def _fallback_object_url_to_object_name_and_bucket_name( :type object_name_keyword_arg_name: str :return: Decorator """ - def _wrapper(func: T): + def _wrapper(func: T): @functools.wraps(func) - def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT: + def _inner_wrapper(self: "GCSHook", *args, **kwargs) -> RT: if args: raise AirflowException( - "You must use keyword arguments in this methods rather than positional") + "You must use keyword arguments in this methods rather than positional" + ) object_url = kwargs.get(object_url_keyword_arg_name) bucket_name = kwargs.get(bucket_name_keyword_arg_name) @@ -100,7 +101,9 @@ def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT: ) return func(self, *args, **kwargs) + return cast(T, _inner_wrapper) + return _wrapper @@ -124,12 +127,13 @@ def __init__( if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=2) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=2, + ) gcp_conn_id = google_cloud_storage_conn_id super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) def get_conn(self): @@ -137,14 +141,13 @@ def get_conn(self): Returns a Google Cloud Storage service object. """ if not self._conn: - self._conn = storage.Client(credentials=self._get_credentials(), - client_info=self.client_info, - project=self.project_id) + self._conn = storage.Client( + credentials=self._get_credentials(), client_info=self.client_info, project=self.project_id + ) return self._conn - def copy(self, source_bucket, source_object, destination_bucket=None, - destination_object=None): + def copy(self, source_bucket, source_object, destination_bucket=None, destination_object=None): """ Copies an object from a bucket to another, with renaming if requested. @@ -165,13 +168,12 @@ def copy(self, source_bucket, source_object, destination_bucket=None, destination_bucket = destination_bucket or source_bucket destination_object = destination_object or source_object - if source_bucket == destination_bucket and \ - source_object == destination_object: + if source_bucket == destination_bucket and source_object == destination_object: raise ValueError( 'Either source/destination bucket or source/destination object ' - 'must be different, not both the same: bucket=%s, object=%s' % - (source_bucket, source_object)) + 'must be different, not both the same: bucket=%s, object=%s' % (source_bucket, source_object) + ) if not source_bucket or not source_object: raise ValueError('source_bucket and source_object cannot be empty.') @@ -180,16 +182,18 @@ def copy(self, source_bucket, source_object, destination_bucket=None, source_object = source_bucket.blob(source_object) destination_bucket = client.bucket(destination_bucket) destination_object = source_bucket.copy_blob( - blob=source_object, - destination_bucket=destination_bucket, - new_name=destination_object) + blob=source_object, destination_bucket=destination_bucket, new_name=destination_object + ) - self.log.info('Object %s in bucket %s copied to object %s in bucket %s', - source_object.name, source_bucket.name, - destination_object.name, destination_bucket.name) + self.log.info( + 'Object %s in bucket %s copied to object %s in bucket %s', + source_object.name, + source_bucket.name, + destination_object.name, + destination_bucket.name, + ) - def rewrite(self, source_bucket, source_object, destination_bucket, - destination_object=None): + def rewrite(self, source_bucket, source_object, destination_bucket, destination_object=None): """ Has the same functionality as copy, except that will work on files over 5 TB, as well as when copying between locations and/or storage @@ -208,12 +212,11 @@ def rewrite(self, source_bucket, source_object, destination_bucket, :type destination_object: str """ destination_object = destination_object or source_object - if (source_bucket == destination_bucket and - source_object == destination_object): + if source_bucket == destination_bucket and source_object == destination_object: raise ValueError( 'Either source/destination bucket or source/destination object ' - 'must be different, not both the same: bucket=%s, object=%s' % - (source_bucket, source_object)) + 'must be different, not both the same: bucket=%s, object=%s' % (source_bucket, source_object) + ) if not source_bucket or not source_object: raise ValueError('source_bucket and source_object cannot be empty.') @@ -222,25 +225,25 @@ def rewrite(self, source_bucket, source_object, destination_bucket, source_object = source_bucket.blob(blob_name=source_object) destination_bucket = client.bucket(destination_bucket) - token, bytes_rewritten, total_bytes = destination_bucket.blob( - blob_name=destination_object).rewrite( + token, bytes_rewritten, total_bytes = destination_bucket.blob(blob_name=destination_object).rewrite( source=source_object ) - self.log.info('Total Bytes: %s | Bytes Written: %s', - total_bytes, bytes_rewritten) + self.log.info('Total Bytes: %s | Bytes Written: %s', total_bytes, bytes_rewritten) while token is not None: token, bytes_rewritten, total_bytes = destination_bucket.blob( - blob_name=destination_object).rewrite( - source=source_object, token=token - ) + blob_name=destination_object + ).rewrite(source=source_object, token=token) - self.log.info('Total Bytes: %s | Bytes Written: %s', - total_bytes, bytes_rewritten) - self.log.info('Object %s in bucket %s rewritten to object %s in bucket %s', - source_object.name, source_bucket.name, - destination_object, destination_bucket.name) + self.log.info('Total Bytes: %s | Bytes Written: %s', total_bytes, bytes_rewritten) + self.log.info( + 'Object %s in bucket %s rewritten to object %s in bucket %s', + source_object.name, + source_bucket.name, + destination_object, + destination_bucket.name, + ) def download(self, bucket_name, object_name, filename=None): """ @@ -279,7 +282,7 @@ def provide_file( self, bucket_name: Optional[str] = None, object_name: Optional[str] = None, - object_url: Optional[str] = None # pylint: disable=unused-argument + object_url: Optional[str] = None, # pylint: disable=unused-argument ): """ Downloads the file to a temporary directory and returns a file handle @@ -303,9 +306,16 @@ def provide_file( tmp_file.flush() yield tmp_file - def upload(self, bucket_name: str, object_name: str, filename: Optional[str] = None, - data: Optional[Union[str, bytes]] = None, mime_type: Optional[str] = None, gzip: bool = False, - encoding: str = 'utf-8') -> None: + def upload( + self, + bucket_name: str, + object_name: str, + filename: Optional[str] = None, + data: Optional[Union[str, bytes]] = None, + mime_type: Optional[str] = None, + gzip: bool = False, + encoding: str = 'utf-8', + ) -> None: """ Uploads a local file or file data as string or bytes to Google Cloud Storage. @@ -328,9 +338,11 @@ def upload(self, bucket_name: str, object_name: str, filename: Optional[str] = N bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) if filename and data: - raise ValueError("'filename' and 'data' parameter provided. Please " - "specify a single parameter, either 'filename' for " - "local file uploads or 'data' for file content uploads.") + raise ValueError( + "'filename' and 'data' parameter provided. Please " + "specify a single parameter, either 'filename' for " + "local file uploads or 'data' for file content uploads." + ) elif filename: if not mime_type: mime_type = 'application/octet-stream' @@ -342,8 +354,7 @@ def upload(self, bucket_name: str, object_name: str, filename: Optional[str] = N shutil.copyfileobj(f_in, f_out) filename = filename_gz - blob.upload_from_filename(filename=filename, - content_type=mime_type) + blob.upload_from_filename(filename=filename, content_type=mime_type) if gzip: os.remove(filename) self.log.info('File %s uploaded to %s in %s bucket', filename, object_name, bucket_name) @@ -357,12 +368,10 @@ def upload(self, bucket_name: str, object_name: str, filename: Optional[str] = N with gz.GzipFile(fileobj=out, mode="w") as f: f.write(data) data = out.getvalue() - blob.upload_from_string(data, - content_type=mime_type) + blob.upload_from_string(data, content_type=mime_type) self.log.info('Data stream uploaded to %s in %s bucket', object_name, bucket_name) else: - raise ValueError("'filename' and 'data' parameter missing. " - "One is required to upload to gcs.") + raise ValueError("'filename' and 'data' parameter missing. " "One is required to upload to gcs.") def exists(self, bucket_name, object_name): """ @@ -393,8 +402,7 @@ def get_blob_update_time(self, bucket_name, object_name): bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) if blob is None: - raise ValueError("Object ({}) not found in Bucket ({})".format( - object_name, bucket_name)) + raise ValueError("Object ({}) not found in Bucket ({})".format(object_name, bucket_name)) return blob.updated def is_updated_after(self, bucket_name, object_name, ts): @@ -412,6 +420,7 @@ def is_updated_after(self, bucket_name, object_name, ts): blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: import dateutil.tz + if not ts.tzinfo: ts = ts.replace(tzinfo=dateutil.tz.tzutc()) self.log.info("Verify object date: %s > %s", blob_update_time, ts) @@ -436,6 +445,7 @@ def is_updated_between(self, bucket_name, object_name, min_ts, max_ts): blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: import dateutil.tz + if not min_ts.tzinfo: min_ts = min_ts.replace(tzinfo=dateutil.tz.tzutc()) if not max_ts.tzinfo: @@ -460,6 +470,7 @@ def is_updated_before(self, bucket_name, object_name, ts): blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: import dateutil.tz + if not ts.tzinfo: ts = ts.replace(tzinfo=dateutil.tz.tzutc()) self.log.info("Verify object date: %s < %s", blob_update_time, ts) @@ -484,6 +495,7 @@ def is_older_than(self, bucket_name, object_name, seconds): from datetime import timedelta from airflow.utils import timezone + current_time = timezone.utcnow() given_time = current_time - timedelta(seconds=seconds) self.log.info("Verify object date: %s is older than %s", blob_update_time, given_time) @@ -556,7 +568,7 @@ def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimi page_token=page_token, prefix=prefix, delimiter=delimiter, - versions=versions + versions=versions, ) blob_names = [] @@ -586,9 +598,7 @@ def get_size(self, bucket_name, object_name): :type object_name: str """ - self.log.info('Checking the file size of object: %s in bucket_name: %s', - object_name, - bucket_name) + self.log.info('Checking the file size of object: %s in bucket_name: %s', object_name, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) @@ -606,8 +616,11 @@ def get_crc32c(self, bucket_name, object_name): storage bucket_name. :type object_name: str """ - self.log.info('Retrieving the crc32c checksum of ' - 'object_name: %s in bucket_name: %s', object_name, bucket_name) + self.log.info( + 'Retrieving the crc32c checksum of ' 'object_name: %s in bucket_name: %s', + object_name, + bucket_name, + ) client = self.get_conn() bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) @@ -625,8 +638,7 @@ def get_md5hash(self, bucket_name, object_name): storage bucket_name. :type object_name: str """ - self.log.info('Retrieving the MD5 hash of ' - 'object: %s in bucket: %s', object_name, bucket_name) + self.log.info('Retrieving the MD5 hash of ' 'object: %s in bucket: %s', object_name, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) @@ -635,14 +647,15 @@ def get_md5hash(self, bucket_name, object_name): return blob_md5hash @GoogleBaseHook.fallback_to_default_project_id - def create_bucket(self, - bucket_name, - resource=None, - storage_class='MULTI_REGIONAL', - location='US', - project_id=None, - labels=None - ): + def create_bucket( + self, + bucket_name, + resource=None, + storage_class='MULTI_REGIONAL', + location='US', + project_id=None, + labels=None, + ): """ Creates a new bucket. Google Cloud Storage uses a flat namespace, so you can't create a bucket with a name that is already in use. @@ -684,8 +697,9 @@ def create_bucket(self, :return: If successful, it returns the ``id`` of the bucket. """ - self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s', - bucket_name, location, storage_class) + self.log.info( + 'Creating Bucket: %s; Location: %s; Storage Class: %s', bucket_name, location, storage_class + ) # Add airflow-version label to the bucket labels = labels or {} @@ -759,8 +773,7 @@ def insert_object_acl(self, bucket_name, object_name, entity, role, generation=N Required for Requester Pays buckets. :type user_project: str """ - self.log.info('Creating a new ACL entry for object: %s in bucket: %s', - object_name, bucket_name) + self.log.info('Creating a new ACL entry for object: %s in bucket: %s', object_name, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name=bucket_name) blob = bucket.blob(blob_name=object_name, generation=generation) @@ -771,8 +784,7 @@ def insert_object_acl(self, bucket_name, object_name, entity, role, generation=N blob.acl.user_project = user_project blob.acl.save() - self.log.info('A new ACL entry created for object: %s in bucket: %s', - object_name, bucket_name) + self.log.info('A new ACL entry created for object: %s in bucket: %s', object_name, bucket_name) def compose(self, bucket_name, source_objects, destination_object): """ @@ -799,15 +811,13 @@ def compose(self, bucket_name, source_objects, destination_object): if not bucket_name or not destination_object: raise ValueError('bucket_name and destination_object cannot be empty.') - self.log.info("Composing %s to %s in the bucket %s", - source_objects, destination_object, bucket_name) + self.log.info("Composing %s to %s in the bucket %s", source_objects, destination_object, bucket_name) client = self.get_conn() bucket = client.bucket(bucket_name) destination_blob = bucket.blob(destination_object) destination_blob.compose( - sources=[ - bucket.blob(blob_name=source_object) for source_object in source_objects - ]) + sources=[bucket.blob(blob_name=source_object) for source_object in source_objects] + ) self.log.info("Completed successfully.") @@ -819,7 +829,7 @@ def sync( destination_object: Optional[str] = None, recursive: bool = True, allow_overwrite: bool = False, - delete_extra_files: bool = False + delete_extra_files: bool = False, ): """ Synchronizes the contents of the buckets. @@ -872,7 +882,7 @@ def sync( destination_bucket=destination_bucket_obj, source_object=source_object, destination_object=destination_object, - recursive=recursive + recursive=recursive, ) self.log.info( "Planned synchronization. To delete blobs count: %s, to upload blobs count: %s, " @@ -911,8 +921,9 @@ def sync( self.log.info("Skipped blobs overwriting.") elif allow_overwrite: for blob in to_rewrite_blobs: - dst_object = self._calculate_sync_destination_path(blob, destination_object, - source_object_prefix_len) + dst_object = self._calculate_sync_destination_path( + blob, destination_object, source_object_prefix_len + ) self.rewrite( source_bucket=source_bucket_obj.name, source_object=blob.name, @@ -924,10 +935,7 @@ def sync( self.log.info("Synchronization finished.") def _calculate_sync_destination_path( - self, - blob: storage.Blob, - destination_object: Optional[str], - source_object_prefix_len: int + self, blob: storage.Blob, destination_object: Optional[str], source_object_prefix_len: int ) -> str: return ( path.join(destination_object, blob.name[source_object_prefix_len:]) @@ -936,9 +944,7 @@ def _calculate_sync_destination_path( ) def _normalize_directory_path(self, source_object: Optional[str]) -> Optional[str]: - return ( - source_object + "/" if source_object and not source_object.endswith("/") else source_object - ) + return source_object + "/" if source_object and not source_object.endswith("/") else source_object @staticmethod def _prepare_sync_plan( @@ -956,7 +962,8 @@ def _prepare_sync_plan( # Fetch blobs list source_blobs = list(source_bucket.list_blobs(prefix=source_object, delimiter=delimiter)) destination_blobs = list( - destination_bucket.list_blobs(prefix=destination_object, delimiter=delimiter)) + destination_bucket.list_blobs(prefix=destination_object, delimiter=delimiter) + ) # Create indexes that allow you to identify blobs based on their name source_names_index = {a.name[source_object_prefix_len:]: a for a in source_blobs} destination_names_index = {a.name[destination_object_prefix_len:]: a for a in destination_blobs} diff --git a/airflow/providers/google/cloud/hooks/gdm.py b/airflow/providers/google/cloud/hooks/gdm.py index 60cebbf2d103d..0bd0b7e95b519 100644 --- a/airflow/providers/google/cloud/hooks/gdm.py +++ b/airflow/providers/google/cloud/hooks/gdm.py @@ -38,9 +38,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super(GoogleDeploymentManagerHook, self).__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) def get_conn(self): @@ -53,9 +51,12 @@ def get_conn(self): return build('deploymentmanager', 'v2', http=http_authorized, cache_discovery=False) @GoogleBaseHook.fallback_to_default_project_id - def list_deployments(self, project_id: Optional[str] = None, # pylint: disable=too-many-arguments - deployment_filter: Optional[str] = None, - order_by: Optional[str] = None) -> List[Dict[str, Any]]: + def list_deployments( + self, + project_id: Optional[str] = None, # pylint: disable=too-many-arguments + deployment_filter: Optional[str] = None, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: """ Lists deployments in a google cloud project. @@ -69,9 +70,8 @@ def list_deployments(self, project_id: Optional[str] = None, # pylint: disable= """ deployments = [] # type: List[Dict] conn = self.get_conn() - request = conn.deployments().list(project=project_id, # pylint: disable=no-member - filter=deployment_filter, - orderBy=order_by) + # pylint: disable=no-member + request = conn.deployments().list(project=project_id, filter=deployment_filter, orderBy=order_by) while request is not None: response = request.execute(num_retries=self.num_retries) @@ -83,10 +83,9 @@ def list_deployments(self, project_id: Optional[str] = None, # pylint: disable= return deployments @GoogleBaseHook.fallback_to_default_project_id - def delete_deployment(self, - project_id: Optional[str], - deployment: Optional[str] = None, - delete_policy: Optional[str] = None): + def delete_deployment( + self, project_id: Optional[str], deployment: Optional[str] = None, delete_policy: Optional[str] = None + ): """ Deletes a deployment and all associated resources in a google cloud project. @@ -100,10 +99,12 @@ def delete_deployment(self, :rtype: None """ conn = self.get_conn() - request = conn.deployments().delete(project=project_id, # pylint: disable=no-member - deployment=deployment, - deletePolicy=delete_policy) + # pylint: disable=no-member + request = conn.deployments().delete( + project=project_id, deployment=deployment, deletePolicy=delete_policy + ) resp = request.execute() if 'error' in resp.keys(): - raise AirflowException('Errors deleting deployment: ', - ', '.join([err['message'] for err in resp['error']['errors']])) + raise AirflowException( + 'Errors deleting deployment: ', ', '.join([err['message'] for err in resp['error']['errors']]) + ) diff --git a/airflow/providers/google/cloud/hooks/kms.py b/airflow/providers/google/cloud/hooks/kms.py index 1ecc70678d340..1786ce718f06a 100644 --- a/airflow/providers/google/cloud/hooks/kms.py +++ b/airflow/providers/google/cloud/hooks/kms.py @@ -68,9 +68,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._conn = None # type: Optional[KeyManagementServiceClient] @@ -83,8 +81,7 @@ def get_conn(self) -> KeyManagementServiceClient: """ if not self._conn: self._conn = KeyManagementServiceClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) return self._conn diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 099f48c7ec611..150c24528af24 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -55,9 +55,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None self.location = location @@ -71,16 +69,17 @@ def get_conn(self) -> container_v1.ClusterManagerClient: if self._client is None: credentials = self._get_credentials() self._client = container_v1.ClusterManagerClient( - credentials=credentials, - client_info=self.client_info + credentials=credentials, client_info=self.client_info ) return self._client # To preserve backward compatibility # TODO: remove one day def get_client(self) -> container_v1.ClusterManagerClient: # pylint: disable=missing-docstring - warnings.warn("The get_client method has been deprecated. " - "You should use the get_conn method.", DeprecationWarning) + warnings.warn( + "The get_client method has been deprecated. " "You should use the get_conn method.", + DeprecationWarning, + ) return self.get_conn() def wait_for_operation(self, operation: Operation, project_id: Optional[str] = None) -> Operation: @@ -97,12 +96,10 @@ def wait_for_operation(self, operation: Operation, project_id: Optional[str] = N self.log.info("Waiting for OPERATION_NAME %s", operation.name) time.sleep(OPERATIONAL_POLL_INTERVAL) while operation.status != Operation.Status.DONE: - if operation.status == Operation.Status.RUNNING or operation.status == \ - Operation.Status.PENDING: + if operation.status == Operation.Status.RUNNING or operation.status == Operation.Status.PENDING: time.sleep(OPERATIONAL_POLL_INTERVAL) else: - raise exceptions.GoogleCloudError( - "Operation has failed with status: %s" % operation.status) + raise exceptions.GoogleCloudError("Operation has failed with status: %s" % operation.status) # To update status of operation operation = self.get_operation(operation.name, project_id=project_id or self.project_id) return operation @@ -117,9 +114,9 @@ def get_operation(self, operation_name: str, project_id: Optional[str] = None) - :type project_id: str :return: The new, updated operation from Google Cloud """ - return self.get_conn().get_operation(project_id=project_id or self.project_id, - zone=self.location, - operation_id=operation_name) + return self.get_conn().get_operation( + project_id=project_id or self.project_id, zone=self.location, operation_id=operation_name + ) @staticmethod def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster: @@ -144,11 +141,7 @@ def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster: @GoogleBaseHook.fallback_to_default_project_id def delete_cluster( - self, - name: str, - project_id: str, - retry: Retry = DEFAULT, - timeout: float = DEFAULT + self, name: str, project_id: str, retry: Retry = DEFAULT, timeout: float = DEFAULT ) -> Optional[str]: """ Deletes the cluster, including the Kubernetes endpoint and all @@ -172,16 +165,12 @@ def delete_cluster( :return: The full url to the delete operation if successful, else None """ - self.log.info( - "Deleting (project_id=%s, zone=%s, cluster_id=%s)", project_id, self.location, name - ) + self.log.info("Deleting (project_id=%s, zone=%s, cluster_id=%s)", project_id, self.location, name) try: - resource = self.get_conn().delete_cluster(project_id=project_id, - zone=self.location, - cluster_id=name, - retry=retry, - timeout=timeout) + resource = self.get_conn().delete_cluster( + project_id=project_id, zone=self.location, cluster_id=name, retry=retry, timeout=timeout + ) resource = self.wait_for_operation(resource) # Returns server-defined url for the resource return resource.self_link @@ -191,11 +180,7 @@ def delete_cluster( @GoogleBaseHook.fallback_to_default_project_id def create_cluster( - self, - cluster: Union[Dict, Cluster], - project_id: str, - retry: Retry = DEFAULT, - timeout: float = DEFAULT + self, cluster: Union[Dict, Cluster], project_id: str, retry: Retry = DEFAULT, timeout: float = DEFAULT ) -> str: """ Creates a cluster, consisting of the specified number and type of Google Compute @@ -225,21 +210,17 @@ def create_cluster( cluster_proto = Cluster() cluster = ParseDict(cluster, cluster_proto) elif not isinstance(cluster, Cluster): - raise AirflowException( - "cluster is not instance of Cluster proto or python dict") + raise AirflowException("cluster is not instance of Cluster proto or python dict") self._append_label(cluster, 'airflow-version', 'v' + version.version) self.log.info( - "Creating (project_id=%s, zone=%s, cluster_name=%s)", - project_id, self.location, cluster.name + "Creating (project_id=%s, zone=%s, cluster_name=%s)", project_id, self.location, cluster.name ) try: - resource = self.get_conn().create_cluster(project_id=project_id, - zone=self.location, - cluster=cluster, - retry=retry, - timeout=timeout) + resource = self.get_conn().create_cluster( + project_id=project_id, zone=self.location, cluster=cluster, retry=retry, timeout=timeout + ) resource = self.wait_for_operation(resource) return resource.target_link @@ -249,11 +230,7 @@ def create_cluster( @GoogleBaseHook.fallback_to_default_project_id def get_cluster( - self, - name: str, - project_id: str, - retry: Retry = DEFAULT, - timeout: float = DEFAULT + self, name: str, project_id: str, retry: Retry = DEFAULT, timeout: float = DEFAULT ) -> Cluster: """ Gets details of specified cluster @@ -273,11 +250,15 @@ def get_cluster( """ self.log.info( "Fetching cluster (project_id=%s, zone=%s, cluster_name=%s)", - project_id or self.project_id, self.location, name + project_id or self.project_id, + self.location, + name, ) - return self.get_conn().get_cluster(project_id=project_id, - zone=self.location, - cluster_id=name, - retry=retry, - timeout=timeout).self_link + return ( + self.get_conn() + .get_cluster( + project_id=project_id, zone=self.location, cluster_id=name, retry=retry, timeout=timeout + ) + .self_link + ) diff --git a/airflow/providers/google/cloud/hooks/life_sciences.py b/airflow/providers/google/cloud/hooks/life_sciences.py index 5ab9f9a877aa4..286f71ef20551 100644 --- a/airflow/providers/google/cloud/hooks/life_sciences.py +++ b/airflow/providers/google/cloud/hooks/life_sciences.py @@ -66,9 +66,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -80,8 +78,7 @@ def get_conn(self): """ if not self._conn: http_authorized = self._authorize() - self._conn = build("lifesciences", self.api_version, - http=http_authorized, cache_discovery=False) + self._conn = build("lifesciences", self.api_version, http=http_authorized, cache_discovery=False) return self._conn @GoogleBaseHook.fallback_to_default_project_id @@ -101,11 +98,12 @@ def run_pipeline(self, body: Dict, location: str, project_id: str): parent = self._location_path(project_id=project_id, location=location) service = self.get_conn() - request = (service.projects() # pylint: disable=no-member - .locations() - .pipelines() - .run(parent=parent, body=body) - ) + request = ( + service.projects() # pylint: disable=no-member + .locations() + .pipelines() + .run(parent=parent, body=body) + ) response = request.execute(num_retries=self.num_retries) @@ -128,9 +126,7 @@ def _location_path(self, project_id: str, location: str): :type location: str """ return google.api_core.path_template.expand( - 'projects/{project}/locations/{location}', - project=project_id, - location=location, + 'projects/{project}/locations/{location}', project=project_id, location=location, ) def _wait_for_operation_to_complete(self, operation_name: str) -> None: @@ -146,11 +142,13 @@ def _wait_for_operation_to_complete(self, operation_name: str) -> None: """ service = self.get_conn() while True: - operation_response = (service.projects() # pylint: disable=no-member - .locations() - .operations() - .get(name=operation_name) - .execute(num_retries=self.num_retries)) + operation_response = ( + service.projects() # pylint: disable=no-member + .locations() + .operations() + .get(name=operation_name) + .execute(num_retries=self.num_retries) + ) self.log.info('Waiting for pipeline operation to complete') if operation_response.get("done"): response = operation_response.get("response") diff --git a/airflow/providers/google/cloud/hooks/mlengine.py b/airflow/providers/google/cloud/hooks/mlengine.py index 4cbed121ebc64..c05698c0e407a 100644 --- a/airflow/providers/google/cloud/hooks/mlengine.py +++ b/airflow/providers/google/cloud/hooks/mlengine.py @@ -40,20 +40,18 @@ def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func): try: response = request.execute() if is_error_func(response): - raise ValueError( - 'The response contained an error: {}'.format(response) - ) + raise ValueError('The response contained an error: {}'.format(response)) if is_done_func(response): log.info('Operation is done: %s', response) return response - time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + time.sleep((2 ** i) + (random.randint(0, 1000) / 1000)) except HttpError as e: if e.resp.status != 429: log.info('Something went wrong. Not retrying: %s', format(e)) raise else: - time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + time.sleep((2 ** i) + (random.randint(0, 1000) / 1000)) raise ValueError('Connection could not be established after {} retries.'.format(max_n)) @@ -65,6 +63,7 @@ class MLEngineHook(GoogleBaseHook): All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """ + def get_conn(self): """ Retrieves the connection to MLEngine. @@ -75,12 +74,7 @@ def get_conn(self): return build('ml', 'v1', http=authed_http, cache_discovery=False) @GoogleBaseHook.fallback_to_default_project_id - def create_job( - self, - job: Dict, - project_id: str, - use_existing_job_fn: Optional[Callable] = None - ) -> Dict: + def create_job(self, job: Dict, project_id: str, use_existing_job_fn: Optional[Callable] = None) -> Dict: """ Launches a MLEngine job and wait for it to reach a terminal state. @@ -116,9 +110,8 @@ def create_job( self._append_label(job) self.log.info("Creating job.") - request = hook.projects().jobs().create( # pylint: disable=no-member - parent='projects/{}'.format(project_id), - body=job) + # pylint: disable=no-member + request = hook.projects().jobs().create(parent='projects/{}'.format(project_id), body=job) job_id = job['jobId'] try: @@ -130,15 +123,12 @@ def create_job( existing_job = self._get_job(project_id, job_id) if not use_existing_job_fn(existing_job): self.log.error( - 'Job with job_id %s already exist, but it does ' - 'not match our expectation: %s', - job_id, existing_job + 'Job with job_id %s already exist, but it does ' 'not match our expectation: %s', + job_id, + existing_job, ) raise - self.log.info( - 'Job with job_id %s already exist. Will waiting for it to finish', - job_id - ) + self.log.info('Job with job_id %s already exist. Will waiting for it to finish', job_id) else: self.log.error('Failed to create MLEngine job: %s', e) raise @@ -146,11 +136,7 @@ def create_job( return self._wait_for_job_done(project_id, job_id) @GoogleBaseHook.fallback_to_default_project_id - def cancel_job( - self, - job_id: str, - project_id: str, - ) -> Dict: + def cancel_job(self, job_id: str, project_id: str,) -> Dict: """ Cancels a MLEngine job. @@ -166,9 +152,8 @@ def cancel_job( :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - - request = hook.projects().jobs().cancel( # pylint: disable=no-member - name=f'projects/{project_id}/jobs/{job_id}') + # pylint: disable=no-member + request = hook.projects().jobs().cancel(name=f'projects/{project_id}/jobs/{job_id}') try: return request.execute() @@ -177,9 +162,7 @@ def cancel_job( self.log.error('Job with job_id %s does not exist. ', job_id) raise elif e.resp.status == 400: - self.log.info( - 'Job with job_id %s is already complete, cancellation aborted.', - job_id) + self.log.info('Job with job_id %s is already complete, cancellation aborted.', job_id) return {} else: self.log.error('Failed to cancel MLEngine job: %s', e) @@ -239,12 +222,7 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): time.sleep(interval) @GoogleBaseHook.fallback_to_default_project_id - def create_version( - self, - model_name: str, - version_spec: Dict, - project_id: str, - ) -> Dict: + def create_version(self, model_name: str, version_spec: Dict, project_id: str,) -> Dict: """ Creates the Version on Google Cloud ML Engine. @@ -266,25 +244,20 @@ def create_version( self._append_label(version_spec) - create_request = hook.projects().models().versions().create( # pylint: disable=no-member - parent=parent_name, body=version_spec) + # pylint: disable=no-member + create_request = hook.projects().models().versions().create(parent=parent_name, body=version_spec) response = create_request.execute() - get_request = hook.projects().operations().get( # pylint: disable=no-member - name=response['name']) + get_request = hook.projects().operations().get(name=response['name']) # pylint: disable=no-member return _poll_with_exponential_delay( request=get_request, max_n=9, is_done_func=lambda resp: resp.get('done', False), - is_error_func=lambda resp: resp.get('error', None) is not None) + is_error_func=lambda resp: resp.get('error', None) is not None, + ) @GoogleBaseHook.fallback_to_default_project_id - def set_default_version( - self, - model_name: str, - version_name: str, - project_id: str, - ) -> Dict: + def set_default_version(self, model_name: str, version_name: str, project_id: str,) -> Dict: """ Sets a version to be the default. Blocks until finished. @@ -302,10 +275,9 @@ def set_default_version( :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() - full_version_name = 'projects/{}/models/{}/versions/{}'.format( - project_id, model_name, version_name) - request = hook.projects().models().versions().setDefault( # pylint: disable=no-member - name=full_version_name, body={}) + full_version_name = 'projects/{}/models/{}/versions/{}'.format(project_id, model_name, version_name) + # pylint: disable=no-member + request = hook.projects().models().versions().setDefault(name=full_version_name, body={}) try: response = request.execute() @@ -316,11 +288,7 @@ def set_default_version( raise @GoogleBaseHook.fallback_to_default_project_id - def list_versions( - self, - model_name: str, - project_id: str, - ) -> List[Dict]: + def list_versions(self, model_name: str, project_id: str,) -> List[Dict]: """ Lists all available versions of a model. Blocks until finished. @@ -336,28 +304,25 @@ def list_versions( """ hook = self.get_conn() result = [] # type: List[Dict] - full_parent_name = 'projects/{}/models/{}'.format( - project_id, model_name) - request = hook.projects().models().versions().list( # pylint: disable=no-member - parent=full_parent_name, pageSize=100) + full_parent_name = 'projects/{}/models/{}'.format(project_id, model_name) + # pylint: disable=no-member + request = hook.projects().models().versions().list(parent=full_parent_name, pageSize=100) while request is not None: response = request.execute() result.extend(response.get('versions', [])) - - request = hook.projects().models().versions().list_next( # pylint: disable=no-member - previous_request=request, - previous_response=response) + # pylint: disable=no-member + request = ( + hook.projects() + .models() + .versions() + .list_next(previous_request=request, previous_response=response) + ) time.sleep(5) return result @GoogleBaseHook.fallback_to_default_project_id - def delete_version( - self, - model_name: str, - version_name: str, - project_id: str, - ) -> Dict: + def delete_version(self, model_name: str, version_name: str, project_id: str,) -> Dict: """ Deletes the given version of a model. Blocks until finished. @@ -372,26 +337,22 @@ def delete_version( :rtype: Dict """ hook = self.get_conn() - full_name = 'projects/{}/models/{}/versions/{}'.format( - project_id, model_name, version_name) - delete_request = hook.projects().models().versions().delete( # pylint: disable=no-member - name=full_name) + full_name = 'projects/{}/models/{}/versions/{}'.format(project_id, model_name, version_name) + delete_request = ( + hook.projects().models().versions().delete(name=full_name) # pylint: disable=no-member + ) response = delete_request.execute() - get_request = hook.projects().operations().get( # pylint: disable=no-member - name=response['name']) + get_request = hook.projects().operations().get(name=response['name']) # pylint: disable=no-member return _poll_with_exponential_delay( request=get_request, max_n=9, is_done_func=lambda resp: resp.get('done', False), - is_error_func=lambda resp: resp.get('error', None) is not None) + is_error_func=lambda resp: resp.get('error', None) is not None, + ) @GoogleBaseHook.fallback_to_default_project_id - def create_model( - self, - model: Dict, - project_id: str, - ) -> Dict: + def create_model(self, model: Dict, project_id: str,) -> Dict: """ Create a Model. Blocks until finished. @@ -407,14 +368,12 @@ def create_model( """ hook = self.get_conn() if 'name' not in model or not model['name']: - raise ValueError("Model name must be provided and " - "could not be an empty string") + raise ValueError("Model name must be provided and " "could not be an empty string") project = 'projects/{}'.format(project_id) self._append_label(model) try: - request = hook.projects().models().create( # pylint: disable=no-member - parent=project, body=model) + request = hook.projects().models().create(parent=project, body=model) # pylint: disable=no-member respone = request.execute() except HttpError as e: if e.resp.status != 409: @@ -432,23 +391,16 @@ def create_model( field_violation = error_detail['fieldViolations'][0] if ( - field_violation["field"] != "model.name" or - field_violation["description"] != "A model with the same name already exists." + field_violation["field"] != "model.name" + or field_violation["description"] != "A model with the same name already exists." ): raise e - respone = self.get_model( - model_name=model['name'], - project_id=project_id - ) + respone = self.get_model(model_name=model['name'], project_id=project_id) return respone @GoogleBaseHook.fallback_to_default_project_id - def get_model( - self, - model_name: str, - project_id: str, - ) -> Optional[Dict]: + def get_model(self, model_name: str, project_id: str,) -> Optional[Dict]: """ Gets a Model. Blocks until finished. @@ -464,10 +416,8 @@ def get_model( """ hook = self.get_conn() if not model_name: - raise ValueError("Model name must be provided and " - "it could not be an empty string") - full_model_name = 'projects/{}/models/{}'.format( - project_id, model_name) + raise ValueError("Model name must be provided and " "it could not be an empty string") + full_model_name = 'projects/{}/models/{}'.format(project_id, model_name) request = hook.projects().models().get(name=full_model_name) # pylint: disable=no-member try: return request.execute() @@ -478,12 +428,7 @@ def get_model( raise @GoogleBaseHook.fallback_to_default_project_id - def delete_model( - self, - model_name: str, - project_id: str, - delete_contents: bool = False, - ) -> None: + def delete_model(self, model_name: str, project_id: str, delete_contents: bool = False,) -> None: """ Delete a Model. Blocks until finished. diff --git a/airflow/providers/google/cloud/hooks/natural_language.py b/airflow/providers/google/cloud/hooks/natural_language.py index d1fc1344de5bc..d69962e6c200a 100644 --- a/airflow/providers/google/cloud/hooks/natural_language.py +++ b/airflow/providers/google/cloud/hooks/natural_language.py @@ -23,8 +23,14 @@ from google.api_core.retry import Retry from google.cloud.language_v1 import LanguageServiceClient, enums from google.cloud.language_v1.types import ( - AnalyzeEntitiesResponse, AnalyzeEntitySentimentResponse, AnalyzeSentimentResponse, AnalyzeSyntaxResponse, - AnnotateTextRequest, AnnotateTextResponse, ClassifyTextResponse, Document, + AnalyzeEntitiesResponse, + AnalyzeEntitySentimentResponse, + AnalyzeSentimentResponse, + AnalyzeSyntaxResponse, + AnnotateTextRequest, + AnnotateTextResponse, + ClassifyTextResponse, + Document, ) from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -58,9 +64,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._conn = None @@ -73,8 +77,7 @@ def get_conn(self) -> LanguageServiceClient: """ if not self._conn: self._conn = LanguageServiceClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) return self._conn @@ -85,7 +88,7 @@ def analyze_entities( encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> AnalyzeEntitiesResponse: """ Finds named entities in the text along with entity types, @@ -119,7 +122,7 @@ def analyze_entity_sentiment( encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> AnalyzeEntitySentimentResponse: """ Finds entities, similar to AnalyzeEntities in the text and analyzes sentiment associated with each @@ -153,7 +156,7 @@ def analyze_sentiment( encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> AnalyzeSentimentResponse: """ Analyzes the sentiment of the provided text. @@ -186,7 +189,7 @@ def analyze_syntax( encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> AnalyzeSyntaxResponse: """ Analyzes the syntax of the text and provides sentence boundaries and tokenization along with part @@ -221,7 +224,7 @@ def annotate_text( encoding_type: enums.EncodingType = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> AnnotateTextResponse: """ A convenience method that provides all the features that analyzeSentiment, @@ -262,7 +265,7 @@ def classify_text( document: Union[Dict, Document], retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> ClassifyTextResponse: """ Classifies a document into categories. diff --git a/airflow/providers/google/cloud/hooks/pubsub.py b/airflow/providers/google/cloud/hooks/pubsub.py index 9830d56e1e5d5..1efc3ba5e0737 100644 --- a/airflow/providers/google/cloud/hooks/pubsub.py +++ b/airflow/providers/google/cloud/hooks/pubsub.py @@ -29,7 +29,12 @@ from google.cloud.exceptions import NotFound from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient from google.cloud.pubsub_v1.types import ( - DeadLetterPolicy, Duration, ExpirationPolicy, MessageStoragePolicy, PushConfig, ReceivedMessage, + DeadLetterPolicy, + Duration, + ExpirationPolicy, + MessageStoragePolicy, + PushConfig, + ReceivedMessage, RetryPolicy, ) from googleapiclient.errors import HttpError @@ -59,9 +64,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -73,10 +76,7 @@ def get_conn(self) -> PublisherClient: :rtype: google.cloud.pubsub_v1.PublisherClient """ if not self._client: - self._client = PublisherClient( - credentials=self._get_credentials(), - client_info=self.client_info - ) + self._client = PublisherClient(credentials=self._get_credentials(), client_info=self.client_info) return self._client @cached_property @@ -87,18 +87,10 @@ def subscriber_client(self) -> SubscriberClient: :return: Google Cloud Pub/Sub client object. :rtype: google.cloud.pubsub_v1.SubscriberClient """ - return SubscriberClient( - credentials=self._get_credentials(), - client_info=self.client_info - ) + return SubscriberClient(credentials=self._get_credentials(), client_info=self.client_info) @GoogleBaseHook.fallback_to_default_project_id - def publish( - self, - topic: str, - messages: List[Dict], - project_id: str, - ) -> None: + def publish(self, topic: str, messages: List[Dict], project_id: str,) -> None: """ Publishes messages to a Pub/Sub topic. @@ -122,9 +114,7 @@ def publish( try: for message in messages: future = publisher.publish( - topic=topic_path, - data=message.get("data", b''), - **message.get('attributes', {}) + topic=topic_path, data=message.get("data", b''), **message.get('attributes', {}) ) future.result() except GoogleAPICallError as e: @@ -142,7 +132,9 @@ def _validate_messages(messages) -> None: b64decode(message["data"]) warnings.warn( "The base 64 encoded string as 'data' field has been deprecated. " - "You should pass bytestring (utf-8 encoded).", DeprecationWarning, stacklevel=4 + "You should pass bytestring (utf-8 encoded).", + DeprecationWarning, + stacklevel=4, ) except ValueError: pass @@ -153,10 +145,12 @@ def _validate_messages(messages) -> None: raise PubSubException("Wrong message. Dictionary must contain 'data' or 'attributes'.") if "data" in message and not isinstance(message["data"], bytes): raise PubSubException("Wrong message. 'data' must be send as a bytestring") - if ("data" not in message and "attributes" in message and not message["attributes"]) \ - or ("attributes" in message and not isinstance(message["attributes"], dict)): + if ("data" not in message and "attributes" in message and not message["attributes"]) or ( + "attributes" in message and not isinstance(message["attributes"], dict) + ): raise PubSubException( - "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary.") + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary." + ) # pylint: disable=too-many-arguments @GoogleBaseHook.fallback_to_default_project_id @@ -275,10 +269,7 @@ def delete_topic( try: # pylint: disable=no-member publisher.delete_topic( - topic=topic_path, - retry=retry, - timeout=timeout, - metadata=metadata, + topic=topic_path, retry=retry, timeout=timeout, metadata=metadata, ) except NotFound: self.log.warning('Topic does not exist: %s', topic_path) @@ -467,16 +458,14 @@ def delete_subscription( :type metadata: Sequence[Tuple[str, str]]] """ subscriber = self.subscriber_client - subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long + # noqa E501 # pylint: disable=no-member + subscription_path = SubscriberClient.subscription_path(project_id, subscription) self.log.info("Deleting subscription (path) %s", subscription_path) try: # pylint: disable=no-member subscriber.delete_subscription( - subscription=subscription_path, - retry=retry, - timeout=timeout, - metadata=metadata + subscription=subscription_path, retry=retry, timeout=timeout, metadata=metadata ) except NotFound: @@ -530,7 +519,8 @@ def pull( https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/pull#ReceivedMessage """ subscriber = self.subscriber_client - subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long + # noqa E501 # pylint: disable=no-member,line-too-long + subscription_path = SubscriberClient.subscription_path(project_id, subscription) self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path) try: @@ -588,15 +578,13 @@ def acknowledge( if ack_ids is not None and messages is None: pass elif ack_ids is None and messages is not None: - ack_ids = [ - message.ack_id - for message in messages - ] + ack_ids = [message.ack_id for message in messages] else: raise ValueError("One and only one of 'ack_ids' and 'messages' arguments have to be provided") subscriber = self.subscriber_client - subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long + # noqa E501 # pylint: disable=no-member + subscription_path = SubscriberClient.subscription_path(project_id, subscription) self.log.info("Acknowledging %d ack_ids from subscription (path) %s", len(ack_ids), subscription_path) try: @@ -610,7 +598,10 @@ def acknowledge( ) except (HttpError, GoogleAPICallError) as e: raise PubSubException( - 'Error acknowledging {} messages pulled from subscription {}' - .format(len(ack_ids), subscription_path), e) + 'Error acknowledging {} messages pulled from subscription {}'.format( + len(ack_ids), subscription_path + ), + e, + ) self.log.info("Acknowledged ack_ids from subscription (path) %s", subscription_path) diff --git a/airflow/providers/google/cloud/hooks/secret_manager.py b/airflow/providers/google/cloud/hooks/secret_manager.py index 422797e081100..a905bfd5152c5 100644 --- a/airflow/providers/google/cloud/hooks/secret_manager.py +++ b/airflow/providers/google/cloud/hooks/secret_manager.py @@ -47,6 +47,7 @@ class SecretsManagerHook(GoogleBaseHook): account from the list granting this role to the originating account. :type impersonation_chain: Union[str, Sequence[str]] """ + def __init__( self, gcp_conn_id: str = "google_cloud_default", @@ -54,9 +55,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.client = _SecretManagerClient(credentials=self._get_credentials()) @@ -70,9 +69,9 @@ def get_conn(self) -> _SecretManagerClient: return self.client @GoogleBaseHook.fallback_to_default_project_id - def get_secret(self, secret_id: str, - secret_version: str = 'latest', - project_id: Optional[str] = None) -> Optional[str]: + def get_secret( + self, secret_id: str, secret_version: str = 'latest', project_id: Optional[str] = None + ) -> Optional[str]: """ Get secret value from the Secret Manager. @@ -83,5 +82,6 @@ def get_secret(self, secret_id: str, :param project_id: Project id (if you want to override the project_id from credentials) :type project_id: str """ - return self.get_conn().get_secret(secret_id=secret_id, secret_version=secret_version, - project_id=project_id) # type: ignore + return self.get_conn().get_secret( + secret_id=secret_id, secret_version=secret_version, project_id=project_id # type: ignore + ) diff --git a/airflow/providers/google/cloud/hooks/spanner.py b/airflow/providers/google/cloud/hooks/spanner.py index 77607b6158462..d9317e535f1c8 100644 --- a/airflow/providers/google/cloud/hooks/spanner.py +++ b/airflow/providers/google/cloud/hooks/spanner.py @@ -46,9 +46,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -63,18 +61,12 @@ def _get_client(self, project_id: str) -> Client: """ if not self._client: self._client = Client( - project=project_id, - credentials=self._get_credentials(), - client_info=self.client_info + project=project_id, credentials=self._get_credentials(), client_info=self.client_info ) return self._client @GoogleBaseHook.fallback_to_default_project_id - def get_instance( - self, - instance_id: str, - project_id: str, - ) -> Instance: + def get_instance(self, instance_id: str, project_id: str,) -> Instance: """ Gets information about a particular instance. @@ -92,12 +84,13 @@ def get_instance( return instance def _apply_to_instance( - self, project_id: str, + self, + project_id: str, instance_id: str, configuration_name: str, node_count: int, display_name: str, - func: Callable[[Instance], Operation] + func: Callable[[Instance], Operation], ) -> None: """ Invokes a method on a given instance by applying a specified Callable. @@ -120,8 +113,11 @@ def _apply_to_instance( :type func: Callable[google.cloud.spanner_v1.instance.Instance] """ instance = self._get_client(project_id=project_id).instance( - instance_id=instance_id, configuration_name=configuration_name, - node_count=node_count, display_name=display_name) + instance_id=instance_id, + configuration_name=configuration_name, + node_count=node_count, + display_name=display_name, + ) try: operation = func(instance) # type: Operation except GoogleAPICallError as e: @@ -134,12 +130,7 @@ def _apply_to_instance( @GoogleBaseHook.fallback_to_default_project_id def create_instance( - self, - instance_id: str, - configuration_name: str, - node_count: int, - display_name: str, - project_id: str, + self, instance_id: str, configuration_name: str, node_count: int, display_name: str, project_id: str, ) -> None: """ Creates a new Cloud Spanner instance. @@ -162,17 +153,13 @@ def create_instance( :type project_id: str :return: None """ - self._apply_to_instance(project_id, instance_id, configuration_name, - node_count, display_name, lambda x: x.create()) + self._apply_to_instance( + project_id, instance_id, configuration_name, node_count, display_name, lambda x: x.create() + ) @GoogleBaseHook.fallback_to_default_project_id def update_instance( - self, - instance_id: str, - configuration_name: str, - node_count: int, - display_name: str, - project_id: str, + self, instance_id: str, configuration_name: str, node_count: int, display_name: str, project_id: str, ) -> None: """ Updates an existing Cloud Spanner instance. @@ -195,8 +182,9 @@ def update_instance( :type project_id: str :return: None """ - self._apply_to_instance(project_id, instance_id, configuration_name, - node_count, display_name, lambda x: x.update()) + self._apply_to_instance( + project_id, instance_id, configuration_name, node_count, display_name, lambda x: x.update() + ) @GoogleBaseHook.fallback_to_default_project_id def delete_instance(self, instance_id: str, project_id: str) -> None: @@ -219,12 +207,7 @@ def delete_instance(self, instance_id: str, project_id: str) -> None: raise e @GoogleBaseHook.fallback_to_default_project_id - def get_database( - self, - instance_id: str, - database_id: str, - project_id: str, - ) -> Optional[Database]: + def get_database(self, instance_id: str, database_id: str, project_id: str,) -> Optional[Database]: """ Retrieves a database in Cloud Spanner. If the database does not exist in the specified instance, it returns None. @@ -239,11 +222,11 @@ def get_database( :return: Database object or None if database does not exist :rtype: google.cloud.spanner_v1.database.Database or None """ - instance = self._get_client(project_id=project_id).instance( - instance_id=instance_id) + instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): - raise AirflowException("The instance {} does not exist in project {} !". - format(instance_id, project_id)) + raise AirflowException( + "The instance {} does not exist in project {} !".format(instance_id, project_id) + ) database = instance.database(database_id=database_id) if not database.exists(): return None @@ -252,11 +235,7 @@ def get_database( @GoogleBaseHook.fallback_to_default_project_id def create_database( - self, - instance_id: str, - database_id: str, - ddl_statements: List[str], - project_id: str, + self, instance_id: str, database_id: str, ddl_statements: List[str], project_id: str, ) -> None: """ Creates a new database in Cloud Spanner. @@ -272,13 +251,12 @@ def create_database( database. If set to None or missing, the default project_id from the GCP connection is used. :return: None """ - instance = self._get_client(project_id=project_id).instance( - instance_id=instance_id) + instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): - raise AirflowException("The instance {} does not exist in project {} !". - format(instance_id, project_id)) - database = instance.database(database_id=database_id, - ddl_statements=ddl_statements) + raise AirflowException( + "The instance {} does not exist in project {} !".format(instance_id, project_id) + ) + database = instance.database(database_id=database_id, ddl_statements=ddl_statements) try: operation = database.create() # type: Operation except GoogleAPICallError as e: @@ -296,7 +274,7 @@ def update_database( database_id: str, ddl_statements: List[str], project_id: str, - operation_id: Optional[str] = None + operation_id: Optional[str] = None, ) -> None: """ Updates DDL of a database in Cloud Spanner. @@ -315,23 +293,24 @@ def update_database( :type operation_id: str :return: None """ - instance = self._get_client(project_id=project_id).instance( - instance_id=instance_id) + instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): - raise AirflowException("The instance {} does not exist in project {} !". - format(instance_id, project_id)) + raise AirflowException( + "The instance {} does not exist in project {} !".format(instance_id, project_id) + ) database = instance.database(database_id=database_id) try: - operation = database.update_ddl( - ddl_statements=ddl_statements, operation_id=operation_id) + operation = database.update_ddl(ddl_statements=ddl_statements, operation_id=operation_id) if operation: result = operation.result() self.log.info(result) return except AlreadyExists as e: if e.code == 409 and operation_id in e.message: - self.log.info("Replayed update_ddl message - the operation id %s " - "was already done before.", operation_id) + self.log.info( + "Replayed update_ddl message - the operation id %s " "was already done before.", + operation_id, + ) return except GoogleAPICallError as e: self.log.error('An error occurred: %s. Exiting.', e.message) @@ -352,16 +331,15 @@ def delete_database(self, instance_id: str, database_id, project_id: str) -> boo :return: True if everything succeeded :rtype: bool """ - instance = self._get_client(project_id=project_id).\ - instance(instance_id=instance_id) + instance = self._get_client(project_id=project_id).instance(instance_id=instance_id) if not instance.exists(): - raise AirflowException("The instance {} does not exist in project {} !". - format(instance_id, project_id)) + raise AirflowException( + "The instance {} does not exist in project {} !".format(instance_id, project_id) + ) database = instance.database(database_id=database_id) if not database.exists(): self.log.info( - "The database %s is already deleted from instance %s. Exiting.", - database_id, instance_id + "The database %s is already deleted from instance %s. Exiting.", database_id, instance_id ) return False try: @@ -373,13 +351,7 @@ def delete_database(self, instance_id: str, database_id, project_id: str) -> boo return True @GoogleBaseHook.fallback_to_default_project_id - def execute_dml( - self, - instance_id: str, - database_id: str, - queries: List[str], - project_id: str, - ) -> None: + def execute_dml(self, instance_id: str, database_id: str, queries: List[str], project_id: str,) -> None: """ Executes an arbitrary DML query (INSERT, UPDATE, DELETE). @@ -393,9 +365,9 @@ def execute_dml( database. If set to None or missing, the default project_id from the GCP connection is used. :type project_id: str """ - self._get_client(project_id=project_id).instance(instance_id=instance_id).\ - database(database_id=database_id).run_in_transaction( - lambda transaction: self._execute_sql_in_transaction(transaction, queries)) + self._get_client(project_id=project_id).instance(instance_id=instance_id).database( + database_id=database_id + ).run_in_transaction(lambda transaction: self._execute_sql_in_transaction(transaction, queries)) @staticmethod def _execute_sql_in_transaction(transaction: Transaction, queries: List[str]): diff --git a/airflow/providers/google/cloud/hooks/speech_to_text.py b/airflow/providers/google/cloud/hooks/speech_to_text.py index 6cb943cfcdcb5..b0ef03812da81 100644 --- a/airflow/providers/google/cloud/hooks/speech_to_text.py +++ b/airflow/providers/google/cloud/hooks/speech_to_text.py @@ -55,9 +55,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -78,7 +76,7 @@ def recognize_speech( config: Union[Dict, RecognitionConfig], audio: Union[Dict, RecognitionAudio], retry: Optional[Retry] = None, - timeout: Optional[float] = None + timeout: Optional[float] = None, ): """ Recognizes audio input diff --git a/airflow/providers/google/cloud/hooks/stackdriver.py b/airflow/providers/google/cloud/hooks/stackdriver.py index 2aec897943a06..f814458ba6017 100644 --- a/airflow/providers/google/cloud/hooks/stackdriver.py +++ b/airflow/providers/google/cloud/hooks/stackdriver.py @@ -45,9 +45,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._policy_client = None self._channel_client = None @@ -72,7 +70,7 @@ def list_alert_policies( page_size: Optional[int] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> Any: """ Fetches all the Alert Policies identified by the filter passed as @@ -137,7 +135,7 @@ def _toggle_policy_status( filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ): client = self._get_policy_client() policies_ = self.list_alert_policies(project_id=project_id, filter_=filter_) @@ -147,11 +145,7 @@ def _toggle_policy_status( mask = monitoring_v3.types.field_mask_pb2.FieldMask() mask.paths.append('enabled') # pylint: disable=no-member client.update_alert_policy( - alert_policy=policy, - update_mask=mask, - retry=retry, - timeout=timeout, - metadata=metadata + alert_policy=policy, update_mask=mask, retry=retry, timeout=timeout, metadata=metadata ) @GoogleBaseHook.fallback_to_default_project_id @@ -161,7 +155,7 @@ def enable_alert_policies( filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Enables one or more disabled alerting policies identified by filter @@ -189,7 +183,7 @@ def enable_alert_policies( filter_=filter_, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -199,7 +193,7 @@ def disable_alert_policies( filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Disables one or more enabled alerting policies identified by filter @@ -227,7 +221,7 @@ def disable_alert_policies( new_state=False, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -237,7 +231,7 @@ def upsert_alert( project_id: str, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Creates a new alert or updates an existing policy identified @@ -264,10 +258,13 @@ def upsert_alert( channel_client = self._get_channel_client() record = json.loads(alerts) - existing_policies = [policy['name'] for policy in - self.list_alert_policies(project_id=project_id, format_='dict')] - existing_channels = [channel['name'] for channel in - self.list_notification_channels(project_id=project_id, format_='dict')] + existing_policies = [ + policy['name'] for policy in self.list_alert_policies(project_id=project_id, format_='dict') + ] + existing_channels = [ + channel['name'] + for channel in self.list_notification_channels(project_id=project_id, format_='dict') + ] policies_ = [] channels = [] @@ -281,15 +278,13 @@ def upsert_alert( channel_name_map = {} for channel in channels: - channel.verification_status = monitoring_v3.enums.NotificationChannel. \ - VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED + channel.verification_status = ( + monitoring_v3.enums.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED + ) if channel.name in existing_channels: channel_client.update_notification_channel( - notification_channel=channel, - retry=retry, - timeout=timeout, - metadata=metadata + notification_channel=channel, retry=retry, timeout=timeout, metadata=metadata ) else: old_name = channel.name @@ -299,7 +294,7 @@ def upsert_alert( notification_channel=channel, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) channel_name_map[old_name] = new_channel.name @@ -315,10 +310,7 @@ def upsert_alert( if policy.name in existing_policies: try: policy_client.update_alert_policy( - alert_policy=policy, - retry=retry, - timeout=timeout, - metadata=metadata + alert_policy=policy, retry=retry, timeout=timeout, metadata=metadata ) except InvalidArgument: pass @@ -331,7 +323,7 @@ def upsert_alert( alert_policy=policy, retry=retry, timeout=timeout, - metadata=None + metadata=None, ) def delete_alert_policy( @@ -339,7 +331,7 @@ def delete_alert_policy( name: str, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Deletes an alerting policy. @@ -360,16 +352,9 @@ def delete_alert_policy( policy_client = self._get_policy_client() try: - policy_client.delete_alert_policy( - name=name, - retry=retry, - timeout=timeout, - metadata=metadata - ) + policy_client.delete_alert_policy(name=name, retry=retry, timeout=timeout, metadata=metadata) except HttpError as err: - raise AirflowException( - 'Delete alerting policy failed. Error was {}'.format(err.content) - ) + raise AirflowException('Delete alerting policy failed. Error was {}'.format(err.content)) @GoogleBaseHook.fallback_to_default_project_id def list_notification_channels( @@ -447,12 +432,11 @@ def _toggle_channel_status( filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[str] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: client = self._get_channel_client() channels = client.list_notification_channels( - name='projects/{project_id}'.format(project_id=project_id), - filter_=filter_ + name='projects/{project_id}'.format(project_id=project_id), filter_=filter_ ) for channel in channels: if channel.enabled.value != bool(new_state): @@ -464,7 +448,7 @@ def _toggle_channel_status( update_mask=mask, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -474,7 +458,7 @@ def enable_notification_channels( filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[str] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Enables one or more disabled alerting policies identified by filter @@ -503,7 +487,7 @@ def enable_notification_channels( new_state=True, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -513,7 +497,7 @@ def disable_notification_channels( filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[str] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Disables one or more enabled notification channels identified by filter @@ -542,7 +526,7 @@ def disable_notification_channels( new_state=False, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -552,7 +536,7 @@ def upsert_channel( project_id: str, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> Dict: """ Creates a new notification or updates an existing notification channel @@ -579,8 +563,10 @@ def upsert_channel( channel_client = self._get_channel_client() record = json.loads(channels) - existing_channels = [channel["name"] for channel in - self.list_notification_channels(project_id=project_id, format_="dict")] + existing_channels = [ + channel["name"] + for channel in self.list_notification_channels(project_id=project_id, format_="dict") + ] channels_list = [] channel_name_map = {} @@ -591,15 +577,13 @@ def upsert_channel( ) for channel in channels_list: - channel.verification_status = monitoring_v3.enums.NotificationChannel. \ - VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED + channel.verification_status = ( + monitoring_v3.enums.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED + ) if channel.name in existing_channels: channel_client.update_notification_channel( - notification_channel=channel, - retry=retry, - timeout=timeout, - metadata=metadata + notification_channel=channel, retry=retry, timeout=timeout, metadata=metadata ) else: old_name = channel.name @@ -609,7 +593,7 @@ def upsert_channel( notification_channel=channel, retry=retry, timeout=timeout, - metadata=metadata + metadata=metadata, ) channel_name_map[old_name] = new_channel.name @@ -620,7 +604,7 @@ def delete_notification_channel( name: str, retry: Optional[str] = DEFAULT, timeout: Optional[str] = DEFAULT, - metadata: Optional[str] = None + metadata: Optional[str] = None, ) -> None: """ Deletes a notification channel. @@ -642,12 +626,7 @@ def delete_notification_channel( channel_client = self._get_channel_client() try: channel_client.delete_notification_channel( - name=name, - retry=retry, - timeout=timeout, - metadata=metadata + name=name, retry=retry, timeout=timeout, metadata=metadata ) except HttpError as err: - raise AirflowException( - 'Delete notification channel failed. Error was {}'.format(err.content) - ) + raise AirflowException('Delete notification channel failed. Error was {}'.format(err.content)) diff --git a/airflow/providers/google/cloud/hooks/tasks.py b/airflow/providers/google/cloud/hooks/tasks.py index e324012a32d7e..4dccad7d81c1a 100644 --- a/airflow/providers/google/cloud/hooks/tasks.py +++ b/airflow/providers/google/cloud/hooks/tasks.py @@ -63,9 +63,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -77,10 +75,7 @@ def get_conn(self): :rtype: google.cloud.tasks_v2.CloudTasksClient """ if not self._client: - self._client = CloudTasksClient( - credentials=self._get_credentials(), - client_info=self.client_info - ) + self._client = CloudTasksClient(credentials=self._get_credentials(), client_info=self.client_info) return self._client @GoogleBaseHook.fallback_to_default_project_id @@ -92,7 +87,7 @@ def create_queue( queue_name: Optional[str] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Queue: """ Creates a queue in Cloud Tasks. @@ -133,11 +128,7 @@ def create_queue( raise AirflowException('Unable to set queue_name.') full_location_path = CloudTasksClient.location_path(project_id, location) return client.create_queue( - parent=full_location_path, - queue=task_queue, - retry=retry, - timeout=timeout, - metadata=metadata, + parent=full_location_path, queue=task_queue, retry=retry, timeout=timeout, metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -150,7 +141,7 @@ def update_queue( update_mask: Optional[FieldMask] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Queue: """ Updates a queue in Cloud Tasks. @@ -195,11 +186,7 @@ def update_queue( else: raise AirflowException('Unable to set queue_name.') return client.update_queue( - queue=task_queue, - update_mask=update_mask, - retry=retry, - timeout=timeout, - metadata=metadata, + queue=task_queue, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -210,7 +197,7 @@ def get_queue( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Queue: """ Gets a queue from Cloud Tasks. @@ -237,9 +224,7 @@ def get_queue( client = self.get_conn() full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.get_queue( - name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata - ) + return client.get_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def list_queues( @@ -250,7 +235,7 @@ def list_queues( page_size: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> List[Queue]: """ Lists queues from Cloud Tasks. @@ -298,7 +283,7 @@ def delete_queue( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> None: """ Deletes a queue from Cloud Tasks, even if it has tasks in it. @@ -324,9 +309,7 @@ def delete_queue( client = self.get_conn() full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - client.delete_queue( - name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata - ) + client.delete_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def purge_queue( @@ -336,7 +319,7 @@ def purge_queue( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> List[Queue]: """ Purges a queue by deleting all of its tasks from Cloud Tasks. @@ -363,9 +346,7 @@ def purge_queue( client = self.get_conn() full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.purge_queue( - name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata - ) + return client.purge_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def pause_queue( @@ -375,7 +356,7 @@ def pause_queue( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> List[Queue]: """ Pauses a queue in Cloud Tasks. @@ -402,9 +383,7 @@ def pause_queue( client = self.get_conn() full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.pause_queue( - name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata - ) + return client.pause_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def resume_queue( @@ -414,7 +393,7 @@ def resume_queue( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> List[Queue]: """ Resumes a queue in Cloud Tasks. @@ -441,9 +420,7 @@ def resume_queue( client = self.get_conn() full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.resume_queue( - name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata - ) + return client.resume_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def create_task( @@ -456,7 +433,7 @@ def create_task( response_view: Optional[enums.Task.View] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Task: """ Creates a task in Cloud Tasks. @@ -492,9 +469,7 @@ def create_task( client = self.get_conn() if task_name: - full_task_name = CloudTasksClient.task_path( - project_id, location, queue_name, task_name - ) + full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) if isinstance(task, Task): task.name = full_task_name elif isinstance(task, dict): @@ -521,7 +496,7 @@ def get_task( response_view: Optional[enums.Task.View] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Task: """ Gets a task from Cloud Tasks. @@ -554,11 +529,7 @@ def get_task( full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) return client.get_task( - name=full_task_name, - response_view=response_view, - retry=retry, - timeout=timeout, - metadata=metadata, + name=full_task_name, response_view=response_view, retry=retry, timeout=timeout, metadata=metadata, ) @GoogleBaseHook.fallback_to_default_project_id @@ -571,7 +542,7 @@ def list_tasks( page_size: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> List[Task]: """ Lists the tasks in Cloud Tasks. @@ -622,7 +593,7 @@ def delete_task( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> None: """ Deletes a task from Cloud Tasks. @@ -650,9 +621,7 @@ def delete_task( client = self.get_conn() full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) - client.delete_task( - name=full_task_name, retry=retry, timeout=timeout, metadata=metadata - ) + client.delete_task(name=full_task_name, retry=retry, timeout=timeout, metadata=metadata) @GoogleBaseHook.fallback_to_default_project_id def run_task( @@ -664,7 +633,7 @@ def run_task( response_view: Optional[enums.Task.View] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Task: """ Forces to run a task in Cloud Tasks. @@ -697,9 +666,5 @@ def run_task( full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) return client.run_task( - name=full_task_name, - response_view=response_view, - retry=retry, - timeout=timeout, - metadata=metadata, + name=full_task_name, response_view=response_view, retry=retry, timeout=timeout, metadata=metadata, ) diff --git a/airflow/providers/google/cloud/hooks/text_to_speech.py b/airflow/providers/google/cloud/hooks/text_to_speech.py index 7e16cd3cfd4b3..06afe0ddf2d0a 100644 --- a/airflow/providers/google/cloud/hooks/text_to_speech.py +++ b/airflow/providers/google/cloud/hooks/text_to_speech.py @@ -23,7 +23,10 @@ from google.api_core.retry import Retry from google.cloud.texttospeech_v1 import TextToSpeechClient from google.cloud.texttospeech_v1.types import ( - AudioConfig, SynthesisInput, SynthesizeSpeechResponse, VoiceSelectionParams, + AudioConfig, + SynthesisInput, + SynthesizeSpeechResponse, + VoiceSelectionParams, ) from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -60,9 +63,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None # type: Optional[TextToSpeechClient] @@ -76,8 +77,7 @@ def get_conn(self) -> TextToSpeechClient: if not self._client: # pylint: disable=unexpected-keyword-arg self._client = TextToSpeechClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) # pylint: enable=unexpected-keyword-arg @@ -90,7 +90,7 @@ def synthesize_speech( voice: Union[Dict, VoiceSelectionParams], audio_config: Union[Dict, AudioConfig], retry: Optional[Retry] = None, - timeout: Optional[float] = None + timeout: Optional[float] = None, ) -> SynthesizeSpeechResponse: """ Synthesizes text input diff --git a/airflow/providers/google/cloud/hooks/translate.py b/airflow/providers/google/cloud/hooks/translate.py index e3bc978100fbc..4c06a2de82e79 100644 --- a/airflow/providers/google/cloud/hooks/translate.py +++ b/airflow/providers/google/cloud/hooks/translate.py @@ -40,9 +40,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None # type: Optional[Client] @@ -64,7 +62,7 @@ def translate( target_language: str, format_: Optional[str] = None, source_language: Optional[str] = None, - model: Optional[Union[str, List[str]]] = None + model: Optional[Union[str, List[str]]] = None, ) -> Dict: """Translate a string or list of strings. diff --git a/airflow/providers/google/cloud/hooks/video_intelligence.py b/airflow/providers/google/cloud/hooks/video_intelligence.py index a1f1c3f8fe184..421c2fcfe9d4a 100644 --- a/airflow/providers/google/cloud/hooks/video_intelligence.py +++ b/airflow/providers/google/cloud/hooks/video_intelligence.py @@ -59,9 +59,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._conn = None @@ -73,8 +71,7 @@ def get_conn(self) -> VideoIntelligenceServiceClient: """ if not self._conn: self._conn = VideoIntelligenceServiceClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) return self._conn diff --git a/airflow/providers/google/cloud/hooks/vision.py b/airflow/providers/google/cloud/hooks/vision.py index 84c3df58ea4a4..9665db95e10e8 100644 --- a/airflow/providers/google/cloud/hooks/vision.py +++ b/airflow/providers/google/cloud/hooks/vision.py @@ -26,23 +26,26 @@ from google.api_core.retry import Retry from google.cloud.vision_v1 import ImageAnnotatorClient, ProductSearchClient from google.cloud.vision_v1.types import ( - AnnotateImageRequest, FieldMask, Image, Product, ProductSet, ReferenceImage, + AnnotateImageRequest, + FieldMask, + Image, + Product, + ProductSet, + ReferenceImage, ) from google.protobuf.json_format import MessageToDict from airflow.exceptions import AirflowException from airflow.providers.google.common.hooks.base_google import GoogleBaseHook -ERR_DIFF_NAMES = \ - """The {label} name provided in the object ({explicit_name}) is different than the name created - from the input parameters ({constructed_name}). Please either: +ERR_DIFF_NAMES = """The {label} name provided in the object ({explicit_name}) is different + than the name created from the input parameters ({constructed_name}). Please either: 1) Remove the {label} name, 2) Remove the location and {id_label} parameters, 3) Unify the {label} name and input parameters. """ -ERR_UNABLE_TO_CREATE = \ - """Unable to determine the {label} name. Please either set the name directly +ERR_UNABLE_TO_CREATE = """Unable to determine the {label} name. Please either set the name directly in the {label} object or provide the `location` and `{id_label}` parameters. """ @@ -51,17 +54,14 @@ class NameDeterminer: """ Helper class to determine entity name. """ + def __init__(self, label: str, id_label: str, get_path: Callable[[str, str, str], str]) -> None: self.label = label self.id_label = id_label self.get_path = get_path def get_entity_with_name( - self, - entity: Any, - entity_id: Optional[str], - location: Optional[str], - project_id: str + self, entity: Any, entity_id: Optional[str], location: Optional[str], project_id: str ) -> Any: """ Check if entity has the `name` attribute set: @@ -100,20 +100,20 @@ def get_entity_with_name( return entity if explicit_name != constructed_name: - raise AirflowException(ERR_DIFF_NAMES.format( - label=self.label, - explicit_name=explicit_name, - constructed_name=constructed_name, - id_label=self.id_label) + raise AirflowException( + ERR_DIFF_NAMES.format( + label=self.label, + explicit_name=explicit_name, + constructed_name=constructed_name, + id_label=self.id_label, + ) ) # Not enough parameters to construct the name. Trying to use the name from Product / ProductSet. if explicit_name: return entity else: - raise AirflowException( - ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label) - ) + raise AirflowException(ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label)) class CloudVisionHook(GoogleBaseHook): @@ -136,9 +136,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self._client = None @@ -151,8 +149,7 @@ def get_conn(self) -> ProductSearchClient: """ if not self._client: self._client = ProductSearchClient( - credentials=self._get_credentials(), - client_info=self.client_info + credentials=self._get_credentials(), client_info=self.client_info ) return self._client @@ -214,7 +211,8 @@ def get_product_set( product_set_id: str, project_id: str, retry: Optional[Retry] = None, - timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> Dict: """ For the documentation see: @@ -264,7 +262,7 @@ def delete_product_set( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ): """ For the documentation see: @@ -285,7 +283,7 @@ def create_product( product_id: Optional[str] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ): """ For the documentation see: @@ -320,7 +318,7 @@ def get_product( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ): """ For the documentation see: @@ -368,7 +366,7 @@ def delete_product( project_id: str, retry: Optional[Retry] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[str, str]]] = None + metadata: Optional[Sequence[Tuple[str, str]]] = None, ): """ For the documentation see: @@ -441,10 +439,8 @@ def delete_reference_image( name = ProductSearchClient.reference_image_path( project=project_id, location=location, product=product_id, reference_image=reference_image_id ) - response = client.delete_reference_image(name=name, # pylint: disable=assignment-from-no-return - retry=retry, - timeout=timeout, - metadata=metadata) + # pylint: disable=assignment-from-no-return + response = client.delete_reference_image(name=name, retry=retry, timeout=timeout, metadata=metadata,) self.log.info('ReferenceImage with the name [%s] deleted.', name) return MessageToDict(response) @@ -509,7 +505,7 @@ def annotate_image( self, request: Union[dict, AnnotateImageRequest], retry: Optional[Retry] = None, - timeout: Optional[float] = None + timeout: Optional[float] = None, ) -> Dict: """ For the documentation see: @@ -531,7 +527,7 @@ def batch_annotate_images( self, requests: Union[List[dict], List[AnnotateImageRequest]], retry: Optional[Retry] = None, - timeout: Optional[float] = None + timeout: Optional[float] = None, ) -> Dict: """ For the documentation see: @@ -541,9 +537,9 @@ def batch_annotate_images( self.log.info('Annotating images') - response = client.batch_annotate_images(requests=requests, # pylint: disable=no-member - retry=retry, - timeout=timeout) + response = client.batch_annotate_images( + requests=requests, retry=retry, timeout=timeout # pylint: disable=no-member + ) self.log.info('Images annotated') @@ -556,7 +552,7 @@ def text_detection( max_results: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None + additional_properties: Optional[Dict] = None, ) -> Dict: """ For the documentation see: @@ -586,7 +582,7 @@ def document_text_detection( max_results: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None + additional_properties: Optional[Dict] = None, ) -> Dict: """ For the documentation see: @@ -616,7 +612,7 @@ def label_detection( max_results: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None + additional_properties: Optional[Dict] = None, ) -> Dict: """ For the documentation see: @@ -646,7 +642,7 @@ def safe_search_detection( max_results: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, - additional_properties: Optional[Dict] = None + additional_properties: Optional[Dict] = None, ) -> Dict: """ For the documentation see: diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py b/airflow/providers/google/cloud/log/gcs_task_handler.py index 077282c22fd48..99a9a55612a44 100644 --- a/airflow/providers/google/cloud/log/gcs_task_handler.py +++ b/airflow/providers/google/cloud/log/gcs_task_handler.py @@ -27,9 +27,7 @@ from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin -_DEFAULT_SCOPESS = frozenset([ - "https://www.googleapis.com/auth/devstorage.read_write", -]) +_DEFAULT_SCOPESS = frozenset(["https://www.googleapis.com/auth/devstorage.read_write",]) class GCSTaskHandler(FileTaskHandler, LoggingMixin): @@ -59,6 +57,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin): will be used. :type project_id: str """ + def __init__( self, *, @@ -89,12 +88,12 @@ def client(self) -> storage.Client: key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=self.scopes, - disable_logging=True + disable_logging=True, ) return storage.Client( credentials=credentials, client_info=ClientInfo(client_library_version='airflow_v' + version.version), - project=self.project_id if self.project_id else project_id + project=self.project_id if self.project_id else project_id, ) def set_context(self, ti): @@ -151,12 +150,10 @@ def _read(self, ti, try_number, metadata=None): try: blob = storage.Blob.from_string(remote_loc, self.client) remote_log = blob.download_as_string() - log = '*** Reading remote log from {}.\n{}\n'.format( - remote_loc, remote_log) + log = '*** Reading remote log from {}.\n{}\n'.format(remote_loc, remote_log) return log, {'end_of_log': True} except Exception as e: # pylint: disable=broad-except - log = '*** Unable to read remote log from {}\n*** {}\n\n'.format( - remote_loc, str(e)) + log = '*** Unable to read remote log from {}\n*** {}\n\n'.format(remote_loc, str(e)) self.log.error(log) local_log, metadata = super()._read(ti, try_number) log += local_log diff --git a/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/airflow/providers/google/cloud/log/stackdriver_task_handler.py index 2161de98f243f..647fc3c12fc41 100644 --- a/airflow/providers/google/cloud/log/stackdriver_task_handler.py +++ b/airflow/providers/google/cloud/log/stackdriver_task_handler.py @@ -34,10 +34,9 @@ DEFAULT_LOGGER_NAME = "airflow" _GLOBAL_RESOURCE = Resource(type="global", labels={}) -_DEFAULT_SCOPESS = frozenset([ - "https://www.googleapis.com/auth/logging.read", - "https://www.googleapis.com/auth/logging.write" -]) +_DEFAULT_SCOPESS = frozenset( + ["https://www.googleapis.com/auth/logging.read", "https://www.googleapis.com/auth/logging.write"] +) class StackdriverTaskHandler(logging.Handler): @@ -107,14 +106,12 @@ def __init__( def _client(self) -> gcp_logging.Client: """Google Cloud Library API client""" credentials, project = get_credentials_and_project_id( - key_path=self.gcp_key_path, - scopes=self.scopes, - disable_logging=True + key_path=self.gcp_key_path, scopes=self.scopes, disable_logging=True ) client = gcp_logging.Client( credentials=credentials, project=project, - client_info=ClientInfo(client_library_version='airflow_v' + version.version) + client_info=ClientInfo(client_library_version='airflow_v' + version.version), ) return client @@ -207,6 +204,7 @@ def _prepare_log_filter(self, ti_labels: Dict[str, str]) -> str: :type: Dict[str, str] :return: logs filter """ + def escape_label_key(key: str) -> str: return f'"{key}"' if "." in key else key @@ -216,7 +214,7 @@ def escale_label_value(value: str) -> str: log_filters = [ f'resource.type={escale_label_value(self.resource.type)}', - f'logName="projects/{self._client.project}/logs/{self.name}"' + f'logName="projects/{self._client.project}/logs/{self.name}"', ] for key, value in self.resource.labels.items(): @@ -227,10 +225,7 @@ def escale_label_value(value: str) -> str: return "\n".join(log_filters) def _read_logs( - self, - log_filter: str, - next_page_token: Optional[str], - all_pages: bool + self, log_filter: str, next_page_token: Optional[str], all_pages: bool ) -> Tuple[str, bool, Optional[str]]: """ Sends requests to the Stackdriver service and downloads logs. @@ -249,15 +244,13 @@ def _read_logs( """ messages = [] new_messages, next_page_token = self._read_single_logs_page( - log_filter=log_filter, - page_token=next_page_token, + log_filter=log_filter, page_token=next_page_token, ) messages.append(new_messages) if all_pages: while next_page_token: new_messages, next_page_token = self._read_single_logs_page( - log_filter=log_filter, - page_token=next_page_token + log_filter=log_filter, page_token=next_page_token ) messages.append(new_messages) diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index 38fd11dfaf33f..dfb84238ca89d 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -69,11 +69,17 @@ class AutoMLTrainModelOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("model", "location", "project_id", "impersonation_chain",) + template_fields = ( + "model", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, model: dict, location: str, project_id: Optional[str] = None, @@ -82,7 +88,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -96,10 +102,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Creating model.") operation = hook.create_model( model=self.model, @@ -157,11 +160,17 @@ class AutoMLPredictOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("model_id", "location", "project_id", "impersonation_chain",) + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, model_id: str, location: str, payload: dict, @@ -172,7 +181,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -188,10 +197,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) result = hook.predict( model_id=self.model_id, payload=self.payload, @@ -268,7 +274,8 @@ class AutoMLBatchPredictOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, model_id: str, input_config: dict, output_config: dict, @@ -280,7 +287,7 @@ def __init__( # pylint: disable=too-many-arguments retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -297,10 +304,7 @@ def __init__( # pylint: disable=too-many-arguments self.output_config = output_config def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Fetch batch prediction.") operation = hook.batch_predict( model_id=self.model_id, @@ -357,11 +361,17 @@ class AutoMLCreateDatasetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dataset", "location", "project_id", "impersonation_chain",) + template_fields = ( + "dataset", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dataset: dict, location: str, project_id: Optional[str] = None, @@ -370,7 +380,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -384,10 +394,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Creating dataset") result = hook.create_dataset( dataset=self.dataset, @@ -446,12 +453,18 @@ class AutoMLImportDataOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dataset_id", "input_config", "location", "project_id", - "impersonation_chain",) + template_fields = ( + "dataset_id", + "input_config", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, location: str, input_config: dict, @@ -461,7 +474,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -476,10 +489,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Importing dataset") operation = hook.import_data( dataset_id=self.dataset_id, @@ -556,7 +566,8 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, dataset_id: str, table_spec_id: str, location: str, @@ -569,7 +580,7 @@ def __init__( # pylint: disable=too-many-arguments retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.dataset_id = dataset_id @@ -586,10 +597,7 @@ def __init__( # pylint: disable=too-many-arguments self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Requesting column specs.") page_iterator = hook.list_column_specs( dataset_id=self.dataset_id, @@ -648,11 +656,17 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dataset", "update_mask", "location", "impersonation_chain",) + template_fields = ( + "dataset", + "update_mask", + "location", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dataset: dict, location: str, update_mask: Optional[dict] = None, @@ -661,7 +675,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -675,10 +689,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Updating AutoML dataset %s.", self.dataset["name"]) result = hook.update_dataset( dataset=self.dataset, @@ -729,11 +740,17 @@ class AutoMLGetModelOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("model_id", "location", "project_id", "impersonation_chain",) + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, model_id: str, location: str, project_id: Optional[str] = None, @@ -742,7 +759,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -756,10 +773,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) result = hook.get_model( model_id=self.model_id, location=self.location, @@ -809,11 +823,17 @@ class AutoMLDeleteModelOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("model_id", "location", "project_id", "impersonation_chain",) + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, model_id: str, location: str, project_id: Optional[str] = None, @@ -822,7 +842,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -836,10 +856,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) operation = hook.delete_model( model_id=self.model_id, location=self.location, @@ -899,11 +916,17 @@ class AutoMLDeployModelOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("model_id", "location", "project_id", "impersonation_chain",) + template_fields = ( + "model_id", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, model_id: str, location: str, project_id: Optional[str] = None, @@ -913,7 +936,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -928,10 +951,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Deploying model_id %s", self.model_id) operation = hook.deploy_model( @@ -992,11 +1012,18 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dataset_id", "filter_", "location", "project_id", "impersonation_chain",) + template_fields = ( + "dataset_id", + "filter_", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, location: str, page_size: Optional[int] = None, @@ -1007,7 +1034,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.dataset_id = dataset_id @@ -1022,10 +1049,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Requesting table specs for %s.", self.dataset_id) page_iterator = hook.list_table_specs( dataset_id=self.dataset_id, @@ -1077,11 +1101,16 @@ class AutoMLListDatasetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("location", "project_id", "impersonation_chain",) + template_fields = ( + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, project_id: Optional[str] = None, metadata: Optional[MetaData] = None, @@ -1089,7 +1118,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1101,10 +1130,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Requesting datasets") page_iterator = hook.list_datasets( location=self.location, @@ -1117,9 +1143,7 @@ def execute(self, context): self.log.info("Datasets obtained.") self.xcom_push( - context, - key="dataset_id_list", - value=[hook.extract_object_id(d) for d in result], + context, key="dataset_id_list", value=[hook.extract_object_id(d) for d in result], ) return result @@ -1161,11 +1185,17 @@ class AutoMLDeleteDatasetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dataset_id", "location", "project_id", "impersonation_chain",) + template_fields = ( + "dataset_id", + "location", + "project_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dataset_id: Union[str, List[str]], location: str, project_id: Optional[str] = None, @@ -1174,7 +1204,7 @@ def __init__( retry: Optional[Retry] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -1197,10 +1227,7 @@ def _parse_dataset_id(dataset_id: Union[str, List[str]]) -> List[str]: return dataset_id.split(",") def execute(self, context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) dataset_id_list = self._parse_dataset_id(self.dataset_id) for dataset_id in dataset_id_list: self.log.info("Deleting dataset %s", dataset_id) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index e02601b40765b..ac469c25c0b96 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -42,12 +42,14 @@ BIGQUERY_JOB_DETAILS_LINK_FMT = "https://console.cloud.google.com/bigquery?j={job_id}" -_DEPRECATION_MSG = "The bigquery_conn_id parameter has been deprecated. " \ - "You should pass the gcp_conn_id parameter." +_DEPRECATION_MSG = ( + "The bigquery_conn_id parameter has been deprecated. You should pass the gcp_conn_id parameter." +) class BigQueryUIColors(enum.Enum): """Hex colors for BigQuery operators""" + CHECK = "#C0D7FF" QUERY = "#A1BBFF" TABLE = "#81A0FF" @@ -58,6 +60,7 @@ class BigQueryConsoleLink(BaseOperatorLink): """ Helper class for constructing BigQuery link. """ + name = 'BigQuery Console' def get_link(self, operator, dttm): @@ -145,13 +148,18 @@ class BigQueryCheckOperator(CheckOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('sql', 'gcp_conn_id', 'impersonation_chain',) + template_fields = ( + 'sql', + 'gcp_conn_id', + 'impersonation_chain', + ) template_ext = ('.sql',) ui_color = BigQueryUIColors.CHECK.value @apply_defaults def __init__( - self, *, + self, + *, sql: str, gcp_conn_id: str = 'google_cloud_default', bigquery_conn_id: Optional[str] = None, @@ -212,13 +220,19 @@ class BigQueryValueCheckOperator(ValueCheckOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('sql', 'gcp_conn_id', 'pass_value', 'impersonation_chain',) + template_fields = ( + 'sql', + 'gcp_conn_id', + 'pass_value', + 'impersonation_chain', + ) template_ext = ('.sql',) ui_color = BigQueryUIColors.CHECK.value @apply_defaults def __init__( - self, *, + self, + *, sql: str, pass_value: Any, tolerance: Any = None, @@ -229,12 +243,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, **kwargs, ) -> None: - super().__init__( - sql=sql, - pass_value=pass_value, - tolerance=tolerance, - **kwargs - ) + super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs) if bigquery_conn_id: warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) @@ -299,12 +308,19 @@ class BigQueryIntervalCheckOperator(IntervalCheckOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('table', 'gcp_conn_id', 'sql1', 'sql2', 'impersonation_chain',) + template_fields = ( + 'table', + 'gcp_conn_id', + 'sql1', + 'sql2', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.CHECK.value @apply_defaults def __init__( - self, *, + self, + *, table: str, metrics_thresholds: dict, date_filter_column: str = 'ds', @@ -321,7 +337,7 @@ def __init__( metrics_thresholds=metrics_thresholds, date_filter_column=date_filter_column, days_back=days_back, - **kwargs + **kwargs, ) if bigquery_conn_id: @@ -405,13 +421,20 @@ class BigQueryGetDataOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('dataset_id', 'table_id', 'max_results', 'selected_fields', - 'impersonation_chain',) + + template_fields = ( + 'dataset_id', + 'table_id', + 'max_results', + 'selected_fields', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.QUERY.value @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, table_id: str, max_results: int = 100, @@ -421,14 +444,17 @@ def __init__( delegate_to: Optional[str] = None, location: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.dataset_id = dataset_id @@ -441,8 +467,9 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - self.log.info('Fetching Data from %s.%s max results: %s', - self.dataset_id, self.table_id, self.max_results) + self.log.info( + 'Fetching Data from %s.%s max results: %s', self.dataset_id, self.table_id, self.max_results + ) hook = BigQueryHook( bigquery_conn_id=self.gcp_conn_id, @@ -455,7 +482,7 @@ def execute(self, context): table_id=self.table_id, max_results=self.max_results, selected_fields=self.selected_fields, - location=self.location + location=self.location, ) self.log.info('Total extracted rows: %s', len(rows)) @@ -567,9 +594,14 @@ class BigQueryExecuteQueryOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('sql', 'destination_dataset_table', 'labels', 'query_params', - 'impersonation_chain',) - template_ext = ('.sql', ) + template_fields = ( + 'sql', + 'destination_dataset_table', + 'labels', + 'query_params', + 'impersonation_chain', + ) + template_ext = ('.sql',) ui_color = BigQueryUIColors.QUERY.value @property @@ -578,52 +610,51 @@ def operator_extra_links(self): Return operator extra links """ if isinstance(self.sql, str): - return ( - BigQueryConsoleLink(), - ) - return ( - BigQueryConsoleIndexableLink(i) for i, _ in enumerate(self.sql) - ) + return (BigQueryConsoleLink(),) + return (BigQueryConsoleIndexableLink(i) for i, _ in enumerate(self.sql)) # pylint: disable=too-many-arguments, too-many-locals @apply_defaults - def __init__(self, - *, - sql: Union[str, Iterable], - destination_dataset_table: Optional[str] = None, - write_disposition: Optional[str] = 'WRITE_EMPTY', - allow_large_results: Optional[bool] = False, - flatten_results: Optional[bool] = None, - gcp_conn_id: Optional[str] = 'google_cloud_default', - bigquery_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - udf_config: Optional[list] = None, - use_legacy_sql: Optional[bool] = True, - maximum_billing_tier: Optional[int] = None, - maximum_bytes_billed: Optional[float] = None, - create_disposition: Optional[str] = 'CREATE_IF_NEEDED', - schema_update_options: Optional[Union[list, tuple, set]] = None, - query_params: Optional[list] = None, - labels: Optional[dict] = None, - priority: Optional[str] = 'INTERACTIVE', - time_partitioning: Optional[dict] = None, - api_resource_configs: Optional[dict] = None, - cluster_fields: Optional[List[str]] = None, - location: Optional[str] = None, - encryption_configuration: Optional[dict] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + sql: Union[str, Iterable], + destination_dataset_table: Optional[str] = None, + write_disposition: Optional[str] = 'WRITE_EMPTY', + allow_large_results: Optional[bool] = False, + flatten_results: Optional[bool] = None, + gcp_conn_id: Optional[str] = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + udf_config: Optional[list] = None, + use_legacy_sql: Optional[bool] = True, + maximum_billing_tier: Optional[int] = None, + maximum_bytes_billed: Optional[float] = None, + create_disposition: Optional[str] = 'CREATE_IF_NEEDED', + schema_update_options: Optional[Union[list, tuple, set]] = None, + query_params: Optional[list] = None, + labels: Optional[dict] = None, + priority: Optional[str] = 'INTERACTIVE', + time_partitioning: Optional[dict] = None, + api_resource_configs: Optional[dict] = None, + cluster_fields: Optional[List[str]] = None, + location: Optional[str] = None, + encryption_configuration: Optional[dict] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning) + "the gcp_conn_id parameter.", + DeprecationWarning, + ) gcp_conn_id = bigquery_conn_id warnings.warn( - "This operator is deprecated. Please use `BigQueryInsertJobOperator`.", - DeprecationWarning, + "This operator is deprecated. Please use `BigQueryInsertJobOperator`.", DeprecationWarning, ) self.sql = sql @@ -678,7 +709,7 @@ def execute(self, context): time_partitioning=self.time_partitioning, api_resource_configs=self.api_resource_configs, cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration + encryption_configuration=self.encryption_configuration, ) elif isinstance(self.sql, Iterable): job_id = [ @@ -699,12 +730,14 @@ def execute(self, context): time_partitioning=self.time_partitioning, api_resource_configs=self.api_resource_configs, cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration + encryption_configuration=self.encryption_configuration, ) - for s in self.sql] + for s in self.sql + ] else: raise AirflowException( - "argument 'sql' of type {} is neither a string nor an iterable".format(type(str))) + "argument 'sql' of type {} is neither a string nor an iterable".format(type(str)) + ) context['task_instance'].xcom_push(key='job_id', value=job_id) def on_kill(self): @@ -842,6 +875,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + template_fields = ( 'dataset_id', 'table_id', @@ -857,7 +891,8 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, table_id: str, table_resource: Optional[Dict[str, Any]] = None, @@ -874,7 +909,7 @@ def __init__( location: Optional[str] = None, cluster_fields: Optional[List[str]] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -910,9 +945,7 @@ def execute(self, context): delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - schema_fields = json.loads(gcs_hook.download( - gcs_bucket, - gcs_object).decode("utf-8")) + schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8")) else: schema_fields = self.schema_fields @@ -931,8 +964,9 @@ def execute(self, context): table_resource=self.table_resource, exists_ok=False, ) - self.log.info('Table %s.%s.%s created successfully', - table.project, table.dataset_id, table.table_id) + self.log.info( + 'Table %s.%s.%s created successfully', table.project, table.dataset_id, table.table_id + ) except Conflict: self.log.info('Table %s.%s already exists.', self.dataset_id, self.table_id) @@ -1037,6 +1071,7 @@ class BigQueryCreateExternalTableOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + template_fields = ( 'bucket', 'source_objects', @@ -1051,7 +1086,8 @@ class BigQueryCreateExternalTableOperator(BaseOperator): # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__( - self, *, + self, + *, bucket: str, source_objects: List, destination_project_dataset_table: str, @@ -1074,7 +1110,7 @@ def __init__( encryption_configuration: Optional[Dict] = None, location: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -1084,28 +1120,31 @@ def __init__( self.schema_object = schema_object # BQ config - kwargs_passed = any([ - destination_project_dataset_table, - schema_fields, - source_format, - compression, - skip_leading_rows, - field_delimiter, - max_bad_records, - quote_character, - allow_quoted_newlines, - allow_jagged_rows, - src_fmt_configs, - labels, - encryption_configuration, - ]) + kwargs_passed = any( + [ + destination_project_dataset_table, + schema_fields, + source_format, + compression, + skip_leading_rows, + field_delimiter, + max_bad_records, + quote_character, + allow_quoted_newlines, + allow_jagged_rows, + src_fmt_configs, + labels, + encryption_configuration, + ] + ) if not table_resource: warnings.warn( "Passing table parameters via keywords arguments will be deprecated. " "Please use provide table definition using `table_resource` parameter." "You can still use external `schema_object`. ", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) if table_resource and kwargs_passed: @@ -1180,7 +1219,7 @@ def execute(self, context): allow_jagged_rows=self.allow_jagged_rows, src_fmt_configs=self.src_fmt_configs, labels=self.labels, - encryption_configuration=self.encryption_configuration + encryption_configuration=self.encryption_configuration, ) @@ -1232,12 +1271,17 @@ class BigQueryDeleteDatasetOperator(BaseOperator): dag=dag) """ - template_fields = ('dataset_id', 'project_id', 'impersonation_chain',) + template_fields = ( + 'dataset_id', + 'project_id', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.DATASET.value @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, project_id: Optional[str] = None, delete_contents: bool = False, @@ -1245,12 +1289,15 @@ def __init__( bigquery_conn_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.dataset_id = dataset_id @@ -1272,9 +1319,7 @@ def execute(self, context): ) bq_hook.delete_dataset( - project_id=self.project_id, - dataset_id=self.dataset_id, - delete_contents=self.delete_contents + project_id=self.project_id, dataset_id=self.dataset_id, delete_contents=self.delete_contents ) @@ -1326,26 +1371,36 @@ class BigQueryCreateEmptyDatasetOperator(BaseOperator): dag=dag) """ - template_fields = ('dataset_id', 'project_id', 'dataset_reference', 'impersonation_chain',) + template_fields = ( + 'dataset_id', + 'project_id', + 'dataset_reference', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.DATASET.value @apply_defaults - def __init__(self, - *, - dataset_id: Optional[str] = None, - project_id: Optional[str] = None, - dataset_reference: Optional[Dict] = None, - location: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - bigquery_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + dataset_id: Optional[str] = None, + project_id: Optional[str] = None, + dataset_reference: Optional[Dict] = None, + location: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.dataset_id = dataset_id @@ -1413,18 +1468,24 @@ class BigQueryGetDatasetOperator(BaseOperator): https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ - template_fields = ('dataset_id', 'project_id', 'impersonation_chain',) + template_fields = ( + 'dataset_id', + 'project_id', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.DATASET.value @apply_defaults - def __init__(self, - *, - dataset_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + dataset_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.dataset_id = dataset_id self.project_id = project_id @@ -1434,15 +1495,15 @@ def __init__(self, super().__init__(**kwargs) def execute(self, context): - bq_hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - impersonation_chain=self.impersonation_chain) + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) self.log.info('Start getting dataset: %s:%s', self.project_id, self.dataset_id) - dataset = bq_hook.get_dataset( - dataset_id=self.dataset_id, - project_id=self.project_id) + dataset = bq_hook.get_dataset(dataset_id=self.dataset_id, project_id=self.project_id) return dataset.to_api_repr() @@ -1477,19 +1538,25 @@ class BigQueryGetDatasetTablesOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('dataset_id', 'project_id', 'impersonation_chain',) + + template_fields = ( + 'dataset_id', + 'project_id', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.DATASET.value @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, project_id: Optional[str] = None, max_results: Optional[int] = None, gcp_conn_id: Optional[str] = 'google_cloud_default', delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: self.dataset_id = dataset_id self.project_id = project_id @@ -1507,9 +1574,7 @@ def execute(self, context): ) return bq_hook.get_dataset_tables( - dataset_id=self.dataset_id, - project_id=self.project_id, - max_results=self.max_results, + dataset_id=self.dataset_id, project_id=self.project_id, max_results=self.max_results, ) @@ -1551,12 +1616,17 @@ class BigQueryPatchDatasetOperator(BaseOperator): https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ - template_fields = ('dataset_id', 'project_id', 'impersonation_chain',) + template_fields = ( + 'dataset_id', + 'project_id', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.DATASET.value @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, dataset_resource: dict, project_id: Optional[str] = None, @@ -1568,7 +1638,8 @@ def __init__( warnings.warn( "This operator is deprecated. Please use BigQueryUpdateDatasetOperator.", - DeprecationWarning, stacklevel=3 + DeprecationWarning, + stacklevel=3, ) self.dataset_id = dataset_id self.project_id = project_id @@ -1586,9 +1657,7 @@ def execute(self, context): ) return bq_hook.patch_dataset( - dataset_id=self.dataset_id, - dataset_resource=self.dataset_resource, - project_id=self.project_id, + dataset_id=self.dataset_id, dataset_resource=self.dataset_resource, project_id=self.project_id, ) @@ -1635,12 +1704,17 @@ class BigQueryUpdateDatasetOperator(BaseOperator): https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ - template_fields = ('dataset_id', 'project_id', 'impersonation_chain',) + template_fields = ( + 'dataset_id', + 'project_id', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.DATASET.value @apply_defaults def __init__( - self, *, + self, + *, dataset_resource: dict, fields: Optional[List[str]] = None, dataset_id: Optional[str] = None, @@ -1648,7 +1722,7 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: self.dataset_id = dataset_id self.project_id = project_id @@ -1712,12 +1786,17 @@ class BigQueryDeleteTableOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('deletion_dataset_table', 'impersonation_chain',) + + template_fields = ( + 'deletion_dataset_table', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.TABLE.value @apply_defaults def __init__( - self, *, + self, + *, deletion_dataset_table: str, gcp_conn_id: str = 'google_cloud_default', bigquery_conn_id: Optional[str] = None, @@ -1732,7 +1811,10 @@ def __init__( if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.deletion_dataset_table = deletion_dataset_table @@ -1750,10 +1832,7 @@ def execute(self, context): location=self.location, impersonation_chain=self.impersonation_chain, ) - hook.delete_table( - table_id=self.deletion_dataset_table, - not_found_ok=self.ignore_if_missing - ) + hook.delete_table(table_id=self.deletion_dataset_table, not_found_ok=self.ignore_if_missing) class BigQueryUpsertTableOperator(BaseOperator): @@ -1795,12 +1874,18 @@ class BigQueryUpsertTableOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('dataset_id', 'table_resource', 'impersonation_chain',) + + template_fields = ( + 'dataset_id', + 'table_resource', + 'impersonation_chain', + ) ui_color = BigQueryUIColors.TABLE.value @apply_defaults def __init__( - self, *, + self, + *, dataset_id: str, table_resource: dict, project_id: Optional[str] = None, @@ -1816,7 +1901,10 @@ def __init__( if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.dataset_id = dataset_id @@ -1836,9 +1924,7 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) hook.run_table_upsert( - dataset_id=self.dataset_id, - table_resource=self.table_resource, - project_id=self.project_id, + dataset_id=self.dataset_id, table_resource=self.table_resource, project_id=self.project_id, ) @@ -1900,8 +1986,12 @@ class BigQueryInsertJobOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("configuration", "job_id", "impersonation_chain",) - template_ext = (".json", ) + template_fields = ( + "configuration", + "job_id", + "impersonation_chain", + ) + template_ext = (".json",) ui_color = BigQueryUIColors.QUERY.value def __init__( @@ -1934,11 +2024,7 @@ def prepare_template(self) -> None: with open(self.configuration, 'r') as file: self.configuration = json.loads(file.read()) - def _submit_job( - self, - hook: BigQueryHook, - job_id: str, - ) -> BigQueryJob: + def _submit_job(self, hook: BigQueryHook, job_id: str,) -> BigQueryJob: # Submit a new job job = hook.insert_job( configuration=self.configuration, @@ -1983,11 +2069,7 @@ def execute(self, context: Any): self._handle_job_error(job) except Conflict: # If the job already exists retrieve it - job = hook.get_job( - project_id=self.project_id, - location=self.location, - job_id=job_id, - ) + job = hook.get_job(project_id=self.project_id, location=self.location, job_id=job_id,) if job.state in self.reattach_states: # We are reattaching to a job job.result() diff --git a/airflow/providers/google/cloud/operators/bigquery_dts.py b/airflow/providers/google/cloud/operators/bigquery_dts.py index 50e96c7a28772..8475d395906a0 100644 --- a/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -76,7 +76,8 @@ class BigQueryCreateDataTransferOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, transfer_config: dict, project_id: Optional[str] = None, authorization_code: Optional[str] = None, @@ -85,7 +86,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id="google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.transfer_config = transfer_config @@ -99,8 +100,7 @@ def __init__( def execute(self, context): hook = BiqQueryDataTransferServiceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) self.log.info("Creating DTS transfer config") response = hook.create_transfer_config( @@ -152,11 +152,17 @@ class BigQueryDeleteDataTransferConfigOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("transfer_config_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "transfer_config_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, transfer_config_id: str, project_id: Optional[str] = None, retry: Retry = None, @@ -164,7 +170,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id="google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.project_id = project_id @@ -177,8 +183,7 @@ def __init__( def execute(self, context): hook = BiqQueryDataTransferServiceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.delete_transfer_config( transfer_config_id=self.transfer_config_id, @@ -247,7 +252,8 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, transfer_config_id: str, project_id: Optional[str] = None, requested_time_range: Optional[dict] = None, @@ -257,7 +263,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id="google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.project_id = project_id @@ -272,8 +278,7 @@ def __init__( def execute(self, context): hook = BiqQueryDataTransferServiceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) self.log.info('Submitting manual transfer for %s', self.transfer_config_id) response = hook.start_manual_transfer_runs( diff --git a/airflow/providers/google/cloud/operators/bigtable.py b/airflow/providers/google/cloud/operators/bigtable.py index 13abe1ee7159c..7b0b0374125a2 100644 --- a/airflow/providers/google/cloud/operators/bigtable.py +++ b/airflow/providers/google/cloud/operators/bigtable.py @@ -106,27 +106,35 @@ class BigtableCreateInstanceOperator(BaseOperator, BigtableValidationMixin): """ REQUIRED_ATTRIBUTES: Iterable[str] = ('instance_id', 'main_cluster_id', 'main_cluster_zone') - template_fields: Iterable[str] = ['project_id', 'instance_id', 'main_cluster_id', - 'main_cluster_zone', 'impersonation_chain', ] + template_fields: Iterable[str] = [ + 'project_id', + 'instance_id', + 'main_cluster_id', + 'main_cluster_zone', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - instance_id: str, - main_cluster_id: str, - main_cluster_zone: str, - project_id: Optional[str] = None, - replica_clusters: Optional[List[Dict[str, str]]] = None, - replica_cluster_id: Optional[str] = None, - replica_cluster_zone: Optional[str] = None, - instance_display_name: Optional[str] = None, - instance_type: Optional[enums.Instance.Type] = None, - instance_labels: Optional[Dict] = None, - cluster_nodes: Optional[int] = None, - cluster_storage_type: Optional[enums.StorageType] = None, - timeout: Optional[float] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, # pylint: disable=too-many-arguments + instance_id: str, + main_cluster_id: str, + main_cluster_zone: str, + project_id: Optional[str] = None, + replica_clusters: Optional[List[Dict[str, str]]] = None, + replica_cluster_id: Optional[str] = None, + replica_cluster_zone: Optional[str] = None, + instance_display_name: Optional[str] = None, + instance_type: Optional[enums.Instance.Type] = None, + instance_labels: Optional[Dict] = None, + cluster_nodes: Optional[int] = None, + cluster_storage_type: Optional[enums.StorageType] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance_id = instance_id self.main_cluster_id = main_cluster_id @@ -146,19 +154,14 @@ def __init__(self, *, # pylint: disable=too-many-arguments super().__init__(**kwargs) def execute(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - instance = hook.get_instance(project_id=self.project_id, - instance_id=self.instance_id) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + instance = hook.get_instance(project_id=self.project_id, instance_id=self.instance_id) if instance: # Based on Instance.__eq__ instance with the same ID and client is # considered as equal. self.log.info( - "The instance '%s' already exists in this project. " - "Consider it as created", - self.instance_id + "The instance '%s' already exists in this project. " "Consider it as created", + self.instance_id, ) return try: @@ -222,19 +225,26 @@ class BigtableUpdateInstanceOperator(BaseOperator, BigtableValidationMixin): """ REQUIRED_ATTRIBUTES: Iterable[str] = ['instance_id'] - template_fields: Iterable[str] = ['project_id', 'instance_id', 'impersonation_chain', ] + template_fields: Iterable[str] = [ + 'project_id', + 'instance_id', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, *, - instance_id: str, - project_id: Optional[str] = None, - instance_display_name: Optional[str] = None, - instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, - instance_labels: Optional[Dict] = None, - timeout: Optional[float] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + project_id: Optional[str] = None, + instance_display_name: Optional[str] = None, + instance_type: Optional[Union[enums.Instance.Type, enum.IntEnum]] = None, + instance_labels: Optional[Dict] = None, + timeout: Optional[float] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance_id = instance_id self.instance_display_name = instance_display_name @@ -247,16 +257,10 @@ def __init__(self, *, super().__init__(**kwargs) def execute(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - instance = hook.get_instance(project_id=self.project_id, - instance_id=self.instance_id) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + instance = hook.get_instance(project_id=self.project_id, instance_id=self.instance_id) if not instance: - raise AirflowException( - f"Dependency: instance '{self.instance_id}' does not exist." - ) + raise AirflowException(f"Dependency: instance '{self.instance_id}' does not exist.") try: hook.update_instance( @@ -300,16 +304,24 @@ class BigtableDeleteInstanceOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + REQUIRED_ATTRIBUTES = ('instance_id',) # type: Iterable[str] - template_fields = ['project_id', 'instance_id', 'impersonation_chain', ] # type: Iterable[str] + template_fields = [ + 'project_id', + 'instance_id', + 'impersonation_chain', + ] # type: Iterable[str] @apply_defaults - def __init__(self, *, - instance_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance_id = instance_id self._validate_inputs() @@ -318,18 +330,14 @@ def __init__(self, *, super().__init__(**kwargs) def execute(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: - hook.delete_instance(project_id=self.project_id, - instance_id=self.instance_id) + hook.delete_instance(project_id=self.project_id, instance_id=self.instance_id) except google.api_core.exceptions.NotFound: self.log.info( - "The instance '%s' does not exist in project '%s'. " - "Consider it as deleted", - self.instance_id, self.project_id + "The instance '%s' does not exist in project '%s'. " "Consider it as deleted", + self.instance_id, + self.project_id, ) except google.api_core.exceptions.GoogleAPICallError as e: self.log.error('An error occurred. Exiting.') @@ -374,20 +382,28 @@ class BigtableCreateTableOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') # type: Iterable[str] - template_fields = ['project_id', 'instance_id', 'table_id', - 'impersonation_chain', ] # type: Iterable[str] + template_fields = [ + 'project_id', + 'instance_id', + 'table_id', + 'impersonation_chain', + ] # type: Iterable[str] @apply_defaults - def __init__(self, *, - instance_id: str, - table_id: str, - project_id: Optional[str] = None, - initial_split_keys: Optional[List] = None, - column_families: Optional[Dict[str, GarbageCollectionRule]] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + initial_split_keys: Optional[List] = None, + column_families: Optional[Dict[str, GarbageCollectionRule]] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance_id = instance_id self.table_id = table_id @@ -401,8 +417,7 @@ def __init__(self, *, def _compare_column_families(self, hook, instance): table_column_families = hook.get_column_families_for_table(instance, self.table_id) if set(table_column_families.keys()) != set(self.column_families.keys()): - self.log.error("Table '%s' has different set of Column Families", - self.table_id) + self.log.error("Table '%s' has different set of Column Families", self.table_id) self.log.error("Expected: %s", self.column_families.keys()) self.log.error("Actual: %s", table_column_families.keys()) return False @@ -416,35 +431,32 @@ def _compare_column_families(self, hook, instance): # For more information about ColumnFamily please refer to the documentation: # https://googleapis.github.io/google-cloud-python/latest/bigtable/column-family.html#google.cloud.bigtable.column_family.ColumnFamily if table_column_families[key].gc_rule != self.column_families[key]: - self.log.error("Column Family '%s' differs for table '%s'.", key, - self.table_id) + self.log.error("Column Family '%s' differs for table '%s'.", key, self.table_id) return False return True def execute(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) instance = hook.get_instance(project_id=self.project_id, instance_id=self.instance_id) if not instance: raise AirflowException( - "Dependency: instance '{}' does not exist in project '{}'.". - format(self.instance_id, self.project_id)) + "Dependency: instance '{}' does not exist in project '{}'.".format( + self.instance_id, self.project_id + ) + ) try: hook.create_table( instance=instance, table_id=self.table_id, initial_split_keys=self.initial_split_keys, - column_families=self.column_families + column_families=self.column_families, ) except google.api_core.exceptions.AlreadyExists: if not self._compare_column_families(hook, instance): raise AirflowException( - "Table '{}' already exists with different Column Families.". - format(self.table_id)) - self.log.info("The table '%s' already exists. Consider it as created", - self.table_id) + "Table '{}' already exists with different Column Families.".format(self.table_id) + ) + self.log.info("The table '%s' already exists. Consider it as created", self.table_id) class BigtableDeleteTableOperator(BaseOperator, BigtableValidationMixin): @@ -479,19 +491,27 @@ class BigtableDeleteTableOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') # type: Iterable[str] - template_fields = ['project_id', 'instance_id', 'table_id', - 'impersonation_chain', ] # type: Iterable[str] + template_fields = [ + 'project_id', + 'instance_id', + 'table_id', + 'impersonation_chain', + ] # type: Iterable[str] @apply_defaults - def __init__(self, *, - instance_id: str, - table_id: str, - project_id: Optional[str] = None, - app_profile_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + table_id: str, + project_id: Optional[str] = None, + app_profile_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance_id = instance_id self.table_id = table_id @@ -502,26 +522,18 @@ def __init__(self, *, super().__init__(**kwargs) def execute(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - instance = hook.get_instance(project_id=self.project_id, - instance_id=self.instance_id) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + instance = hook.get_instance(project_id=self.project_id, instance_id=self.instance_id) if not instance: - raise AirflowException("Dependency: instance '{}' does not exist.".format( - self.instance_id)) + raise AirflowException("Dependency: instance '{}' does not exist.".format(self.instance_id)) try: hook.delete_table( - project_id=self.project_id, - instance_id=self.instance_id, - table_id=self.table_id, + project_id=self.project_id, instance_id=self.instance_id, table_id=self.table_id, ) except google.api_core.exceptions.NotFound: # It's OK if table doesn't exists. - self.log.info("The table '%s' no longer exists. Consider it as deleted", - self.table_id) + self.log.info("The table '%s' no longer exists. Consider it as deleted", self.table_id) except google.api_core.exceptions.GoogleAPICallError as e: self.log.error('An error occurred. Exiting.') raise e @@ -559,19 +571,28 @@ class BigtableUpdateClusterOperator(BaseOperator, BigtableValidationMixin): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + REQUIRED_ATTRIBUTES = ('instance_id', 'cluster_id', 'nodes') # type: Iterable[str] - template_fields = ['project_id', 'instance_id', 'cluster_id', 'nodes', - 'impersonation_chain', ] # type: Iterable[str] + template_fields = [ + 'project_id', + 'instance_id', + 'cluster_id', + 'nodes', + 'impersonation_chain', + ] # type: Iterable[str] @apply_defaults - def __init__(self, *, - instance_id: str, - cluster_id: str, - nodes: int, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + cluster_id: str, + nodes: int, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance_id = instance_id self.cluster_id = cluster_id @@ -582,26 +603,19 @@ def __init__(self, *, super().__init__(**kwargs) def execute(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - instance = hook.get_instance(project_id=self.project_id, - instance_id=self.instance_id) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + instance = hook.get_instance(project_id=self.project_id, instance_id=self.instance_id) if not instance: - raise AirflowException("Dependency: instance '{}' does not exist.".format( - self.instance_id)) + raise AirflowException("Dependency: instance '{}' does not exist.".format(self.instance_id)) try: - hook.update_cluster( - instance=instance, - cluster_id=self.cluster_id, - nodes=self.nodes - ) + hook.update_cluster(instance=instance, cluster_id=self.cluster_id, nodes=self.nodes) except google.api_core.exceptions.NotFound: raise AirflowException( - "Dependency: cluster '{}' does not exist for instance '{}'.". - format(self.cluster_id, self.instance_id)) + "Dependency: cluster '{}' does not exist for instance '{}'.".format( + self.cluster_id, self.instance_id + ) + ) except google.api_core.exceptions.GoogleAPICallError as e: self.log.error('An error occurred. Exiting.') raise e diff --git a/airflow/providers/google/cloud/operators/cloud_build.py b/airflow/providers/google/cloud/operators/cloud_build.py index 159f24e185367..cdf35752513f3 100644 --- a/airflow/providers/google/cloud/operators/cloud_build.py +++ b/airflow/providers/google/cloud/operators/cloud_build.py @@ -189,17 +189,25 @@ class CloudBuildCreateBuildOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body", "gcp_conn_id", "api_version", "impersonation_chain",) + template_fields = ( + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) template_ext = ['.yml', '.yaml', '.json'] @apply_defaults - def __init__(self, *, - body: Union[dict, str], - project_id: Optional[str] = None, - gcp_conn_id: str = "google_cloud_default", - api_version: str = "v1", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + body: Union[dict, str], + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.body = body # Not template fields to keep original value @@ -228,7 +236,7 @@ def execute(self, context): hook = CloudBuildHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) body = BuildProcessor(body=self.body).process_body() return hook.create_build(body=body, project_id=self.project_id) diff --git a/airflow/providers/google/cloud/operators/cloud_memorystore.py b/airflow/providers/google/cloud/operators/cloud_memorystore.py index d28e787eb6d94..a3e6fbe9b380d 100644 --- a/airflow/providers/google/cloud/operators/cloud_memorystore.py +++ b/airflow/providers/google/cloud/operators/cloud_memorystore.py @@ -93,7 +93,8 @@ class CloudMemorystoreCreateInstanceOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, instance_id: str, instance: Union[Dict, Instance], @@ -103,7 +104,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -118,8 +119,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.create_instance( location=self.location, @@ -169,12 +169,21 @@ class CloudMemorystoreDeleteInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("location", "instance", "project_id", "retry", "timeout", "metadata", - "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "location", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, instance: str, project_id: Optional[str] = None, @@ -183,7 +192,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -197,8 +206,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.delete_instance( location=self.location, @@ -266,7 +274,8 @@ class CloudMemorystoreExportInstanceOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, instance: str, output_config: Union[Dict, OutputConfig], @@ -276,7 +285,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -291,8 +300,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.export_instance( @@ -359,7 +367,8 @@ class CloudMemorystoreFailoverInstanceOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, instance: str, data_protection_mode: FailoverInstanceRequest.DataProtectionMode, @@ -369,7 +378,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -384,8 +393,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.failover_instance( location=self.location, @@ -433,12 +441,21 @@ class CloudMemorystoreGetInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("location", "instance", "project_id", "retry", "timeout", "metadata", - "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "location", + "instance", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, instance: str, project_id: Optional[str] = None, @@ -447,7 +464,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -461,8 +478,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.get_instance( location=self.location, @@ -532,7 +548,8 @@ class CloudMemorystoreImportOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, instance: str, input_config: Union[Dict, InputConfig], @@ -542,7 +559,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -557,8 +574,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.import_instance( location=self.location, @@ -610,12 +626,21 @@ class CloudMemorystoreListInstancesOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("location", "page_size", "project_id", "retry", "timeout", "metadata", - "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "location", + "page_size", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, page_size: int, project_id: Optional[str] = None, @@ -624,7 +649,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -638,8 +663,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.list_instances( location=self.location, @@ -721,7 +745,8 @@ class CloudMemorystoreUpdateInstanceOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, update_mask: Union[Dict, FieldMask], instance: Union[Dict, Instance], location: Optional[str] = None, @@ -732,7 +757,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.update_mask = update_mask @@ -748,8 +773,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.update_instance( update_mask=self.update_mask, @@ -815,7 +839,8 @@ class CloudMemorystoreScaleInstanceOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, memory_size_gb: int, location: Optional[str] = None, instance_id: Optional[str] = None, @@ -825,7 +850,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.memory_size_gb = memory_size_gb @@ -840,8 +865,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.update_instance( @@ -928,7 +952,8 @@ class CloudMemorystoreCreateInstanceAndImportOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, instance_id: str, instance: Union[Dict, Instance], @@ -939,7 +964,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -955,8 +980,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.create_instance( @@ -1037,7 +1061,8 @@ class CloudMemorystoreExportAndDeleteInstanceOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, instance: str, output_config: Union[Dict, OutputConfig], @@ -1047,7 +1072,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1062,8 +1087,7 @@ def __init__( def execute(self, context: Dict): hook = CloudMemorystoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.export_instance( diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index ab02455c2be98..c155413a4da0a 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -36,93 +36,143 @@ CLOUD_SQL_CREATE_VALIDATION = [ dict(name="name", allow_empty=False), - dict(name="settings", type="dict", fields=[ - dict(name="tier", allow_empty=False), - dict(name="backupConfiguration", type="dict", fields=[ - dict(name="binaryLogEnabled", optional=True), - dict(name="enabled", optional=True), - dict(name="replicationLogArchivingEnabled", optional=True), - dict(name="startTime", allow_empty=False, optional=True) - ], optional=True), - dict(name="activationPolicy", allow_empty=False, optional=True), - dict(name="authorizedGaeApplications", type="list", optional=True), - dict(name="crashSafeReplicationEnabled", optional=True), - dict(name="dataDiskSizeGb", optional=True), - dict(name="dataDiskType", allow_empty=False, optional=True), - dict(name="databaseFlags", type="list", optional=True), - dict(name="ipConfiguration", type="dict", fields=[ - dict(name="authorizedNetworks", type="list", fields=[ - dict(name="expirationTime", optional=True), - dict(name="name", allow_empty=False, optional=True), - dict(name="value", allow_empty=False, optional=True) - ], optional=True), - dict(name="ipv4Enabled", optional=True), - dict(name="privateNetwork", allow_empty=False, optional=True), - dict(name="requireSsl", optional=True), - ], optional=True), - dict(name="locationPreference", type="dict", fields=[ - dict(name="followGaeApplication", allow_empty=False, optional=True), - dict(name="zone", allow_empty=False, optional=True), - ], optional=True), - dict(name="maintenanceWindow", type="dict", fields=[ - dict(name="hour", optional=True), - dict(name="day", optional=True), - dict(name="updateTrack", allow_empty=False, optional=True), - ], optional=True), - dict(name="pricingPlan", allow_empty=False, optional=True), - dict(name="replicationType", allow_empty=False, optional=True), - dict(name="storageAutoResize", optional=True), - dict(name="storageAutoResizeLimit", optional=True), - dict(name="userLabels", type="dict", optional=True), - ]), + dict( + name="settings", + type="dict", + fields=[ + dict(name="tier", allow_empty=False), + dict( + name="backupConfiguration", + type="dict", + fields=[ + dict(name="binaryLogEnabled", optional=True), + dict(name="enabled", optional=True), + dict(name="replicationLogArchivingEnabled", optional=True), + dict(name="startTime", allow_empty=False, optional=True), + ], + optional=True, + ), + dict(name="activationPolicy", allow_empty=False, optional=True), + dict(name="authorizedGaeApplications", type="list", optional=True), + dict(name="crashSafeReplicationEnabled", optional=True), + dict(name="dataDiskSizeGb", optional=True), + dict(name="dataDiskType", allow_empty=False, optional=True), + dict(name="databaseFlags", type="list", optional=True), + dict( + name="ipConfiguration", + type="dict", + fields=[ + dict( + name="authorizedNetworks", + type="list", + fields=[ + dict(name="expirationTime", optional=True), + dict(name="name", allow_empty=False, optional=True), + dict(name="value", allow_empty=False, optional=True), + ], + optional=True, + ), + dict(name="ipv4Enabled", optional=True), + dict(name="privateNetwork", allow_empty=False, optional=True), + dict(name="requireSsl", optional=True), + ], + optional=True, + ), + dict( + name="locationPreference", + type="dict", + fields=[ + dict(name="followGaeApplication", allow_empty=False, optional=True), + dict(name="zone", allow_empty=False, optional=True), + ], + optional=True, + ), + dict( + name="maintenanceWindow", + type="dict", + fields=[ + dict(name="hour", optional=True), + dict(name="day", optional=True), + dict(name="updateTrack", allow_empty=False, optional=True), + ], + optional=True, + ), + dict(name="pricingPlan", allow_empty=False, optional=True), + dict(name="replicationType", allow_empty=False, optional=True), + dict(name="storageAutoResize", optional=True), + dict(name="storageAutoResizeLimit", optional=True), + dict(name="userLabels", type="dict", optional=True), + ], + ), dict(name="databaseVersion", allow_empty=False, optional=True), - dict(name="failoverReplica", type="dict", fields=[ - dict(name="name", allow_empty=False) - ], optional=True), + dict(name="failoverReplica", type="dict", fields=[dict(name="name", allow_empty=False)], optional=True), dict(name="masterInstanceName", allow_empty=False, optional=True), dict(name="onPremisesConfiguration", type="dict", optional=True), dict(name="region", allow_empty=False, optional=True), - dict(name="replicaConfiguration", type="dict", fields=[ - dict(name="failoverTarget", optional=True), - dict(name="mysqlReplicaConfiguration", type="dict", fields=[ - dict(name="caCertificate", allow_empty=False, optional=True), - dict(name="clientCertificate", allow_empty=False, optional=True), - dict(name="clientKey", allow_empty=False, optional=True), - dict(name="connectRetryInterval", optional=True), - dict(name="dumpFilePath", allow_empty=False, optional=True), - dict(name="masterHeartbeatPeriod", optional=True), - dict(name="password", allow_empty=False, optional=True), - dict(name="sslCipher", allow_empty=False, optional=True), - dict(name="username", allow_empty=False, optional=True), - dict(name="verifyServerCertificate", optional=True) - ], optional=True), - ], optional=True) + dict( + name="replicaConfiguration", + type="dict", + fields=[ + dict(name="failoverTarget", optional=True), + dict( + name="mysqlReplicaConfiguration", + type="dict", + fields=[ + dict(name="caCertificate", allow_empty=False, optional=True), + dict(name="clientCertificate", allow_empty=False, optional=True), + dict(name="clientKey", allow_empty=False, optional=True), + dict(name="connectRetryInterval", optional=True), + dict(name="dumpFilePath", allow_empty=False, optional=True), + dict(name="masterHeartbeatPeriod", optional=True), + dict(name="password", allow_empty=False, optional=True), + dict(name="sslCipher", allow_empty=False, optional=True), + dict(name="username", allow_empty=False, optional=True), + dict(name="verifyServerCertificate", optional=True), + ], + optional=True, + ), + ], + optional=True, + ), ] CLOUD_SQL_EXPORT_VALIDATION = [ - dict(name="exportContext", type="dict", fields=[ - dict(name="fileType", allow_empty=False), - dict(name="uri", allow_empty=False), - dict(name="databases", optional=True, type="list"), - dict(name="sqlExportOptions", type="dict", optional=True, fields=[ - dict(name="tables", optional=True, type="list"), - dict(name="schemaOnly", optional=True) - ]), - dict(name="csvExportOptions", type="dict", optional=True, fields=[ - dict(name="selectQuery") - ]) - ]) + dict( + name="exportContext", + type="dict", + fields=[ + dict(name="fileType", allow_empty=False), + dict(name="uri", allow_empty=False), + dict(name="databases", optional=True, type="list"), + dict( + name="sqlExportOptions", + type="dict", + optional=True, + fields=[ + dict(name="tables", optional=True, type="list"), + dict(name="schemaOnly", optional=True), + ], + ), + dict(name="csvExportOptions", type="dict", optional=True, fields=[dict(name="selectQuery")]), + ], + ) ] CLOUD_SQL_IMPORT_VALIDATION = [ - dict(name="importContext", type="dict", fields=[ - dict(name="fileType", allow_empty=False), - dict(name="uri", allow_empty=False), - dict(name="database", optional=True, allow_empty=False), - dict(name="importUser", optional=True), - dict(name="csvImportOptions", type="dict", optional=True, fields=[ - dict(name="table"), - dict(name="columns", type="list", optional=True) - ]) - ]) + dict( + name="importContext", + type="dict", + fields=[ + dict(name="fileType", allow_empty=False), + dict(name="uri", allow_empty=False), + dict(name="database", optional=True, allow_empty=False), + dict(name="importUser", optional=True), + dict( + name="csvImportOptions", + type="dict", + optional=True, + fields=[dict(name="table"), dict(name="columns", type="list", optional=True)], + ), + ], + ) ] CLOUD_SQL_DATABASE_CREATE_VALIDATION = [ dict(name="instance", allow_empty=False), @@ -162,14 +212,18 @@ class CloudSQLBaseOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + @apply_defaults - def __init__(self, *, - instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.instance = instance self.gcp_conn_id = gcp_conn_id @@ -186,10 +240,7 @@ def _validate_inputs(self): def _check_if_instance_exists(self, instance, hook: CloudSQLHook): try: - return hook.get_instance( - project_id=self.project_id, - instance=instance - ) + return hook.get_instance(project_id=self.project_id, instance=instance) except HttpError as e: status = e.resp.status if status == 404: @@ -198,10 +249,7 @@ def _check_if_instance_exists(self, instance, hook: CloudSQLHook): def _check_if_db_exists(self, db_name, hook: CloudSQLHook): try: - return hook.get_database( - project_id=self.project_id, - instance=self.instance, - database=db_name) + return hook.get_database(project_id=self.project_id, instance=self.instance, database=db_name) except HttpError as e: status = e.resp.status if status == 404: @@ -251,26 +299,41 @@ class CloudSQLCreateInstanceOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_create_template_fields] - template_fields = ('project_id', 'instance', 'body', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'body', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_create_template_fields] @apply_defaults - def __init__(self, *, - body: dict, - instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + body: dict, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body = body self.validate_body = validate_body super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -279,8 +342,9 @@ def _validate_inputs(self): def _validate_body_fields(self): if self.validate_body: - GcpBodyFieldValidator(CLOUD_SQL_CREATE_VALIDATION, - api_version=self.api_version).validate(self.body) + GcpBodyFieldValidator(CLOUD_SQL_CREATE_VALIDATION, api_version=self.api_version).validate( + self.body + ) def execute(self, context): hook = CloudSQLHook( @@ -290,18 +354,11 @@ def execute(self, context): ) self._validate_body_fields() if not self._check_if_instance_exists(self.instance, hook): - hook.create_instance( - project_id=self.project_id, - body=self.body - ) + hook.create_instance(project_id=self.project_id, body=self.body) else: - self.log.info("Cloud SQL instance with ID %s already exists. " - "Aborting create.", self.instance) + self.log.info("Cloud SQL instance with ID %s already exists. " "Aborting create.", self.instance) - instance_resource = hook.get_instance( - project_id=self.project_id, - instance=self.instance - ) + instance_resource = hook.get_instance(project_id=self.project_id, instance=self.instance) service_account_email = instance_resource["serviceAccountEmailAddress"] task_instance = context['task_instance'] task_instance.xcom_push(key="service_account_email", value=service_account_email) @@ -344,24 +401,39 @@ class CloudSQLInstancePatchOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_patch_template_fields] - template_fields = ('project_id', 'instance', 'body', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'body', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_patch_template_fields] @apply_defaults - def __init__(self, *, - body: dict, - instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + body: dict, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body = body super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -375,14 +447,12 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) if not self._check_if_instance_exists(self.instance, hook): - raise AirflowException('Cloud SQL instance with ID {} does not exist. ' - 'Please specify another instance to patch.' - .format(self.instance)) + raise AirflowException( + 'Cloud SQL instance with ID {} does not exist. ' + 'Please specify another instance to patch.'.format(self.instance) + ) else: - return hook.patch_instance( - project_id=self.project_id, - body=self.body, - instance=self.instance) + return hook.patch_instance(project_id=self.project_id, body=self.body, instance=self.instance) class CloudSQLDeleteInstanceOperator(CloudSQLBaseOperator): @@ -412,22 +482,36 @@ class CloudSQLDeleteInstanceOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_delete_template_fields] - template_fields = ('project_id', 'instance', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_delete_template_fields] @apply_defaults - def __init__(self, *, - instance: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def execute(self, context): hook = CloudSQLHook( @@ -436,13 +520,10 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) if not self._check_if_instance_exists(self.instance, hook): - print("Cloud SQL instance with ID {} does not exist. Aborting delete." - .format(self.instance)) + print("Cloud SQL instance with ID {} does not exist. Aborting delete.".format(self.instance)) return True else: - return hook.delete_instance( - project_id=self.project_id, - instance=self.instance) + return hook.delete_instance(project_id=self.project_id, instance=self.instance) class CloudSQLCreateInstanceDatabaseOperator(CloudSQLBaseOperator): @@ -477,26 +558,41 @@ class CloudSQLCreateInstanceDatabaseOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_db_create_template_fields] - template_fields = ('project_id', 'instance', 'body', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'body', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_db_create_template_fields] @apply_defaults - def __init__(self, *, - instance: str, - body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body = body self.validate_body = validate_body super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -505,15 +601,19 @@ def _validate_inputs(self): def _validate_body_fields(self): if self.validate_body: - GcpBodyFieldValidator(CLOUD_SQL_DATABASE_CREATE_VALIDATION, - api_version=self.api_version).validate(self.body) + GcpBodyFieldValidator( + CLOUD_SQL_DATABASE_CREATE_VALIDATION, api_version=self.api_version + ).validate(self.body) def execute(self, context): self._validate_body_fields() database = self.body.get("name") if not database: - self.log.error("Body doesn't contain 'name'. Cannot check if the" - " database already exists in the instance %s.", self.instance) + self.log.error( + "Body doesn't contain 'name'. Cannot check if the" + " database already exists in the instance %s.", + self.instance, + ) return False hook = CloudSQLHook( gcp_conn_id=self.gcp_conn_id, @@ -521,15 +621,14 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) if self._check_if_db_exists(database, hook): - self.log.info("Cloud SQL instance with ID %s already contains database" - " '%s'. Aborting database insert.", self.instance, database) + self.log.info( + "Cloud SQL instance with ID %s already contains database" " '%s'. Aborting database insert.", + self.instance, + database, + ) return True else: - return hook.create_database( - project_id=self.project_id, - instance=self.instance, - body=self.body - ) + return hook.create_database(project_id=self.project_id, instance=self.instance, body=self.body) class CloudSQLPatchInstanceDatabaseOperator(CloudSQLBaseOperator): @@ -567,28 +666,44 @@ class CloudSQLPatchInstanceDatabaseOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_db_patch_template_fields] - template_fields = ('project_id', 'instance', 'body', 'database', 'gcp_conn_id', - 'api_version', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'body', + 'database', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_db_patch_template_fields] @apply_defaults - def __init__(self, *, - instance: str, - database: str, - body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + database: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.database = database self.body = body self.validate_body = validate_body super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -599,8 +714,9 @@ def _validate_inputs(self): def _validate_body_fields(self): if self.validate_body: - GcpBodyFieldValidator(CLOUD_SQL_DATABASE_PATCH_VALIDATION, - api_version=self.api_version).validate(self.body) + GcpBodyFieldValidator(CLOUD_SQL_DATABASE_PATCH_VALIDATION, api_version=self.api_version).validate( + self.body + ) def execute(self, context): self._validate_body_fields() @@ -610,16 +726,17 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) if not self._check_if_db_exists(self.database, hook): - raise AirflowException("Cloud SQL instance with ID {instance} does not contain " - "database '{database}'. " - "Please specify another database to patch.". - format(instance=self.instance, database=self.database)) + raise AirflowException( + "Cloud SQL instance with ID {instance} does not contain " + "database '{database}'. " + "Please specify another database to patch.".format( + instance=self.instance, database=self.database + ) + ) else: return hook.patch_database( - project_id=self.project_id, - instance=self.instance, - database=self.database, - body=self.body) + project_id=self.project_id, instance=self.instance, database=self.database, body=self.body + ) class CloudSQLDeleteInstanceDatabaseOperator(CloudSQLBaseOperator): @@ -651,24 +768,39 @@ class CloudSQLDeleteInstanceDatabaseOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_db_delete_template_fields] - template_fields = ('project_id', 'instance', 'database', 'gcp_conn_id', - 'api_version', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'database', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_db_delete_template_fields] @apply_defaults - def __init__(self, *, - instance: str, - database: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + database: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.database = database super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -682,15 +814,15 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) if not self._check_if_db_exists(self.database, hook): - print("Cloud SQL instance with ID {} does not contain database '{}'. " - "Aborting database delete." - .format(self.instance, self.database)) + print( + "Cloud SQL instance with ID {} does not contain database '{}'. " + "Aborting database delete.".format(self.instance, self.database) + ) return True else: return hook.delete_database( - project_id=self.project_id, - instance=self.instance, - database=self.database) + project_id=self.project_id, instance=self.instance, database=self.database + ) class CloudSQLExportInstanceOperator(CloudSQLBaseOperator): @@ -729,26 +861,41 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_export_template_fields] - template_fields = ('project_id', 'instance', 'body', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'body', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_export_template_fields] @apply_defaults - def __init__(self, *, - instance: str, - body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body = body self.validate_body = validate_body super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -757,8 +904,9 @@ def _validate_inputs(self): def _validate_body_fields(self): if self.validate_body: - GcpBodyFieldValidator(CLOUD_SQL_EXPORT_VALIDATION, - api_version=self.api_version).validate(self.body) + GcpBodyFieldValidator(CLOUD_SQL_EXPORT_VALIDATION, api_version=self.api_version).validate( + self.body + ) def execute(self, context): self._validate_body_fields() @@ -767,10 +915,7 @@ def execute(self, context): api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - return hook.export_instance( - project_id=self.project_id, - instance=self.instance, - body=self.body) + return hook.export_instance(project_id=self.project_id, instance=self.instance, body=self.body) class CloudSQLImportInstanceOperator(CloudSQLBaseOperator): @@ -821,26 +966,41 @@ class CloudSQLImportInstanceOperator(CloudSQLBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_sql_import_template_fields] - template_fields = ('project_id', 'instance', 'body', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance', + 'body', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcp_sql_import_template_fields] @apply_defaults - def __init__(self, *, - instance: str, - body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1beta4', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1beta4', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body = body self.validate_body = validate_body super().__init__( - project_id=project_id, instance=instance, gcp_conn_id=gcp_conn_id, - api_version=api_version, impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + instance=instance, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_inputs(self): super()._validate_inputs() @@ -849,8 +1009,9 @@ def _validate_inputs(self): def _validate_body_fields(self): if self.validate_body: - GcpBodyFieldValidator(CLOUD_SQL_IMPORT_VALIDATION, - api_version=self.api_version).validate(self.body) + GcpBodyFieldValidator(CLOUD_SQL_IMPORT_VALIDATION, api_version=self.api_version).validate( + self.body + ) def execute(self, context): self._validate_body_fields() @@ -859,10 +1020,7 @@ def execute(self, context): api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - return hook.import_instance( - project_id=self.project_id, - instance=self.instance, - body=self.body) + return hook.import_instance(project_id=self.project_id, instance=self.instance, body=self.body) class CloudSQLExecuteQueryOperator(BaseOperator): @@ -894,19 +1052,23 @@ class CloudSQLExecuteQueryOperator(BaseOperator): details on how to define gcpcloudsql:// connection. :type gcp_cloudsql_conn_id: str """ + # [START gcp_sql_query_template_fields] template_fields = ('sql', 'gcp_cloudsql_conn_id', 'gcp_conn_id') template_ext = ('.sql',) # [END gcp_sql_query_template_fields] @apply_defaults - def __init__(self, *, - sql: Union[List[str], str], - autocommit: bool = False, - parameters: Optional[Union[Dict, Iterable]] = None, - gcp_conn_id: str = 'google_cloud_default', - gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', - **kwargs) -> None: + def __init__( + self, + *, + sql: Union[List[str], str], + autocommit: bool = False, + parameters: Optional[Union[Dict, Iterable]] = None, + gcp_conn_id: str = 'google_cloud_default', + gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.gcp_conn_id = gcp_conn_id @@ -937,7 +1099,8 @@ def execute(self, context): gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, default_gcp_project_id=self.gcp_connection.extra_dejson.get( - 'extra__google_cloud_platform__project') + 'extra__google_cloud_platform__project' + ), ) hook.validate_ssl_certs() connection = hook.create_connection() diff --git a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index 114366a4dada1..47e9b4fb51691 100644 --- a/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -27,10 +27,32 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - ACCESS_KEY_ID, AWS_ACCESS_KEY, AWS_S3_DATA_SOURCE, BUCKET_NAME, DAY, DESCRIPTION, GCS_DATA_SINK, - GCS_DATA_SOURCE, HOURS, HTTP_DATA_SOURCE, MINUTES, MONTH, OBJECT_CONDITIONS, PROJECT_ID, SCHEDULE, - SCHEDULE_END_DATE, SCHEDULE_START_DATE, SECONDS, SECRET_ACCESS_KEY, START_TIME_OF_DAY, STATUS, - TRANSFER_OPTIONS, TRANSFER_SPEC, YEAR, CloudDataTransferServiceHook, GcpTransferJobsStatus, + ACCESS_KEY_ID, + AWS_ACCESS_KEY, + AWS_S3_DATA_SOURCE, + BUCKET_NAME, + DAY, + DESCRIPTION, + GCS_DATA_SINK, + GCS_DATA_SOURCE, + HOURS, + HTTP_DATA_SOURCE, + MINUTES, + MONTH, + OBJECT_CONDITIONS, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + SECONDS, + SECRET_ACCESS_KEY, + START_TIME_OF_DAY, + STATUS, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + YEAR, + CloudDataTransferServiceHook, + GcpTransferJobsStatus, ) from airflow.utils.decorators import apply_defaults @@ -73,10 +95,7 @@ def _reformat_time(self, field_key): def _reformat_schedule(self): if SCHEDULE not in self.body: if self.default_schedule: - self.body[SCHEDULE] = { - SCHEDULE_START_DATE: date.today(), - SCHEDULE_END_DATE: date.today() - } + self.body[SCHEDULE] = {SCHEDULE_START_DATE: date.today(), SCHEDULE_END_DATE: date.today()} else: return self._reformat_date(SCHEDULE_START_DATE) @@ -115,6 +134,7 @@ class TransferJobValidator: """ Helper class for validating transfer job body. """ + def __init__(self, body: dict) -> None: if not body: raise AirflowException("The required parameter 'body' is empty or None") @@ -200,19 +220,26 @@ class CloudDataTransferServiceCreateJobOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_job_create_template_fields] - template_fields = ('body', 'gcp_conn_id', 'aws_conn_id', 'google_impersonation_chain',) + template_fields = ( + 'body', + 'gcp_conn_id', + 'aws_conn_id', + 'google_impersonation_chain', + ) # [END gcp_transfer_job_create_template_fields] @apply_defaults def __init__( - self, *, + self, + *, body: dict, aws_conn_id: str = 'aws_default', gcp_conn_id: str = 'google_cloud_default', api_version: str = 'v1', google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.body = deepcopy(body) @@ -273,21 +300,28 @@ class CloudDataTransferServiceUpdateJobOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_job_update_template_fields] - template_fields = ('job_name', 'body', 'gcp_conn_id', 'aws_conn_id', - 'google_impersonation_chain',) + template_fields = ( + 'job_name', + 'body', + 'gcp_conn_id', + 'aws_conn_id', + 'google_impersonation_chain', + ) # [END gcp_transfer_job_update_template_fields] @apply_defaults def __init__( - self, *, + self, + *, job_name: str, body: dict, aws_conn_id: str = 'aws_default', gcp_conn_id: str = 'google_cloud_default', api_version: str = 'v1', google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.job_name = job_name @@ -345,20 +379,27 @@ class CloudDataTransferServiceDeleteJobOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_job_delete_template_fields] - template_fields = ('job_name', 'project_id', 'gcp_conn_id', 'api_version', - 'google_impersonation_chain',) + template_fields = ( + 'job_name', + 'project_id', + 'gcp_conn_id', + 'api_version', + 'google_impersonation_chain', + ) # [END gcp_transfer_job_delete_template_fields] @apply_defaults def __init__( - self, *, + self, + *, job_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", project_id: Optional[str] = None, google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.job_name = job_name @@ -408,18 +449,24 @@ class CloudDataTransferServiceGetOperationOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_operation_get_template_fields] - template_fields = ('operation_name', 'gcp_conn_id', 'google_impersonation_chain',) + template_fields = ( + 'operation_name', + 'gcp_conn_id', + 'google_impersonation_chain', + ) # [END gcp_transfer_operation_get_template_fields] @apply_defaults def __init__( - self, *, + self, + *, operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.operation_name = operation_name @@ -469,16 +516,23 @@ class CloudDataTransferServiceListOperationsOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_operations_list_template_fields] - template_fields = ('filter', 'gcp_conn_id', 'google_impersonation_chain',) + template_fields = ( + 'filter', + 'gcp_conn_id', + 'google_impersonation_chain', + ) # [END gcp_transfer_operations_list_template_fields] - def __init__(self, - request_filter: Optional[Dict] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + request_filter: Optional[Dict] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: # To preserve backward compatibility # TODO: remove one day if request_filter is None: @@ -534,19 +588,25 @@ class CloudDataTransferServicePauseOperationOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_operation_pause_template_fields] - template_fields = ('operation_name', 'gcp_conn_id', 'api_version', - 'google_impersonation_chain',) + template_fields = ( + 'operation_name', + 'gcp_conn_id', + 'api_version', + 'google_impersonation_chain', + ) # [END gcp_transfer_operation_pause_template_fields] @apply_defaults def __init__( - self, *, + self, + *, operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.operation_name = operation_name @@ -592,19 +652,25 @@ class CloudDataTransferServiceResumeOperationOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_operation_resume_template_fields] - template_fields = ('operation_name', 'gcp_conn_id', 'api_version', - 'google_impersonation_chain',) + template_fields = ( + 'operation_name', + 'gcp_conn_id', + 'api_version', + 'google_impersonation_chain', + ) # [END gcp_transfer_operation_resume_template_fields] @apply_defaults def __init__( - self, *, + self, + *, operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: self.operation_name = operation_name self.gcp_conn_id = gcp_conn_id @@ -651,19 +717,25 @@ class CloudDataTransferServiceCancelOperationOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type google_impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_transfer_operation_cancel_template_fields] - template_fields = ('operation_name', 'gcp_conn_id', 'api_version', - 'google_impersonation_chain',) + template_fields = ( + 'operation_name', + 'gcp_conn_id', + 'api_version', + 'google_impersonation_chain', + ) # [END gcp_transfer_operation_cancel_template_fields] @apply_defaults def __init__( - self, *, + self, + *, operation_name: str, gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.operation_name = operation_name @@ -755,13 +827,20 @@ class CloudDataTransferServiceS3ToGCSOperator(BaseOperator): :type google_impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('gcp_conn_id', 's3_bucket', 'gcs_bucket', 'description', 'object_conditions', - 'google_impersonation_chain',) + template_fields = ( + 'gcp_conn_id', + 's3_bucket', + 'gcs_bucket', + 'description', + 'object_conditions', + 'google_impersonation_chain', + ) ui_color = '#e09411' @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, s3_bucket: str, gcs_bucket: str, project_id: Optional[str] = None, @@ -775,7 +854,7 @@ def __init__( # pylint: disable=too-many-arguments wait: bool = True, timeout: Optional[float] = None, google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -918,7 +997,8 @@ class CloudDataTransferServiceGCSToGCSOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, source_bucket: str, destination_bucket: str, project_id: Optional[str] = None, @@ -931,7 +1011,7 @@ def __init__( # pylint: disable=too-many-arguments wait: bool = True, timeout: Optional[float] = None, google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/cloud/operators/compute.py b/airflow/providers/google/cloud/operators/compute.py index e3fe98d553920..839cda3db98a6 100644 --- a/airflow/providers/google/cloud/operators/compute.py +++ b/airflow/providers/google/cloud/operators/compute.py @@ -39,14 +39,17 @@ class ComputeEngineBaseOperator(BaseOperator): """ @apply_defaults - def __init__(self, *, - zone: str, - resource_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.zone = zone self.resource_id = resource_id @@ -102,34 +105,47 @@ class ComputeEngineStartInstanceOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gce_instance_start_template_fields] - template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'zone', + 'resource_id', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gce_instance_start_template_fields] @apply_defaults - def __init__(self, *, - zone: str, - resource_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__( - project_id=project_id, zone=zone, resource_id=resource_id, - gcp_conn_id=gcp_conn_id, api_version=api_version, - impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def execute(self, context): hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) - return hook.start_instance(zone=self.zone, - resource_id=self.resource_id, - project_id=self.project_id) + return hook.start_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) class ComputeEngineStopInstanceOperator(ComputeEngineBaseOperator): @@ -166,34 +182,47 @@ class ComputeEngineStopInstanceOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gce_instance_stop_template_fields] - template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'zone', + 'resource_id', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gce_instance_stop_template_fields] @apply_defaults - def __init__(self, *, - zone: str, - resource_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + zone: str, + resource_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__( - project_id=project_id, zone=zone, resource_id=resource_id, - gcp_conn_id=gcp_conn_id, api_version=api_version, - impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def execute(self, context): hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) - hook.stop_instance(zone=self.zone, - resource_id=self.resource_id, - project_id=self.project_id) + hook.stop_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) SET_MACHINE_TYPE_VALIDATION_SPECIFICATION = [ @@ -240,31 +269,48 @@ class ComputeEngineSetMachineTypeOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gce_instance_set_machine_type_template_fields] - template_fields = ('project_id', 'zone', 'resource_id', 'body', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'zone', + 'resource_id', + 'body', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gce_instance_set_machine_type_template_fields] @apply_defaults - def __init__(self, *, - zone: str, - resource_id: str, - body: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + zone: str, + resource_id: str, + body: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body = body self._field_validator = None # type: Optional[GcpBodyFieldValidator] if validate_body: self._field_validator = GcpBodyFieldValidator( - SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version) + SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version + ) super().__init__( - project_id=project_id, zone=zone, resource_id=resource_id, - gcp_conn_id=gcp_conn_id, api_version=api_version, - impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + zone=zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_all_body_fields(self): if self._field_validator: @@ -274,43 +320,53 @@ def execute(self, context): hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) self._validate_all_body_fields() - return hook.set_machine_type(zone=self.zone, - resource_id=self.resource_id, - body=self.body, - project_id=self.project_id) + return hook.set_machine_type( + zone=self.zone, resource_id=self.resource_id, body=self.body, project_id=self.project_id + ) GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION = [ dict(name="name", regexp="^.+$"), dict(name="description", optional=True), - dict(name="properties", type='dict', optional=True, fields=[ - dict(name="description", optional=True), - dict(name="tags", optional=True, fields=[ - dict(name="items", optional=True) - ]), - dict(name="machineType", optional=True), - dict(name="canIpForward", optional=True), - dict(name="networkInterfaces", optional=True), # not validating deeper - dict(name="disks", optional=True), # not validating the array deeper - dict(name="metadata", optional=True, fields=[ - dict(name="fingerprint", optional=True), - dict(name="items", optional=True), - dict(name="kind", optional=True), - ]), - dict(name="serviceAccounts", optional=True), # not validating deeper - dict(name="scheduling", optional=True, fields=[ - dict(name="onHostMaintenance", optional=True), - dict(name="automaticRestart", optional=True), - dict(name="preemptible", optional=True), - dict(name="nodeAffinitites", optional=True), # not validating deeper - ]), - dict(name="labels", optional=True), - dict(name="guestAccelerators", optional=True), # not validating deeper - dict(name="minCpuPlatform", optional=True), - ]), + dict( + name="properties", + type='dict', + optional=True, + fields=[ + dict(name="description", optional=True), + dict(name="tags", optional=True, fields=[dict(name="items", optional=True)]), + dict(name="machineType", optional=True), + dict(name="canIpForward", optional=True), + dict(name="networkInterfaces", optional=True), # not validating deeper + dict(name="disks", optional=True), # not validating the array deeper + dict( + name="metadata", + optional=True, + fields=[ + dict(name="fingerprint", optional=True), + dict(name="items", optional=True), + dict(name="kind", optional=True), + ], + ), + dict(name="serviceAccounts", optional=True), # not validating deeper + dict( + name="scheduling", + optional=True, + fields=[ + dict(name="onHostMaintenance", optional=True), + dict(name="automaticRestart", optional=True), + dict(name="preemptible", optional=True), + dict(name="nodeAffinitites", optional=True), # not validating deeper + ], + ), + dict(name="labels", optional=True), + dict(name="guestAccelerators", optional=True), # not validating deeper + dict(name="minCpuPlatform", optional=True), + ], + ), ] # type: List[Dict[str, Any]] GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE = [ @@ -327,7 +383,7 @@ def execute(self, context): "properties.networkInterfaces.accessConfigs.kind", "properties.networkInterfaces.name", "properties.metadata.kind", - "selfLink" + "selfLink", ] @@ -377,38 +433,54 @@ class ComputeEngineCopyInstanceTemplateOperator(ComputeEngineBaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gce_instance_template_copy_operator_template_fields] - template_fields = ('project_id', 'resource_id', 'request_id', - 'gcp_conn_id', 'api_version', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'resource_id', + 'request_id', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gce_instance_template_copy_operator_template_fields] @apply_defaults - def __init__(self, *, - resource_id: str, - body_patch: dict, - project_id: Optional[str] = None, - request_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + resource_id: str, + body_patch: dict, + project_id: Optional[str] = None, + request_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.body_patch = body_patch self.request_id = request_id self._field_validator = None # Optional[GcpBodyFieldValidator] if 'name' not in self.body_patch: - raise AirflowException("The body '{}' should contain at least " - "name for the new operator in the 'name' field". - format(body_patch)) + raise AirflowException( + "The body '{}' should contain at least " + "name for the new operator in the 'name' field".format(body_patch) + ) if validate_body: self._field_validator = GcpBodyFieldValidator( - GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version) - self._field_sanitizer = GcpBodyFieldSanitizer( - GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE) + GCE_INSTANCE_TEMPLATE_VALIDATION_PATCH_SPECIFICATION, api_version=api_version + ) + self._field_sanitizer = GcpBodyFieldSanitizer(GCE_INSTANCE_TEMPLATE_FIELDS_TO_SANITIZE) super().__init__( - project_id=project_id, zone='global', resource_id=resource_id, - gcp_conn_id=gcp_conn_id, api_version=api_version, - impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + zone='global', + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _validate_all_body_fields(self): if self._field_validator: @@ -418,7 +490,7 @@ def execute(self, context): hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) self._validate_all_body_fields() try: @@ -430,11 +502,12 @@ def execute(self, context): # that we cannot delete template if it is already used in some Instance # Group Manager. We assume success if the template is simply present existing_template = hook.get_instance_template( - resource_id=self.body_patch['name'], project_id=self.project_id) + resource_id=self.body_patch['name'], project_id=self.project_id + ) self.log.info( "The %s template already existed. It was likely created by previous run of the operator. " "Assuming success.", - existing_template + existing_template, ) return existing_template except HttpError as e: @@ -442,17 +515,13 @@ def execute(self, context): # not yet exist if not e.resp.status == 404: raise e - old_body = hook.get_instance_template(resource_id=self.resource_id, - project_id=self.project_id) + old_body = hook.get_instance_template(resource_id=self.resource_id, project_id=self.project_id) new_body = deepcopy(old_body) self._field_sanitizer.sanitize(new_body) new_body = merge(new_body, self.body_patch) self.log.info("Calling insert instance template with updated body: %s", new_body) - hook.insert_instance_template(body=new_body, - request_id=self.request_id, - project_id=self.project_id) - return hook.get_instance_template(resource_id=self.body_patch['name'], - project_id=self.project_id) + hook.insert_instance_template(body=new_body, request_id=self.request_id, project_id=self.project_id) + return hook.get_instance_template(resource_id=self.body_patch['name'], project_id=self.project_id) class ComputeEngineInstanceGroupUpdateManagerTemplateOperator(ComputeEngineBaseOperator): @@ -501,25 +570,37 @@ class ComputeEngineInstanceGroupUpdateManagerTemplateOperator(ComputeEngineBaseO account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gce_igm_update_template_operator_template_fields] - template_fields = ('project_id', 'resource_id', 'zone', 'request_id', - 'source_template', 'destination_template', - 'gcp_conn_id', 'api_version', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'resource_id', + 'zone', + 'request_id', + 'source_template', + 'destination_template', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gce_igm_update_template_operator_template_fields] @apply_defaults - def __init__(self, *, - resource_id: str, - zone: str, - source_template: str, - destination_template: str, - project_id: Optional[str] = None, - update_policy: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version='beta', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + resource_id: str, + zone: str, + source_template: str, + destination_template: str, + project_id: Optional[str] = None, + update_policy: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version='beta', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.zone = zone self.source_template = source_template self.destination_template = destination_template @@ -527,13 +608,20 @@ def __init__(self, *, self.update_policy = update_policy self._change_performed = False if api_version == 'v1': - raise AirflowException("Api version v1 does not have update/patch " - "operations for Instance Group Managers. Use beta" - " api version or above") + raise AirflowException( + "Api version v1 does not have update/patch " + "operations for Instance Group Managers. Use beta" + " api version or above" + ) super().__init__( - project_id=project_id, zone=self.zone, resource_id=resource_id, - gcp_conn_id=gcp_conn_id, api_version=api_version, - impersonation_chain=impersonation_chain, **kwargs) + project_id=project_id, + zone=self.zone, + resource_id=resource_id, + gcp_conn_id=gcp_conn_id, + api_version=api_version, + impersonation_chain=impersonation_chain, + **kwargs, + ) def _possibly_replace_template(self, dictionary: Dict) -> None: if dictionary.get('instanceTemplate') == self.source_template: @@ -544,10 +632,11 @@ def execute(self, context): hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) old_instance_group_manager = hook.get_instance_group_manager( - zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) + zone=self.zone, resource_id=self.resource_id, project_id=self.project_id + ) patch_body = {} if 'versions' in old_instance_group_manager: patch_body['versions'] = old_instance_group_manager['versions'] @@ -560,13 +649,14 @@ def execute(self, context): for version in patch_body['versions']: self._possibly_replace_template(version) if self._change_performed or self.update_policy: - self.log.info( - "Calling patch instance template with updated body: %s", - patch_body) + self.log.info("Calling patch instance template with updated body: %s", patch_body) return hook.patch_instance_group_manager( - zone=self.zone, resource_id=self.resource_id, - body=patch_body, request_id=self.request_id, - project_id=self.project_id) + zone=self.zone, + resource_id=self.resource_id, + body=patch_body, + request_id=self.request_id, + project_id=self.project_id, + ) else: # Idempotence achieved return True diff --git a/airflow/providers/google/cloud/operators/datacatalog.py b/airflow/providers/google/cloud/operators/datacatalog.py index b8404f9ae72b1..e7baa6d8fd606 100644 --- a/airflow/providers/google/cloud/operators/datacatalog.py +++ b/airflow/providers/google/cloud/operators/datacatalog.py @@ -21,7 +21,13 @@ from google.api_core.retry import Retry from google.cloud.datacatalog_v1beta1 import DataCatalogClient from google.cloud.datacatalog_v1beta1.types import ( - Entry, EntryGroup, FieldMask, SearchCatalogRequest, Tag, TagTemplate, TagTemplateField, + Entry, + EntryGroup, + FieldMask, + SearchCatalogRequest, + Tag, + TagTemplate, + TagTemplateField, ) from google.protobuf.json_format import MessageToDict @@ -92,7 +98,8 @@ class CloudDataCatalogCreateEntryOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, entry_id: str, @@ -103,7 +110,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -119,8 +126,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: result = hook.create_entry( @@ -210,7 +216,8 @@ class CloudDataCatalogCreateEntryGroupOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group_id: str, entry_group: Union[Dict, EntryGroup], @@ -220,7 +227,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -235,8 +242,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: result = hook.create_entry_group( @@ -328,7 +334,8 @@ class CloudDataCatalogCreateTagOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, location: str, entry_group: str, entry: str, @@ -340,7 +347,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -357,8 +364,7 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: tag = hook.create_tag( @@ -459,7 +465,8 @@ class CloudDataCatalogCreateTagTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, tag_template_id: str, tag_template: Union[Dict, TagTemplate], @@ -469,7 +476,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -484,8 +491,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: result = hook.create_tag_template( @@ -577,7 +583,8 @@ class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, tag_template: str, tag_template_field_id: str, @@ -588,7 +595,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -604,8 +611,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: result = hook.create_tag_template_field( @@ -688,7 +694,8 @@ class CloudDataCatalogDeleteEntryOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, entry: str, @@ -698,7 +705,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -713,8 +720,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: hook.delete_entry( @@ -769,12 +775,21 @@ class CloudDataCatalogDeleteEntryGroupOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("location", "entry_group", "project_id", "retry", "timeout", "metadata", - "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "location", + "entry_group", + "project_id", + "retry", + "timeout", + "metadata", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, project_id: Optional[str] = None, @@ -783,7 +798,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -797,8 +812,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: hook.delete_entry_group( @@ -869,7 +883,8 @@ class CloudDataCatalogDeleteTagOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, entry: str, @@ -880,7 +895,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -896,8 +911,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: hook.delete_tag( @@ -969,7 +983,8 @@ class CloudDataCatalogDeleteTagTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, tag_template: str, force: bool, @@ -979,7 +994,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -994,8 +1009,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: hook.delete_tag_template( @@ -1067,7 +1081,8 @@ class CloudDataCatalogDeleteTagTemplateFieldOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, tag_template: str, field: str, @@ -1078,7 +1093,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1094,8 +1109,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) try: hook.delete_tag_template_field( @@ -1165,7 +1179,8 @@ class CloudDataCatalogGetEntryOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, entry: str, @@ -1175,7 +1190,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1190,8 +1205,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.get_entry( location=self.location, @@ -1261,7 +1275,8 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, read_mask: Union[Dict, FieldMask], @@ -1271,7 +1286,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1286,8 +1301,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.get_entry_group( location=self.location, @@ -1351,7 +1365,8 @@ class CloudDataCatalogGetTagTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, tag_template: str, project_id: Optional[str] = None, @@ -1360,7 +1375,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1374,8 +1389,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.get_tag_template( location=self.location, @@ -1446,7 +1460,8 @@ class CloudDataCatalogListTagsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, entry_group: str, entry: str, @@ -1457,7 +1472,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1473,8 +1488,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.list_tags( location=self.location, @@ -1541,7 +1555,8 @@ class CloudDataCatalogLookupEntryOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, linked_resource: Optional[str] = None, sql_resource: Optional[str] = None, project_id: Optional[str] = None, @@ -1550,7 +1565,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.linked_resource = linked_resource @@ -1564,8 +1579,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.lookup_entry( linked_resource=self.linked_resource, @@ -1635,7 +1649,8 @@ class CloudDataCatalogRenameTagTemplateFieldOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, tag_template: str, field: str, @@ -1646,7 +1661,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1662,8 +1677,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.rename_tag_template_field( location=self.location, @@ -1757,7 +1771,8 @@ class CloudDataCatalogSearchCatalogOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, scope: Union[Dict, SearchCatalogRequest.Scope], query: str, page_size: int = 100, @@ -1767,7 +1782,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.scope = scope @@ -1782,8 +1797,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) result = hook.search_catalog( scope=self.scope, @@ -1863,7 +1877,8 @@ class CloudDataCatalogUpdateEntryOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, entry: Union[Dict, Entry], update_mask: Union[Dict, FieldMask], location: Optional[str] = None, @@ -1875,7 +1890,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.entry = entry @@ -1892,8 +1907,7 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.update_entry( entry=self.entry, @@ -1975,7 +1989,8 @@ class CloudDataCatalogUpdateTagOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, tag: Union[Dict, Tag], update_mask: Union[Dict, FieldMask], location: Optional[str] = None, @@ -1988,7 +2003,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.tag = tag @@ -2006,8 +2021,7 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.update_tag( tag=self.tag, @@ -2091,7 +2105,8 @@ class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, tag_template: Union[Dict, TagTemplate], update_mask: Union[Dict, FieldMask], location: Optional[str] = None, @@ -2102,7 +2117,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.tag_template = tag_template @@ -2118,8 +2133,7 @@ def __init__( def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.update_tag_template( tag_template=self.tag_template, @@ -2209,7 +2223,8 @@ class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, tag_template_field: Union[Dict, TagTemplateField], update_mask: Union[Dict, FieldMask], tag_template_field_name: Optional[str] = None, @@ -2222,7 +2237,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.tag_template_field_name = tag_template_field_name @@ -2240,8 +2255,7 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context: Dict): hook = CloudDataCatalogHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) hook.update_tag_template_field( tag_template_field=self.tag_template_field, diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 1d0ce32813f3e..8eb67a8ed55a5 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -39,6 +39,7 @@ class CheckJobRunning(Enum): FinishIfRunning - finish current dag run with no action WaitForRun - wait for job to finish and then continue with new job """ + IgnoreJob = 1 FinishIfRunning = 2 WaitForRun = 3 @@ -174,31 +175,36 @@ class DataflowCreateJavaJobOperator(BaseOperator): dag=my-dag) """ + template_fields = ['options', 'jar', 'job_name'] ui_color = '#0273d4' # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - jar: str, - job_name: str = '{{task.task_id}}', - dataflow_default_options: Optional[dict] = None, - options: Optional[dict] = None, - project_id: Optional[str] = None, - location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - poll_sleep: int = 10, - job_class: Optional[str] = None, - check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, - multiple_jobs: Optional[bool] = None, - **kwargs) -> None: + def __init__( + self, + *, + jar: str, + job_name: str = '{{task.task_id}}', + dataflow_default_options: Optional[dict] = None, + options: Optional[dict] = None, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + job_class: Optional[str] = None, + check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, + multiple_jobs: Optional[bool] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) dataflow_default_options = dataflow_default_options or {} options = options or {} options.setdefault('labels', {}).update( - {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) + {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} + ) self.project_id = project_id self.location = location self.gcp_conn_id = gcp_conn_id @@ -216,9 +222,7 @@ def __init__(self, *, def execute(self, context): self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - poll_sleep=self.poll_sleep + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep ) dataflow_options = copy.copy(self.dataflow_default_options) dataflow_options.update(self.options) @@ -228,12 +232,14 @@ def execute(self, context): name=self.job_name, variables=dataflow_options, project_id=self.project_id, - location=self.location + location=self.location, ) while is_running and self.check_if_running == CheckJobRunning.WaitForRun: is_running = self.hook.is_job_dataflow_running( - name=self.job_name, variables=dataflow_options, project_id=self.project_id, - location=self.location + name=self.job_name, + variables=dataflow_options, + project_id=self.project_id, + location=self.location, ) if not is_running: @@ -257,7 +263,7 @@ def set_current_job_id(job_id): multiple_jobs=self.multiple_jobs, on_new_job_id_callback=set_current_job_id, project_id=self.project_id, - location=self.location + location=self.location, ) def on_kill(self) -> None: @@ -359,6 +365,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator): For more detail on job template execution have a look at the reference: https://cloud.google.com/dataflow/docs/templates/executing-templates """ + template_fields = [ 'template', 'job_name', @@ -373,19 +380,21 @@ class DataflowTemplatedJobStartOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, - template: str, - job_name: str = '{{task.task_id}}', - options: Optional[Dict[str, Any]] = None, - dataflow_default_options: Optional[Dict[str, Any]] = None, - parameters: Optional[Dict[str, str]] = None, - project_id: Optional[str] = None, - location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - poll_sleep: int = 10, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + self, + *, + template: str, + job_name: str = '{{task.task_id}}', + options: Optional[Dict[str, Any]] = None, + dataflow_default_options: Optional[Dict[str, Any]] = None, + parameters: Optional[Dict[str, str]] = None, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.template = template self.job_name = job_name @@ -411,6 +420,7 @@ def execute(self, context): def set_current_job_id(job_id): self.job_id = job_id + options = self.dataflow_default_options options.update(self.options) @@ -421,7 +431,7 @@ def set_current_job_id(job_id): dataflow_template=self.template, on_new_job_id_callback=set_current_job_id, project_id=self.project_id, - location=self.location + location=self.location, ) return job @@ -500,25 +510,28 @@ class DataflowCreatePythonJobOperator(BaseOperator): JOB_STATE_RUNNING state. :type poll_sleep: int """ + template_fields = ['options', 'dataflow_default_options', 'job_name', 'py_file'] @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, - py_file: str, - job_name: str = '{{task.task_id}}', - dataflow_default_options: Optional[dict] = None, - options: Optional[dict] = None, - py_interpreter: str = "python3", - py_options: Optional[List[str]] = None, - py_requirements: Optional[List[str]] = None, - py_system_site_packages: bool = False, - project_id: Optional[str] = None, - location: str = DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - poll_sleep: int = 10, - **kwargs) -> None: + self, + *, + py_file: str, + job_name: str = '{{task.task_id}}', + dataflow_default_options: Optional[dict] = None, + options: Optional[dict] = None, + py_interpreter: str = "python3", + py_options: Optional[List[str]] = None, + py_requirements: Optional[List[str]] = None, + py_system_site_packages: bool = False, + project_id: Optional[str] = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + poll_sleep: int = 10, + **kwargs, + ) -> None: super().__init__(**kwargs) @@ -528,7 +541,8 @@ def __init__( # pylint: disable=too-many-arguments self.dataflow_default_options = dataflow_default_options or {} self.options = options or {} self.options.setdefault('labels', {}).update( - {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) + {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} + ) self.py_interpreter = py_interpreter self.py_requirements = py_requirements self.py_system_site_packages = py_system_site_packages @@ -551,16 +565,13 @@ def execute(self, context): self.py_file = tmp_gcs_file.name self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - poll_sleep=self.poll_sleep + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep ) dataflow_options = self.dataflow_default_options.copy() dataflow_options.update(self.options) # Convert argument names from lowerCamelCase to snake case. camel_to_snake = lambda name: re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), name) - formatted_options = {camel_to_snake(key): dataflow_options[key] - for key in dataflow_options} + formatted_options = {camel_to_snake(key): dataflow_options[key] for key in dataflow_options} def set_current_job_id(job_id): self.job_id = job_id diff --git a/airflow/providers/google/cloud/operators/datafusion.py b/airflow/providers/google/cloud/operators/datafusion.py index 4923ea8be76d1..23477627f9ad0 100644 --- a/airflow/providers/google/cloud/operators/datafusion.py +++ b/airflow/providers/google/cloud/operators/datafusion.py @@ -63,11 +63,15 @@ class CloudDataFusionRestartInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "impersonation_chain",) + template_fields = ( + "instance_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, instance_name: str, location: str, project_id: Optional[str] = None, @@ -75,7 +79,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.instance_name = instance_name @@ -95,9 +99,7 @@ def execute(self, context: Dict): ) self.log.info("Restarting Data Fusion instance: %s", self.instance_name) operation = hook.restart_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) hook.wait_for_operation(operation) self.log.info("Instance %s restarted successfully", self.instance_name) @@ -136,11 +138,15 @@ class CloudDataFusionDeleteInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "impersonation_chain",) + template_fields = ( + "instance_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, instance_name: str, location: str, project_id: Optional[str] = None, @@ -148,7 +154,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.instance_name = instance_name @@ -168,9 +174,7 @@ def execute(self, context: Dict): ) self.log.info("Deleting Data Fusion instance: %s", self.instance_name) operation = hook.delete_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) hook.wait_for_operation(operation) self.log.info("Instance %s deleted successfully", self.instance_name) @@ -212,11 +216,16 @@ class CloudDataFusionCreateInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "instance", "impersonation_chain",) + template_fields = ( + "instance_name", + "instance", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, instance_name: str, instance: Dict[str, Any], location: str, @@ -225,7 +234,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.instance_name = instance_name @@ -315,11 +324,16 @@ class CloudDataFusionUpdateInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "instance", "impersonation_chain",) + template_fields = ( + "instance_name", + "instance", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, instance_name: str, instance: Dict[str, Any], update_mask: str, @@ -329,7 +343,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.update_mask = update_mask @@ -394,11 +408,15 @@ class CloudDataFusionGetInstanceOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "impersonation_chain",) + template_fields = ( + "instance_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, instance_name: str, location: str, project_id: Optional[str] = None, @@ -406,7 +424,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.instance_name = instance_name @@ -426,9 +444,7 @@ def execute(self, context: Dict): ) self.log.info("Retrieving Data Fusion instance: %s", self.instance_name) instance = hook.get_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) return instance @@ -473,11 +489,16 @@ class CloudDataFusionCreatePipelineOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "pipeline_name", "impersonation_chain",) + template_fields = ( + "instance_name", + "pipeline_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, pipeline_name: str, pipeline: Dict[str, Any], instance_name: str, @@ -488,7 +509,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.pipeline_name = pipeline_name @@ -511,9 +532,7 @@ def execute(self, context: Dict): ) self.log.info("Creating Data Fusion pipeline: %s", self.pipeline_name) instance = hook.get_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) api_url = instance["apiEndpoint"] hook.create_pipeline( @@ -564,11 +583,17 @@ class CloudDataFusionDeletePipelineOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "version_id", "pipeline_name", "impersonation_chain",) + template_fields = ( + "instance_name", + "version_id", + "pipeline_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, pipeline_name: str, instance_name: str, location: str, @@ -579,7 +604,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.pipeline_name = pipeline_name @@ -602,9 +627,7 @@ def execute(self, context: Dict): ) self.log.info("Deleting Data Fusion pipeline: %s", self.pipeline_name) instance = hook.get_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) api_url = instance["apiEndpoint"] hook.delete_pipeline( @@ -656,11 +679,17 @@ class CloudDataFusionListPipelinesOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "artifact_name", "artifact_version", "impersonation_chain",) + template_fields = ( + "instance_name", + "artifact_name", + "artifact_version", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, instance_name: str, location: str, artifact_name: Optional[str] = None, @@ -671,7 +700,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.artifact_version = artifact_version @@ -694,9 +723,7 @@ def execute(self, context: Dict): ) self.log.info("Listing Data Fusion pipelines") instance = hook.get_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) api_url = instance["apiEndpoint"] pipelines = hook.list_pipelines( @@ -754,11 +781,17 @@ class CloudDataFusionStartPipelineOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "pipeline_name", "runtime_args", "impersonation_chain",) + template_fields = ( + "instance_name", + "pipeline_name", + "runtime_args", + "impersonation_chain", + ) @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, pipeline_name: str, instance_name: str, location: str, @@ -771,7 +804,7 @@ def __init__( # pylint: disable=too-many-arguments gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.pipeline_name = pipeline_name @@ -796,9 +829,7 @@ def execute(self, context: Dict): ) self.log.info("Starting Data Fusion pipeline: %s", self.pipeline_name) instance = hook.get_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) api_url = instance["apiEndpoint"] pipeline_id = hook.start_pipeline( @@ -816,7 +847,7 @@ def execute(self, context: Dict): pipeline_name=self.pipeline_name, namespace=self.namespace, instance_url=api_url, - timeout=self.pipeline_timeout + timeout=self.pipeline_timeout, ) @@ -857,11 +888,16 @@ class CloudDataFusionStopPipelineOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("instance_name", "pipeline_name", "impersonation_chain",) + template_fields = ( + "instance_name", + "pipeline_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, pipeline_name: str, instance_name: str, location: str, @@ -871,7 +907,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.pipeline_name = pipeline_name @@ -893,14 +929,10 @@ def execute(self, context: Dict): ) self.log.info("Starting Data Fusion pipeline: %s", self.pipeline_name) instance = hook.get_instance( - instance_name=self.instance_name, - location=self.location, - project_id=self.project_id, + instance_name=self.instance_name, location=self.location, project_id=self.project_id, ) api_url = instance["apiEndpoint"] hook.stop_pipeline( - pipeline_name=self.pipeline_name, - instance_url=api_url, - namespace=self.namespace, + pipeline_name=self.pipeline_name, instance_url=api_url, namespace=self.namespace, ) self.log.info("Pipeline started") diff --git a/airflow/providers/google/cloud/operators/dataprep.py b/airflow/providers/google/cloud/operators/dataprep.py index 4b1d7b5ba79b5..c78fa6e728ce0 100644 --- a/airflow/providers/google/cloud/operators/dataprep.py +++ b/airflow/providers/google/cloud/operators/dataprep.py @@ -43,9 +43,7 @@ class DataprepGetJobsForJobGroupOperator(BaseOperator): template_fields = ("job_id",) @apply_defaults - def __init__( - self, *, job_id: int, **kwargs - ) -> None: + def __init__(self, *, job_id: int, **kwargs) -> None: super().__init__(**kwargs) self.job_id = job_id diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 2855188895316..458f19533c8c3 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -34,7 +34,9 @@ from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry, exponential_sleep_generator from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module - Cluster, Duration, FieldMask, + Cluster, + Duration, + FieldMask, ) from google.protobuf.json_format import MessageToDict @@ -151,44 +153,46 @@ class ClusterGenerator: ``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]`` # noqa # pylint: disable=line-too-long :type customer_managed_key: str """ + # pylint: disable=too-many-arguments,too-many-locals - def __init__(self, - project_id: Optional[str] = None, - cluster_name: Optional[str] = None, - num_workers: Optional[int] = None, - zone: Optional[str] = None, - network_uri: Optional[str] = None, - subnetwork_uri: Optional[str] = None, - internal_ip_only: Optional[bool] = None, - tags: Optional[List[str]] = None, - storage_bucket: Optional[str] = None, - init_actions_uris: Optional[List[str]] = None, - init_action_timeout: str = "10m", - metadata: Optional[Dict] = None, - custom_image: Optional[str] = None, - custom_image_project_id: Optional[str] = None, - image_version: Optional[str] = None, - autoscaling_policy: Optional[str] = None, - properties: Optional[Dict] = None, - optional_components: Optional[List[str]] = None, - num_masters: int = 1, - master_machine_type: str = 'n1-standard-4', - master_disk_type: str = 'pd-standard', - master_disk_size: int = 1024, - worker_machine_type: str = 'n1-standard-4', - worker_disk_type: str = 'pd-standard', - worker_disk_size: int = 1024, - num_preemptible_workers: int = 0, - labels: Optional[Dict] = None, - region: Optional[str] = None, - service_account: Optional[str] = None, - service_account_scopes: Optional[List[str]] = None, - idle_delete_ttl: Optional[int] = None, - auto_delete_time: Optional[datetime] = None, - auto_delete_ttl: Optional[int] = None, - customer_managed_key: Optional[str] = None, - **kwargs - ) -> None: + def __init__( + self, + project_id: Optional[str] = None, + cluster_name: Optional[str] = None, + num_workers: Optional[int] = None, + zone: Optional[str] = None, + network_uri: Optional[str] = None, + subnetwork_uri: Optional[str] = None, + internal_ip_only: Optional[bool] = None, + tags: Optional[List[str]] = None, + storage_bucket: Optional[str] = None, + init_actions_uris: Optional[List[str]] = None, + init_action_timeout: str = "10m", + metadata: Optional[Dict] = None, + custom_image: Optional[str] = None, + custom_image_project_id: Optional[str] = None, + image_version: Optional[str] = None, + autoscaling_policy: Optional[str] = None, + properties: Optional[Dict] = None, + optional_components: Optional[List[str]] = None, + num_masters: int = 1, + master_machine_type: str = 'n1-standard-4', + master_disk_type: str = 'pd-standard', + master_disk_size: int = 1024, + worker_machine_type: str = 'n1-standard-4', + worker_disk_type: str = 'pd-standard', + worker_disk_size: int = 1024, + num_preemptible_workers: int = 0, + labels: Optional[Dict] = None, + region: Optional[str] = None, + service_account: Optional[str] = None, + service_account_scopes: Optional[List[str]] = None, + idle_delete_ttl: Optional[int] = None, + auto_delete_time: Optional[datetime] = None, + auto_delete_ttl: Optional[int] = None, + customer_managed_key: Optional[str] = None, + **kwargs, + ) -> None: self.cluster_name = cluster_name self.project_id = project_id @@ -243,14 +247,14 @@ def _get_init_action_timeout(self): raise AirflowException( "DataprocClusterCreateOperator init_action_timeout" - " should be expressed in minutes or seconds. i.e. 10m, 30s") + " should be expressed in minutes or seconds. i.e. 10m, 30s" + ) def _build_gce_cluster_config(self, cluster_data): if self.zone: - zone_uri = \ - 'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format( - self.project_id, self.zone - ) + zone_uri = 'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format( + self.project_id, self.zone + ) cluster_data['config']['gce_cluster_config']['zone_uri'] = zone_uri if self.metadata: @@ -260,51 +264,50 @@ def _build_gce_cluster_config(self, cluster_data): cluster_data['config']['gce_cluster_config']['network_uri'] = self.network_uri if self.subnetwork_uri: - cluster_data['config']['gce_cluster_config']['subnetwork_uri'] = \ - self.subnetwork_uri + cluster_data['config']['gce_cluster_config']['subnetwork_uri'] = self.subnetwork_uri if self.internal_ip_only: if not self.subnetwork_uri: - raise AirflowException("Set internal_ip_only to true only when" - " you pass a subnetwork_uri.") + raise AirflowException("Set internal_ip_only to true only when" " you pass a subnetwork_uri.") cluster_data['config']['gce_cluster_config']['internal_ip_only'] = True if self.tags: cluster_data['config']['gce_cluster_config']['tags'] = self.tags if self.service_account: - cluster_data['config']['gce_cluster_config']['service_account'] = \ - self.service_account + cluster_data['config']['gce_cluster_config']['service_account'] = self.service_account if self.service_account_scopes: - cluster_data['config']['gce_cluster_config']['service_account_scopes'] = \ - self.service_account_scopes + cluster_data['config']['gce_cluster_config'][ + 'service_account_scopes' + ] = self.service_account_scopes return cluster_data def _build_lifecycle_config(self, cluster_data): if self.idle_delete_ttl: - cluster_data['config']['lifecycle_config']['idle_delete_ttl'] = \ - "{}s".format(self.idle_delete_ttl) + cluster_data['config']['lifecycle_config']['idle_delete_ttl'] = "{}s".format(self.idle_delete_ttl) if self.auto_delete_time: utc_auto_delete_time = timezone.convert_to_utc(self.auto_delete_time) - cluster_data['config']['lifecycle_config']['auto_delete_time'] = \ - utc_auto_delete_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') + cluster_data['config']['lifecycle_config']['auto_delete_time'] = utc_auto_delete_time.strftime( + '%Y-%m-%dT%H:%M:%S.%fZ' + ) elif self.auto_delete_ttl: - cluster_data['config']['lifecycle_config']['auto_delete_ttl'] = \ - "{}s".format(self.auto_delete_ttl) + cluster_data['config']['lifecycle_config']['auto_delete_ttl'] = "{}s".format(self.auto_delete_ttl) return cluster_data def _build_cluster_data(self): if self.zone: - master_type_uri = \ - "https://www.googleapis.com/compute/v1/projects/{}/zones/{}/machineTypes/{}".format( - self.project_id, self.zone, self.master_machine_type) - worker_type_uri = \ - "https://www.googleapis.com/compute/v1/projects/{}/zones/{}/machineTypes/{}".format( - self.project_id, self.zone, self.worker_machine_type) + master_type_uri = ( + "https://www.googleapis.com/compute/v1/projects" + f"/{self.project_id}/zones/{self.zone}/machineTypes/{self.master_machine_type}" + ) + worker_type_uri = ( + "https://www.googleapis.com/compute/v1/projects" + f"/{self.project_id}/zones/{self.zone}/machineTypes/{self.worker_machine_type}" + ) else: master_type_uri = self.master_machine_type worker_type_uri = self.worker_machine_type @@ -313,30 +316,29 @@ def _build_cluster_data(self): 'project_id': self.project_id, 'cluster_name': self.cluster_name, 'config': { - 'gce_cluster_config': { - }, + 'gce_cluster_config': {}, 'master_config': { 'num_instances': self.num_masters, 'machine_type_uri': master_type_uri, 'disk_config': { 'boot_disk_type': self.master_disk_type, - 'boot_disk_size_gb': self.master_disk_size - } + 'boot_disk_size_gb': self.master_disk_size, + }, }, 'worker_config': { 'num_instances': self.num_workers, 'machine_type_uri': worker_type_uri, 'disk_config': { 'boot_disk_type': self.worker_disk_type, - 'boot_disk_size_gb': self.worker_disk_size - } + 'boot_disk_size_gb': self.worker_disk_size, + }, }, 'secondary_worker_config': {}, 'software_config': {}, 'lifecycle_config': {}, 'encryption_config': {}, 'autoscaling_config': {}, - } + }, } if self.num_preemptible_workers > 0: cluster_data['config']['secondary_worker_config'] = { @@ -344,9 +346,9 @@ def _build_cluster_data(self): 'machine_type_uri': worker_type_uri, 'disk_config': { 'boot_disk_type': self.worker_disk_type, - 'boot_disk_size_gb': self.worker_disk_size + 'boot_disk_size_gb': self.worker_disk_size, }, - 'is_preemptible': True + 'is_preemptible': True, } cluster_data['labels'] = self.labels or {} @@ -354,9 +356,9 @@ def _build_cluster_data(self): # Dataproc labels must conform to the following regex: # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows # semantic versioning spec: x.y.z). - cluster_data['labels'].update({ - 'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-') - }) + cluster_data['labels'].update( + {'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')} + ) if self.storage_bucket: cluster_data['config']['config_bucket'] = self.storage_bucket @@ -365,9 +367,10 @@ def _build_cluster_data(self): elif self.custom_image: project_id = self.custom_image_project_id or self.project_id - custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \ - '{}/global/images/{}'.format(project_id, - self.custom_image) + custom_image_url = ( + 'https://www.googleapis.com/compute/beta/projects/' + '{}/global/images/{}'.format(project_id, self.custom_image) + ) cluster_data['config']['master_config']['image_uri'] = custom_image_url if not self.single_node: cluster_data['config']['worker_config']['image_uri'] = custom_image_url @@ -387,16 +390,13 @@ def _build_cluster_data(self): if self.init_actions_uris: init_actions_dict = [ - { - 'executable_file': uri, - 'execution_timeout': self._get_init_action_timeout() - } for uri in self.init_actions_uris + {'executable_file': uri, 'execution_timeout': self._get_init_action_timeout()} + for uri in self.init_actions_uris ] cluster_data['config']['initialization_actions'] = init_actions_dict if self.customer_managed_key: - cluster_data['config']['encryption_config'] = \ - {'gce_pd_kms_key_name': self.customer_managed_key} + cluster_data['config']['encryption_config'] = {'gce_pd_kms_key_name': self.customer_managed_key} if self.autoscaling_policy: cluster_data['config']['autoscaling_config'] = {'policy_uri': self.autoscaling_policy} @@ -467,7 +467,12 @@ class DataprocCreateClusterOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('project_id', 'region', 'cluster', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'region', + 'cluster', + 'impersonation_chain', + ) @apply_defaults def __init__( # pylint: disable=too-many-arguments @@ -484,7 +489,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: # TODO: remove one day if cluster is None: @@ -493,7 +498,8 @@ def __init__( # pylint: disable=too-many-arguments "will be deprecated. Please provide cluster object using `cluster` parameter. " "You can use `airflow.dataproc.ClusterGenerator.generate_cluster` method to " "obtain cluster object.".format(type(self).__name__), - DeprecationWarning, stacklevel=1 + DeprecationWarning, + stacklevel=1, ) # Remove result of apply defaults if 'params' in kwargs: @@ -545,9 +551,7 @@ def _create_cluster(self, hook): def _delete_cluster(self, hook): self.log.info("Deleting the cluster") hook.delete_cluster( - region=self.region, - cluster_name=self.cluster_name, - project_id=self.project_id, + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id, ) def _get_cluster(self, hook: DataprocHook): @@ -565,14 +569,9 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: return self.log.info("Cluster is in ERROR state") gcs_uri = hook.diagnose_cluster( - region=self.region, - cluster_name=self.cluster_name, - project_id=self.project_id, - ) - self.log.info( - 'Diagnostic information for cluster %s available at: %s', - self.cluster_name, gcs_uri + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id, ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) if self.delete_on_error: self._delete_cluster(hook) raise AirflowException("Cluster was created but was in ERROR state.") @@ -582,9 +581,7 @@ def _wait_for_cluster_in_deleting_state(self, hook: DataprocHook) -> None: time_left = self.timeout for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): if time_left < 0: - raise AirflowException( - f"Cluster {self.cluster_name} is still DELETING state, aborting" - ) + raise AirflowException(f"Cluster {self.cluster_name} is still DELETING state, aborting") time.sleep(time_to_sleep) time_left = time_left - time_to_sleep try: @@ -599,9 +596,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: if cluster.status.state != cluster.status.CREATING: break if time_left < 0: - raise AirflowException( - f"Cluster {self.cluster_name} is still CREATING state, aborting" - ) + raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting") time.sleep(time_to_sleep) time_left = time_left - time_to_sleep cluster = self._get_cluster(hook) @@ -609,10 +604,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: def execute(self, context): self.log.info('Creating cluster: %s', self.cluster_name) - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: # First try to create a new cluster cluster = self._create_cluster(hook) @@ -685,19 +677,27 @@ class DataprocScaleClusterOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['cluster_name', 'project_id', 'region', 'impersonation_chain', ] + template_fields = [ + 'cluster_name', + 'project_id', + 'region', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, *, - cluster_name: str, - project_id: Optional[str] = None, - region: str = 'global', - num_workers: int = 2, - num_preemptible_workers: int = 0, - graceful_decommission_timeout: Optional[str] = None, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + cluster_name: str, + project_id: Optional[str] = None, + region: str = 'global', + num_workers: int = 2, + num_preemptible_workers: int = 0, + graceful_decommission_timeout: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.project_id = project_id self.region = region @@ -714,18 +714,14 @@ def __init__(self, *, cls=type(self).__name__ ), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) def _build_scale_cluster_data(self): scale_data = { 'config': { - 'worker_config': { - 'num_instances': self.num_workers - }, - 'secondary_worker_config': { - 'num_instances': self.num_preemptible_workers - } + 'worker_config': {'num_instances': self.num_workers}, + 'secondary_worker_config': {'num_instances': self.num_preemptible_workers}, } } return scale_data @@ -766,15 +762,9 @@ def execute(self, context): self.log.info("Scaling cluster: %s", self.cluster_name) scaling_cluster_data = self._build_scale_cluster_data() - update_mask = [ - "config.worker_config.num_instances", - "config.secondary_worker_config.num_instances" - ] - - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"] + + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) operation = hook.update_cluster( project_id=self.project_id, location=self.region, @@ -829,7 +819,8 @@ class DataprocDeleteClusterOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, project_id: str, region: str, cluster_name: str, @@ -840,7 +831,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.project_id = project_id @@ -855,10 +846,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Deleting cluster: %s", self.cluster_name) operation = hook.delete_cluster( project_id=self.project_id, @@ -925,21 +913,25 @@ class DataprocJobBaseOperator(BaseOperator): an 8 character random string. :vartype dataproc_job_id: str """ + job_type = "" @apply_defaults - def __init__(self, *, - job_name: str = '{{task.task_id}}_{{ds_nodash}}', - cluster_name: str = "cluster-1", - dataproc_properties: Optional[Dict] = None, - dataproc_jars: Optional[List[str]] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - region: str = 'global', - job_error_states: Optional[Set[str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + job_name: str = '{{task.task_id}}_{{ds_nodash}}', + cluster_name: str = "cluster-1", + dataproc_properties: Optional[Dict] = None, + dataproc_jars: Optional[List[str]] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + region: str = 'global', + job_error_states: Optional[Set[str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to @@ -967,7 +959,7 @@ def create_job_template(self): task_id=self.task_id, cluster_name=self.cluster_name, job_type=self.job_type, - properties=self.dataproc_properties + properties=self.dataproc_properties, ) self.job_template.set_job_name(self.job_name) self.job_template.add_jar_file_uris(self.dataproc_jars) @@ -985,16 +977,10 @@ def execute(self, context): self.dataproc_job_id = self.job["job"]["reference"]["job_id"] self.log.info('Submitting %s job %s', self.job_type, self.dataproc_job_id) job_object = self.hook.submit_job( - project_id=self.project_id, - job=self.job["job"], - location=self.region, + project_id=self.project_id, job=self.job["job"], location=self.region, ) job_id = job_object.reference.job_id - self.hook.wait_for_job( - job_id=job_id, - location=self.region, - project_id=self.project_id - ) + self.hook.wait_for_job(job_id=job_id, location=self.region, project_id=self.project_id) self.log.info('Job executed correctly.') else: raise AirflowException("Create a job template before") @@ -1006,9 +992,7 @@ def on_kill(self): """ if self.dataproc_job_id: self.hook.cancel_job( - project_id=self.project_id, - job_id=self.dataproc_job_id, - location=self.region + project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region ) @@ -1054,19 +1038,32 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): :param variables: Map of named parameters for the query. (templated) :type variables: dict """ - template_fields = ['query', 'variables', 'job_name', 'cluster_name', - 'region', 'dataproc_jars', 'dataproc_properties', 'impersonation_chain', ] - template_ext = ('.pg', '.pig',) + + template_fields = [ + 'query', + 'variables', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', + ] + template_ext = ( + '.pg', + '.pig', + ) ui_color = '#0273d4' job_type = 'pig_job' @apply_defaults def __init__( - self, *, + self, + *, query: Optional[str] = None, query_uri: Optional[str] = None, variables: Optional[Dict] = None, - **kwargs + **kwargs, ) -> None: # TODO: Remove one day warnings.warn( @@ -1074,7 +1071,7 @@ def __init__( " `generate_job` method of `{cls}` to generate dictionary representing your job" " and use it with the new operator.".format(cls=type(self).__name__), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) super().__init__(**kwargs) @@ -1119,19 +1116,32 @@ class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): :param variables: Map of named parameters for the query. :type variables: dict """ - template_fields = ['query', 'variables', 'job_name', 'cluster_name', - 'region', 'dataproc_jars', 'dataproc_properties', 'impersonation_chain', ] - template_ext = ('.q', '.hql',) + + template_fields = [ + 'query', + 'variables', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', + ] + template_ext = ( + '.q', + '.hql', + ) ui_color = '#0273d4' job_type = 'hive_job' @apply_defaults def __init__( - self, *, + self, + *, query: Optional[str] = None, query_uri: Optional[str] = None, variables: Optional[Dict] = None, - **kwargs + **kwargs, ) -> None: # TODO: Remove one day warnings.warn( @@ -1139,7 +1149,7 @@ def __init__( " `generate_job` method of `{cls}` to generate dictionary representing your job" " and use it with the new operator.".format(cls=type(self).__name__), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) super().__init__(**kwargs) @@ -1184,19 +1194,29 @@ class DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator): :param variables: Map of named parameters for the query. (templated) :type variables: dict """ - template_fields = ['query', 'variables', 'job_name', 'cluster_name', - 'region', 'dataproc_jars', 'dataproc_properties', 'impersonation_chain', ] + + template_fields = [ + 'query', + 'variables', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', + ] template_ext = ('.q',) ui_color = '#0273d4' job_type = 'spark_sql_job' @apply_defaults def __init__( - self, *, + self, + *, query: Optional[str] = None, query_uri: Optional[str] = None, variables: Optional[Dict] = None, - **kwargs + **kwargs, ) -> None: # TODO: Remove one day warnings.warn( @@ -1204,7 +1224,7 @@ def __init__( " `generate_job` method of `{cls}` to generate dictionary representing your job" " and use it with the new operator.".format(cls=type(self).__name__), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) super().__init__(**kwargs) @@ -1257,20 +1277,28 @@ class DataprocSubmitSparkJobOperator(DataprocJobBaseOperator): :type files: list """ - template_fields = ['arguments', 'job_name', 'cluster_name', - 'region', 'dataproc_jars', 'dataproc_properties', 'impersonation_chain', ] + template_fields = [ + 'arguments', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', + ] ui_color = '#0273d4' job_type = 'spark_job' @apply_defaults def __init__( - self, *, + self, + *, main_jar: Optional[str] = None, main_class: Optional[str] = None, arguments: Optional[List] = None, archives: Optional[List] = None, files: Optional[List] = None, - **kwargs + **kwargs, ) -> None: # TODO: Remove one day warnings.warn( @@ -1278,7 +1306,7 @@ def __init__( " `generate_job` method of `{cls}` to generate dictionary representing your job" " and use it with the new operator.".format(cls=type(self).__name__), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) super().__init__(**kwargs) @@ -1329,20 +1357,28 @@ class DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator): :type files: list """ - template_fields = ['arguments', 'job_name', 'cluster_name', - 'region', 'dataproc_jars', 'dataproc_properties', 'impersonation_chain', ] + template_fields = [ + 'arguments', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', + ] ui_color = '#0273d4' job_type = 'hadoop_job' @apply_defaults def __init__( - self, *, + self, + *, main_jar: Optional[str] = None, main_class: Optional[str] = None, arguments: Optional[List] = None, archives: Optional[List] = None, files: Optional[List] = None, - **kwargs + **kwargs, ) -> None: # TODO: Remove one day warnings.warn( @@ -1350,7 +1386,7 @@ def __init__( " `generate_job` method of `{cls}` to generate dictionary representing your job" " and use it with the new operator.".format(cls=type(self).__name__), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) super().__init__(**kwargs) @@ -1401,8 +1437,16 @@ class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator): :type pyfiles: list """ - template_fields = ['main', 'arguments', 'job_name', 'cluster_name', - 'region', 'dataproc_jars', 'dataproc_properties', 'impersonation_chain', ] + template_fields = [ + 'main', + 'arguments', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', + ] ui_color = '#0273d4' job_type = 'pyspark_job' @@ -1419,30 +1463,31 @@ def _upload_file_temp(self, bucket, local_file): if not bucket: raise AirflowException( "If you want Airflow to upload the local file to a temporary bucket, set " - "the 'temp_bucket' key in the connection string") + "the 'temp_bucket' key in the connection string" + ) self.log.info("Uploading %s to %s", local_file, temp_filename) GCSHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain + google_cloud_storage_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ).upload( bucket_name=bucket, object_name=temp_filename, mime_type='application/x-python', - filename=local_file + filename=local_file, ) return "gs://{}/{}".format(bucket, temp_filename) @apply_defaults def __init__( - self, *, + self, + *, main: str, arguments: Optional[List] = None, archives: Optional[List] = None, pyfiles: Optional[List] = None, files: Optional[List] = None, - **kwargs + **kwargs, ) -> None: # TODO: Remove one day warnings.warn( @@ -1450,7 +1495,7 @@ def __init__( " `generate_job` method of `{cls}` to generate dictionary representing your job" " and use it with the new operator.".format(cls=type(self).__name__), DeprecationWarning, - stacklevel=1 + stacklevel=1, ) super().__init__(**kwargs) @@ -1469,9 +1514,7 @@ def generate_job(self): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, - region=self.region, - cluster_name=self.cluster_name + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name ) bucket = cluster_info['config']['config_bucket'] self.main = "gs://{}/{}".format(bucket, self.main) @@ -1488,9 +1531,7 @@ def execute(self, context): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, - region=self.region, - cluster_name=self.cluster_name + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name ) bucket = cluster_info['config']['config_bucket'] self.main = self._upload_file_temp(bucket, self.main) @@ -1555,11 +1596,15 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['template_id', 'impersonation_chain', ] + template_fields = [ + 'template_id', + 'impersonation_chain', + ] @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, template_id: str, region: str, project_id: Optional[str] = None, @@ -1571,7 +1616,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -1588,10 +1633,7 @@ def __init__( # pylint: disable=too-many-arguments self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info('Instantiating template %s', self.template_id) operation = hook.instantiate_workflow_template( project_id=self.project_id, @@ -1659,11 +1701,15 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['template', 'impersonation_chain', ] + template_fields = [ + 'template', + 'impersonation_chain', + ] @apply_defaults def __init__( - self, *, + self, + *, template: Dict, region: str, project_id: Optional[str] = None, @@ -1673,7 +1719,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template = template @@ -1689,10 +1735,7 @@ def __init__( def execute(self, context): self.log.info('Instantiating Inline Template') - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) operation = hook.instantiate_inline_workflow_template( template=self.template, project_id=self.project_id, @@ -1744,11 +1787,17 @@ class DataprocSubmitJobOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('project_id', 'location', 'job', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'location', + 'job', + 'impersonation_chain', + ) @apply_defaults def __init__( - self, *, + self, + *, project_id: str, location: str, job: Dict, @@ -1758,7 +1807,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -1773,10 +1822,7 @@ def __init__( def execute(self, context: Dict): self.log.info("Submitting job") - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) job_object = hook.submit_job( project_id=self.project_id, location=self.location, @@ -1788,11 +1834,7 @@ def execute(self, context: Dict): ) job_id = job_object.reference.job_id self.log.info("Waiting for job %s to complete", job_id) - hook.wait_for_job( - job_id=job_id, - project_id=self.project_id, - location=self.location - ) + hook.wait_for_job(job_id=job_id, project_id=self.project_id, location=self.location) self.log.info("Job completed successfully.") @@ -1847,11 +1889,13 @@ class DataprocUpdateClusterOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + template_fields = ('impersonation_chain',) @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, location: str, cluster_name: str, cluster: Union[Dict, Cluster], @@ -1864,7 +1908,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.project_id = project_id @@ -1881,10 +1925,7 @@ def __init__( # pylint: disable=too-many-arguments self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) self.log.info("Updating %s cluster.", self.cluster_name) operation = hook.update_cluster( project_id=self.project_id, diff --git a/airflow/providers/google/cloud/operators/datastore.py b/airflow/providers/google/cloud/operators/datastore.py index 15d804d5d4185..333a17a5f96c5 100644 --- a/airflow/providers/google/cloud/operators/datastore.py +++ b/airflow/providers/google/cloud/operators/datastore.py @@ -72,23 +72,32 @@ class CloudDatastoreExportEntitiesOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['bucket', 'namespace', 'entity_filter', 'labels', 'impersonation_chain', ] + + template_fields = [ + 'bucket', + 'namespace', + 'entity_filter', + 'labels', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, # pylint: disable=too-many-arguments - *, - bucket: str, - namespace: Optional[str] = None, - datastore_conn_id: str = 'google_cloud_default', - cloud_storage_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - entity_filter: Optional[dict] = None, - labels: Optional[dict] = None, - polling_interval_in_seconds: int = 10, - overwrite_existing: bool = False, - project_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, # pylint: disable=too-many-arguments + *, + bucket: str, + namespace: Optional[str] = None, + datastore_conn_id: str = 'google_cloud_default', + cloud_storage_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[dict] = None, + polling_interval_in_seconds: int = 10, + overwrite_existing: bool = False, + project_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.datastore_conn_id = datastore_conn_id self.cloud_storage_conn_id = cloud_storage_conn_id @@ -108,28 +117,23 @@ def execute(self, context): self.log.info('Exporting data to Cloud Storage bucket %s', self.bucket) if self.overwrite_existing and self.namespace: - gcs_hook = GCSHook( - self.cloud_storage_conn_id, - impersonation_chain=self.impersonation_chain - ) + gcs_hook = GCSHook(self.cloud_storage_conn_id, impersonation_chain=self.impersonation_chain) objects = gcs_hook.list(self.bucket, prefix=self.namespace) for obj in objects: gcs_hook.delete(self.bucket, obj) ds_hook = DatastoreHook( - self.datastore_conn_id, - self.delegate_to, - impersonation_chain=self.impersonation_chain, + self.datastore_conn_id, self.delegate_to, impersonation_chain=self.impersonation_chain, + ) + result = ds_hook.export_to_storage_bucket( + bucket=self.bucket, + namespace=self.namespace, + entity_filter=self.entity_filter, + labels=self.labels, + project_id=self.project_id, ) - result = ds_hook.export_to_storage_bucket(bucket=self.bucket, - namespace=self.namespace, - entity_filter=self.entity_filter, - labels=self.labels, - project_id=self.project_id - ) operation_name = result['name'] - result = ds_hook.poll_operation_until_done(operation_name, - self.polling_interval_in_seconds) + result = ds_hook.poll_operation_until_done(operation_name, self.polling_interval_in_seconds) state = result['metadata']['common']['state'] if state != 'SUCCESSFUL': @@ -179,23 +183,31 @@ class CloudDatastoreImportEntitiesOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['bucket', 'file', 'namespace', 'entity_filter', 'labels', - 'impersonation_chain', ] + template_fields = [ + 'bucket', + 'file', + 'namespace', + 'entity_filter', + 'labels', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, - *, - bucket: str, - file: str, - namespace: Optional[str] = None, - entity_filter: Optional[dict] = None, - labels: Optional[dict] = None, - datastore_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - polling_interval_in_seconds: float = 10, - project_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket: str, + file: str, + namespace: Optional[str] = None, + entity_filter: Optional[dict] = None, + labels: Optional[dict] = None, + datastore_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + polling_interval_in_seconds: float = 10, + project_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.datastore_conn_id = datastore_conn_id self.delegate_to = delegate_to @@ -213,20 +225,18 @@ def __init__(self, def execute(self, context): self.log.info('Importing data from Cloud Storage bucket %s', self.bucket) ds_hook = DatastoreHook( - self.datastore_conn_id, - self.delegate_to, - impersonation_chain=self.impersonation_chain, + self.datastore_conn_id, self.delegate_to, impersonation_chain=self.impersonation_chain, + ) + result = ds_hook.import_from_storage_bucket( + bucket=self.bucket, + file=self.file, + namespace=self.namespace, + entity_filter=self.entity_filter, + labels=self.labels, + project_id=self.project_id, ) - result = ds_hook.import_from_storage_bucket(bucket=self.bucket, - file=self.file, - namespace=self.namespace, - entity_filter=self.entity_filter, - labels=self.labels, - project_id=self.project_id - ) operation_name = result['name'] - result = ds_hook.poll_operation_until_done(operation_name, - self.polling_interval_in_seconds) + result = ds_hook.poll_operation_until_done(operation_name, self.polling_interval_in_seconds) state = result['metadata']['common']['state'] if state != 'SUCCESSFUL': @@ -266,17 +276,22 @@ class CloudDatastoreAllocateIdsOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("partial_keys", "impersonation_chain",) + + template_fields = ( + "partial_keys", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, partial_keys: List, project_id: Optional[str] = None, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -287,14 +302,8 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - keys = hook.allocate_ids( - partial_keys=self.partial_keys, - project_id=self.project_id, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + keys = hook.allocate_ids(partial_keys=self.partial_keys, project_id=self.project_id,) return keys @@ -329,17 +338,22 @@ class CloudDatastoreBeginTransactionOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("transaction_options", "impersonation_chain",) + + template_fields = ( + "transaction_options", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, transaction_options: Dict[str, Any], project_id: Optional[str] = None, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -350,13 +364,9 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) handle = hook.begin_transaction( - transaction_options=self.transaction_options, - project_id=self.project_id, + transaction_options=self.transaction_options, project_id=self.project_id, ) return handle @@ -392,17 +402,22 @@ class CloudDatastoreCommitOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body", "impersonation_chain",) + + template_fields = ( + "body", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, body: Dict[str, Any], project_id: Optional[str] = None, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -413,14 +428,8 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - response = hook.commit( - body=self.body, - project_id=self.project_id, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + response = hook.commit(body=self.body, project_id=self.project_id,) return response @@ -455,17 +464,22 @@ class CloudDatastoreRollbackOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("transaction", "impersonation_chain",) + + template_fields = ( + "transaction", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, transaction: str, project_id: Optional[str] = None, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -476,13 +490,9 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.rollback( - transaction=self.transaction, - project_id=self.project_id, + transaction=self.transaction, project_id=self.project_id, ) @@ -517,17 +527,22 @@ class CloudDatastoreRunQueryOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body", "impersonation_chain",) + + template_fields = ( + "body", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, body: Dict[str, Any], project_id: Optional[str] = None, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -538,14 +553,8 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - response = hook.run_query( - body=self.body, - project_id=self.project_id, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + response = hook.run_query(body=self.body, project_id=self.project_id,) return response @@ -574,16 +583,21 @@ class CloudDatastoreGetOperationOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("name", "impersonation_chain",) + + template_fields = ( + "name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, name: str, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -593,10 +607,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) op = hook.get_operation(name=self.name) return op @@ -626,16 +637,21 @@ class CloudDatastoreDeleteOperationOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("name", "impersonation_chain",) + + template_fields = ( + "name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, name: str, delegate_to: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -645,8 +661,5 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = DatastoreHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = DatastoreHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.delete_operation(name=self.name) diff --git a/airflow/providers/google/cloud/operators/dlp.py b/airflow/providers/google/cloud/operators/dlp.py index 43c698eb363a2..7a5aef4e0a71a 100644 --- a/airflow/providers/google/cloud/operators/dlp.py +++ b/airflow/providers/google/cloud/operators/dlp.py @@ -27,8 +27,17 @@ from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry from google.cloud.dlp_v2.types import ( - ByteContentItem, ContentItem, DeidentifyConfig, DeidentifyTemplate, FieldMask, InspectConfig, - InspectJobConfig, InspectTemplate, JobTrigger, RedactImageRequest, RiskAnalysisJobConfig, + ByteContentItem, + ContentItem, + DeidentifyConfig, + DeidentifyTemplate, + FieldMask, + InspectConfig, + InspectJobConfig, + InspectTemplate, + JobTrigger, + RedactImageRequest, + RiskAnalysisJobConfig, StoredInfoTypeConfig, ) from google.protobuf.json_format import MessageToDict @@ -70,11 +79,17 @@ class CloudDLPCancelDLPJobOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dlp_job_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "dlp_job_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dlp_job_id: str, project_id: Optional[str] = None, retry: Optional[Retry] = None, @@ -82,7 +97,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.dlp_job_id = dlp_job_id @@ -94,10 +109,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.cancel_dlp_job( dlp_job_id=self.dlp_job_id, project_id=self.project_id, @@ -158,7 +170,8 @@ class CloudDLPCreateDeidentifyTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, organization_id: Optional[str] = None, project_id: Optional[str] = None, deidentify_template: Optional[Union[Dict, DeidentifyTemplate]] = None, @@ -168,7 +181,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.organization_id = organization_id @@ -182,10 +195,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: template = hook.create_deidentify_template( organization_id=self.organization_id, @@ -250,12 +260,19 @@ class CloudDLPCreateDLPJobOperator(BaseOperator): :rtype: google.cloud.dlp_v2.types.DlpJob """ - template_fields = ("project_id", "inspect_job", "risk_job", "job_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "project_id", + "inspect_job", + "risk_job", + "job_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, inspect_job: Optional[Union[Dict, InspectJobConfig]] = None, risk_job: Optional[Union[Dict, RiskAnalysisJobConfig]] = None, @@ -266,7 +283,7 @@ def __init__( wait_until_finished: bool = True, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -281,10 +298,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: job = hook.create_dlp_job( project_id=self.project_id, @@ -358,7 +372,8 @@ class CloudDLPCreateInspectTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, organization_id: Optional[str] = None, project_id: Optional[str] = None, inspect_template: Optional[InspectTemplate] = None, @@ -368,7 +383,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.organization_id = organization_id @@ -382,10 +397,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: template = hook.create_inspect_template( organization_id=self.organization_id, @@ -445,12 +457,18 @@ class CloudDLPCreateJobTriggerOperator(BaseOperator): :rtype: google.cloud.dlp_v2.types.JobTrigger """ - template_fields = ("project_id", "job_trigger", "trigger_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "project_id", + "job_trigger", + "trigger_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, job_trigger: Optional[Union[Dict, JobTrigger]] = None, trigger_id: Optional[str] = None, @@ -459,7 +477,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -472,10 +490,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: trigger = hook.create_job_trigger( project_id=self.project_id, @@ -546,7 +561,8 @@ class CloudDLPCreateStoredInfoTypeOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, organization_id: Optional[str] = None, project_id: Optional[str] = None, config: Optional[StoredInfoTypeConfig] = None, @@ -556,7 +572,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.organization_id = organization_id @@ -570,10 +586,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: info = hook.create_stored_info_type( organization_id=self.organization_id, @@ -658,7 +671,8 @@ class CloudDLPDeidentifyContentOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, deidentify_config: Optional[Union[Dict, DeidentifyConfig]] = None, inspect_config: Optional[Union[Dict, InspectConfig]] = None, @@ -670,7 +684,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -686,10 +700,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) response = hook.deidentify_content( project_id=self.project_id, deidentify_config=self.deidentify_config, @@ -739,12 +750,18 @@ class CloudDLPDeleteDeidentifyTemplateOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("template_id", "organization_id", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, template_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -753,7 +770,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template_id = template_id @@ -766,10 +783,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: hook.delete_deidentify_template( template_id=self.template_id, @@ -816,11 +830,17 @@ class CloudDLPDeleteDLPJobOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("dlp_job_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "dlp_job_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dlp_job_id: str, project_id: Optional[str] = None, retry: Optional[Retry] = None, @@ -828,7 +848,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.dlp_job_id = dlp_job_id @@ -840,10 +860,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: hook.delete_dlp_job( dlp_job_id=self.dlp_job_id, @@ -891,12 +908,18 @@ class CloudDLPDeleteInspectTemplateOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("template_id", "organization_id", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, template_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -905,7 +928,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template_id = template_id @@ -918,10 +941,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: hook.delete_inspect_template( template_id=self.template_id, @@ -967,11 +987,17 @@ class CloudDLPDeleteJobTriggerOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("job_trigger_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "job_trigger_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, job_trigger_id: str, project_id: Optional[str] = None, retry: Optional[Retry] = None, @@ -979,7 +1005,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.job_trigger_id = job_trigger_id @@ -991,10 +1017,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: hook.delete_job_trigger( job_trigger_id=self.job_trigger_id, @@ -1052,7 +1075,8 @@ class CloudDLPDeleteStoredInfoTypeOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, stored_info_type_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -1061,7 +1085,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.stored_info_type_id = stored_info_type_id @@ -1074,10 +1098,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: hook.delete_stored_info_type( stored_info_type_id=self.stored_info_type_id, @@ -1128,12 +1149,18 @@ class CloudDLPGetDeidentifyTemplateOperator(BaseOperator): :rtype: google.cloud.dlp_v2.types.DeidentifyTemplate """ - template_fields = ("template_id", "organization_id", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, template_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -1142,7 +1169,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template_id = template_id @@ -1155,10 +1182,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) template = hook.get_deidentify_template( template_id=self.template_id, organization_id=self.organization_id, @@ -1204,11 +1228,17 @@ class CloudDLPGetDLPJobOperator(BaseOperator): :rtype: google.cloud.dlp_v2.types.DlpJob """ - template_fields = ("dlp_job_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "dlp_job_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, dlp_job_id: str, project_id: Optional[str] = None, retry: Optional[Retry] = None, @@ -1216,7 +1246,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.dlp_job_id = dlp_job_id @@ -1228,10 +1258,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) job = hook.get_dlp_job( dlp_job_id=self.dlp_job_id, project_id=self.project_id, @@ -1279,12 +1306,18 @@ class CloudDLPGetInspectTemplateOperator(BaseOperator): :rtype: google.cloud.dlp_v2.types.InspectTemplate """ - template_fields = ("template_id", "organization_id", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "template_id", + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, template_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -1293,7 +1326,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template_id = template_id @@ -1306,10 +1339,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) template = hook.get_inspect_template( template_id=self.template_id, organization_id=self.organization_id, @@ -1355,11 +1385,17 @@ class CloudDLPGetDLPJobTriggerOperator(BaseOperator): :rtype: google.cloud.dlp_v2.types.JobTrigger """ - template_fields = ("job_trigger_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "job_trigger_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, job_trigger_id: str, project_id: Optional[str] = None, retry: Optional[Retry] = None, @@ -1367,7 +1403,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.job_trigger_id = job_trigger_id @@ -1379,10 +1415,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) trigger = hook.get_job_trigger( job_trigger_id=self.job_trigger_id, project_id=self.project_id, @@ -1440,7 +1473,8 @@ class CloudDLPGetStoredInfoTypeOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, stored_info_type_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -1449,7 +1483,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.stored_info_type_id = stored_info_type_id @@ -1462,10 +1496,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) info = hook.get_stored_info_type( stored_info_type_id=self.stored_info_type_id, organization_id=self.organization_id, @@ -1529,7 +1560,8 @@ class CloudDLPInspectContentOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, inspect_config: Optional[Union[Dict, InspectConfig]] = None, item: Optional[Union[Dict, ContentItem]] = None, @@ -1539,7 +1571,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -1553,10 +1585,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) response = hook.inspect_content( project_id=self.project_id, inspect_config=self.inspect_config, @@ -1610,11 +1639,17 @@ class CloudDLPListDeidentifyTemplatesOperator(BaseOperator): :rtype: list[google.cloud.dlp_v2.types.DeidentifyTemplate] """ - template_fields = ("organization_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, organization_id: Optional[str] = None, project_id: Optional[str] = None, page_size: Optional[int] = None, @@ -1624,7 +1659,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.organization_id = organization_id @@ -1638,10 +1673,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) template = hook.list_deidentify_templates( organization_id=self.organization_id, project_id=self.project_id, @@ -1696,11 +1728,16 @@ class CloudDLPListDLPJobsOperator(BaseOperator): :rtype: list[google.cloud.dlp_v2.types.DlpJob] """ - template_fields = ("project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, results_filter: Optional[str] = None, page_size: Optional[int] = None, @@ -1711,7 +1748,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -1726,10 +1763,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) job = hook.list_dlp_jobs( project_id=self.project_id, results_filter=self.results_filter, @@ -1777,11 +1811,16 @@ class CloudDLPListInfoTypesOperator(BaseOperator): :rtype: ListInfoTypesResponse """ - template_fields = ("language_code", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "language_code", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, language_code: Optional[str] = None, results_filter: Optional[str] = None, retry: Optional[Retry] = None, @@ -1789,7 +1828,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.language_code = language_code @@ -1801,10 +1840,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) response = hook.list_info_types( language_code=self.language_code, results_filter=self.results_filter, @@ -1856,11 +1892,17 @@ class CloudDLPListInspectTemplatesOperator(BaseOperator): :rtype: list[google.cloud.dlp_v2.types.InspectTemplate] """ - template_fields = ("organization_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, organization_id: Optional[str] = None, project_id: Optional[str] = None, page_size: Optional[int] = None, @@ -1870,7 +1912,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.organization_id = organization_id @@ -1884,10 +1926,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) templates = hook.list_inspect_templates( organization_id=self.organization_id, project_id=self.project_id, @@ -1940,11 +1979,16 @@ class CloudDLPListJobTriggersOperator(BaseOperator): :rtype: list[google.cloud.dlp_v2.types.JobTrigger] """ - template_fields = ("project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, page_size: Optional[int] = None, order_by: Optional[str] = None, @@ -1954,7 +1998,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -1968,10 +2012,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) jobs = hook.list_job_triggers( project_id=self.project_id, page_size=self.page_size, @@ -2025,11 +2066,17 @@ class CloudDLPListStoredInfoTypesOperator(BaseOperator): :rtype: list[google.cloud.dlp_v2.types.StoredInfoType] """ - template_fields = ("organization_id", "project_id", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "organization_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, organization_id: Optional[str] = None, project_id: Optional[str] = None, page_size: Optional[int] = None, @@ -2039,7 +2086,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.organization_id = organization_id @@ -2053,10 +2100,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) infos = hook.list_stored_info_types( organization_id=self.organization_id, project_id=self.project_id, @@ -2126,7 +2170,8 @@ class CloudDLPRedactImageOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, inspect_config: Optional[Union[Dict, InspectConfig]] = None, image_redaction_configs: Optional[Union[Dict, RedactImageRequest.ImageRedactionConfig]] = None, @@ -2137,7 +2182,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -2152,10 +2197,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) response = hook.redact_image( project_id=self.project_id, inspect_config=self.inspect_config, @@ -2228,7 +2270,8 @@ class CloudDLPReidentifyContentOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, project_id: Optional[str] = None, reidentify_config: Optional[Union[Dict, DeidentifyConfig]] = None, inspect_config: Optional[Union[Dict, InspectConfig]] = None, @@ -2240,7 +2283,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -2256,10 +2299,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) response = hook.reidentify_content( project_id=self.project_id, reidentify_config=self.reidentify_config, @@ -2327,7 +2367,8 @@ class CloudDLPUpdateDeidentifyTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, template_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -2338,7 +2379,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template_id = template_id @@ -2353,10 +2394,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) template = hook.update_deidentify_template( template_id=self.template_id, organization_id=self.organization_id, @@ -2423,7 +2461,8 @@ class CloudDLPUpdateInspectTemplateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, template_id: str, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -2434,7 +2473,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.template_id = template_id @@ -2449,10 +2488,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) template = hook.update_inspect_template( template_id=self.template_id, organization_id=self.organization_id, @@ -2515,7 +2551,8 @@ class CloudDLPUpdateJobTriggerOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, job_trigger_id, project_id: Optional[str] = None, job_trigger: Optional[JobTrigger] = None, @@ -2525,7 +2562,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.job_trigger_id = job_trigger_id @@ -2539,10 +2576,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) trigger = hook.update_job_trigger( job_trigger_id=self.job_trigger_id, project_id=self.project_id, @@ -2609,7 +2643,8 @@ class CloudDLPUpdateStoredInfoTypeOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, stored_info_type_id, organization_id: Optional[str] = None, project_id: Optional[str] = None, @@ -2620,7 +2655,7 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.stored_info_type_id = stored_info_type_id @@ -2635,10 +2670,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudDLPHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudDLPHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) info = hook.update_stored_info_type( stored_info_type_id=self.stored_info_type_id, organization_id=self.organization_id, diff --git a/airflow/providers/google/cloud/operators/functions.py b/airflow/providers/google/cloud/operators/functions.py index 4a19300675510..ca51a8faebb95 100644 --- a/airflow/providers/google/cloud/operators/functions.py +++ b/airflow/providers/google/cloud/operators/functions.py @@ -28,7 +28,8 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook from airflow.providers.google.cloud.utils.field_validator import ( - GcpBodyFieldValidator, GcpFieldValidationException, + GcpBodyFieldValidator, + GcpFieldValidationException, ) from airflow.utils.decorators import apply_defaults from airflow.version import version @@ -41,8 +42,7 @@ def _validate_available_memory_in_mb(value): def _validate_max_instances(value): if int(value) <= 0: - raise GcpFieldValidationException( - "The max instances parameter has to be greater than 0") + raise GcpFieldValidationException("The max instances parameter has to be greater than 0") CLOUD_FUNCTION_VALIDATION = [ @@ -51,35 +51,49 @@ def _validate_max_instances(value): dict(name="entryPoint", regexp=r'^.+$', optional=True), dict(name="runtime", regexp=r'^.+$', optional=True), dict(name="timeout", regexp=r'^.+$', optional=True), - dict(name="availableMemoryMb", custom_validation=_validate_available_memory_in_mb, - optional=True), + dict(name="availableMemoryMb", custom_validation=_validate_available_memory_in_mb, optional=True), dict(name="labels", optional=True), dict(name="environmentVariables", optional=True), dict(name="network", regexp=r'^.+$', optional=True), dict(name="maxInstances", optional=True, custom_validation=_validate_max_instances), - - dict(name="source_code", type="union", fields=[ - dict(name="sourceArchiveUrl", regexp=r'^.+$'), - dict(name="sourceRepositoryUrl", regexp=r'^.+$', api_version='v1beta2'), - dict(name="sourceRepository", type="dict", fields=[ - dict(name="url", regexp=r'^.+$') - ]), - dict(name="sourceUploadUrl") - ]), - - dict(name="trigger", type="union", fields=[ - dict(name="httpsTrigger", type="dict", fields=[ - # This dict should be empty at input (url is added at output) - ]), - dict(name="eventTrigger", type="dict", fields=[ - dict(name="eventType", regexp=r'^.+$'), - dict(name="resource", regexp=r'^.+$'), - dict(name="service", regexp=r'^.+$', optional=True), - dict(name="failurePolicy", type="dict", optional=True, fields=[ - dict(name="retry", type="dict", optional=True) - ]) - ]) - ]), + dict( + name="source_code", + type="union", + fields=[ + dict(name="sourceArchiveUrl", regexp=r'^.+$'), + dict(name="sourceRepositoryUrl", regexp=r'^.+$', api_version='v1beta2'), + dict(name="sourceRepository", type="dict", fields=[dict(name="url", regexp=r'^.+$')]), + dict(name="sourceUploadUrl"), + ], + ), + dict( + name="trigger", + type="union", + fields=[ + dict( + name="httpsTrigger", + type="dict", + fields=[ + # This dict should be empty at input (url is added at output) + ], + ), + dict( + name="eventTrigger", + type="dict", + fields=[ + dict(name="eventType", regexp=r'^.+$'), + dict(name="resource", regexp=r'^.+$'), + dict(name="service", regexp=r'^.+$', optional=True), + dict( + name="failurePolicy", + type="dict", + optional=True, + fields=[dict(name="retry", type="dict", optional=True)], + ), + ], + ), + ], + ), ] # type: List[Dict[str, Any]] @@ -126,22 +140,32 @@ class CloudFunctionDeployFunctionOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcf_function_deploy_template_fields] - template_fields = ('body', 'project_id', 'location', 'gcp_conn_id', 'api_version', - 'impersonation_chain',) + template_fields = ( + 'body', + 'project_id', + 'location', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcf_function_deploy_template_fields] @apply_defaults - def __init__(self, *, - location: str, - body: Dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - zip_path: Optional[str] = None, - validate_body: bool = True, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + location: str, + body: Dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + zip_path: Optional[str] = None, + validate_body: bool = True, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.project_id = project_id self.location = location self.body = body @@ -152,8 +176,7 @@ def __init__(self, *, self._field_validator = None # type: Optional[GcpBodyFieldValidator] self.impersonation_chain = impersonation_chain if validate_body: - self._field_validator = GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION, - api_version=api_version) + self._field_validator = GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION, api_version=api_version) self._validate_inputs() super().__init__(**kwargs) @@ -169,10 +192,7 @@ def _validate_all_body_fields(self): self._field_validator.validate(self.body) def _create_new_function(self, hook): - hook.create_new_function( - project_id=self.project_id, - location=self.location, - body=self.body) + hook.create_new_function(project_id=self.project_id, location=self.location, body=self.body) def _update_function(self, hook): hook.update_function(self.body['name'], self.body, self.body.keys()) @@ -180,8 +200,9 @@ def _update_function(self, hook): def _check_if_function_exists(self, hook): name = self.body.get('name') if not name: - raise GcpFieldValidationException("The 'name' field should be present in " - "body: '{}'.".format(self.body)) + raise GcpFieldValidationException( + "The 'name' field should be present in " "body: '{}'.".format(self.body) + ) try: hook.get_function(name) except HttpError as e: @@ -192,15 +213,14 @@ def _check_if_function_exists(self, hook): return True def _upload_source_code(self, hook): - return hook.upload_function_zip(project_id=self.project_id, - location=self.location, - zip_path=self.zip_path) + return hook.upload_function_zip( + project_id=self.project_id, location=self.location, zip_path=self.zip_path + ) def _set_airflow_version_label(self): if 'labels' not in self.body.keys(): self.body['labels'] = {} - self.body['labels'].update( - {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) + self.body['labels'].update({'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) def execute(self, context): hook = CloudFunctionsHook( @@ -246,6 +266,7 @@ class ZipPathPreprocessor: :type zip_path: str """ + upload_function = None # type: Optional[bool] def __init__(self, body: dict, zip_path: Optional[str] = None) -> None: @@ -262,9 +283,7 @@ def _verify_upload_url_and_no_zip_path(self): raise AirflowException( "Parameter '{url}' is empty in the body and argument '{path}' " "is missing or empty. You need to have non empty '{path}' " - "when '{url}' is present and empty.".format( - url=GCF_SOURCE_UPLOAD_URL, - path=GCF_ZIP_PATH) + "when '{url}' is present and empty.".format(url=GCF_SOURCE_UPLOAD_URL, path=GCF_ZIP_PATH) ) def _verify_upload_url_and_zip_path(self): @@ -272,15 +291,17 @@ def _verify_upload_url_and_zip_path(self): if not self.body[GCF_SOURCE_UPLOAD_URL]: self.upload_function = True else: - raise AirflowException("Only one of '{}' in body or '{}' argument " - "allowed. Found both." - .format(GCF_SOURCE_UPLOAD_URL, GCF_ZIP_PATH)) + raise AirflowException( + "Only one of '{}' in body or '{}' argument " + "allowed. Found both.".format(GCF_SOURCE_UPLOAD_URL, GCF_ZIP_PATH) + ) def _verify_archive_url_and_zip_path(self): if GCF_SOURCE_ARCHIVE_URL in self.body and self.zip_path: - raise AirflowException("Only one of '{}' in body or '{}' argument " - "allowed. Found both." - .format(GCF_SOURCE_ARCHIVE_URL, GCF_ZIP_PATH)) + raise AirflowException( + "Only one of '{}' in body or '{}' argument " + "allowed. Found both.".format(GCF_SOURCE_ARCHIVE_URL, GCF_ZIP_PATH) + ) def should_upload_function(self) -> bool: """ @@ -289,8 +310,7 @@ def should_upload_function(self) -> bool: :rtype: bool """ if self.upload_function is None: - raise AirflowException('validate() method has to be invoked before ' - 'should_upload_function') + raise AirflowException('validate() method has to be invoked before ' 'should_upload_function') return self.upload_function def preprocess_body(self): @@ -334,17 +354,26 @@ class CloudFunctionDeleteFunctionOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcf_function_delete_template_fields] - template_fields = ('name', 'gcp_conn_id', 'api_version', 'impersonation_chain',) + template_fields = ( + 'name', + 'gcp_conn_id', + 'api_version', + 'impersonation_chain', + ) # [END gcf_function_delete_template_fields] @apply_defaults - def __init__(self, *, - name: str, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v1', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + name: str, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v1', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.name = name self.gcp_conn_id = gcp_conn_id self.api_version = api_version @@ -358,8 +387,7 @@ def _validate_inputs(self): else: pattern = FUNCTION_NAME_COMPILED_PATTERN if not pattern.match(self.name): - raise AttributeError( - 'Parameter name must match pattern: {}'.format(FUNCTION_NAME_PATTERN)) + raise AttributeError('Parameter name must match pattern: {}'.format(FUNCTION_NAME_PATTERN)) def execute(self, context): hook = CloudFunctionsHook( @@ -408,12 +436,19 @@ class CloudFunctionInvokeFunctionOperator(BaseOperator): :return: None """ - template_fields = ('function_id', 'input_data', 'location', 'project_id', - 'impersonation_chain',) + + template_fields = ( + 'function_id', + 'input_data', + 'location', + 'project_id', + 'impersonation_chain', + ) @apply_defaults def __init__( - self, *, + self, + *, function_id: str, input_data: Dict, location: str, @@ -421,7 +456,7 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', api_version: str = 'v1', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.function_id = function_id @@ -443,7 +478,7 @@ def execute(self, context: Dict): function_id=self.function_id, input_data=self.input_data, location=self.location, - project_id=self.project_id + project_id=self.project_id, ) self.log.info('Function called successfully. Execution id %s', result.get('executionId', None)) self.xcom_push(context=context, key='execution_id', value=result.get('executionId', None)) diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 06e03e72a8e67..02492d1d8aa7e 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -104,29 +104,41 @@ class GCSCreateBucketOperator(BaseOperator): ) """ - template_fields = ('bucket_name', 'storage_class', - 'location', 'project_id', 'impersonation_chain',) + + template_fields = ( + 'bucket_name', + 'storage_class', + 'location', + 'project_id', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - bucket_name: str, - resource: Optional[Dict] = None, - storage_class: str = 'MULTI_REGIONAL', - location: str = 'US', - project_id: Optional[str] = None, - labels: Optional[Dict] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket_name: str, + resource: Optional[Dict] = None, + storage_class: str = 'MULTI_REGIONAL', + location: str = 'US', + project_id: Optional[str] = None, + labels: Optional[Dict] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.bucket_name = bucket_name @@ -146,12 +158,14 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) try: - hook.create_bucket(bucket_name=self.bucket_name, - resource=self.resource, - storage_class=self.storage_class, - location=self.location, - project_id=self.project_id, - labels=self.labels) + hook.create_bucket( + bucket_name=self.bucket_name, + resource=self.resource, + storage_class=self.storage_class, + location=self.location, + project_id=self.project_id, + labels=self.labels, + ) except Conflict: # HTTP 409 self.log.warning("Bucket %s already exists", self.bucket_name) @@ -203,26 +217,38 @@ class GCSListObjectsOperator(BaseOperator): gcp_conn_id=google_cloud_conn_id ) """ - template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'impersonation_chain',) + + template_fields: Iterable[str] = ( + 'bucket', + 'prefix', + 'delimiter', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - bucket: str, - prefix: Optional[str] = None, - delimiter: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket: str, + prefix: Optional[str] = None, + delimiter: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.bucket = bucket @@ -240,12 +266,14 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) - self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', - self.bucket, self.delimiter, self.prefix) + self.log.info( + 'Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s', + self.bucket, + self.delimiter, + self.prefix, + ) - return hook.list(bucket_name=self.bucket, - prefix=self.prefix, - delimiter=self.delimiter) + return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) class GCSDeleteObjectsOperator(BaseOperator): @@ -281,23 +309,34 @@ class GCSDeleteObjectsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket_name', 'prefix', 'objects', 'impersonation_chain',) + template_fields = ( + 'bucket_name', + 'prefix', + 'objects', + 'impersonation_chain', + ) @apply_defaults - def __init__(self, *, - bucket_name: str, - objects: Optional[Iterable[str]] = None, - prefix: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket_name: str, + objects: Optional[Iterable[str]] = None, + prefix: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.bucket_name = bucket_name @@ -322,14 +361,11 @@ def execute(self, context): if self.objects: objects = self.objects else: - objects = hook.list(bucket_name=self.bucket_name, - prefix=self.prefix) + objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix) - self.log.info("Deleting %s objects from %s", - len(objects), self.bucket_name) + self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name) for object_name in objects: - hook.delete(bucket_name=self.bucket_name, - object_name=object_name) + hook.delete(bucket_name=self.bucket_name, object_name=object_name) class GCSBucketCreateAclEntryOperator(BaseOperator): @@ -367,13 +403,21 @@ class GCSBucketCreateAclEntryOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcs_bucket_create_acl_template_fields] - template_fields = ('bucket', 'entity', 'role', 'user_project', 'impersonation_chain',) + template_fields = ( + 'bucket', + 'entity', + 'role', + 'user_project', + 'impersonation_chain', + ) # [END gcs_bucket_create_acl_template_fields] @apply_defaults def __init__( - self, *, + self, + *, bucket: str, entity: str, role: str, @@ -381,14 +425,17 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', google_cloud_storage_conn_id: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.bucket = bucket @@ -400,11 +447,11 @@ def __init__( def execute(self, context): hook = GCSHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + google_cloud_storage_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, + ) + hook.insert_bucket_acl( + bucket_name=self.bucket, entity=self.entity, role=self.role, user_project=self.user_project ) - hook.insert_bucket_acl(bucket_name=self.bucket, entity=self.entity, role=self.role, - user_project=self.user_project) class GCSObjectCreateAclEntryOperator(BaseOperator): @@ -448,29 +495,43 @@ class GCSObjectCreateAclEntryOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcs_object_create_acl_template_fields] - template_fields = ('bucket', 'object_name', 'entity', 'generation', 'role', 'user_project', - 'impersonation_chain',) + template_fields = ( + 'bucket', + 'object_name', + 'entity', + 'generation', + 'role', + 'user_project', + 'impersonation_chain', + ) # [END gcs_object_create_acl_template_fields] @apply_defaults - def __init__(self, *, - bucket: str, - object_name: str, - entity: str, - role: str, - generation: Optional[int] = None, - user_project: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket: str, + object_name: str, + entity: str, + role: str, + generation: Optional[int] = None, + user_project: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.bucket = bucket @@ -484,15 +545,16 @@ def __init__(self, *, def execute(self, context): hook = GCSHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + google_cloud_storage_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, + ) + hook.insert_object_acl( + bucket_name=self.bucket, + object_name=self.object_name, + entity=self.entity, + role=self.role, + generation=self.generation, + user_project=self.user_project, ) - hook.insert_object_acl(bucket_name=self.bucket, - object_name=self.object_name, - entity=self.entity, - role=self.role, - generation=self.generation, - user_project=self.user_project) class GCSFileTransformOperator(BaseOperator): @@ -529,12 +591,17 @@ class GCSFileTransformOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('source_bucket', 'destination_bucket', 'transform_script', - 'impersonation_chain',) + template_fields = ( + 'source_bucket', + 'destination_bucket', + 'transform_script', + 'impersonation_chain', + ) @apply_defaults def __init__( - self, *, + self, + *, source_bucket: str, source_object: str, transform_script: Union[str, List[str]], @@ -542,7 +609,7 @@ def __init__( destination_object: Optional[str] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.source_bucket = source_bucket @@ -561,19 +628,14 @@ def execute(self, context: Dict): with NamedTemporaryFile() as source_file, NamedTemporaryFile() as destination_file: self.log.info("Downloading file from %s", self.source_bucket) hook.download( - bucket_name=self.source_bucket, - object_name=self.source_object, - filename=source_file.name + bucket_name=self.source_bucket, object_name=self.source_object, filename=source_file.name ) self.log.info("Starting the transformation") cmd = [self.transform_script] if isinstance(self.transform_script, str) else self.transform_script cmd += [source_file.name, destination_file.name] process = subprocess.Popen( - args=cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - close_fds=True + args=cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True ) self.log.info("Process output:") if process.stdout: @@ -582,24 +644,15 @@ def execute(self, context: Dict): process.wait() if process.returncode: - raise AirflowException( - "Transform script failed: {0}".format(process.returncode) - ) + raise AirflowException("Transform script failed: {0}".format(process.returncode)) - self.log.info( - "Transformation succeeded. Output temporarily located at %s", - destination_file.name - ) + self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name) - self.log.info( - "Uploading file to %s as %s", - self.destination_bucket, - self.destination_object - ) + self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object) hook.upload( bucket_name=self.destination_bucket, object_name=self.destination_object, - filename=destination_file.name + filename=destination_file.name, ) @@ -629,15 +682,22 @@ class GCSDeleteBucketOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket_name', "gcp_conn_id", "impersonation_chain",) + template_fields = ( + 'bucket_name', + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults - def __init__(self, *, - bucket_name: str, - force: bool = True, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket_name: str, + force: bool = True, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name @@ -717,7 +777,8 @@ class GCSSynchronizeBucketsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, source_bucket: str, destination_bucket: str, source_object: Optional[str] = None, @@ -728,7 +789,7 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.source_bucket = source_bucket @@ -755,5 +816,5 @@ def execute(self, context): destination_object=self.destination_object, recursive=self.recursive, delete_extra_files=self.delete_extra_files, - allow_overwrite=self.allow_overwrite + allow_overwrite=self.allow_overwrite, ) diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index ab671f1fdc184..b5ed6f595c5fd 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -79,19 +79,28 @@ class GKEDeleteClusterOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'gcp_conn_id', 'name', 'location', 'api_version', - 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'gcp_conn_id', + 'name', + 'location', + 'api_version', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, - *, - name: str, - location: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v2', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + name: str, + location: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v2', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -104,8 +113,7 @@ def __init__(self, def _check_input(self): if not all([self.project_id, self.name, self.location]): - self.log.error( - 'One of (project_id, name, location) is missing or incorrect') + self.log.error('One of (project_id, name, location) is missing or incorrect') raise AirflowException('Operator has incorrect or missing input.') def execute(self, context): @@ -174,19 +182,28 @@ class GKECreateClusterOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'gcp_conn_id', 'location', 'api_version', 'body', - 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'gcp_conn_id', + 'location', + 'api_version', + 'body', + 'impersonation_chain', + ] @apply_defaults - def __init__(self, - *, - location: str, - body: Optional[Union[Dict, Cluster]], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - api_version: str = 'v2', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + location: str, + body: Optional[Union[Dict, Cluster]], + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + api_version: str = 'v2', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -199,8 +216,8 @@ def __init__(self, def _check_input(self): if not all([self.project_id, self.location, self.body]) or not ( - (isinstance(self.body, dict) and "name" in self.body and "initial_node_count" in self.body) or - (getattr(self.body, "name", None) and getattr(self.body, "initial_node_count", None)) + (isinstance(self.body, dict) and "name" in self.body and "initial_node_count" in self.body) + or (getattr(self.body, "name", None) and getattr(self.body, "initial_node_count", None)) ): self.log.error( "One of (project_id, location, body, body['name'], " @@ -254,17 +271,20 @@ class GKEStartPodOperator(KubernetesPodOperator): users to specify a service account. :type gcp_conn_id: str """ + template_fields = {'project_id', 'location', 'cluster_name'} | set(KubernetesPodOperator.template_fields) @apply_defaults - def __init__(self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - **kwargs) -> None: + def __init__( + self, + *, + location: str, + cluster_name: str, + use_internal_ip: bool = False, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.project_id = project_id self.location = location @@ -284,15 +304,17 @@ def execute(self, context): self.project_id = self.project_id or hook.project_id if not self.project_id: - raise AirflowException("The project id must be passed either as " - "keyword project_id parameter or as project_id extra " - "in GCP connection definition. Both are not set!") + raise AirflowException( + "The project id must be passed either as " + "keyword project_id parameter or as project_id extra " + "in GCP connection definition. Both are not set!" + ) # Write config to a temp file and set the environment variable to point to it. # This is to avoid race conditions of reading/writing a single file - with tempfile.NamedTemporaryFile() as conf_file,\ - patch_environ({KUBE_CONFIG_ENV_VAR: conf_file.name}), \ - hook.provide_authorized_gcloud(): + with tempfile.NamedTemporaryFile() as conf_file, patch_environ( + {KUBE_CONFIG_ENV_VAR: conf_file.name} + ), hook.provide_authorized_gcloud(): # Attempt to get/update credentials # We call gcloud directly instead of using google-cloud-python api # because there is no way to write kubernetes config to a file, which is @@ -305,8 +327,10 @@ def execute(self, context): "clusters", "get-credentials", self.cluster_name, - "--zone", self.location, - "--project", self.project_id + "--zone", + self.location, + "--project", + self.project_id, ] if self.use_internal_ip: cmd.append('--internal-ip') diff --git a/airflow/providers/google/cloud/operators/life_sciences.py b/airflow/providers/google/cloud/operators/life_sciences.py index 6af8c9d5ec4fc..67760fff9eaee 100644 --- a/airflow/providers/google/cloud/operators/life_sciences.py +++ b/airflow/providers/google/cloud/operators/life_sciences.py @@ -55,18 +55,25 @@ class LifeSciencesRunPipelineOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body", "gcp_conn_id", "api_version", "impersonation_chain",) + template_fields = ( + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) @apply_defaults - def __init__(self, - *, - body: dict, - location: str, - project_id: Optional[str] = None, - gcp_conn_id: str = "google_cloud_default", - api_version: str = "v2beta", - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + body: dict, + location: str, + project_id: Optional[str] = None, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v2beta", + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.body = body self.location = location @@ -89,6 +96,4 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) - return hook.run_pipeline(body=self.body, - location=self.location, - project_id=self.project_id) + return hook.run_pipeline(body=self.body, location=self.location, project_id=self.project_id) diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index 0b23bc5e9f930..383684b392430 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -56,9 +56,8 @@ def _normalize_mlengine_job_id(job_id: str) -> str: tracker = 0 cleansed_job_id = '' for match in re.finditer(r'\{{2}.+?\}{2}', job): - cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', - job[tracker:match.start()]) - cleansed_job_id += job[match.start():match.end()] + cleansed_job_id += re.sub(r'[^0-9a-zA-Z]+', '_', job[tracker : match.start()]) + cleansed_job_id += job[match.start() : match.end()] tracker = match.end() # Clean up last substring or the full string if no templates @@ -181,25 +180,27 @@ class MLEngineStartBatchPredictionJobOperator(BaseOperator): ] @apply_defaults - def __init__(self, # pylint: disable=too-many-arguments - *, - job_id: str, - region: str, - data_format: str, - input_paths: List[str], - output_path: str, - model_name: Optional[str] = None, - version_name: Optional[str] = None, - uri: Optional[str] = None, - max_worker_count: Optional[int] = None, - runtime_version: Optional[str] = None, - signature_name: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, # pylint: disable=too-many-arguments + *, + job_id: str, + region: str, + data_format: str, + input_paths: List[str], + output_path: str, + model_name: Optional[str] = None, + version_name: Optional[str] = None, + uri: Optional[str] = None, + max_worker_count: Optional[int] = None, + runtime_version: Optional[str] = None, + signature_name: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id @@ -222,24 +223,24 @@ def __init__(self, # pylint: disable=too-many-arguments if not self._project_id: raise AirflowException('Google Cloud project id is required.') if not self._job_id: - raise AirflowException( - 'An unique job id is required for Google MLEngine prediction ' - 'job.') + raise AirflowException('An unique job id is required for Google MLEngine prediction ' 'job.') if self._uri: if self._model_name or self._version_name: - raise AirflowException('Ambiguous model origin: Both uri and ' - 'model/version name are provided.') + raise AirflowException( + 'Ambiguous model origin: Both uri and ' 'model/version name are provided.' + ) if self._version_name and not self._model_name: raise AirflowException( - 'Missing model: Batch prediction expects ' - 'a model name when a version name is provided.') + 'Missing model: Batch prediction expects ' 'a model name when a version name is provided.' + ) if not (self._uri or self._model_name): raise AirflowException( 'Missing model origin: Batch prediction expects a model, ' - 'a model & version combination, or a URI to a savedModel.') + 'a model & version combination, or a URI to a savedModel.' + ) def execute(self, context): job_id = _normalize_mlengine_job_id(self._job_id) @@ -249,8 +250,8 @@ def execute(self, context): 'dataFormat': self._data_format, 'inputPaths': self._input_paths, 'outputPath': self._output_path, - 'region': self._region - } + 'region': self._region, + }, } if self._labels: prediction_request['labels'] = self._labels @@ -258,47 +259,38 @@ def execute(self, context): if self._uri: prediction_request['predictionInput']['uri'] = self._uri elif self._model_name: - origin_name = 'projects/{}/models/{}'.format( - self._project_id, self._model_name) + origin_name = 'projects/{}/models/{}'.format(self._project_id, self._model_name) if not self._version_name: - prediction_request['predictionInput'][ - 'modelName'] = origin_name + prediction_request['predictionInput']['modelName'] = origin_name else: - prediction_request['predictionInput']['versionName'] = \ - origin_name + '/versions/{}'.format(self._version_name) + prediction_request['predictionInput']['versionName'] = origin_name + '/versions/{}'.format( + self._version_name + ) if self._max_worker_count: - prediction_request['predictionInput'][ - 'maxWorkerCount'] = self._max_worker_count + prediction_request['predictionInput']['maxWorkerCount'] = self._max_worker_count if self._runtime_version: - prediction_request['predictionInput'][ - 'runtimeVersion'] = self._runtime_version + prediction_request['predictionInput']['runtimeVersion'] = self._runtime_version if self._signature_name: - prediction_request['predictionInput'][ - 'signatureName'] = self._signature_name + prediction_request['predictionInput']['signatureName'] = self._signature_name hook = MLEngineHook( - self._gcp_conn_id, - self._delegate_to, - impersonation_chain=self._impersonation_chain + self._gcp_conn_id, self._delegate_to, impersonation_chain=self._impersonation_chain ) # Helper method to check if the existing job's prediction input is the # same as the request we get here. def check_existing_job(existing_job): - return existing_job.get('predictionInput', None) == \ - prediction_request['predictionInput'] + return existing_job.get('predictionInput', None) == prediction_request['predictionInput'] finished_prediction_job = hook.create_job( project_id=self._project_id, job=prediction_request, use_existing_job_fn=check_existing_job ) if finished_prediction_job['state'] != 'SUCCEEDED': - self.log.error( - 'MLEngine batch prediction job failed: %s', str(finished_prediction_job) - ) + self.log.error('MLEngine batch prediction job failed: %s', str(finished_prediction_job)) raise RuntimeError(finished_prediction_job['errorMessage']) return finished_prediction_job['predictionOutput'] @@ -351,22 +343,24 @@ class MLEngineManageModelOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model: dict, - operation: str = 'create', - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model: dict, + operation: str = 'create', + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) warnings.warn( "This operator is deprecated. Consider using operators for specific operations: " "MLEngineCreateModelOperator, MLEngineGetModelOperator.", DeprecationWarning, - stacklevel=3 + stacklevel=3, ) self._project_id = project_id @@ -429,14 +423,16 @@ class MLEngineCreateModelOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id self._model = model @@ -492,14 +488,16 @@ class MLEngineGetModelOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -559,15 +557,17 @@ class MLEngineDeleteModelOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - delete_contents: bool = False, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + delete_contents: bool = False, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -580,7 +580,7 @@ def execute(self, context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to, - impersonation_chain=self._impersonation_chain + impersonation_chain=self._impersonation_chain, ) return hook.delete_model( @@ -660,17 +660,19 @@ class MLEngineManageVersionOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - version_name: Optional[str] = None, - version: Optional[dict] = None, - operation: str = 'create', - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + version_name: Optional[str] = None, + version: Optional[dict] = None, + operation: str = 'create', + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id self._model_name = model_name @@ -685,7 +687,7 @@ def __init__(self, "This operator is deprecated. Consider using operators for specific operations: " "MLEngineCreateVersion, MLEngineSetDefaultVersion, MLEngineListVersions, MLEngineDeleteVersion.", DeprecationWarning, - stacklevel=3 + stacklevel=3, ) def execute(self, context): @@ -700,29 +702,21 @@ def execute(self, context): if self._operation == 'create': if not self._version: - raise ValueError("version attribute of {} could not " - "be empty".format(self.__class__.__name__)) + raise ValueError( + "version attribute of {} could not " "be empty".format(self.__class__.__name__) + ) return hook.create_version( - project_id=self._project_id, - model_name=self._model_name, - version_spec=self._version + project_id=self._project_id, model_name=self._model_name, version_spec=self._version ) elif self._operation == 'set_default': return hook.set_default_version( - project_id=self._project_id, - model_name=self._model_name, - version_name=self._version['name'] + project_id=self._project_id, model_name=self._model_name, version_name=self._version['name'] ) elif self._operation == 'list': - return hook.list_versions( - project_id=self._project_id, - model_name=self._model_name - ) + return hook.list_versions(project_id=self._project_id, model_name=self._model_name) elif self._operation == 'delete': return hook.delete_version( - project_id=self._project_id, - model_name=self._model_name, - version_name=self._version['name'] + project_id=self._project_id, model_name=self._model_name, version_name=self._version['name'] ) else: raise ValueError('Unknown operation: {}'.format(self._operation)) @@ -771,15 +765,17 @@ class MLEngineCreateVersionOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - version: dict, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + version: dict, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id @@ -805,9 +801,7 @@ def execute(self, context): ) return hook.create_version( - project_id=self._project_id, - model_name=self._model_name, - version_spec=self._version + project_id=self._project_id, model_name=self._model_name, version_spec=self._version ) @@ -854,15 +848,17 @@ class MLEngineSetDefaultVersionOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - version_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + version_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id @@ -888,9 +884,7 @@ def execute(self, context): ) return hook.set_default_version( - project_id=self._project_id, - model_name=self._model_name, - version_name=self._version_name + project_id=self._project_id, model_name=self._model_name, version_name=self._version_name ) @@ -926,6 +920,7 @@ class MLEngineListVersionsOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + template_fields = [ '_project_id', '_model_name', @@ -933,14 +928,16 @@ class MLEngineListVersionsOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id @@ -961,10 +958,7 @@ def execute(self, context): impersonation_chain=self._impersonation_chain, ) - return hook.list_versions( - project_id=self._project_id, - model_name=self._model_name, - ) + return hook.list_versions(project_id=self._project_id, model_name=self._model_name,) class MLEngineDeleteVersionOperator(BaseOperator): @@ -1002,6 +996,7 @@ class MLEngineDeleteVersionOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + template_fields = [ '_project_id', '_model_name', @@ -1010,15 +1005,17 @@ class MLEngineDeleteVersionOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - model_name: str, - version_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + model_name: str, + version_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id @@ -1044,9 +1041,7 @@ def execute(self, context): ) return hook.delete_version( - project_id=self._project_id, - model_name=self._model_name, - version_name=self._version_name + project_id=self._project_id, model_name=self._model_name, version_name=self._version_name ) @@ -1054,6 +1049,7 @@ class AIPlatformConsoleLink(BaseOperatorLink): """ Helper class for constructing AI Platform Console link. """ + name = "AI Platform Console" def get_link(self, operator, dttm): @@ -1147,30 +1143,30 @@ class MLEngineStartTrainingJobOperator(BaseOperator): '_impersonation_chain', ] - operator_extra_links = ( - AIPlatformConsoleLink(), - ) + operator_extra_links = (AIPlatformConsoleLink(),) @apply_defaults - def __init__(self, # pylint: disable=too-many-arguments - *, - job_id: str, - package_uris: List[str], - training_python_module: str, - training_args: List[str], - region: str, - scale_tier: Optional[str] = None, - master_type: Optional[str] = None, - runtime_version: Optional[str] = None, - python_version: Optional[str] = None, - job_dir: Optional[str] = None, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - mode: str = 'PRODUCTION', - labels: Optional[Dict[str, str]] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, # pylint: disable=too-many-arguments + *, + job_id: str, + package_uris: List[str], + training_python_module: str, + training_args: List[str], + region: str, + scale_tier: Optional[str] = None, + master_type: Optional[str] = None, + runtime_version: Optional[str] = None, + python_version: Optional[str] = None, + job_dir: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + mode: str = 'PRODUCTION', + labels: Optional[Dict[str, str]] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id self._job_id = job_id @@ -1192,22 +1188,17 @@ def __init__(self, # pylint: disable=too-many-arguments if not self._project_id: raise AirflowException('Google Cloud project id is required.') if not self._job_id: - raise AirflowException( - 'An unique job id is required for Google MLEngine training ' - 'job.') + raise AirflowException('An unique job id is required for Google MLEngine training ' 'job.') if not package_uris: - raise AirflowException( - 'At least one python package is required for MLEngine ' - 'Training job.') + raise AirflowException('At least one python package is required for MLEngine ' 'Training job.') if not training_python_module: raise AirflowException( - 'Python module name to run after installing required ' - 'packages is required.') + 'Python module name to run after installing required ' 'packages is required.' + ) if not self._region: raise AirflowException('Google Compute Engine region is required.') if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" and not self._master_type: - raise AirflowException( - 'master_type must be set when scale_tier is CUSTOM') + raise AirflowException('master_type must be set when scale_tier is CUSTOM') def execute(self, context): job_id = _normalize_mlengine_job_id(self._job_id) @@ -1219,7 +1210,7 @@ def execute(self, context): 'pythonModule': self._training_python_module, 'region': self._region, 'args': self._training_args, - } + }, } if self._labels: training_request['labels'] = self._labels @@ -1256,8 +1247,9 @@ def check_existing_job(existing_job): existing_training_input['scaleTier'] = None existing_training_input['args'] = existing_training_input.get('args', None) - requested_training_input["args"] = requested_training_input['args'] \ - if requested_training_input["args"] else None + requested_training_input["args"] = ( + requested_training_input['args'] if requested_training_input["args"] else None + ) return existing_training_input == requested_training_input @@ -1310,14 +1302,16 @@ class MLEngineTrainingCancelJobOperator(BaseOperator): ] @apply_defaults - def __init__(self, - *, - job_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + job_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self._project_id = project_id self._job_id = job_id diff --git a/airflow/providers/google/cloud/operators/natural_language.py b/airflow/providers/google/cloud/operators/natural_language.py index eed2ee550bf7c..a2c4124e46a64 100644 --- a/airflow/providers/google/cloud/operators/natural_language.py +++ b/airflow/providers/google/cloud/operators/natural_language.py @@ -65,13 +65,19 @@ class CloudNaturalLanguageAnalyzeEntitiesOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START natural_language_analyze_entities_template_fields] - template_fields = ("document", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) # [END natural_language_analyze_entities_template_fields] @apply_defaults def __init__( - self, *, + self, + *, document: Union[dict, Document], encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, @@ -79,7 +85,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.document = document @@ -92,8 +98,7 @@ def __init__( def execute(self, context): hook = CloudNaturalLanguageHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.log.info("Start analyzing entities") @@ -140,13 +145,19 @@ class CloudNaturalLanguageAnalyzeEntitySentimentOperator(BaseOperator): :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse """ + # [START natural_language_analyze_entity_sentiment_template_fields] - template_fields = ("document", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) # [END natural_language_analyze_entity_sentiment_template_fields] @apply_defaults def __init__( - self, *, + self, + *, document: Union[dict, Document], encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, @@ -154,7 +165,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.document = document @@ -167,8 +178,7 @@ def __init__( def execute(self, context): hook = CloudNaturalLanguageHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.log.info("Start entity sentiment analyze") @@ -218,13 +228,19 @@ class CloudNaturalLanguageAnalyzeSentimentOperator(BaseOperator): :rtype: google.cloud.language_v1.types.AnalyzeEntitiesResponse """ + # [START natural_language_analyze_sentiment_template_fields] - template_fields = ("document", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) # [END natural_language_analyze_sentiment_template_fields] @apply_defaults def __init__( - self, *, + self, + *, document: Union[dict, Document], encoding_type: Optional[enums.EncodingType] = None, retry: Optional[Retry] = None, @@ -232,7 +248,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.document = document @@ -245,8 +261,7 @@ def __init__( def execute(self, context): hook = CloudNaturalLanguageHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.log.info("Start sentiment analyze") @@ -288,20 +303,26 @@ class CloudNaturalLanguageClassifyTextOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START natural_language_classify_text_template_fields] - template_fields = ("document", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "document", + "gcp_conn_id", + "impersonation_chain", + ) # [END natural_language_classify_text_template_fields] @apply_defaults def __init__( - self, *, + self, + *, document: Union[dict, Document], retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.document = document @@ -313,8 +334,7 @@ def __init__( def execute(self, context): hook = CloudNaturalLanguageHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.log.info("Start text classify") diff --git a/airflow/providers/google/cloud/operators/pubsub.py b/airflow/providers/google/cloud/operators/pubsub.py index 425b8c6b91c77..7bb61d3e7b90f 100644 --- a/airflow/providers/google/cloud/operators/pubsub.py +++ b/airflow/providers/google/cloud/operators/pubsub.py @@ -23,7 +23,12 @@ from google.api_core.retry import Retry from google.cloud.pubsub_v1.types import ( - DeadLetterPolicy, Duration, ExpirationPolicy, MessageStoragePolicy, PushConfig, ReceivedMessage, + DeadLetterPolicy, + Duration, + ExpirationPolicy, + MessageStoragePolicy, + PushConfig, + ReceivedMessage, RetryPolicy, ) from google.protobuf.json_format import MessageToDict @@ -115,13 +120,19 @@ class PubSubCreateTopicOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'topic', 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'topic', + 'impersonation_chain', + ] ui_color = '#0273d4' # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, + self, + *, topic: str, project_id: Optional[str] = None, fail_if_exists: bool = False, @@ -135,14 +146,16 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, project: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: # To preserve backward compatibility # TODO: remove one day if project: warnings.warn( - "The project parameter has been deprecated. You should pass " - "the project_id parameter.", DeprecationWarning, stacklevel=2) + "The project parameter has been deprecated. You should pass " "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) project_id = project super().__init__(**kwargs) @@ -176,7 +189,7 @@ def execute(self, context): kms_key_name=self.kms_key_name, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) self.log.info("Created topic %s", self.topic) @@ -325,14 +338,21 @@ class PubSubCreateSubscriptionOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'topic', 'subscription', 'subscription_project_id', - 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'topic', + 'subscription', + 'subscription_project_id', + 'impersonation_chain', + ] ui_color = '#0273d4' # pylint: disable=too-many-arguments, too-many-locals @apply_defaults def __init__( - self, *, + self, + *, topic: str, project_id: Optional[str] = None, subscription: Optional[str] = None, @@ -356,7 +376,7 @@ def __init__( topic_project: Optional[str] = None, subscription_project: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: # To preserve backward compatibility @@ -364,12 +384,18 @@ def __init__( if topic_project: warnings.warn( "The topic_project parameter has been deprecated. You should pass " - "the project_id parameter.", DeprecationWarning, stacklevel=2) + "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) project_id = topic_project if subscription_project: warnings.warn( "The project_id parameter has been deprecated. You should pass " - "the subscription_project parameter.", DeprecationWarning, stacklevel=2) + "the subscription_project parameter.", + DeprecationWarning, + stacklevel=2, + ) subscription_project_id = subscription_project super().__init__(**kwargs) @@ -421,7 +447,7 @@ def execute(self, context): retry_policy=self.retry_policy, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) self.log.info("Created subscription for topic %s", self.topic) @@ -495,12 +521,18 @@ class PubSubDeleteTopicOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'topic', 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'topic', + 'impersonation_chain', + ] ui_color = '#cb4335' @apply_defaults def __init__( - self, *, + self, + *, topic: str, project_id: Optional[str] = None, fail_if_not_exists: bool = False, @@ -511,14 +543,16 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, project: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: # To preserve backward compatibility # TODO: remove one day if project: warnings.warn( - "The project parameter has been deprecated. You should pass " - "the project_id parameter.", DeprecationWarning, stacklevel=2) + "The project parameter has been deprecated. You should pass " "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) project_id = project super().__init__(**kwargs) @@ -546,7 +580,7 @@ def execute(self, context): fail_if_not_exists=self.fail_if_not_exists, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) self.log.info("Deleted topic %s", self.topic) @@ -620,12 +654,18 @@ class PubSubDeleteSubscriptionOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'subscription', 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'subscription', + 'impersonation_chain', + ] ui_color = '#cb4335' @apply_defaults def __init__( - self, *, + self, + *, subscription: str, project_id: Optional[str] = None, fail_if_not_exists: bool = False, @@ -636,14 +676,16 @@ def __init__( metadata: Optional[Sequence[Tuple[str, str]]] = None, project: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: # To preserve backward compatibility # TODO: remove one day if project: warnings.warn( - "The project parameter has been deprecated. You should pass " - "the project_id parameter.", DeprecationWarning, stacklevel=2) + "The project parameter has been deprecated. You should pass " "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) project_id = project super().__init__(**kwargs) @@ -671,7 +713,7 @@ def execute(self, context): fail_if_not_exists=self.fail_if_not_exists, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) self.log.info("Deleted subscription %s", self.subscription) @@ -738,12 +780,19 @@ class PubSubPublishMessageOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'topic', 'messages', 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'topic', + 'messages', + 'impersonation_chain', + ] ui_color = '#0273d4' @apply_defaults def __init__( - self, *, + self, + *, topic: str, messages: List, project_id: Optional[str] = None, @@ -751,14 +800,16 @@ def __init__( delegate_to: Optional[str] = None, project: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: # To preserve backward compatibility # TODO: remove one day if project: warnings.warn( - "The project parameter has been deprecated. You should pass " - "the project_id parameter.", DeprecationWarning, stacklevel=2) + "The project parameter has been deprecated. You should pass " "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) project_id = project super().__init__(**kwargs) @@ -837,11 +888,17 @@ class PubSubPullOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'subscription', 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'subscription', + 'impersonation_chain', + ] @apply_defaults def __init__( - self, *, + self, + *, project_id: str, subscription: str, max_messages: int = 5, @@ -850,7 +907,7 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.gcp_conn_id = gcp_conn_id @@ -882,9 +939,7 @@ def execute(self, context): if pulled_messages and self.ack_messages: hook.acknowledge( - project_id=self.project_id, - subscription=self.subscription, - messages=pulled_messages, + project_id=self.project_id, subscription=self.subscription, messages=pulled_messages, ) return ret @@ -904,9 +959,6 @@ def _default_message_callback( :return: value to be saved to XCom. """ - messages_json = [ - MessageToDict(m) - for m in pulled_messages - ] + messages_json = [MessageToDict(m) for m in pulled_messages] return messages_json diff --git a/airflow/providers/google/cloud/operators/spanner.py b/airflow/providers/google/cloud/operators/spanner.py index 5daf90fd1c281..a9db04b4bbc20 100644 --- a/airflow/providers/google/cloud/operators/spanner.py +++ b/airflow/providers/google/cloud/operators/spanner.py @@ -63,21 +63,31 @@ class SpannerDeployInstanceOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_spanner_deploy_template_fields] - template_fields = ('project_id', 'instance_id', 'configuration_name', 'display_name', - 'gcp_conn_id', 'impersonation_chain', ) + template_fields = ( + 'project_id', + 'instance_id', + 'configuration_name', + 'display_name', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END gcp_spanner_deploy_template_fields] @apply_defaults - def __init__(self, *, - instance_id: str, - configuration_name: str, - node_count: int, - display_name: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + configuration_name: str, + node_count: int, + display_name: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.instance_id = instance_id self.project_id = project_id self.configuration_name = configuration_name @@ -92,25 +102,23 @@ def _validate_inputs(self): if self.project_id == '': raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: - raise AirflowException("The required parameter 'instance_id' " - "is empty or None") + raise AirflowException("The required parameter 'instance_id' " "is empty or None") def execute(self, context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = SpannerHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) if not hook.get_instance(project_id=self.project_id, instance_id=self.instance_id): self.log.info("Creating Cloud Spanner instance '%s'", self.instance_id) func = hook.create_instance else: self.log.info("Updating Cloud Spanner instance '%s'", self.instance_id) func = hook.update_instance - func(project_id=self.project_id, - instance_id=self.instance_id, - configuration_name=self.configuration_name, - node_count=self.node_count, - display_name=self.display_name) + func( + project_id=self.project_id, + instance_id=self.instance_id, + configuration_name=self.configuration_name, + node_count=self.node_count, + display_name=self.display_name, + ) class SpannerDeleteInstanceOperator(BaseOperator): @@ -139,17 +147,26 @@ class SpannerDeleteInstanceOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_spanner_delete_template_fields] - template_fields = ('project_id', 'instance_id', 'gcp_conn_id', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END gcp_spanner_delete_template_fields] @apply_defaults - def __init__(self, *, - instance_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.instance_id = instance_id self.project_id = project_id self.gcp_conn_id = gcp_conn_id @@ -161,20 +178,18 @@ def _validate_inputs(self): if self.project_id == '': raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: - raise AirflowException("The required parameter 'instance_id' " - "is empty or None") + raise AirflowException("The required parameter 'instance_id' " "is empty or None") def execute(self, context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = SpannerHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) if hook.get_instance(project_id=self.project_id, instance_id=self.instance_id): - return hook.delete_instance(project_id=self.project_id, - instance_id=self.instance_id) + return hook.delete_instance(project_id=self.project_id, instance_id=self.instance_id) else: - self.log.info("Instance '%s' does not exist in project '%s'. " - "Aborting delete.", self.instance_id, self.project_id) + self.log.info( + "Instance '%s' does not exist in project '%s'. " "Aborting delete.", + self.instance_id, + self.project_id, + ) return True @@ -208,21 +223,31 @@ class SpannerQueryDatabaseInstanceOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_spanner_query_template_fields] - template_fields = ('project_id', 'instance_id', 'database_id', 'query', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance_id', + 'database_id', + 'query', + 'gcp_conn_id', + 'impersonation_chain', + ) template_ext = ('.sql',) # [END gcp_spanner_query_template_fields] @apply_defaults - def __init__(self, *, - instance_id: str, - database_id: str, - query: Union[str, List[str]], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + database_id: str, + query: Union[str, List[str]], + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -236,31 +261,31 @@ def _validate_inputs(self): if self.project_id == '': raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: - raise AirflowException("The required parameter 'instance_id' " - "is empty or None") + raise AirflowException("The required parameter 'instance_id' " "is empty or None") if not self.database_id: - raise AirflowException("The required parameter 'database_id' " - "is empty or None") + raise AirflowException("The required parameter 'database_id' " "is empty or None") if not self.query: raise AirflowException("The required parameter 'query' is empty") def execute(self, context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = SpannerHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queries = self.query if isinstance(self.query, str): queries = [x.strip() for x in self.query.split(';')] self.sanitize_queries(queries) - self.log.info("Executing DML query(-ies) on " - "projects/%s/instances/%s/databases/%s", - self.project_id, self.instance_id, self.database_id) + self.log.info( + "Executing DML query(-ies) on " "projects/%s/instances/%s/databases/%s", + self.project_id, + self.instance_id, + self.database_id, + ) self.log.info(queries) - hook.execute_dml(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - queries=queries) + hook.execute_dml( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + queries=queries, + ) @staticmethod def sanitize_queries(queries): @@ -305,21 +330,31 @@ class SpannerDeployDatabaseInstanceOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_spanner_database_deploy_template_fields] - template_fields = ('project_id', 'instance_id', 'database_id', 'ddl_statements', - 'gcp_conn_id', 'impersonation_chain',) - template_ext = ('.sql', ) + template_fields = ( + 'project_id', + 'instance_id', + 'database_id', + 'ddl_statements', + 'gcp_conn_id', + 'impersonation_chain', + ) + template_ext = ('.sql',) # [END gcp_spanner_database_deploy_template_fields] @apply_defaults - def __init__(self, *, - instance_id: str, - database_id: str, - ddl_statements: List[str], - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + database_id: str, + ddl_statements: List[str], + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -333,31 +368,35 @@ def _validate_inputs(self): if self.project_id == '': raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: - raise AirflowException("The required parameter 'instance_id' is empty " - "or None") + raise AirflowException("The required parameter 'instance_id' is empty " "or None") if not self.database_id: - raise AirflowException("The required parameter 'database_id' is empty" - " or None") + raise AirflowException("The required parameter 'database_id' is empty" " or None") def execute(self, context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - if not hook.get_database(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id): - self.log.info("Creating Cloud Spanner database " - "'%s' in project '%s' and instance '%s'", - self.database_id, self.project_id, self.instance_id) - return hook.create_database(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - ddl_statements=self.ddl_statements) + hook = SpannerHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + if not hook.get_database( + project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id + ): + self.log.info( + "Creating Cloud Spanner database " "'%s' in project '%s' and instance '%s'", + self.database_id, + self.project_id, + self.instance_id, + ) + return hook.create_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ddl_statements=self.ddl_statements, + ) else: - self.log.info("The database '%s' in project '%s' and instance '%s'" - " already exists. Nothing to do. Exiting.", - self.database_id, self.project_id, self.instance_id) + self.log.info( + "The database '%s' in project '%s' and instance '%s'" + " already exists. Nothing to do. Exiting.", + self.database_id, + self.project_id, + self.instance_id, + ) return True @@ -393,22 +432,32 @@ class SpannerUpdateDatabaseInstanceOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_spanner_database_update_template_fields] - template_fields = ('project_id', 'instance_id', 'database_id', 'ddl_statements', - 'gcp_conn_id', 'impersonation_chain',) - template_ext = ('.sql', ) + template_fields = ( + 'project_id', + 'instance_id', + 'database_id', + 'ddl_statements', + 'gcp_conn_id', + 'impersonation_chain', + ) + template_ext = ('.sql',) # [END gcp_spanner_database_update_template_fields] @apply_defaults - def __init__(self, *, - instance_id: str, - database_id: str, - ddl_statements: List[str], - project_id: Optional[str] = None, - operation_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + database_id: str, + ddl_statements: List[str], + project_id: Optional[str] = None, + operation_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -423,34 +472,30 @@ def _validate_inputs(self): if self.project_id == '': raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: - raise AirflowException("The required parameter 'instance_id' is empty" - " or None") + raise AirflowException("The required parameter 'instance_id' is empty" " or None") if not self.database_id: - raise AirflowException("The required parameter 'database_id' is empty" - " or None") + raise AirflowException("The required parameter 'database_id' is empty" " or None") if not self.ddl_statements: - raise AirflowException("The required parameter 'ddl_statements' is empty" - " or None") + raise AirflowException("The required parameter 'ddl_statements' is empty" " or None") def execute(self, context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - if not hook.get_database(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id): - raise AirflowException("The Cloud Spanner database '{}' in project '{}' and " - "instance '{}' is missing. Create the database first " - "before you can update it.".format(self.database_id, - self.project_id, - self.instance_id)) + hook = SpannerHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + if not hook.get_database( + project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id + ): + raise AirflowException( + "The Cloud Spanner database '{}' in project '{}' and " + "instance '{}' is missing. Create the database first " + "before you can update it.".format(self.database_id, self.project_id, self.instance_id) + ) else: - return hook.update_database(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - ddl_statements=self.ddl_statements, - operation_id=self.operation_id) + return hook.update_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ddl_statements=self.ddl_statements, + operation_id=self.operation_id, + ) class SpannerDeleteDatabaseInstanceOperator(BaseOperator): @@ -480,19 +525,28 @@ class SpannerDeleteDatabaseInstanceOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_spanner_database_delete_template_fields] - template_fields = ('project_id', 'instance_id', 'database_id', - 'gcp_conn_id', 'impersonation_chain',) + template_fields = ( + 'project_id', + 'instance_id', + 'database_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END gcp_spanner_database_delete_template_fields] @apply_defaults - def __init__(self, *, - instance_id: str, - database_id: str, - project_id: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + instance_id: str, + database_id: str, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: self.instance_id = instance_id self.project_id = project_id self.database_id = database_id @@ -505,26 +559,25 @@ def _validate_inputs(self): if self.project_id == '': raise AirflowException("The required parameter 'project_id' is empty") if not self.instance_id: - raise AirflowException("The required parameter 'instance_id' is empty" - " or None") + raise AirflowException("The required parameter 'instance_id' is empty" " or None") if not self.database_id: - raise AirflowException("The required parameter 'database_id' is empty" - " or None") + raise AirflowException("The required parameter 'database_id' is empty" " or None") def execute(self, context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + hook = SpannerHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + database = hook.get_database( + project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id ) - database = hook.get_database(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id) if not database: - self.log.info("The Cloud Spanner database was missing: " - "'%s' in project '%s' and instance '%s'. Assuming success.", - self.database_id, self.project_id, self.instance_id) + self.log.info( + "The Cloud Spanner database was missing: " + "'%s' in project '%s' and instance '%s'. Assuming success.", + self.database_id, + self.project_id, + self.instance_id, + ) return True else: - return hook.delete_database(project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id) + return hook.delete_database( + project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id + ) diff --git a/airflow/providers/google/cloud/operators/speech_to_text.py b/airflow/providers/google/cloud/operators/speech_to_text.py index 3fce779e00004..d60be88967f47 100644 --- a/airflow/providers/google/cloud/operators/speech_to_text.py +++ b/airflow/providers/google/cloud/operators/speech_to_text.py @@ -67,14 +67,22 @@ class CloudSpeechToTextRecognizeSpeechOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_speech_to_text_synthesize_template_fields] - template_fields = ("audio", "config", "project_id", "gcp_conn_id", "timeout", - "impersonation_chain",) + template_fields = ( + "audio", + "config", + "project_id", + "gcp_conn_id", + "timeout", + "impersonation_chain", + ) # [END gcp_speech_to_text_synthesize_template_fields] @apply_defaults def __init__( - self, *, + self, + *, audio: RecognitionAudio, config: RecognitionConfig, project_id: Optional[str] = None, @@ -82,7 +90,7 @@ def __init__( retry: Optional[Retry] = None, timeout: Optional[float] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: self.audio = audio self.config = config @@ -102,8 +110,7 @@ def _validate_inputs(self): def execute(self, context): hook = CloudSpeechToTextHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) respones = hook.recognize_speech( config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout diff --git a/airflow/providers/google/cloud/operators/stackdriver.py b/airflow/providers/google/cloud/operators/stackdriver.py index 851cc04fc4efe..7e1c127d1fa29 100644 --- a/airflow/providers/google/cloud/operators/stackdriver.py +++ b/airflow/providers/google/cloud/operators/stackdriver.py @@ -85,13 +85,17 @@ class StackdriverListAlertPoliciesOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('filter_', 'impersonation_chain',) + template_fields = ( + 'filter_', + 'impersonation_chain', + ) ui_color = "#e5ffcc" # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, + self, + *, format_: Optional[str] = None, filter_: Optional[str] = None, order_by: Optional[str] = None, @@ -103,7 +107,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.format_ = format_ @@ -120,8 +124,14 @@ def __init__( self.hook = None def execute(self, context): - self.log.info('List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %d', - self.project_id, self.format_, self.filter_, self.order_by, self.page_size) + self.log.info( + 'List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %d', + self.project_id, + self.format_, + self.filter_, + self.order_by, + self.page_size, + ) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -137,7 +147,7 @@ def execute(self, context): page_size=self.page_size, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -182,12 +192,17 @@ class StackdriverEnableAlertPoliciesOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + ui_color = "#e5ffcc" - template_fields = ('filter_', 'impersonation_chain',) + template_fields = ( + 'filter_', + 'impersonation_chain', + ) @apply_defaults def __init__( - self, *, + self, + *, filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -196,7 +211,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.gcp_conn_id = gcp_conn_id @@ -222,7 +237,7 @@ def execute(self, context): project_id=self.project_id, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -270,11 +285,15 @@ class StackdriverDisableAlertPoliciesOperator(BaseOperator): """ ui_color = "#e5ffcc" - template_fields = ('filter_', 'impersonation_chain',) + template_fields = ( + 'filter_', + 'impersonation_chain', + ) @apply_defaults def __init__( - self, *, + self, + *, filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -283,7 +302,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.gcp_conn_id = gcp_conn_id @@ -309,7 +328,7 @@ def execute(self, context): project_id=self.project_id, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -356,14 +375,18 @@ class StackdriverUpsertAlertOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('alerts', 'impersonation_chain',) + template_fields = ( + 'alerts', + 'impersonation_chain', + ) template_ext = ('.json',) ui_color = "#e5ffcc" @apply_defaults def __init__( - self, *, + self, + *, alerts: str, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -372,7 +395,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.alerts = alerts @@ -398,7 +421,7 @@ def execute(self, context): project_id=self.project_id, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -442,13 +465,17 @@ class StackdriverDeleteAlertOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('name', 'impersonation_chain',) + template_fields = ( + 'name', + 'impersonation_chain', + ) ui_color = "#e5ffcc" @apply_defaults def __init__( - self, *, + self, + *, name: str, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -457,7 +484,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.name = name @@ -479,10 +506,7 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) self.hook.delete_alert_policy( - name=self.name, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, + name=self.name, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) @@ -546,14 +570,18 @@ class StackdriverListNotificationChannelsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('filter_', 'impersonation_chain',) + template_fields = ( + 'filter_', + 'impersonation_chain', + ) ui_color = "#e5ffcc" # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, + self, + *, format_: Optional[str] = None, filter_: Optional[str] = None, order_by: Optional[str] = None, @@ -565,7 +593,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.format_ = format_ @@ -584,7 +612,11 @@ def __init__( def execute(self, context): self.log.info( 'List Notification Channels: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %d', - self.project_id, self.format_, self.filter_, self.order_by, self.page_size + self.project_id, + self.format_, + self.filter_, + self.order_by, + self.page_size, ) if self.hook is None: self.hook = StackdriverHook( @@ -600,7 +632,7 @@ def execute(self, context): page_size=self.page_size, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -646,13 +678,17 @@ class StackdriverEnableNotificationChannelsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('filter_', 'impersonation_chain',) + template_fields = ( + 'filter_', + 'impersonation_chain', + ) ui_color = "#e5ffcc" @apply_defaults def __init__( - self, *, + self, + *, filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -661,7 +697,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.filter_ = filter_ @@ -675,8 +711,9 @@ def __init__( self.hook = None def execute(self, context): - self.log.info('Enable Notification Channels: Project id: %s Filter: %s', - self.project_id, self.filter_) + self.log.info( + 'Enable Notification Channels: Project id: %s Filter: %s', self.project_id, self.filter_ + ) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -688,7 +725,7 @@ def execute(self, context): project_id=self.project_id, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -734,13 +771,17 @@ class StackdriverDisableNotificationChannelsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('filter_', 'impersonation_chain',) + template_fields = ( + 'filter_', + 'impersonation_chain', + ) ui_color = "#e5ffcc" @apply_defaults def __init__( - self, *, + self, + *, filter_: Optional[str] = None, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -749,7 +790,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.filter_ = filter_ @@ -763,8 +804,9 @@ def __init__( self.hook = None def execute(self, context): - self.log.info('Disable Notification Channels: Project id: %s Filter: %s', - self.project_id, self.filter_) + self.log.info( + 'Disable Notification Channels: Project id: %s Filter: %s', self.project_id, self.filter_ + ) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -776,7 +818,7 @@ def execute(self, context): project_id=self.project_id, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -823,14 +865,18 @@ class StackdriverUpsertNotificationChannelOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('channels', 'impersonation_chain',) + template_fields = ( + 'channels', + 'impersonation_chain', + ) template_ext = ('.json',) ui_color = "#e5ffcc" @apply_defaults def __init__( - self, *, + self, + *, channels: str, retry: Optional[str] = DEFAULT, timeout: Optional[str] = DEFAULT, @@ -839,7 +885,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.channels = channels @@ -853,8 +899,9 @@ def __init__( self.hook = None def execute(self, context): - self.log.info('Upsert Notification Channels: Channels: %s Project id: %s', - self.channels, self.project_id) + self.log.info( + 'Upsert Notification Channels: Channels: %s Project id: %s', self.channels, self.project_id + ) if self.hook is None: self.hook = StackdriverHook( gcp_conn_id=self.gcp_conn_id, @@ -866,7 +913,7 @@ def execute(self, context): project_id=self.project_id, retry=self.retry, timeout=self.timeout, - metadata=self.metadata + metadata=self.metadata, ) @@ -910,13 +957,17 @@ class StackdriverDeleteNotificationChannelOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('name', 'impersonation_chain',) + template_fields = ( + 'name', + 'impersonation_chain', + ) ui_color = "#e5ffcc" @apply_defaults def __init__( - self, *, + self, + *, name: str, retry: Optional[str] = DEFAULT, timeout: Optional[float] = DEFAULT, @@ -925,7 +976,7 @@ def __init__( project_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.name = name @@ -947,8 +998,5 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) self.hook.delete_notification_channel( - name=self.name, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata + name=self.name, retry=self.retry, timeout=self.timeout, metadata=self.metadata ) diff --git a/airflow/providers/google/cloud/operators/tasks.py b/airflow/providers/google/cloud/operators/tasks.py index e4f5f42dd79bf..7074b3e173920 100644 --- a/airflow/providers/google/cloud/operators/tasks.py +++ b/airflow/providers/google/cloud/operators/tasks.py @@ -87,7 +87,8 @@ class CloudTasksQueueCreateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, task_queue: Queue, project_id: Optional[str] = None, @@ -97,7 +98,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -111,10 +112,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: queue = hook.create_queue( location=self.location, @@ -195,7 +193,8 @@ class CloudTasksQueueUpdateOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, task_queue: Queue, project_id: Optional[str] = None, location: Optional[str] = None, @@ -206,7 +205,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.task_queue = task_queue @@ -221,10 +220,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queue = hook.update_queue( task_queue=self.task_queue, project_id=self.project_id, @@ -273,12 +269,18 @@ class CloudTasksQueueGetOperator(BaseOperator): :rtype: google.cloud.tasks_v2.types.Queue """ - template_fields = ("location", "queue_name", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, project_id: Optional[str] = None, @@ -287,7 +289,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -300,10 +302,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queue = hook.get_queue( location=self.location, queue_name=self.queue_name, @@ -353,12 +352,17 @@ class CloudTasksQueuesListOperator(BaseOperator): :rtype: list[google.cloud.tasks_v2.types.Queue] """ - template_fields = ("location", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, project_id: Optional[str] = None, results_filter: Optional[str] = None, @@ -368,7 +372,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -382,10 +386,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queues = hook.list_queues( location=self.location, project_id=self.project_id, @@ -431,12 +432,18 @@ class CloudTasksQueueDeleteOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("location", "queue_name", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, project_id: Optional[str] = None, @@ -445,7 +452,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -458,10 +465,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.delete_queue( location=self.location, queue_name=self.queue_name, @@ -507,12 +511,18 @@ class CloudTasksQueuePurgeOperator(BaseOperator): :rtype: list[google.cloud.tasks_v2.types.Queue] """ - template_fields = ("location", "queue_name", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, project_id: Optional[str] = None, @@ -521,7 +531,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -534,10 +544,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queue = hook.purge_queue( location=self.location, queue_name=self.queue_name, @@ -584,12 +591,18 @@ class CloudTasksQueuePauseOperator(BaseOperator): :rtype: list[google.cloud.tasks_v2.types.Queue] """ - template_fields = ("location", "queue_name", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, project_id: Optional[str] = None, @@ -598,7 +611,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -611,10 +624,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queues = hook.pause_queue( location=self.location, queue_name=self.queue_name, @@ -661,12 +671,18 @@ class CloudTasksQueueResumeOperator(BaseOperator): :rtype: list[google.cloud.tasks_v2.types.Queue] """ - template_fields = ("location", "queue_name", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, project_id: Optional[str] = None, @@ -675,7 +691,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -688,10 +704,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) queue = hook.resume_queue( location=self.location, queue_name=self.queue_name, @@ -759,7 +772,8 @@ class CloudTasksTaskCreateOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, location: str, queue_name: str, task: Union[Dict, Task], @@ -771,7 +785,7 @@ def __init__( # pylint: disable=too-many-arguments metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -787,10 +801,7 @@ def __init__( # pylint: disable=too-many-arguments self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) task = hook.create_task( location=self.location, queue_name=self.queue_name, @@ -856,7 +867,8 @@ class CloudTasksTaskGetOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, task_name: str, @@ -867,7 +879,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -882,10 +894,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) task = hook.get_task( location=self.location, queue_name=self.queue_name, @@ -940,12 +949,18 @@ class CloudTasksTasksListOperator(BaseOperator): :rtype: list[google.cloud.tasks_v2.types.Task] """ - template_fields = ("location", "queue_name", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "queue_name", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, project_id: Optional[str] = None, @@ -956,7 +971,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -971,10 +986,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) tasks = hook.list_tasks( location=self.location, queue_name=self.queue_name, @@ -1034,7 +1046,8 @@ class CloudTasksTaskDeleteOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, task_name: str, @@ -1044,7 +1057,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1058,10 +1071,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.delete_task( location=self.location, queue_name=self.queue_name, @@ -1124,7 +1134,8 @@ class CloudTasksTaskRunOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, queue_name: str, task_name: str, @@ -1135,7 +1146,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1150,10 +1161,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTasksHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) task = hook.run_task( location=self.location, queue_name=self.queue_name, diff --git a/airflow/providers/google/cloud/operators/text_to_speech.py b/airflow/providers/google/cloud/operators/text_to_speech.py index 06f0aa3d79640..b6a6cd3bad42b 100644 --- a/airflow/providers/google/cloud/operators/text_to_speech.py +++ b/airflow/providers/google/cloud/operators/text_to_speech.py @@ -76,6 +76,7 @@ class CloudTextToSpeechSynthesizeOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_text_to_speech_synthesize_template_fields] template_fields = ( "input_data", @@ -91,7 +92,8 @@ class CloudTextToSpeechSynthesizeOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, input_data: Union[Dict, SynthesisInput], voice: Union[Dict, VoiceSelectionParams], audio_config: Union[Dict, AudioConfig], @@ -102,7 +104,7 @@ def __init__( retry: Optional[Retry] = None, timeout: Optional[float] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: self.input_data = input_data self.voice = voice @@ -130,8 +132,7 @@ def _validate_inputs(self): def execute(self, context): hook = CloudTextToSpeechHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) result = hook.synthesize_speech( input_data=self.input_data, @@ -143,8 +144,7 @@ def execute(self, context): with NamedTemporaryFile() as temp_file: temp_file.write(result.audio_content) cloud_storage_hook = GCSHook( - google_cloud_storage_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + google_cloud_storage_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) cloud_storage_hook.upload( bucket_name=self.target_bucket_name, object_name=self.target_filename, filename=temp_file.name diff --git a/airflow/providers/google/cloud/operators/translate.py b/airflow/providers/google/cloud/operators/translate.py index 43ed29f691ef7..32f75485b9651 100644 --- a/airflow/providers/google/cloud/operators/translate.py +++ b/airflow/providers/google/cloud/operators/translate.py @@ -83,14 +83,23 @@ class CloudTranslateTextOperator(BaseOperator): account from the list granting this role to the originating account (templated). """ + # [START translate_template_fields] - template_fields = ('values', 'target_language', 'format_', 'source_language', 'model', - 'gcp_conn_id', 'impersonation_chain',) + template_fields = ( + 'values', + 'target_language', + 'format_', + 'source_language', + 'model', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END translate_template_fields] @apply_defaults def __init__( - self, *, + self, + *, values: Union[List[str], str], target_language: str, format_: str, @@ -98,7 +107,7 @@ def __init__( model: str, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.values = values @@ -110,10 +119,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudTranslateHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudTranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: translation = hook.translate( values=self.values, diff --git a/airflow/providers/google/cloud/operators/translate_speech.py b/airflow/providers/google/cloud/operators/translate_speech.py index 1707e271ea365..36d3857992c94 100644 --- a/airflow/providers/google/cloud/operators/translate_speech.py +++ b/airflow/providers/google/cloud/operators/translate_speech.py @@ -107,14 +107,23 @@ class CloudTranslateSpeechOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ + # [START translate_speech_template_fields] - template_fields = ('target_language', 'format_', 'source_language', 'model', 'project_id', - 'gcp_conn_id', 'impersonation_chain',) + template_fields = ( + 'target_language', + 'format_', + 'source_language', + 'model', + 'project_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END translate_speech_template_fields] @apply_defaults def __init__( - self, *, + self, + *, audio: RecognitionAudio, config: RecognitionConfig, target_language: str, @@ -124,7 +133,7 @@ def __init__( project_id: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.audio = audio @@ -139,17 +148,13 @@ def __init__( def execute(self, context): speech_to_text_hook = CloudSpeechToTextHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) translate_hook = CloudTranslateHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - recognize_result = speech_to_text_hook.recognize_speech( - config=self.config, audio=self.audio - ) + recognize_result = speech_to_text_hook.recognize_speech(config=self.config, audio=self.audio) recognize_dict = MessageToDict(recognize_result) self.log.info("Recognition operation finished") @@ -162,8 +167,9 @@ def execute(self, context): try: transcript = recognize_dict['results'][0]['alternatives'][0]['transcript'] except KeyError as key: - raise AirflowException("Wrong response '{}' returned - it should contain {} field" - .format(recognize_dict, key)) + raise AirflowException( + "Wrong response '{}' returned - it should contain {} field".format(recognize_dict, key) + ) try: translation = translate_hook.translate( @@ -171,7 +177,7 @@ def execute(self, context): target_language=self.target_language, format_=self.format_, source_language=self.source_language, - model=self.model + model=self.model, ) self.log.info('Translated output: %s', translation) return translation diff --git a/airflow/providers/google/cloud/operators/video_intelligence.py b/airflow/providers/google/cloud/operators/video_intelligence.py index 17e0add92cf64..37f9475c00624 100644 --- a/airflow/providers/google/cloud/operators/video_intelligence.py +++ b/airflow/providers/google/cloud/operators/video_intelligence.py @@ -74,13 +74,20 @@ class CloudVideoIntelligenceDetectVideoLabelsOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_video_intelligence_detect_labels_template_fields] - template_fields = ("input_uri", "output_uri", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "input_uri", + "output_uri", + "gcp_conn_id", + "impersonation_chain", + ) # [END gcp_video_intelligence_detect_labels_template_fields] @apply_defaults def __init__( - self, *, + self, + *, input_uri: str, input_content: Optional[bytes] = None, output_uri: Optional[str] = None, @@ -90,7 +97,7 @@ def __init__( timeout: Optional[float] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.input_uri = input_uri @@ -105,8 +112,7 @@ def __init__( def execute(self, context): hook = CloudVideoIntelligenceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) operation = hook.annotate_video( input_uri=self.input_uri, @@ -115,7 +121,7 @@ def execute(self, context): location=self.location, retry=self.retry, features=[enums.Feature.LABEL_DETECTION], - timeout=self.timeout + timeout=self.timeout, ) self.log.info("Processing video for label annotations") result = MessageToDict(operation.result()) @@ -167,13 +173,20 @@ class CloudVideoIntelligenceDetectVideoExplicitContentOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_video_intelligence_detect_explicit_content_template_fields] - template_fields = ("input_uri", "output_uri", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "input_uri", + "output_uri", + "gcp_conn_id", + "impersonation_chain", + ) # [END gcp_video_intelligence_detect_explicit_content_template_fields] @apply_defaults def __init__( - self, *, + self, + *, input_uri: str, output_uri: Optional[str] = None, input_content: Optional[bytes] = None, @@ -183,7 +196,7 @@ def __init__( timeout: Optional[float] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.input_uri = input_uri @@ -198,8 +211,7 @@ def __init__( def execute(self, context): hook = CloudVideoIntelligenceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) operation = hook.annotate_video( input_uri=self.input_uri, @@ -208,7 +220,7 @@ def execute(self, context): location=self.location, retry=self.retry, features=[enums.Feature.EXPLICIT_CONTENT_DETECTION], - timeout=self.timeout + timeout=self.timeout, ) self.log.info("Processing video for explicit content annotations") result = MessageToDict(operation.result()) @@ -260,13 +272,20 @@ class CloudVideoIntelligenceDetectVideoShotsOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START gcp_video_intelligence_detect_video_shots_template_fields] - template_fields = ("input_uri", "output_uri", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "input_uri", + "output_uri", + "gcp_conn_id", + "impersonation_chain", + ) # [END gcp_video_intelligence_detect_video_shots_template_fields] @apply_defaults def __init__( - self, *, + self, + *, input_uri: str, output_uri: Optional[str] = None, input_content: Optional[bytes] = None, @@ -276,7 +295,7 @@ def __init__( timeout: Optional[float] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.input_uri = input_uri @@ -291,8 +310,7 @@ def __init__( def execute(self, context): hook = CloudVideoIntelligenceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) operation = hook.annotate_video( input_uri=self.input_uri, @@ -301,7 +319,7 @@ def execute(self, context): location=self.location, retry=self.retry, features=[enums.Feature.SHOT_CHANGE_DETECTION], - timeout=self.timeout + timeout=self.timeout, ) self.log.info("Processing video for video shots annotations") result = MessageToDict(operation.result()) diff --git a/airflow/providers/google/cloud/operators/vision.py b/airflow/providers/google/cloud/operators/vision.py index 6eaf6b5c643d2..3e60783195aa0 100644 --- a/airflow/providers/google/cloud/operators/vision.py +++ b/airflow/providers/google/cloud/operators/vision.py @@ -25,7 +25,12 @@ from google.api_core.exceptions import AlreadyExists from google.api_core.retry import Retry from google.cloud.vision_v1.types import ( - AnnotateImageRequest, FieldMask, Image, Product, ProductSet, ReferenceImage, + AnnotateImageRequest, + FieldMask, + Image, + Product, + ProductSet, + ReferenceImage, ) from airflow.models import BaseOperator @@ -78,14 +83,21 @@ class CloudVisionCreateProductSetOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_productset_create_template_fields] - template_fields = ("location", "project_id", "product_set_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "project_id", + "product_set_id", + "gcp_conn_id", + "impersonation_chain", + ) # [END vision_productset_create_template_fields] @apply_defaults def __init__( - self, *, + self, + *, product_set: Union[dict, ProductSet], location: str, project_id: Optional[str] = None, @@ -95,7 +107,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -109,10 +121,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: return hook.create_product_set( location=self.location, @@ -168,14 +177,21 @@ class CloudVisionGetProductSetOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_productset_get_template_fields] - template_fields = ('location', 'project_id', 'product_set_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_set_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_productset_get_template_fields] @apply_defaults def __init__( - self, *, + self, + *, location: str, product_set_id: str, project_id: Optional[str] = None, @@ -184,7 +200,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -197,10 +213,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.get_product_set( location=self.location, product_set_id=self.product_set_id, @@ -267,14 +280,21 @@ class CloudVisionUpdateProductSetOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_productset_update_template_fields] - template_fields = ('location', 'project_id', 'product_set_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_set_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_productset_update_template_fields] @apply_defaults def __init__( - self, *, + self, + *, product_set: Union[Dict, ProductSet], location: Optional[str] = None, product_set_id: Optional[str] = None, @@ -285,7 +305,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.product_set = product_set @@ -300,10 +320,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.update_product_set( location=self.location, product_set_id=self.product_set_id, @@ -355,14 +372,21 @@ class CloudVisionDeleteProductSetOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_productset_delete_template_fields] - template_fields = ('location', 'project_id', 'product_set_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_set_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_productset_delete_template_fields] @apply_defaults def __init__( - self, *, + self, + *, location: str, product_set_id: str, project_id: Optional[str] = None, @@ -371,7 +395,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -384,10 +408,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.delete_product_set( location=self.location, product_set_id=self.product_set_id, @@ -447,14 +468,21 @@ class CloudVisionCreateProductOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_product_create_template_fields] - template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_product_create_template_fields] @apply_defaults def __init__( - self, *, + self, + *, location: str, product: str, project_id: Optional[str] = None, @@ -464,7 +492,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -478,10 +506,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) try: return hook.create_product( location=self.location, @@ -540,14 +565,21 @@ class CloudVisionGetProductOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_product_get_template_fields] - template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_product_get_template_fields] @apply_defaults def __init__( - self, *, + self, + *, location: str, product_id: str, project_id: Optional[str] = None, @@ -556,7 +588,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -569,10 +601,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.get_product( location=self.location, product_id=self.product_id, @@ -650,14 +679,21 @@ class CloudVisionUpdateProductOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_product_update_template_fields] - template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_product_update_template_fields] @apply_defaults def __init__( - self, *, + self, + *, product: Union[Dict, Product], location: Optional[str] = None, product_id: Optional[str] = None, @@ -668,7 +704,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.product = product @@ -683,10 +719,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.update_product( product=self.product, location=self.location, @@ -743,14 +776,21 @@ class CloudVisionDeleteProductOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_product_delete_template_fields] - template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id', - 'impersonation_chain',) + template_fields = ( + 'location', + 'project_id', + 'product_id', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_product_delete_template_fields] @apply_defaults def __init__( - self, *, + self, + *, location: str, product_id: str, project_id: Optional[str] = None, @@ -759,7 +799,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -772,10 +812,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.delete_product( location=self.location, product_id=self.product_id, @@ -818,19 +855,25 @@ class CloudVisionImageAnnotateOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_annotate_image_template_fields] - template_fields = ('request', 'gcp_conn_id', 'impersonation_chain',) + template_fields = ( + 'request', + 'gcp_conn_id', + 'impersonation_chain', + ) # [END vision_annotate_image_template_fields] @apply_defaults def __init__( - self, *, + self, + *, request: Union[Dict, AnnotateImageRequest], retry: Optional[Retry] = None, timeout: Optional[float] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.request = request @@ -840,18 +883,13 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) if not isinstance(self.request, list): response = hook.annotate_image(request=self.request, retry=self.retry, timeout=self.timeout) else: response = hook.batch_annotate_images( - requests=self.request, - retry=self.retry, - timeout=self.timeout + requests=self.request, retry=self.retry, timeout=self.timeout ) return response @@ -904,6 +942,7 @@ class CloudVisionCreateReferenceImageOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_reference_image_create_template_fields] template_fields = ( "location", @@ -918,7 +957,8 @@ class CloudVisionCreateReferenceImageOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, reference_image: Union[Dict, ReferenceImage], product_id: str, @@ -929,7 +969,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -946,8 +986,7 @@ def __init__( def execute(self, context): try: hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) return hook.create_reference_image( location=self.location, @@ -1009,6 +1048,7 @@ class CloudVisionDeleteReferenceImageOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_reference_image_create_template_fields] template_fields = ( "location", @@ -1022,7 +1062,8 @@ class CloudVisionDeleteReferenceImageOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, location: str, product_id: str, reference_image_id: str, @@ -1032,7 +1073,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.location = location @@ -1046,10 +1087,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.delete_reference_image( location=self.location, product_id=self.product_id, @@ -1106,14 +1144,22 @@ class CloudVisionAddProductToProductSetOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_add_product_to_product_set_template_fields] - template_fields = ("location", "product_set_id", "product_id", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "product_set_id", + "product_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) # [END vision_add_product_to_product_set_template_fields] @apply_defaults def __init__( - self, *, + self, + *, product_set_id: str, product_id: str, location: str, @@ -1123,7 +1169,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.product_set_id = product_set_id @@ -1137,10 +1183,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.add_product_to_product_set( product_set_id=self.product_set_id, product_id=self.product_id, @@ -1191,14 +1234,22 @@ class CloudVisionRemoveProductFromProductSetOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_remove_product_from_product_set_template_fields] - template_fields = ("location", "product_set_id", "product_id", "project_id", "gcp_conn_id", - "impersonation_chain",) + template_fields = ( + "location", + "product_set_id", + "product_id", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) # [END vision_remove_product_from_product_set_template_fields] @apply_defaults def __init__( - self, *, + self, + *, product_set_id: str, product_id: str, location: str, @@ -1208,7 +1259,7 @@ def __init__( metadata: Optional[MetaData] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.product_set_id = product_set_id @@ -1222,10 +1273,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.remove_product_from_product_set( product_set_id=self.product_set_id, product_id=self.product_id, @@ -1276,8 +1324,15 @@ class CloudVisionDetectTextOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_detect_text_set_template_fields] - template_fields = ("image", "max_results", "timeout", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) # [END vision_detect_text_set_template_fields] def __init__( @@ -1291,7 +1346,7 @@ def __init__( additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.image = image @@ -1308,10 +1363,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.text_detection( image=self.image, max_results=self.max_results, @@ -1360,9 +1412,15 @@ class CloudVisionTextDetectOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_document_detect_text_set_template_fields] - template_fields = ("image", "max_results", "timeout", "gcp_conn_id", - "impersonation_chain",) # Iterable[str] + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) # Iterable[str] # [END vision_document_detect_text_set_template_fields] def __init__( @@ -1376,7 +1434,7 @@ def __init__( additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.image = image @@ -1392,10 +1450,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.document_text_detection( image=self.image, max_results=self.max_results, @@ -1438,8 +1493,15 @@ class CloudVisionDetectImageLabelsOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_detect_labels_template_fields] - template_fields = ("image", "max_results", "timeout", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) # [END vision_detect_labels_template_fields] def __init__( @@ -1451,7 +1513,7 @@ def __init__( additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.image = image @@ -1463,10 +1525,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.label_detection( image=self.image, max_results=self.max_results, @@ -1509,8 +1568,15 @@ class CloudVisionDetectImageSafeSearchOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + # [START vision_detect_safe_search_template_fields] - template_fields = ("image", "max_results", "timeout", "gcp_conn_id", "impersonation_chain",) + template_fields = ( + "image", + "max_results", + "timeout", + "gcp_conn_id", + "impersonation_chain", + ) # [END vision_detect_safe_search_template_fields] def __init__( @@ -1522,7 +1588,7 @@ def __init__( additional_properties: Optional[Dict] = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.image = image @@ -1534,10 +1600,7 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context): - hook = CloudVisionHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) return hook.safe_search_detection( image=self.image, max_results=self.max_results, @@ -1548,9 +1611,7 @@ def execute(self, context): def prepare_additional_parameters( - additional_properties: Optional[Dict], - language_hints: Any, - web_detection_params: Any + additional_properties: Optional[Dict], language_hints: Any, web_detection_params: Any ) -> Optional[Dict]: """ Creates additional_properties parameter based on language_hints, web_detection_params and diff --git a/airflow/providers/google/cloud/secrets/secret_manager.py b/airflow/providers/google/cloud/secrets/secret_manager.py index 88f1421ed8f45..dcd961d30ea45 100644 --- a/airflow/providers/google/cloud/secrets/secret_manager.py +++ b/airflow/providers/google/cloud/secrets/secret_manager.py @@ -70,6 +70,7 @@ class CloudSecretManagerBackend(BaseSecretsBackend, LoggingMixin): :param sep: Separator used to concatenate connections_prefix and conn_id. Default: "-" :type sep: str """ + def __init__( self, connections_prefix: str = "airflow-connections", @@ -79,7 +80,7 @@ def __init__( gcp_scopes: Optional[str] = None, project_id: Optional[str] = None, sep: str = "-", - **kwargs + **kwargs, ): super().__init__(**kwargs) self.connections_prefix = connections_prefix @@ -91,9 +92,7 @@ def __init__( f"follows that pattern {SECRET_ID_PATTERN}" ) self.credentials, self.project_id = get_credentials_and_project_id( - keyfile_dict=gcp_keyfile_dict, - key_path=gcp_key_path, - scopes=gcp_scopes + keyfile_dict=gcp_keyfile_dict, key_path=gcp_key_path, scopes=gcp_scopes ) # In case project id provided if project_id: diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index e07b824282ccc..bd325eb63070e 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -55,18 +55,27 @@ class BigQueryTableExistenceSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('project_id', 'dataset_id', 'table_id', 'impersonation_chain',) + + template_fields = ( + 'project_id', + 'dataset_id', + 'table_id', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - project_id: str, - dataset_id: str, - table_id: str, - bigquery_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + project_id: str, + dataset_id: str, + table_id: str, + bigquery_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -85,9 +94,8 @@ def poke(self, context): impersonation_chain=self.impersonation_chain, ) return hook.table_exists( - project_id=self.project_id, - dataset_id=self.dataset_id, - table_id=self.table_id) + project_id=self.project_id, dataset_id=self.dataset_id, table_id=self.table_id + ) class BigQueryTablePartitionExistenceSensor(BaseSensorOperator): @@ -122,20 +130,29 @@ class BigQueryTablePartitionExistenceSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('project_id', 'dataset_id', 'table_id', 'partition_id', - 'impersonation_chain',) + + template_fields = ( + 'project_id', + 'dataset_id', + 'table_id', + 'partition_id', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - project_id: str, - dataset_id: str, - table_id: str, - partition_id: str, - bigquery_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + project_id: str, + dataset_id: str, + table_id: str, + partition_id: str, + bigquery_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.project_id = project_id @@ -158,5 +175,5 @@ def poke(self, context): project_id=self.project_id, dataset_id=self.dataset_id, table_id=self.table_id, - partition_id=self.partition_id + partition_id=self.partition_id, ) diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py b/airflow/providers/google/cloud/sensors/bigquery_dts.py index e774221eb4bfd..84fedd6e46cad 100644 --- a/airflow/providers/google/cloud/sensors/bigquery_dts.py +++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py @@ -90,7 +90,7 @@ def __init__( request_timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.run_id = run_id @@ -99,9 +99,7 @@ def __init__( self.request_timeout = request_timeout self.metadata = metadata self.expected_statuses = ( - {expected_statuses} - if isinstance(expected_statuses, str) - else expected_statuses + {expected_statuses} if isinstance(expected_statuses, str) else expected_statuses ) self.project_id = project_id self.gcp_cloud_conn_id = gcp_conn_id @@ -109,8 +107,7 @@ def __init__( def poke(self, context): hook = BiqQueryDataTransferServiceHook( - gcp_conn_id=self.gcp_cloud_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_cloud_conn_id, impersonation_chain=self.impersonation_chain, ) run = hook.get_transfer_run( run_id=self.run_id, diff --git a/airflow/providers/google/cloud/sensors/bigtable.py b/airflow/providers/google/cloud/sensors/bigtable.py index a46d0c26c2d14..0d1671ab5a2ad 100644 --- a/airflow/providers/google/cloud/sensors/bigtable.py +++ b/airflow/providers/google/cloud/sensors/bigtable.py @@ -58,8 +58,14 @@ class BigtableTableReplicationCompletedSensor(BaseSensorOperator, BigtableValida account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ + REQUIRED_ATTRIBUTES = ('instance_id', 'table_id') - template_fields = ['project_id', 'instance_id', 'table_id', 'impersonation_chain', ] + template_fields = [ + 'project_id', + 'instance_id', + 'table_id', + 'impersonation_chain', + ] @apply_defaults def __init__( @@ -70,7 +76,7 @@ def __init__( project_id: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: self.project_id = project_id self.instance_id = instance_id @@ -81,10 +87,7 @@ def __init__( super().__init__(**kwargs) def poke(self, context): - hook = BigtableHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = BigtableHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) instance = hook.get_instance(project_id=self.project_id, instance_id=self.instance_id) if not instance: self.log.info("Dependency: instance '%s' does not exist.", self.instance_id) @@ -94,8 +97,8 @@ def poke(self, context): cluster_states = hook.get_cluster_states_for_table(instance=instance, table_id=self.table_id) except google.api_core.exceptions.NotFound: self.log.info( - "Dependency: table '%s' does not exist in instance '%s'.", - self.table_id, self.instance_id) + "Dependency: table '%s' does not exist in instance '%s'.", self.table_id, self.instance_id + ) return False ready_state = ClusterState(enums.Table.ClusterState.ReplicationState.READY) @@ -103,8 +106,7 @@ def poke(self, context): is_table_replicated = True for cluster_id in cluster_states.keys(): if cluster_states[cluster_id] != ready_state: - self.log.info("Table '%s' is not yet replicated on cluster '%s'.", - self.table_id, cluster_id) + self.log.info("Table '%s' is not yet replicated on cluster '%s'.", self.table_id, cluster_id) is_table_replicated = False if not is_table_replicated: diff --git a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py index 28957ad8f9bbe..78410363fb75a 100644 --- a/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py @@ -59,18 +59,22 @@ class CloudDataTransferServiceJobStatusSensor(BaseSensorOperator): """ # [START gcp_transfer_job_sensor_template_fields] - template_fields = ('job_name', 'impersonation_chain',) + template_fields = ( + 'job_name', + 'impersonation_chain', + ) # [END gcp_transfer_job_sensor_template_fields] @apply_defaults def __init__( - self, *, + self, + *, job_name: str, expected_statuses: Union[Set[str], str], project_id: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.job_name = job_name @@ -83,8 +87,7 @@ def __init__( def poke(self, context): hook = CloudDataTransferServiceHook( - gcp_conn_id=self.gcp_cloud_conn_id, - impersonation_chain=self.impersonation_chain, + gcp_conn_id=self.gcp_cloud_conn_id, impersonation_chain=self.impersonation_chain, ) operations = hook.list_transfer_operations( request_filter={'project_id': self.project_id, 'job_names': [self.job_name]} diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index c5599a7bd5aef..1df443c83716b 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -55,17 +55,25 @@ class GCSObjectExistenceSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket', 'object', 'impersonation_chain',) + + template_fields = ( + 'bucket', + 'object', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - bucket: str, - object: str, # pylint: disable=redefined-builtin - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket: str, + object: str, # pylint: disable=redefined-builtin + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket = bucket @@ -123,18 +131,25 @@ class GCSObjectUpdateSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket', 'object', 'impersonation_chain',) + + template_fields = ( + 'bucket', + 'object', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, - bucket: str, - object: str, # pylint: disable=redefined-builtin - ts_func: Callable = ts_function, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + bucket: str, + object: str, # pylint: disable=redefined-builtin + ts_func: Callable = ts_function, + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket = bucket @@ -184,17 +199,24 @@ class GCSObjectsWtihPrefixExistenceSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket', 'prefix', 'impersonation_chain',) + + template_fields = ( + 'bucket', + 'prefix', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, - bucket: str, - prefix: str, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + bucket: str, + prefix: str, + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket = bucket self.prefix = prefix @@ -204,8 +226,7 @@ def __init__(self, self.impersonation_chain = impersonation_chain def poke(self, context): - self.log.info('Sensor checks existence of objects: %s, %s', - self.bucket, self.prefix) + self.log.info('Sensor checks existence of objects: %s, %s', self.bucket, self.prefix) hook = GCSHook( google_cloud_storage_conn_id=self.google_cloud_conn_id, delegate_to=self.delegate_to, @@ -274,21 +295,27 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket', 'prefix', 'impersonation_chain',) + template_fields = ( + 'bucket', + 'prefix', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, - bucket: str, - prefix: str, - inactivity_period: float = 60 * 60, - min_objects: int = 1, - previous_objects: Optional[Set[str]] = None, - allow_delete: bool = True, - google_cloud_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + bucket: str, + prefix: str, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + previous_objects: Optional[Set[str]] = None, + allow_delete: bool = True, + google_cloud_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) @@ -328,10 +355,11 @@ def is_bucket_updated(self, current_objects: Set[str]) -> bool: if current_objects > self.previous_objects: # When new objects arrived, reset the inactivity_seconds # and update previous_objects for the next poke. - self.log.info("New objects found at %s resetting last_activity_time.", - os.path.join(self.bucket, self.prefix)) - self.log.debug("New objects: %s", - "\n".join(current_objects - self.previous_objects)) + self.log.info( + "New objects found at %s resetting last_activity_time.", + os.path.join(self.bucket, self.prefix), + ) + self.log.debug("New objects: %s", "\n".join(current_objects - self.previous_objects)) self.last_activity_time = get_time() self.inactivity_seconds = 0 self.previous_objects = current_objects @@ -348,14 +376,17 @@ def is_bucket_updated(self, current_objects: Set[str]) -> bool: poke interval. Updating the file counter and resetting last_activity_time. %s - """, self.previous_objects - current_objects + """, + self.previous_objects - current_objects, ) return False raise AirflowException( """ Illegal behavior: objects were deleted in {} between pokes. - """.format(os.path.join(self.bucket, self.prefix)) + """.format( + os.path.join(self.bucket, self.prefix) + ) ) if self.last_activity_time: @@ -369,10 +400,15 @@ def is_bucket_updated(self, current_objects: Set[str]) -> bool: path = os.path.join(self.bucket, self.prefix) if current_num_objects >= self.min_objects: - self.log.info("""SUCCESS: + self.log.info( + """SUCCESS: Sensor found %s objects at %s. Waited at least %s seconds, with no new objects dropped. - """, current_num_objects, path, self.inactivity_period) + """, + current_num_objects, + path, + self.inactivity_period, + ) return True self.log.error("FAILURE: Inactivity Period passed, not enough objects found in %s", path) diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index a0baa01495cdd..9153c874a9dda 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -97,30 +97,38 @@ class PubSubPullSensor(BaseSensorOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ['project_id', 'subscription', 'impersonation_chain', ] + + template_fields = [ + 'project_id', + 'subscription', + 'impersonation_chain', + ] ui_color = '#ff7f50' @apply_defaults def __init__( - self, *, - project_id: str, - subscription: str, - max_messages: int = 5, - return_immediately: bool = True, - ack_messages: bool = False, - gcp_conn_id: str = 'google_cloud_default', - messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] = None, - delegate_to: Optional[str] = None, - project: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + self, + *, + project_id: str, + subscription: str, + max_messages: int = 5, + return_immediately: bool = True, + ack_messages: bool = False, + gcp_conn_id: str = 'google_cloud_default', + messages_callback: Optional[Callable[[List[ReceivedMessage], Dict[str, Any]], Any]] = None, + delegate_to: Optional[str] = None, + project: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, ) -> None: # To preserve backward compatibility # TODO: remove one day if project: warnings.warn( - "The project parameter has been deprecated. You should pass " - "the project_id parameter.", DeprecationWarning, stacklevel=2) + "The project parameter has been deprecated. You should pass " "the project_id parameter.", + DeprecationWarning, + stacklevel=2, + ) project_id = project if not return_immediately: @@ -132,7 +140,7 @@ def __init__( " If is here only because of backwards compatibility.\n" " If may be removed in the future.\n", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) super().__init__(**kwargs) @@ -173,17 +181,15 @@ def poke(self, context): if pulled_messages and self.ack_messages: hook.acknowledge( - project_id=self.project_id, - subscription=self.subscription, - messages=pulled_messages, + project_id=self.project_id, subscription=self.subscription, messages=pulled_messages, ) return bool(pulled_messages) def _default_message_callback( - self, - pulled_messages: List[ReceivedMessage], - context: Dict[str, Any], # pylint: disable=unused-argument + self, + pulled_messages: List[ReceivedMessage], + context: Dict[str, Any], # pylint: disable=unused-argument ): """ This method can be overridden by subclasses or by `messages_callback` constructor argument. @@ -195,9 +201,6 @@ def _default_message_callback( :return: value to be saved to XCom. """ - messages_json = [ - MessageToDict(m) - for m in pulled_messages - ] + messages_json = [MessageToDict(m) for m in pulled_messages] return messages_json diff --git a/airflow/providers/google/cloud/transfers/adls_to_gcs.py b/airflow/providers/google/cloud/transfers/adls_to_gcs.py index 5f92f488aa37e..04f8b9fc01204 100644 --- a/airflow/providers/google/cloud/transfers/adls_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/adls_to_gcs.py @@ -103,32 +103,39 @@ class ADLSToGCSOperator(AzureDataLakeStorageListOperator): gcp_conn_id='google_cloud_default' ) """ - template_fields: Sequence[str] = ('src_adls', 'dest_gcs', 'google_impersonation_chain',) + + template_fields: Sequence[str] = ( + 'src_adls', + 'dest_gcs', + 'google_impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - src_adls: str, - dest_gcs: str, - azure_data_lake_conn_id: str, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - replace: bool = False, - gzip: bool = False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: - - super().__init__( - path=src_adls, - azure_data_lake_conn_id=azure_data_lake_conn_id, - **kwargs - ) + def __init__( + self, + *, + src_adls: str, + dest_gcs: str, + azure_data_lake_conn_id: str, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + replace: bool = False, + gzip: bool = False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + + super().__init__(path=src_adls, azure_data_lake_conn_id=azure_data_lake_conn_id, **kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.src_adls = src_adls @@ -157,9 +164,7 @@ def execute(self, context): files = set(files) - set(existing_files) if files: - hook = AzureDataLakeHook( - azure_data_lake_conn_id=self.azure_data_lake_conn_id - ) + hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) for obj in files: with NamedTemporaryFile(mode='wb', delete=True) as f: @@ -170,10 +175,7 @@ def execute(self, context): self.log.info("Saving file to %s", dest_path) g_hook.upload( - bucket_name=dest_gcs_bucket, - object_name=dest_path, - filename=f.name, - gzip=self.gzip + bucket_name=dest_gcs_bucket, object_name=dest_path, filename=f.name, gzip=self.gzip ) self.log.info("All done, uploaded %d files to GCS", len(files)) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index 7a92d35188aa0..004753648e966 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -78,31 +78,42 @@ class BigQueryToBigQueryOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('source_project_dataset_tables', - 'destination_project_dataset_table', 'labels', 'impersonation_chain',) + + template_fields = ( + 'source_project_dataset_tables', + 'destination_project_dataset_table', + 'labels', + 'impersonation_chain', + ) template_ext = ('.sql',) ui_color = '#e6f0e4' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - source_project_dataset_tables: Union[List[str], str], - destination_project_dataset_table: str, - write_disposition: str = 'WRITE_EMPTY', - create_disposition: str = 'CREATE_IF_NEEDED', - gcp_conn_id: str = 'google_cloud_default', - bigquery_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - encryption_configuration: Optional[Dict] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, # pylint: disable=too-many-arguments + source_project_dataset_tables: Union[List[str], str], + destination_project_dataset_table: str, + write_disposition: str = 'WRITE_EMPTY', + create_disposition: str = 'CREATE_IF_NEEDED', + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + encryption_configuration: Optional[Dict] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.source_project_dataset_tables = source_project_dataset_tables @@ -119,12 +130,15 @@ def __init__(self, *, # pylint: disable=too-many-arguments def execute(self, context): self.log.info( 'Executing copy of %s into: %s', - self.source_project_dataset_tables, self.destination_project_dataset_table + self.source_project_dataset_tables, + self.destination_project_dataset_table, + ) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, ) - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - location=self.location, - impersonation_chain=self.impersonation_chain) conn = hook.get_conn() cursor = conn.cursor() cursor.run_copy( @@ -133,4 +147,5 @@ def execute(self, context): write_disposition=self.write_disposition, create_disposition=self.create_disposition, labels=self.labels, - encryption_configuration=self.encryption_configuration) + encryption_configuration=self.encryption_configuration, + ) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index f2eec028c5cda..65705f0175b52 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -76,32 +76,43 @@ class BigQueryToGCSOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('source_project_dataset_table', - 'destination_cloud_storage_uris', 'labels', 'impersonation_chain',) + + template_fields = ( + 'source_project_dataset_table', + 'destination_cloud_storage_uris', + 'labels', + 'impersonation_chain', + ) template_ext = () ui_color = '#e4e6f0' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - source_project_dataset_table: str, - destination_cloud_storage_uris: List[str], - compression: str = 'NONE', - export_format: str = 'CSV', - field_delimiter: str = ',', - print_header: bool = True, - gcp_conn_id: str = 'google_cloud_default', - bigquery_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - labels: Optional[Dict] = None, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, # pylint: disable=too-many-arguments + source_project_dataset_table: str, + destination_cloud_storage_uris: List[str], + compression: str = 'NONE', + export_format: str = 'CSV', + field_delimiter: str = ',', + print_header: bool = True, + gcp_conn_id: str = 'google_cloud_default', + bigquery_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + labels: Optional[Dict] = None, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if bigquery_conn_id: warnings.warn( "The bigquery_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = bigquery_conn_id self.source_project_dataset_table = source_project_dataset_table @@ -117,13 +128,17 @@ def __init__(self, *, # pylint: disable=too-many-arguments self.impersonation_chain = impersonation_chain def execute(self, context): - self.log.info('Executing extract of %s into: %s', - self.source_project_dataset_table, - self.destination_cloud_storage_uris) - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - location=self.location, - impersonation_chain=self.impersonation_chain) + self.log.info( + 'Executing extract of %s into: %s', + self.source_project_dataset_table, + self.destination_cloud_storage_uris, + ) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) conn = hook.get_conn() cursor = conn.cursor() cursor.run_extract( @@ -133,4 +148,5 @@ def execute(self, context): export_format=self.export_format, field_delimiter=self.field_delimiter, print_header=self.print_header, - labels=self.labels) + labels=self.labels, + ) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py b/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py index d4cc5b1e4a725..d400fd2efd494 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_mysql.py @@ -85,22 +85,31 @@ class BigQueryToMySqlOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('dataset_id', 'table_id', 'mysql_table', 'impersonation_chain',) + + template_fields = ( + 'dataset_id', + 'table_id', + 'mysql_table', + 'impersonation_chain', + ) @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - dataset_table: str, - mysql_table: str, - selected_fields: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - mysql_conn_id: str = 'mysql_default', - database: Optional[str] = None, - delegate_to: Optional[str] = None, - replace: bool = False, - batch_size: int = 1000, - location: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, # pylint: disable=too-many-arguments + dataset_table: str, + mysql_table: str, + selected_fields: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + mysql_conn_id: str = 'mysql_default', + database: Optional[str] = None, + delegate_to: Optional[str] = None, + replace: bool = False, + batch_size: int = 1000, + location: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.selected_fields = selected_fields self.gcp_conn_id = gcp_conn_id @@ -115,28 +124,30 @@ def __init__(self, *, # pylint: disable=too-many-arguments try: self.dataset_id, self.table_id = dataset_table.split('.') except ValueError: - raise ValueError('Could not parse {} as .
' - .format(dataset_table)) + raise ValueError('Could not parse {} as .
'.format(dataset_table)) def _bq_get_data(self): self.log.info('Fetching Data from:') - self.log.info('Dataset: %s ; Table: %s', - self.dataset_id, self.table_id) + self.log.info('Dataset: %s ; Table: %s', self.dataset_id, self.table_id) - hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - location=self.location, - impersonation_chain=self.impersonation_chain) + hook = BigQueryHook( + bigquery_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) conn = hook.get_conn() cursor = conn.cursor() i = 0 while True: - response = cursor.get_tabledata(dataset_id=self.dataset_id, - table_id=self.table_id, - max_results=self.batch_size, - selected_fields=self.selected_fields, - start_index=i * self.batch_size) + response = cursor.get_tabledata( + dataset_id=self.dataset_id, + table_id=self.table_id, + max_results=self.batch_size, + selected_fields=self.selected_fields, + start_index=i * self.batch_size, + ) if 'rows' in response: rows = response['rows'] diff --git a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index cf7da69e93cc1..3aa4640842944 100644 --- a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -86,30 +86,43 @@ class CassandraToGCSOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('cql', 'bucket', 'filename', 'schema_filename', 'impersonation_chain',) + + template_fields = ( + 'cql', + 'bucket', + 'filename', + 'schema_filename', + 'impersonation_chain', + ) template_ext = ('.cql',) ui_color = '#a0e08c' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - cql: str, - bucket: str, - filename: str, - schema_filename: Optional[str] = None, - approx_max_file_size_bytes: int = 1900000000, - gzip: bool = False, - cassandra_conn_id: str = 'cassandra_default', - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, # pylint: disable=too-many-arguments + cql: str, + bucket: str, + filename: str, + schema_filename: Optional[str] = None, + approx_max_file_size_bytes: int = 1900000000, + gzip: bool = False, + cassandra_conn_id: str = 'cassandra_default', + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.cql = cql @@ -229,7 +242,7 @@ def _upload_to_gcs(self, files_to_upload: Dict[str, Any]): object_name=obj, filename=tmp_file_handle.name, mime_type='application/json', - gzip=self.gzip + gzip=self.gzip, ) @classmethod @@ -241,8 +254,7 @@ def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any] @classmethod def convert_value( # pylint: disable=too-many-return-statements - cls, - value: Optional[Any] + cls, value: Optional[Any] ) -> Optional[Any]: """ Convert value to BQ type. @@ -308,10 +320,7 @@ def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any] """ converted_map = [] for k, v in zip(value.keys(), value.values()): - converted_map.append({ - 'key': cls.convert_value(k), - 'value': cls.convert_value(v) - }) + converted_map.append({'key': cls.convert_value(k), 'value': cls.convert_value(v)}) return converted_map @classmethod diff --git a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py index 46e188a166af1..1bf6b8bc8b154 100644 --- a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py @@ -77,11 +77,17 @@ class FacebookAdsReportToGcsOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("facebook_conn_id", "bucket_name", "object_name", "impersonation_chain",) + template_fields = ( + "facebook_conn_id", + "bucket_name", + "object_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, bucket_name: str, object_name: str, fields: List[str], @@ -105,10 +111,10 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context: Dict): - service = FacebookAdsReportingHook(facebook_conn_id=self.facebook_conn_id, - api_version=self.api_version) - rows = service.bulk_facebook_report(params=self.params, - fields=self.fields) + service = FacebookAdsReportingHook( + facebook_conn_id=self.facebook_conn_id, api_version=self.api_version + ) + rows = service.bulk_facebook_report(params=self.params, fields=self.fields) converted_rows = [dict(row) for row in rows] self.log.info("Facebook Returned %s data points", len(converted_rows)) @@ -120,10 +126,7 @@ def execute(self, context: Dict): writer.writeheader() writer.writerows(converted_rows) csvfile.flush() - hook = GCSHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) hook.upload( bucket_name=self.bucket_name, object_name=self.object_name, diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index a39b5cf936b91..b441f4eaa5523 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -160,45 +160,54 @@ class GCSToBigQueryOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket', 'source_objects', - 'schema_object', 'destination_project_dataset_table', 'impersonation_chain',) + + template_fields = ( + 'bucket', + 'source_objects', + 'schema_object', + 'destination_project_dataset_table', + 'impersonation_chain', + ) template_ext = ('.sql',) ui_color = '#f0eee4' # pylint: disable=too-many-locals,too-many-arguments @apply_defaults - def __init__(self, *, - bucket, - source_objects, - destination_project_dataset_table, - schema_fields=None, - schema_object=None, - source_format='CSV', - compression='NONE', - create_disposition='CREATE_IF_NEEDED', - skip_leading_rows=0, - write_disposition='WRITE_EMPTY', - field_delimiter=',', - max_bad_records=0, - quote_character=None, - ignore_unknown_values=False, - allow_quoted_newlines=False, - allow_jagged_rows=False, - encoding="UTF-8", - max_id_key=None, - bigquery_conn_id='google_cloud_default', - google_cloud_storage_conn_id='google_cloud_default', - delegate_to=None, - schema_update_options=(), - src_fmt_configs=None, - external_table=False, - time_partitioning=None, - cluster_fields=None, - autodetect=True, - encryption_configuration=None, - location=None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs): + def __init__( + self, + *, + bucket, + source_objects, + destination_project_dataset_table, + schema_fields=None, + schema_object=None, + source_format='CSV', + compression='NONE', + create_disposition='CREATE_IF_NEEDED', + skip_leading_rows=0, + write_disposition='WRITE_EMPTY', + field_delimiter=',', + max_bad_records=0, + quote_character=None, + ignore_unknown_values=False, + allow_quoted_newlines=False, + allow_jagged_rows=False, + encoding="UTF-8", + max_id_key=None, + bigquery_conn_id='google_cloud_default', + google_cloud_storage_conn_id='google_cloud_default', + delegate_to=None, + schema_update_options=(), + src_fmt_configs=None, + external_table=False, + time_partitioning=None, + cluster_fields=None, + autodetect=True, + encryption_configuration=None, + location=None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): super().__init__(**kwargs) @@ -243,10 +252,12 @@ def __init__(self, *, self.impersonation_chain = impersonation_chain def execute(self, context): - bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, - delegate_to=self.delegate_to, - location=self.location, - impersonation_chain=self.impersonation_chain) + bq_hook = BigQueryHook( + bigquery_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) if not self.schema_fields: if self.schema_object and self.source_format != 'DATASTORE_BACKUP': @@ -255,20 +266,20 @@ def execute(self, context): delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - schema_fields = json.loads(gcs_hook.download( - self.bucket, - self.schema_object).decode("utf-8")) + schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8")) elif self.schema_object is None and self.autodetect is False: - raise AirflowException('At least one of `schema_fields`, ' - '`schema_object`, or `autodetect` must be passed.') + raise AirflowException( + 'At least one of `schema_fields`, ' '`schema_object`, or `autodetect` must be passed.' + ) else: schema_fields = None else: schema_fields = self.schema_fields - source_uris = ['gs://{}/{}'.format(self.bucket, source_object) - for source_object in self.source_objects] + source_uris = [ + 'gs://{}/{}'.format(self.bucket, source_object) for source_object in self.source_objects + ] conn = bq_hook.get_conn() cursor = conn.cursor() @@ -288,7 +299,7 @@ def execute(self, context): allow_jagged_rows=self.allow_jagged_rows, encoding=self.encoding, src_fmt_configs=self.src_fmt_configs, - encryption_configuration=self.encryption_configuration + encryption_configuration=self.encryption_configuration, ) else: cursor.run_load( @@ -311,7 +322,8 @@ def execute(self, context): src_fmt_configs=self.src_fmt_configs, time_partitioning=self.time_partitioning, cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration) + encryption_configuration=self.encryption_configuration, + ) if cursor.use_legacy_sql: escaped_table_name = f'[{self.destination_project_dataset_table}]' @@ -319,12 +331,12 @@ def execute(self, context): escaped_table_name = f'`{self.destination_project_dataset_table}`' if self.max_id_key: - cursor.execute('SELECT MAX({}) FROM {}'.format( - self.max_id_key, - escaped_table_name)) + cursor.execute('SELECT MAX({}) FROM {}'.format(self.max_id_key, escaped_table_name)) row = cursor.fetchone() max_id = row[0] if row[0] else 0 self.log.info( 'Loaded BQ data with max %s.%s=%s', - self.destination_project_dataset_table, self.max_id_key, max_id + self.destination_project_dataset_table, + self.max_id_key, + max_id, ) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 3c9d0add4c05f..300d34cd1c613 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -175,33 +175,47 @@ class GCSToGCSOperator(BaseOperator): ) """ - template_fields = ('source_bucket', 'source_object', 'source_objects', 'destination_bucket', - 'destination_object', 'delimiter', 'impersonation_chain',) + + template_fields = ( + 'source_bucket', + 'source_object', + 'source_objects', + 'destination_bucket', + 'destination_object', + 'delimiter', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - source_bucket, - source_object=None, - source_objects=None, - destination_bucket=None, - destination_object=None, - delimiter=None, - move_object=False, - replace=True, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - last_modified_time=None, - maximum_modified_time=None, - is_older_than=None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs): + def __init__( + self, + *, # pylint: disable=too-many-arguments + source_bucket, + source_object=None, + source_objects=None, + destination_bucket=None, + destination_object=None, + delimiter=None, + move_object=False, + replace=True, + gcp_conn_id='google_cloud_default', + google_cloud_storage_conn_id=None, + delegate_to=None, + last_modified_time=None, + maximum_modified_time=None, + is_older_than=None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.source_bucket = source_bucket @@ -227,9 +241,11 @@ def execute(self, context): impersonation_chain=self.impersonation_chain, ) if self.source_objects and self.source_object: - error_msg = "You can either set source_object parameter or source_objects " \ - "parameter but not both. Found source_object={} and" \ - " source_objects={}".format(self.source_object, self.source_objects) + error_msg = ( + "You can either set source_object parameter or source_objects " + "parameter but not both. Found source_object={} and" + " source_objects={}".format(self.source_object, self.source_objects) + ) raise AirflowException(error_msg) if not self.source_object and not self.source_objects: @@ -245,8 +261,8 @@ def execute(self, context): if self.destination_bucket is None: self.log.warning( - 'destination_bucket is None. Defaulting it to source_bucket (%s)', - self.source_bucket) + 'destination_bucket is None. Defaulting it to source_bucket (%s)', self.source_bucket + ) self.destination_bucket = self.source_bucket # An empty source_object means to copy all files @@ -272,21 +288,25 @@ def _copy_source_without_wildcard(self, hook, prefix): # and copy directly if len(objects) == 0 and prefix: if hook.exists(self.source_bucket, prefix): - self._copy_single_object(hook=hook, source_object=prefix, - destination_object=self.destination_object) + self._copy_single_object( + hook=hook, source_object=prefix, destination_object=self.destination_object + ) for source_obj in objects: if self.destination_object is None: destination_object = source_obj else: destination_object = self.destination_object - self._copy_single_object(hook=hook, source_object=source_obj, - destination_object=destination_object) + self._copy_single_object( + hook=hook, source_object=source_obj, destination_object=destination_object + ) def _copy_source_with_wildcard(self, hook, prefix): total_wildcards = prefix.count(WILDCARD) if total_wildcards > 1: - error_msg = "Only one wildcard '*' is allowed in source_object parameter. " \ - "Found {} in {}.".format(total_wildcards, prefix) + error_msg = ( + "Only one wildcard '*' is allowed in source_object parameter. " + "Found {} in {}.".format(total_wildcards, prefix) + ) raise AirflowException(error_msg) self.log.info('Delimiter ignored because wildcard is in prefix') @@ -301,30 +321,24 @@ def _copy_source_with_wildcard(self, hook, prefix): objects = set(objects) - set(existing_objects) if len(objects) > 0: - self.log.info( - '%s files are going to be synced: %s.', len(objects), objects - ) + self.log.info('%s files are going to be synced: %s.', len(objects), objects) else: - self.log.info( - 'There are no new files to sync. Have a nice day!') + self.log.info('There are no new files to sync. Have a nice day!') for source_object in objects: if self.destination_object is None: destination_object = source_object else: - destination_object = source_object.replace(prefix_, - self.destination_object, 1) + destination_object = source_object.replace(prefix_, self.destination_object, 1) - self._copy_single_object(hook=hook, source_object=source_object, - destination_object=destination_object) + self._copy_single_object( + hook=hook, source_object=source_object, destination_object=destination_object + ) def _copy_single_object(self, hook, source_object, destination_object): if self.is_older_than: # Here we check if the given object is older than the given time # If given, last_modified_time and maximum_modified_time is ignored - if hook.is_older_than(self.source_bucket, - source_object, - self.is_older_than - ): + if hook.is_older_than(self.source_bucket, source_object, self.is_older_than): self.log.info("Object is older than %s seconds ago", self.is_older_than) else: self.log.debug("Object is not older than %s seconds ago", self.is_older_than) @@ -332,43 +346,46 @@ def _copy_single_object(self, hook, source_object, destination_object): elif self.last_modified_time and self.maximum_modified_time: # check to see if object was modified between last_modified_time and # maximum_modified_time - if hook.is_updated_between(self.source_bucket, - source_object, - self.last_modified_time, - self.maximum_modified_time - ): - self.log.info("Object has been modified between %s and %s", - self.last_modified_time, self.maximum_modified_time) + if hook.is_updated_between( + self.source_bucket, source_object, self.last_modified_time, self.maximum_modified_time + ): + self.log.info( + "Object has been modified between %s and %s", + self.last_modified_time, + self.maximum_modified_time, + ) else: - self.log.debug("Object was not modified between %s and %s", - self.last_modified_time, self.maximum_modified_time) + self.log.debug( + "Object was not modified between %s and %s", + self.last_modified_time, + self.maximum_modified_time, + ) return elif self.last_modified_time is not None: # Check to see if object was modified after last_modified_time - if hook.is_updated_after(self.source_bucket, - source_object, - self.last_modified_time): + if hook.is_updated_after(self.source_bucket, source_object, self.last_modified_time): self.log.info("Object has been modified after %s ", self.last_modified_time) else: self.log.debug("Object was not modified after %s ", self.last_modified_time) return elif self.maximum_modified_time is not None: # Check to see if object was modified before maximum_modified_time - if hook.is_updated_before(self.source_bucket, - source_object, - self.maximum_modified_time): + if hook.is_updated_before(self.source_bucket, source_object, self.maximum_modified_time): self.log.info("Object has been modified before %s ", self.maximum_modified_time) else: self.log.debug("Object was not modified before %s ", self.maximum_modified_time) return - self.log.info('Executing copy of gs://%s/%s to gs://%s/%s', - self.source_bucket, source_object, - self.destination_bucket, destination_object) + self.log.info( + 'Executing copy of gs://%s/%s to gs://%s/%s', + self.source_bucket, + source_object, + self.destination_bucket, + destination_object, + ) - hook.rewrite(self.source_bucket, source_object, - self.destination_bucket, destination_object) + hook.rewrite(self.source_bucket, source_object, self.destination_bucket, destination_object) if self.move_object: hook.delete(self.source_bucket, source_object) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_local.py b/airflow/providers/google/cloud/transfers/gcs_to_local.py index efbcc89fe46b1..471140794e271 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_local.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_local.py @@ -72,20 +72,30 @@ class GCSToLocalFilesystemOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('bucket', 'object', 'filename', 'store_to_xcom_key', 'impersonation_chain',) + + template_fields = ( + 'bucket', + 'object', + 'filename', + 'store_to_xcom_key', + 'impersonation_chain', + ) ui_color = '#f0eee4' @apply_defaults - def __init__(self, *, - bucket: str, - object_name: Optional[str] = None, - filename: Optional[str] = None, - store_to_xcom_key: Optional[str] = None, - gcp_conn_id: str = 'google_cloud_default', - google_cloud_storage_conn_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs) -> None: + def __init__( + self, + *, + bucket: str, + object_name: Optional[str] = None, + filename: Optional[str] = None, + store_to_xcom_key: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + google_cloud_storage_conn_id: Optional[str] = None, + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: # To preserve backward compatibility # TODO: Remove one day if object_name is None: @@ -101,21 +111,23 @@ def __init__(self, *, if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id super().__init__(**kwargs) self.bucket = bucket self.object = object_name self.filename = filename # noqa - self.store_to_xcom_key = store_to_xcom_key # noqa + self.store_to_xcom_key = store_to_xcom_key # noqa self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain def execute(self, context): - self.log.info('Executing download: %s, %s, %s', self.bucket, - self.object, self.filename) + self.log.info('Executing download: %s, %s, %s', self.bucket, self.object, self.filename) hook = GCSHook( google_cloud_storage_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, @@ -123,15 +135,10 @@ def execute(self, context): ) if self.store_to_xcom_key: - file_bytes = hook.download(bucket_name=self.bucket, - object_name=self.object) + file_bytes = hook.download(bucket_name=self.bucket, object_name=self.object) if sys.getsizeof(file_bytes) < MAX_XCOM_SIZE: context['ti'].xcom_push(key=self.store_to_xcom_key, value=file_bytes) else: - raise AirflowException( - 'The size of the downloaded file is too large to push to XCom!' - ) + raise AirflowException('The size of the downloaded file is too large to push to XCom!') else: - hook.download(bucket_name=self.bucket, - object_name=self.object, - filename=self.filename) + hook.download(bucket_name=self.bucket, object_name=self.object, filename=self.filename) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_sftp.py b/airflow/providers/google/cloud/transfers/gcs_to_sftp.py index e766ad7211596..ef1a4206732fe 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_sftp.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_sftp.py @@ -76,13 +76,19 @@ class GCSToSFTPOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("source_bucket", "source_object", "destination_path", "impersonation_chain",) + template_fields = ( + "source_bucket", + "source_object", + "destination_path", + "impersonation_chain", + ) ui_color = "#f0eee4" # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, + self, + *, source_bucket: str, source_object: str, destination_path: str, @@ -91,7 +97,7 @@ def __init__( sftp_conn_id: str = "ssh_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -123,43 +129,26 @@ def execute(self, context): ) prefix, delimiter = self.source_object.split(WILDCARD, 1) - objects = gcs_hook.list( - self.source_bucket, prefix=prefix, delimiter=delimiter - ) + objects = gcs_hook.list(self.source_bucket, prefix=prefix, delimiter=delimiter) for source_object in objects: destination_path = os.path.join(self.destination_path, source_object) - self._copy_single_object( - gcs_hook, sftp_hook, source_object, destination_path - ) + self._copy_single_object(gcs_hook, sftp_hook, source_object, destination_path) - self.log.info( - "Done. Uploaded '%d' files to %s", len(objects), self.destination_path - ) + self.log.info("Done. Uploaded '%d' files to %s", len(objects), self.destination_path) else: destination_path = os.path.join(self.destination_path, self.source_object) - self._copy_single_object( - gcs_hook, sftp_hook, self.source_object, destination_path - ) - self.log.info( - "Done. Uploaded '%s' file to %s", self.source_object, destination_path - ) + self._copy_single_object(gcs_hook, sftp_hook, self.source_object, destination_path) + self.log.info("Done. Uploaded '%s' file to %s", self.source_object, destination_path) def _copy_single_object( - self, - gcs_hook: GCSHook, - sftp_hook: SFTPHook, - source_object: str, - destination_path: str, + self, gcs_hook: GCSHook, sftp_hook: SFTPHook, source_object: str, destination_path: str, ) -> None: """ Helper function to copy single object. """ self.log.info( - "Executing copy of gs://%s/%s to %s", - self.source_bucket, - source_object, - destination_path, + "Executing copy of gs://%s/%s to %s", self.source_bucket, source_object, destination_path, ) dir_path = os.path.dirname(destination_path) @@ -167,14 +156,10 @@ def _copy_single_object( with NamedTemporaryFile("w") as tmp: gcs_hook.download( - bucket_name=self.source_bucket, - object_name=source_object, - filename=tmp.name, + bucket_name=self.source_bucket, object_name=source_object, filename=tmp.name, ) sftp_hook.store_file(destination_path, tmp.name) if self.move_object: - self.log.info( - "Executing delete of gs://%s/%s", self.source_bucket, source_object - ) + self.log.info("Executing delete of gs://%s/%s", self.source_bucket, source_object) gcs_hook.delete(self.source_bucket, source_object) diff --git a/airflow/providers/google/cloud/transfers/local_to_gcs.py b/airflow/providers/google/cloud/transfers/local_to_gcs.py index c0f35d478d9aa..4c5008411a594 100644 --- a/airflow/providers/google/cloud/transfers/local_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/local_to_gcs.py @@ -67,26 +67,38 @@ class LocalFilesystemToGCSOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('src', 'dst', 'bucket', 'impersonation_chain',) + + template_fields = ( + 'src', + 'dst', + 'bucket', + 'impersonation_chain', + ) @apply_defaults - def __init__(self, *, - src, - dst, - bucket, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - mime_type='application/octet-stream', - delegate_to=None, - gzip=False, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs): + def __init__( + self, + *, + src, + dst, + bucket, + gcp_conn_id='google_cloud_default', + google_cloud_storage_conn_id=None, + mime_type='application/octet-stream', + delegate_to=None, + gzip=False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.src = src @@ -111,13 +123,14 @@ def execute(self, context): filepaths = self.src if isinstance(self.src, list) else glob(self.src) if os.path.basename(self.dst): # path to a file if len(filepaths) > 1: # multiple file upload - raise ValueError("'dst' parameter references filepath. Please specifiy " - "directory (with trailing backslash) to upload multiple " - "files. e.g. /path/to/directory/") + raise ValueError( + "'dst' parameter references filepath. Please specifiy " + "directory (with trailing backslash) to upload multiple " + "files. e.g. /path/to/directory/" + ) object_paths = [self.dst] else: # directory is provided - object_paths = [os.path.join(self.dst, os.path.basename(filepath)) - for filepath in filepaths] + object_paths = [os.path.join(self.dst, os.path.basename(filepath)) for filepath in filepaths] for filepath, object_path in zip(filepaths, object_paths): hook.upload( diff --git a/airflow/providers/google/cloud/transfers/mssql_to_gcs.py b/airflow/providers/google/cloud/transfers/mssql_to_gcs.py index face7667e05ac..1614c9da5ab58 100644 --- a/airflow/providers/google/cloud/transfers/mssql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/mssql_to_gcs.py @@ -49,18 +49,13 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator): dag=dag ) """ + ui_color = '#e0a98c' - type_map = { - 3: 'INTEGER', - 4: 'TIMESTAMP', - 5: 'NUMERIC' - } + type_map = {3: 'INTEGER', 4: 'TIMESTAMP', 5: 'NUMERIC'} @apply_defaults - def __init__(self, *, - mssql_conn_id='mssql_default', - **kwargs): + def __init__(self, *, mssql_conn_id='mssql_default', **kwargs): super().__init__(**kwargs) self.mssql_conn_id = mssql_conn_id diff --git a/airflow/providers/google/cloud/transfers/mysql_to_gcs.py b/airflow/providers/google/cloud/transfers/mysql_to_gcs.py index 5d98a2d175e09..607149e141510 100644 --- a/airflow/providers/google/cloud/transfers/mysql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/mysql_to_gcs.py @@ -41,6 +41,7 @@ class MySQLToGCSOperator(BaseSQLToGCSOperator): default timezone. :type ensure_utc: bool """ + ui_color = '#a0e08c' type_map = { @@ -62,10 +63,7 @@ class MySQLToGCSOperator(BaseSQLToGCSOperator): } @apply_defaults - def __init__(self, *, - mysql_conn_id='mysql_default', - ensure_utc=False, - **kwargs): + def __init__(self, *, mysql_conn_id='mysql_default', ensure_utc=False, **kwargs): super().__init__(**kwargs) self.mysql_conn_id = mysql_conn_id self.ensure_utc = ensure_utc diff --git a/airflow/providers/google/cloud/transfers/postgres_to_gcs.py b/airflow/providers/google/cloud/transfers/postgres_to_gcs.py index 816ad277552fb..102a26a7dca0b 100644 --- a/airflow/providers/google/cloud/transfers/postgres_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/postgres_to_gcs.py @@ -38,6 +38,7 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator): :param postgres_conn_id: Reference to a specific Postgres hook. :type postgres_conn_id: str """ + ui_color = '#a0e08c' type_map = { @@ -58,9 +59,7 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator): } @apply_defaults - def __init__(self, *, - postgres_conn_id='postgres_default', - **kwargs): + def __init__(self, *, postgres_conn_id='postgres_default', **kwargs): super().__init__(**kwargs) self.postgres_conn_id = postgres_conn_id @@ -78,8 +77,7 @@ def field_to_bigquery(self, field): return { 'name': field[0], 'type': self.type_map.get(field[1], "STRING"), - 'mode': 'REPEATED' if field[1] in (1009, 1005, 1007, - 1016) else 'NULLABLE' + 'mode': 'REPEATED' if field[1] in (1009, 1005, 1007, 1016) else 'NULLABLE', } def convert_type(self, value, schema_type): @@ -92,10 +90,11 @@ def convert_type(self, value, schema_type): return pendulum.parse(value.isoformat()).float_timestamp if isinstance(value, datetime.time): formated_time = time.strptime(str(value), "%H:%M:%S") - return int(datetime.timedelta( - hours=formated_time.tm_hour, - minutes=formated_time.tm_min, - seconds=formated_time.tm_sec).total_seconds()) + return int( + datetime.timedelta( + hours=formated_time.tm_hour, minutes=formated_time.tm_min, seconds=formated_time.tm_sec + ).total_seconds() + ) if isinstance(value, dict): return json.dumps(value) if isinstance(value, Decimal): diff --git a/airflow/providers/google/cloud/transfers/presto_to_gcs.py b/airflow/providers/google/cloud/transfers/presto_to_gcs.py index f94024fdbebdd..f0ab92231cf93 100644 --- a/airflow/providers/google/cloud/transfers/presto_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/presto_to_gcs.py @@ -179,11 +179,7 @@ class PrestoToGCSOperator(BaseSQLToGCSOperator): } @apply_defaults - def __init__( - self, *, - presto_conn_id: str = "presto_default", - **kwargs - ): + def __init__(self, *, presto_conn_id: str = "presto_default", **kwargs): super().__init__(**kwargs) self.presto_conn_id = presto_conn_id diff --git a/airflow/providers/google/cloud/transfers/s3_to_gcs.py b/airflow/providers/google/cloud/transfers/s3_to_gcs.py index fadf99949142e..0ca3889e2d235 100644 --- a/airflow/providers/google/cloud/transfers/s3_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/s3_to_gcs.py @@ -97,38 +97,44 @@ class S3ToGCSOperator(S3ListOperator): templated, so you can use variables in them if you wish. """ - template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'dest_gcs', - 'google_impersonation_chain',) + template_fields: Iterable[str] = ( + 'bucket', + 'prefix', + 'delimiter', + 'dest_gcs', + 'google_impersonation_chain', + ) ui_color = '#e09411' # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - bucket, - prefix='', - delimiter='', - aws_conn_id='aws_default', - verify=None, - gcp_conn_id='google_cloud_default', - dest_gcs_conn_id=None, - dest_gcs=None, - delegate_to=None, - replace=False, - gzip=False, - google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs): - - super().__init__( - bucket=bucket, - prefix=prefix, - delimiter=delimiter, - aws_conn_id=aws_conn_id, - **kwargs) + def __init__( + self, + *, + bucket, + prefix='', + delimiter='', + aws_conn_id='aws_default', + verify=None, + gcp_conn_id='google_cloud_default', + dest_gcs_conn_id=None, + dest_gcs=None, + delegate_to=None, + replace=False, + gzip=False, + google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): + + super().__init__(bucket=bucket, prefix=prefix, delimiter=delimiter, aws_conn_id=aws_conn_id, **kwargs) if dest_gcs_conn_id: warnings.warn( "The dest_gcs_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = dest_gcs_conn_id self.gcp_conn_id = gcp_conn_id @@ -143,9 +149,11 @@ def __init__(self, *, self.log.info( 'Destination Google Cloud Storage path is not a valid ' '"directory", define a path that ends with a slash "/" or ' - 'leave it empty for the root of the bucket.') - raise AirflowException('The destination Google Cloud Storage path ' - 'must end with a slash "/" or be empty.') + 'leave it empty for the root of the bucket.' + ) + raise AirflowException( + 'The destination Google Cloud Storage path ' 'must end with a slash "/" or be empty.' + ) def execute(self, context): # use the super method to list all the files in an S3 bucket/key @@ -163,8 +171,7 @@ def execute(self, context): # and only keep those files which are present in # S3 and not in Google Cloud Storage bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs) - existing_files_prefixed = gcs_hook.list( - bucket_name, prefix=object_prefix) + existing_files_prefixed = gcs_hook.list(bucket_name, prefix=object_prefix) existing_files = [] @@ -176,18 +183,15 @@ def execute(self, context): # Remove the object prefix from all object string paths for f in existing_files_prefixed: if f.startswith(object_prefix): - existing_files.append(f[len(object_prefix):]) + existing_files.append(f[len(object_prefix) :]) else: existing_files.append(f) files = list(set(files) - set(existing_files)) if len(files) > 0: - self.log.info( - '%s files are going to be synced: %s.', len(files), files - ) + self.log.info('%s files are going to be synced: %s.', len(files), files) else: - self.log.info( - 'There are no new files to sync. Have a nice day!') + self.log.info('There are no new files to sync. Have a nice day!') if files: hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) @@ -200,8 +204,7 @@ def execute(self, context): file_object.download_fileobj(f) f.flush() - dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url( - self.dest_gcs) + dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs) # There will always be a '/' before file because it is # enforced at instantiation time dest_gcs_object = dest_gcs_object_prefix + file @@ -216,13 +219,9 @@ def execute(self, context): gcs_hook.upload(dest_gcs_bucket, dest_gcs_object, f.name, gzip=self.gzip) - self.log.info( - "All done, uploaded %d files to Google Cloud Storage", - len(files)) + self.log.info("All done, uploaded %d files to Google Cloud Storage", len(files)) else: - self.log.info( - 'In sync, no files needed to be uploaded to Google Cloud' - 'Storage') + self.log.info('In sync, no files needed to be uploaded to Google Cloud' 'Storage') return files diff --git a/airflow/providers/google/cloud/transfers/sftp_to_gcs.py b/airflow/providers/google/cloud/transfers/sftp_to_gcs.py index 930cdfa401140..ca42c737346f4 100644 --- a/airflow/providers/google/cloud/transfers/sftp_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sftp_to_gcs.py @@ -81,12 +81,17 @@ class SFTPToGCSOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("source_path", "destination_path", "destination_bucket", - "impersonation_chain",) + template_fields = ( + "source_path", + "destination_path", + "destination_bucket", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, source_path: str, destination_bucket: str, destination_path: Optional[str] = None, @@ -97,7 +102,7 @@ def __init__( gzip: bool = False, move_object: bool = False, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) @@ -132,9 +137,7 @@ def execute(self, context): prefix, delimiter = self.source_path.split(WILDCARD, 1) base_path = os.path.dirname(prefix) - files, _, _ = sftp_hook.get_tree_map( - base_path, prefix=prefix, delimiter=delimiter - ) + files, _, _ = sftp_hook.get_tree_map(base_path, prefix=prefix, delimiter=delimiter) for file in files: destination_path = file.replace(base_path, self.destination_path, 1) @@ -142,29 +145,18 @@ def execute(self, context): else: destination_object = ( - self.destination_path - if self.destination_path - else self.source_path.rsplit("/", 1)[1] - ) - self._copy_single_object( - gcs_hook, sftp_hook, self.source_path, destination_object + self.destination_path if self.destination_path else self.source_path.rsplit("/", 1)[1] ) + self._copy_single_object(gcs_hook, sftp_hook, self.source_path, destination_object) def _copy_single_object( - self, - gcs_hook: GCSHook, - sftp_hook: SFTPHook, - source_path: str, - destination_object: str, + self, gcs_hook: GCSHook, sftp_hook: SFTPHook, source_path: str, destination_object: str, ) -> None: """ Helper function to copy single object. """ self.log.info( - "Executing copy of %s to gs://%s/%s", - source_path, - self.destination_bucket, - destination_object, + "Executing copy of %s to gs://%s/%s", source_path, self.destination_bucket, destination_object, ) with NamedTemporaryFile("w") as tmp: diff --git a/airflow/providers/google/cloud/transfers/sheets_to_gcs.py b/airflow/providers/google/cloud/transfers/sheets_to_gcs.py index a39e73008028d..f56f93e974e20 100644 --- a/airflow/providers/google/cloud/transfers/sheets_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sheets_to_gcs.py @@ -61,12 +61,18 @@ class GoogleSheetsToGCSOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ["spreadsheet_id", "destination_bucket", "destination_path", "sheet_filter", - "impersonation_chain", ] + template_fields = [ + "spreadsheet_id", + "destination_bucket", + "destination_path", + "sheet_filter", + "impersonation_chain", + ] @apply_defaults def __init__( - self, *, + self, + *, spreadsheet_id: str, destination_bucket: str, sheet_filter: Optional[List[str]] = None, @@ -86,21 +92,13 @@ def __init__( self.impersonation_chain = impersonation_chain def _upload_data( - self, - gcs_hook: GCSHook, - hook: GSheetsHook, - sheet_range: str, - sheet_values: List[Any], + self, gcs_hook: GCSHook, hook: GSheetsHook, sheet_range: str, sheet_values: List[Any], ) -> str: # Construct destination file path sheet = hook.get_spreadsheet(self.spreadsheet_id) - file_name = f"{sheet['properties']['title']}_{sheet_range}.csv".replace( - " ", "_" - ) + file_name = f"{sheet['properties']['title']}_{sheet_range}.csv".replace(" ", "_") dest_file_name = ( - f"{self.destination_path.strip('/')}/{file_name}" - if self.destination_path - else file_name + f"{self.destination_path.strip('/')}/{file_name}" if self.destination_path else file_name ) with NamedTemporaryFile("w+") as temp_file: @@ -111,9 +109,7 @@ def _upload_data( # Upload to GCS gcs_hook.upload( - bucket_name=self.destination_bucket, - object_name=dest_file_name, - filename=temp_file.name, + bucket_name=self.destination_bucket, object_name=dest_file_name, filename=temp_file.name, ) return dest_file_name @@ -135,12 +131,8 @@ def execute(self, context): spreadsheet_id=self.spreadsheet_id, sheet_filter=self.sheet_filter ) for sheet_range in sheet_titles: - data = sheet_hook.get_values( - spreadsheet_id=self.spreadsheet_id, range_=sheet_range - ) - gcs_path_to_file = self._upload_data( - gcs_hook, sheet_hook, sheet_range, data - ) + data = sheet_hook.get_values(spreadsheet_id=self.spreadsheet_id, range_=sheet_range) + gcs_path_to_file = self._upload_data(gcs_hook, sheet_hook, sheet_range, data) destination_array.append(gcs_path_to_file) self.xcom_push(context, "destination_objects", destination_array) diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py index 578602ee39481..70d172f4f294d 100644 --- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -85,34 +85,48 @@ class BaseSQLToGCSOperator(BaseOperator): account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('sql', 'bucket', 'filename', 'schema_filename', 'schema', 'parameters', - 'impersonation_chain',) + + template_fields = ( + 'sql', + 'bucket', + 'filename', + 'schema_filename', + 'schema', + 'parameters', + 'impersonation_chain', + ) template_ext = ('.sql',) ui_color = '#a0e08c' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments - sql, - bucket, - filename, - schema_filename=None, - approx_max_file_size_bytes=1900000000, - export_format='json', - field_delimiter=',', - gzip=False, - schema=None, - parameters=None, - gcp_conn_id='google_cloud_default', - google_cloud_storage_conn_id=None, - delegate_to=None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs): + def __init__( + self, + *, # pylint: disable=too-many-arguments + sql, + bucket, + filename, + schema_filename=None, + approx_max_file_size_bytes=1900000000, + export_format='json', + field_delimiter=',', + gzip=False, + schema=None, + parameters=None, + gcp_conn_id='google_cloud_default', + google_cloud_storage_conn_id=None, + delegate_to=None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ): super().__init__(**kwargs) if google_cloud_storage_conn_id: warnings.warn( "The google_cloud_storage_conn_id parameter has been deprecated. You should pass " - "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3) + "the gcp_conn_id parameter.", + DeprecationWarning, + stacklevel=3, + ) gcp_conn_id = google_cloud_storage_conn_id self.sql = sql @@ -154,10 +168,7 @@ def execute(self, context): def convert_types(self, schema, col_type_dict, row): """Convert values from DBAPI to output-friendly formats.""" - return [ - self.convert_type(value, col_type_dict.get(name)) - for name, value in zip(schema, row) - ] + return [self.convert_type(value, col_type_dict.get(name)) for name, value in zip(schema, row)] def _write_local_data_files(self, cursor): """ @@ -175,11 +186,13 @@ def _write_local_data_files(self, cursor): file_mime_type = 'text/csv' else: file_mime_type = 'application/json' - files_to_upload = [{ - 'file_name': self.filename.format(file_no), - 'file_handle': tmp_file_handle, - 'file_mime_type': file_mime_type - }] + files_to_upload = [ + { + 'file_name': self.filename.format(file_no), + 'file_handle': tmp_file_handle, + 'file_mime_type': file_mime_type, + } + ] self.log.info("Current file count: %d", len(files_to_upload)) if self.export_format == 'csv': @@ -206,11 +219,13 @@ def _write_local_data_files(self, cursor): if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: file_no += 1 tmp_file_handle = NamedTemporaryFile(delete=True) - files_to_upload.append({ - 'file_name': self.filename.format(file_no), - 'file_handle': tmp_file_handle, - 'file_mime_type': file_mime_type - }) + files_to_upload.append( + { + 'file_name': self.filename.format(file_no), + 'file_handle': tmp_file_handle, + 'file_mime_type': file_mime_type, + } + ) self.log.info("Current file count: %d", len(files_to_upload)) if self.export_format == 'csv': csv_writer = self._configure_csv_file(tmp_file_handle, schema) @@ -221,8 +236,7 @@ def _configure_csv_file(self, file_handle, schema): """Configure a csv writer with the file_handle and write schema as headers for the new file. """ - csv_writer = csv.writer(file_handle, encoding='utf-8', - delimiter=self.field_delimiter) + csv_writer = csv.writer(file_handle, encoding='utf-8', delimiter=self.field_delimiter) csv_writer.writerow(schema) return csv_writer @@ -248,16 +262,17 @@ def _get_col_type_dict(self): elif isinstance(self.schema, list): schema = self.schema elif self.schema is not None: - self.log.warning('Using default schema due to unexpected type.' - 'Should be a string or list.') + self.log.warning('Using default schema due to unexpected type.' 'Should be a string or list.') col_type_dict = {} try: col_type_dict = {col['name']: col['type'] for col in schema} except KeyError: - self.log.warning('Using default schema due to missing name or type. Please ' - 'refer to: https://cloud.google.com/bigquery/docs/schemas' - '#specifying_a_json_schema_file') + self.log.warning( + 'Using default schema due to missing name or type. Please ' + 'refer to: https://cloud.google.com/bigquery/docs/schemas' + '#specifying_a_json_schema_file' + ) return col_type_dict def _write_local_schema_file(self, cursor): @@ -303,7 +318,10 @@ def _upload_to_gcs(self, files_to_upload): impersonation_chain=self.impersonation_chain, ) for tmp_file in files_to_upload: - hook.upload(self.bucket, tmp_file.get('file_name'), - tmp_file.get('file_handle').name, - mime_type=tmp_file.get('file_mime_type'), - gzip=self.gzip if tmp_file.get('file_name') != self.schema_filename else False) + hook.upload( + self.bucket, + tmp_file.get('file_name'), + tmp_file.get('file_handle').name, + mime_type=tmp_file.get('file_mime_type'), + gzip=self.gzip if tmp_file.get('file_name') != self.schema_filename else False, + ) diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py b/airflow/providers/google/cloud/utils/credentials_provider.py index bdeef6be83104..fcbdb498ed7eb 100644 --- a/airflow/providers/google/cloud/utils/credentials_provider.py +++ b/airflow/providers/google/cloud/utils/credentials_provider.py @@ -76,9 +76,7 @@ def build_gcp_conn( @contextmanager -def provide_gcp_credentials( - key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None -): +def provide_gcp_credentials(key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None): """ Context manager that provides a GCP credentials for application supporting `Application Default Credentials (ADC) strategy `__. @@ -95,9 +93,7 @@ def provide_gcp_credentials( raise ValueError("Please provide `key_file_path` or `key_file_dict`.") if key_file_path and key_file_path.endswith(".p12"): - raise AirflowException( - "Legacy P12 key file are not supported, use a JSON key file." - ) + raise AirflowException("Legacy P12 key file are not supported, use a JSON key file.") with tempfile.NamedTemporaryFile(mode="w+t") as conf_file: if not key_file_path and key_file_dict: @@ -114,9 +110,7 @@ def provide_gcp_credentials( @contextmanager def provide_gcp_connection( - key_file_path: Optional[str] = None, - scopes: Optional[Sequence] = None, - project_id: Optional[str] = None, + key_file_path: Optional[str] = None, scopes: Optional[Sequence] = None, project_id: Optional[str] = None, ): """ Context manager that provides a temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT` @@ -131,13 +125,9 @@ def provide_gcp_connection( :type project_id: str """ if key_file_path and key_file_path.endswith(".p12"): - raise AirflowException( - "Legacy P12 key file are not supported, use a JSON key file." - ) + raise AirflowException("Legacy P12 key file are not supported, use a JSON key file.") - conn = build_gcp_conn( - scopes=scopes, key_file_path=key_file_path, project_id=project_id - ) + conn = build_gcp_conn(scopes=scopes, key_file_path=key_file_path, project_id=project_id) with patch_environ({AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: conn}): yield @@ -145,9 +135,7 @@ def provide_gcp_connection( @contextmanager def provide_gcp_conn_and_credentials( - key_file_path: Optional[str] = None, - scopes: Optional[Sequence] = None, - project_id: Optional[str] = None, + key_file_path: Optional[str] = None, scopes: Optional[Sequence] = None, project_id: Optional[str] = None, ): """ Context manager that provides both: @@ -209,6 +197,7 @@ class to configure Logger. granting the role to the last account from the list. :type delegates: Sequence[str] """ + def __init__( self, key_path: Optional[str] = None, @@ -263,7 +252,7 @@ def get_credentials_and_project(self): source_credentials=credentials, target_principal=self.target_principal, delegates=self.delegates, - target_scopes=self.scopes + target_scopes=self.scopes, ) project_id = _get_project_id_from_service_account_email(self.target_principal) @@ -275,26 +264,22 @@ def _get_credentials_using_keyfile_dict(self): # Depending on how the JSON was formatted, it may contain # escaped newlines. Convert those to actual newlines. self.keyfile_dict['private_key'] = self.keyfile_dict['private_key'].replace('\\n', '\n') - credentials = ( - google.oauth2.service_account.Credentials.from_service_account_info( - self.keyfile_dict, scopes=self.scopes) + credentials = google.oauth2.service_account.Credentials.from_service_account_info( + self.keyfile_dict, scopes=self.scopes ) project_id = credentials.project_id return credentials, project_id def _get_credentials_using_key_path(self): if self.key_path.endswith('.p12'): - raise AirflowException( - 'Legacy P12 key file are not supported, use a JSON key file.' - ) + raise AirflowException('Legacy P12 key file are not supported, use a JSON key file.') if not self.key_path.endswith('.json'): raise AirflowException('Unrecognised extension for key file.') self._log_debug('Getting connection using JSON key file %s', self.key_path) - credentials = ( - google.oauth2.service_account.Credentials.from_service_account_file( - self.key_path, scopes=self.scopes) + credentials = google.oauth2.service_account.Credentials.from_service_account_file( + self.key_path, scopes=self.scopes ) project_id = credentials.project_id return credentials, project_id @@ -315,9 +300,7 @@ def _log_debug(self, *args, **kwargs): self.log.debug(*args, **kwargs) -def get_credentials_and_project_id( - *args, **kwargs -) -> Tuple[google.auth.credentials.Credentials, str]: +def get_credentials_and_project_id(*args, **kwargs) -> Tuple[google.auth.credentials.Credentials, str]: """ Returns the Credentials object for Google API and the associated project_id. """ @@ -334,8 +317,7 @@ def _get_scopes(scopes: Optional[str] = None) -> Sequence[str]: :return: Returns the scope defined in the connection configuration, or the default scope :rtype: Sequence[str] """ - return [s.strip() for s in scopes.split(',')] \ - if scopes else _DEFAULT_SCOPES + return [s.strip() for s in scopes.split(',')] if scopes else _DEFAULT_SCOPES def _get_target_principal_and_delegates( @@ -375,5 +357,6 @@ def _get_project_id_from_service_account_email(service_account_email: str) -> st try: return service_account_email.split('@')[1].split('.')[0] except IndexError: - raise AirflowException(f"Could not extract project_id from service account's email: " - f"{service_account_email}.") + raise AirflowException( + f"Could not extract project_id from service account's email: " f"{service_account_email}." + ) diff --git a/airflow/providers/google/cloud/utils/field_sanitizer.py b/airflow/providers/google/cloud/utils/field_sanitizer.py index f9ebc976b960c..0dd8180e00be1 100644 --- a/airflow/providers/google/cloud/utils/field_sanitizer.py +++ b/airflow/providers/google/cloud/utils/field_sanitizer.py @@ -116,6 +116,7 @@ class GcpBodyFieldSanitizer(LoggingMixin): :type sanitize_specs: list[str] """ + def __init__(self, sanitize_specs: List[str]) -> None: super().__init__() self._sanitize_specs = sanitize_specs @@ -140,21 +141,23 @@ def _sanitize(self, dictionary, remaining_field_spec, current_path): "The field %s is missing in %s at the path %s. ", field_name, dictionary, current_path ) elif isinstance(child, dict): - self._sanitize(child, remaining_path, "{}.{}".format( - current_path, field_name)) + self._sanitize(child, remaining_path, "{}.{}".format(current_path, field_name)) elif isinstance(child, list): for index, elem in enumerate(child): if not isinstance(elem, dict): self.log.warning( "The field %s element at index %s is of wrong type. " "It should be dict and is %s. Skipping it.", - current_path, index, elem) - self._sanitize(elem, remaining_path, "{}.{}[{}]".format( - current_path, field_name, index)) + current_path, + index, + elem, + ) + self._sanitize(elem, remaining_path, "{}.{}[{}]".format(current_path, field_name, index)) else: self.log.warning( "The field %s is of wrong type. It should be dict or list and it is %s. Skipping it.", - current_path, child + current_path, + child, ) def sanitize(self, body): diff --git a/airflow/providers/google/cloud/utils/field_validator.py b/airflow/providers/google/cloud/utils/field_validator.py index 0eb4497fdee9e..9d2d0159e461b 100644 --- a/airflow/providers/google/cloud/utils/field_validator.py +++ b/airflow/providers/google/cloud/utils/field_validator.py @@ -160,17 +160,18 @@ def _int_greater_than_zero(value): EXAMPLE_VALIDATION_SPECIFICATION = [ dict(name="name", allow_empty=False), dict(name="description", allow_empty=False, optional=True), - dict(name="availableMemoryMb", custom_validation=_int_greater_than_zero, - optional=True), + dict(name="availableMemoryMb", custom_validation=_int_greater_than_zero, optional=True), dict(name="labels", optional=True, type="dict"), - dict(name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'^.+$'), - dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), - dict(name="variant_3", type="dict", fields=[ - dict(name="url", regexp=r'^.+$') - ]), - dict(name="variant_4") - ]), + dict( + name="an_union", + type="union", + fields=[ + dict(name="variant_1", regexp=r'^.+$'), + dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), + dict(name="variant_3", type="dict", fields=[dict(name="url", regexp=r'^.+$')]), + dict(name="variant_4"), + ], + ), ] @@ -188,6 +189,7 @@ class GcpBodyFieldValidator(LoggingMixin): :type api_version: str """ + def __init__(self, validation_specs: Sequence[Dict], api_version: str) -> None: super().__init__() self._validation_specs = validation_specs @@ -200,32 +202,43 @@ def _get_field_name_with_parent(field_name, parent): return field_name @staticmethod - def _sanity_checks(children_validation_specs: Dict, field_type: str, full_field_path: str, - regexp: str, allow_empty: bool, custom_validation: Callable, value) -> None: + def _sanity_checks( + children_validation_specs: Dict, + field_type: str, + full_field_path: str, + regexp: str, + allow_empty: bool, + custom_validation: Callable, + value, + ) -> None: if value is None and field_type != 'union': raise GcpFieldValidationException( - "The required body field '{}' is missing. Please add it.". - format(full_field_path)) + "The required body field '{}' is missing. Please add it.".format(full_field_path) + ) if regexp and field_type: raise GcpValidationSpecificationException( "The validation specification entry '{}' has both type and regexp. " "The regexp is only allowed without type (i.e. assume type is 'str' " - "that can be validated with regexp)".format(full_field_path)) + "that can be validated with regexp)".format(full_field_path) + ) if allow_empty is not None and field_type: raise GcpValidationSpecificationException( "The validation specification entry '{}' has both type and allow_empty. " "The allow_empty is only allowed without type (i.e. assume type is 'str' " - "that can be validated with allow_empty)".format(full_field_path)) + "that can be validated with allow_empty)".format(full_field_path) + ) if children_validation_specs and field_type not in COMPOSITE_FIELD_TYPES: raise GcpValidationSpecificationException( "Nested fields are specified in field '{}' of type '{}'. " - "Nested fields are only allowed for fields of those types: ('{}').". - format(full_field_path, field_type, COMPOSITE_FIELD_TYPES)) + "Nested fields are only allowed for fields of those types: ('{}').".format( + full_field_path, field_type, COMPOSITE_FIELD_TYPES + ) + ) if custom_validation and field_type: raise GcpValidationSpecificationException( "The validation specification field '{}' has both type and " - "custom_validation. Custom validation is only allowed without type.". - format(full_field_path)) + "custom_validation. Custom validation is only allowed without type.".format(full_field_path) + ) @staticmethod def _validate_regexp(full_field_path: str, regexp: str, value: str) -> None: @@ -233,21 +246,21 @@ def _validate_regexp(full_field_path: str, regexp: str, value: str) -> None: # Note matching of only the beginning as we assume the regexps all-or-nothing raise GcpFieldValidationException( "The body field '{}' of value '{}' does not match the field " - "specification regexp: '{}'.". - format(full_field_path, value, regexp)) + "specification regexp: '{}'.".format(full_field_path, value, regexp) + ) @staticmethod def _validate_is_empty(full_field_path: str, value: str) -> None: if not value: raise GcpFieldValidationException( - "The body field '{}' can't be empty. Please provide a value." - .format(full_field_path)) + "The body field '{}' can't be empty. Please provide a value.".format(full_field_path) + ) def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, value: Dict) -> None: for child_validation_spec in children_validation_specs: - self._validate_field(validation_spec=child_validation_spec, - dictionary_to_validate=value, - parent=full_field_path) + self._validate_field( + validation_spec=child_validation_spec, dictionary_to_validate=value, parent=full_field_path + ) all_dict_keys = [spec['name'] for spec in children_validation_specs] for field_name in value.keys(): if field_name not in all_dict_keys: @@ -259,10 +272,12 @@ def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, "can be safely ignored, or you might want to upgrade the operator" "to the version that supports the new API version.", self._get_field_name_with_parent(field_name, full_field_path), - children_validation_specs) + children_validation_specs, + ) - def _validate_union(self, children_validation_specs: Dict, full_field_path: str, - dictionary_to_validate: Dict) -> None: + def _validate_union( + self, children_validation_specs: Dict, full_field_path: str, dictionary_to_validate: Dict + ) -> None: field_found = False found_field_name = None for child_validation_spec in children_validation_specs: @@ -272,13 +287,16 @@ def _validate_union(self, children_validation_specs: Dict, full_field_path: str, validation_spec=child_validation_spec, dictionary_to_validate=dictionary_to_validate, parent=full_field_path, - force_optional=True) + force_optional=True, + ) field_name = child_validation_spec['name'] if new_field_found and field_found: raise GcpFieldValidationException( "The mutually exclusive fields '{}' and '{}' belonging to the " - "union '{}' are both present. Please remove one". - format(field_name, found_field_name, full_field_path)) + "union '{}' are both present. Please remove one".format( + field_name, found_field_name, full_field_path + ) + ) if new_field_found: field_found = True found_field_name = field_name @@ -290,11 +308,12 @@ def _validate_union(self, children_validation_specs: Dict, full_field_path: str, "defined for that version. Then the warning can be safely ignored, " "or you might want to upgrade the operator to the version that " "supports the new API version.", - full_field_path, dictionary_to_validate, - [field['name'] for field in children_validation_specs]) + full_field_path, + dictionary_to_validate, + [field['name'] for field in children_validation_specs], + ) - def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, - force_optional=False): + def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, force_optional=False): """ Validates if field is OK. @@ -318,13 +337,15 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, required_api_version = validation_spec.get('api_version') custom_validation = validation_spec.get('custom_validation') - full_field_path = self._get_field_name_with_parent(field_name=field_name, - parent=parent) + full_field_path = self._get_field_name_with_parent(field_name=field_name, parent=parent) if required_api_version and required_api_version != self._api_version: self.log.debug( "Skipping validation of the field '%s' for API version '%s' " "as it is only valid for API version '%s'", - field_name, self._api_version, required_api_version) + field_name, + self._api_version, + required_api_version, + ) return False value = dictionary_to_validate.get(field_name) @@ -335,13 +356,15 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, # Certainly down from here the field is present (value is not None) # so we should only return True from now on - self._sanity_checks(children_validation_specs=children_validation_specs, - field_type=field_type, - full_field_path=full_field_path, - regexp=regexp, - allow_empty=allow_empty, - custom_validation=custom_validation, - value=value) + self._sanity_checks( + children_validation_specs=children_validation_specs, + field_type=field_type, + full_field_path=full_field_path, + regexp=regexp, + allow_empty=allow_empty, + custom_validation=custom_validation, + value=value, + ) if allow_empty is False: self._validate_is_empty(full_field_path, value) @@ -351,13 +374,16 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, if not isinstance(value, dict): raise GcpFieldValidationException( "The field '{}' should be of dictionary type according to the " - "specification '{}' but it is '{}'". - format(full_field_path, validation_spec, value)) + "specification '{}' but it is '{}'".format(full_field_path, validation_spec, value) + ) if children_validation_specs is None: self.log.debug( "The dict field '%s' has no nested fields defined in the " "specification '%s'. That's perfectly ok - it's content will " - "not be validated.", full_field_path, validation_spec) + "not be validated.", + full_field_path, + validation_spec, + ) else: self._validate_dict(children_validation_specs, full_field_path, value) elif field_type == 'union': @@ -365,30 +391,35 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, raise GcpValidationSpecificationException( "The union field '{}' has no nested fields " "defined in specification '{}'. Unions should have at least one " - "nested field defined.".format(full_field_path, validation_spec)) - self._validate_union(children_validation_specs, full_field_path, - dictionary_to_validate) + "nested field defined.".format(full_field_path, validation_spec) + ) + self._validate_union(children_validation_specs, full_field_path, dictionary_to_validate) elif field_type == 'list': if not isinstance(value, list): raise GcpFieldValidationException( "The field '{}' should be of list type according to the " - "specification '{}' but it is '{}'". - format(full_field_path, validation_spec, value)) + "specification '{}' but it is '{}'".format(full_field_path, validation_spec, value) + ) elif custom_validation: try: custom_validation(value) except Exception as e: raise GcpFieldValidationException( - "Error while validating custom field '{}' specified by '{}': '{}'". - format(full_field_path, validation_spec, e)) + "Error while validating custom field '{}' specified by '{}': '{}'".format( + full_field_path, validation_spec, e + ) + ) elif field_type is None: - self.log.debug("The type of field '%s' is not specified in '%s'. " - "Not validating its content.", full_field_path, validation_spec) + self.log.debug( + "The type of field '%s' is not specified in '%s'. " "Not validating its content.", + full_field_path, + validation_spec, + ) else: raise GcpValidationSpecificationException( "The field '{}' is of type '{}' in specification '{}'." - "This type is unknown to validation!".format( - full_field_path, field_type, validation_spec)) + "This type is unknown to validation!".format(full_field_path, field_type, validation_spec) + ) return True def validate(self, body_to_validate): @@ -404,22 +435,26 @@ def validate(self, body_to_validate): """ try: for validation_spec in self._validation_specs: - self._validate_field(validation_spec=validation_spec, - dictionary_to_validate=body_to_validate) + self._validate_field(validation_spec=validation_spec, dictionary_to_validate=body_to_validate) except GcpFieldValidationException as e: raise GcpFieldValidationException( - "There was an error when validating: body '{}': '{}'". - format(body_to_validate, e)) - all_field_names = [spec['name'] for spec in self._validation_specs - if spec.get('type') != 'union' and - spec.get('api_version') != self._api_version] - all_union_fields = [spec for spec in self._validation_specs - if spec.get('type') == 'union'] + "There was an error when validating: body '{}': '{}'".format(body_to_validate, e) + ) + all_field_names = [ + spec['name'] + for spec in self._validation_specs + if spec.get('type') != 'union' and spec.get('api_version') != self._api_version + ] + all_union_fields = [spec for spec in self._validation_specs if spec.get('type') == 'union'] for union_field in all_union_fields: all_field_names.extend( - [nested_union_spec['name'] for nested_union_spec in union_field['fields'] - if nested_union_spec.get('type') != 'union' and - nested_union_spec.get('api_version') != self._api_version]) + [ + nested_union_spec['name'] + for nested_union_spec in union_field['fields'] + if nested_union_spec.get('type') != 'union' + and nested_union_spec.get('api_version') != self._api_version + ] + ) for field_name in body_to_validate.keys(): if field_name not in all_field_names: self.log.warning( @@ -429,4 +464,6 @@ def validate(self, body_to_validate): "new field names defined for that version. Then the warning " "can be safely ignored, or you might want to upgrade the operator" "to the version that supports the new API version.", - field_name, self._validation_specs) + field_name, + self._validation_specs, + ) diff --git a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py index 2231b2d14641d..b1b538b60a36b 100644 --- a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py +++ b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py @@ -39,21 +39,23 @@ T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name -def create_evaluate_ops(task_prefix: str, # pylint: disable=too-many-arguments - data_format: str, - input_paths: List[str], - prediction_path: str, - metric_fn_and_keys: Tuple[T, Iterable[str]], - validate_fn: T, - batch_prediction_job_id: Optional[str] = None, - region: Optional[str] = None, - project_id: Optional[str] = None, - dataflow_options: Optional[Dict] = None, - model_uri: Optional[str] = None, - model_name: Optional[str] = None, - version_name: Optional[str] = None, - dag: Optional[DAG] = None, - py_interpreter="python3"): +def create_evaluate_ops( # pylint: disable=too-many-arguments + task_prefix: str, + data_format: str, + input_paths: List[str], + prediction_path: str, + metric_fn_and_keys: Tuple[T, Iterable[str]], + validate_fn: T, + batch_prediction_job_id: Optional[str] = None, + region: Optional[str] = None, + project_id: Optional[str] = None, + dataflow_options: Optional[Dict] = None, + model_uri: Optional[str] = None, + model_name: Optional[str] = None, + version_name: Optional[str] = None, + dag: Optional[DAG] = None, + py_interpreter="python3", +): """ Creates Operators needed for model evaluation and returns. @@ -199,7 +201,8 @@ def validate_err_and_count(summary): if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix): raise AirflowException( "Malformed task_id for DataFlowPythonOperator (only alphanumeric " - "and hyphens are allowed but got: " + task_prefix) + "and hyphens are allowed but got: " + task_prefix + ) metric_fn, metric_keys = metric_fn_and_keys if not callable(metric_fn): @@ -213,8 +216,7 @@ def validate_err_and_count(summary): region = region or default_args['region'] model_name = model_name or default_args.get('model_name') version_name = version_name or default_args.get('version_name') - dataflow_options = dataflow_options or \ - default_args.get('dataflow_default_options') + dataflow_options = dataflow_options or default_args.get('dataflow_default_options') evaluate_prediction = MLEngineStartBatchPredictionJobOperator( task_id=(task_prefix + "-prediction"), @@ -227,7 +229,8 @@ def validate_err_and_count(summary): uri=model_uri, model_name=model_name, version_name=version_name, - dag=dag) + dag=dag, + ) metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode() evaluate_summary = DataflowCreatePythonJobOperator( @@ -237,13 +240,12 @@ def validate_err_and_count(summary): options={ "prediction_path": prediction_path, "metric_fn_encoded": metric_fn_encoded, - "metric_keys": ','.join(metric_keys) + "metric_keys": ','.join(metric_keys), }, py_interpreter=py_interpreter, - py_requirements=[ - 'apache-beam[gcp]>=2.14.0' - ], - dag=dag) + py_requirements=['apache-beam[gcp]>=2.14.0'], + dag=dag, + ) evaluate_summary.set_upstream(evaluate_prediction) def apply_validate_fn(*args, templates_dict, **kwargs): @@ -251,8 +253,7 @@ def apply_validate_fn(*args, templates_dict, **kwargs): scheme, bucket, obj, _, _ = urlsplit(prediction_path) if scheme != "gs" or not bucket or not obj: raise ValueError("Wrong format prediction_path: {}".format(prediction_path)) - summary = os.path.join(obj.strip("/"), - "prediction.summary.json") + summary = os.path.join(obj.strip("/"), "prediction.summary.json") gcs_hook = GCSHook() summary = json.loads(gcs_hook.download(bucket, summary)) return validate_fn(summary) @@ -261,7 +262,8 @@ def apply_validate_fn(*args, templates_dict, **kwargs): task_id=(task_prefix + "-validation"), python_callable=apply_validate_fn, templates_dict={"prediction_path": prediction_path}, - dag=dag) + dag=dag, + ) evaluate_validation.set_upstream(evaluate_summary) return evaluate_prediction, evaluate_summary, evaluate_validation diff --git a/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py b/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py index 63733b209c9e8..3d9133df55684 100644 --- a/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +++ b/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py @@ -97,6 +97,7 @@ class JsonCoder: """ JSON encoder/decoder. """ + @staticmethod def encode(x): """JSON encoder.""" @@ -114,15 +115,17 @@ def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name Summary PTransofrm used in Dataflow. """ return ( - pcoll | - "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) | - "PairWith1" >> beam.Map(lambda tup: tup + (1,)) | - "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn( - *([sum] * (len(metric_keys) + 1)))) | - "AverageAndMakeDict" >> beam.Map( + pcoll + | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) + | "PairWith1" >> beam.Map(lambda tup: tup + (1,)) + | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(*([sum] * (len(metric_keys) + 1)))) + | "AverageAndMakeDict" + >> beam.Map( lambda tup: dict( - [(name, tup[i] / tup[-1]) for i, name in enumerate(metric_keys)] + - [("count", tup[-1])]))) + [(name, tup[i] / tup[-1]) for i, name in enumerate(metric_keys)] + [("count", tup[-1])] + ) + ) + ) def run(argv=None): @@ -131,26 +134,35 @@ def run(argv=None): """ parser = argparse.ArgumentParser() parser.add_argument( - "--prediction_path", required=True, + "--prediction_path", + required=True, help=( "The GCS folder that contains BatchPrediction results, containing " "prediction.results-NNNNN-of-NNNNN files in the json format. " "Output will be also stored in this folder, as a file" - "'prediction.summary.json'.")) + "'prediction.summary.json'." + ), + ) parser.add_argument( - "--metric_fn_encoded", required=True, + "--metric_fn_encoded", + required=True, help=( "An encoded function that calculates and returns a tuple of " "metric(s) for a given instance (as a dictionary). It should be " - "encoded via base64.b64encode(dill.dumps(fn, recurse=True)).")) + "encoded via base64.b64encode(dill.dumps(fn, recurse=True))." + ), + ) parser.add_argument( - "--metric_keys", required=True, + "--metric_keys", + required=True, help=( "A comma-separated keys of the aggregated metric(s) in the summary " "output. The order and the size of the keys must match to the " "output of metric_fn. The summary will have an additional key, " "'count', to represent the total number of instances, so this flag " - "shouldn't include 'count'.")) + "shouldn't include 'count'." + ), + ) known_args, pipeline_args = parser.parse_known_args(argv) metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) @@ -164,13 +176,15 @@ def run(argv=None): prediction_summary_path = os.path.join(known_args.prediction_path, "prediction.summary.json") # This is apache-beam ptransform's convention _ = ( - pipe | "ReadPredictionResult" >> beam.io.ReadFromText( - prediction_result_pattern, coder=JsonCoder()) - | "Summary" >> MakeSummary(metric_fn, metric_keys) - | "Write" >> beam.io.WriteToText( - prediction_summary_path, - shard_name_template='', # without trailing -NNNNN-of-NNNNN. - coder=JsonCoder()) + pipe + | "ReadPredictionResult" >> beam.io.ReadFromText(prediction_result_pattern, coder=JsonCoder()) + | "Summary" >> MakeSummary(metric_fn, metric_keys) + | "Write" + >> beam.io.WriteToText( + prediction_summary_path, + shard_name_template='', # without trailing -NNNNN-of-NNNNN. + coder=JsonCoder(), + ) ) diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 62e22e1736c62..ebecafa74bcf4 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -44,7 +44,9 @@ from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.providers.google.cloud.utils.credentials_provider import ( - _get_scopes, _get_target_principal_and_delegates, get_credentials_and_project_id, + _get_scopes, + _get_target_principal_and_delegates, + get_credentials_and_project_id, ) from airflow.utils.process_utils import patch_environ @@ -76,18 +78,10 @@ def is_soft_quota_exception(exception: Exception): * Google Video Intelligence """ if isinstance(exception, Forbidden): - return any( - reason in error.details() - for reason in INVALID_REASONS - for error in exception.errors - ) + return any(reason in error.details() for reason in INVALID_REASONS for error in exception.errors) if isinstance(exception, (ResourceExhausted, TooManyRequests)): - return any( - key in error.details() - for key in INVALID_KEYS - for error in exception.errors - ) + return any(key in error.details() for key in INVALID_KEYS for error in exception.errors) return False @@ -315,6 +309,7 @@ def quota_retry(*args, **kwargs) -> Callable: A decorator that provides a mechanism to repeat requests in response to exceeding a temporary quote limit. """ + def decorator(fun: Callable): default_kwargs = { 'wait': tenacity.wait_exponential(multiplier=1, max=100), @@ -323,9 +318,8 @@ def decorator(fun: Callable): 'after': tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) - return tenacity.retry( - *args, **default_kwargs - )(fun) + return tenacity.retry(*args, **default_kwargs)(fun) + return decorator @staticmethod @@ -335,6 +329,7 @@ def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]: operation in progress (HTTP 409) limit. """ + def decorator(fun: T): default_kwargs = { 'wait': tenacity.wait_exponential(multiplier=1, max=300), @@ -343,9 +338,8 @@ def decorator(fun: T): 'after': tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) - return cast(T, tenacity.retry( - *args, **default_kwargs - )(fun)) + return cast(T, tenacity.retry(*args, **default_kwargs)(fun)) + return decorator @staticmethod @@ -359,21 +353,25 @@ def fallback_to_default_project_id(func: Callable[..., RT]) -> Callable[..., RT] :param func: function to wrap :return: result of the function call """ + @functools.wraps(func) def inner_wrapper(self: GoogleBaseHook, *args, **kwargs) -> RT: if args: raise AirflowException( - "You must use keyword arguments in this methods rather than" - " positional") + "You must use keyword arguments in this methods rather than" " positional" + ) if 'project_id' in kwargs: kwargs['project_id'] = kwargs['project_id'] or self.project_id else: kwargs['project_id'] = self.project_id if not kwargs['project_id']: - raise AirflowException("The project id must be passed either as " - "keyword project_id parameter or as project_id extra " - "in GCP connection definition. Both are not set!") + raise AirflowException( + "The project id must be passed either as " + "keyword project_id parameter or as project_id extra " + "in GCP connection definition. Both are not set!" + ) return func(self, *args, **kwargs) + return inner_wrapper @staticmethod @@ -386,10 +384,12 @@ def provide_gcp_credential_file(func: T) -> T: scope when authorization data is available. Using context manager also makes it easier to use multiple connection in one function. """ + @functools.wraps(func) def wrapper(self: GoogleBaseHook, *args, **kwargs): with self.provide_gcp_credential_file_as_context(): return func(self, *args, **kwargs) + return cast(T, wrapper) @contextmanager @@ -401,8 +401,12 @@ def provide_gcp_credential_file_as_context(self): It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable. """ - key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access - keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access + key_path = self._get_field( + 'key_path', None + ) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access + keyfile_dict = self._get_field( + 'keyfile_dict', None + ) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access if key_path and keyfile_dict: raise AirflowException( "The `keyfile_dict` and `key_path` fields are mutually exclusive. " @@ -410,9 +414,7 @@ def provide_gcp_credential_file_as_context(self): ) elif key_path: if key_path.endswith('.p12'): - raise AirflowException( - 'Legacy P12 key file are not supported, use a JSON key file.' - ) + raise AirflowException('Legacy P12 key file are not supported, use a JSON key file.') with patch_environ({CREDENTIALS: key_path}): yield key_path elif keyfile_dict: @@ -438,42 +440,42 @@ def provide_authorized_gcloud(self): credentials_path = _cloud_sdk.get_application_default_credentials_path() project_id = self.project_id + # fmt: off with self.provide_gcp_credential_file_as_context(), \ tempfile.TemporaryDirectory() as gcloud_config_tmp, \ patch_environ({'CLOUDSDK_CONFIG': gcloud_config_tmp}): + # fmt: on if project_id: # Don't display stdout/stderr for security reason - check_output([ - "gcloud", "config", "set", "core/project", project_id - ]) + check_output(["gcloud", "config", "set", "core/project", project_id]) if CREDENTIALS in os.environ: # This solves most cases when we are logged in using the service key in Airflow. # Don't display stdout/stderr for security reason - check_output([ - "gcloud", "auth", "activate-service-account", f"--key-file={os.environ[CREDENTIALS]}", - ]) + check_output( + ["gcloud", "auth", "activate-service-account", f"--key-file={os.environ[CREDENTIALS]}",] + ) elif os.path.exists(credentials_path): # If we are logged in by `gcloud auth application-default` then we need to log in manually. # This will make the `gcloud auth application-default` and `gcloud auth` credentials equals. with open(credentials_path) as creds_file: creds_content = json.loads(creds_file.read()) # Don't display stdout/stderr for security reason - check_output([ - "gcloud", "config", "set", "auth/client_id", creds_content["client_id"] - ]) + check_output(["gcloud", "config", "set", "auth/client_id", creds_content["client_id"]]) # Don't display stdout/stderr for security reason - check_output([ - "gcloud", "config", "set", "auth/client_secret", creds_content["client_secret"] - ]) + check_output( + ["gcloud", "config", "set", "auth/client_secret", creds_content["client_secret"]] + ) # Don't display stdout/stderr for security reason - check_output([ - "gcloud", - "auth", - "activate-refresh-token", - creds_content["client_id"], - creds_content["refresh_token"], - ]) + check_output( + [ + "gcloud", + "auth", + "activate-refresh-token", + creds_content["client_id"], + creds_content["refresh_token"], + ] + ) yield @staticmethod diff --git a/airflow/providers/google/common/hooks/discovery_api.py b/airflow/providers/google/common/hooks/discovery_api.py index b001b547df32d..3c9f2fa58446e 100644 --- a/airflow/providers/google/common/hooks/discovery_api.py +++ b/airflow/providers/google/common/hooks/discovery_api.py @@ -51,6 +51,7 @@ class GoogleDiscoveryApiHook(GoogleBaseHook): account from the list granting this role to the originating account. :type impersonation_chain: Union[str, Sequence[str]] """ + _conn = None # type: Optional[Resource] def __init__( @@ -62,9 +63,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_service_name = api_service_name self.api_version = api_version @@ -84,7 +83,7 @@ def get_conn(self): serviceName=self.api_service_name, version=self.api_version, http=http_authorized, - cache_discovery=False + cache_discovery=False, ) return self._conn @@ -117,9 +116,7 @@ def _call_api_request(self, google_api_conn_client, endpoint, data, paginate, nu api_endpoint_parts = endpoint.split('.') google_api_endpoint_instance = self._build_api_request( - google_api_conn_client, - api_sub_functions=api_endpoint_parts[1:], - api_endpoint_params=data + google_api_conn_client, api_sub_functions=api_endpoint_parts[1:], api_endpoint_params=data ) if paginate: diff --git a/airflow/providers/google/firebase/example_dags/example_firestore.py b/airflow/providers/google/firebase/example_dags/example_firestore.py index 7bf7e012bd571..39df4cac4b02c 100644 --- a/airflow/providers/google/firebase/example_dags/example_firestore.py +++ b/airflow/providers/google/firebase/example_dags/example_firestore.py @@ -48,7 +48,9 @@ from airflow import models from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, BigQueryExecuteQueryOperator, ) from airflow.providers.google.firebase.operators.firestore import CloudFirestoreExportDatabaseOperator diff --git a/airflow/providers/google/firebase/hooks/firestore.py b/airflow/providers/google/firebase/hooks/firestore.py index 3a6b9c5ff0240..366b0d5e3af57 100644 --- a/airflow/providers/google/firebase/hooks/firestore.py +++ b/airflow/providers/google/firebase/hooks/firestore.py @@ -65,9 +65,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -87,8 +85,7 @@ def get_conn(self): # At the same time, the Non-Authorized Client has no problems. non_authorized_conn = build("firestore", self.api_version, cache_discovery=False) self._conn = build_from_document( - non_authorized_conn._rootDesc, # pylint: disable=protected-access - http=http_authorized + non_authorized_conn._rootDesc, http=http_authorized # pylint: disable=protected-access ) return self._conn diff --git a/airflow/providers/google/firebase/operators/firestore.py b/airflow/providers/google/firebase/operators/firestore.py index f2adbe9b55363..074bfa2152dff 100644 --- a/airflow/providers/google/firebase/operators/firestore.py +++ b/airflow/providers/google/firebase/operators/firestore.py @@ -56,7 +56,12 @@ class CloudFirestoreExportDatabaseOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body", "gcp_conn_id", "api_version", "impersonation_chain",) + template_fields = ( + "body", + "gcp_conn_id", + "api_version", + "impersonation_chain", + ) @apply_defaults def __init__( @@ -68,7 +73,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", api_version: str = "v1", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.database_id = database_id @@ -87,6 +92,6 @@ def execute(self, context): hook = CloudFirestoreHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, - impersonation_chain=self.impersonation_chain + impersonation_chain=self.impersonation_chain, ) return hook.export_documents(database_id=self.database_id, body=self.body, project_id=self.project_id) diff --git a/airflow/providers/google/marketing_platform/example_dags/example_analytics.py b/airflow/providers/google/marketing_platform/example_dags/example_analytics.py index 79578c4dc78c5..337446a213f82 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_analytics.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_analytics.py @@ -21,9 +21,12 @@ from airflow import models from airflow.providers.google.marketing_platform.operators.analytics import ( - GoogleAnalyticsDataImportUploadOperator, GoogleAnalyticsDeletePreviousDataUploadsOperator, - GoogleAnalyticsGetAdsLinkOperator, GoogleAnalyticsListAccountsOperator, - GoogleAnalyticsModifyFileHeadersDataImportOperator, GoogleAnalyticsRetrieveAdsLinksListOperator, + GoogleAnalyticsDataImportUploadOperator, + GoogleAnalyticsDeletePreviousDataUploadsOperator, + GoogleAnalyticsGetAdsLinkOperator, + GoogleAnalyticsListAccountsOperator, + GoogleAnalyticsModifyFileHeadersDataImportOperator, + GoogleAnalyticsRetrieveAdsLinksListOperator, ) from airflow.utils import dates @@ -32,9 +35,7 @@ BUCKET = os.environ.get("GMP_ANALYTICS_BUCKET", "test-airflow-analytics-bucket") BUCKET_FILENAME = "data.csv" WEB_PROPERTY_ID = os.environ.get("GA_WEB_PROPERTY", "UA-12345678-1") -WEB_PROPERTY_AD_WORDS_LINK_ID = os.environ.get( - "GA_WEB_PROPERTY_AD_WORDS_LINK_ID", "rQafFTPOQdmkx4U-fxUfhj" -) +WEB_PROPERTY_AD_WORDS_LINK_ID = os.environ.get("GA_WEB_PROPERTY_AD_WORDS_LINK_ID", "rQafFTPOQdmkx4U-fxUfhj") DATA_ID = "kjdDu3_tQa6n8Q1kXFtSmg" with models.DAG( @@ -78,9 +79,7 @@ ) transform = GoogleAnalyticsModifyFileHeadersDataImportOperator( - task_id="transform", - storage_bucket=BUCKET, - storage_name_object=BUCKET_FILENAME, + task_id="transform", storage_bucket=BUCKET, storage_name_object=BUCKET_FILENAME, ) upload >> [delete, transform] diff --git a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py index 0fba00f29ba7c..b5482a66a8115 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py @@ -23,9 +23,12 @@ from airflow import models from airflow.providers.google.marketing_platform.operators.campaign_manager import ( - GoogleCampaignManagerBatchInsertConversionsOperator, GoogleCampaignManagerBatchUpdateConversionsOperator, - GoogleCampaignManagerDeleteReportOperator, GoogleCampaignManagerDownloadReportOperator, - GoogleCampaignManagerInsertReportOperator, GoogleCampaignManagerRunReportOperator, + GoogleCampaignManagerBatchInsertConversionsOperator, + GoogleCampaignManagerBatchUpdateConversionsOperator, + GoogleCampaignManagerDeleteReportOperator, + GoogleCampaignManagerDownloadReportOperator, + GoogleCampaignManagerInsertReportOperator, + GoogleCampaignManagerRunReportOperator, ) from airflow.providers.google.marketing_platform.sensors.campaign_manager import ( GoogleCampaignManagerReportSensor, @@ -44,13 +47,8 @@ "type": "STANDARD", "name": REPORT_NAME, "criteria": { - "dateRange": { - "kind": "dfareporting#dateRange", - "relativeDateRange": "LAST_365_DAYS", - }, - "dimensions": [ - {"kind": "dfareporting#sortedDimension", "name": "dfa:advertiser"} - ], + "dateRange": {"kind": "dfareporting#dateRange", "relativeDateRange": "LAST_365_DAYS",}, + "dimensions": [{"kind": "dfareporting#sortedDimension", "name": "dfa:advertiser"}], "metricNames": ["dfa:activeViewImpressionDistributionViewable"], }, } @@ -64,13 +62,7 @@ "quantity": 42, "value": 123.4, "timestampMicros": int(time.time()) * 1000000, - "customVariables": [ - { - "kind": "dfareporting#customFloodlightVariable", - "type": "U4", - "value": "value", - } - ], + "customVariables": [{"kind": "dfareporting#customFloodlightVariable", "type": "U4", "value": "value",}], } CONVERSION_UPDATE = { @@ -86,7 +78,7 @@ with models.DAG( "example_campaign_manager", schedule_interval=None, # Override to match your needs, - start_date=dates.days_ago(1) + start_date=dates.days_ago(1), ) as dag: # [START howto_campaign_manager_insert_report_operator] create_report = GoogleCampaignManagerInsertReportOperator( @@ -104,10 +96,7 @@ # [START howto_campaign_manager_wait_for_operation] wait_for_report = GoogleCampaignManagerReportSensor( - task_id="wait_for_report", - profile_id=PROFILE_ID, - report_id=report_id, - file_id=file_id, + task_id="wait_for_report", profile_id=PROFILE_ID, report_id=report_id, file_id=file_id, ) # [END howto_campaign_manager_wait_for_operation] diff --git a/airflow/providers/google/marketing_platform/example_dags/example_display_video.py b/airflow/providers/google/marketing_platform/example_dags/example_display_video.py index daa008a5e2108..dca61bc9ed01f 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_display_video.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_display_video.py @@ -25,13 +25,18 @@ from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook from airflow.providers.google.marketing_platform.operators.display_video import ( - GoogleDisplayVideo360CreateReportOperator, GoogleDisplayVideo360CreateSDFDownloadTaskOperator, - GoogleDisplayVideo360DeleteReportOperator, GoogleDisplayVideo360DownloadLineItemsOperator, - GoogleDisplayVideo360DownloadReportOperator, GoogleDisplayVideo360RunReportOperator, - GoogleDisplayVideo360SDFtoGCSOperator, GoogleDisplayVideo360UploadLineItemsOperator, + GoogleDisplayVideo360CreateReportOperator, + GoogleDisplayVideo360CreateSDFDownloadTaskOperator, + GoogleDisplayVideo360DeleteReportOperator, + GoogleDisplayVideo360DownloadLineItemsOperator, + GoogleDisplayVideo360DownloadReportOperator, + GoogleDisplayVideo360RunReportOperator, + GoogleDisplayVideo360SDFtoGCSOperator, + GoogleDisplayVideo360UploadLineItemsOperator, ) from airflow.providers.google.marketing_platform.sensors.display_video import ( - GoogleDisplayVideo360GetSDFDownloadOperationSensor, GoogleDisplayVideo360ReportSensor, + GoogleDisplayVideo360GetSDFDownloadOperationSensor, + GoogleDisplayVideo360ReportSensor, ) from airflow.utils import dates @@ -74,21 +79,16 @@ "inventorySourceFilter": {"inventorySourceIds": []}, } -DOWNLOAD_LINE_ITEMS_REQUEST: Dict = { - "filterType": ADVERTISER_ID, - "format": "CSV", - "fileSpec": "EWF"} +DOWNLOAD_LINE_ITEMS_REQUEST: Dict = {"filterType": ADVERTISER_ID, "format": "CSV", "fileSpec": "EWF"} # [END howto_display_video_env_variables] with models.DAG( "example_display_video", schedule_interval=None, # Override to match your needs, - start_date=dates.days_ago(1) + start_date=dates.days_ago(1), ) as dag1: # [START howto_google_display_video_createquery_report_operator] - create_report = GoogleDisplayVideo360CreateReportOperator( - body=REPORT, task_id="create_report" - ) + create_report = GoogleDisplayVideo360CreateReportOperator(body=REPORT, task_id="create_report") report_id = "{{ task_instance.xcom_pull('create_report', key='report_id') }}" # [END howto_google_display_video_createquery_report_operator] @@ -99,24 +99,17 @@ # [END howto_google_display_video_runquery_report_operator] # [START howto_google_display_video_wait_report_operator] - wait_for_report = GoogleDisplayVideo360ReportSensor( - task_id="wait_for_report", report_id=report_id - ) + wait_for_report = GoogleDisplayVideo360ReportSensor(task_id="wait_for_report", report_id=report_id) # [END howto_google_display_video_wait_report_operator] # [START howto_google_display_video_getquery_report_operator] get_report = GoogleDisplayVideo360DownloadReportOperator( - report_id=report_id, - task_id="get_report", - bucket_name=BUCKET, - report_name="test1.csv", + report_id=report_id, task_id="get_report", bucket_name=BUCKET, report_name="test1.csv", ) # [END howto_google_display_video_getquery_report_operator] # [START howto_google_display_video_deletequery_report_operator] - delete_report = GoogleDisplayVideo360DeleteReportOperator( - report_id=report_id, task_id="delete_report" - ) + delete_report = GoogleDisplayVideo360DeleteReportOperator(report_id=report_id, task_id="delete_report") # [END howto_google_display_video_deletequery_report_operator] create_report >> run_report >> wait_for_report >> get_report >> delete_report @@ -124,7 +117,7 @@ with models.DAG( "example_display_video_misc", schedule_interval=None, # Override to match your needs, - start_date=dates.days_ago(1) + start_date=dates.days_ago(1), ) as dag2: # [START howto_google_display_video_upload_multiple_entity_read_files_to_big_query] upload_erf_to_bq = GCSToBigQueryOperator( @@ -148,16 +141,14 @@ # [START howto_google_display_video_upload_line_items_operator] upload_line_items = GoogleDisplayVideo360UploadLineItemsOperator( - task_id="upload_line_items", - bucket_name=BUCKET, - object_name=BUCKET_FILE_LOCATION, + task_id="upload_line_items", bucket_name=BUCKET, object_name=BUCKET_FILE_LOCATION, ) # [END howto_google_display_video_upload_line_items_operator] with models.DAG( "example_display_video_sdf", schedule_interval=None, # Override to match your needs, - start_date=dates.days_ago(1) + start_date=dates.days_ago(1), ) as dag3: # [START howto_google_display_video_create_sdf_download_task_operator] create_sdf_download_task = GoogleDisplayVideo360CreateSDFDownloadTaskOperator( diff --git a/airflow/providers/google/marketing_platform/example_dags/example_search_ads.py b/airflow/providers/google/marketing_platform/example_dags/example_search_ads.py index 4cc9328b50f8d..66c860d54172f 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_search_ads.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_search_ads.py @@ -22,7 +22,8 @@ from airflow import models from airflow.providers.google.marketing_platform.operators.search_ads import ( - GoogleSearchAdsDownloadReportOperator, GoogleSearchAdsInsertReportOperator, + GoogleSearchAdsDownloadReportOperator, + GoogleSearchAdsInsertReportOperator, ) from airflow.providers.google.marketing_platform.sensors.search_ads import GoogleSearchAdsReportSensor from airflow.utils import dates @@ -46,12 +47,10 @@ with models.DAG( "example_search_ads", schedule_interval=None, # Override to match your needs, - start_date=dates.days_ago(1) + start_date=dates.days_ago(1), ) as dag: # [START howto_search_ads_generate_report_operator] - generate_report = GoogleSearchAdsInsertReportOperator( - report=REPORT, task_id="generate_report" - ) + generate_report = GoogleSearchAdsInsertReportOperator(report=REPORT, task_id="generate_report") # [END howto_search_ads_generate_report_operator] # [START howto_search_ads_get_report_id] @@ -59,9 +58,7 @@ # [END howto_search_ads_get_report_id] # [START howto_search_ads_get_report_operator] - wait_for_report = GoogleSearchAdsReportSensor( - report_id=report_id, task_id="wait_for_report" - ) + wait_for_report = GoogleSearchAdsReportSensor(report_id=report_id, task_id="wait_for_report") # [END howto_search_ads_get_report_operator] # [START howto_search_ads_getfile_report_operator] diff --git a/airflow/providers/google/marketing_platform/hooks/analytics.py b/airflow/providers/google/marketing_platform/hooks/analytics.py index 546ad29b709c1..662473f56e78f 100644 --- a/airflow/providers/google/marketing_platform/hooks/analytics.py +++ b/airflow/providers/google/marketing_platform/hooks/analytics.py @@ -28,12 +28,7 @@ class GoogleAnalyticsHook(GoogleBaseHook): Hook for Google Analytics 360. """ - def __init__( - self, - api_version: str = "v3", - *args, - **kwargs - ): + def __init__(self, api_version: str = "v3", *args, **kwargs): super().__init__(*args, **kwargs) self.api_version = api_version self._conn = None @@ -59,12 +54,7 @@ def get_conn(self) -> Resource: """ if not self._conn: http_authorized = self._authorize() - self._conn = build( - "analytics", - self.api_version, - http=http_authorized, - cache_discovery=False, - ) + self._conn = build("analytics", self.api_version, http=http_authorized, cache_discovery=False,) return self._conn def list_accounts(self) -> List[Dict[str, Any]]: @@ -109,9 +99,7 @@ def get_ad_words_link( ) return ad_words_link - def list_ad_words_links( - self, account_id: str, web_property_id: str - ) -> List[Dict[str, Any]]: + def list_ad_words_links(self, account_id: str, web_property_id: str) -> List[Dict[str, Any]]: """ Lists webProperty-Google Ads links for a given web property. @@ -126,9 +114,7 @@ def list_ad_words_links( self.log.info("Retrieving ad words list...") conn = self.get_conn() - ads_links = ( - conn.management().webPropertyAdWordsLinks() # pylint: disable=no-member - ) + ads_links = conn.management().webPropertyAdWordsLinks() # pylint: disable=no-member list_args = {"accountId": account_id, "webPropertyId": web_property_id} result = self._paginate(ads_links, list_args) return result @@ -158,9 +144,7 @@ def upload_data( """ media = MediaFileUpload( - file_location, - mimetype="application/octet-stream", - resumable=resumable_upload, + file_location, mimetype="application/octet-stream", resumable=resumable_upload, ) self.log.info( @@ -212,9 +196,7 @@ def delete_upload_data( body=delete_request_body, ).execute() - def list_uploads( - self, account_id, web_property_id, custom_data_source_id - ) -> List[Dict[str, Any]]: + def list_uploads(self, account_id, web_property_id, custom_data_source_id) -> List[Dict[str, Any]]: """ Get list of data upload from GA diff --git a/airflow/providers/google/marketing_platform/hooks/campaign_manager.py b/airflow/providers/google/marketing_platform/hooks/campaign_manager.py index 797d9539d5b6c..afcf332c88db4 100644 --- a/airflow/providers/google/marketing_platform/hooks/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/hooks/campaign_manager.py @@ -42,9 +42,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -54,12 +52,7 @@ def get_conn(self) -> Resource: """ if not self._conn: http_authorized = self._authorize() - self._conn = build( - "dfareporting", - self.api_version, - http=http_authorized, - cache_discovery=False, - ) + self._conn = build("dfareporting", self.api_version, http=http_authorized, cache_discovery=False,) return self._conn def delete_report(self, profile_id: str, report_id: str) -> Any: @@ -156,9 +149,7 @@ def patch_report(self, profile_id: str, report_id: str, update_mask: Dict) -> An ) return response - def run_report( - self, profile_id: str, report_id: str, synchronous: Optional[bool] = None - ) -> Any: + def run_report(self, profile_id: str, report_id: str, synchronous: Optional[bool] = None) -> Any: """ Runs a report. @@ -214,9 +205,7 @@ def get_report(self, file_id: str, profile_id: str, report_id: str) -> Any: ) return response - def get_report_file( - self, file_id: str, profile_id: str, report_id: str - ) -> http.HttpRequest: + def get_report_file(self, file_id: str, profile_id: str, report_id: str) -> http.HttpRequest: """ Retrieves a media part of report file. diff --git a/airflow/providers/google/marketing_platform/hooks/display_video.py b/airflow/providers/google/marketing_platform/hooks/display_video.py index c3c7364a5b0dc..33d4a06ee8a1f 100644 --- a/airflow/providers/google/marketing_platform/hooks/display_video.py +++ b/airflow/providers/google/marketing_platform/hooks/display_video.py @@ -41,9 +41,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -54,10 +52,7 @@ def get_conn(self) -> Resource: if not self._conn: http_authorized = self._authorize() self._conn = build( - "doubleclickbidmanager", - self.api_version, - http=http_authorized, - cache_discovery=False, + "doubleclickbidmanager", self.api_version, http=http_authorized, cache_discovery=False, ) return self._conn @@ -67,12 +62,7 @@ def get_conn_to_display_video(self) -> Resource: """ if not self._conn: http_authorized = self._authorize() - self._conn = build( - "displayvideo", - self.api_version, - http=http_authorized, - cache_discovery=False, - ) + self._conn = build("displayvideo", self.api_version, http=http_authorized, cache_discovery=False,) return self._conn @staticmethod @@ -141,7 +131,7 @@ def get_query(self, query_id: str) -> Dict: ) return response - def list_queries(self, ) -> List[Dict]: + def list_queries(self,) -> List[Dict]: """ Retrieves stored queries. diff --git a/airflow/providers/google/marketing_platform/hooks/search_ads.py b/airflow/providers/google/marketing_platform/hooks/search_ads.py index f6342aaba5021..021f4b9b36053 100644 --- a/airflow/providers/google/marketing_platform/hooks/search_ads.py +++ b/airflow/providers/google/marketing_platform/hooks/search_ads.py @@ -40,9 +40,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version @@ -53,10 +51,7 @@ def get_conn(self): if not self._conn: http_authorized = self._authorize() self._conn = build( - "doubleclicksearch", - self.api_version, - http=http_authorized, - cache_discovery=False, + "doubleclicksearch", self.api_version, http=http_authorized, cache_discovery=False, ) return self._conn diff --git a/airflow/providers/google/marketing_platform/operators/analytics.py b/airflow/providers/google/marketing_platform/operators/analytics.py index 52f5e077ba09a..f6bdc225dfd6c 100644 --- a/airflow/providers/google/marketing_platform/operators/analytics.py +++ b/airflow/providers/google/marketing_platform/operators/analytics.py @@ -65,11 +65,12 @@ class GoogleAnalyticsListAccountsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -127,14 +128,15 @@ class GoogleAnalyticsGetAdsLinkOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, account_id: str, web_property_ad_words_link_id: str, web_property_id: str, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -196,13 +198,14 @@ class GoogleAnalyticsRetrieveAdsLinksListOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, account_id: str, web_property_id: str, api_version: str = "v3", gcp_conn_id: str = "google_cloud_default", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -218,9 +221,7 @@ def execute(self, context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - result = hook.list_ad_words_links( - account_id=self.account_id, web_property_id=self.web_property_id, - ) + result = hook.list_ad_words_links(account_id=self.account_id, web_property_id=self.web_property_id,) return result @@ -262,11 +263,16 @@ class GoogleAnalyticsDataImportUploadOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("storage_bucket", "storage_name_object", "impersonation_chain",) + template_fields = ( + "storage_bucket", + "storage_name_object", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, storage_bucket: str, storage_name_object: str, account_id: str, @@ -277,7 +283,7 @@ def __init__( delegate_to: Optional[str] = None, api_version: str = "v3", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.storage_bucket = storage_bucket @@ -307,14 +313,10 @@ def execute(self, context): with NamedTemporaryFile("w+") as tmp_file: self.log.info( - "Downloading file from GCS: %s/%s ", - self.storage_bucket, - self.storage_name_object, + "Downloading file from GCS: %s/%s ", self.storage_bucket, self.storage_name_object, ) gcs_hook.download( - bucket_name=self.storage_bucket, - object_name=self.storage_name_object, - filename=tmp_file.name, + bucket_name=self.storage_bucket, object_name=self.storage_name_object, filename=tmp_file.name, ) ga_hook.upload_data( @@ -366,7 +368,7 @@ def __init__( delegate_to: Optional[str] = None, api_version: str = "v3", impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -396,10 +398,7 @@ def execute(self, context): delete_request_body = {"customDataImportUids": cids} ga_hook.delete_upload_data( - self.account_id, - self.web_property_id, - self.custom_data_source_id, - delete_request_body, + self.account_id, self.web_property_id, self.custom_data_source_id, delete_request_body, ) @@ -437,7 +436,11 @@ class GoogleAnalyticsModifyFileHeadersDataImportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("storage_bucket", "storage_name_object", "impersonation_chain",) + template_fields = ( + "storage_bucket", + "storage_name_object", + "impersonation_chain", + ) def __init__( self, @@ -447,11 +450,9 @@ def __init__( delegate_to: Optional[str] = None, custom_dimension_header_mapping: Optional[Dict[str, str]] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): - super(GoogleAnalyticsModifyFileHeadersDataImportOperator, self).__init__( - **kwargs - ) + super(GoogleAnalyticsModifyFileHeadersDataImportOperator, self).__init__(**kwargs) self.storage_bucket = storage_bucket self.storage_name_object = storage_name_object self.gcp_conn_id = gcp_conn_id @@ -501,15 +502,11 @@ def execute(self, context): with NamedTemporaryFile("w+") as tmp_file: # Download file from GCS self.log.info( - "Downloading file from GCS: %s/%s ", - self.storage_bucket, - self.storage_name_object, + "Downloading file from GCS: %s/%s ", self.storage_bucket, self.storage_name_object, ) gcs_hook.download( - bucket_name=self.storage_bucket, - object_name=self.storage_name_object, - filename=tmp_file.name, + bucket_name=self.storage_bucket, object_name=self.storage_name_object, filename=tmp_file.name, ) # Modify file @@ -521,12 +518,8 @@ def execute(self, context): # Upload newly formatted file to cloud storage self.log.info( - "Uploading file to GCS: %s/%s ", - self.storage_bucket, - self.storage_name_object, + "Uploading file to GCS: %s/%s ", self.storage_bucket, self.storage_name_object, ) gcs_hook.upload( - bucket_name=self.storage_bucket, - object_name=self.storage_name_object, - filename=tmp_file.name, + bucket_name=self.storage_bucket, object_name=self.storage_name_object, filename=tmp_file.name, ) diff --git a/airflow/providers/google/marketing_platform/operators/campaign_manager.py b/airflow/providers/google/marketing_platform/operators/campaign_manager.py index 9326876f08694..fb719ca36be5f 100644 --- a/airflow/providers/google/marketing_platform/operators/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/operators/campaign_manager.py @@ -81,7 +81,8 @@ class GoogleCampaignManagerDeleteReportOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, profile_id: str, report_name: Optional[str] = None, report_id: Optional[str] = None, @@ -89,15 +90,13 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) if not (report_name or report_id): raise AirflowException("Please provide `report_name` or `report_id`.") if report_name and report_id: - raise AirflowException( - "Please provide only one parameter `report_name` or `report_id`." - ) + raise AirflowException("Please provide only one parameter `report_name` or `report_id`.") self.profile_id = profile_id self.report_name = report_name @@ -188,7 +187,8 @@ class GoogleCampaignManagerDownloadReportOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, + self, + *, profile_id: str, report_id: str, file_id: str, @@ -200,7 +200,7 @@ def __init__( # pylint: disable=too-many-arguments gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.profile_id = profile_id @@ -242,9 +242,7 @@ def execute(self, context: Dict): impersonation_chain=self.impersonation_chain, ) # Get name of the report - report = hook.get_report( - file_id=self.file_id, profile_id=self.profile_id, report_id=self.report_id - ) + report = hook.get_report(file_id=self.file_id, profile_id=self.profile_id, report_id=self.report_id) report_name = self.report_name or report.get("fileName", str(uuid.uuid4())) report_name = self._resolve_file_name(report_name) @@ -254,9 +252,7 @@ def execute(self, context: Dict): profile_id=self.profile_id, report_id=self.report_id, file_id=self.file_id ) with tempfile.NamedTemporaryFile() as temp_file: - downloader = http.MediaIoBaseDownload( - fd=temp_file, request=request, chunksize=self.chunk_size - ) + downloader = http.MediaIoBaseDownload(fd=temp_file, request=request, chunksize=self.chunk_size) download_finished = False while not download_finished: _, download_finished = downloader.next_chunk() @@ -322,14 +318,15 @@ class GoogleCampaignManagerInsertReportOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, profile_id: str, report: Dict[str, Any], api_version: str = "v3.3", gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.profile_id = profile_id @@ -353,9 +350,7 @@ def execute(self, context: Dict): impersonation_chain=self.impersonation_chain, ) self.log.info("Inserting Campaign Manager report.") - response = hook.insert_report( - profile_id=self.profile_id, report=self.report - ) + response = hook.insert_report(profile_id=self.profile_id, report=self.report) report_id = response.get("id") self.xcom_push(context, key="report_id", value=report_id) self.log.info("Report successfully inserted. Report id: %s", report_id) @@ -411,7 +406,8 @@ class GoogleCampaignManagerRunReportOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, profile_id: str, report_id: str, synchronous: bool = False, @@ -419,7 +415,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.profile_id = profile_id @@ -439,9 +435,7 @@ def execute(self, context: Dict): ) self.log.info("Running report %s", self.report_id) response = hook.run_report( - profile_id=self.profile_id, - report_id=self.report_id, - synchronous=self.synchronous, + profile_id=self.profile_id, report_id=self.report_id, synchronous=self.synchronous, ) file_id = response.get("id") self.xcom_push(context, key="file_id", value=file_id) @@ -507,7 +501,8 @@ class GoogleCampaignManagerBatchInsertConversionsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, profile_id: str, conversions: List[Dict[str, Any]], encryption_entity_type: str, @@ -518,7 +513,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.profile_id = profile_id @@ -545,7 +540,7 @@ def execute(self, context: Dict): encryption_entity_type=self.encryption_entity_type, encryption_entity_id=self.encryption_entity_id, encryption_source=self.encryption_source, - max_failed_inserts=self.max_failed_inserts + max_failed_inserts=self.max_failed_inserts, ) return response @@ -608,7 +603,8 @@ class GoogleCampaignManagerBatchUpdateConversionsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, profile_id: str, conversions: List[Dict[str, Any]], encryption_entity_type: str, @@ -619,7 +615,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.profile_id = profile_id @@ -646,6 +642,6 @@ def execute(self, context: Dict): encryption_entity_type=self.encryption_entity_type, encryption_entity_id=self.encryption_entity_id, encryption_source=self.encryption_source, - max_failed_updates=self.max_failed_updates + max_failed_updates=self.max_failed_updates, ) return response diff --git a/airflow/providers/google/marketing_platform/operators/display_video.py b/airflow/providers/google/marketing_platform/operators/display_video.py index 678af81b49835..92708d9ba1b8b 100644 --- a/airflow/providers/google/marketing_platform/operators/display_video.py +++ b/airflow/providers/google/marketing_platform/operators/display_video.py @@ -67,12 +67,16 @@ class GoogleDisplayVideo360CreateReportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body", "impersonation_chain",) + template_fields = ( + "body", + "impersonation_chain", + ) template_ext = (".json",) @apply_defaults def __init__( - self, *, + self, + *, body: Dict[str, Any], api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", @@ -143,11 +147,15 @@ class GoogleDisplayVideo360DeleteReportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report_id", "impersonation_chain",) + template_fields = ( + "report_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, report_id: Optional[str] = None, report_name: Optional[str] = None, api_version: str = "v1", @@ -168,9 +176,7 @@ def __init__( raise AirflowException("Use only one value - `report_name` or `report_id`.") if not (report_name or report_id): - raise AirflowException( - "Provide one of the values: `report_name` or `report_id`." - ) + raise AirflowException("Provide one of the values: `report_name` or `report_id`.") def execute(self, context: Dict): hook = GoogleDisplayVideo360Hook( @@ -184,9 +190,7 @@ def execute(self, context: Dict): else: reports = hook.list_queries() reports_ids_to_delete = [ - report["queryId"] - for report in reports - if report["metadata"]["title"] == self.report_name + report["queryId"] for report in reports if report["metadata"]["title"] == self.report_name ] for report_id in reports_ids_to_delete: @@ -236,11 +240,17 @@ class GoogleDisplayVideo360DownloadReportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report_id", "bucket_name", "report_name", "impersonation_chain",) + template_fields = ( + "report_id", + "bucket_name", + "report_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, report_id: str, bucket_name: str, report_name: Optional[str] = None, @@ -312,10 +322,7 @@ def execute(self, context: Dict): mime_type="text/csv", ) self.log.info( - "Report %s was saved in bucket %s as %s.", - self.report_id, - self.bucket_name, - report_name, + "Report %s was saved in bucket %s as %s.", self.report_id, self.bucket_name, report_name, ) self.xcom_push(context, key="report_name", value=report_name) @@ -356,11 +363,16 @@ class GoogleDisplayVideo360RunReportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report_id", "params", "impersonation_chain",) + template_fields = ( + "report_id", + "params", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, report_id: str, params: Dict[str, Any], api_version: str = "v1", @@ -385,9 +397,7 @@ def execute(self, context: Dict): impersonation_chain=self.impersonation_chain, ) self.log.info( - "Running report %s with the following params:\n %s", - self.report_id, - self.params, + "Running report %s with the following params:\n %s", self.report_id, self.params, ) hook.run_query(query_id=self.report_id, params=self.params) @@ -410,11 +420,17 @@ class GoogleDisplayVideo360DownloadLineItemsOperator(BaseOperator): :type request_body: Dict[str, Any], """ - template_fields = ("request_body", "bucket_name", "object_name", "impersonation_chain",) + template_fields = ( + "request_body", + "bucket_name", + "object_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, request_body: Dict[str, Any], bucket_name: str, object_name: str, @@ -423,7 +439,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.request_body = request_body @@ -497,7 +513,8 @@ class GoogleDisplayVideo360UploadLineItemsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, bucket_name: str, object_name: str, api_version: str = "v1.1", @@ -532,9 +549,7 @@ def execute(self, context: Dict): # downloaded file from the GCS could be a 1GB size or even more with tempfile.NamedTemporaryFile("w+") as f: line_items = gcs_hook.download( - bucket_name=self.bucket_name, - object_name=self.object_name, - filename=f.name, + bucket_name=self.bucket_name, object_name=self.object_name, filename=f.name, ) f.flush() hook.upload_line_items(line_items=line_items) @@ -581,11 +596,15 @@ class GoogleDisplayVideo360CreateSDFDownloadTaskOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("body_request", "impersonation_chain",) + template_fields = ( + "body_request", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, body_request: Dict[str, Any], api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", @@ -609,9 +628,7 @@ def execute(self, context: Dict): ) self.log.info("Creating operation for SDF download task...") - operation = hook.create_sdf_download_operation( - body_request=self.body_request - ) + operation = hook.create_sdf_download_operation(body_request=self.body_request) return operation @@ -657,11 +674,17 @@ class GoogleDisplayVideo360SDFtoGCSOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("operation_name", "bucket_name", "object_name", "impersonation_chain",) + template_fields = ( + "operation_name", + "bucket_name", + "object_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, operation_name: str, bucket_name: str, object_name: str, @@ -703,9 +726,7 @@ def execute(self, context: Dict): self.log.info("Sending file to the Google Cloud Storage...") with tempfile.NamedTemporaryFile() as temp_file: - hook.download_content_from_request( - temp_file, media, chunk_size=1024 * 1024 - ) + hook.download_content_from_request(temp_file, media, chunk_size=1024 * 1024) temp_file.flush() gcs_hook.upload( bucket_name=self.bucket_name, diff --git a/airflow/providers/google/marketing_platform/operators/search_ads.py b/airflow/providers/google/marketing_platform/operators/search_ads.py index 0645e214887e3..02a32d79e1894 100644 --- a/airflow/providers/google/marketing_platform/operators/search_ads.py +++ b/airflow/providers/google/marketing_platform/operators/search_ads.py @@ -62,18 +62,22 @@ class GoogleSearchAdsInsertReportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report", "impersonation_chain",) + template_fields = ( + "report", + "impersonation_chain", + ) template_ext = (".json",) @apply_defaults def __init__( - self, *, + self, + *, report: Dict[str, Any], api_version: str = "v2", gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.report = report @@ -143,11 +147,17 @@ class GoogleSearchAdsDownloadReportOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report_name", "report_id", "bucket_name", "impersonation_chain",) + template_fields = ( + "report_name", + "report_id", + "bucket_name", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, report_id: str, bucket_name: str, report_name: Optional[str] = None, @@ -157,7 +167,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.report_id = report_id @@ -221,14 +231,8 @@ def execute(self, context: Dict): self.log.info("Downloading Search Ads report %s", self.report_id) with NamedTemporaryFile() as temp_file: for i in range(fragments_count): - byte_content = hook.get_file( - report_fragment=i, report_id=self.report_id - ) - fragment = ( - byte_content - if i == 0 - else self._handle_report_fragment(byte_content) - ) + byte_content = hook.get_file(report_fragment=i, report_id=self.report_id) + fragment = byte_content if i == 0 else self._handle_report_fragment(byte_content) temp_file.write(fragment) temp_file.flush() diff --git a/airflow/providers/google/marketing_platform/sensors/campaign_manager.py b/airflow/providers/google/marketing_platform/sensors/campaign_manager.py index 3207db91726e6..b0519c1fbc13f 100644 --- a/airflow/providers/google/marketing_platform/sensors/campaign_manager.py +++ b/airflow/providers/google/marketing_platform/sensors/campaign_manager.py @@ -62,7 +62,12 @@ class GoogleCampaignManagerReportSensor(BaseSensorOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("profile_id", "report_id", "file_id", "impersonation_chain",) + template_fields = ( + "profile_id", + "report_id", + "file_id", + "impersonation_chain", + ) def poke(self, context: Dict) -> bool: hook = GoogleCampaignManagerHook( @@ -71,15 +76,14 @@ def poke(self, context: Dict) -> bool: api_version=self.api_version, impersonation_chain=self.impersonation_chain, ) - response = hook.get_report( - profile_id=self.profile_id, report_id=self.report_id, file_id=self.file_id - ) + response = hook.get_report(profile_id=self.profile_id, report_id=self.report_id, file_id=self.file_id) self.log.info("Report status: %s", response["status"]) return response["status"] != "PROCESSING" @apply_defaults def __init__( - self, *, + self, + *, profile_id: str, report_id: str, file_id: str, @@ -89,7 +93,7 @@ def __init__( mode: str = "reschedule", poke_interval: int = 60 * 5, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.mode = mode diff --git a/airflow/providers/google/marketing_platform/sensors/display_video.py b/airflow/providers/google/marketing_platform/sensors/display_video.py index 01a175f90c0bf..2de6ab4e73e8f 100644 --- a/airflow/providers/google/marketing_platform/sensors/display_video.py +++ b/airflow/providers/google/marketing_platform/sensors/display_video.py @@ -54,16 +54,20 @@ class GoogleDisplayVideo360ReportSensor(BaseSensorOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report_id", "impersonation_chain",) + template_fields = ( + "report_id", + "impersonation_chain", + ) def __init__( - self, *, + self, + *, report_id: str, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -117,7 +121,10 @@ class GoogleDisplayVideo360GetSDFDownloadOperationSensor(BaseSensorOperator): """ - template_fields = ("operation_name", "impersonation_chain",) + template_fields = ( + "operation_name", + "impersonation_chain", + ) def __init__( self, @@ -129,7 +136,7 @@ def __init__( poke_interval: int = 60 * 5, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, *args, - **kwargs + **kwargs, ): super().__init__(*args, **kwargs) self.mode = mode diff --git a/airflow/providers/google/marketing_platform/sensors/search_ads.py b/airflow/providers/google/marketing_platform/sensors/search_ads.py index 450fe0fca6561..fe8e6011615f1 100644 --- a/airflow/providers/google/marketing_platform/sensors/search_ads.py +++ b/airflow/providers/google/marketing_platform/sensors/search_ads.py @@ -58,11 +58,15 @@ class GoogleSearchAdsReportSensor(BaseSensorOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("report_id", "impersonation_chain",) + template_fields = ( + "report_id", + "impersonation_chain", + ) @apply_defaults def __init__( - self, *, + self, + *, report_id: str, api_version: str = "v2", gcp_conn_id: str = "google_cloud_default", @@ -70,7 +74,7 @@ def __init__( mode: str = "reschedule", poke_interval: int = 5 * 60, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(mode=mode, poke_interval=poke_interval, **kwargs) self.report_id = report_id diff --git a/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py b/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py index 1b5eefda33a55..128c0ed00e68e 100644 --- a/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py +++ b/airflow/providers/google/suite/example_dags/example_gcs_to_sheets.py @@ -35,9 +35,7 @@ ) as dag: upload_sheet_to_gcs = GoogleSheetsToGCSOperator( - task_id="upload_sheet_to_gcs", - destination_bucket=BUCKET, - spreadsheet_id=SPREADSHEET_ID, + task_id="upload_sheet_to_gcs", destination_bucket=BUCKET, spreadsheet_id=SPREADSHEET_ID, ) # [START upload_gcs_to_sheets] diff --git a/airflow/providers/google/suite/example_dags/example_sheets.py b/airflow/providers/google/suite/example_dags/example_sheets.py index 171cdcfeb9032..a95157bb8c2c8 100644 --- a/airflow/providers/google/suite/example_dags/example_sheets.py +++ b/airflow/providers/google/suite/example_dags/example_sheets.py @@ -42,9 +42,7 @@ ) as dag: # [START upload_sheet_to_gcs] upload_sheet_to_gcs = GoogleSheetsToGCSOperator( - task_id="upload_sheet_to_gcs", - destination_bucket=GCS_BUCKET, - spreadsheet_id=SPREADSHEET_ID, + task_id="upload_sheet_to_gcs", destination_bucket=GCS_BUCKET, spreadsheet_id=SPREADSHEET_ID, ) # [END upload_sheet_to_gcs] diff --git a/airflow/providers/google/suite/hooks/drive.py b/airflow/providers/google/suite/hooks/drive.py index bae4c7599c17a..ab94bb8f14adb 100644 --- a/airflow/providers/google/suite/hooks/drive.py +++ b/airflow/providers/google/suite/hooks/drive.py @@ -57,9 +57,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.api_version = api_version diff --git a/airflow/providers/google/suite/hooks/sheets.py b/airflow/providers/google/suite/hooks/sheets.py index 2f973b3d8728a..9604aa829e4c0 100644 --- a/airflow/providers/google/suite/hooks/sheets.py +++ b/airflow/providers/google/suite/hooks/sheets.py @@ -61,9 +61,7 @@ def __init__( impersonation_chain: Optional[Union[str, Sequence[str]]] = None, ) -> None: super().__init__( - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, + gcp_conn_id=gcp_conn_id, delegate_to=delegate_to, impersonation_chain=impersonation_chain, ) self.gcp_conn_id = gcp_conn_id self.api_version = api_version @@ -89,7 +87,7 @@ def get_values( range_: str, major_dimension: str = 'DIMENSION_UNSPECIFIED', value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER' + date_time_render_option: str = 'SERIAL_NUMBER', ) -> List: """ Gets values from Google Sheet from a single range @@ -112,13 +110,19 @@ def get_values( :rtype: List """ service = self.get_conn() - response = service.spreadsheets().values().get( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - range=range_, - majorDimension=major_dimension, - valueRenderOption=value_render_option, - dateTimeRenderOption=date_time_render_option - ).execute(num_retries=self.num_retries) + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .get( + spreadsheetId=spreadsheet_id, + range=range_, + majorDimension=major_dimension, + valueRenderOption=value_render_option, + dateTimeRenderOption=date_time_render_option, + ) + .execute(num_retries=self.num_retries) + ) return response['values'] @@ -128,7 +132,7 @@ def batch_get_values( ranges: List, major_dimension: str = 'DIMENSION_UNSPECIFIED', value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER' + date_time_render_option: str = 'SERIAL_NUMBER', ) -> Dict: """ Gets values from Google Sheet from a list of ranges @@ -151,13 +155,19 @@ def batch_get_values( :rtype: Dict """ service = self.get_conn() - response = service.spreadsheets().values().batchGet( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - ranges=ranges, - majorDimension=major_dimension, - valueRenderOption=value_render_option, - dateTimeRenderOption=date_time_render_option - ).execute(num_retries=self.num_retries) + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .batchGet( + spreadsheetId=spreadsheet_id, + ranges=ranges, + majorDimension=major_dimension, + valueRenderOption=value_render_option, + dateTimeRenderOption=date_time_render_option, + ) + .execute(num_retries=self.num_retries) + ) return response @@ -170,7 +180,7 @@ def update_values( value_input_option: str = 'RAW', include_values_in_response: bool = False, value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER' + date_time_render_option: str = 'SERIAL_NUMBER', ) -> Dict: """ Updates values from Google Sheet from a single range @@ -201,20 +211,22 @@ def update_values( :rtype: Dict """ service = self.get_conn() - body = { - "range": range_, - "majorDimension": major_dimension, - "values": values - } - response = service.spreadsheets().values().update( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - range=range_, - valueInputOption=value_input_option, - includeValuesInResponse=include_values_in_response, - responseValueRenderOption=value_render_option, - responseDateTimeRenderOption=date_time_render_option, - body=body - ).execute(num_retries=self.num_retries) + body = {"range": range_, "majorDimension": major_dimension, "values": values} + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .update( + spreadsheetId=spreadsheet_id, + range=range_, + valueInputOption=value_input_option, + includeValuesInResponse=include_values_in_response, + responseValueRenderOption=value_render_option, + responseDateTimeRenderOption=date_time_render_option, + body=body, + ) + .execute(num_retries=self.num_retries) + ) return response @@ -227,7 +239,7 @@ def batch_update_values( value_input_option: str = 'RAW', include_values_in_response: bool = False, value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER' + date_time_render_option: str = 'SERIAL_NUMBER', ) -> Dict: """ Updates values from Google Sheet for multiple ranges @@ -261,27 +273,29 @@ def batch_update_values( raise AirflowException( "'Ranges' and and 'Lists' must be of equal length. \n \ 'Ranges' is of length: {} and \n \ - 'Values' is of length: {}.".format(str(len(ranges)), str(len(values)))) + 'Values' is of length: {}.".format( + str(len(ranges)), str(len(values)) + ) + ) service = self.get_conn() data = [] for idx, range_ in enumerate(ranges): - value_range = { - "range": range_, - "majorDimension": major_dimension, - "values": values[idx] - } + value_range = {"range": range_, "majorDimension": major_dimension, "values": values[idx]} data.append(value_range) body = { "valueInputOption": value_input_option, "data": data, "includeValuesInResponse": include_values_in_response, "responseValueRenderOption": value_render_option, - "responseDateTimeRenderOption": date_time_render_option + "responseDateTimeRenderOption": date_time_render_option, } - response = service.spreadsheets().values().batchUpdate( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - body=body - ).execute(num_retries=self.num_retries) + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .batchUpdate(spreadsheetId=spreadsheet_id, body=body) + .execute(num_retries=self.num_retries) + ) return response @@ -295,7 +309,7 @@ def append_values( insert_data_option: str = 'OVERWRITE', include_values_in_response: bool = False, value_render_option: str = 'FORMATTED_VALUE', - date_time_render_option: str = 'SERIAL_NUMBER' + date_time_render_option: str = 'SERIAL_NUMBER', ) -> Dict: """ Append values from Google Sheet from a single range @@ -329,21 +343,23 @@ def append_values( :rtype: Dict """ service = self.get_conn() - body = { - "range": range_, - "majorDimension": major_dimension, - "values": values - } - response = service.spreadsheets().values().append( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - range=range_, - valueInputOption=value_input_option, - insertDataOption=insert_data_option, - includeValuesInResponse=include_values_in_response, - responseValueRenderOption=value_render_option, - responseDateTimeRenderOption=date_time_render_option, - body=body - ).execute(num_retries=self.num_retries) + body = {"range": range_, "majorDimension": major_dimension, "values": values} + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .append( + spreadsheetId=spreadsheet_id, + range=range_, + valueInputOption=value_input_option, + insertDataOption=insert_data_option, + includeValuesInResponse=include_values_in_response, + responseValueRenderOption=value_render_option, + responseDateTimeRenderOption=date_time_render_option, + body=body, + ) + .execute(num_retries=self.num_retries) + ) return response @@ -360,10 +376,13 @@ def clear(self, spreadsheet_id: str, range_: str) -> Dict: :rtype: Dict """ service = self.get_conn() - response = service.spreadsheets().values().clear( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - range=range_ - ).execute(num_retries=self.num_retries) + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .clear(spreadsheetId=spreadsheet_id, range=range_) + .execute(num_retries=self.num_retries) + ) return response @@ -380,13 +399,14 @@ def batch_clear(self, spreadsheet_id: str, ranges: List) -> Dict: :rtype: Dict """ service = self.get_conn() - body = { - "ranges": ranges - } - response = service.spreadsheets().values().batchClear( # pylint: disable=no-member - spreadsheetId=spreadsheet_id, - body=body - ).execute(num_retries=self.num_retries) + body = {"ranges": ranges} + # pylint: disable=no-member + response = ( + service.spreadsheets() + .values() + .batchClear(spreadsheetId=spreadsheet_id, body=body) + .execute(num_retries=self.num_retries) + ) return response @@ -421,7 +441,8 @@ def get_sheet_titles(self, spreadsheet_id: str, sheet_filter: Optional[List[str] if sheet_filter: titles = [ - sh['properties']['title'] for sh in response['sheets'] + sh['properties']['title'] + for sh in response['sheets'] if sh['properties']['title'] in sheet_filter ] else: @@ -438,11 +459,9 @@ def create_spreadsheet(self, spreadsheet: Dict[str, Any]) -> Dict[str, Any]: :return: An spreadsheet object. """ self.log.info("Creating spreadsheet: %s", spreadsheet['properties']['title']) + # pylint: disable=no-member response = ( - self.get_conn() # pylint: disable=no-member - .spreadsheets() - .create(body=spreadsheet) - .execute(num_retries=self.num_retries) + self.get_conn().spreadsheets().create(body=spreadsheet).execute(num_retries=self.num_retries) ) self.log.info("Spreadsheet: %s created", spreadsheet['properties']['title']) return response diff --git a/airflow/providers/google/suite/operators/sheets.py b/airflow/providers/google/suite/operators/sheets.py index ee5ff0575c3ce..4ec8271a9482b 100644 --- a/airflow/providers/google/suite/operators/sheets.py +++ b/airflow/providers/google/suite/operators/sheets.py @@ -50,11 +50,15 @@ class GoogleSheetsCreateSpreadsheetOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ["spreadsheet", "impersonation_chain", ] + template_fields = [ + "spreadsheet", + "impersonation_chain", + ] @apply_defaults def __init__( - self, *, + self, + *, spreadsheet: Dict[str, Any], gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, diff --git a/airflow/providers/google/suite/transfers/gcs_to_gdrive.py b/airflow/providers/google/suite/transfers/gcs_to_gdrive.py index c525bfbeade6f..c10df4ff1dea5 100644 --- a/airflow/providers/google/suite/transfers/gcs_to_gdrive.py +++ b/airflow/providers/google/suite/transfers/gcs_to_gdrive.py @@ -83,13 +83,18 @@ class GCSToGoogleDriveOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ("source_bucket", "source_object", "destination_object", - "impersonation_chain",) + template_fields = ( + "source_bucket", + "source_object", + "destination_object", + "impersonation_chain", + ) ui_color = "#f0eee4" @apply_defaults def __init__( - self, *, + self, + *, source_bucket: str, source_object: str, destination_object: Optional[str] = None, @@ -97,7 +102,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) diff --git a/airflow/providers/google/suite/transfers/gcs_to_sheets.py b/airflow/providers/google/suite/transfers/gcs_to_sheets.py index 4fcedfa59bdd8..e09f41da4a24a 100644 --- a/airflow/providers/google/suite/transfers/gcs_to_sheets.py +++ b/airflow/providers/google/suite/transfers/gcs_to_sheets.py @@ -68,7 +68,8 @@ class GCSToGoogleSheetsOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, spreadsheet_id: str, bucket_name: str, object_name: Optional[str] = None, @@ -102,15 +103,11 @@ def execute(self, context: Any): with NamedTemporaryFile("w+") as temp_file: # Download data gcs_hook.download( - bucket_name=self.bucket_name, - object_name=self.object_name, - filename=temp_file.name, + bucket_name=self.bucket_name, object_name=self.object_name, filename=temp_file.name, ) # Upload data values = list(csv.reader(temp_file)) sheet_hook.update_values( - spreadsheet_id=self.spreadsheet_id, - range_=self.spreadsheet_range, - values=values, + spreadsheet_id=self.spreadsheet_id, range_=self.spreadsheet_range, values=values, ) diff --git a/airflow/providers/grpc/hooks/grpc.py b/airflow/providers/grpc/hooks/grpc.py index 54b68febfccea..ccf8c2e413a26 100644 --- a/airflow/providers/grpc/hooks/grpc.py +++ b/airflow/providers/grpc/hooks/grpc.py @@ -22,7 +22,8 @@ from google import auth as google_auth from google.auth import jwt as google_auth_jwt from google.auth.transport import ( - grpc as google_auth_transport_grpc, requests as google_auth_transport_requests, + grpc as google_auth_transport_grpc, + requests as google_auth_transport_requests, ) from airflow.exceptions import AirflowConfigException @@ -46,10 +47,12 @@ class GrpcHook(BaseHook): its only arg. Could be partial or lambda. """ - def __init__(self, - grpc_conn_id: str, - interceptors: Optional[List[Callable]] = None, - custom_connection_func: Optional[Callable] = None) -> None: + def __init__( + self, + grpc_conn_id: str, + interceptors: Optional[List[Callable]] = None, + custom_connection_func: Optional[Callable] = None, + ) -> None: super().__init__() self.grpc_conn_id = grpc_conn_id self.conn = self.get_connection(self.grpc_conn_id) @@ -73,38 +76,35 @@ def get_conn(self) -> grpc.Channel: channel = grpc.secure_channel(base_url, creds) elif auth_type == "JWT_GOOGLE": credentials, _ = google_auth.default() - jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials( - credentials) - channel = google_auth_transport_grpc.secure_authorized_channel( - jwt_creds, None, base_url) + jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials(credentials) + channel = google_auth_transport_grpc.secure_authorized_channel(jwt_creds, None, base_url) elif auth_type == "OATH_GOOGLE": scopes = self._get_field("scopes").split(",") credentials, _ = google_auth.default(scopes=scopes) request = google_auth_transport_requests.Request() - channel = google_auth_transport_grpc.secure_authorized_channel( - credentials, request, base_url) + channel = google_auth_transport_grpc.secure_authorized_channel(credentials, request, base_url) elif auth_type == "CUSTOM": if not self.custom_connection_func: raise AirflowConfigException( - "Customized connection function not set, not able to establish a channel") + "Customized connection function not set, not able to establish a channel" + ) channel = self.custom_connection_func(self.conn) else: raise AirflowConfigException( "auth_type not supported or not provided, channel cannot be established,\ - given value: %s" % str(auth_type)) + given value: %s" + % str(auth_type) + ) if self.interceptors: for interceptor in self.interceptors: - channel = grpc.intercept_channel(channel, - interceptor) + channel = grpc.intercept_channel(channel, interceptor) return channel - def run(self, - stub_class: Callable, - call_func: str, - streaming: bool = False, - data: Optional[dict] = None) -> Generator: + def run( + self, stub_class: Callable, call_func: str, streaming: bool = False, data: Optional[dict] = None + ) -> Generator: """ Call gRPC function and yield response to caller """ @@ -126,7 +126,7 @@ def run(self, stub.__class__.__name__, call_func, ex.code(), # pylint: disable=no-member - ex.details() # pylint: disable=no-member + ex.details(), # pylint: disable=no-member ) raise ex diff --git a/airflow/providers/grpc/operators/grpc.py b/airflow/providers/grpc/operators/grpc.py index 10ecd49e7635c..2d54c2b107cfa 100644 --- a/airflow/providers/grpc/operators/grpc.py +++ b/airflow/providers/grpc/operators/grpc.py @@ -53,17 +53,20 @@ class GrpcOperator(BaseOperator): template_fields = ('stub_class', 'call_func', 'data') @apply_defaults - def __init__(self, *, - stub_class: Callable, - call_func: str, - grpc_conn_id: str = "grpc_default", - data: Optional[dict] = None, - interceptors: Optional[List[Callable]] = None, - custom_connection_func: Optional[Callable] = None, - streaming: bool = False, - response_callback: Optional[Callable] = None, - log_response: bool = False, - **kwargs) -> None: + def __init__( + self, + *, + stub_class: Callable, + call_func: str, + grpc_conn_id: str = "grpc_default", + data: Optional[dict] = None, + interceptors: Optional[List[Callable]] = None, + custom_connection_func: Optional[Callable] = None, + streaming: bool = False, + response_callback: Optional[Callable] = None, + log_response: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.stub_class = stub_class self.call_func = call_func @@ -79,7 +82,7 @@ def _get_grpc_hook(self) -> GrpcHook: return GrpcHook( self.grpc_conn_id, interceptors=self.interceptors, - custom_connection_func=self.custom_connection_func + custom_connection_func=self.custom_connection_func, ) def execute(self, context: Dict) -> None: diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py index 9c2fa48bb9b32..f573c2420a8f9 100644 --- a/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -38,7 +38,7 @@ 'ldap', 'radius', 'token', - 'userpass' + 'userpass', ] @@ -104,6 +104,7 @@ class _VaultClient(LoggingMixin): # pylint: disable=too-many-instance-attribute :param radius_port: Port for radius (for ``radius`` auth_type). :type radius_port: int """ + def __init__( # pylint: disable=too-many-arguments self, url: Optional[str] = None, @@ -128,15 +129,18 @@ def __init__( # pylint: disable=too-many-arguments radius_host: Optional[str] = None, radius_secret: Optional[str] = None, radius_port: Optional[int] = None, - **kwargs + **kwargs, ): super().__init__() if kv_engine_version and kv_engine_version not in VALID_KV_VERSIONS: - raise VaultError(f"The version is not supported: {kv_engine_version}. " - f"It should be one of {VALID_KV_VERSIONS}") + raise VaultError( + f"The version is not supported: {kv_engine_version}. " + f"It should be one of {VALID_KV_VERSIONS}" + ) if auth_type not in VALID_AUTH_TYPES: - raise VaultError(f"The auth_type is not supported: {auth_type}. " - f"It should be one of {VALID_AUTH_TYPES}") + raise VaultError( + f"The auth_type is not supported: {auth_type}. " f"It should be one of {VALID_AUTH_TYPES}" + ) if auth_type == "token" and not token and not token_path: raise VaultError("The 'token' authentication type requires 'token' or 'token_path'") if auth_type == "github" and not token and not token_path: @@ -223,29 +227,32 @@ def client(self) -> hvac.Client: def _auth_userpass(self, _client: hvac.Client) -> None: if self.auth_mount_point: - _client.auth_userpass(username=self.username, password=self.password, - mount_point=self.auth_mount_point) + _client.auth_userpass( + username=self.username, password=self.password, mount_point=self.auth_mount_point + ) else: _client.auth_userpass(username=self.username, password=self.password) def _auth_radius(self, _client: hvac.Client) -> None: if self.auth_mount_point: - _client.auth.radius.configure(host=self.radius_host, - secret=self.radius_secret, - port=self.radius_port, - mount_point=self.auth_mount_point) + _client.auth.radius.configure( + host=self.radius_host, + secret=self.radius_secret, + port=self.radius_port, + mount_point=self.auth_mount_point, + ) else: - _client.auth.radius.configure(host=self.radius_host, - secret=self.radius_secret, - port=self.radius_port) + _client.auth.radius.configure( + host=self.radius_host, secret=self.radius_secret, port=self.radius_port + ) def _auth_ldap(self, _client: hvac.Client) -> None: if self.auth_mount_point: _client.auth.ldap.login( - username=self.username, password=self.password, mount_point=self.auth_mount_point) + username=self.username, password=self.password, mount_point=self.auth_mount_point + ) else: - _client.auth.ldap.login( - username=self.username, password=self.password) + _client.auth.ldap.login(username=self.username, password=self.password) def _auth_kubernetes(self, _client: hvac.Client) -> None: if not self.kubernetes_jwt_path: @@ -253,8 +260,7 @@ def _auth_kubernetes(self, _client: hvac.Client) -> None: with open(self.kubernetes_jwt_path) as f: jwt = f.read() if self.auth_mount_point: - _client.auth_kubernetes(role=self.kubernetes_role, jwt=jwt, - mount_point=self.auth_mount_point) + _client.auth_kubernetes(role=self.kubernetes_role, jwt=jwt, mount_point=self.auth_mount_point) else: _client.auth_kubernetes(role=self.kubernetes_role, jwt=jwt) @@ -266,12 +272,14 @@ def _auth_github(self, _client: hvac.Client) -> None: def _auth_gcp(self, _client: hvac.Client) -> None: from airflow.providers.google.cloud.utils.credentials_provider import ( # noqa - _get_scopes, get_credentials_and_project_id, + _get_scopes, + get_credentials_and_project_id, ) + scopes = _get_scopes(self.gcp_scopes) - credentials, _ = get_credentials_and_project_id(key_path=self.gcp_key_path, - keyfile_dict=self.gcp_keyfile_dict, - scopes=scopes) + credentials, _ = get_credentials_and_project_id( + key_path=self.gcp_key_path, keyfile_dict=self.gcp_keyfile_dict, scopes=scopes + ) if self.auth_mount_point: _client.auth.gcp.configure(credentials=credentials, mount_point=self.auth_mount_point) else: @@ -284,28 +292,32 @@ def _auth_azure(self, _client: hvac.Client) -> None: resource=self.azure_resource, client_id=self.key_id, client_secret=self.secret_id, - mount_point=self.auth_mount_point + mount_point=self.auth_mount_point, ) else: _client.auth.azure.configure( tenant_id=self.azure_tenant_id, resource=self.azure_resource, client_id=self.key_id, - client_secret=self.secret_id + client_secret=self.secret_id, ) def _auth_aws_iam(self, _client: hvac.Client) -> None: if self.auth_mount_point: - _client.auth_aws_iam(access_key=self.key_id, secret_key=self.secret_id, - role=self.role_id, mount_point=self.auth_mount_point) + _client.auth_aws_iam( + access_key=self.key_id, + secret_key=self.secret_id, + role=self.role_id, + mount_point=self.auth_mount_point, + ) else: - _client.auth_aws_iam(access_key=self.key_id, secret_key=self.secret_id, - role=self.role_id) + _client.auth_aws_iam(access_key=self.key_id, secret_key=self.secret_id, role=self.role_id) def _auth_approle(self, _client: hvac.Client) -> None: if self.auth_mount_point: - _client.auth_approle(role_id=self.role_id, secret_id=self.secret_id, - mount_point=self.auth_mount_point) + _client.auth_approle( + role_id=self.role_id, secret_id=self.secret_id, mount_point=self.auth_mount_point + ) else: _client.auth_approle(role_id=self.role_id, secret_id=self.secret_id) @@ -336,10 +348,12 @@ def get_secret(self, secret_path: str, secret_version: Optional[int] = None) -> if secret_version: raise VaultError("Secret version can only be used with version 2 of the KV engine") response = self.client.secrets.kv.v1.read_secret( - path=secret_path, mount_point=self.mount_point) + path=secret_path, mount_point=self.mount_point + ) else: response = self.client.secrets.kv.v2.read_secret_version( - path=secret_path, mount_point=self.mount_point, version=secret_version) + path=secret_path, mount_point=self.mount_point, version=secret_version + ) except InvalidPath: self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point) return None @@ -363,15 +377,15 @@ def get_secret_metadata(self, secret_path: str) -> Optional[dict]: raise VaultError("Metadata might only be used with version 2 of the KV engine.") try: return self.client.secrets.kv.v2.read_secret_metadata( - path=secret_path, - mount_point=self.mount_point) + path=secret_path, mount_point=self.mount_point + ) except InvalidPath: self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point) return None - def get_secret_including_metadata(self, - secret_path: str, - secret_version: Optional[int] = None) -> Optional[dict]: + def get_secret_including_metadata( + self, secret_path: str, secret_version: Optional[int] = None + ) -> Optional[dict]: """ Reads secret including metadata. It is only valid for KV version 2. @@ -390,18 +404,20 @@ def get_secret_including_metadata(self, raise VaultError("Metadata might only be used with version 2 of the KV engine.") try: return self.client.secrets.kv.v2.read_secret_version( - path=secret_path, mount_point=self.mount_point, - version=secret_version) + path=secret_path, mount_point=self.mount_point, version=secret_version + ) except InvalidPath: - self.log.debug("Secret not found %s with mount point %s and version %s", - secret_path, self.mount_point, secret_version) + self.log.debug( + "Secret not found %s with mount point %s and version %s", + secret_path, + self.mount_point, + secret_version, + ) return None - def create_or_update_secret(self, - secret_path: str, - secret: dict, - method: Optional[str] = None, - cas: Optional[int] = None) -> Response: + def create_or_update_secret( + self, secret_path: str, secret: dict, method: Optional[str] = None, cas: Optional[int] = None + ) -> Response: """ Creates or updates secret. @@ -431,8 +447,10 @@ def create_or_update_secret(self, raise VaultError("The cas parameter is only valid for version 2") if self.kv_engine_version == 1: response = self.client.secrets.kv.v1.create_or_update_secret( - secret_path=secret_path, secret=secret, mount_point=self.mount_point, method=method) + secret_path=secret_path, secret=secret, mount_point=self.mount_point, method=method + ) else: response = self.client.secrets.kv.v2.create_or_update_secret( - secret_path=secret_path, secret=secret, mount_point=self.mount_point, cas=cas) + secret_path=secret_path, secret=secret, mount_point=self.mount_point, cas=cas + ) return response diff --git a/airflow/providers/hashicorp/hooks/vault.py b/airflow/providers/hashicorp/hooks/vault.py index c1ed5a03f7464..24da524399cba 100644 --- a/airflow/providers/hashicorp/hooks/vault.py +++ b/airflow/providers/hashicorp/hooks/vault.py @@ -25,7 +25,9 @@ from airflow.hooks.base_hook import BaseHook from airflow.providers.hashicorp._internal_client.vault_client import ( # noqa - DEFAULT_KUBERNETES_JWT_PATH, DEFAULT_KV_ENGINE_VERSION, _VaultClient, + DEFAULT_KUBERNETES_JWT_PATH, + DEFAULT_KV_ENGINE_VERSION, + _VaultClient, ) @@ -107,6 +109,7 @@ class VaultHook(BaseHook): :type radius_port: int """ + def __init__( # pylint: disable=too-many-arguments self, vault_conn_id: str, @@ -122,7 +125,7 @@ def __init__( # pylint: disable=too-many-arguments azure_tenant_id: Optional[str] = None, azure_resource: Optional[str] = None, radius_host: Optional[str] = None, - radius_port: Optional[int] = None + radius_port: Optional[int] = None, ): super().__init__() self.connection = self.get_connection(vault_conn_id) @@ -144,17 +147,26 @@ def __init__( # pylint: disable=too-many-arguments if not role_id: role_id = self.connection.extra_dejson.get('role_id') - azure_resource, azure_tenant_id = \ - self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id) \ - if auth_type == 'azure' else (None, None) - gcp_key_path, gcp_keyfile_dict, gcp_scopes = \ - self._get_gcp_parameters_from_connection(gcp_key_path, gcp_scopes) \ - if auth_type == 'gcp' else (None, None, None) - kubernetes_jwt_path, kubernetes_role = \ - self._get_kubernetes_parameters_from_connection(kubernetes_jwt_path, kubernetes_role) \ - if auth_type == 'kubernetes' else (None, None) - radius_host, radius_port = self._get_radius_parameters_from_connection(radius_host, radius_port) \ - if auth_type == 'radius' else (None, None) + azure_resource, azure_tenant_id = ( + self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id) + if auth_type == 'azure' + else (None, None) + ) + gcp_key_path, gcp_keyfile_dict, gcp_scopes = ( + self._get_gcp_parameters_from_connection(gcp_key_path, gcp_scopes) + if auth_type == 'gcp' + else (None, None, None) + ) + kubernetes_jwt_path, kubernetes_role = ( + self._get_kubernetes_parameters_from_connection(kubernetes_jwt_path, kubernetes_role) + if auth_type == 'kubernetes' + else (None, None) + ) + radius_host, radius_port = ( + self._get_radius_parameters_from_connection(radius_host, radius_port) + if auth_type == 'radius' + else (None, None) + ) if self.connection.conn_type == 'vault': conn_protocol = 'http' @@ -196,12 +208,12 @@ def __init__( # pylint: disable=too-many-arguments azure_resource=azure_resource, radius_host=radius_host, radius_secret=self.connection.password, - radius_port=radius_port + radius_port=radius_port, ) def _get_kubernetes_parameters_from_connection( - self, kubernetes_jwt_path: Optional[str], kubernetes_role: Optional[str]) \ - -> Tuple[str, Optional[str]]: + self, kubernetes_jwt_path: Optional[str], kubernetes_role: Optional[str] + ) -> Tuple[str, Optional[str]]: if not kubernetes_jwt_path: kubernetes_jwt_path = self.connection.extra_dejson.get("kubernetes_jwt_path") if not kubernetes_jwt_path: @@ -211,9 +223,7 @@ def _get_kubernetes_parameters_from_connection( return kubernetes_jwt_path, kubernetes_role def _get_gcp_parameters_from_connection( - self, - gcp_key_path: Optional[str], - gcp_scopes: Optional[str], + self, gcp_key_path: Optional[str], gcp_scopes: Optional[str], ) -> Tuple[Optional[str], Optional[dict], Optional[str]]: if not gcp_scopes: gcp_scopes = self.connection.extra_dejson.get("gcp_scopes") @@ -224,8 +234,8 @@ def _get_gcp_parameters_from_connection( return gcp_key_path, gcp_keyfile_dict, gcp_scopes def _get_azure_parameters_from_connection( - self, azure_resource: Optional[str], azure_tenant_id: Optional[str]) \ - -> Tuple[Optional[str], Optional[str]]: + self, azure_resource: Optional[str], azure_tenant_id: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: if not azure_tenant_id: azure_tenant_id = self.connection.extra_dejson.get("azure_tenant_id") if not azure_resource: @@ -233,8 +243,8 @@ def _get_azure_parameters_from_connection( return azure_resource, azure_tenant_id def _get_radius_parameters_from_connection( - self, radius_host: Optional[str], radius_port: Optional[int]) \ - -> Tuple[Optional[str], Optional[int]]: + self, radius_host: Optional[str], radius_port: Optional[int] + ) -> Tuple[Optional[str], Optional[int]]: if not radius_port: radius_port_str = self.connection.extra_dejson.get("radius_port") if radius_port_str: @@ -288,9 +298,9 @@ def get_secret_metadata(self, secret_path: str) -> Optional[dict]: """ return self.vault_client.get_secret_metadata(secret_path=secret_path) - def get_secret_including_metadata(self, - secret_path: str, - secret_version: Optional[int] = None) -> Optional[dict]: + def get_secret_including_metadata( + self, secret_path: str, secret_version: Optional[int] = None + ) -> Optional[dict]: """ Reads secret including metadata. It is only valid for KV version 2. @@ -306,13 +316,12 @@ def get_secret_including_metadata(self, """ return self.vault_client.get_secret_including_metadata( - secret_path=secret_path, secret_version=secret_version) + secret_path=secret_path, secret_version=secret_version + ) - def create_or_update_secret(self, - secret_path: str, - secret: dict, - method: Optional[str] = None, - cas: Optional[int] = None) -> Response: + def create_or_update_secret( + self, secret_path: str, secret: dict, method: Optional[str] = None, cas: Optional[int] = None + ) -> Response: """ Creates or updates secret. @@ -337,7 +346,5 @@ def create_or_update_secret(self, """ return self.vault_client.create_or_update_secret( - secret_path=secret_path, - secret=secret, - method=method, - cas=cas) + secret_path=secret_path, secret=secret, method=method, cas=cas + ) diff --git a/airflow/providers/hashicorp/secrets/vault.py b/airflow/providers/hashicorp/secrets/vault.py index 635971666ce57..86d53513e02af 100644 --- a/airflow/providers/hashicorp/secrets/vault.py +++ b/airflow/providers/hashicorp/secrets/vault.py @@ -110,6 +110,7 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin): :param radius_port: Port for radius (for ``radius`` auth_type). :type radius_port: str """ + def __init__( # pylint: disable=too-many-arguments self, connections_path: str = 'connections', @@ -137,7 +138,7 @@ def __init__( # pylint: disable=too-many-arguments radius_host: Optional[str] = None, radius_secret: Optional[str] = None, radius_port: Optional[int] = None, - **kwargs + **kwargs, ): super().__init__() self.connections_path = connections_path.rstrip('/') @@ -168,7 +169,7 @@ def __init__( # pylint: disable=too-many-arguments radius_host=radius_host, radius_secret=radius_secret, radius_port=radius_port, - **kwargs + **kwargs, ) def get_conn_uri(self, conn_id: str) -> Optional[str]: diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index 578cec13013df..00464fe4013e5 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -40,10 +40,7 @@ class HttpHook(BaseHook): """ def __init__( - self, - method: str = 'POST', - http_conn_id: str = 'http_default', - auth_type: Any = HTTPBasicAuth, + self, method: str = 'POST', http_conn_id: str = 'http_default', auth_type: Any = HTTPBasicAuth, ) -> None: super().__init__() self.http_conn_id = http_conn_id @@ -88,12 +85,14 @@ def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session return session - def run(self, - endpoint: Optional[str], - data: Optional[Union[Dict[str, Any], str]] = None, - headers: Optional[Dict[str, Any]] = None, - extra_options: Optional[Dict[str, Any]] = None, - **request_kwargs: Any) -> Any: + def run( + self, + endpoint: Optional[str], + data: Optional[Union[Dict[str, Any], str]] = None, + headers: Optional[Dict[str, Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + **request_kwargs: Any, + ) -> Any: r""" Performs the request @@ -114,32 +113,20 @@ def run(self, session = self.get_conn(headers) - if self.base_url and not self.base_url.endswith('/') and \ - endpoint and not endpoint.startswith('/'): + if self.base_url and not self.base_url.endswith('/') and endpoint and not endpoint.startswith('/'): url = self.base_url + '/' + endpoint else: url = (self.base_url or '') + (endpoint or '') if self.method == 'GET': # GET uses params - req = requests.Request(self.method, - url, - params=data, - headers=headers, - **request_kwargs) + req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs) elif self.method == 'HEAD': # HEAD doesn't use params - req = requests.Request(self.method, - url, - headers=headers, - **request_kwargs) + req = requests.Request(self.method, url, headers=headers, **request_kwargs) else: # Others use data - req = requests.Request(self.method, - url, - data=data, - headers=headers, - **request_kwargs) + req = requests.Request(self.method, url, data=data, headers=headers, **request_kwargs) prepped_request = session.prepare_request(req) self.log.info("Sending '%s' to url: %s", self.method, url) @@ -160,11 +147,12 @@ def check_response(self, response: requests.Response) -> None: self.log.error(response.text) raise AirflowException(str(response.status_code) + ":" + response.reason) - def run_and_check(self, - session: requests.Session, - prepped_request: requests.PreparedRequest, - extra_options: Dict[Any, Any] - ) -> Any: + def run_and_check( + self, + session: requests.Session, + prepped_request: requests.PreparedRequest, + extra_options: Dict[Any, Any], + ) -> Any: """ Grabs extra options like timeout and actually runs the request, checking for the result @@ -188,7 +176,8 @@ def run_and_check(self, proxies=extra_options.get("proxies", {}), cert=extra_options.get("cert"), timeout=extra_options.get("timeout"), - allow_redirects=extra_options.get("allow_redirects", True)) + allow_redirects=extra_options.get("allow_redirects", True), + ) if extra_options.get('check_response', True): self.check_response(response) @@ -198,8 +187,7 @@ def run_and_check(self, self.log.warning('%s Tenacity will retry to execute the operation', ex) raise ex - def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], - *args: Any, **kwargs: Any) -> Any: + def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any: """ Runs Hook.run() with a Tenacity decorator attached to it. This is useful for connectors which might be disturbed by intermittent issues and should not @@ -224,8 +212,6 @@ def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], ) """ - self._retry_obj = tenacity.Retrying( - **_retry_args - ) + self._retry_obj = tenacity.Retrying(**_retry_args) return self._retry_obj(self.run, *args, **kwargs) diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index 7505ea2bfdc54..8457b46baee86 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -59,22 +59,29 @@ class SimpleHttpOperator(BaseOperator): :type log_response: bool """ - template_fields = ['endpoint', 'data', 'headers', ] + template_fields = [ + 'endpoint', + 'data', + 'headers', + ] template_ext = () ui_color = '#f4a460' @apply_defaults - def __init__(self, *, - endpoint: Optional[str] = None, - method: str = 'POST', - data: Any = None, - headers: Optional[Dict[str, str]] = None, - response_check: Optional[Callable[..., Any]] = None, - response_filter: Optional[Callable[[requests.Response], Any]] = None, - extra_options: Optional[Dict[str, Any]] = None, - http_conn_id: str = 'http_default', - log_response: bool = False, - **kwargs: Any) -> None: + def __init__( + self, + *, + endpoint: Optional[str] = None, + method: str = 'POST', + data: Any = None, + headers: Optional[Dict[str, str]] = None, + response_check: Optional[Callable[..., Any]] = None, + response_filter: Optional[Callable[[requests.Response], Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + http_conn_id: str = 'http_default', + log_response: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.http_conn_id = http_conn_id self.method = method @@ -93,10 +100,7 @@ def execute(self, context: Dict[str, Any]) -> Any: self.log.info("Calling HTTP method") - response = http.run(self.endpoint, - self.data, - self.headers, - self.extra_options) + response = http.run(self.endpoint, self.data, self.headers, self.extra_options) if self.log_response: self.log.info(response.text) if self.response_check: diff --git a/airflow/providers/http/sensors/http.py b/airflow/providers/http/sensors/http.py index e730b566ada60..ec024fe150379 100644 --- a/airflow/providers/http/sensors/http.py +++ b/airflow/providers/http/sensors/http.py @@ -70,15 +70,18 @@ def response_check(response, task_instance): template_fields = ('endpoint', 'request_params') @apply_defaults - def __init__(self, *, - endpoint: str, - http_conn_id: str = 'http_default', - method: str = 'GET', - request_params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None, - response_check: Optional[Callable[..., Any]] = None, - extra_options: Optional[Dict[str, Any]] = None, - **kwargs: Any) -> None: + def __init__( + self, + *, + endpoint: str, + http_conn_id: str = 'http_default', + method: str = 'GET', + request_params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + response_check: Optional[Callable[..., Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -87,17 +90,17 @@ def __init__(self, *, self.extra_options = extra_options or {} self.response_check = response_check - self.hook = HttpHook( - method=method, - http_conn_id=http_conn_id) + self.hook = HttpHook(method=method, http_conn_id=http_conn_id) def poke(self, context: Dict[Any, Any]) -> bool: self.log.info('Poking: %s', self.endpoint) try: - response = self.hook.run(self.endpoint, - data=self.request_params, - headers=self.headers, - extra_options=self.extra_options) + response = self.hook.run( + self.endpoint, + data=self.request_params, + headers=self.headers, + extra_options=self.extra_options, + ) if self.response_check: op_kwargs = PythonOperator.determine_op_kwargs(self.response_check, context) return self.response_check(response, **op_kwargs) diff --git a/airflow/providers/imap/hooks/imap.py b/airflow/providers/imap/hooks/imap.py index 60a46f6f80778..6b96faf5344eb 100644 --- a/airflow/providers/imap/hooks/imap.py +++ b/airflow/providers/imap/hooks/imap.py @@ -71,12 +71,9 @@ def get_conn(self) -> 'ImapHook': return self - def has_mail_attachment(self, - name: str, - *, - check_regex: bool = False, - mail_folder: str = 'INBOX', - mail_filter: str = 'All') -> bool: + def has_mail_attachment( + self, name: str, *, check_regex: bool = False, mail_folder: str = 'INBOX', mail_filter: str = 'All' + ) -> bool: """ Checks the mail folder for mails containing attachments with the given name. @@ -92,21 +89,21 @@ def has_mail_attachment(self, :returns: True if there is an attachment with the given name and False if not. :rtype: bool """ - mail_attachments = self._retrieve_mails_attachments_by_name(name, - check_regex, - True, - mail_folder, - mail_filter) + mail_attachments = self._retrieve_mails_attachments_by_name( + name, check_regex, True, mail_folder, mail_filter + ) return len(mail_attachments) > 0 - def retrieve_mail_attachments(self, - name: str, - *, - check_regex: bool = False, - latest_only: bool = False, - mail_folder: str = 'INBOX', - mail_filter: str = 'All', - not_found_mode: str = 'raise') -> List[Tuple]: + def retrieve_mail_attachments( + self, + name: str, + *, + check_regex: bool = False, + latest_only: bool = False, + mail_folder: str = 'INBOX', + mail_filter: str = 'All', + not_found_mode: str = 'raise', + ) -> List[Tuple]: """ Retrieves mail's attachments in the mail folder by its name. @@ -130,26 +127,26 @@ def retrieve_mail_attachments(self, :returns: a list of tuple each containing the attachment filename and its payload. :rtype: a list of tuple """ - mail_attachments = self._retrieve_mails_attachments_by_name(name, - check_regex, - latest_only, - mail_folder, - mail_filter) + mail_attachments = self._retrieve_mails_attachments_by_name( + name, check_regex, latest_only, mail_folder, mail_filter + ) if not mail_attachments: self._handle_not_found_mode(not_found_mode) return mail_attachments - def download_mail_attachments(self, - name: str, - local_output_directory: str, - *, - check_regex: bool = False, - latest_only: bool = False, - mail_folder: str = 'INBOX', - mail_filter: str = 'All', - not_found_mode: str = 'raise'): + def download_mail_attachments( + self, + name: str, + local_output_directory: str, + *, + check_regex: bool = False, + latest_only: bool = False, + mail_folder: str = 'INBOX', + mail_filter: str = 'All', + not_found_mode: str = 'raise', + ): """ Downloads mail's attachments in the mail folder by its name to the local directory. @@ -174,11 +171,9 @@ def download_mail_attachments(self, if set to 'ignore' it won't notify you at all. :type not_found_mode: str """ - mail_attachments = self._retrieve_mails_attachments_by_name(name, - check_regex, - latest_only, - mail_folder, - mail_filter) + mail_attachments = self._retrieve_mails_attachments_by_name( + name, check_regex, latest_only, mail_folder, mail_filter + ) if not mail_attachments: self._handle_not_found_mode(not_found_mode) @@ -195,8 +190,9 @@ def _handle_not_found_mode(self, not_found_mode: str): else: self.log.error('Invalid "not_found_mode" %s', not_found_mode) - def _retrieve_mails_attachments_by_name(self, name: str, check_regex: bool, latest_only: bool, - mail_folder: str, mail_filter: str) -> List: + def _retrieve_mails_attachments_by_name( + self, name: str, check_regex: bool, latest_only: bool, mail_folder: str, mail_filter: str + ) -> List: if not self.mail_client: raise Exception("The 'mail_client' should be initialized before!") @@ -232,8 +228,9 @@ def _fetch_mail_body(self, mail_id: str) -> str: mail_body_str = mail_body.decode('utf-8') # type: ignore return mail_body_str - def _check_mail_body(self, response_mail_body: str, name: str, check_regex: bool, - latest_only: bool) -> List[Tuple[Any, Any]]: + def _check_mail_body( + self, response_mail_body: str, name: str, check_regex: bool, latest_only: bool + ) -> List[Tuple[Any, Any]]: mail = Mail(response_mail_body) if mail.has_attachments(): return mail.get_attachments_by_name(name, check_regex, find_first=latest_only) @@ -257,8 +254,11 @@ def _is_escaping_current_directory(self, name: str): return '../' in name def _correct_path(self, name: str, local_output_directory: str): - return local_output_directory + name if local_output_directory.endswith('/') \ + return ( + local_output_directory + name + if local_output_directory.endswith('/') else local_output_directory + '/' + name + ) def _create_file(self, name: str, payload: Any, local_output_directory: str): file_path = self._correct_path(name, local_output_directory) @@ -288,10 +288,9 @@ def has_attachments(self) -> bool: """ return self.mail.get_content_maintype() == 'multipart' - def get_attachments_by_name(self, - name: str, - check_regex: bool, - find_first: bool = False) -> List[Tuple[Any, Any]]: + def get_attachments_by_name( + self, name: str, check_regex: bool, find_first: bool = False + ) -> List[Tuple[Any, Any]]: """ Gets all attachments by name for the mail. @@ -308,8 +307,9 @@ def get_attachments_by_name(self, attachments = [] for attachment in self._iterate_attachments(): - found_attachment = attachment.has_matching_name(name) if check_regex \ - else attachment.has_equal_name(name) + found_attachment = ( + attachment.has_matching_name(name) if check_regex else attachment.has_equal_name(name) + ) if found_attachment: file_name, file_payload = attachment.get_file() self.log.info('Found attachment: %s', file_name) diff --git a/airflow/providers/imap/sensors/imap_attachment.py b/airflow/providers/imap/sensors/imap_attachment.py index 4468df35e750e..6e4d1b714262d 100644 --- a/airflow/providers/imap/sensors/imap_attachment.py +++ b/airflow/providers/imap/sensors/imap_attachment.py @@ -41,16 +41,20 @@ class ImapAttachmentSensor(BaseSensorOperator): :param conn_id: The connection to run the sensor against. :type conn_id: str """ + template_fields = ('attachment_name', 'mail_filter') @apply_defaults - def __init__(self, *, - attachment_name, - check_regex=False, - mail_folder='INBOX', - mail_filter='All', - conn_id='imap_default', - **kwargs): + def __init__( + self, + *, + attachment_name, + check_regex=False, + mail_folder='INBOX', + mail_filter='All', + conn_id='imap_default', + **kwargs, + ): super().__init__(**kwargs) self.attachment_name = attachment_name @@ -75,5 +79,5 @@ def poke(self, context): name=self.attachment_name, check_regex=self.check_regex, mail_folder=self.mail_folder, - mail_filter=self.mail_filter + mail_filter=self.mail_filter, ) diff --git a/airflow/providers/jdbc/hooks/jdbc.py b/airflow/providers/jdbc/hooks/jdbc.py index 619aefbcbe781..e048d6adcf342 100644 --- a/airflow/providers/jdbc/hooks/jdbc.py +++ b/airflow/providers/jdbc/hooks/jdbc.py @@ -42,10 +42,12 @@ def get_conn(self): jdbc_driver_loc = conn.extra_dejson.get('extra__jdbc__drv_path') jdbc_driver_name = conn.extra_dejson.get('extra__jdbc__drv_clsname') - conn = jaydebeapi.connect(jclassname=jdbc_driver_name, - url=str(host), - driver_args=[str(login), str(psw)], - jars=jdbc_driver_loc.split(",")) + conn = jaydebeapi.connect( + jclassname=jdbc_driver_name, + url=str(host), + driver_args=[str(login), str(psw)], + jars=jdbc_driver_loc.split(","), + ) return conn def set_autocommit(self, conn, autocommit): diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index 1590d903bd81e..94901210dd4c6 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -46,12 +46,15 @@ class JdbcOperator(BaseOperator): ui_color = '#ededed' @apply_defaults - def __init__(self, *, - sql: str, - jdbc_conn_id: str = 'jdbc_default', - autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, - **kwargs) -> None: + def __init__( + self, + *, + sql: str, + jdbc_conn_id: str = 'jdbc_default', + autocommit: bool = False, + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.parameters = parameters self.sql = sql diff --git a/airflow/providers/jenkins/example_dags/example_jenkins_job_trigger.py b/airflow/providers/jenkins/example_dags/example_jenkins_job_trigger.py index 6dc503523237f..e5d00b5a6bba6 100644 --- a/airflow/providers/jenkins/example_dags/example_jenkins_job_trigger.py +++ b/airflow/providers/jenkins/example_dags/example_jenkins_job_trigger.py @@ -30,23 +30,19 @@ "retry_delay": timedelta(minutes=5), "depends_on_past": False, "concurrency": 8, - "max_active_runs": 8 - + "max_active_runs": 8, } with DAG( - "test_jenkins", - default_args=default_args, - start_date=datetime(2017, 6, 1), - schedule_interval=None + "test_jenkins", default_args=default_args, start_date=datetime(2017, 6, 1), schedule_interval=None ) as dag: job_trigger = JenkinsJobTriggerOperator( task_id="trigger_job", job_name="generate-merlin-config", parameters={"first_parameter": "a_value", "second_parameter": "18"}, # parameters="resources/paremeter.json", You can also pass a path to a json file containing your param - jenkins_connection_id="your_jenkins_connection" # T he connection must be configured first + jenkins_connection_id="your_jenkins_connection", # T he connection must be configured first ) def grab_artifact_from_jenkins(**context): @@ -66,9 +62,6 @@ def grab_artifact_from_jenkins(**context): response = jenkins_server.jenkins_open(request) return response # We store the artifact content in a xcom variable for later use - artifact_grabber = PythonOperator( - task_id='artifact_grabber', - python_callable=grab_artifact_from_jenkins - ) + artifact_grabber = PythonOperator(task_id='artifact_grabber', python_callable=grab_artifact_from_jenkins) job_trigger >> artifact_grabber diff --git a/airflow/providers/jenkins/operators/jenkins_job_trigger.py b/airflow/providers/jenkins/operators/jenkins_job_trigger.py index 9e38ede8b883a..ec6cf8d7b5ce1 100644 --- a/airflow/providers/jenkins/operators/jenkins_job_trigger.py +++ b/airflow/providers/jenkins/operators/jenkins_job_trigger.py @@ -54,8 +54,8 @@ def jenkins_request_with_headers(jenkins_server: Jenkins, req: Request) -> Optio response_headers = response.headers if response_body is None: raise jenkins.EmptyResponseException( - "Error communicating with server[%s]: " - "empty response" % jenkins_server.server) + "Error communicating with server[%s]: " "empty response" % jenkins_server.server + ) return {'body': response_body.decode('utf-8'), 'headers': response_headers} except HTTPError as e: # Jenkins's funky authentication means its nigh impossible to distinguish errors. @@ -95,18 +95,22 @@ class JenkinsJobTriggerOperator(BaseOperator): while waiting for the job to appears on jenkins server (default 10) :type max_try_before_job_appears: int """ + template_fields = ('parameters',) template_ext = ('.json',) ui_color = '#f9ec86' @apply_defaults - def __init__(self, *, - jenkins_connection_id: str, - job_name: str, - parameters: ParamType = "", - sleep_time: int = 10, - max_try_before_job_appears: int = 10, - **kwargs): + def __init__( + self, + *, + jenkins_connection_id: str, + job_name: str, + parameters: ParamType = "", + sleep_time: int = 10, + max_try_before_job_appears: int = 10, + **kwargs, + ): super().__init__(**kwargs) self.job_name = job_name self.parameters = parameters @@ -116,9 +120,7 @@ def __init__(self, *, self.jenkins_connection_id = jenkins_connection_id self.max_try_before_job_appears = max_try_before_job_appears - def build_job(self, - jenkins_server: Jenkins, - params: ParamType = "") -> Optional[JenkinsRequest]: + def build_job(self, jenkins_server: Jenkins, params: ParamType = "") -> Optional[JenkinsRequest]: """ This function makes an API call to Jenkins to trigger a build for 'job_name' It returned a dict with 2 keys : body and headers. @@ -139,9 +141,7 @@ def build_job(self, if not params: params = None - request = Request( - method='POST', - url=jenkins_server.build_job_url(self.job_name, params, None)) + request = Request(method='POST', url=jenkins_server.build_job_url(self.job_name, params, None)) return jenkins_request_with_headers(jenkins_server, request) def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: @@ -166,18 +166,19 @@ def poll_job_in_queue(self, location: str, jenkins_server: Jenkins) -> int: self.log.info('Polling jenkins queue at the url %s', location) while try_count < self.max_try_before_job_appears: location_answer = jenkins_request_with_headers( - jenkins_server, Request(method='POST', url=location)) + jenkins_server, Request(method='POST', url=location) + ) if location_answer is not None: json_response = json.loads(location_answer['body']) if 'executable' in json_response: build_number = json_response['executable']['number'] - self.log.info('Job executed on Jenkins side with the build number %s', - build_number) + self.log.info('Job executed on Jenkins side with the build number %s', build_number) return build_number try_count += 1 time.sleep(self.sleep_time) - raise AirflowException("The job hasn't been executed after polling " - f"the queue {self.max_try_before_job_appears} times") + raise AirflowException( + "The job hasn't been executed after polling " f"the queue {self.max_try_before_job_appears} times" + ) def get_hook(self) -> JenkinsHook: """ @@ -190,23 +191,26 @@ def execute(self, context: Mapping[Any, Any]) -> Optional[str]: self.log.error( 'Please specify the jenkins connection id to use.' 'You must create a Jenkins connection before' - ' being able to use this operator') - raise AirflowException('The jenkins_connection_id parameter is missing,' - 'impossible to trigger the job') + ' being able to use this operator' + ) + raise AirflowException( + 'The jenkins_connection_id parameter is missing,' 'impossible to trigger the job' + ) if not self.job_name: self.log.error("Please specify the job name to use in the job_name parameter") - raise AirflowException('The job_name parameter is missing,' - 'impossible to trigger the job') + raise AirflowException('The job_name parameter is missing,' 'impossible to trigger the job') self.log.info( 'Triggering the job %s on the jenkins : %s with the parameters : %s', - self.job_name, self.jenkins_connection_id, self.parameters) + self.job_name, + self.jenkins_connection_id, + self.parameters, + ) jenkins_server = self.get_hook().get_jenkins_server() jenkins_response = self.build_job(jenkins_server, self.parameters) if jenkins_response: - build_number = self.poll_job_in_queue( - jenkins_response['headers']['Location'], jenkins_server) + build_number = self.poll_job_in_queue(jenkins_response['headers']['Location'], jenkins_server) time.sleep(self.sleep_time) keep_polling_job = True @@ -214,8 +218,7 @@ def execute(self, context: Mapping[Any, Any]) -> Optional[str]: # pylint: disable=too-many-nested-blocks while keep_polling_job: try: - build_info = jenkins_server.get_build_info(name=self.job_name, - number=build_number) + build_info = jenkins_server.get_build_info(name=self.job_name, number=build_number) if build_info['result'] is not None: keep_polling_job = False # Check if job had errors. @@ -223,23 +226,24 @@ def execute(self, context: Mapping[Any, Any]) -> Optional[str]: raise AirflowException( 'Jenkins job failed, final state : %s.' 'Find more information on job url : %s' - % (build_info['result'], build_info['url'])) + % (build_info['result'], build_info['url']) + ) else: - self.log.info('Waiting for job to complete : %s , build %s', - self.job_name, build_number) + self.log.info('Waiting for job to complete : %s , build %s', self.job_name, build_number) time.sleep(self.sleep_time) except jenkins.NotFoundException as err: # pylint: disable=no-member raise AirflowException( - 'Jenkins job status check failed. Final error was: ' - f'{err.resp.status}') + 'Jenkins job status check failed. Final error was: ' f'{err.resp.status}' + ) except jenkins.JenkinsException as err: raise AirflowException( f'Jenkins call failed with error : {err}, if you have parameters ' 'double check them, jenkins sends back ' 'this exception for unknown parameters' 'You can also check logs for more details on this exception ' - '(jenkins_url/log/rss)') + '(jenkins_url/log/rss)' + ) if build_info: # If we can we return the url of the job # for later use (like retrieving an artifact) diff --git a/airflow/providers/jira/hooks/jira.py b/airflow/providers/jira/hooks/jira.py index 3afc9ae9dc6f9..daf573f3a546d 100644 --- a/airflow/providers/jira/hooks/jira.py +++ b/airflow/providers/jira/hooks/jira.py @@ -32,9 +32,8 @@ class JiraHook(BaseHook): :param jira_conn_id: reference to a pre-defined Jira Connection :type jira_conn_id: str """ - def __init__(self, - jira_conn_id: str = 'jira_default', - proxies: Optional[Any] = None) -> None: + + def __init__(self, jira_conn_id: str = 'jira_default', proxies: Optional[Any] = None) -> None: super().__init__() self.jira_conn_id = jira_conn_id self.proxies = proxies @@ -58,31 +57,28 @@ def get_conn(self) -> JIRA: # more can be added ex: async, logging, max_retries # verify - if 'verify' in extra_options \ - and extra_options['verify'].lower() == 'false': + if 'verify' in extra_options and extra_options['verify'].lower() == 'false': extra_options['verify'] = False # validate - if 'validate' in extra_options \ - and extra_options['validate'].lower() == 'false': + if 'validate' in extra_options and extra_options['validate'].lower() == 'false': validate = False - if 'get_server_info' in extra_options \ - and extra_options['get_server_info'].lower() == 'false': + if 'get_server_info' in extra_options and extra_options['get_server_info'].lower() == 'false': get_server_info = False try: - self.client = JIRA(conn.host, - options=extra_options, - basic_auth=(conn.login, conn.password), - get_server_info=get_server_info, - validate=validate, - proxies=self.proxies) + self.client = JIRA( + conn.host, + options=extra_options, + basic_auth=(conn.login, conn.password), + get_server_info=get_server_info, + validate=validate, + proxies=self.proxies, + ) except JIRAError as jira_error: - raise AirflowException('Failed to create jira client, jira error: %s' - % str(jira_error)) + raise AirflowException('Failed to create jira client, jira error: %s' % str(jira_error)) except Exception as e: - raise AirflowException('Failed to create jira client, error: %s' - % str(e)) + raise AirflowException('Failed to create jira client, error: %s' % str(e)) return self.client diff --git a/airflow/providers/jira/operators/jira.py b/airflow/providers/jira/operators/jira.py index 0ba7f7565b9a4..3550d1f186570 100644 --- a/airflow/providers/jira/operators/jira.py +++ b/airflow/providers/jira/operators/jira.py @@ -45,13 +45,16 @@ class JiraOperator(BaseOperator): template_fields = ("jira_method_args",) @apply_defaults - def __init__(self, *, - jira_method: str, - jira_conn_id: str = 'jira_default', - jira_method_args: Optional[dict] = None, - result_processor: Optional[Callable] = None, - get_jira_resource_method: Optional[Callable] = None, - **kwargs) -> None: + def __init__( + self, + *, + jira_method: str, + jira_conn_id: str = 'jira_default', + jira_method_args: Optional[dict] = None, + result_processor: Optional[Callable] = None, + get_jira_resource_method: Optional[Callable] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.jira_conn_id = jira_conn_id self.method_name = jira_method @@ -86,7 +89,6 @@ def execute(self, context: Dict) -> Any: return jira_result except JIRAError as jira_error: - raise AirflowException("Failed to execute jiraOperator, error: %s" - % str(jira_error)) + raise AirflowException("Failed to execute jiraOperator, error: %s" % str(jira_error)) except Exception as e: raise AirflowException("Jira operator error: %s" % str(e)) diff --git a/airflow/providers/jira/sensors/jira.py b/airflow/providers/jira/sensors/jira.py index b3d8ca250e831..f82e7b01f3d8d 100644 --- a/airflow/providers/jira/sensors/jira.py +++ b/airflow/providers/jira/sensors/jira.py @@ -39,12 +39,15 @@ class JiraSensor(BaseSensorOperator): """ @apply_defaults - def __init__(self, *, - method_name: str, - jira_conn_id: str = 'jira_default', - method_params: Optional[dict] = None, - result_processor: Optional[Callable] = None, - **kwargs) -> None: + def __init__( + self, + *, + method_name: str, + jira_conn_id: str = 'jira_default', + method_params: Optional[dict] = None, + result_processor: Optional[Callable] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.jira_conn_id = jira_conn_id self.result_processor = None @@ -52,11 +55,13 @@ def __init__(self, *, self.result_processor = result_processor self.method_name = method_name self.method_params = method_params - self.jira_operator = JiraOperator(task_id=self.task_id, - jira_conn_id=self.jira_conn_id, - jira_method=self.method_name, - jira_method_args=self.method_params, - result_processor=self.result_processor) + self.jira_operator = JiraOperator( + task_id=self.task_id, + jira_conn_id=self.jira_conn_id, + jira_method=self.method_name, + jira_method_args=self.method_params, + result_processor=self.result_processor, + ) def poke(self, context: Dict) -> Any: return self.jira_operator.execute(context=context) @@ -81,13 +86,16 @@ class JiraTicketSensor(JiraSensor): template_fields = ("ticket_id",) @apply_defaults - def __init__(self, *, - jira_conn_id: str = 'jira_default', - ticket_id: Optional[str] = None, - field: Optional[str] = None, - expected_value: Optional[str] = None, - field_checker_func: Optional[Callable] = None, - **kwargs) -> None: + def __init__( + self, + *, + jira_conn_id: str = 'jira_default', + ticket_id: Optional[str] = None, + field: Optional[str] = None, + expected_value: Optional[str] = None, + field_checker_func: Optional[Callable] = None, + **kwargs, + ) -> None: self.jira_conn_id = jira_conn_id self.ticket_id = ticket_id @@ -96,27 +104,20 @@ def __init__(self, *, if field_checker_func is None: field_checker_func = self.issue_field_checker - super().__init__(jira_conn_id=jira_conn_id, - result_processor=field_checker_func, - **kwargs) + super().__init__(jira_conn_id=jira_conn_id, result_processor=field_checker_func, **kwargs) def poke(self, context: Dict) -> Any: self.log.info('Jira Sensor checking for change in ticket: %s', self.ticket_id) self.jira_operator.method_name = "issue" - self.jira_operator.jira_method_args = { - 'id': self.ticket_id, - 'fields': self.field - } + self.jira_operator.jira_method_args = {'id': self.ticket_id, 'fields': self.field} return JiraSensor.poke(self, context=context) def issue_field_checker(self, issue: Issue) -> Optional[bool]: """Check issue using different conditions to prepare to evaluate sensor.""" result = None try: # pylint: disable=too-many-nested-blocks - if issue is not None \ - and self.field is not None \ - and self.expected_value is not None: + if issue is not None and self.field is not None and self.expected_value is not None: field_val = getattr(issue.fields, self.field) if field_val is not None: @@ -130,20 +131,18 @@ def issue_field_checker(self, issue: Issue) -> Optional[bool]: self.log.warning( "Not implemented checker for issue field %s which " "is neither string nor list nor Jira Resource", - self.field + self.field, ) except JIRAError as jira_error: - self.log.error("Jira error while checking with expected value: %s", - jira_error) + self.log.error("Jira error while checking with expected value: %s", jira_error) except Exception as e: # pylint: disable=broad-except - self.log.error("Error while checking with expected value %s:", - self.expected_value) + self.log.error("Error while checking with expected value %s:", self.expected_value) self.log.exception(e) if result is True: - self.log.info("Issue field %s has expected value %s, returning success", - self.field, self.expected_value) + self.log.info( + "Issue field %s has expected value %s, returning success", self.field, self.expected_value + ) else: - self.log.info("Issue field %s don't have expected value %s yet.", - self.field, self.expected_value) + self.log.info("Issue field %s don't have expected value %s yet.", self.field, self.expected_value) return result diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py b/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py index 06121394a49fd..06c7a49c967f0 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py +++ b/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py @@ -54,5 +54,5 @@ volumes=[], memory_in_gb=4.0, cpu=1.0, - task_id='start_container' + task_id='start_container', ) diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py b/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py index 77740cff3cdef..a9bc6a786b29a 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py +++ b/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py @@ -36,7 +36,7 @@ 'depends_on_past': False, 'email': ['airflow@example.com'], 'email_on_failure': False, - 'email_on_retry': False + 'email_on_retry': False, } with DAG( diff --git a/airflow/providers/microsoft/azure/hooks/adx.py b/airflow/providers/microsoft/azure/hooks/adx.py index d390c79800bf0..741b369201c02 100644 --- a/airflow/providers/microsoft/azure/hooks/adx.py +++ b/airflow/providers/microsoft/azure/hooks/adx.py @@ -78,9 +78,7 @@ class AzureDataExplorerHook(BaseHook): :type azure_data_explorer_conn_id: str """ - def __init__( - self, - azure_data_explorer_conn_id: str = 'azure_data_explorer_default'): + def __init__(self, azure_data_explorer_conn_id: str = 'azure_data_explorer_default'): super().__init__() self.conn_id = azure_data_explorer_conn_id self.connection = self.get_conn() @@ -97,37 +95,36 @@ def get_required_param(name): value = conn.extra_dejson.get(name) if not value: raise AirflowException( - 'Extra connection option is missing required parameter: `{}`'. - format(name)) + 'Extra connection option is missing required parameter: `{}`'.format(name) + ) return value auth_method = get_required_param('auth_method') if auth_method == 'AAD_APP': kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( - cluster, conn.login, conn.password, - get_required_param('tenant')) + cluster, conn.login, conn.password, get_required_param('tenant') + ) elif auth_method == 'AAD_APP_CERT': kcsb = KustoConnectionStringBuilder.with_aad_application_certificate_authentication( - cluster, conn.login, get_required_param('certificate'), - get_required_param('thumbprint'), get_required_param('tenant')) + cluster, + conn.login, + get_required_param('certificate'), + get_required_param('thumbprint'), + get_required_param('tenant'), + ) elif auth_method == 'AAD_CREDS': kcsb = KustoConnectionStringBuilder.with_aad_user_password_authentication( - cluster, conn.login, conn.password, - get_required_param('tenant')) + cluster, conn.login, conn.password, get_required_param('tenant') + ) elif auth_method == 'AAD_DEVICE': - kcsb = KustoConnectionStringBuilder.with_aad_device_authentication( - cluster) + kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(cluster) else: - raise AirflowException( - 'Unknown authentication method: {}'.format(auth_method)) + raise AirflowException('Unknown authentication method: {}'.format(auth_method)) return KustoClient(kcsb) - def run_query(self, - query: str, - database: str, - options: Optional[Dict] = None) -> KustoResponseDataSetV2: + def run_query(self, query: str, database: str, options: Optional[Dict] = None) -> KustoResponseDataSetV2: """ Run KQL query using provided configuration, and return `azure.kusto.data.response.KustoResponseDataSet` instance. @@ -147,8 +144,6 @@ def run_query(self, for k, v in options.items(): properties.set_option(k, v) try: - return self.connection.execute( - database, query, properties=properties) + return self.connection.execute(database, query, properties=properties) except KustoServiceError as error: - raise AirflowException( - 'Error running Kusto query: {}'.format(error)) + raise AirflowException('Error running Kusto query: {}'.format(error)) diff --git a/airflow/providers/microsoft/azure/hooks/azure_batch.py b/airflow/providers/microsoft/azure/hooks/azure_batch.py index e64a57f0cc8c1..ad7e341437cbf 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_batch.py +++ b/airflow/providers/microsoft/azure/hooks/azure_batch.py @@ -58,30 +58,29 @@ def _get_required_param(name): value = conn.extra_dejson.get(name) if not value: raise AirflowException( - 'Extra connection option is missing required parameter: `{}`'. - format(name)) + 'Extra connection option is missing required parameter: `{}`'.format(name) + ) return value + batch_account_name = _get_required_param('account_name') batch_account_key = _get_required_param('account_key') batch_account_url = _get_required_param('account_url') - credentials = batch_auth.SharedKeyCredentials(batch_account_name, - batch_account_key) - batch_client = BatchServiceClient( - credentials, - batch_url=batch_account_url) + credentials = batch_auth.SharedKeyCredentials(batch_account_name, batch_account_key) + batch_client = BatchServiceClient(credentials, batch_url=batch_account_url) return batch_client - def configure_pool(self, - pool_id: str, - vm_size: str, - display_name: Optional[str] = None, - target_dedicated_nodes: Optional[int] = None, - use_latest_image_and_sku: bool = False, - vm_publisher: Optional[str] = None, - vm_offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - **kwargs - ): + def configure_pool( + self, + pool_id: str, + vm_size: str, + display_name: Optional[str] = None, + target_dedicated_nodes: Optional[int] = None, + use_latest_image_and_sku: bool = False, + vm_publisher: Optional[str] = None, + vm_offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + **kwargs, + ): """ Configures a pool @@ -113,34 +112,34 @@ def configure_pool(self, """ if use_latest_image_and_sku: self.log.info('Using latest verified virtual machine image with node agent sku') - sku_to_use, image_ref_to_use = \ - self._get_latest_verified_image_vm_and_sku(publisher=vm_publisher, - offer=vm_offer, - sku_starts_with=sku_starts_with) + sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku( + publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with + ) pool = batch_models.PoolAddParameter( id=pool_id, vm_size=vm_size, display_name=display_name, virtual_machine_configuration=batch_models.VirtualMachineConfiguration( - image_reference=image_ref_to_use, - node_agent_sku_id=sku_to_use + image_reference=image_ref_to_use, node_agent_sku_id=sku_to_use ), target_dedicated_nodes=target_dedicated_nodes, - **kwargs) + **kwargs, + ) elif self.extra.get('os_family'): - self.log.info('Using cloud service configuration to create pool, ' - 'virtual machine configuration ignored') + self.log.info( + 'Using cloud service configuration to create pool, ' 'virtual machine configuration ignored' + ) pool = batch_models.PoolAddParameter( id=pool_id, vm_size=vm_size, display_name=display_name, cloud_service_configuration=batch_models.CloudServiceConfiguration( - os_family=self.extra.get('os_family'), - os_version=self.extra.get('os_version') + os_family=self.extra.get('os_family'), os_version=self.extra.get('os_version') ), target_dedicated_nodes=target_dedicated_nodes, - **kwargs) + **kwargs, + ) else: self.log.info('Using virtual machine configuration to create a pool') @@ -153,12 +152,13 @@ def configure_pool(self, publisher=self.extra.get('vm_publisher'), offer=self.extra.get('vm_offer'), sku=self.extra.get('vm_sku'), - version=self.extra.get("vm_version") + version=self.extra.get("vm_version"), ), - node_agent_sku_id=self.extra.get('node_agent_sku_id') + node_agent_sku_id=self.extra.get('node_agent_sku_id'), ), target_dedicated_nodes=target_dedicated_nodes, - **kwargs) + **kwargs, + ) return pool def create_pool(self, pool): @@ -193,16 +193,15 @@ def _get_latest_verified_image_vm_and_sku(self, publisher, offer, sku_starts_wit :type sku_starts_with: str """ - options = batch_models.AccountListSupportedImagesOptions( - filter="verificationType eq 'verified'") - images = self.connection.account.list_supported_images( - account_list_supported_images_options=options) + options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'") + images = self.connection.account.list_supported_images(account_list_supported_images_options=options) # pick the latest supported sku skus_to_use = [ - (image.node_agent_sku_id, image.image_reference) for image in images - if image.image_reference.publisher.lower() == publisher.lower() and - image.image_reference.offer.lower() == offer.lower() and - image.image_reference.sku.startswith(sku_starts_with) + (image.node_agent_sku_id, image.image_reference) + for image in images + if image.image_reference.publisher.lower() == publisher.lower() + and image.image_reference.offer.lower() == offer.lower() + and image.image_reference.sku.startswith(sku_starts_with) ] # pick first @@ -218,30 +217,22 @@ def wait_for_all_node_state(self, pool_id, node_state): :param node_state: A set of batch_models.ComputeNodeState :type node_state: set """ - self.log.info('waiting for all nodes in pool %s to reach one of: %s', - pool_id, node_state) + self.log.info('waiting for all nodes in pool %s to reach one of: %s', pool_id, node_state) while True: # refresh pool to ensure that there is no resize error pool = self.connection.pool.get(pool_id) if pool.resize_errors is not None: resize_errors = "\n".join([repr(e) for e in pool.resize_errors]) - raise RuntimeError( - 'resize error encountered for pool {}:\n{}'.format( - pool.id, resize_errors)) + raise RuntimeError('resize error encountered for pool {}:\n{}'.format(pool.id, resize_errors)) nodes = list(self.connection.compute_node.list(pool.id)) - if (len(nodes) >= pool.target_dedicated_nodes and - all(node.state in node_state for node in nodes)): + if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes): return nodes # Allow the timeout to be controlled by the AzureBatchOperator # specified timeout. This way we don't interrupt a startTask inside # the pool time.sleep(10) - def configure_job(self, - job_id: str, - pool_id: str, - display_name: Optional[str] = None, - **kwargs): + def configure_job(self, job_id: str, pool_id: str, display_name: Optional[str] = None, **kwargs): """ Configures a job for use in the pool @@ -253,12 +244,12 @@ def configure_job(self, :type display_name: str """ - job = batch_models.JobAddParameter(id=job_id, - pool_info=batch_models.PoolInformation( - pool_id=pool_id), - display_name=display_name, - **kwargs - ) + job = batch_models.JobAddParameter( + id=job_id, + pool_info=batch_models.PoolInformation(pool_id=pool_id), + display_name=display_name, + **kwargs, + ) return job def create_job(self, job): @@ -277,12 +268,14 @@ def create_job(self, job): else: self.log.info("Job %s already exists", job.id) - def configure_task(self, - task_id: str, - command_line: str, - display_name: Optional[str] = None, - container_settings=None, - **kwargs): + def configure_task( + self, + task_id: str, + command_line: str, + display_name: Optional[str] = None, + container_settings=None, + **kwargs, + ): """ Creates a task @@ -298,11 +291,13 @@ def configure_task(self, containerConfiguration set, this must not be set. :type container_settings: batch_models.TaskContainerSettings """ - task = batch_models.TaskAddParameter(id=task_id, - command_line=command_line, - display_name=display_name, - container_settings=container_settings, - **kwargs) + task = batch_models.TaskAddParameter( + id=task_id, + command_line=command_line, + display_name=display_name, + container_settings=container_settings, + **kwargs, + ) self.log.info("Task created: %s", task_id) return task @@ -317,8 +312,7 @@ def add_single_task_to_job(self, job_id, task): """ try: - self.connection.task.add(job_id=job_id, - task=task) + self.connection.task.add(job_id=job_id, task=task) except batch_models.BatchErrorException as err: if err.error.code != "TaskExists": raise @@ -338,8 +332,7 @@ def wait_for_job_tasks_to_complete(self, job_id, timeout): while timezone.utcnow() < timeout_time: tasks = self.connection.task.list(job_id) - incomplete_tasks = [task for task in tasks if - task.state != batch_models.TaskState.completed] + incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed] if not incomplete_tasks: return for task in incomplete_tasks: diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_instance.py b/airflow/providers/microsoft/azure/hooks/azure_container_instance.py index 4bcb79c6ee253..1294509253990 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_container_instance.py +++ b/airflow/providers/microsoft/azure/hooks/azure_container_instance.py @@ -54,9 +54,7 @@ def create_or_update(self, resource_group, name, container_group): :param container_group: the properties of the container group :type container_group: azure.mgmt.containerinstance.models.ContainerGroup """ - self.connection.container_groups.create_or_update(resource_group, - name, - container_group) + self.connection.container_groups.create_or_update(resource_group, name, container_group) def get_state_exitcode_details(self, resource_group, name): """ @@ -73,7 +71,7 @@ def get_state_exitcode_details(self, resource_group, name): warnings.warn( "get_state_exitcode_details() is deprecated. Related method is get_state()", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) cg_state = self.get_state(resource_group, name) c_state = cg_state.containers[0].instance_view.current_state @@ -91,9 +89,7 @@ def get_messages(self, resource_group, name): :rtype: list[str] """ warnings.warn( - "get_messages() is deprecated. Related method is get_state()", - DeprecationWarning, - stacklevel=2 + "get_messages() is deprecated. Related method is get_state()", DeprecationWarning, stacklevel=2 ) cg_state = self.get_state(resource_group, name) instance_view = cg_state.containers[0].instance_view @@ -110,9 +106,7 @@ def get_state(self, resource_group, name): :return: ContainerGroup :rtype: ~azure.mgmt.containerinstance.models.ContainerGroup """ - return self.connection.container_groups.get(resource_group, - name, - raw=False) + return self.connection.container_groups.get(resource_group, name, raw=False) def get_logs(self, resource_group, name, tail=1000): """ diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_volume.py b/airflow/providers/microsoft/azure/hooks/azure_container_volume.py index fe72c359cf7cd..372871dd8daf9 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_container_volume.py +++ b/airflow/providers/microsoft/azure/hooks/azure_container_volume.py @@ -48,13 +48,16 @@ def get_storagekey(self): return value return conn.password - def get_file_volume(self, mount_name, share_name, - storage_account_name, read_only=False): + def get_file_volume(self, mount_name, share_name, storage_account_name, read_only=False): """ Get Azure File Volume """ - return Volume(name=mount_name, - azure_file=AzureFileVolume(share_name=share_name, - storage_account_name=storage_account_name, - read_only=read_only, - storage_account_key=self.get_storagekey())) + return Volume( + name=mount_name, + azure_file=AzureFileVolume( + share_name=share_name, + storage_account_name=storage_account_name, + read_only=read_only, + storage_account_key=self.get_storagekey(), + ), + ) diff --git a/airflow/providers/microsoft/azure/hooks/azure_cosmos.py b/airflow/providers/microsoft/azure/hooks/azure_cosmos.py index 12599f50095b4..e1e9a8063f229 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/hooks/azure_cosmos.py @@ -98,13 +98,15 @@ def does_collection_exist(self, collection_name, database_name=None): if collection_name is None: raise AirflowBadRequest("Collection name cannot be None.") - existing_container = list(self.get_conn().QueryContainers( - get_database_link(self.__get_database_name(database_name)), { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [ - {"name": "@id", "value": collection_name} - ] - })) + existing_container = list( + self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": collection_name}], + }, + ) + ) if len(existing_container) == 0: return False @@ -119,19 +121,21 @@ def create_collection(self, collection_name, database_name=None): # We need to check to see if this container already exists so we don't try # to create it twice - existing_container = list(self.get_conn().QueryContainers( - get_database_link(self.__get_database_name(database_name)), { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [ - {"name": "@id", "value": collection_name} - ] - })) + existing_container = list( + self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": collection_name}], + }, + ) + ) # Only create if we did not find it already existing if len(existing_container) == 0: self.get_conn().CreateContainer( - get_database_link(self.__get_database_name(database_name)), - {"id": collection_name}) + get_database_link(self.__get_database_name(database_name)), {"id": collection_name} + ) def does_database_exist(self, database_name): """ @@ -140,12 +144,14 @@ def does_database_exist(self, database_name): if database_name is None: raise AirflowBadRequest("Database name cannot be None.") - existing_database = list(self.get_conn().QueryDatabases({ - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [ - {"name": "@id", "value": database_name} - ] - })) + existing_database = list( + self.get_conn().QueryDatabases( + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": database_name}], + } + ) + ) if len(existing_database) == 0: return False @@ -160,12 +166,14 @@ def create_database(self, database_name): # We need to check to see if this database already exists so we don't try # to create it twice - existing_database = list(self.get_conn().QueryDatabases({ - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [ - {"name": "@id", "value": database_name} - ] - })) + existing_database = list( + self.get_conn().QueryDatabases( + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": database_name}], + } + ) + ) # Only create if we did not find it already existing if len(existing_database) == 0: @@ -188,7 +196,8 @@ def delete_collection(self, collection_name, database_name=None): raise AirflowBadRequest("Collection name cannot be None.") self.get_conn().DeleteContainer( - get_collection_link(self.__get_database_name(database_name), collection_name)) + get_collection_link(self.__get_database_name(database_name), collection_name) + ) def upsert_document(self, document, database_name=None, collection_name=None, document_id=None): """ @@ -211,9 +220,10 @@ def upsert_document(self, document, database_name=None, collection_name=None, do created_document = self.get_conn().CreateItem( get_collection_link( - self.__get_database_name(database_name), - self.__get_collection_name(collection_name)), - document) + self.__get_database_name(database_name), self.__get_collection_name(collection_name) + ), + document, + ) return created_document @@ -229,9 +239,11 @@ def insert_documents(self, documents, database_name=None, collection_name=None): created_documents.append( self.get_conn().CreateItem( get_collection_link( - self.__get_database_name(database_name), - self.__get_collection_name(collection_name)), - single_document)) + self.__get_database_name(database_name), self.__get_collection_name(collection_name) + ), + single_document, + ) + ) return created_documents @@ -246,7 +258,9 @@ def delete_document(self, document_id, database_name=None, collection_name=None) get_document_link( self.__get_database_name(database_name), self.__get_collection_name(collection_name), - document_id)) + document_id, + ) + ) def get_document(self, document_id, database_name=None, collection_name=None): """ @@ -260,7 +274,9 @@ def get_document(self, document_id, database_name=None, collection_name=None): get_document_link( self.__get_database_name(database_name), self.__get_collection_name(collection_name), - document_id)) + document_id, + ) + ) except HTTPFailure: return None @@ -277,10 +293,11 @@ def get_documents(self, sql_string, database_name=None, collection_name=None, pa try: result_iterable = self.get_conn().QueryItems( get_collection_link( - self.__get_database_name(database_name), - self.__get_collection_name(collection_name)), + self.__get_database_name(database_name), self.__get_collection_name(collection_name) + ), query, - partition_key) + partition_key, + ) return list(result_iterable) except HTTPFailure: diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_lake.py b/airflow/providers/microsoft/azure/hooks/azure_data_lake.py index 3feda65d413cc..ba3a020eff739 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_data_lake.py +++ b/airflow/providers/microsoft/azure/hooks/azure_data_lake.py @@ -54,9 +54,9 @@ def get_conn(self): service_options = conn.extra_dejson self.account_name = service_options.get('account_name') - adl_creds = lib.auth(tenant_id=service_options.get('tenant'), - client_secret=conn.password, - client_id=conn.login) + adl_creds = lib.auth( + tenant_id=service_options.get('tenant'), client_secret=conn.password, client_id=conn.login + ) self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name) self._conn.connect() return self._conn @@ -76,8 +76,9 @@ def check_for_file(self, file_path): except FileNotFoundError: return False - def upload_file(self, local_path, remote_path, nthreads=64, overwrite=True, - buffersize=4194304, blocksize=4194304): + def upload_file( + self, local_path, remote_path, nthreads=64, overwrite=True, buffersize=4194304, blocksize=4194304 + ): """ Upload a file to Azure Data Lake. @@ -104,16 +105,19 @@ def upload_file(self, local_path, remote_path, nthreads=64, overwrite=True, block for each API call. This block cannot be bigger than a chunk. :type blocksize: int """ - multithread.ADLUploader(self.get_conn(), - lpath=local_path, - rpath=remote_path, - nthreads=nthreads, - overwrite=overwrite, - buffersize=buffersize, - blocksize=blocksize) - - def download_file(self, local_path, remote_path, nthreads=64, overwrite=True, - buffersize=4194304, blocksize=4194304): + multithread.ADLUploader( + self.get_conn(), + lpath=local_path, + rpath=remote_path, + nthreads=nthreads, + overwrite=overwrite, + buffersize=buffersize, + blocksize=blocksize, + ) + + def download_file( + self, local_path, remote_path, nthreads=64, overwrite=True, buffersize=4194304, blocksize=4194304 + ): """ Download a file from Azure Blob Storage. @@ -141,13 +145,15 @@ def download_file(self, local_path, remote_path, nthreads=64, overwrite=True, block for each API call. This block cannot be bigger than a chunk. :type blocksize: int """ - multithread.ADLDownloader(self.get_conn(), - lpath=local_path, - rpath=remote_path, - nthreads=nthreads, - overwrite=overwrite, - buffersize=buffersize, - blocksize=blocksize) + multithread.ADLDownloader( + self.get_conn(), + lpath=local_path, + rpath=remote_path, + nthreads=nthreads, + overwrite=overwrite, + buffersize=buffersize, + blocksize=blocksize, + ) def list(self, path): """ diff --git a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py index 84e592060eba3..be48c0f2b9839 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py +++ b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py @@ -43,8 +43,7 @@ def get_conn(self): if not self._conn: conn = self.get_connection(self.conn_id) service_options = conn.extra_dejson - self._conn = FileService(account_name=conn.login, - account_key=conn.password, **service_options) + self._conn = FileService(account_name=conn.login, account_key=conn.password, **service_options) return self._conn def check_for_directory(self, share_name, directory_name, **kwargs): @@ -61,8 +60,7 @@ def check_for_directory(self, share_name, directory_name, **kwargs): :return: True if the file exists, False otherwise. :rtype: bool """ - return self.get_conn().exists(share_name, directory_name, - **kwargs) + return self.get_conn().exists(share_name, directory_name, **kwargs) def check_for_file(self, share_name, directory_name, file_name, **kwargs): """ @@ -80,8 +78,7 @@ def check_for_file(self, share_name, directory_name, file_name, **kwargs): :return: True if the file exists, False otherwise. :rtype: bool """ - return self.get_conn().exists(share_name, directory_name, - file_name, **kwargs) + return self.get_conn().exists(share_name, directory_name, file_name, **kwargs) def list_directories_and_files(self, share_name, directory_name=None, **kwargs): """ @@ -97,9 +94,7 @@ def list_directories_and_files(self, share_name, directory_name=None, **kwargs): :return: A list of files and directories :rtype: list """ - return self.get_conn().list_directories_and_files(share_name, - directory_name, - **kwargs) + return self.get_conn().list_directories_and_files(share_name, directory_name, **kwargs) def create_directory(self, share_name, directory_name, **kwargs): """ @@ -133,8 +128,7 @@ def get_file(self, file_path, share_name, directory_name, file_name, **kwargs): `FileService.get_file_to_path()` takes. :type kwargs: object """ - self.get_conn().get_file_to_path(share_name, directory_name, - file_name, file_path, **kwargs) + self.get_conn().get_file_to_path(share_name, directory_name, file_name, file_path, **kwargs) def get_file_to_stream(self, stream, share_name, directory_name, file_name, **kwargs): """ @@ -152,8 +146,7 @@ def get_file_to_stream(self, stream, share_name, directory_name, file_name, **kw `FileService.get_file_to_stream()` takes. :type kwargs: object """ - self.get_conn().get_file_to_stream(share_name, directory_name, - file_name, stream, **kwargs) + self.get_conn().get_file_to_stream(share_name, directory_name, file_name, stream, **kwargs) def load_file(self, file_path, share_name, directory_name, file_name, **kwargs): """ @@ -171,8 +164,7 @@ def load_file(self, file_path, share_name, directory_name, file_name, **kwargs): `FileService.create_file_from_path()` takes. :type kwargs: object """ - self.get_conn().create_file_from_path(share_name, directory_name, - file_name, file_path, **kwargs) + self.get_conn().create_file_from_path(share_name, directory_name, file_name, file_path, **kwargs) def load_string(self, string_data, share_name, directory_name, file_name, **kwargs): """ @@ -190,8 +182,7 @@ def load_string(self, string_data, share_name, directory_name, file_name, **kwar `FileService.create_file_from_text()` takes. :type kwargs: object """ - self.get_conn().create_file_from_text(share_name, directory_name, - file_name, string_data, **kwargs) + self.get_conn().create_file_from_text(share_name, directory_name, file_name, string_data, **kwargs) def load_stream(self, stream, share_name, directory_name, file_name, count, **kwargs): """ @@ -211,5 +202,6 @@ def load_stream(self, stream, share_name, directory_name, file_name, count, **kw `FileService.create_file_from_stream()` takes. :type kwargs: object """ - self.get_conn().create_file_from_stream(share_name, directory_name, - file_name, stream, count, **kwargs) + self.get_conn().create_file_from_stream( + share_name, directory_name, file_name, stream, count, **kwargs + ) diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py b/airflow/providers/microsoft/azure/hooks/base_azure.py index 58950fda2d8f9..2ab0952114226 100644 --- a/airflow/providers/microsoft/azure/hooks/base_azure.py +++ b/airflow/providers/microsoft/azure/hooks/base_azure.py @@ -51,25 +51,17 @@ def get_conn(self) -> Any: if not key_path.endswith('.json'): raise AirflowException('Unrecognised extension for key file.') self.log.info('Getting connection using a JSON key file.') - return get_client_from_auth_file( - client_class=self.sdk_client, - auth_path=key_path - ) + return get_client_from_auth_file(client_class=self.sdk_client, auth_path=key_path) key_json = conn.extra_dejson.get('key_json') if key_json: self.log.info('Getting connection using a JSON config.') - return get_client_from_json_dict( - client_class=self.sdk_client, - config_dict=key_json - ) + return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json) self.log.info('Getting connection using specific credentials and subscription_id.') return self.sdk_client( credentials=ServicePrincipalCredentials( - client_id=conn.login, - secret=conn.password, - tenant=conn.extra_dejson.get('tenantId') + client_id=conn.login, secret=conn.password, tenant=conn.extra_dejson.get('tenantId') ), - subscription_id=conn.extra_dejson.get('subscriptionId') + subscription_id=conn.extra_dejson.get('subscriptionId'), ) diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 3c91cf330ec48..07fbe6123c1d9 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -51,8 +51,7 @@ def get_conn(self): """Return the BlockBlobService object.""" conn = self.get_connection(self.conn_id) service_options = conn.extra_dejson - return BlockBlobService(account_name=conn.login, - account_key=conn.password, **service_options) + return BlockBlobService(account_name=conn.login, account_key=conn.password, **service_options) def check_for_blob(self, container_name, blob_name, **kwargs): """ @@ -84,8 +83,7 @@ def check_for_prefix(self, container_name, prefix, **kwargs): :return: True if blobs matching the prefix exist, False otherwise. :rtype: bool """ - matches = self.connection.list_blobs(container_name, prefix, - num_results=1, **kwargs) + matches = self.connection.list_blobs(container_name, prefix, num_results=1, **kwargs) return len(list(matches)) > 0 def get_blobs_list(self, container_name: str, prefix: str, **kwargs): @@ -120,8 +118,7 @@ def load_file(self, file_path, container_name, blob_name, **kwargs): :type kwargs: object """ # Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_file. - self.connection.create_blob_from_path(container_name, blob_name, - file_path, **kwargs) + self.connection.create_blob_from_path(container_name, blob_name, file_path, **kwargs) def load_string(self, string_data, container_name, blob_name, **kwargs): """ @@ -138,8 +135,7 @@ def load_string(self, string_data, container_name, blob_name, **kwargs): :type kwargs: object """ # Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_string. - self.connection.create_blob_from_text(container_name, blob_name, - string_data, **kwargs) + self.connection.create_blob_from_text(container_name, blob_name, string_data, **kwargs) def get_file(self, file_path, container_name, blob_name, **kwargs): """ @@ -155,8 +151,7 @@ def get_file(self, file_path, container_name, blob_name, **kwargs): `BlockBlobService.create_blob_from_path()` takes. :type kwargs: object """ - return self.connection.get_blob_to_path(container_name, blob_name, - file_path, **kwargs) + return self.connection.get_blob_to_path(container_name, blob_name, file_path, **kwargs) def read_file(self, container_name, blob_name, **kwargs): """ @@ -170,12 +165,9 @@ def read_file(self, container_name, blob_name, **kwargs): `BlockBlobService.create_blob_from_path()` takes. :type kwargs: object """ - return self.connection.get_blob_to_text(container_name, - blob_name, - **kwargs).content + return self.connection.get_blob_to_text(container_name, blob_name, **kwargs).content - def delete_file(self, container_name, blob_name, is_prefix=False, - ignore_if_missing=False, **kwargs): + def delete_file(self, container_name, blob_name, is_prefix=False, ignore_if_missing=False, **kwargs): """ Delete a file from Azure Blob Storage. @@ -195,9 +187,7 @@ def delete_file(self, container_name, blob_name, is_prefix=False, if is_prefix: blobs_to_delete = [ - blob.name for blob in self.connection.list_blobs( - container_name, prefix=blob_name, **kwargs - ) + blob.name for blob in self.connection.list_blobs(container_name, prefix=blob_name, **kwargs) ] elif self.check_for_blob(container_name, blob_name): blobs_to_delete = [blob_name] @@ -209,7 +199,4 @@ def delete_file(self, container_name, blob_name, is_prefix=False, for blob_uri in blobs_to_delete: self.log.info("Deleting blob: %s", blob_uri) - self.connection.delete_blob(container_name, - blob_uri, - delete_snapshots='include', - **kwargs) + self.connection.delete_blob(container_name, blob_uri, delete_snapshots='include', **kwargs) diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py index 4650fd45659e8..49fcbdb8b1318 100644 --- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -33,8 +33,9 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): uploads to and reads from Wasb remote storage. """ - def __init__(self, base_log_folder, wasb_log_folder, wasb_container, - filename_template, delete_local_copy): + def __init__( + self, base_log_folder, wasb_log_folder, wasb_container, filename_template, delete_local_copy + ): super().__init__(base_log_folder, filename_template) self.wasb_container = wasb_container self.remote_base = wasb_log_folder @@ -52,12 +53,14 @@ def hook(self): remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') try: from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + return WasbHook(remote_conn_id) except AzureHttpError: self.log.error( 'Could not create an WasbHook with connection id "%s". ' 'Please make sure that airflow[azure] is installed and ' - 'the Wasb connection exists.', remote_conn_id + 'the Wasb connection exists.', + remote_conn_id, ) def set_context(self, ti): @@ -117,8 +120,7 @@ def _read(self, ti, try_number, metadata=None): # local machine even if there are errors reading remote logs, as # returned remote_log will contain error messages. remote_log = self.wasb_read(remote_loc, return_error=True) - log = '*** Reading remote log from {}.\n{}\n'.format( - remote_loc, remote_log) + log = '*** Reading remote log from {}.\n{}\n'.format(remote_loc, remote_log) return log, {'end_of_log': True} else: return super()._read(ti, try_number) @@ -175,10 +177,7 @@ def wasb_write(self, log, remote_log_location, append=True): try: self.hook.load_string( - log, - self.wasb_container, - remote_log_location, + log, self.wasb_container, remote_log_location, ) except AzureHttpError: - self.log.exception('Could not write logs to %s', - remote_log_location) + self.log.exception('Could not write logs to %s', remote_log_location) diff --git a/airflow/providers/microsoft/azure/operators/adls_list.py b/airflow/providers/microsoft/azure/operators/adls_list.py index 576a2d8b150d8..ad9755712fae4 100644 --- a/airflow/providers/microsoft/azure/operators/adls_list.py +++ b/airflow/providers/microsoft/azure/operators/adls_list.py @@ -46,24 +46,21 @@ class AzureDataLakeStorageListOperator(BaseOperator): azure_data_lake_conn_id='azure_data_lake_default' ) """ + template_fields: Sequence[str] = ('path',) ui_color = '#901dd2' @apply_defaults - def __init__(self, *, - path: str, - azure_data_lake_conn_id: str = 'azure_data_lake_default', - **kwargs) -> None: + def __init__( + self, *, path: str, azure_data_lake_conn_id: str = 'azure_data_lake_default', **kwargs + ) -> None: super().__init__(**kwargs) self.path = path self.azure_data_lake_conn_id = azure_data_lake_conn_id - def execute(self, - context: Dict[Any, Any]) -> List: + def execute(self, context: Dict[Any, Any]) -> List: - hook = AzureDataLakeHook( - azure_data_lake_conn_id=self.azure_data_lake_conn_id - ) + hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) self.log.info('Getting list of ADLS files in path: %s', self.path) diff --git a/airflow/providers/microsoft/azure/operators/adx.py b/airflow/providers/microsoft/azure/operators/adx.py index eb4c8e9117e58..db1e485a939c5 100644 --- a/airflow/providers/microsoft/azure/operators/adx.py +++ b/airflow/providers/microsoft/azure/operators/adx.py @@ -44,16 +44,18 @@ class AzureDataExplorerQueryOperator(BaseOperator): ui_color = '#00a1f2' template_fields = ('query', 'database') - template_ext = ('.kql', ) + template_ext = ('.kql',) @apply_defaults def __init__( - self, *, - query: str, - database: str, - options: Optional[Dict] = None, - azure_data_explorer_conn_id: str = 'azure_data_explorer_default', - **kwargs) -> None: + self, + *, + query: str, + database: str, + options: Optional[Dict] = None, + azure_data_explorer_conn_id: str = 'azure_data_explorer_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.query = query self.database = database diff --git a/airflow/providers/microsoft/azure/operators/azure_batch.py b/airflow/providers/microsoft/azure/operators/azure_batch.py index 6cf080f907e14..1c8610a32e231 100644 --- a/airflow/providers/microsoft/azure/operators/azure_batch.py +++ b/airflow/providers/microsoft/azure/operators/azure_batch.py @@ -143,42 +143,50 @@ class AzureBatchOperator(BaseOperator): """ - template_fields = ('batch_pool_id', 'batch_pool_vm_size', 'batch_job_id', - 'batch_task_id', 'batch_task_command_line') + template_fields = ( + 'batch_pool_id', + 'batch_pool_vm_size', + 'batch_job_id', + 'batch_task_id', + 'batch_task_command_line', + ) ui_color = '#f0f0e4' @apply_defaults - def __init__(self, *, # pylint: disable=too-many-arguments,too-many-locals - batch_pool_id: str, - batch_pool_vm_size: str, - batch_job_id: str, - batch_task_command_line: str, - batch_task_id: str, - batch_pool_display_name: Optional[str] = None, - batch_job_display_name: Optional[str] = None, - batch_job_manager_task: Optional[batch_models.JobManagerTask] = None, - batch_job_preparation_task: Optional[batch_models.JobPreparationTask] = None, - batch_job_release_task: Optional[batch_models.JobReleaseTask] = None, - batch_task_display_name: Optional[str] = None, - batch_task_container_settings: Optional[batch_models.TaskContainerSettings] = None, - batch_start_task: Optional[batch_models.StartTask] = None, - batch_max_retries: int = 3, - batch_task_resource_files: Optional[List[batch_models.ResourceFile]] = None, - batch_task_output_files: Optional[List[batch_models.OutputFile]] = None, - batch_task_user_identity: Optional[batch_models.UserIdentity] = None, - target_low_priority_nodes: Optional[int] = None, - target_dedicated_nodes: Optional[int] = None, - enable_auto_scale: bool = False, - auto_scale_formula: Optional[str] = None, - azure_batch_conn_id='azure_batch_default', - use_latest_verified_vm_image_and_sku: bool = False, - vm_publisher: Optional[str] = None, - vm_offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - timeout: int = 25, - should_delete_job: bool = False, - should_delete_pool: bool = False, - **kwargs) -> None: + def __init__( + self, + *, # pylint: disable=too-many-arguments,too-many-locals + batch_pool_id: str, + batch_pool_vm_size: str, + batch_job_id: str, + batch_task_command_line: str, + batch_task_id: str, + batch_pool_display_name: Optional[str] = None, + batch_job_display_name: Optional[str] = None, + batch_job_manager_task: Optional[batch_models.JobManagerTask] = None, + batch_job_preparation_task: Optional[batch_models.JobPreparationTask] = None, + batch_job_release_task: Optional[batch_models.JobReleaseTask] = None, + batch_task_display_name: Optional[str] = None, + batch_task_container_settings: Optional[batch_models.TaskContainerSettings] = None, + batch_start_task: Optional[batch_models.StartTask] = None, + batch_max_retries: int = 3, + batch_task_resource_files: Optional[List[batch_models.ResourceFile]] = None, + batch_task_output_files: Optional[List[batch_models.OutputFile]] = None, + batch_task_user_identity: Optional[batch_models.UserIdentity] = None, + target_low_priority_nodes: Optional[int] = None, + target_dedicated_nodes: Optional[int] = None, + enable_auto_scale: bool = False, + auto_scale_formula: Optional[str] = None, + azure_batch_conn_id='azure_batch_default', + use_latest_verified_vm_image_and_sku: bool = False, + vm_publisher: Optional[str] = None, + vm_offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + timeout: int = 25, + should_delete_job: bool = False, + should_delete_pool: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.batch_pool_id = batch_pool_id @@ -216,61 +224,81 @@ def _check_inputs(self) -> Any: if self.use_latest_image: if not all(elem for elem in [self.vm_publisher, self.vm_offer, self.sku_starts_with]): - raise AirflowException("If use_latest_image_and_sku is" - " set to True then the parameters vm_publisher, vm_offer, " - "sku_starts_with must all be set. Found " - "vm_publisher={}, vm_offer={}, sku_starts_with={}". - format(self.vm_publisher, self.vm_offer, self.sku_starts_with)) + raise AirflowException( + "If use_latest_image_and_sku is" + " set to True then the parameters vm_publisher, vm_offer, " + "sku_starts_with must all be set. Found " + "vm_publisher={}, vm_offer={}, sku_starts_with={}".format( + self.vm_publisher, self.vm_offer, self.sku_starts_with + ) + ) if self.enable_auto_scale: if self.target_dedicated_nodes or self.target_low_priority_nodes: - raise AirflowException("If enable_auto_scale is set, then the parameters " - "target_dedicated_nodes and target_low_priority_nodes must not " - "be set. Found target_dedicated_nodes={}," - " target_low_priority_nodes={}" - .format(self.target_dedicated_nodes, self.target_low_priority_nodes)) + raise AirflowException( + "If enable_auto_scale is set, then the parameters " + "target_dedicated_nodes and target_low_priority_nodes must not " + "be set. Found target_dedicated_nodes={}," + " target_low_priority_nodes={}".format( + self.target_dedicated_nodes, self.target_low_priority_nodes + ) + ) if not self.auto_scale_formula: - raise AirflowException("The auto_scale_formula is required when enable_auto_scale is" - " set") + raise AirflowException("The auto_scale_formula is required when enable_auto_scale is" " set") if self.batch_job_release_task and not self.batch_job_preparation_task: - raise AirflowException("A batch_job_release_task cannot be specified without also " - " specifying a batch_job_preparation_task for the Job.") - if not all([self.batch_pool_id, self.batch_job_id, self.batch_pool_vm_size, - self.batch_task_id, self.batch_task_command_line]): - raise AirflowException("Some required parameters are missing.Please you must set " - "all the required parameters. ") - - def execute(self, - context: Dict[Any, Any]) -> None: + raise AirflowException( + "A batch_job_release_task cannot be specified without also " + " specifying a batch_job_preparation_task for the Job." + ) + if not all( + [ + self.batch_pool_id, + self.batch_job_id, + self.batch_pool_vm_size, + self.batch_task_id, + self.batch_task_command_line, + ] + ): + raise AirflowException( + "Some required parameters are missing.Please you must set " "all the required parameters. " + ) + + def execute(self, context: Dict[Any, Any]) -> None: self._check_inputs() self.hook.connection.config.retry_policy = self.batch_max_retries - pool = self.hook.configure_pool(pool_id=self.batch_pool_id, - vm_size=self.batch_pool_vm_size, - display_name=self.batch_pool_display_name, - target_dedicated_nodes=self.target_dedicated_nodes, - use_latest_image_and_sku=self.use_latest_image, - vm_publisher=self.vm_publisher, - vm_offer=self.vm_offer, - sku_starts_with=self.sku_starts_with, - target_low_priority_nodes=self.target_low_priority_nodes, - enable_auto_scale=self.enable_auto_scale, - auto_scale_formula=self.auto_scale_formula, - start_task=self.batch_start_task, - ) + pool = self.hook.configure_pool( + pool_id=self.batch_pool_id, + vm_size=self.batch_pool_vm_size, + display_name=self.batch_pool_display_name, + target_dedicated_nodes=self.target_dedicated_nodes, + use_latest_image_and_sku=self.use_latest_image, + vm_publisher=self.vm_publisher, + vm_offer=self.vm_offer, + sku_starts_with=self.sku_starts_with, + target_low_priority_nodes=self.target_low_priority_nodes, + enable_auto_scale=self.enable_auto_scale, + auto_scale_formula=self.auto_scale_formula, + start_task=self.batch_start_task, + ) self.hook.create_pool(pool) # Wait for nodes to reach complete state - self.hook.wait_for_all_node_state(self.batch_pool_id, - {batch_models.ComputeNodeState.start_task_failed, - batch_models.ComputeNodeState.unusable, - batch_models.ComputeNodeState.idle} - ) + self.hook.wait_for_all_node_state( + self.batch_pool_id, + { + batch_models.ComputeNodeState.start_task_failed, + batch_models.ComputeNodeState.unusable, + batch_models.ComputeNodeState.idle, + }, + ) # Create job if not already exist - job = self.hook.configure_job(job_id=self.batch_job_id, - pool_id=self.batch_pool_id, - display_name=self.batch_job_display_name, - job_manager_task=self.batch_job_manager_task, - job_preparation_task=self.batch_job_preparation_task, - job_release_task=self.batch_job_release_task) + job = self.hook.configure_job( + job_id=self.batch_job_id, + pool_id=self.batch_pool_id, + display_name=self.batch_job_display_name, + job_manager_task=self.batch_job_manager_task, + job_preparation_task=self.batch_job_preparation_task, + job_release_task=self.batch_job_release_task, + ) self.hook.create_job(job) # Create task task = self.hook.configure_task( @@ -280,17 +308,12 @@ def execute(self, container_settings=self.batch_task_container_settings, resource_files=self.batch_task_resource_files, output_files=self.batch_task_output_files, - user_identity=self.batch_task_user_identity + user_identity=self.batch_task_user_identity, ) # Add task to job - self.hook.add_single_task_to_job( - job_id=self.batch_job_id, - task=task - ) + self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task) # Wait for tasks to complete - self.hook.wait_for_job_tasks_to_complete( - job_id=self.batch_job_id, - timeout=self.timeout) + self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout) # Clean up if self.should_delete_job: # delete job first @@ -300,8 +323,7 @@ def execute(self, def on_kill(self) -> None: response = self.hook.connection.job.terminate( - job_id=self.batch_job_id, - terminate_reason='Job killed by user' + job_id=self.batch_job_id, terminate_reason='Job killed by user' ) self.log.info("Azure Batch job (%s) terminated: %s", self.batch_job_id, response) @@ -310,13 +332,9 @@ def get_hook(self) -> AzureBatchHook: Create and return an AzureBatchHook. """ - return AzureBatchHook( - azure_batch_conn_id=self.azure_batch_conn_id - ) + return AzureBatchHook(azure_batch_conn_id=self.azure_batch_conn_id) - def clean_up(self, - pool_id: Optional[str] = None, - job_id: Optional[str] = None) -> None: + def clean_up(self, pool_id: Optional[str] = None, job_id: Optional[str] = None) -> None: """ Delete the given pool and job in the batch account diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py index 9f4b407eb693e..923768324b44c 100644 --- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py +++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py @@ -22,7 +22,12 @@ from typing import Any, Dict, List, Optional, Sequence, Union from azure.mgmt.containerinstance.models import ( - Container, ContainerGroup, EnvironmentVariable, ResourceRequests, ResourceRequirements, VolumeMount, + Container, + ContainerGroup, + EnvironmentVariable, + ResourceRequests, + ResourceRequirements, + VolumeMount, ) from msrestazure.azure_exceptions import CloudError @@ -33,10 +38,7 @@ from airflow.providers.microsoft.azure.hooks.azure_container_volume import AzureContainerVolumeHook from airflow.utils.decorators import apply_defaults -Volume = namedtuple( - 'Volume', - ['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'], -) +Volume = namedtuple('Volume', ['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'],) DEFAULT_ENVIRONMENT_VARIABLES = {} # type: Dict[str, str] @@ -122,24 +124,27 @@ class AzureContainerInstancesOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - ci_conn_id: str, - registry_conn_id: Optional[str], - resource_group: str, - name: str, - image: str, - region: str, - environment_variables: Optional[Dict[Any, Any]] = None, - secured_variables: Optional[str] = None, - volumes: Optional[List[Any]] = None, - memory_in_gb: Optional[Any] = None, - cpu: Optional[Any] = None, - gpu: Optional[Any] = None, - command: Optional[str] = None, - remove_on_error: bool = True, - fail_if_exists: bool = True, - tags: Optional[Dict[str, str]] = None, - **kwargs) -> None: + def __init__( + self, + *, + ci_conn_id: str, + registry_conn_id: Optional[str], + resource_group: str, + name: str, + image: str, + region: str, + environment_variables: Optional[Dict[Any, Any]] = None, + secured_variables: Optional[str] = None, + volumes: Optional[List[Any]] = None, + memory_in_gb: Optional[Any] = None, + cpu: Optional[Any] = None, + gpu: Optional[Any] = None, + command: Optional[str] = None, + remove_on_error: bool = True, + fail_if_exists: bool = True, + tags: Optional[Dict[str, str]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.ci_conn_id = ci_conn_id @@ -160,8 +165,7 @@ def __init__(self, *, self._ci_hook: Any = None self.tags = tags - def execute(self, - context: Dict[Any, Any]) -> int: + def execute(self, context: Dict[Any, Any]) -> int: # Check name again in case it was templated. self._check_name(self.name) @@ -174,7 +178,9 @@ def execute(self, if self.registry_conn_id: registry_hook = AzureContainerRegistryHook(self.registry_conn_id) - image_registry_credentials: Optional[List[Any]] = [registry_hook.connection, ] + image_registry_credentials: Optional[List[Any]] = [ + registry_hook.connection, + ] else: image_registry_credentials = None @@ -192,26 +198,18 @@ def execute(self, hook = AzureContainerVolumeHook(conn_id) mount_name = "mount-%d" % len(volumes) - volumes.append(hook.get_file_volume(mount_name, - share_name, - account_name, - read_only)) - volume_mounts.append(VolumeMount(name=mount_name, - mount_path=mount_path, - read_only=read_only)) + volumes.append(hook.get_file_volume(mount_name, share_name, account_name, read_only)) + volume_mounts.append(VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only)) exit_code = 1 try: - self.log.info("Starting container group with %.1f cpu %.1f mem", - self.cpu, self.memory_in_gb) + self.log.info("Starting container group with %.1f cpu %.1f mem", self.cpu, self.memory_in_gb) if self.gpu: - self.log.info("GPU count: %.1f, GPU SKU: %s", - self.gpu.count, self.gpu.sku) + self.log.info("GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku) - resources = ResourceRequirements(requests=ResourceRequests( - memory_in_gb=self.memory_in_gb, - cpu=self.cpu, - gpu=self.gpu)) + resources = ResourceRequirements( + requests=ResourceRequests(memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu) + ) container = Container( name=self.name, @@ -219,16 +217,18 @@ def execute(self, resources=resources, command=self.command, environment_variables=environment_variables, - volume_mounts=volume_mounts) + volume_mounts=volume_mounts, + ) container_group = ContainerGroup( location=self.region, - containers=[container, ], + containers=[container,], image_registry_credentials=image_registry_credentials, volumes=volumes, restart_policy='Never', os_type='Linux', - tags=self.tags) + tags=self.tags, + ) self._ci_hook.create_or_update(self.resource_group, self.name, container_group) @@ -238,8 +238,7 @@ def execute(self, self.log.info("Container had exit code: %s", exit_code) if exit_code != 0: - raise AirflowException("Container had a non-zero exit code, %s" - % exit_code) + raise AirflowException("Container had a non-zero exit code, %s" % exit_code) return exit_code except CloudError: @@ -272,9 +271,11 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: # If there is no instance view, we show the provisioning state if instance_view is not None: c_state = instance_view.current_state - state, exit_code, detail_status = (c_state.state, - c_state.exit_code, - c_state.detail_status) + state, exit_code, detail_status = ( + c_state.state, + c_state.exit_code, + c_state.detail_status, + ) messages = [event.message for event in instance_view.events] last_message_logged = self._log_last(messages, last_message_logged) @@ -292,8 +293,9 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: logs = self._ci_hook.get_logs(resource_group, name) last_line_logged = self._log_last(logs, last_line_logged) except CloudError: - self.log.exception("Exception while getting logs from " - "container instance, retrying...") + self.log.exception( + "Exception while getting logs from " "container instance, retrying..." + ) if state == "Terminated": self.log.error("Container exited with detail_status %s", detail_status) @@ -307,9 +309,11 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: raise except CloudError as err: if 'ResourceNotFound' in str(err): - self.log.warning("ResourceNotFound, container is probably removed " - "by another process " - "(make sure that the name is unique).") + self.log.warning( + "ResourceNotFound, container is probably removed " + "by another process " + "(make sure that the name is unique)." + ) return 1 else: self.log.exception("Exception while getting container groups") @@ -318,9 +322,7 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: sleep(1) - def _log_last(self, - logs: Optional[List[Any]], - last_line_logged: Any) -> Optional[Any]: + def _log_last(self, logs: Optional[List[Any]], last_line_logged: Any) -> Optional[Any]: if logs: # determine the last line which was logged before last_line_index = 0 diff --git a/airflow/providers/microsoft/azure/operators/azure_cosmos.py b/airflow/providers/microsoft/azure/operators/azure_cosmos.py index a206f25cc0868..23d5fee6501dd 100644 --- a/airflow/providers/microsoft/azure/operators/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/operators/azure_cosmos.py @@ -36,16 +36,20 @@ class AzureCosmosInsertDocumentOperator(BaseOperator): :param azure_cosmos_conn_id: reference to a CosmosDB connection. :type azure_cosmos_conn_id: str """ + template_fields = ('database_name', 'collection_name') ui_color = '#e4f0e8' @apply_defaults - def __init__(self, *, - database_name: str, - collection_name: str, - document: dict, - azure_cosmos_conn_id: str = 'azure_cosmos_default', - **kwargs) -> None: + def __init__( + self, + *, + database_name: str, + collection_name: str, + document: dict, + azure_cosmos_conn_id: str = 'azure_cosmos_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.database_name = database_name self.collection_name = collection_name diff --git a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py index 7779a60222c1f..5e5d6f2559e6d 100644 --- a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py +++ b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py @@ -45,14 +45,17 @@ class WasbDeleteBlobOperator(BaseOperator): template_fields = ('container_name', 'blob_name') @apply_defaults - def __init__(self, *, - container_name: str, - blob_name: str, - wasb_conn_id: str = 'wasb_default', - check_options: Any = None, - is_prefix: bool = False, - ignore_if_missing: bool = False, - **kwargs) -> None: + def __init__( + self, + *, + container_name: str, + blob_name: str, + wasb_conn_id: str = 'wasb_default', + check_options: Any = None, + is_prefix: bool = False, + ignore_if_missing: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) if check_options is None: check_options = {} @@ -64,11 +67,9 @@ def __init__(self, *, self.ignore_if_missing = ignore_if_missing def execute(self, context: Dict[Any, Any]) -> None: - self.log.info( - 'Deleting blob: %s\nin wasb://%s', self.blob_name, self.container_name - ) + self.log.info('Deleting blob: %s\nin wasb://%s', self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) - hook.delete_file(self.container_name, self.blob_name, - self.is_prefix, self.ignore_if_missing, - **self.check_options) + hook.delete_file( + self.container_name, self.blob_name, self.is_prefix, self.ignore_if_missing, **self.check_options + ) diff --git a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py index 7235b94e86918..1b7eab23a97a9 100644 --- a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py @@ -42,16 +42,19 @@ class AzureCosmosDocumentSensor(BaseSensorOperator): :param azure_cosmos_conn_id: Reference to the Azure CosmosDB connection. :type azure_cosmos_conn_id: str """ + template_fields = ('database_name', 'collection_name', 'document_id') @apply_defaults def __init__( - self, *, - database_name: str, - collection_name: str, - document_id: str, - azure_cosmos_conn_id: str = "azure_cosmos_default", - **kwargs) -> None: + self, + *, + database_name: str, + collection_name: str, + document_id: str, + azure_cosmos_conn_id: str = "azure_cosmos_default", + **kwargs, + ) -> None: super().__init__(**kwargs) self.azure_cosmos_conn_id = azure_cosmos_conn_id self.database_name = database_name diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index 033b89fe154cd..57d016b2db45b 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -41,12 +41,15 @@ class WasbBlobSensor(BaseSensorOperator): template_fields = ('container_name', 'blob_name') @apply_defaults - def __init__(self, *, - container_name: str, - blob_name: str, - wasb_conn_id: str = 'wasb_default', - check_options: Optional[dict] = None, - **kwargs): + def __init__( + self, + *, + container_name: str, + blob_name: str, + wasb_conn_id: str = 'wasb_default', + check_options: Optional[dict] = None, + **kwargs, + ): super().__init__(**kwargs) if check_options is None: check_options = {} @@ -56,12 +59,9 @@ def __init__(self, *, self.check_options = check_options def poke(self, context: Dict[Any, Any]): - self.log.info( - 'Poking for blob: %s\nin wasb://%s', self.blob_name, self.container_name - ) + self.log.info('Poking for blob: %s\nin wasb://%s', self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) - return hook.check_for_blob(self.container_name, self.blob_name, - **self.check_options) + return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options) class WasbPrefixSensor(BaseSensorOperator): @@ -82,12 +82,15 @@ class WasbPrefixSensor(BaseSensorOperator): template_fields = ('container_name', 'prefix') @apply_defaults - def __init__(self, *, - container_name: str, - prefix: str, - wasb_conn_id: str = 'wasb_default', - check_options: Optional[dict] = None, - **kwargs) -> None: + def __init__( + self, + *, + container_name: str, + prefix: str, + wasb_conn_id: str = 'wasb_default', + check_options: Optional[dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if check_options is None: check_options = {} @@ -96,9 +99,7 @@ def __init__(self, *, self.prefix = prefix self.check_options = check_options - def poke(self, - context: Dict[Any, Any]) -> bool: + def poke(self, context: Dict[Any, Any]) -> bool: self.log.info('Poking for prefix: %s in wasb://%s', self.prefix, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) - return hook.check_for_prefix(self.container_name, self.prefix, - **self.check_options) + return hook.check_for_prefix(self.container_name, self.prefix, **self.check_options) diff --git a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py index ac1b415f2ceb1..cb9d8e70d126b 100644 --- a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py +++ b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py @@ -39,16 +39,20 @@ class FileToWasbOperator(BaseOperator): `WasbHook.load_file()` takes. :type load_options: Optional[dict] """ + template_fields = ('file_path', 'container_name', 'blob_name') @apply_defaults - def __init__(self, *, - file_path: str, - container_name: str, - blob_name: str, - wasb_conn_id: str = 'wasb_default', - load_options: Optional[dict] = None, - **kwargs) -> None: + def __init__( + self, + *, + file_path: str, + container_name: str, + blob_name: str, + wasb_conn_id: str = 'wasb_default', + load_options: Optional[dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if load_options is None: load_options = {} @@ -58,13 +62,10 @@ def __init__(self, *, self.wasb_conn_id = wasb_conn_id self.load_options = load_options - def execute(self, - context: Dict[Any, Any]) -> None: + def execute(self, context: Dict[Any, Any]) -> None: """Upload a file to Azure Blob Storage.""" hook = WasbHook(wasb_conn_id=self.wasb_conn_id) self.log.info( - 'Uploading %s to wasb://%s as %s', - self.file_path, self.container_name, self.blob_name, + 'Uploading %s to wasb://%s as %s', self.file_path, self.container_name, self.blob_name, ) - hook.load_file(self.file_path, self.container_name, - self.blob_name, **self.load_options) + hook.load_file(self.file_path, self.container_name, self.blob_name, **self.load_options) diff --git a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py index a83c1ae727da5..60630c5e9c806 100644 --- a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py +++ b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py @@ -62,18 +62,20 @@ class OracleToAzureDataLakeOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults def __init__( - self, *, - filename: str, - azure_data_lake_conn_id: str, - azure_data_lake_path: str, - oracle_conn_id: str, - sql: str, - sql_params: Optional[dict] = None, - delimiter: str = ",", - encoding: str = "utf-8", - quotechar: str = '"', - quoting: str = csv.QUOTE_MINIMAL, - **kwargs) -> None: + self, + *, + filename: str, + azure_data_lake_conn_id: str, + azure_data_lake_path: str, + oracle_conn_id: str, + sql: str, + sql_params: Optional[dict] = None, + delimiter: str = ",", + encoding: str = "utf-8", + quotechar: str = '"', + quoting: str = csv.QUOTE_MINIMAL, + **kwargs, + ) -> None: super().__init__(**kwargs) if sql_params is None: sql_params = {} @@ -88,22 +90,22 @@ def __init__( self.quotechar = quotechar self.quoting = quoting - def _write_temp_file(self, - cursor: Any, - path_to_save: Union[str, bytes, int]) -> None: + def _write_temp_file(self, cursor: Any, path_to_save: Union[str, bytes, int]) -> None: with open(path_to_save, 'wb') as csvfile: - csv_writer = csv.writer(csvfile, delimiter=self.delimiter, - encoding=self.encoding, quotechar=self.quotechar, - quoting=self.quoting) + csv_writer = csv.writer( + csvfile, + delimiter=self.delimiter, + encoding=self.encoding, + quotechar=self.quotechar, + quoting=self.quoting, + ) csv_writer.writerow(map(lambda field: field[0], cursor.description)) csv_writer.writerows(cursor) csvfile.flush() - def execute(self, - context: Dict[Any, Any]) -> None: + def execute(self, context: Dict[Any, Any]) -> None: oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - azure_data_lake_hook = AzureDataLakeHook( - azure_data_lake_conn_id=self.azure_data_lake_conn_id) + azure_data_lake_hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) self.log.info("Dumping Oracle query results to local file") conn = oracle_hook.get_conn() @@ -113,8 +115,8 @@ def execute(self, with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: self._write_temp_file(cursor, os.path.join(temp, self.filename)) self.log.info("Uploading local file to Azure Data Lake") - azure_data_lake_hook.upload_file(os.path.join(temp, self.filename), - os.path.join(self.azure_data_lake_path, - self.filename)) + azure_data_lake_hook.upload_file( + os.path.join(temp, self.filename), os.path.join(self.azure_data_lake_path, self.filename) + ) cursor.close() conn.close() diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py index 6b8f4d9396590..25d6815a5e9fd 100644 --- a/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/airflow/providers/microsoft/mssql/operators/mssql.py @@ -53,13 +53,14 @@ class MsSqlOperator(BaseOperator): @apply_defaults def __init__( - self, *, + self, + *, sql: str, mssql_conn_id: str = 'mssql_default', parameters: Optional[Union[Mapping, Iterable]] = None, autocommit: bool = False, database: Optional[str] = None, - **kwargs + **kwargs, ) -> None: super().__init__(**kwargs) self.mssql_conn_id = mssql_conn_id diff --git a/airflow/providers/microsoft/winrm/example_dags/example_winrm.py b/airflow/providers/microsoft/winrm/example_dags/example_winrm.py index 85b4038fb77f9..d91aa533bda1f 100644 --- a/airflow/providers/microsoft/winrm/example_dags/example_winrm.py +++ b/airflow/providers/microsoft/winrm/example_dags/example_winrm.py @@ -54,22 +54,10 @@ winRMHook = WinRMHook(ssh_conn_id='ssh_POC1') - t1 = WinRMOperator( - task_id="wintask1", - command='ls -altr', - winrm_hook=winRMHook - ) + t1 = WinRMOperator(task_id="wintask1", command='ls -altr', winrm_hook=winRMHook) - t2 = WinRMOperator( - task_id="wintask2", - command='sleep 60', - winrm_hook=winRMHook - ) + t2 = WinRMOperator(task_id="wintask2", command='sleep 60', winrm_hook=winRMHook) - t3 = WinRMOperator( - task_id="wintask3", - command='echo \'luke test\' ', - winrm_hook=winRMHook - ) + t3 = WinRMOperator(task_id="wintask3", command='echo \'luke test\' ', winrm_hook=winRMHook) [t1, t2, t3] >> run_this_last diff --git a/airflow/providers/microsoft/winrm/hooks/winrm.py b/airflow/providers/microsoft/winrm/hooks/winrm.py index f58ef18f3f503..ad6e5ca57caa4 100644 --- a/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -88,27 +88,29 @@ class WinRMHook(BaseHook): :type send_cbt: bool """ - def __init__(self, - ssh_conn_id=None, - endpoint=None, - remote_host=None, - remote_port=5985, - transport='plaintext', - username=None, - password=None, - service='HTTP', - keytab=None, - ca_trust_path=None, - cert_pem=None, - cert_key_pem=None, - server_cert_validation='validate', - kerberos_delegation=False, - read_timeout_sec=30, - operation_timeout_sec=20, - kerberos_hostname_override=None, - message_encryption='auto', - credssp_disable_tlsv1_2=False, - send_cbt=True): + def __init__( + self, + ssh_conn_id=None, + endpoint=None, + remote_host=None, + remote_port=5985, + transport='plaintext', + username=None, + password=None, + service='HTTP', + keytab=None, + ca_trust_path=None, + cert_pem=None, + cert_key_pem=None, + server_cert_validation='validate', + kerberos_delegation=False, + read_timeout_sec=30, + operation_timeout_sec=20, + kerberos_hostname_override=None, + message_encryption='auto', + credssp_disable_tlsv1_2=False, + send_cbt=True, + ): super().__init__() self.ssh_conn_id = ssh_conn_id self.endpoint = endpoint @@ -181,8 +183,9 @@ def get_conn(self): if "message_encryption" in extra_options: self.message_encryption = str(extra_options["message_encryption"]) if "credssp_disable_tlsv1_2" in extra_options: - self.credssp_disable_tlsv1_2 = \ + self.credssp_disable_tlsv1_2 = ( str(extra_options["credssp_disable_tlsv1_2"]).lower() == 'true' + ) if "send_cbt" in extra_options: self.send_cbt = str(extra_options["send_cbt"]).lower() == 'true' @@ -194,7 +197,8 @@ def get_conn(self): self.log.debug( "username to WinRM to host: %s is not specified for connection id" " %s. Using system's default provided by getpass.getuser()", - self.remote_host, self.ssh_conn_id + self.remote_host, + self.ssh_conn_id, ) self.username = getpass.getuser() @@ -221,7 +225,7 @@ def get_conn(self): kerberos_hostname_override=self.kerberos_hostname_override, message_encryption=self.message_encryption, credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2, - send_cbt=self.send_cbt + send_cbt=self.send_cbt, ) self.log.info("Establishing WinRM connection to host: %s", self.remote_host) diff --git a/airflow/providers/microsoft/winrm/operators/winrm.py b/airflow/providers/microsoft/winrm/operators/winrm.py index 7afbcfbe72016..a0c2c7621381b 100644 --- a/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/airflow/providers/microsoft/winrm/operators/winrm.py @@ -48,16 +48,13 @@ class WinRMOperator(BaseOperator): :param timeout: timeout for executing the command. :type timeout: int """ + template_fields = ('command',) @apply_defaults - def __init__(self, *, - winrm_hook=None, - ssh_conn_id=None, - remote_host=None, - command=None, - timeout=10, - **kwargs): + def __init__( + self, *, winrm_hook=None, ssh_conn_id=None, remote_host=None, command=None, timeout=10, **kwargs + ): super().__init__(**kwargs) self.winrm_hook = winrm_hook self.ssh_conn_id = ssh_conn_id @@ -84,10 +81,7 @@ def execute(self, context): # pylint: disable=too-many-nested-blocks try: self.log.info("Running command: '%s'...", self.command) - command_id = self.winrm_hook.winrm_protocol.run_command( - winrm_client, - self.command - ) + command_id = self.winrm_hook.winrm_protocol.run_command(winrm_client, self.command) # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py stdout_buffer = [] @@ -96,11 +90,12 @@ def execute(self, context): while not command_done: try: # pylint: disable=protected-access - stdout, stderr, return_code, command_done = \ - self.winrm_hook.winrm_protocol._raw_get_command_output( - winrm_client, - command_id - ) + ( + stdout, + stderr, + return_code, + command_done, + ) = self.winrm_hook.winrm_protocol._raw_get_command_output(winrm_client, command_id) # Only buffer stdout if we need to so that we minimize memory usage. if self.do_xcom_push: @@ -124,17 +119,13 @@ def execute(self, context): if return_code == 0: # returning output if do_xcom_push is set - enable_pickling = conf.getboolean( - 'core', 'enable_xcom_pickling' - ) + enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if enable_pickling: return stdout_buffer else: return b64encode(b''.join(stdout_buffer)).decode('utf-8') else: error_msg = "Error running cmd: {0}, return code: {1}, error: {2}".format( - self.command, - return_code, - b''.join(stderr_buffer).decode('utf-8') + self.command, return_code, b''.join(stderr_buffer).decode('utf-8') ) raise AirflowException(error_msg) diff --git a/airflow/providers/mongo/hooks/mongo.py b/airflow/providers/mongo/hooks/mongo.py index d46fde3629268..71bd3b53d2923 100644 --- a/airflow/providers/mongo/hooks/mongo.py +++ b/airflow/providers/mongo/hooks/mongo.py @@ -39,6 +39,7 @@ class MongoHook(BaseHook): ex. {"srv": true, "replicaSet": "test", "ssl": true, "connectTimeoutMS": 30000} """ + conn_type = 'mongo' def __init__(self, conn_id: str = 'mongo_default', *args, **kwargs) -> None: @@ -54,22 +55,23 @@ def __init__(self, conn_id: str = 'mongo_default', *args, **kwargs) -> None: self.uri = '{scheme}://{creds}{host}{port}/{database}'.format( scheme=scheme, - creds='{}:{}@'.format( - self.connection.login, self.connection.password - ) if self.connection.login else '', - + creds='{}:{}@'.format(self.connection.login, self.connection.password) + if self.connection.login + else '', host=self.connection.host, port='' if self.connection.port is None else ':{}'.format(self.connection.port), - database=self.connection.schema + database=self.connection.schema, ) def __enter__(self): return self - def __exit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: if self.client is not None: self.close_conn() @@ -98,9 +100,9 @@ def close_conn(self) -> None: client.close() self.client = None - def get_collection(self, - mongo_collection: str, - mongo_db: Optional[str] = None) -> pymongo.collection.Collection: + def get_collection( + self, mongo_collection: str, mongo_db: Optional[str] = None + ) -> pymongo.collection.Collection: """ Fetches a mongo collection object for querying. @@ -111,11 +113,9 @@ def get_collection(self, return mongo_conn.get_database(mongo_db).get_collection(mongo_collection) - def aggregate(self, - mongo_collection: str, - aggregate_query: list, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.command_cursor.CommandCursor: + def aggregate( + self, mongo_collection: str, aggregate_query: list, mongo_db: Optional[str] = None, **kwargs + ) -> pymongo.command_cursor.CommandCursor: """ Runs an aggregation pipeline and returns the results https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.aggregate @@ -125,12 +125,14 @@ def aggregate(self, return collection.aggregate(aggregate_query, **kwargs) - def find(self, - mongo_collection: str, - query: dict, - find_one: bool = False, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.cursor.Cursor: + def find( + self, + mongo_collection: str, + query: dict, + find_one: bool = False, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.cursor.Cursor: """ Runs a mongo find query and returns the results https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.find @@ -142,11 +144,9 @@ def find(self, else: return collection.find(query, **kwargs) - def insert_one(self, - mongo_collection: str, - doc: dict, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.InsertOneResult: + def insert_one( + self, mongo_collection: str, doc: dict, mongo_db: Optional[str] = None, **kwargs + ) -> pymongo.results.InsertOneResult: """ Inserts a single document into a mongo collection https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_one @@ -155,11 +155,9 @@ def insert_one(self, return collection.insert_one(doc, **kwargs) - def insert_many(self, - mongo_collection: str, - docs: dict, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.InsertManyResult: + def insert_many( + self, mongo_collection: str, docs: dict, mongo_db: Optional[str] = None, **kwargs + ) -> pymongo.results.InsertManyResult: """ Inserts many docs into a mongo collection. https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_many @@ -168,12 +166,14 @@ def insert_many(self, return collection.insert_many(docs, **kwargs) - def update_one(self, - mongo_collection: str, - filter_doc: dict, - update_doc: dict, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.UpdateResult: + def update_one( + self, + mongo_collection: str, + filter_doc: dict, + update_doc: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.UpdateResult: """ Updates a single document in a mongo collection. https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one @@ -193,12 +193,14 @@ def update_one(self, return collection.update_one(filter_doc, update_doc, **kwargs) - def update_many(self, - mongo_collection: str, - filter_doc: dict, - update_doc: dict, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.UpdateResult: + def update_many( + self, + mongo_collection: str, + filter_doc: dict, + update_doc: dict, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.UpdateResult: """ Updates one or more documents in a mongo collection. https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many @@ -218,12 +220,14 @@ def update_many(self, return collection.update_many(filter_doc, update_doc, **kwargs) - def replace_one(self, - mongo_collection: str, - doc: dict, - filter_doc: Optional[dict] = None, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.UpdateResult: + def replace_one( + self, + mongo_collection: str, + doc: dict, + filter_doc: Optional[dict] = None, + mongo_db: Optional[str] = None, + **kwargs, + ) -> pymongo.results.UpdateResult: """ Replaces a single document in a mongo collection. https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one @@ -250,14 +254,16 @@ def replace_one(self, return collection.replace_one(filter_doc, doc, **kwargs) - def replace_many(self, - mongo_collection: str, - docs: List[dict], - filter_docs: Optional[List[dict]] = None, - mongo_db: Optional[str] = None, - upsert: bool = False, - collation: Optional[pymongo.collation.Collation] = None, - **kwargs) -> pymongo.results.BulkWriteResult: + def replace_many( + self, + mongo_collection: str, + docs: List[dict], + filter_docs: Optional[List[dict]] = None, + mongo_db: Optional[str] = None, + upsert: bool = False, + collation: Optional[pymongo.collation.Collation] = None, + **kwargs, + ) -> pymongo.results.BulkWriteResult: """ Replaces many documents in a mongo collection. @@ -294,21 +300,14 @@ def replace_many(self, filter_docs = [{'_id': doc['_id']} for doc in docs] requests = [ - ReplaceOne( - filter_docs[i], - docs[i], - upsert=upsert, - collation=collation) - for i in range(len(docs)) + ReplaceOne(filter_docs[i], docs[i], upsert=upsert, collation=collation) for i in range(len(docs)) ] return collection.bulk_write(requests, **kwargs) - def delete_one(self, - mongo_collection: str, - filter_doc: dict, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.DeleteResult: + def delete_one( + self, mongo_collection: str, filter_doc: dict, mongo_db: Optional[str] = None, **kwargs + ) -> pymongo.results.DeleteResult: """ Deletes a single document in a mongo collection. https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one @@ -326,11 +325,9 @@ def delete_one(self, return collection.delete_one(filter_doc, **kwargs) - def delete_many(self, - mongo_collection: str, - filter_doc: dict, - mongo_db: Optional[str] = None, - **kwargs) -> pymongo.results.DeleteResult: + def delete_many( + self, mongo_collection: str, filter_doc: dict, mongo_db: Optional[str] = None, **kwargs + ) -> pymongo.results.DeleteResult: """ Deletes one or more documents in a mongo collection. https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many diff --git a/airflow/providers/mongo/sensors/mongo.py b/airflow/providers/mongo/sensors/mongo.py index 58a0e893b6a68..1c80bcb586255 100644 --- a/airflow/providers/mongo/sensors/mongo.py +++ b/airflow/providers/mongo/sensors/mongo.py @@ -38,21 +38,21 @@ class MongoSensor(BaseSensorOperator): when connecting to MongoDB. :type mongo_conn_id: str """ + template_fields = ('collection', 'query') @apply_defaults - def __init__(self, *, - collection: str, - query: dict, - mongo_conn_id: str = "mongo_default", - **kwargs) -> None: + def __init__( + self, *, collection: str, query: dict, mongo_conn_id: str = "mongo_default", **kwargs + ) -> None: super().__init__(**kwargs) self.mongo_conn_id = mongo_conn_id self.collection = collection self.query = query def poke(self, context: dict) -> bool: - self.log.info("Sensor check existence of the document " - "that matches the following query: %s", self.query) + self.log.info( + "Sensor check existence of the document " "that matches the following query: %s", self.query + ) hook = MongoHook(self.mongo_conn_id) return hook.find(self.collection, self.query, find_one=True) is not None diff --git a/airflow/providers/mysql/hooks/mysql.py b/airflow/providers/mysql/hooks/mysql.py index 6e3ca125afc80..2e8432e0b0eff 100644 --- a/airflow/providers/mysql/hooks/mysql.py +++ b/airflow/providers/mysql/hooks/mysql.py @@ -50,13 +50,13 @@ def __init__(self, *args, **kwargs) -> None: self.schema = kwargs.pop("schema", None) self.connection = kwargs.pop("connection", None) - def set_autocommit(self, conn: Connection, autocommit: bool) -> None: # noqa: D403 + def set_autocommit(self, conn: Connection, autocommit: bool) -> None: # noqa: D403 """ MySql connection sets autocommit in a different way. """ conn.autocommit(autocommit) - def get_autocommit(self, conn: Connection) -> bool: # noqa: D403 + def get_autocommit(self, conn: Connection) -> bool: # noqa: D403 """ MySql connection gets autocommit in a different way. @@ -72,7 +72,7 @@ def _get_conn_config_mysql_client(self, conn: Connection) -> Dict: "user": conn.login, "passwd": conn.password or '', "host": conn.host or 'localhost', - "db": self.schema or conn.schema or '' + "db": self.schema or conn.schema or '', } # check for authentication via AWS IAM @@ -88,6 +88,7 @@ def _get_conn_config_mysql_client(self, conn: Connection) -> Dict: conn_config["use_unicode"] = True if conn.extra_dejson.get('cursor', False): import MySQLdb.cursors + if (conn.extra_dejson["cursor"]).lower() == 'sscursor': conn_config["cursorclass"] = MySQLdb.cursors.SSCursor elif (conn.extra_dejson["cursor"]).lower() == 'dictcursor': @@ -115,7 +116,7 @@ def _get_conn_config_mysql_connector_python(self, conn: Connection) -> Dict: 'password': conn.password or '', 'host': conn.host or 'localhost', 'database': self.schema or conn.schema or '', - 'port': int(conn.port) if conn.port else 3306 + 'port': int(conn.port) if conn.port else 3306, } if conn.extra_dejson.get('allow_local_infile', False): @@ -140,11 +141,13 @@ def get_conn(self): if client_name == 'mysqlclient': import MySQLdb + conn_config = self._get_conn_config_mysql_client(conn) return MySQLdb.connect(**conn_config) if client_name == 'mysql-connector-python': import mysql.connector # pylint: disable=no-name-in-module + conn_config = self._get_conn_config_mysql_connector_python(conn) return mysql.connector.connect(**conn_config) # pylint: disable=no-member @@ -164,10 +167,14 @@ def bulk_load(self, table: str, tmp_file: str) -> None: """ conn = self.get_conn() cur = conn.cursor() - cur.execute(""" + cur.execute( + """ LOAD DATA LOCAL INFILE '{tmp_file}' INTO TABLE {table} - """.format(tmp_file=tmp_file, table=table)) + """.format( + tmp_file=tmp_file, table=table + ) + ) conn.commit() def bulk_dump(self, table: str, tmp_file: str) -> None: @@ -176,16 +183,19 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: """ conn = self.get_conn() cur = conn.cursor() - cur.execute(""" + cur.execute( + """ SELECT * INTO OUTFILE '{tmp_file}' FROM {table} - """.format(tmp_file=tmp_file, table=table)) + """.format( + tmp_file=tmp_file, table=table + ) + ) conn.commit() @staticmethod def _serialize_cell( - cell: object, - conn: Optional[Connection] = None + cell: object, conn: Optional[Connection] = None ) -> object: # pylint: disable=signature-differs # noqa: D403 """ MySQLdb converts an argument to a literal @@ -218,10 +228,9 @@ def get_iam_token(self, conn: Connection) -> Tuple[str, int]: token = client.generate_db_auth_token(conn.host, port, conn.login) return token, port - def bulk_load_custom(self, table: str, - tmp_file: str, - duplicate_key_handling: str = 'IGNORE', - extra_options: str = '') -> None: + def bulk_load_custom( + self, table: str, tmp_file: str, duplicate_key_handling: str = 'IGNORE', extra_options: str = '' + ) -> None: """ A more configurable way to load local data from a file into the database. @@ -248,17 +257,19 @@ def bulk_load_custom(self, table: str, conn = self.get_conn() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ LOAD DATA LOCAL INFILE '{tmp_file}' {duplicate_key_handling} INTO TABLE {table} {extra_options} """.format( - tmp_file=tmp_file, - table=table, - duplicate_key_handling=duplicate_key_handling, - extra_options=extra_options - )) + tmp_file=tmp_file, + table=table, + duplicate_key_handling=duplicate_key_handling, + extra_options=extra_options, + ) + ) cursor.close() conn.commit() diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py index 85c74cc805341..85abe64a18a3f 100644 --- a/airflow/providers/mysql/operators/mysql.py +++ b/airflow/providers/mysql/operators/mysql.py @@ -48,13 +48,15 @@ class MySqlOperator(BaseOperator): @apply_defaults def __init__( - self, *, - sql: str, - mysql_conn_id: str = 'mysql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - autocommit: bool = False, - database: Optional[str] = None, - **kwargs) -> None: + self, + *, + sql: str, + mysql_conn_id: str = 'mysql_default', + parameters: Optional[Union[Mapping, Iterable]] = None, + autocommit: bool = False, + database: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.mysql_conn_id = mysql_conn_id self.sql = sql @@ -64,9 +66,5 @@ def __init__( def execute(self, context: Dict) -> None: self.log.info('Executing: %s', self.sql) - hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, - schema=self.database) - hook.run( - self.sql, - autocommit=self.autocommit, - parameters=self.parameters) + hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/airflow/providers/mysql/transfers/presto_to_mysql.py b/airflow/providers/mysql/transfers/presto_to_mysql.py index 10b5d68ceea4e..081ca0a0a29d2 100644 --- a/airflow/providers/mysql/transfers/presto_to_mysql.py +++ b/airflow/providers/mysql/transfers/presto_to_mysql.py @@ -50,13 +50,16 @@ class PrestoToMySqlOperator(BaseOperator): ui_color = '#a0e08c' @apply_defaults - def __init__(self, *, - sql: str, - mysql_table: str, - presto_conn_id: str = 'presto_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, - **kwargs) -> None: + def __init__( + self, + *, + sql: str, + mysql_table: str, + presto_conn_id: str = 'presto_default', + mysql_conn_id: str = 'mysql_default', + mysql_preoperator: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.mysql_table = mysql_table diff --git a/airflow/providers/mysql/transfers/s3_to_mysql.py b/airflow/providers/mysql/transfers/s3_to_mysql.py index 2bee95abee7bc..39391ca9e6682 100644 --- a/airflow/providers/mysql/transfers/s3_to_mysql.py +++ b/airflow/providers/mysql/transfers/s3_to_mysql.py @@ -46,19 +46,25 @@ class S3ToMySqlOperator(BaseOperator): :type mysql_conn_id: str """ - template_fields = ('s3_source_key', 'mysql_table',) + template_fields = ( + 's3_source_key', + 'mysql_table', + ) template_ext = () ui_color = '#f4a460' @apply_defaults - def __init__(self, *, - s3_source_key: str, - mysql_table: str, - mysql_duplicate_key_handling: str = 'IGNORE', - mysql_extra_options: Optional[str] = None, - aws_conn_id: str = 'aws_default', - mysql_conn_id: str = 'mysql_default', - **kwargs) -> None: + def __init__( + self, + *, + s3_source_key: str, + mysql_table: str, + mysql_duplicate_key_handling: str = 'IGNORE', + mysql_extra_options: Optional[str] = None, + aws_conn_id: str = 'aws_default', + mysql_conn_id: str = 'mysql_default', + **kwargs, + ) -> None: super().__init__(**kwargs) self.s3_source_key = s3_source_key self.mysql_table = mysql_table @@ -85,7 +91,7 @@ def execute(self, context: Dict) -> None: table=self.mysql_table, tmp_file=file, duplicate_key_handling=self.mysql_duplicate_key_handling, - extra_options=self.mysql_extra_options + extra_options=self.mysql_extra_options, ) finally: # Remove file downloaded from s3 to be idempotent. diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py index b5a62f0458029..8abc76999e5d5 100644 --- a/airflow/providers/mysql/transfers/vertica_to_mysql.py +++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py @@ -58,22 +58,23 @@ class VerticaToMySqlOperator(BaseOperator): :type bulk_load: bool """ - template_fields = ('sql', 'mysql_table', 'mysql_preoperator', - 'mysql_postoperator') + template_fields = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator') template_ext = ('.sql',) ui_color = '#a0e08c' @apply_defaults def __init__( - self, - sql: str, - mysql_table: str, - vertica_conn_id: str = 'vertica_default', - mysql_conn_id: str = 'mysql_default', - mysql_preoperator: Optional[str] = None, - mysql_postoperator: Optional[str] = None, - bulk_load: bool = False, - *args, **kwargs) -> None: + self, + sql: str, + mysql_table: str, + vertica_conn_id: str = 'vertica_default', + mysql_conn_id: str = 'mysql_default', + mysql_preoperator: Optional[str] = None, + mysql_postoperator: Optional[str] = None, + bulk_load: bool = False, + *args, + **kwargs, + ) -> None: super().__init__(*args, **kwargs) self.sql = sql self.mysql_table = mysql_table @@ -101,9 +102,7 @@ def execute(self, context): if self.bulk_load: tmpfile = NamedTemporaryFile("w") - self.log.info( - "Selecting rows from Vertica to local file %s...", - tmpfile.name) + self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') @@ -130,18 +129,16 @@ def execute(self, context): self.log.info("Bulk inserting rows into MySQL...") with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: - cursor.execute("LOAD DATA LOCAL INFILE '%s' INTO " - "TABLE %s LINES TERMINATED BY '\r\n' (%s)" % - (tmpfile.name, - self.mysql_table, - ", ".join(selected_columns))) + cursor.execute( + "LOAD DATA LOCAL INFILE '%s' INTO " + "TABLE %s LINES TERMINATED BY '\r\n' (%s)" + % (tmpfile.name, self.mysql_table, ", ".join(selected_columns)) + ) conn.commit() tmpfile.close() else: self.log.info("Inserting rows into MySQL...") - mysql.insert_rows(table=self.mysql_table, - rows=result, - target_fields=selected_columns) + mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): # pylint: disable=no-member self.log.info("Inserted rows into MySQL 0") diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 7a2a498cd453f..0f066b4dccd78 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -90,9 +90,9 @@ def sqlalchemy_scheme(self): Database provided in init if exists; otherwise, ``schema`` from ``Connection`` object. """ return ( - self._sqlalchemy_scheme or - self.connection_extra_lower.get('sqlalchemy_scheme') or - self.DEFAULT_SQLALCHEMY_SCHEME + self._sqlalchemy_scheme + or self.connection_extra_lower.get('sqlalchemy_scheme') + or self.DEFAULT_SQLALCHEMY_SCHEME ) @property @@ -154,9 +154,7 @@ def odbc_connection_string(self): extra_exclude = {'driver', 'dsn', 'connect_kwargs', 'sqlalchemy_scheme'} extra_params = { - k: v - for k, v in self.connection.extra_dejson.items() - if not k.lower() in extra_exclude + k: v for k, v in self.connection.extra_dejson.items() if not k.lower() in extra_exclude } for k, v in extra_params.items(): conn_str += f"{k}={v};" diff --git a/airflow/providers/openfaas/hooks/openfaas.py b/airflow/providers/openfaas/hooks/openfaas.py index c3eb87dd9f834..f54462be62115 100644 --- a/airflow/providers/openfaas/hooks/openfaas.py +++ b/airflow/providers/openfaas/hooks/openfaas.py @@ -42,10 +42,7 @@ class OpenFaasHook(BaseHook): DEPLOY_FUNCTION = "/system/functions" UPDATE_FUNCTION = "/system/functions" - def __init__(self, - function_name=None, - conn_id: str = 'open_faas_default', - *args, **kwargs) -> None: + def __init__(self, function_name=None, conn_id: str = 'open_faas_default', *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.function_name = function_name self.conn_id = conn_id diff --git a/airflow/providers/opsgenie/hooks/opsgenie_alert.py b/airflow/providers/opsgenie/hooks/opsgenie_alert.py index 821c246fd8694..beb1cfec7a725 100644 --- a/airflow/providers/opsgenie/hooks/opsgenie_alert.py +++ b/airflow/providers/opsgenie/hooks/opsgenie_alert.py @@ -39,11 +39,8 @@ class OpsgenieAlertHook(HttpHook): :type opsgenie_conn_id: str """ - def __init__(self, - opsgenie_conn_id='opsgenie_default', - *args, - **kwargs - ): + + def __init__(self, opsgenie_conn_id='opsgenie_default', *args, **kwargs): super().__init__(http_conn_id=opsgenie_conn_id, *args, **kwargs) def _get_api_key(self): @@ -53,8 +50,9 @@ def _get_api_key(self): conn = self.get_connection(self.http_conn_id) api_key = conn.password if not api_key: - raise AirflowException('Opsgenie API Key is required for this hook, ' - 'please check your conn_id configuration.') + raise AirflowException( + 'Opsgenie API Key is required for this hook, ' 'please check your conn_id configuration.' + ) return api_key def get_conn(self, headers=None): @@ -82,7 +80,8 @@ def execute(self, payload=None): """ payload = payload or {} api_key = self._get_api_key() - return self.run(endpoint='v2/alerts', - data=json.dumps(payload), - headers={'Content-Type': 'application/json', - 'Authorization': 'GenieKey %s' % api_key}) + return self.run( + endpoint='v2/alerts', + data=json.dumps(payload), + headers={'Content-Type': 'application/json', 'Authorization': 'GenieKey %s' % api_key}, + ) diff --git a/airflow/providers/opsgenie/operators/opsgenie_alert.py b/airflow/providers/opsgenie/operators/opsgenie_alert.py index f086816da91e7..4a6f0321d97c5 100644 --- a/airflow/providers/opsgenie/operators/opsgenie_alert.py +++ b/airflow/providers/opsgenie/operators/opsgenie_alert.py @@ -64,27 +64,30 @@ class OpsgenieAlertOperator(BaseOperator): :param note: Additional note that will be added while creating the alert. (templated) :type note: str """ + template_fields = ('message', 'alias', 'description', 'entity', 'priority', 'note') # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - message, - opsgenie_conn_id='opsgenie_default', - alias=None, - description=None, - responders=None, - visible_to=None, - actions=None, - tags=None, - details=None, - entity=None, - source=None, - priority=None, - user=None, - note=None, - **kwargs - ): + def __init__( + self, + *, + message, + opsgenie_conn_id='opsgenie_default', + alias=None, + description=None, + responders=None, + visible_to=None, + actions=None, + tags=None, + details=None, + entity=None, + source=None, + priority=None, + user=None, + note=None, + **kwargs, + ): super().__init__(**kwargs) self.message = message @@ -113,9 +116,19 @@ def _build_opsgenie_payload(self): payload = {} for key in [ - "message", "alias", "description", "responders", - "visible_to", "actions", "tags", "details", "entity", - "source", "priority", "user", "note" + "message", + "alias", + "description", + "responders", + "visible_to", + "actions", + "tags", + "details", + "entity", + "source", + "priority", + "user", + "note", ]: val = getattr(self, key) if val: diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index b2385c83b046d..e6f1a8439bb58 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -28,6 +28,7 @@ class OracleHook(DbApiHook): """ Interact with Oracle SQL. """ + conn_name_attr = 'oracle_conn_id' default_conn_name = 'oracle_default' supports_autocommit = False @@ -52,10 +53,7 @@ def get_conn(self): `cx_Oracle.connect `_ """ conn = self.get_connection(self.oracle_conn_id) # pylint: disable=no-member - conn_config = { - 'user': conn.login, - 'password': conn.password - } + conn_config = {'user': conn.login, 'password': conn.password} dsn = conn.extra_dejson.get('dsn', None) sid = conn.extra_dejson.get('sid', None) mod = conn.extra_dejson.get('module', None) @@ -155,22 +153,20 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000): lst.append("'" + str(cell).replace("'", "''") + "'") elif cell is None: lst.append('NULL') - elif isinstance(cell, float) and \ - numpy.isnan(cell): # coerce numpy NaN to NULL + elif isinstance(cell, float) and numpy.isnan(cell): # coerce numpy NaN to NULL lst.append('NULL') elif isinstance(cell, numpy.datetime64): lst.append("'" + str(cell) + "'") elif isinstance(cell, datetime): - lst.append("to_date('" + - cell.strftime('%Y-%m-%d %H:%M:%S') + - "','YYYY-MM-DD HH24:MI:SS')") + lst.append( + "to_date('" + cell.strftime('%Y-%m-%d %H:%M:%S') + "','YYYY-MM-DD HH24:MI:SS')" + ) else: lst.append(str(cell)) values = tuple(lst) - sql = 'INSERT /*+ APPEND */ ' \ - 'INTO {0} {1} VALUES ({2})'.format(table, - target_fields, - ','.join(values)) + sql = 'INSERT /*+ APPEND */ ' 'INTO {0} {1} VALUES ({2})'.format( + table, target_fields, ','.join(values) + ) cur.execute(sql) if i % commit_every == 0: conn.commit() diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index 5d7c98eb747fc..8858bc58e3456 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -46,12 +46,14 @@ class OracleOperator(BaseOperator): @apply_defaults def __init__( - self, *, - sql: str, - oracle_conn_id: str = 'oracle_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - autocommit: bool = False, - **kwargs) -> None: + self, + *, + sql: str, + oracle_conn_id: str = 'oracle_default', + parameters: Optional[Union[Mapping, Iterable]] = None, + autocommit: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.oracle_conn_id = oracle_conn_id self.sql = sql @@ -61,7 +63,4 @@ def __init__( def execute(self, context): self.log.info('Executing: %s', self.sql) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - hook.run( - self.sql, - autocommit=self.autocommit, - parameters=self.parameters) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/airflow/providers/oracle/transfers/oracle_to_oracle.py b/airflow/providers/oracle/transfers/oracle_to_oracle.py index 64ed4c7821fb0..b3edd78e3060d 100644 --- a/airflow/providers/oracle/transfers/oracle_to_oracle.py +++ b/airflow/providers/oracle/transfers/oracle_to_oracle.py @@ -46,14 +46,16 @@ class OracleToOracleOperator(BaseOperator): @apply_defaults def __init__( - self, *, - oracle_destination_conn_id, - destination_table, - oracle_source_conn_id, - source_sql, - source_sql_params=None, - rows_chunk=5000, - **kwargs): + self, + *, + oracle_destination_conn_id, + destination_table, + oracle_source_conn_id, + source_sql, + source_sql_params=None, + rows_chunk=5000, + **kwargs, + ): super().__init__(**kwargs) if source_sql_params is None: source_sql_params = {} @@ -76,9 +78,9 @@ def _execute(self, src_hook, dest_hook, context): rows = cursor.fetchmany(self.rows_chunk) while len(rows) > 0: rows_total += len(rows) - dest_hook.bulk_insert_rows(self.destination_table, rows, - target_fields=target_fields, - commit_every=self.rows_chunk) + dest_hook.bulk_insert_rows( + self.destination_table, rows, target_fields=target_fields, commit_every=self.rows_chunk + ) rows = cursor.fetchmany(self.rows_chunk) self.log.info("Total inserted: %s rows", rows_total) diff --git a/airflow/providers/pagerduty/hooks/pagerduty.py b/airflow/providers/pagerduty/hooks/pagerduty.py index d14971c53f88e..55f9a476dcc7f 100644 --- a/airflow/providers/pagerduty/hooks/pagerduty.py +++ b/airflow/providers/pagerduty/hooks/pagerduty.py @@ -50,8 +50,7 @@ def __init__(self, token: Optional[str] = None, pagerduty_conn_id: Optional[str] self.token = token if self.token is None: - raise AirflowException( - 'Cannot get token: No valid api token nor pagerduty_conn_id supplied.') + raise AirflowException('Cannot get token: No valid api token nor pagerduty_conn_id supplied.') # pylint: disable=too-many-arguments def create_event( diff --git a/airflow/providers/papermill/example_dags/example_papermill.py b/airflow/providers/papermill/example_dags/example_papermill.py index 6b9910ab5f441..c9e645a331f71 100644 --- a/airflow/providers/papermill/example_dags/example_papermill.py +++ b/airflow/providers/papermill/example_dags/example_papermill.py @@ -48,7 +48,7 @@ task_id="run_example_notebook", input_nb="/tmp/hello_world.ipynb", output_nb="/tmp/out-{{ execution_date }}.ipynb", - parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"} + parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"}, ) # [END howto_operator_papermill] @@ -72,21 +72,16 @@ def check_notebook(inlets, execution_date): default_args=default_args, schedule_interval='0 0 * * *', start_date=days_ago(2), - dagrun_timeout=timedelta(minutes=60) + dagrun_timeout=timedelta(minutes=60), ) as dag_2: run_this = PapermillOperator( task_id="run_example_notebook", - input_nb=os.path.join(os.path.dirname(os.path.realpath(__file__)), - "input_notebook.ipynb"), + input_nb=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_notebook.ipynb"), output_nb="/tmp/out-{{ execution_date }}.ipynb", - parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"} + parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"}, ) - check_output = PythonOperator( - task_id='check_out', - python_callable=check_notebook, - inlets=AUTO - ) + check_output = PythonOperator(task_id='check_out', python_callable=check_notebook, inlets=AUTO) check_output.set_upstream(run_this) diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index 4c800506368f3..4abbc8036ec96 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -30,6 +30,7 @@ class NoteBook(File): """ Jupyter notebook """ + type_hint: Optional[str] = "jupyter_notebook" parameters: Optional[Dict] = {} @@ -47,19 +48,22 @@ class PapermillOperator(BaseOperator): :param parameters: the notebook parameters to set :type parameters: dict """ + supports_lineage = True @apply_defaults - def __init__(self, *, - input_nb: Optional[str] = None, - output_nb: Optional[str] = None, - parameters: Optional[Dict] = None, - **kwargs) -> None: + def __init__( + self, + *, + input_nb: Optional[str] = None, + output_nb: Optional[str] = None, + parameters: Optional[Dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) if input_nb: - self.inlets.append(NoteBook(url=input_nb, - parameters=parameters)) + self.inlets.append(NoteBook(url=input_nb, parameters=parameters)) if output_nb: self.outlets.append(NoteBook(url=output_nb)) @@ -68,6 +72,10 @@ def execute(self, context): raise ValueError("Input notebook or output notebook is not specified") for i in range(len(self.inlets)): - pm.execute_notebook(self.inlets[i].url, self.outlets[i].url, - parameters=self.inlets[i].parameters, - progress_bar=False, report_mode=True) + pm.execute_notebook( + self.inlets[i].url, + self.outlets[i].url, + parameters=self.inlets[i].parameters, + progress_bar=False, + report_mode=True, + ) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index bf5fa4d094614..5dda0e051299c 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -48,6 +48,7 @@ class PostgresHook(DbApiHook): the host field, so is optional. It can however be overridden in the extra field. extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}`` """ + conn_name_attr = 'postgres_conn_id' default_conn_name = 'postgres_default' supports_autocommit = True @@ -82,15 +83,22 @@ def get_conn(self): user=conn.login, password=conn.password, dbname=self.schema or conn.schema, - port=conn.port) + port=conn.port, + ) raw_cursor = conn.extra_dejson.get('cursor', False) if raw_cursor: conn_args['cursor_factory'] = self._get_cursor(raw_cursor) # check for ssl parameters in conn.extra for arg_name, arg_val in conn.extra_dejson.items(): - if arg_name in ['sslmode', 'sslcert', 'sslkey', - 'sslrootcert', 'sslcrl', 'application_name', - 'keepalives_idle']: + if arg_name in [ + 'sslmode', + 'sslcert', + 'sslkey', + 'sslrootcert', + 'sslcrl', + 'application_name', + 'keepalives_idle', + ]: conn_args[arg_name] = arg_val self.conn = psycopg2.connect(**conn_args) @@ -174,7 +182,8 @@ def get_iam_token(self, conn): DbUser=conn.login, DbName=self.schema or conn.schema, ClusterIdentifier=cluster_identifier, - AutoCreate=False) + AutoCreate=False, + ) token = cluster_creds['DbPassword'] login = cluster_creds['DbUser'] else: @@ -201,7 +210,7 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs): :return: The generated INSERT or REPLACE SQL statement :rtype: str """ - placeholders = ["%s", ] * len(values) + placeholders = ["%s",] * len(values) replace_index = kwargs.get("replace_index", None) if target_fields: @@ -210,10 +219,7 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs): else: target_fields_fragment = '' - sql = "INSERT INTO {0} {1} VALUES ({2})".format( - table, - target_fields_fragment, - ",".join(placeholders)) + sql = "INSERT INTO {0} {1} VALUES ({2})".format(table, target_fields_fragment, ",".join(placeholders)) if replace: if target_fields is None: @@ -225,12 +231,9 @@ def _generate_insert_sql(table, values, target_fields, replace, **kwargs): replace_index_set = set(replace_index) replace_target = [ - "{0} = excluded.{0}".format(col) - for col in target_fields - if col not in replace_index_set + "{0} = excluded.{0}".format(col) for col in target_fields if col not in replace_index_set ] sql += " ON CONFLICT ({0}) DO UPDATE SET {1}".format( - ", ".join(replace_index), - ", ".join(replace_target), + ", ".join(replace_index), ", ".join(replace_target), ) return sql diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py index f0bce2186aee2..88f7662a8a1de 100644 --- a/airflow/providers/postgres/operators/postgres.py +++ b/airflow/providers/postgres/operators/postgres.py @@ -47,13 +47,15 @@ class PostgresOperator(BaseOperator): @apply_defaults def __init__( - self, *, - sql: str, - postgres_conn_id: str = 'postgres_default', - autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, - database: Optional[str] = None, - **kwargs) -> None: + self, + *, + sql: str, + postgres_conn_id: str = 'postgres_default', + autocommit: bool = False, + parameters: Optional[Union[Mapping, Iterable]] = None, + database: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sql = sql self.postgres_conn_id = postgres_conn_id @@ -64,8 +66,7 @@ def __init__( def execute(self, context): self.log.info('Executing: %s', self.sql) - self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, - schema=self.database) + self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database) self.hook.run(self.sql, self.autocommit, parameters=self.parameters) for output in self.hook.conn.notices: self.log.info(output) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 6abb6a310706d..be731068089e2 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -55,7 +55,7 @@ def get_conn(self): catalog=db.extra_dejson.get('catalog', 'hive'), schema=db.schema, auth=auth, - isolation_level=self.get_isolation_level() + isolation_level=self.get_isolation_level(), ) def get_isolation_level(self): @@ -73,8 +73,7 @@ def get_records(self, hql, parameters=None): Get a set of records from Presto """ try: - return super().get_records( - self._strip_sql(hql), parameters) + return super().get_records(self._strip_sql(hql), parameters) except DatabaseError as e: raise PrestoException(e) @@ -84,8 +83,7 @@ def get_first(self, hql, parameters=None): returns. """ try: - return super().get_first( - self._strip_sql(hql), parameters) + return super().get_first(self._strip_sql(hql), parameters) except DatabaseError as e: raise PrestoException(e) @@ -94,6 +92,7 @@ def get_pandas_df(self, hql, parameters=None, **kwargs): Get a pandas dataframe from a sql query. """ import pandas + cursor = self.get_cursor() try: cursor.execute(self._strip_sql(hql), parameters) diff --git a/airflow/providers/qubole/example_dags/example_qubole.py b/airflow/providers/qubole/example_dags/example_qubole.py index 7076e34d530a9..8da5a64f23dee 100644 --- a/airflow/providers/qubole/example_dags/example_qubole.py +++ b/airflow/providers/qubole/example_dags/example_qubole.py @@ -32,7 +32,7 @@ 'depends_on_past': False, 'email': ['airflow@example.com'], 'email_on_failure': False, - 'email_on_retry': False + 'email_on_retry': False, } with DAG( @@ -81,9 +81,7 @@ def compare_result_fn(**kwargs): # To attach tags to qubole command, auto attach 3 tags - dag_id, task_id, run_id qubole_conn_id='qubole_default', # Connection id to submit commands inside QDS, if not set "qubole_default" is used - params={ - 'cluster_label': 'default', - } + params={'cluster_label': 'default',}, ) hive_s3_location = QuboleOperator( @@ -94,45 +92,35 @@ def compare_result_fn(**kwargs): tags=['tag1', 'tag2'], # If the script at s3 location has any qubole specific macros to be replaced # macros='[{"date": "{{ ds }}"}, {"name" : "abc"}]', - trigger_rule="all_done" + trigger_rule="all_done", ) compare_result = PythonOperator( - task_id='compare_result', - python_callable=compare_result_fn, - trigger_rule="all_done" + task_id='compare_result', python_callable=compare_result_fn, trigger_rule="all_done" ) compare_result << [hive_show_table, hive_s3_location] options = ['hadoop_jar_cmd', 'presto_cmd', 'db_query', 'spark_cmd'] - branching = BranchPythonOperator( - task_id='branching', - python_callable=lambda: random.choice(options) - ) + branching = BranchPythonOperator(task_id='branching', python_callable=lambda: random.choice(options)) branching << compare_result - join = DummyOperator( - task_id='join', - trigger_rule='one_success' - ) + join = DummyOperator(task_id='join', trigger_rule='one_success') hadoop_jar_cmd = QuboleOperator( task_id='hadoop_jar_cmd', command_type='hadoopcmd', sub_command='jar s3://paid-qubole/HadoopAPIExamples/' - 'jars/hadoop-0.20.1-dev-streaming.jar ' - '-mapper wc ' - '-numReduceTasks 0 -input s3://paid-qubole/HadoopAPITests/' - 'data/3.tsv -output ' - 's3://paid-qubole/HadoopAPITests/data/3_wc', + 'jars/hadoop-0.20.1-dev-streaming.jar ' + '-mapper wc ' + '-numReduceTasks 0 -input s3://paid-qubole/HadoopAPITests/' + 'data/3.tsv -output ' + 's3://paid-qubole/HadoopAPITests/data/3_wc', cluster_label='{{ params.cluster_label }}', fetch_logs=True, - params={ - 'cluster_label': 'default', - } + params={'cluster_label': 'default',}, ) pig_cmd = QuboleOperator( @@ -140,34 +128,27 @@ def compare_result_fn(**kwargs): command_type="pigcmd", script_location="s3://public-qubole/qbol-library/scripts/script1-hadoop-s3-small.pig", parameters="key1=value1 key2=value2", - trigger_rule="all_done" + trigger_rule="all_done", ) pig_cmd << hadoop_jar_cmd << branching pig_cmd >> join - presto_cmd = QuboleOperator( - task_id='presto_cmd', - command_type='prestocmd', - query='show tables' - ) + presto_cmd = QuboleOperator(task_id='presto_cmd', command_type='prestocmd', query='show tables') shell_cmd = QuboleOperator( task_id='shell_cmd', command_type="shellcmd", script_location="s3://public-qubole/qbol-library/scripts/shellx.sh", parameters="param1 param2", - trigger_rule="all_done" + trigger_rule="all_done", ) shell_cmd << presto_cmd << branching shell_cmd >> join db_query = QuboleOperator( - task_id='db_query', - command_type='dbtapquerycmd', - query='show tables', - db_tap_id=2064 + task_id='db_query', command_type='dbtapquerycmd', query='show tables', db_tap_id=2064 ) db_export = QuboleOperator( @@ -178,7 +159,7 @@ def compare_result_fn(**kwargs): db_table='exported_airline_origin_destination', partition_spec='dt=20110104-02', dbtap_id=2064, - trigger_rule="all_done" + trigger_rule="all_done", ) db_export << db_query << branching @@ -193,7 +174,7 @@ def compare_result_fn(**kwargs): where_clause='id < 10', parallelism=2, dbtap_id=2064, - trigger_rule="all_done" + trigger_rule="all_done", ) prog = ''' @@ -226,7 +207,7 @@ def main(args: Array[String]) { program=prog, language='scala', arguments='--class SparkPi', - tags='airflow_example_run' + tags='airflow_example_run', ) spark_cmd << db_import << branching @@ -259,27 +240,25 @@ def main(args: Array[String]) { poke_interval=60, timeout=600, data={ - "files": - [ - "s3://paid-qubole/HadoopAPIExamples/jars/hadoop-0.20.1-dev-streaming.jar", - "s3://paid-qubole/HadoopAPITests/data/{{ ds.split('-')[2] }}.tsv" - ] # will check for availability of all the files in array - } + "files": [ + "s3://paid-qubole/HadoopAPIExamples/jars/hadoop-0.20.1-dev-streaming.jar", + "s3://paid-qubole/HadoopAPITests/data/{{ ds.split('-')[2] }}.tsv", + ] # will check for availability of all the files in array + }, ) check_hive_partition = QubolePartitionSensor( task_id='check_hive_partition', poke_interval=10, timeout=60, - data={"schema": "default", - "table": "my_partitioned_table", - "columns": [ - {"column": "month", "values": - ["{{ ds.split('-')[1] }}"]}, - {"column": "day", "values": - ["{{ ds.split('-')[2] }}", "{{ yesterday_ds.split('-')[2] }}"]} - ] # will check for partitions like [month=12/day=12,month=12/day=13] - } + data={ + "schema": "default", + "table": "my_partitioned_table", + "columns": [ + {"column": "month", "values": ["{{ ds.split('-')[1] }}"]}, + {"column": "day", "values": ["{{ ds.split('-')[2] }}", "{{ yesterday_ds.split('-')[2] }}"]}, + ], # will check for partitions like [month=12/day=12,month=12/day=13] + }, ) check_s3_file >> check_hive_partition diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index 508ead3f9d5c0..4a03ae89961de 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -24,8 +24,17 @@ import time from qds_sdk.commands import ( - Command, DbExportCommand, DbImportCommand, DbTapQueryCommand, HadoopCommand, HiveCommand, PigCommand, - PrestoCommand, ShellCommand, SparkCommand, SqlCommand, + Command, + DbExportCommand, + DbImportCommand, + DbTapQueryCommand, + HadoopCommand, + HiveCommand, + PigCommand, + PrestoCommand, + ShellCommand, + SparkCommand, + SqlCommand, ) from qds_sdk.qubole import Qubole @@ -46,14 +55,10 @@ "dbtapquerycmd": DbTapQueryCommand, "dbexportcmd": DbExportCommand, "dbimportcmd": DbImportCommand, - "sqlcmd": SqlCommand + "sqlcmd": SqlCommand, } -POSITIONAL_ARGS = { - 'hadoopcmd': ['sub_command'], - 'shellcmd': ['parameters'], - 'pigcmd': ['parameters'] -} +POSITIONAL_ARGS = {'hadoopcmd': ['sub_command'], 'shellcmd': ['parameters'], 'pigcmd': ['parameters']} def flatten_list(list_of_lists): @@ -100,6 +105,7 @@ def build_command_args(): class QuboleHook(BaseHook): """Hook for Qubole communication""" + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument super().__init__() conn = self.get_connection(kwargs['qubole_conn_id']) @@ -121,8 +127,9 @@ def handle_failure_retry(context): cmd = Command.find(cmd_id) if cmd is not None: if cmd.status == 'done': - log.info('Command ID: %s has been succeeded, hence marking this ' - 'TI as Success.', cmd_id) + log.info( + 'Command ID: %s has been succeeded, hence marking this ' 'TI as Success.', cmd_id + ) ti.state = State.SUCCESS elif cmd.status == 'running': log.info('Cancelling the Qubole Command Id: %s', cmd_id) @@ -134,10 +141,7 @@ def execute(self, context): self.cmd = self.cls.create(**args) self.task_instance = context['task_instance'] context['task_instance'].xcom_push(key='qbol_cmd_id', value=self.cmd.id) - self.log.info( - "Qubole command created with Id: %s and Status: %s", - self.cmd.id, self.cmd.status - ) + self.log.info("Qubole command created with Id: %s and Status: %s", self.cmd.id, self.cmd.status) while not Command.is_done(self.cmd.status): time.sleep(Qubole.poll_interval) @@ -148,8 +152,9 @@ def execute(self, context): self.log.info("Logs for Command Id: %s \n%s", self.cmd.id, self.cmd.get_log()) if self.cmd.status != 'done': - raise AirflowException('Command Id: {0} failed with Status: {1}'.format( - self.cmd.id, self.cmd.status)) + raise AirflowException( + 'Command Id: {0} failed with Status: {1}'.format(self.cmd.id, self.cmd.status) + ) def kill(self, ti): """ @@ -182,9 +187,7 @@ def get_results(self, ti=None, fp=None, inline=True, delim=None, fetch=True): """ if fp is None: iso = datetime.datetime.utcnow().isoformat() - logpath = os.path.expanduser( - conf.get('logging', 'BASE_LOG_FOLDER') - ) + logpath = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) resultpath = logpath + '/' + self.dag_id + '/' + self.task_id + '/results' pathlib.Path(resultpath).mkdir(parents=True, exist_ok=True) fp = open(resultpath + '/' + iso, 'wb') diff --git a/airflow/providers/qubole/hooks/qubole_check.py b/airflow/providers/qubole/hooks/qubole_check.py index 473391a1bd95a..5274ba5dee886 100644 --- a/airflow/providers/qubole/hooks/qubole_check.py +++ b/airflow/providers/qubole/hooks/qubole_check.py @@ -75,7 +75,7 @@ def parse_first_row(row_list): elif isfloat(col_value): col_value = float(col_value) elif isbool(col_value): - col_value = (col_value.lower() == "true") + col_value = col_value.lower() == "true" record_list.append(col_value) return record_list @@ -85,11 +85,11 @@ class QuboleCheckHook(QuboleHook): """ Qubole check hook """ + def __init__(self, context, *args, **kwargs): super().__init__(*args, **kwargs) self.results_parser_callable = parse_first_row - if 'results_parser_callable' in kwargs and \ - kwargs['results_parser_callable'] is not None: + if 'results_parser_callable' in kwargs and kwargs['results_parser_callable'] is not None: if not callable(kwargs['results_parser_callable']): raise AirflowException('`results_parser_callable` param must be callable') self.results_parser_callable = kwargs['results_parser_callable'] diff --git a/airflow/providers/qubole/operators/qubole.py b/airflow/providers/qubole/operators/qubole.py index cf987824ece4a..51592ea4f0a01 100644 --- a/airflow/providers/qubole/operators/qubole.py +++ b/airflow/providers/qubole/operators/qubole.py @@ -23,13 +23,18 @@ from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance from airflow.providers.qubole.hooks.qubole import ( - COMMAND_ARGS, HYPHEN_ARGS, POSITIONAL_ARGS, QuboleHook, flatten_list, + COMMAND_ARGS, + HYPHEN_ARGS, + POSITIONAL_ARGS, + QuboleHook, + flatten_list, ) from airflow.utils.decorators import apply_defaults class QDSLink(BaseOperatorLink): """Link to QDS""" + name = 'Go to QDS' def get_link(self, operator, dttm): @@ -42,7 +47,8 @@ def get_link(self, operator, dttm): """ ti = TaskInstance(task=operator, execution_date=dttm) conn = BaseHook.get_connection( - getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id']) + getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'] + ) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: @@ -166,12 +172,34 @@ class QuboleOperator(BaseOperator): """ template_fields: Iterable[str] = ( - 'query', 'script_location', 'sub_command', 'script', 'files', - 'archives', 'program', 'cmdline', 'sql', 'where_clause', 'tags', - 'extract_query', 'boundary_query', 'macros', 'name', 'parameters', - 'dbtap_id', 'hive_table', 'db_table', 'split_column', 'note_id', - 'db_update_keys', 'export_dir', 'partition_spec', 'qubole_conn_id', - 'arguments', 'user_program_arguments', 'cluster_label' + 'query', + 'script_location', + 'sub_command', + 'script', + 'files', + 'archives', + 'program', + 'cmdline', + 'sql', + 'where_clause', + 'tags', + 'extract_query', + 'boundary_query', + 'macros', + 'name', + 'parameters', + 'dbtap_id', + 'hive_table', + 'db_table', + 'split_column', + 'note_id', + 'db_update_keys', + 'export_dir', + 'partition_spec', + 'qubole_conn_id', + 'arguments', + 'user_program_arguments', + 'cluster_label', ) template_ext: Iterable[str] = ('.txt',) @@ -179,9 +207,7 @@ class QuboleOperator(BaseOperator): ui_fgcolor = '#fff' qubole_hook_allowed_args_list = ['command_type', 'qubole_conn_id', 'fetch_logs'] - operator_extra_links = ( - QDSLink(), - ) + operator_extra_links = (QDSLink(),) @apply_defaults def __init__(self, *, qubole_conn_id="qubole_default", **kwargs): @@ -198,8 +224,12 @@ def __init__(self, *, qubole_conn_id="qubole_default", **kwargs): self.on_retry_callback = QuboleHook.handle_failure_retry def _get_filtered_args(self, all_kwargs): - qubole_args = flatten_list(COMMAND_ARGS.values()) + HYPHEN_ARGS + \ - flatten_list(POSITIONAL_ARGS.values()) + self.qubole_hook_allowed_args_list + qubole_args = ( + flatten_list(COMMAND_ARGS.values()) + + HYPHEN_ARGS + + flatten_list(POSITIONAL_ARGS.values()) + + self.qubole_hook_allowed_args_list + ) return {key: value for key, value in all_kwargs.items() if key not in qubole_args} def execute(self, context): diff --git a/airflow/providers/qubole/operators/qubole_check.py b/airflow/providers/qubole/operators/qubole_check.py index e010281005ef9..a2a3befe4cd71 100644 --- a/airflow/providers/qubole/operators/qubole_check.py +++ b/airflow/providers/qubole/operators/qubole_check.py @@ -162,14 +162,20 @@ class QuboleValueCheckOperator(ValueCheckOperator, QuboleOperator): ui_fgcolor = '#000' @apply_defaults - def __init__(self, *, pass_value, tolerance=None, results_parser_callable=None, - qubole_conn_id="qubole_default", **kwargs): + def __init__( + self, + *, + pass_value, + tolerance=None, + results_parser_callable=None, + qubole_conn_id="qubole_default", + **kwargs, + ): sql = get_sql_from_qbol_cmd(kwargs) super().__init__( - qubole_conn_id=qubole_conn_id, - sql=sql, pass_value=pass_value, tolerance=tolerance, - **kwargs) + qubole_conn_id=qubole_conn_id, sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs + ) self.results_parser_callable = results_parser_callable self.on_failure_callback = QuboleCheckHook.handle_failure_retry @@ -190,9 +196,7 @@ def get_hook(self, context=None): return self.hook else: return QuboleCheckHook( - context=context, - results_parser_callable=self.results_parser_callable, - **self.kwargs + context=context, results_parser_callable=self.results_parser_callable, **self.kwargs ) def __getattribute__(self, name): @@ -232,11 +236,12 @@ def handle_airflow_exception(airflow_exception, hook): if cmd.is_success(cmd.status): qubole_command_results = hook.get_query_results() qubole_command_id = cmd.id - exception_message = \ - '\nQubole Command Id: {qubole_command_id}' \ - '\nQubole Command Results:' \ + exception_message = ( + '\nQubole Command Id: {qubole_command_id}' + '\nQubole Command Results:' '\n{qubole_command_results}'.format( - qubole_command_id=qubole_command_id, - qubole_command_results=qubole_command_results) + qubole_command_id=qubole_command_id, qubole_command_results=qubole_command_results + ) + ) raise AirflowException(str(airflow_exception) + exception_message) raise AirflowException(str(airflow_exception)) diff --git a/airflow/providers/qubole/sensors/qubole.py b/airflow/providers/qubole/sensors/qubole.py index b432ede1edbfd..5dfeaeb3d7450 100644 --- a/airflow/providers/qubole/sensors/qubole.py +++ b/airflow/providers/qubole/sensors/qubole.py @@ -40,9 +40,10 @@ def __init__(self, *, data, qubole_conn_id="qubole_default", **kwargs): self.qubole_conn_id = qubole_conn_id if 'poke_interval' in kwargs and kwargs['poke_interval'] < 5: - raise AirflowException("Sorry, poke_interval can't be less than 5 sec for " - "task '{0}' in dag '{1}'." - .format(kwargs['task_id'], kwargs['dag'].dag_id)) + raise AirflowException( + "Sorry, poke_interval can't be less than 5 sec for " + "task '{0}' in dag '{1}'.".format(kwargs['task_id'], kwargs['dag'].dag_id) + ) super().__init__(**kwargs) diff --git a/airflow/providers/redis/hooks/redis.py b/airflow/providers/redis/hooks/redis.py index e82b9425f8f2e..9ed90e725e5a0 100644 --- a/airflow/providers/redis/hooks/redis.py +++ b/airflow/providers/redis/hooks/redis.py @@ -32,6 +32,7 @@ class RedisHook(BaseHook): Also you can set ssl parameters as: ``{"ssl": true, "ssl_cert_reqs": "require", "ssl_cert_file": "/path/to/cert.pem", etc}``. """ + def __init__(self, redis_conn_id: str = 'redis_default') -> None: """ Prepares hook to connect to a Redis database. @@ -58,20 +59,24 @@ def get_conn(self): self.db = conn.extra_dejson.get('db', None) # check for ssl parameters in conn.extra - ssl_arg_names = ["ssl", "ssl_cert_reqs", "ssl_ca_certs", "ssl_keyfile", "ssl_cert_file", - "ssl_check_hostname"] + ssl_arg_names = [ + "ssl", + "ssl_cert_reqs", + "ssl_ca_certs", + "ssl_keyfile", + "ssl_cert_file", + "ssl_check_hostname", + ] ssl_args = {name: val for name, val in conn.extra_dejson.items() if name in ssl_arg_names} if not self.redis: self.log.debug( 'Initializing redis object for conn_id "%s" on %s:%s:%s', - self.redis_conn_id, self.host, self.port, self.db + self.redis_conn_id, + self.host, + self.port, + self.db, ) - self.redis = Redis( - host=self.host, - port=self.port, - password=self.password, - db=self.db, - **ssl_args) + self.redis = Redis(host=self.host, port=self.port, password=self.password, db=self.db, **ssl_args) return self.redis diff --git a/airflow/providers/redis/operators/redis_publish.py b/airflow/providers/redis/operators/redis_publish.py index 6cec1443158f8..fc85201dc06fb 100644 --- a/airflow/providers/redis/operators/redis_publish.py +++ b/airflow/providers/redis/operators/redis_publish.py @@ -38,12 +38,7 @@ class RedisPublishOperator(BaseOperator): template_fields = ('channel', 'message') @apply_defaults - def __init__( - self, *, - channel: str, - message: str, - redis_conn_id: str = 'redis_default', - **kwargs) -> None: + def __init__(self, *, channel: str, message: str, redis_conn_id: str = 'redis_default', **kwargs) -> None: super().__init__(**kwargs) self.redis_conn_id = redis_conn_id diff --git a/airflow/providers/redis/sensors/redis_key.py b/airflow/providers/redis/sensors/redis_key.py index c90c965addbe1..4fbbdb31884ce 100644 --- a/airflow/providers/redis/sensors/redis_key.py +++ b/airflow/providers/redis/sensors/redis_key.py @@ -26,6 +26,7 @@ class RedisKeySensor(BaseSensorOperator): """ Checks for the existence of a key in a Redis """ + template_fields = ('key',) ui_color = '#f0eee4' diff --git a/airflow/providers/redis/sensors/redis_pub_sub.py b/airflow/providers/redis/sensors/redis_pub_sub.py index b2f5e8b148738..df6db7dc6dd30 100644 --- a/airflow/providers/redis/sensors/redis_pub_sub.py +++ b/airflow/providers/redis/sensors/redis_pub_sub.py @@ -32,6 +32,7 @@ class RedisPubSubSensor(BaseSensorOperator): :param redis_conn_id: the redis connection id :type redis_conn_id: str """ + template_fields = ('channels',) ui_color = '#f0eee4' diff --git a/airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py b/airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py index 68b30ca0c63fb..32b347ce62451 100644 --- a/airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py +++ b/airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py @@ -32,7 +32,7 @@ 'depends_on_past': False, 'email': ['airflow@example.com'], 'email_on_failure': False, - 'email_on_retry': False + 'email_on_retry': False, } with DAG( diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py index f7fce1ad405e2..02d073f6bd75e 100644 --- a/airflow/providers/salesforce/hooks/salesforce.py +++ b/airflow/providers/salesforce/hooks/salesforce.py @@ -71,7 +71,7 @@ def get_conn(self): password=connection.password, security_token=extras['security_token'], instance_url=connection.host, - domain=extras.get('domain', None) + domain=extras.get('domain', None), ) return self.conn @@ -94,8 +94,9 @@ def make_query(self, query, include_deleted=False, query_params=None): query_params = query_params or {} query_results = conn.query_all(query, include_deleted=include_deleted, **query_params) - self.log.info("Received results: Total size: %s; Done: %s", - query_results['totalSize'], query_results['done']) + self.log.info( + "Received results: Total size: %s; Done: %s", query_results['totalSize'], query_results['done'] + ) return query_results @@ -146,8 +147,10 @@ def get_object_from_salesforce(self, obj, fields): """ query = "SELECT {} FROM {}".format(",".join(fields), obj) - self.log.info("Making query to Salesforce: %s", - query if len(query) < 30 else " ... ".join([query[:15], query[-15:]])) + self.log.info( + "Making query to Salesforce: %s", + query if len(query) < 30 else " ... ".join([query[:15], query[-15:]]), + ) return self.make_query(query) @@ -189,12 +192,9 @@ def _to_timestamp(cls, column): return pd.Series(converted, index=column.index) - def write_object_to_file(self, - query_results, - filename, - fmt="csv", - coerce_to_timestamp=False, - record_time_added=False): + def write_object_to_file( + self, query_results, filename, fmt="csv", coerce_to_timestamp=False, record_time_added=False + ): """ Write query results to file. @@ -236,8 +236,11 @@ def write_object_to_file(self, if fmt not in ['csv', 'json', 'ndjson']: raise ValueError("Format value is not recognized: {}".format(fmt)) - df = self.object_to_df(query_results=query_results, coerce_to_timestamp=coerce_to_timestamp, - record_time_added=record_time_added) + df = self.object_to_df( + query_results=query_results, + coerce_to_timestamp=coerce_to_timestamp, + record_time_added=record_time_added, + ) # write the CSV or JSON file depending on the option # NOTE: @@ -253,8 +256,10 @@ def write_object_to_file(self, # we remove these newlines so that the output is a valid CSV format self.log.info("Cleaning data and writing to CSV") possible_strings = df.columns[df.dtypes == "object"] - df[possible_strings] = df[possible_strings].astype(str).apply( - lambda x: x.str.replace("\r\n", "").str.replace("\n", "") + df[possible_strings] = ( + df[possible_strings] + .astype(str) + .apply(lambda x: x.str.replace("\r\n", "").str.replace("\n", "")) ) # write the dataframe df.to_csv(filename, index=False) @@ -265,8 +270,7 @@ def write_object_to_file(self, return df - def object_to_df(self, query_results, coerce_to_timestamp=False, - record_time_added=False): + def object_to_df(self, query_results, coerce_to_timestamp=False, record_time_added=False): """ Export query results to dataframe. diff --git a/airflow/providers/salesforce/hooks/tableau.py b/airflow/providers/salesforce/hooks/tableau.py index 63ae14cc18637..56d1bab09cc40 100644 --- a/airflow/providers/salesforce/hooks/tableau.py +++ b/airflow/providers/salesforce/hooks/tableau.py @@ -30,6 +30,7 @@ class TableauJobFinishCode(Enum): .. seealso:: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref.htm#query_job """ + PENDING = -1 SUCCESS = 0 ERROR = 1 @@ -81,9 +82,7 @@ def get_conn(self) -> Auth.contextmgr: def _auth_via_password(self) -> Auth.contextmgr: tableau_auth = TableauAuth( - username=self.conn.login, - password=self.conn.password, - site_id=self.site_id + username=self.conn.login, password=self.conn.password, site_id=self.site_id ) return self.server.auth.sign_in(tableau_auth) @@ -91,7 +90,7 @@ def _auth_via_token(self) -> Auth.contextmgr: tableau_auth = PersonalAccessTokenAuth( token_name=self.conn.extra_dejson['token_name'], personal_access_token=self.conn.extra_dejson['personal_access_token'], - site_id=self.site_id + site_id=self.site_id, ) return self.server.auth.sign_in_with_personal_access_token(tableau_auth) diff --git a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py b/airflow/providers/salesforce/operators/tableau_refresh_workbook.py index 8000680ee5f91..95400b98f617b 100644 --- a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py +++ b/airflow/providers/salesforce/operators/tableau_refresh_workbook.py @@ -42,12 +42,15 @@ class TableauRefreshWorkbookOperator(BaseOperator): """ @apply_defaults - def __init__(self, *, - workbook_name: str, - site_id: Optional[str] = None, - blocking: bool = True, - tableau_conn_id: str = 'tableau_default', - **kwargs): + def __init__( + self, + *, + workbook_name: str, + site_id: Optional[str] = None, + blocking: bool = True, + tableau_conn_id: str = 'tableau_default', + **kwargs, + ): super().__init__(**kwargs) self.workbook_name = workbook_name self.site_id = site_id @@ -69,12 +72,13 @@ def execute(self, context: dict) -> str: job_id = self._refresh_workbook(tableau_hook, workbook.id) if self.blocking: from airflow.providers.salesforce.sensors.tableau_job_status import TableauJobStatusSensor + TableauJobStatusSensor( job_id=job_id, site_id=self.site_id, tableau_conn_id=self.tableau_conn_id, task_id='wait_until_succeeded', - dag=None + dag=None, ).execute(context={}) self.log.info('Workbook %s has been successfully refreshed.', self.workbook_name) return job_id diff --git a/airflow/providers/salesforce/sensors/tableau_job_status.py b/airflow/providers/salesforce/sensors/tableau_job_status.py index 7ec6b3c293bf1..7544964d7dcce 100644 --- a/airflow/providers/salesforce/sensors/tableau_job_status.py +++ b/airflow/providers/salesforce/sensors/tableau_job_status.py @@ -46,11 +46,14 @@ class TableauJobStatusSensor(BaseSensorOperator): template_fields = ('job_id',) @apply_defaults - def __init__(self, *, - job_id: str, - site_id: Optional[str] = None, - tableau_conn_id: str = 'tableau_default', - **kwargs): + def __init__( + self, + *, + job_id: str, + site_id: Optional[str] = None, + tableau_conn_id: str = 'tableau_default', + **kwargs, + ): super().__init__(**kwargs) self.tableau_conn_id = tableau_conn_id self.job_id = job_id diff --git a/airflow/providers/samba/hooks/samba.py b/airflow/providers/samba/hooks/samba.py index aebc61f27d8cd..273b600c20dd1 100644 --- a/airflow/providers/samba/hooks/samba.py +++ b/airflow/providers/samba/hooks/samba.py @@ -38,7 +38,8 @@ def get_conn(self): share=self.conn.schema, username=self.conn.login, ip=self.conn.host, - password=self.conn.password) + password=self.conn.password, + ) return samba def push_from_local(self, destination_filepath, local_filepath): diff --git a/airflow/providers/segment/hooks/segment.py b/airflow/providers/segment/hooks/segment.py index 2ff223ec7fb46..aba01c366a3c7 100644 --- a/airflow/providers/segment/hooks/segment.py +++ b/airflow/providers/segment/hooks/segment.py @@ -52,12 +52,9 @@ class SegmentHook(BaseHook): So we define it in the `Extras` field as: `{"write_key":"YOUR_SECURITY_TOKEN"}` """ + def __init__( - self, - segment_conn_id: str = 'segment_default', - segment_debug_mode: bool = False, - *args, - **kwargs + self, segment_conn_id: str = 'segment_default', segment_debug_mode: bool = False, *args, **kwargs ) -> None: super().__init__() self.segment_conn_id = segment_conn_id @@ -85,6 +82,5 @@ def on_error(self, error: str, items: str) -> None: """ Handles error callbacks when using Segment with segment_debug_mode set to True """ - self.log.error('Encountered Segment error: %s with ' - 'items: %s', error, items) + self.log.error('Encountered Segment error: %s with ' 'items: %s', error, items) raise AirflowException('Segment error: {}'.format(error)) diff --git a/airflow/providers/segment/operators/segment_track_event.py b/airflow/providers/segment/operators/segment_track_event.py index 01681a4034799..f4494b63cdee4 100644 --- a/airflow/providers/segment/operators/segment_track_event.py +++ b/airflow/providers/segment/operators/segment_track_event.py @@ -38,17 +38,21 @@ class SegmentTrackEventOperator(BaseOperator): Defaults to False :type segment_debug_mode: bool """ + template_fields = ('user_id', 'event', 'properties') ui_color = '#ffd700' @apply_defaults - def __init__(self, *, - user_id: str, - event: str, - properties: Optional[dict] = None, - segment_conn_id: str = 'segment_default', - segment_debug_mode: bool = False, - **kwargs) -> None: + def __init__( + self, + *, + user_id: str, + event: str, + properties: Optional[dict] = None, + segment_conn_id: str = 'segment_default', + segment_debug_mode: bool = False, + **kwargs, + ) -> None: super().__init__(**kwargs) self.user_id = user_id self.event = event @@ -58,15 +62,16 @@ def __init__(self, *, self.segment_conn_id = segment_conn_id def execute(self, context: Dict) -> None: - hook = SegmentHook(segment_conn_id=self.segment_conn_id, - segment_debug_mode=self.segment_debug_mode) + hook = SegmentHook(segment_conn_id=self.segment_conn_id, segment_debug_mode=self.segment_debug_mode) self.log.info( 'Sending track event (%s) for user id: %s with properties: %s', - self.event, self.user_id, self.properties) + self.event, + self.user_id, + self.properties, + ) # pylint: disable=no-member hook.track( # type: ignore - user_id=self.user_id, - event=self.event, - properties=self.properties) + user_id=self.user_id, event=self.event, properties=self.properties + ) diff --git a/airflow/providers/sendgrid/utils/emailer.py b/airflow/providers/sendgrid/utils/emailer.py index f37fc82859369..64c211461a2a6 100644 --- a/airflow/providers/sendgrid/utils/emailer.py +++ b/airflow/providers/sendgrid/utils/emailer.py @@ -27,7 +27,15 @@ import sendgrid from sendgrid.helpers.mail import ( - Attachment, Category, Content, CustomArg, Email, Mail, MailSettings, Personalization, SandBoxMode, + Attachment, + Category, + Content, + CustomArg, + Email, + Mail, + MailSettings, + Personalization, + SandBoxMode, ) from airflow.utils.email import get_email_address_list @@ -37,14 +45,16 @@ AddressesType = Union[str, Iterable[str]] -def send_email(to: AddressesType, - subject: str, - html_content: str, - files: Optional[AddressesType] = None, - cc: Optional[AddressesType] = None, - bcc: Optional[AddressesType] = None, - sandbox_mode: bool = False, - **kwargs) -> None: +def send_email( + to: AddressesType, + subject: str, + html_content: str, + files: Optional[AddressesType] = None, + cc: Optional[AddressesType] = None, + bcc: Optional[AddressesType] = None, + sandbox_mode: bool = False, + **kwargs, +) -> None: """ Send an email with html content using `Sendgrid `__. @@ -104,7 +114,7 @@ def send_email(to: AddressesType, file_type=mimetypes.guess_type(basename)[0], file_name=basename, disposition="attachment", - content_id=f"<{basename}>" + content_id=f"<{basename}>", ) mail.add_attachment(attachment) @@ -116,8 +126,14 @@ def _post_sendgrid_mail(mail_data: Dict) -> None: response = sendgrid_client.client.mail.send.post(request_body=mail_data) # 2xx status code. if 200 <= response.status_code < 300: - log.info('Email with subject %s is successfully sent to recipients: %s', - mail_data['subject'], mail_data['personalizations']) + log.info( + 'Email with subject %s is successfully sent to recipients: %s', + mail_data['subject'], + mail_data['personalizations'], + ) else: - log.error('Failed to send out email with subject %s, status code: %s', - mail_data['subject'], response.status_code) + log.error( + 'Failed to send out email with subject %s, status code: %s', + mail_data['subject'], + response.status_code, + ) diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index 20ba24829a31b..f5afe5bde3e20 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -67,6 +67,7 @@ def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None: # For backward compatibility # TODO: remove in Airflow 2.1 import warnings + if 'ignore_hostkey_verification' in extra_options: warnings.warn( 'Extra option `ignore_hostkey_verification` is deprecated.' @@ -75,13 +76,12 @@ def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None: DeprecationWarning, stacklevel=2, ) - self.no_host_key_check = str( - extra_options['ignore_hostkey_verification'] - ).lower() == 'true' + self.no_host_key_check = ( + str(extra_options['ignore_hostkey_verification']).lower() == 'true' + ) if 'no_host_key_check' in extra_options: - self.no_host_key_check = str( - extra_options['no_host_key_check']).lower() == 'true' + self.no_host_key_check = str(extra_options['no_host_key_check']).lower() == 'true' if 'private_key' in extra_options: warnings.warn( @@ -106,7 +106,7 @@ def get_conn(self) -> pysftp.Connection: 'host': self.remote_host, 'port': self.port, 'username': self.username, - 'cnopts': cnopts + 'cnopts': cnopts, } if self.password and self.password.strip(): conn_params['password'] = self.password @@ -138,12 +138,12 @@ def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]: flist = conn.listdir_attr(path) files = {} for f in flist: - modify = datetime.datetime.fromtimestamp( - f.st_mtime).strftime('%Y%m%d%H%M%S') + modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime('%Y%m%d%H%M%S') files[f.filename] = { 'size': f.st_size, 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file', - 'modify': modify} + 'modify': modify, + } return files def list_directory(self, path: str) -> List[str]: @@ -278,11 +278,7 @@ def get_tree_map( files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str] def append_matching_path_callback(list_): - return ( - lambda item: list_.append(item) - if self._is_path_match(item, prefix, delimiter) - else None - ) + return lambda item: list_.append(item) if self._is_path_match(item, prefix, delimiter) else None conn.walktree( remotepath=path, diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index fc2060bd9a238..8c2d113691e7b 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -31,6 +31,7 @@ class SFTPOperation: """ Operation that can be used with SFTP/ """ + PUT = 'put' GET = 'get' @@ -79,19 +80,23 @@ class SFTPOperator(BaseOperator): :type create_intermediate_dirs: bool """ + template_fields = ('local_filepath', 'remote_filepath', 'remote_host') @apply_defaults - def __init__(self, *, - ssh_hook=None, - ssh_conn_id=None, - remote_host=None, - local_filepath=None, - remote_filepath=None, - operation=SFTPOperation.PUT, - confirm=True, - create_intermediate_dirs=False, - **kwargs): + def __init__( + self, + *, + ssh_hook=None, + ssh_conn_id=None, + remote_host=None, + local_filepath=None, + remote_filepath=None, + operation=SFTPOperation.PUT, + confirm=True, + create_intermediate_dirs=False, + **kwargs, + ): super().__init__(**kwargs) self.ssh_hook = ssh_hook self.ssh_conn_id = ssh_conn_id @@ -101,10 +106,12 @@ def __init__(self, *, self.operation = operation self.confirm = confirm self.create_intermediate_dirs = create_intermediate_dirs - if not (self.operation.lower() == SFTPOperation.GET or - self.operation.lower() == SFTPOperation.PUT): - raise TypeError("unsupported operation value {0}, expected {1} or {2}" - .format(self.operation, SFTPOperation.GET, SFTPOperation.PUT)) + if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): + raise TypeError( + "unsupported operation value {0}, expected {1} or {2}".format( + self.operation, SFTPOperation.GET, SFTPOperation.PUT + ) + ) def execute(self, context): file_msg = None @@ -113,17 +120,20 @@ def execute(self, context): if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") else: - self.log.info("ssh_hook is not provided or invalid. " - "Trying ssh_conn_id to create SSHHook.") + self.log.info( + "ssh_hook is not provided or invalid. " "Trying ssh_conn_id to create SSHHook." + ) self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: - self.log.info("remote_host is provided explicitly. " - "It will replace the remote_host which was defined " - "in ssh_hook or predefined in connection of ssh_conn_id.") + self.log.info( + "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id." + ) self.ssh_hook.remote_host = self.remote_host with self.ssh_hook.get_conn() as ssh_client: @@ -132,27 +142,21 @@ def execute(self, context): local_folder = os.path.dirname(self.local_filepath) if self.create_intermediate_dirs: Path(local_folder).mkdir(parents=True, exist_ok=True) - file_msg = "from {0} to {1}".format(self.remote_filepath, - self.local_filepath) + file_msg = "from {0} to {1}".format(self.remote_filepath, self.local_filepath) self.log.info("Starting to transfer %s", file_msg) sftp_client.get(self.remote_filepath, self.local_filepath) else: remote_folder = os.path.dirname(self.remote_filepath) if self.create_intermediate_dirs: _make_intermediate_dirs( - sftp_client=sftp_client, - remote_directory=remote_folder, + sftp_client=sftp_client, remote_directory=remote_folder, ) - file_msg = "from {0} to {1}".format(self.local_filepath, - self.remote_filepath) + file_msg = "from {0} to {1}".format(self.local_filepath, self.remote_filepath) self.log.info("Starting to transfer file %s", file_msg) - sftp_client.put(self.local_filepath, - self.remote_filepath, - confirm=self.confirm) + sftp_client.put(self.local_filepath, self.remote_filepath, confirm=self.confirm) except Exception as e: - raise AirflowException("Error while transferring {0}, error: {1}" - .format(file_msg, str(e))) + raise AirflowException("Error while transferring {0}, error: {1}".format(file_msg, str(e))) return self.local_filepath diff --git a/airflow/providers/sftp/sensors/sftp.py b/airflow/providers/sftp/sensors/sftp.py index de3e993fc726e..f59da3d536ea4 100644 --- a/airflow/providers/sftp/sensors/sftp.py +++ b/airflow/providers/sftp/sensors/sftp.py @@ -34,6 +34,7 @@ class SFTPSensor(BaseSensorOperator): :param sftp_conn_id: The connection to run the sensor against :type sftp_conn_id: str """ + template_fields = ('path',) @apply_defaults diff --git a/airflow/providers/singularity/example_dags/example_singularity.py b/airflow/providers/singularity/example_dags/example_singularity.py index c4feea4193414..fc33b0dd65db1 100644 --- a/airflow/providers/singularity/example_dags/example_singularity.py +++ b/airflow/providers/singularity/example_dags/example_singularity.py @@ -30,34 +30,25 @@ 'email_on_failure': False, 'email_on_retry': False, 'retries': 1, - 'retry_delay': timedelta(minutes=5) + 'retry_delay': timedelta(minutes=5), } -with DAG('singularity_sample', - default_args=default_args, - schedule_interval=timedelta(minutes=10), - start_date=days_ago(0)) as dag: - - t1 = BashOperator( - task_id='print_date', - bash_command='date', - dag=dag) - - t2 = BashOperator( - task_id='sleep', - bash_command='sleep 5', - retries=3, - dag=dag) - - t3 = SingularityOperator(command='/bin/sleep 30', - image='docker://busybox:1.30.1', - task_id='singularity_op_tester', - dag=dag) - - t4 = BashOperator( - task_id='print_hello', - bash_command='echo "hello world!!!"', - dag=dag) +with DAG( + 'singularity_sample', + default_args=default_args, + schedule_interval=timedelta(minutes=10), + start_date=days_ago(0), +) as dag: + + t1 = BashOperator(task_id='print_date', bash_command='date', dag=dag) + + t2 = BashOperator(task_id='sleep', bash_command='sleep 5', retries=3, dag=dag) + + t3 = SingularityOperator( + command='/bin/sleep 30', image='docker://busybox:1.30.1', task_id='singularity_op_tester', dag=dag + ) + + t4 = BashOperator(task_id='print_hello', bash_command='echo "hello world!!!"', dag=dag) t1 >> [t2, t3] t3 >> t4 diff --git a/airflow/providers/singularity/operators/singularity.py b/airflow/providers/singularity/operators/singularity.py index 17d17494827d0..bbd13135cc66e 100644 --- a/airflow/providers/singularity/operators/singularity.py +++ b/airflow/providers/singularity/operators/singularity.py @@ -60,23 +60,32 @@ class SingularityOperator(BaseOperator): set on the container (equivalent to the -w switch the docker client) :type working_dir: str """ - template_fields = ('command', 'environment',) - template_ext = ('.sh', '.bash',) + + template_fields = ( + 'command', + 'environment', + ) + template_ext = ( + '.sh', + '.bash', + ) @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, *, - image: str, - command: Union[str, List[str]], - start_command: Optional[Union[str, List[str]]] = None, - environment: Optional[Dict[str, Any]] = None, - pull_folder: Optional[str] = None, - working_dir: Optional[str] = None, - force_pull: Optional[bool] = False, - volumes: Optional[List[str]] = None, - options: Optional[List[str]] = None, - auto_remove: Optional[bool] = False, - **kwargs) -> None: + self, + *, + image: str, + command: Union[str, List[str]], + start_command: Optional[Union[str, List[str]]] = None, + environment: Optional[Dict[str, Any]] = None, + pull_folder: Optional[str] = None, + working_dir: Optional[str] = None, + force_pull: Optional[bool] = False, + volumes: Optional[List[str]] = None, + options: Optional[List[str]] = None, + auto_remove: Optional[bool] = False, + **kwargs, + ) -> None: super(SingularityOperator, self).__init__(**kwargs) self.auto_remove = auto_remove @@ -132,10 +141,9 @@ def execute(self, context): # Create a container instance self.log.debug('Options include: %s', self.options) - self.instance = self.cli.instance(self.image, - options=self.options, - args=self.start_command, - start=False) + self.instance = self.cli.instance( + self.image, options=self.options, args=self.start_command, start=False + ) self.instance.start() self.log.info(self.instance.cmd) @@ -143,9 +151,7 @@ def execute(self, context): self.log.info('Running command %s', self._get_command()) self.cli.quiet = True - result = self.cli.execute(self.instance, - self._get_command(), - return_result=True) + result = self.cli.execute(self.instance, self._get_command(), return_result=True) # Stop the instance self.log.info('Stopping instance %s', self.instance) diff --git a/airflow/providers/slack/hooks/slack.py b/airflow/providers/slack/hooks/slack.py index e2e4b126eaaeb..a9cfb2f8f399e 100644 --- a/airflow/providers/slack/hooks/slack.py +++ b/airflow/providers/slack/hooks/slack.py @@ -60,10 +60,7 @@ class SlackHook(BaseHook): # noqa """ def __init__( - self, - token: Optional[str] = None, - slack_conn_id: Optional[str] = None, - **client_args: Any, + self, token: Optional[str] = None, slack_conn_id: Optional[str] = None, **client_args: Any, ) -> None: super().__init__() self.token = self.__get_token(token, slack_conn_id) @@ -80,8 +77,7 @@ def __get_token(self, token, slack_conn_id): raise AirflowException('Missing token(password) in Slack connection') return conn.password - raise AirflowException('Cannot get token: ' - 'No valid Slack token nor slack_conn_id supplied.') + raise AirflowException('Cannot get token: ' 'No valid Slack token nor slack_conn_id supplied.') def call(self, api_method, *args, **kwargs) -> None: """ diff --git a/airflow/providers/slack/hooks/slack_webhook.py b/airflow/providers/slack/hooks/slack_webhook.py index de50d1f3cc201..c486e63063216 100644 --- a/airflow/providers/slack/hooks/slack_webhook.py +++ b/airflow/providers/slack/hooks/slack_webhook.py @@ -60,21 +60,22 @@ class SlackWebhookHook(HttpHook): """ # pylint: disable=too-many-arguments - def __init__(self, - http_conn_id=None, - webhook_token=None, - message="", - attachments=None, - blocks=None, - channel=None, - username=None, - icon_emoji=None, - icon_url=None, - link_names=False, - proxy=None, - *args, - **kwargs - ): + def __init__( + self, + http_conn_id=None, + webhook_token=None, + message="", + attachments=None, + blocks=None, + channel=None, + username=None, + icon_emoji=None, + icon_url=None, + link_names=False, + proxy=None, + *args, + **kwargs, + ): super().__init__(http_conn_id=http_conn_id, *args, **kwargs) self.webhook_token = self._get_token(webhook_token, http_conn_id) self.message = message @@ -105,8 +106,7 @@ def _get_token(self, token, http_conn_id): extra = conn.extra_dejson return extra.get('webhook_token', '') else: - raise AirflowException('Cannot get token: No valid Slack ' - 'webhook token nor conn_id supplied') + raise AirflowException('Cannot get token: No valid Slack ' 'webhook token nor conn_id supplied') def _build_slack_message(self): """ @@ -146,7 +146,9 @@ def execute(self): proxies = {'https': self.proxy} slack_message = self._build_slack_message() - self.run(endpoint=self.webhook_token, - data=slack_message, - headers={'Content-type': 'application/json'}, - extra_options={'proxies': proxies}) + self.run( + endpoint=self.webhook_token, + data=slack_message, + headers={'Content-type': 'application/json'}, + extra_options={'proxies': proxies}, + ) diff --git a/airflow/providers/slack/operators/slack.py b/airflow/providers/slack/operators/slack.py index 30254f495d2c9..2278f1a149979 100644 --- a/airflow/providers/slack/operators/slack.py +++ b/airflow/providers/slack/operators/slack.py @@ -44,12 +44,15 @@ class SlackAPIOperator(BaseOperator): """ @apply_defaults - def __init__(self, *, - slack_conn_id: Optional[str] = None, - token: Optional[str] = None, - method: Optional[str] = None, - api_params: Optional[Dict] = None, - **kwargs) -> None: + def __init__( + self, + *, + slack_conn_id: Optional[str] = None, + token: Optional[str] = None, + method: Optional[str] = None, + api_params: Optional[Dict] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.token = token # type: Optional[str] @@ -73,7 +76,7 @@ def construct_api_call_params(self): "SlackAPIOperator should not be used directly. Chose one of the subclasses instead" ) - def execute(self, **kwargs): # noqa: D403 + def execute(self, **kwargs): # noqa: D403 """ SlackAPIOperator calls will not fail even if the call is not unsuccessful. It should not prevent a DAG from completing in success @@ -120,17 +123,19 @@ class SlackAPIPostOperator(SlackAPIOperator): ui_color = '#FFBA40' @apply_defaults - def __init__(self, - channel: str = '#general', - username: str = 'Airflow', - text: str = 'No message has been set.\n' - 'Here is a cat video instead\n' - 'https://www.youtube.com/watch?v=J---aiyznGQ', - icon_url: str = 'https://raw.githubusercontent.com/apache/' - 'airflow/master/airflow/www/static/pin_100.png', - attachments: Optional[List] = None, - blocks: Optional[List] = None, - **kwargs): + def __init__( + self, + channel: str = '#general', + username: str = 'Airflow', + text: str = 'No message has been set.\n' + 'Here is a cat video instead\n' + 'https://www.youtube.com/watch?v=J---aiyznGQ', + icon_url: str = 'https://raw.githubusercontent.com/apache/' + 'airflow/master/airflow/www/static/pin_100.png', + attachments: Optional[List] = None, + blocks: Optional[List] = None, + **kwargs, + ): self.method = 'chat.postMessage' self.channel = channel self.username = username @@ -186,13 +191,15 @@ class SlackAPIFileOperator(SlackAPIOperator): ui_color = '#44BEDF' @apply_defaults - def __init__(self, - channel: str = '#general', - initial_comment: str = 'No message has been set!', - filename: str = 'default_name.csv', - filetype: str = 'csv', - content: str = 'default,content,csv,file', - **kwargs): + def __init__( + self, + channel: str = '#general', + initial_comment: str = 'No message has been set!', + filename: str = 'default_name.csv', + filetype: str = 'csv', + content: str = 'default,content,csv,file', + **kwargs, + ): self.method = 'files.upload' self.channel = channel self.initial_comment = initial_comment @@ -207,5 +214,5 @@ def construct_api_call_params(self): 'content': self.content, 'filename': self.filename, 'filetype': self.filetype, - 'initial_comment': self.initial_comment + 'initial_comment': self.initial_comment, } diff --git a/airflow/providers/slack/operators/slack_webhook.py b/airflow/providers/slack/operators/slack_webhook.py index 2c0023e14bdc3..ea4103bbaafc8 100644 --- a/airflow/providers/slack/operators/slack_webhook.py +++ b/airflow/providers/slack/operators/slack_webhook.py @@ -58,26 +58,35 @@ class SlackWebhookOperator(SimpleHttpOperator): :type proxy: str """ - template_fields = ['webhook_token', 'message', 'attachments', 'blocks', 'channel', - 'username', 'proxy', ] + template_fields = [ + 'webhook_token', + 'message', + 'attachments', + 'blocks', + 'channel', + 'username', + 'proxy', + ] # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - http_conn_id=None, - webhook_token=None, - message="", - attachments=None, - blocks=None, - channel=None, - username=None, - icon_emoji=None, - icon_url=None, - link_names=False, - proxy=None, - **kwargs): - super().__init__(endpoint=webhook_token, - **kwargs) + def __init__( + self, + *, + http_conn_id=None, + webhook_token=None, + message="", + attachments=None, + blocks=None, + channel=None, + username=None, + icon_emoji=None, + icon_url=None, + link_names=False, + proxy=None, + **kwargs, + ): + super().__init__(endpoint=webhook_token, **kwargs) self.http_conn_id = http_conn_id self.webhook_token = webhook_token self.message = message @@ -106,6 +115,6 @@ def execute(self, context): self.icon_emoji, self.icon_url, self.link_names, - self.proxy + self.proxy, ) self.hook.execute() diff --git a/airflow/providers/snowflake/example_dags/example_snowflake.py b/airflow/providers/snowflake/example_dags/example_snowflake.py index 4fea24abcea6e..a273ea47fe822 100644 --- a/airflow/providers/snowflake/example_dags/example_snowflake.py +++ b/airflow/providers/snowflake/example_dags/example_snowflake.py @@ -37,26 +37,19 @@ SNOWFLAKE_SELECT_SQL = f"SELECT * FROM {SNOWFLAKE_SAMPLE_TABLE} LIMIT 100;" SNOWFLAKE_SLACK_SQL = f"SELECT O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS FROM {SNOWFLAKE_SAMPLE_TABLE} LIMIT 10;" -SNOWFLAKE_SLACK_MESSAGE = "Results in an ASCII table:\n" \ - "```{{ results_df | tabulate(tablefmt='pretty', headers='keys') }}```" +SNOWFLAKE_SLACK_MESSAGE = ( + "Results in an ASCII table:\n```{{ results_df | tabulate(tablefmt='pretty', headers='keys') }}```" +) SNOWFLAKE_CREATE_TABLE_SQL = f"CREATE TRANSIENT TABLE IF NOT EXISTS {SNOWFLAKE_LOAD_TABLE}(data VARIANT);" default_args = { 'owner': 'airflow', } -dag = DAG( - 'example_snowflake', - default_args=default_args, - start_date=days_ago(2), - tags=['example'], -) +dag = DAG('example_snowflake', default_args=default_args, start_date=days_ago(2), tags=['example'],) select = SnowflakeOperator( - task_id='select', - snowflake_conn_id=SNOWFLAKE_CONN_ID, - sql=SNOWFLAKE_SELECT_SQL, - dag=dag, + task_id='select', snowflake_conn_id=SNOWFLAKE_CONN_ID, sql=SNOWFLAKE_SELECT_SQL, dag=dag, ) slack_report = SnowflakeToSlackOperator( @@ -65,7 +58,7 @@ slack_message=SNOWFLAKE_SLACK_MESSAGE, snowflake_conn_id=SNOWFLAKE_CONN_ID, slack_conn_id=SLACK_CONN_ID, - dag=dag + dag=dag, ) create_table = SnowflakeOperator( diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index f8e5ad3316628..e5eadd0d97678 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -18,6 +18,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization + # pylint: disable=no-name-in-module from snowflake import connector @@ -29,6 +30,7 @@ class SnowflakeHook(DbApiHook): Interact with Snowflake. get_sqlalchemy_engine() depends on snowflake-sqlalchemy """ + conn_name_attr = 'snowflake_conn_id' default_conn_name = 'snowflake_default' supports_autocommit = True @@ -66,7 +68,7 @@ def _get_conn_params(self): "warehouse": self.warehouse or warehouse, "region": self.region or region, "role": self.role or role, - "authenticator": self.authenticator or authenticator + "authenticator": self.authenticator or authenticator, } # If private_key_file is specified in the extra json, load the contents of the file as a private @@ -82,14 +84,14 @@ def _get_conn_params(self): passphrase = conn.password.strip().encode() p_key = serialization.load_pem_private_key( - key.read(), - password=passphrase, - backend=default_backend() + key.read(), password=passphrase, backend=default_backend() ) - pkb = p_key.private_bytes(encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption()) + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) conn_config['private_key'] = pkb conn_config.pop('password', None) @@ -101,8 +103,10 @@ def get_uri(self): Override DbApiHook get_uri method for get_sqlalchemy_engine() """ conn_config = self._get_conn_params() - uri = 'snowflake://{user}:{password}@{account}/{database}/{schema}' \ - '?warehouse={warehouse}&role={role}&authenticator={authenticator}' + uri = ( + 'snowflake://{user}:{password}@{account}/{database}/{schema}' + '?warehouse={warehouse}&role={role}&authenticator={authenticator}' + ) return uri.format(**conn_config) def get_conn(self): @@ -123,10 +127,8 @@ def _get_aws_credentials(self): if self.snowflake_conn_id: # pylint: disable=no-member connection_object = self.get_connection(self.snowflake_conn_id) # pylint: disable=no-member if 'aws_secret_access_key' in connection_object.extra_dejson: - aws_access_key_id = connection_object.extra_dejson.get( - 'aws_access_key_id') - aws_secret_access_key = connection_object.extra_dejson.get( - 'aws_secret_access_key') + aws_access_key_id = connection_object.extra_dejson.get('aws_access_key_id') + aws_secret_access_key = connection_object.extra_dejson.get('aws_secret_access_key') return aws_access_key_id, aws_secret_access_key def set_autocommit(self, conn, autocommit): diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 6a7e1766de425..e8c298224d669 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -63,9 +63,19 @@ class SnowflakeOperator(BaseOperator): @apply_defaults def __init__( - self, *, sql, snowflake_conn_id='snowflake_default', parameters=None, - autocommit=True, warehouse=None, database=None, role=None, - schema=None, authenticator=None, **kwargs): + self, + *, + sql, + snowflake_conn_id='snowflake_default', + parameters=None, + autocommit=True, + warehouse=None, + database=None, + role=None, + schema=None, + authenticator=None, + **kwargs, + ): super().__init__(**kwargs) self.snowflake_conn_id = snowflake_conn_id self.sql = sql @@ -83,9 +93,14 @@ def get_hook(self): :return: a SnowflakeHook instance. :rtype: SnowflakeHook """ - return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id, - warehouse=self.warehouse, database=self.database, - role=self.role, schema=self.schema, authenticator=self.authenticator) + return SnowflakeHook( + snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, + database=self.database, + role=self.role, + schema=self.schema, + authenticator=self.authenticator, + ) def execute(self, context): """ @@ -93,7 +108,4 @@ def execute(self, context): """ self.log.info('Executing: %s', self.sql) hook = self.get_hook() - hook.run( - self.sql, - autocommit=self.autocommit, - parameters=self.parameters) + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) diff --git a/airflow/providers/snowflake/transfers/s3_to_snowflake.py b/airflow/providers/snowflake/transfers/s3_to_snowflake.py index 361545dfd28e9..c94137c8044b5 100644 --- a/airflow/providers/snowflake/transfers/s3_to_snowflake.py +++ b/airflow/providers/snowflake/transfers/s3_to_snowflake.py @@ -46,17 +46,19 @@ class S3ToSnowflakeOperator(BaseOperator): """ @apply_defaults - def __init__(self, - *, - s3_keys, - table, - stage, - file_format, - schema, # TODO: shouldn't be required, rely on session/user defaults - columns_array=None, - autocommit=True, - snowflake_conn_id='snowflake_default', - **kwargs): + def __init__( + self, + *, + s3_keys, + table, + stage, + file_format, + schema, # TODO: shouldn't be required, rely on session/user defaults + columns_array=None, + autocommit=True, + snowflake_conn_id='snowflake_default', + **kwargs, + ): super().__init__(**kwargs) self.s3_keys = s3_keys self.table = table @@ -82,27 +84,20 @@ def execute(self, context): files={files} file_format={file_format} """.format( - stage=self.stage, - files=files, - file_format=self.file_format + stage=self.stage, files=files, file_format=self.file_format ) if self.columns_array: copy_query = """ COPY INTO {schema}.{table}({columns}) {base_sql} """.format( - schema=self.schema, - table=self.table, - columns=",".join(self.columns_array), - base_sql=base_sql + schema=self.schema, table=self.table, columns=",".join(self.columns_array), base_sql=base_sql ) else: copy_query = """ COPY INTO {schema}.{table} {base_sql} """.format( - schema=self.schema, - table=self.table, - base_sql=base_sql + schema=self.schema, table=self.table, base_sql=base_sql ) self.log.info('Executing COPY command...') diff --git a/airflow/providers/snowflake/transfers/snowflake_to_slack.py b/airflow/providers/snowflake/transfers/snowflake_to_slack.py index b820a2a0824cd..b25e6012e48e3 100644 --- a/airflow/providers/snowflake/transfers/snowflake_to_slack.py +++ b/airflow/providers/snowflake/transfers/snowflake_to_slack.py @@ -62,6 +62,7 @@ class SnowflakeToSlackOperator(BaseOperator): 'webhook_token' attribute needs to be specified in the 'Extra' JSON field against the slack_conn_id :type slack_token: Optional[str] """ + template_fields = ['sql', 'slack_message'] template_ext = ['.sql', '.jinja', '.j2'] times_rendered = 0 @@ -81,7 +82,7 @@ def __init__( # pylint: disable=too-many-arguments schema: Optional[str] = None, role: Optional[str] = None, slack_token: Optional[str] = None, - **kwargs + **kwargs, ) -> None: super(SnowflakeToSlackOperator, self).__init__(**kwargs) @@ -114,13 +115,18 @@ def _render_and_send_slack_message(self, context, df) -> None: slack_hook.execute() def _get_snowflake_hook(self) -> SnowflakeHook: - return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id, - warehouse=self.warehouse, database=self.database, - role=self.role, schema=self.schema) + return SnowflakeHook( + snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, + database=self.database, + role=self.role, + schema=self.schema, + ) def _get_slack_hook(self) -> SlackWebhookHook: - return SlackWebhookHook(http_conn_id=self.slack_conn_id, message=self.slack_message, - webhook_token=self.slack_token) + return SlackWebhookHook( + http_conn_id=self.slack_conn_id, message=self.slack_message, webhook_token=self.slack_token + ) def render_template_fields(self, context, jinja_env=None) -> None: # If this is the first render of the template fields, exclude slack_message from rendering since diff --git a/airflow/providers/sqlite/operators/sqlite.py b/airflow/providers/sqlite/operators/sqlite.py index c51b4e9e38c6d..a65dc44bca59e 100644 --- a/airflow/providers/sqlite/operators/sqlite.py +++ b/airflow/providers/sqlite/operators/sqlite.py @@ -41,12 +41,13 @@ class SqliteOperator(BaseOperator): @apply_defaults def __init__( - self, - *, - sql: str, - sqlite_conn_id: str = 'sqlite_default', - parameters: Optional[Union[Mapping, Iterable]] = None, - **kwargs) -> None: + self, + *, + sql: str, + sqlite_conn_id: str = 'sqlite_default', + parameters: Optional[Union[Mapping, Iterable]] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.sqlite_conn_id = sqlite_conn_id self.sql = sql diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index 4017b495a2c9e..e4c1850304121 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -57,16 +57,17 @@ class SSHHook(BaseHook): :type keepalive_interval: int """ - def __init__(self, - ssh_conn_id=None, - remote_host=None, - username=None, - password=None, - key_file=None, - port=None, - timeout=10, - keepalive_interval=30 - ): + def __init__( + self, + ssh_conn_id=None, + remote_host=None, + username=None, + password=None, + key_file=None, + port=None, + timeout=10, + keepalive_interval=30, + ): super().__init__() self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host @@ -111,25 +112,28 @@ def __init__(self, if "timeout" in extra_options: self.timeout = int(extra_options["timeout"], 10) - if "compress" in extra_options\ - and str(extra_options["compress"]).lower() == 'false': + if "compress" in extra_options and str(extra_options["compress"]).lower() == 'false': self.compress = False - if "no_host_key_check" in extra_options\ - and\ - str(extra_options["no_host_key_check"]).lower() == 'false': + if ( + "no_host_key_check" in extra_options + and str(extra_options["no_host_key_check"]).lower() == 'false' + ): self.no_host_key_check = False - if "allow_host_key_change" in extra_options\ - and\ - str(extra_options["allow_host_key_change"]).lower() == 'true': + if ( + "allow_host_key_change" in extra_options + and str(extra_options["allow_host_key_change"]).lower() == 'true' + ): self.allow_host_key_change = True - if "look_for_keys" in extra_options\ - and\ - str(extra_options["look_for_keys"]).lower() == 'false': + if ( + "look_for_keys" in extra_options + and str(extra_options["look_for_keys"]).lower() == 'false' + ): self.look_for_keys = False if self.pkey and self.key_file: raise AirflowException( - "Params key_file and private_key both provided. Must provide no more than one.") + "Params key_file and private_key both provided. Must provide no more than one." + ) if not self.remote_host: raise AirflowException("Missing required param: remote_host") @@ -139,7 +143,8 @@ def __init__(self, self.log.debug( "username to ssh to host: %s is not specified for connection id" " %s. Using system's default provided by getpass.getuser()", - self.remote_host, self.ssh_conn_id + self.remote_host, + self.ssh_conn_id, ) self.username = getpass.getuser() @@ -169,12 +174,15 @@ def get_conn(self) -> paramiko.SSHClient: client = paramiko.SSHClient() if not self.allow_host_key_change: - self.log.warning('Remote Identification Change is not verified. ' - 'This wont protect against Man-In-The-Middle attacks') + self.log.warning( + 'Remote Identification Change is not verified. ' + 'This wont protect against Man-In-The-Middle attacks' + ) client.load_system_host_keys() if self.no_host_key_check: - self.log.warning('No Host Key Verification. This wont protect ' - 'against Man-In-The-Middle attacks') + self.log.warning( + 'No Host Key Verification. This wont protect ' 'against Man-In-The-Middle attacks' + ) # Default is RejectPolicy client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) connect_kwargs = dict( @@ -184,7 +192,7 @@ def get_conn(self) -> paramiko.SSHClient: compress=self.compress, port=self.port, sock=self.host_proxy, - look_for_keys=self.look_for_keys + look_for_keys=self.look_for_keys, ) if self.password: @@ -206,10 +214,12 @@ def get_conn(self) -> paramiko.SSHClient: return client def __enter__(self): - warnings.warn('The contextmanager of SSHHook is deprecated.' - 'Please use get_conn() as a contextmanager instead.' - 'This method will be removed in Airflow 2.0', - category=DeprecationWarning) + warnings.warn( + 'The contextmanager of SSHHook is deprecated.' + 'Please use get_conn() as a contextmanager instead.' + 'This method will be removed in Airflow 2.0', + category=DeprecationWarning, + ) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -243,28 +253,21 @@ def get_tunnel(self, remote_port, remote_host="localhost", local_port=None): ssh_proxy=self.host_proxy, local_bind_address=local_bind_address, remote_bind_address=(remote_host, remote_port), - logger=self.log + logger=self.log, ) if self.password: password = self.password.strip() - tunnel_kwargs.update( - ssh_password=password, - ) + tunnel_kwargs.update(ssh_password=password,) else: - tunnel_kwargs.update( - host_pkey_directories=[], - ) + tunnel_kwargs.update(host_pkey_directories=[],) client = SSHTunnelForwarder(self.remote_host, **tunnel_kwargs) return client def create_tunnel( - self, - local_port: int, - remote_port: Optional[int] = None, - remote_host: str = "localhost" + self, local_port: int, remote_port: Optional[int] = None, remote_host: str = "localhost" ) -> SSHTunnelForwarder: """ Creates tunnel for SSH connection [Deprecated]. @@ -274,10 +277,12 @@ def create_tunnel( :param remote_host: remote host :return: """ - warnings.warn('SSHHook.create_tunnel is deprecated, Please' - 'use get_tunnel() instead. But please note that the' - 'order of the parameters have changed' - 'This method will be removed in Airflow 2.0', - category=DeprecationWarning) + warnings.warn( + 'SSHHook.create_tunnel is deprecated, Please' + 'use get_tunnel() instead. But please note that the' + 'order of the parameters have changed' + 'This method will be removed in Airflow 2.0', + category=DeprecationWarning, + ) return self.get_tunnel(remote_port, remote_host, local_port) diff --git a/airflow/providers/ssh/operators/ssh.py b/airflow/providers/ssh/operators/ssh.py index 80b9c34f08191..1a14e44c49d31 100644 --- a/airflow/providers/ssh/operators/ssh.py +++ b/airflow/providers/ssh/operators/ssh.py @@ -58,16 +58,18 @@ class SSHOperator(BaseOperator): template_ext = ('.sh',) @apply_defaults - def __init__(self, - *, - ssh_hook=None, - ssh_conn_id=None, - remote_host=None, - command=None, - timeout=10, - environment=None, - get_pty=False, - **kwargs): + def __init__( + self, + *, + ssh_hook=None, + ssh_conn_id=None, + remote_host=None, + command=None, + timeout=10, + environment=None, + get_pty=False, + **kwargs, + ): super().__init__(**kwargs) self.ssh_hook = ssh_hook self.ssh_conn_id = ssh_conn_id @@ -83,18 +85,20 @@ def execute(self, context): if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") else: - self.log.info("ssh_hook is not provided or invalid. " - "Trying ssh_conn_id to create SSHHook.") - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, - timeout=self.timeout) + self.log.info( + "ssh_hook is not provided or invalid. " "Trying ssh_conn_id to create SSHHook." + ) + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, timeout=self.timeout) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: - self.log.info("remote_host is provided explicitly. " - "It will replace the remote_host which was defined " - "in ssh_hook or predefined in connection of ssh_conn_id.") + self.log.info( + "remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id." + ) self.ssh_hook.remote_host = self.remote_host if not self.command: @@ -104,11 +108,12 @@ def execute(self, context): self.log.info("Running command: %s", self.command) # set timeout taken as params - stdin, stdout, stderr = ssh_client.exec_command(command=self.command, - get_pty=self.get_pty, - timeout=self.timeout, - environment=self.environment - ) + stdin, stdout, stderr = ssh_client.exec_command( + command=self.command, + get_pty=self.get_pty, + timeout=self.timeout, + environment=self.environment, + ) # get channels channel = stdout.channel @@ -126,9 +131,7 @@ def execute(self, context): agg_stdout += stdout.channel.recv(stdout_buffer_length) # read from both stdout and stderr - while not channel.closed or \ - channel.recv_ready() or \ - channel.recv_stderr_ready(): + while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], self.timeout) for recv in readq: if recv.recv_ready(): @@ -139,9 +142,11 @@ def execute(self, context): line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) agg_stderr += line self.log.warning(line.decode('utf-8').strip('\n')) - if stdout.channel.exit_status_ready()\ - and not stderr.channel.recv_stderr_ready()\ - and not stdout.channel.recv_ready(): + if ( + stdout.channel.exit_status_ready() + and not stderr.channel.recv_stderr_ready() + and not stdout.channel.recv_ready() + ): stdout.channel.shutdown_read() stdout.channel.close() break @@ -151,9 +156,7 @@ def execute(self, context): exit_status = stdout.channel.recv_exit_status() if exit_status == 0: - enable_pickling = conf.getboolean( - 'core', 'enable_xcom_pickling' - ) + enable_pickling = conf.getboolean('core', 'enable_xcom_pickling') if enable_pickling: return agg_stdout else: @@ -161,8 +164,9 @@ def execute(self, context): else: error_msg = agg_stderr.decode('utf-8') - raise AirflowException("error running cmd: {0}, error: {1}" - .format(self.command, error_msg)) + raise AirflowException( + "error running cmd: {0}, error: {1}".format(self.command, error_msg) + ) except Exception as e: raise AirflowException("SSH operator error: {0}".format(str(e))) diff --git a/airflow/providers/vertica/hooks/vertica.py b/airflow/providers/vertica/hooks/vertica.py index 96cce40b1ff71..deff220a182fe 100644 --- a/airflow/providers/vertica/hooks/vertica.py +++ b/airflow/providers/vertica/hooks/vertica.py @@ -40,7 +40,7 @@ def get_conn(self) -> connect: "user": conn.login, "password": conn.password or '', "database": conn.schema, - "host": conn.host or 'localhost' + "host": conn.host or 'localhost', } if not conn.port: diff --git a/airflow/providers/vertica/operators/vertica.py b/airflow/providers/vertica/operators/vertica.py index 923b1f687a3d2..e374365d277bc 100644 --- a/airflow/providers/vertica/operators/vertica.py +++ b/airflow/providers/vertica/operators/vertica.py @@ -39,9 +39,9 @@ class VerticaOperator(BaseOperator): ui_color = '#b4e0ff' @apply_defaults - def __init__(self, *, sql: Union[str, List[str]], - vertica_conn_id: str = 'vertica_default', - **kwargs: Any) -> None: + def __init__( + self, *, sql: Union[str, List[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any + ) -> None: super().__init__(**kwargs) self.vertica_conn_id = vertica_conn_id self.sql = sql diff --git a/airflow/providers/yandex/example_dags/example_yandexcloud_dataproc.py b/airflow/providers/yandex/example_dags/example_yandexcloud_dataproc.py index df4fa319d9955..ed4b6bb94072e 100644 --- a/airflow/providers/yandex/example_dags/example_yandexcloud_dataproc.py +++ b/airflow/providers/yandex/example_dags/example_yandexcloud_dataproc.py @@ -17,8 +17,12 @@ from airflow import DAG from airflow.providers.yandex.operators.yandexcloud_dataproc import ( - DataprocCreateClusterOperator, DataprocCreateHiveJobOperator, DataprocCreateMapReduceJobOperator, - DataprocCreatePysparkJobOperator, DataprocCreateSparkJobOperator, DataprocDeleteClusterOperator, + DataprocCreateClusterOperator, + DataprocCreateHiveJobOperator, + DataprocCreateMapReduceJobOperator, + DataprocCreatePysparkJobOperator, + DataprocCreateSparkJobOperator, + DataprocDeleteClusterOperator, ) from airflow.utils.dates import days_ago @@ -53,10 +57,7 @@ s3_bucket=S3_BUCKET_NAME_FOR_JOB_LOGS, ) - create_hive_query = DataprocCreateHiveJobOperator( - task_id='create_hive_query', - query='SELECT 1;', - ) + create_hive_query = DataprocCreateHiveJobOperator(task_id='create_hive_query', query='SELECT 1;',) create_hive_query_from_file = DataprocCreateHiveJobOperator( task_id='create_hive_query_from_file', @@ -64,7 +65,7 @@ script_variables={ 'CITIES_URI': 's3a://data-proc-public/jobs/sources/hive-001/cities/', 'COUNTRY_CODE': 'RU', - } + }, ) create_mapreduce_job = DataprocCreateMapReduceJobOperator( @@ -72,14 +73,19 @@ main_class='org.apache.hadoop.streaming.HadoopStreaming', file_uris=[ 's3a://data-proc-public/jobs/sources/mapreduce-001/mapper.py', - 's3a://data-proc-public/jobs/sources/mapreduce-001/reducer.py' + 's3a://data-proc-public/jobs/sources/mapreduce-001/reducer.py', ], args=[ - '-mapper', 'mapper.py', - '-reducer', 'reducer.py', - '-numReduceTasks', '1', - '-input', 's3a://data-proc-public/jobs/sources/data/cities500.txt.bz2', - '-output', 's3a://{bucket}/dataproc/job/results'.format(bucket=S3_BUCKET_NAME_FOR_JOB_LOGS) + '-mapper', + 'mapper.py', + '-reducer', + 'reducer.py', + '-numReduceTasks', + '1', + '-input', + 's3a://data-proc-public/jobs/sources/data/cities500.txt.bz2', + '-output', + 's3a://{bucket}/dataproc/job/results'.format(bucket=S3_BUCKET_NAME_FOR_JOB_LOGS), ], properties={ 'yarn.app.mapreduce.am.resource.mb': '2048', @@ -92,39 +98,27 @@ task_id='create_spark_job', main_jar_file_uri='s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar', main_class='ru.yandex.cloud.dataproc.examples.PopulationSparkJob', - file_uris=[ - 's3a://data-proc-public/jobs/sources/data/config.json', - ], - archive_uris=[ - 's3a://data-proc-public/jobs/sources/data/country-codes.csv.zip', - ], + file_uris=['s3a://data-proc-public/jobs/sources/data/config.json',], + archive_uris=['s3a://data-proc-public/jobs/sources/data/country-codes.csv.zip',], jar_file_uris=[ 's3a://data-proc-public/jobs/sources/java/icu4j-61.1.jar', 's3a://data-proc-public/jobs/sources/java/commons-lang-2.6.jar', 's3a://data-proc-public/jobs/sources/java/opencsv-4.1.jar', - 's3a://data-proc-public/jobs/sources/java/json-20190722.jar' + 's3a://data-proc-public/jobs/sources/java/json-20190722.jar', ], args=[ 's3a://data-proc-public/jobs/sources/data/cities500.txt.bz2', 's3a://{bucket}/dataproc/job/results/${{JOB_ID}}'.format(bucket=S3_BUCKET_NAME_FOR_JOB_LOGS), ], - properties={ - 'spark.submit.deployMode': 'cluster', - }, + properties={'spark.submit.deployMode': 'cluster',}, ) create_pyspark_job = DataprocCreatePysparkJobOperator( task_id='create_pyspark_job', main_python_file_uri='s3a://data-proc-public/jobs/sources/pyspark-001/main.py', - python_file_uris=[ - 's3a://data-proc-public/jobs/sources/pyspark-001/geonames.py', - ], - file_uris=[ - 's3a://data-proc-public/jobs/sources/data/config.json', - ], - archive_uris=[ - 's3a://data-proc-public/jobs/sources/data/country-codes.csv.zip', - ], + python_file_uris=['s3a://data-proc-public/jobs/sources/pyspark-001/geonames.py',], + file_uris=['s3a://data-proc-public/jobs/sources/data/config.json',], + archive_uris=['s3a://data-proc-public/jobs/sources/data/country-codes.csv.zip',], args=[ 's3a://data-proc-public/jobs/sources/data/cities500.txt.bz2', 's3a://{bucket}/jobs/results/${{JOB_ID}}'.format(bucket=S3_BUCKET_NAME_FOR_JOB_LOGS), @@ -134,14 +128,10 @@ 's3a://data-proc-public/jobs/sources/java/icu4j-61.1.jar', 's3a://data-proc-public/jobs/sources/java/commons-lang-2.6.jar', ], - properties={ - 'spark.submit.deployMode': 'cluster', - }, + properties={'spark.submit.deployMode': 'cluster',}, ) - delete_cluster = DataprocDeleteClusterOperator( - task_id='delete_cluster', - ) + delete_cluster = DataprocDeleteClusterOperator(task_id='delete_cluster',) - create_cluster >> create_mapreduce_job >> create_hive_query >> create_hive_query_from_file \ - >> create_spark_job >> create_pyspark_job >> delete_cluster + create_cluster >> create_mapreduce_job >> create_hive_query >> create_hive_query_from_file + create_hive_query_from_file >> create_spark_job >> create_pyspark_job >> delete_cluster diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index 815237323f11d..01a97e9a0adb9 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -31,11 +31,7 @@ class YandexCloudBaseHook(BaseHook): :type connection_id: str """ - def __init__(self, - connection_id=None, - default_folder_id=None, - default_public_ssh_key=None - ): + def __init__(self, connection_id=None, default_folder_id=None, default_public_ssh_key=None): super().__init__() self.connection_id = connection_id or 'yandexcloud_default' self.connection = self.get_connection(self.connection_id) @@ -52,8 +48,8 @@ def _get_credentials(self): oauth_token = self._get_field('oauth', False) if not (service_account_json or oauth_token or service_account_json_path): raise AirflowException( - 'No credentials are found in connection. Specify either service account ' + - 'authentication JSON or user OAuth token in Yandex.Cloud connection' + 'No credentials are found in connection. Specify either service account ' + + 'authentication JSON or user OAuth token in Yandex.Cloud connection' ) if service_account_json_path: with open(service_account_json_path) as infile: diff --git a/airflow/providers/yandex/hooks/yandexcloud_dataproc.py b/airflow/providers/yandex/hooks/yandexcloud_dataproc.py index 680161d581ec9..170f780d7e408 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_dataproc.py +++ b/airflow/providers/yandex/hooks/yandexcloud_dataproc.py @@ -31,6 +31,5 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cluster_id = None self.client = self.sdk.wrappers.Dataproc( - default_folder_id=self.default_folder_id, - default_public_ssh_key=self.default_public_ssh_key, + default_folder_id=self.default_folder_id, default_public_ssh_key=self.default_public_ssh_key, ) diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py b/airflow/providers/yandex/operators/yandexcloud_dataproc.py index 9d1b5b0b66367..cd7803a19946d 100644 --- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py +++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py @@ -82,31 +82,33 @@ class DataprocCreateClusterOperator(BaseOperator): # pylint: disable=too-many-arguments # pylint: disable=too-many-locals @apply_defaults - def __init__(self, - *, - folder_id: Optional[str] = None, - cluster_name: Optional[str] = None, - cluster_description: str = '', - cluster_image_version: str = '1.1', - ssh_public_keys: Optional[Union[str, Iterable[str]]] = None, - subnet_id: Optional[str] = None, - services: Iterable[str] = ('HDFS', 'YARN', 'MAPREDUCE', 'HIVE', 'SPARK'), - s3_bucket: Optional[str] = None, - zone: str = 'ru-central1-b', - service_account_id: Optional[str] = None, - masternode_resource_preset: str = 's2.small', - masternode_disk_size: int = 15, - masternode_disk_type: str = 'network-ssd', - datanode_resource_preset: str = 's2.small', - datanode_disk_size: int = 15, - datanode_disk_type: str = 'network-ssd', - datanode_count: int = 2, - computenode_resource_preset: str = 's2.small', - computenode_disk_size: int = 15, - computenode_disk_type: str = 'network-ssd', - computenode_count: int = 0, - connection_id: Optional[str] = None, - **kwargs): + def __init__( + self, + *, + folder_id: Optional[str] = None, + cluster_name: Optional[str] = None, + cluster_description: str = '', + cluster_image_version: str = '1.1', + ssh_public_keys: Optional[Union[str, Iterable[str]]] = None, + subnet_id: Optional[str] = None, + services: Iterable[str] = ('HDFS', 'YARN', 'MAPREDUCE', 'HIVE', 'SPARK'), + s3_bucket: Optional[str] = None, + zone: str = 'ru-central1-b', + service_account_id: Optional[str] = None, + masternode_resource_preset: str = 's2.small', + masternode_disk_size: int = 15, + masternode_disk_type: str = 'network-ssd', + datanode_resource_preset: str = 's2.small', + datanode_disk_size: int = 15, + datanode_disk_type: str = 'network-ssd', + datanode_count: int = 2, + computenode_resource_preset: str = 's2.small', + computenode_disk_size: int = 15, + computenode_disk_type: str = 'network-ssd', + computenode_count: int = 0, + connection_id: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.folder_id = folder_id self.connection_id = connection_id @@ -133,9 +135,7 @@ def __init__(self, self.hook = None def execute(self, context): - self.hook = DataprocHook( - connection_id=self.connection_id, - ) + self.hook = DataprocHook(connection_id=self.connection_id,) operation_result = self.hook.client.create_cluster( folder_id=self.folder_id, cluster_name=self.cluster_name, @@ -175,10 +175,7 @@ class DataprocDeleteClusterOperator(BaseOperator): template_fields = ['cluster_id'] @apply_defaults - def __init__(self, *, - connection_id: Optional[str] = None, - cluster_id: Optional[str] = None, - **kwargs): + def __init__(self, *, connection_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.connection_id = connection_id self.cluster_id = cluster_id @@ -189,9 +186,7 @@ def execute(self, context): connection_id = self.connection_id or context['task_instance'].xcom_pull( key='yandexcloud_connection_id' ) - self.hook = DataprocHook( - connection_id=connection_id, - ) + self.hook = DataprocHook(connection_id=connection_id,) self.hook.client.delete_cluster(cluster_id) @@ -221,16 +216,19 @@ class DataprocCreateHiveJobOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - query: Optional[str] = None, - query_file_uri: Optional[str] = None, - script_variables: Optional[Dict[str, str]] = None, - continue_on_failure: bool = False, - properties: Optional[Dict[str, str]] = None, - name: str = 'Hive job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, - **kwargs): + def __init__( + self, + *, + query: Optional[str] = None, + query_file_uri: Optional[str] = None, + script_variables: Optional[Dict[str, str]] = None, + continue_on_failure: bool = False, + properties: Optional[Dict[str, str]] = None, + name: str = 'Hive job', + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.query = query self.query_file_uri = query_file_uri @@ -247,9 +245,7 @@ def execute(self, context): connection_id = self.connection_id or context['task_instance'].xcom_pull( key='yandexcloud_connection_id' ) - self.hook = DataprocHook( - connection_id=connection_id, - ) + self.hook = DataprocHook(connection_id=connection_id,) self.hook.client.create_hive_job( query=self.query, query_file_uri=self.query_file_uri, @@ -292,18 +288,21 @@ class DataprocCreateMapReduceJobOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - main_class: Optional[str] = None, - main_jar_file_uri: Optional[str] = None, - jar_file_uris: Optional[Iterable[str]] = None, - archive_uris: Optional[Iterable[str]] = None, - file_uris: Optional[Iterable[str]] = None, - args: Optional[Iterable[str]] = None, - properties: Optional[Dict[str, str]] = None, - name: str = 'Mapreduce job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, - **kwargs): + def __init__( + self, + *, + main_class: Optional[str] = None, + main_jar_file_uri: Optional[str] = None, + jar_file_uris: Optional[Iterable[str]] = None, + archive_uris: Optional[Iterable[str]] = None, + file_uris: Optional[Iterable[str]] = None, + args: Optional[Iterable[str]] = None, + properties: Optional[Dict[str, str]] = None, + name: str = 'Mapreduce job', + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.main_class = main_class self.main_jar_file_uri = main_jar_file_uri @@ -322,9 +321,7 @@ def execute(self, context): connection_id = self.connection_id or context['task_instance'].xcom_pull( key='yandexcloud_connection_id' ) - self.hook = DataprocHook( - connection_id=connection_id, - ) + self.hook = DataprocHook(connection_id=connection_id,) self.hook.client.create_mapreduce_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, @@ -368,18 +365,21 @@ class DataprocCreateSparkJobOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - main_class: Optional[str] = None, - main_jar_file_uri: Optional[str] = None, - jar_file_uris: Optional[Iterable[str]] = None, - archive_uris: Optional[Iterable[str]] = None, - file_uris: Optional[Iterable[str]] = None, - args: Optional[Iterable[str]] = None, - properties: Optional[Dict[str, str]] = None, - name: str = 'Spark job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, - **kwargs): + def __init__( + self, + *, + main_class: Optional[str] = None, + main_jar_file_uri: Optional[str] = None, + jar_file_uris: Optional[Iterable[str]] = None, + archive_uris: Optional[Iterable[str]] = None, + file_uris: Optional[Iterable[str]] = None, + args: Optional[Iterable[str]] = None, + properties: Optional[Dict[str, str]] = None, + name: str = 'Spark job', + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.main_class = main_class self.main_jar_file_uri = main_jar_file_uri @@ -398,9 +398,7 @@ def execute(self, context): connection_id = self.connection_id or context['task_instance'].xcom_pull( key='yandexcloud_connection_id' ) - self.hook = DataprocHook( - connection_id=connection_id, - ) + self.hook = DataprocHook(connection_id=connection_id,) self.hook.client.create_spark_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, @@ -444,18 +442,21 @@ class DataprocCreatePysparkJobOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults - def __init__(self, *, - main_python_file_uri: Optional[str] = None, - python_file_uris: Optional[Iterable[str]] = None, - jar_file_uris: Optional[Iterable[str]] = None, - archive_uris: Optional[Iterable[str]] = None, - file_uris: Optional[Iterable[str]] = None, - args: Optional[Iterable[str]] = None, - properties: Optional[Dict[str, str]] = None, - name: str = 'Pyspark job', - cluster_id: Optional[str] = None, - connection_id: Optional[str] = None, - **kwargs): + def __init__( + self, + *, + main_python_file_uri: Optional[str] = None, + python_file_uris: Optional[Iterable[str]] = None, + jar_file_uris: Optional[Iterable[str]] = None, + archive_uris: Optional[Iterable[str]] = None, + file_uris: Optional[Iterable[str]] = None, + args: Optional[Iterable[str]] = None, + properties: Optional[Dict[str, str]] = None, + name: str = 'Pyspark job', + cluster_id: Optional[str] = None, + connection_id: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) self.main_python_file_uri = main_python_file_uri self.python_file_uris = python_file_uris @@ -474,9 +475,7 @@ def execute(self, context): connection_id = self.connection_id or context['task_instance'].xcom_pull( key='yandexcloud_connection_id' ) - self.hook = DataprocHook( - connection_id=connection_id, - ) + self.hook = DataprocHook(connection_id=connection_id,) self.hook.client.create_pyspark_job( main_python_file_uri=self.main_python_file_uri, python_file_uris=self.python_file_uris, diff --git a/airflow/providers/zendesk/hooks/zendesk.py b/airflow/providers/zendesk/hooks/zendesk.py index 1134bfa5d8aae..c045e0a297bfc 100644 --- a/airflow/providers/zendesk/hooks/zendesk.py +++ b/airflow/providers/zendesk/hooks/zendesk.py @@ -27,6 +27,7 @@ class ZendeskHook(BaseHook): """ A hook to talk to Zendesk """ + def __init__(self, zendesk_conn_id): super().__init__() self.__zendesk_conn_id = zendesk_conn_id @@ -35,20 +36,17 @@ def __init__(self, zendesk_conn_id): def get_conn(self): conn = self.get_connection(self.__zendesk_conn_id) self.__url = "https://" + conn.host - return Zendesk(zdesk_url=self.__url, zdesk_email=conn.login, zdesk_password=conn.password, - zdesk_token=True) + return Zendesk( + zdesk_url=self.__url, zdesk_email=conn.login, zdesk_password=conn.password, zdesk_token=True + ) def __handle_rate_limit_exception(self, rate_limit_exception): """ Sleep for the time specified in the exception. If not specified, wait for 60 seconds. """ - retry_after = int( - rate_limit_exception.response.headers.get('Retry-After', 60)) - self.log.info( - "Hit Zendesk API rate limit. Pausing for %s seconds", - retry_after - ) + retry_after = int(rate_limit_exception.response.headers.get('Retry-After', 60)) + self.log.info("Hit Zendesk API rate limit. Pausing for %s seconds", retry_after) time.sleep(retry_after) def call(self, path, query=None, get_all_pages=True, side_loading=False): diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py index 58f3ae78145cd..a9bd757464d6d 100644 --- a/tests/providers/amazon/aws/hooks/test_athena.py +++ b/tests/providers/amazon/aws/hooks/test_athena.py @@ -28,15 +28,11 @@ 'workgroup': 'primary', 'query_execution_id': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595', 'next_token_id': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595', - 'max_items': 1000 + 'max_items': 1000, } -mock_query_context = { - 'Database': MOCK_DATA['database'] -} -mock_result_configuration = { - 'OutputLocation': MOCK_DATA['outputLocation'] -} +mock_query_context = {'Database': MOCK_DATA['database']} +mock_result_configuration = {'OutputLocation': MOCK_DATA['outputLocation']} MOCK_RUNNING_QUERY_EXECUTION = {'QueryExecution': {'Status': {'State': 'RUNNING'}}} MOCK_SUCCEEDED_QUERY_EXECUTION = {'QueryExecution': {'Status': {'State': 'SUCCEEDED'}}} @@ -45,7 +41,6 @@ class TestAWSAthenaHook(unittest.TestCase): - def setUp(self): self.athena = AWSAthenaHook(sleep_time=0) @@ -56,14 +51,16 @@ def test_init(self): @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_query_without_token(self, mock_conn): mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION - result = self.athena.run_query(query=MOCK_DATA['query'], - query_context=mock_query_context, - result_configuration=mock_result_configuration) + result = self.athena.run_query( + query=MOCK_DATA['query'], + query_context=mock_query_context, + result_configuration=mock_result_configuration, + ) expected_call_params = { 'QueryString': MOCK_DATA['query'], 'QueryExecutionContext': mock_query_context, 'ResultConfiguration': mock_result_configuration, - 'WorkGroup': MOCK_DATA['workgroup'] + 'WorkGroup': MOCK_DATA['workgroup'], } mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params) self.assertEqual(result, MOCK_DATA['query_execution_id']) @@ -71,16 +68,18 @@ def test_hook_run_query_without_token(self, mock_conn): @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_query_with_token(self, mock_conn): mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION - result = self.athena.run_query(query=MOCK_DATA['query'], - query_context=mock_query_context, - result_configuration=mock_result_configuration, - client_request_token=MOCK_DATA['client_request_token']) + result = self.athena.run_query( + query=MOCK_DATA['query'], + query_context=mock_query_context, + result_configuration=mock_result_configuration, + client_request_token=MOCK_DATA['client_request_token'], + ) expected_call_params = { 'QueryString': MOCK_DATA['query'], 'QueryExecutionContext': mock_query_context, 'ResultConfiguration': mock_result_configuration, 'ClientRequestToken': MOCK_DATA['client_request_token'], - 'WorkGroup': MOCK_DATA['workgroup'] + 'WorkGroup': MOCK_DATA['workgroup'], } mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params) self.assertEqual(result, MOCK_DATA['query_execution_id']) @@ -95,21 +94,19 @@ def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn): def test_hook_get_query_results_with_default_params(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id']) - expected_call_params = { - 'QueryExecutionId': MOCK_DATA['query_execution_id'], - 'MaxResults': 1000 - } + expected_call_params = {'QueryExecutionId': MOCK_DATA['query_execution_id'], 'MaxResults': 1000} mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_get_query_results_with_next_token(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION - self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'], - next_token_id=MOCK_DATA['next_token_id']) + self.athena.get_query_results( + query_execution_id=MOCK_DATA['query_execution_id'], next_token_id=MOCK_DATA['next_token_id'] + ) expected_call_params = { 'QueryExecutionId': MOCK_DATA['query_execution_id'], 'NextToken': MOCK_DATA['next_token_id'], - 'MaxResults': 1000 + 'MaxResults': 1000, } mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params) @@ -125,30 +122,26 @@ def test_hook_get_paginator_with_default_params(self, mock_conn): self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id']) expected_call_params = { 'QueryExecutionId': MOCK_DATA['query_execution_id'], - 'PaginationConfig': { - 'MaxItems': None, - 'PageSize': None, - 'StartingToken': None - - } + 'PaginationConfig': {'MaxItems': None, 'PageSize': None, 'StartingToken': None}, } mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_get_paginator_with_pagination_config(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION - self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'], - max_items=MOCK_DATA['max_items'], - page_size=MOCK_DATA['max_items'], - starting_token=MOCK_DATA['next_token_id']) + self.athena.get_query_results_paginator( + query_execution_id=MOCK_DATA['query_execution_id'], + max_items=MOCK_DATA['max_items'], + page_size=MOCK_DATA['max_items'], + starting_token=MOCK_DATA['next_token_id'], + ) expected_call_params = { 'QueryExecutionId': MOCK_DATA['query_execution_id'], 'PaginationConfig': { 'MaxItems': MOCK_DATA['max_items'], 'PageSize': MOCK_DATA['max_items'], - 'StartingToken': MOCK_DATA['next_token_id'] - - } + 'StartingToken': MOCK_DATA['next_token_id'], + }, } mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params) @@ -162,8 +155,9 @@ def test_hook_poll_query_when_final(self, mock_conn): @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_poll_query_with_timeout(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION - result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id'], - max_tries=1) + result = self.athena.poll_query_status( + query_execution_id=MOCK_DATA['query_execution_id'], max_tries=1 + ) mock_conn.return_value.get_query_execution.assert_called_once() self.assertEqual(result, 'RUNNING') diff --git a/tests/providers/amazon/aws/hooks/test_aws_dynamodb.py b/tests/providers/amazon/aws/hooks/test_aws_dynamodb.py index cb6e202e232bd..907582e594e05 100644 --- a/tests/providers/amazon/aws/hooks/test_aws_dynamodb.py +++ b/tests/providers/amazon/aws/hooks/test_aws_dynamodb.py @@ -29,7 +29,6 @@ class TestDynamoDBHook(unittest.TestCase): - @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') @mock_dynamodb2 def test_get_conn_returns_a_boto3_connection(self): @@ -40,37 +39,23 @@ def test_get_conn_returns_a_boto3_connection(self): @mock_dynamodb2 def test_insert_batch_items_dynamodb_table(self): - hook = AwsDynamoDBHook(aws_conn_id='aws_default', - table_name='test_airflow', table_keys=['id'], region_name='us-east-1') + hook = AwsDynamoDBHook( + aws_conn_id='aws_default', table_name='test_airflow', table_keys=['id'], region_name='us-east-1' + ) # this table needs to be created in production table = hook.get_conn().create_table( TableName='test_airflow', - KeySchema=[ - { - 'AttributeName': 'id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'id', - 'AttributeType': 'S' - } - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 - } + KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'},], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, ) table = hook.get_conn().Table('test_airflow') - items = [{'id': str(uuid.uuid4()), 'name': 'airflow'} - for _ in range(10)] + items = [{'id': str(uuid.uuid4()), 'name': 'airflow'} for _ in range(10)] hook.write_batch_data(items) - table.meta.client.get_waiter( - 'table_exists').wait(TableName='test_airflow') + table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow') self.assertEqual(table.item_count, 10) diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py index f15c479be59df..8e3743776eb51 100644 --- a/tests/providers/amazon/aws/hooks/test_base_aws.py +++ b/tests/providers/amazon/aws/hooks/test_base_aws.py @@ -55,26 +55,12 @@ def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self): # this table needs to be created in production table = resource_from_hook.create_table( # pylint: disable=no-member TableName='test_airflow', - KeySchema=[ - { - 'AttributeName': 'id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'id', - 'AttributeType': 'S' - } - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 - } + KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'},], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, ) - table.meta.client.get_waiter( - 'table_exists').wait(TableName='test_airflow') + table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow') self.assertEqual(table.item_count, 0) @@ -86,35 +72,22 @@ def test_get_session_returns_a_boto3_session(self): resource_from_session = session_from_hook.resource('dynamodb') table = resource_from_session.create_table( # pylint: disable=no-member TableName='test_airflow', - KeySchema=[ - { - 'AttributeName': 'id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'id', - 'AttributeType': 'S' - } - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 - } + KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'},], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, ) - table.meta.client.get_waiter( - 'table_exists').wait(TableName='test_airflow') + table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow') self.assertEqual(table.item_count, 0) @mock.patch.object(AwsBaseHook, 'get_connection') def test_get_credentials_from_login_with_token(self, mock_get_connection): - mock_connection = Connection(login='aws_access_key_id', - password='aws_secret_access_key', - extra='{"aws_session_token": "test_token"}' - ) + mock_connection = Connection( + login='aws_access_key_id', + password='aws_secret_access_key', + extra='{"aws_session_token": "test_token"}', + ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() @@ -124,9 +97,7 @@ def test_get_credentials_from_login_with_token(self, mock_get_connection): @mock.patch.object(AwsBaseHook, 'get_connection') def test_get_credentials_from_login_without_token(self, mock_get_connection): - mock_connection = Connection(login='aws_access_key_id', - password='aws_secret_access_key', - ) + mock_connection = Connection(login='aws_access_key_id', password='aws_secret_access_key',) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam') @@ -139,8 +110,8 @@ def test_get_credentials_from_login_without_token(self, mock_get_connection): def test_get_credentials_from_extra_with_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' - '"aws_secret_access_key": "aws_secret_access_key",' - ' "aws_session_token": "session_token"}' + '"aws_secret_access_key": "aws_secret_access_key",' + ' "aws_session_token": "session_token"}' ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') @@ -153,7 +124,7 @@ def test_get_credentials_from_extra_with_token(self, mock_get_connection): def test_get_credentials_from_extra_without_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' - '"aws_secret_access_key": "aws_secret_access_key"}' + '"aws_secret_access_key": "aws_secret_access_key"}' ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') @@ -162,32 +133,30 @@ def test_get_credentials_from_extra_without_token(self, mock_get_connection): self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertIsNone(credentials_from_hook.token) - @mock.patch('airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config', - return_value=('aws_access_key_id', 'aws_secret_access_key')) + @mock.patch( + 'airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config', + return_value=('aws_access_key_id', 'aws_secret_access_key'), + ) @mock.patch.object(AwsBaseHook, 'get_connection') def test_get_credentials_from_extra_with_s3_config_and_profile( self, mock_get_connection, mock_parse_s3_config ): mock_connection = Connection( extra='{"s3_config_format": "aws", ' - '"profile": "test", ' - '"s3_config_file": "aws-credentials", ' - '"region_name": "us-east-1"}') + '"profile": "test", ' + '"s3_config_file": "aws-credentials", ' + '"region_name": "us-east-1"}' + ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') hook._get_credentials(region_name=None) - mock_parse_s3_config.assert_called_once_with( - 'aws-credentials', - 'aws', - 'test' - ) + mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws', 'test') @unittest.skipIf(mock_sts is None, 'mock_sts package not present') @mock.patch.object(AwsBaseHook, 'get_connection') @mock_sts def test_get_credentials_from_role_arn(self, mock_get_connection): - mock_connection = Connection( - extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}') + mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py index 0eb8242e4dcfa..272b645939219 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_client.py +++ b/tests/providers/amazon/aws/hooks/test_batch_client.py @@ -68,27 +68,19 @@ def test_init(self): self.assertEqual(self.batch_client.aws_conn_id, 'airflow_test') self.assertEqual(self.batch_client.client, self.client_mock) - self.get_client_type_mock.assert_called_once_with( - "batch", region_name=AWS_REGION - ) + self.get_client_type_mock.assert_called_once_with("batch", region_name=AWS_REGION) def test_wait_for_job_with_success(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} with mock.patch.object( - self.batch_client, - "poll_for_job_running", - wraps=self.batch_client.poll_for_job_running, + self.batch_client, "poll_for_job_running", wraps=self.batch_client.poll_for_job_running, ) as job_running: self.batch_client.wait_for_job(JOB_ID) job_running.assert_called_once_with(JOB_ID, None) with mock.patch.object( - self.batch_client, - "poll_for_job_complete", - wraps=self.batch_client.poll_for_job_complete, + self.batch_client, "poll_for_job_complete", wraps=self.batch_client.poll_for_job_complete, ) as job_complete: self.batch_client.wait_for_job(JOB_ID) job_complete.assert_called_once_with(JOB_ID, None) @@ -96,22 +88,16 @@ def test_wait_for_job_with_success(self): self.assertEqual(self.client_mock.describe_jobs.call_count, 4) def test_wait_for_job_with_failure(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "FAILED"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "FAILED"}]} with mock.patch.object( - self.batch_client, - "poll_for_job_running", - wraps=self.batch_client.poll_for_job_running, + self.batch_client, "poll_for_job_running", wraps=self.batch_client.poll_for_job_running, ) as job_running: self.batch_client.wait_for_job(JOB_ID) job_running.assert_called_once_with(JOB_ID, None) with mock.patch.object( - self.batch_client, - "poll_for_job_complete", - wraps=self.batch_client.poll_for_job_complete, + self.batch_client, "poll_for_job_complete", wraps=self.batch_client.poll_for_job_complete, ) as job_complete: self.batch_client.wait_for_job(JOB_ID) job_complete.assert_called_once_with(JOB_ID, None) @@ -119,23 +105,17 @@ def test_wait_for_job_with_failure(self): self.assertEqual(self.client_mock.describe_jobs.call_count, 4) def test_poll_job_running_for_status_running(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "RUNNING"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "RUNNING"}]} self.batch_client.poll_for_job_running(JOB_ID) self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) def test_poll_job_complete_for_status_success(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} self.batch_client.poll_for_job_complete(JOB_ID) self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) def test_poll_job_complete_raises_for_max_retries(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "RUNNING"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "RUNNING"}]} with self.assertRaises(AirflowException) as e: self.batch_client.poll_for_job_complete(JOB_ID) msg = "AWS Batch job ({}) status checks exceed max_retries".format(JOB_ID) @@ -158,8 +138,7 @@ def test_poll_job_status_hit_api_throttle(self): def test_poll_job_status_with_client_error(self): self.client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError( - error_response={"Error": {"Code": "InvalidClientTokenId"}}, - operation_name="get job description", + error_response={"Error": {"Code": "InvalidClientTokenId"}}, operation_name="get job description", ) with self.assertRaises(AirflowException) as e: self.batch_client.poll_for_job_complete(JOB_ID) @@ -169,9 +148,7 @@ def test_poll_job_status_with_client_error(self): self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) def test_check_job_success(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} status = self.batch_client.check_job_success(JOB_ID) self.assertTrue(status) self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) @@ -211,9 +188,7 @@ def test_check_job_success_raises_failed_for_multiple_attempts(self): self.assertIn(msg, str(e.exception)) def test_check_job_success_raises_incomplete(self): - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "RUNNABLE"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "RUNNABLE"}]} with self.assertRaises(AirflowException) as e: self.batch_client.check_job_success(JOB_ID) self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) @@ -222,9 +197,7 @@ def test_check_job_success_raises_incomplete(self): def test_check_job_success_raises_unknown_status(self): status = "STRANGE" - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": status}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": status}]} with self.assertRaises(AirflowException) as e: self.batch_client.check_job_success(JOB_ID) self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) @@ -253,9 +226,7 @@ class TestAwsBatchClientDelays(unittest.TestCase): @mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID) @mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY) def setUp(self): - self.batch_client = AwsBatchClientHook( - aws_conn_id='airflow_test', - region_name=AWS_REGION) + self.batch_client = AwsBatchClientHook(aws_conn_id='airflow_test', region_name=AWS_REGION) def test_init(self): self.assertEqual(self.batch_client.max_retries, self.batch_client.MAX_RETRIES) diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py index 951792d74d830..19cfcb41f041f 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py +++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py @@ -113,9 +113,7 @@ def logs_client(aws_region): @pytest.fixture(scope="module") def aws_clients(batch_client, ec2_client, ecs_client, iam_client, logs_client): - return AwsClients( - batch=batch_client, ec2=ec2_client, ecs=ecs_client, iam=iam_client, log=logs_client - ) + return AwsClients(batch=batch_client, ec2=ec2_client, ecs=ecs_client, iam=iam_client, log=logs_client) # @@ -164,17 +162,12 @@ def batch_infrastructure( ) sg_id = resp["GroupId"] - resp = aws_clients.iam.create_role( - RoleName="MotoTestRole", AssumeRolePolicyDocument="moto_test_policy" - ) + resp = aws_clients.iam.create_role(RoleName="MotoTestRole", AssumeRolePolicyDocument="moto_test_policy") iam_arn = resp["Role"]["Arn"] compute_env_name = "moto_test_compute_env" resp = aws_clients.batch.create_compute_environment( - computeEnvironmentName=compute_env_name, - type="UNMANAGED", - state="ENABLED", - serviceRole=iam_arn, + computeEnvironmentName=compute_env_name, type="UNMANAGED", state="ENABLED", serviceRole=iam_arn, ) compute_env_arn = resp["computeEnvironmentArn"] @@ -191,20 +184,13 @@ def batch_infrastructure( resp = aws_clients.batch.register_job_definition( jobDefinitionName=job_definition_name, type="container", - containerProperties={ - "image": "busybox", - "vcpus": 1, - "memory": 64, - "command": ["sleep", "10"], - }, + containerProperties={"image": "busybox", "vcpus": 1, "memory": 64, "command": ["sleep", "10"],}, ) assert resp["jobDefinitionName"] == job_definition_name assert resp["jobDefinitionArn"] job_definition_arn = resp["jobDefinitionArn"] assert resp["revision"] - assert resp["jobDefinitionArn"].endswith( - "{0}:{1}".format(resp["jobDefinitionName"], resp["revision"]) - ) + assert resp["jobDefinitionArn"].endswith("{0}:{1}".format(resp["jobDefinitionName"], resp["revision"])) infrastructure.vpc_id = vpc_id infrastructure.subnet_id = subnet_id @@ -251,9 +237,7 @@ def test_aws_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_defi - https://github.com/spulec/moto/blob/master/tests/test_batch/test_batch.py """ - aws_resources = batch_infrastructure( - aws_clients, aws_region, job_queue_name, job_definition_name - ) + aws_resources = batch_infrastructure(aws_clients, aws_region, job_queue_name, job_definition_name) batch_waiters = AwsBatchWaitersHook(region_name=aws_resources.aws_region) job_exists_waiter = batch_waiters.get_waiter("JobExists") @@ -345,9 +329,7 @@ def setUp(self, get_client_type_mock): # init the mock client self.client_mock = self.batch_waiters.client - get_client_type_mock.assert_called_once_with( - "batch", region_name=self.region_name - ) + get_client_type_mock.assert_called_once_with("batch", region_name=self.region_name) # don't pause in these unit tests self.mock_delay = mock.Mock(return_value=None) diff --git a/tests/providers/amazon/aws/hooks/test_cloud_formation.py b/tests/providers/amazon/aws/hooks/test_cloud_formation.py index 186f4bc837aea..7906412b9fa10 100644 --- a/tests/providers/amazon/aws/hooks/test_cloud_formation.py +++ b/tests/providers/amazon/aws/hooks/test_cloud_formation.py @@ -29,30 +29,22 @@ @unittest.skipIf(mock_cloudformation is None, 'moto package not present') class TestAWSCloudFormationHook(unittest.TestCase): - def setUp(self): self.hook = AWSCloudFormationHook(aws_conn_id='aws_default') def create_stack(self, stack_name): timeout = 15 - template_body = json.dumps({ - 'Resources': { - "myResource": { - "Type": "emr", - "Properties": { - "myProperty": "myPropertyValue" - } - } - } - }) + template_body = json.dumps( + {'Resources': {"myResource": {"Type": "emr", "Properties": {"myProperty": "myPropertyValue"}}}} + ) self.hook.create_stack( stack_name=stack_name, params={ 'TimeoutInMinutes': timeout, 'TemplateBody': template_body, - 'Parameters': [{'ParameterKey': 'myParam', 'ParameterValue': 'myParamValue'}] - } + 'Parameters': [{'ParameterKey': 'myParam', 'ParameterValue': 'myParamValue'}], + }, ) @mock_cloudformation @@ -82,11 +74,7 @@ def test_create_stack(self): self.assertEqual(len(matching_stacks), 1, f'stack with name {stack_name} should exist') stack = matching_stacks[0] - self.assertEqual( - stack['StackStatus'], - 'CREATE_COMPLETE', - 'Stack should be in status CREATE_COMPLETE' - ) + self.assertEqual(stack['StackStatus'], 'CREATE_COMPLETE', 'Stack should be in status CREATE_COMPLETE') @mock_cloudformation def test_delete_stack(self): diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py b/tests/providers/amazon/aws/hooks/test_datasync.py index 5af89328d8b14..22a7f42d1f4ba 100644 --- a/tests/providers/amazon/aws/hooks/test_datasync.py +++ b/tests/providers/amazon/aws/hooks/test_datasync.py @@ -43,9 +43,7 @@ def no_datasync(x): @mock_datasync -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAwsDataSyncHook(unittest.TestCase): def test_get_conn(self): hook = AWSDataSyncHook(aws_conn_id="aws_default") @@ -67,9 +65,7 @@ def test_get_conn(self): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAWSDataSyncHookMocked(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -102,8 +98,7 @@ def setUp(self): S3Config={"BucketAccessRoleArn": "role"}, )["LocationArn"] self.task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] def tearDown(self): @@ -226,7 +221,7 @@ def test_create_task(self, mock_get_conn): task_arn = self.hook.create_task( source_location_arn=self.source_location_arn, destination_location_arn=self.destination_location_arn, - **create_task_kwargs + **create_task_kwargs, ) task = self.client.describe_task(TaskArn=task_arn) @@ -272,9 +267,7 @@ def test_get_location_arns(self, mock_get_conn): # ### Begin tests: # Get true location_arn from boto/moto self.client - location_uri = "smb://{0}/{1}".format( - self.source_server_hostname, self.source_subdirectory - ) + location_uri = "smb://{0}/{1}".format(self.source_server_hostname, self.source_subdirectory) locations = self.client.list_locations() for location in locations["Locations"]: if location["LocationUri"] == location_uri: @@ -292,9 +285,7 @@ def test_get_location_arns_case_sensitive(self, mock_get_conn): # ### Begin tests: # Get true location_arn from boto/moto self.client - location_uri = "smb://{0}/{1}".format( - self.source_server_hostname.upper(), self.source_subdirectory - ) + location_uri = "smb://{0}/{1}".format(self.source_server_hostname.upper(), self.source_subdirectory) locations = self.client.list_locations() for location in locations["Locations"]: if location["LocationUri"] == location_uri.lower(): @@ -313,22 +304,16 @@ def test_get_location_arns_trailing_slash(self, mock_get_conn): # ### Begin tests: # Get true location_arn from boto/moto self.client - location_uri = "smb://{0}/{1}/".format( - self.source_server_hostname, self.source_subdirectory - ) + location_uri = "smb://{0}/{1}/".format(self.source_server_hostname, self.source_subdirectory) locations = self.client.list_locations() for location in locations["Locations"]: if location["LocationUri"] == location_uri[:-1]: location_arn = location["LocationArn"] # Verify our self.hook manages trailing / correctly - location_arns = self.hook.get_location_arns( - location_uri, ignore_trailing_slash=False - ) + location_arns = self.hook.get_location_arns(location_uri, ignore_trailing_slash=False) self.assertEqual(len(location_arns), 0) - location_arns = self.hook.get_location_arns( - location_uri, ignore_trailing_slash=True - ) + location_arns = self.hook.get_location_arns(location_uri, ignore_trailing_slash=True) self.assertEqual(len(location_arns), 1) self.assertEqual(location_arns[0], location_arn) @@ -361,9 +346,7 @@ def test_start_task_execution(self, mock_get_conn): self.assertIn("CurrentTaskExecutionArn", task) self.assertEqual(task["CurrentTaskExecutionArn"], task_execution_arn) - task_execution = self.client.describe_task_execution( - TaskExecutionArn=task_execution_arn - ) + task_execution = self.client.describe_task_execution(TaskExecutionArn=task_execution_arn) self.assertIn("Status", task_execution) def test_cancel_task_execution(self, mock_get_conn): @@ -407,9 +390,7 @@ def test_wait_for_task_execution(self, mock_get_conn): # ### Begin tests: task_execution_arn = self.hook.start_task_execution(self.task_arn) - result = self.hook.wait_for_task_execution( - task_execution_arn, max_iterations=20 - ) + result = self.hook.wait_for_task_execution(task_execution_arn, max_iterations=20) self.assertIsNotNone(result) @@ -420,7 +401,5 @@ def test_wait_for_task_execution_timeout(self, mock_get_conn): task_execution_arn = self.hook.start_task_execution(self.task_arn) with self.assertRaises(AirflowTaskTimeout): - result = self.hook.wait_for_task_execution( - task_execution_arn, max_iterations=1 - ) + result = self.hook.wait_for_task_execution(task_execution_arn, max_iterations=1) self.assertIsNone(result) diff --git a/tests/providers/amazon/aws/hooks/test_ec2.py b/tests/providers/amazon/aws/hooks/test_ec2.py index 9ffa12329519e..9a509095b9f69 100644 --- a/tests/providers/amazon/aws/hooks/test_ec2.py +++ b/tests/providers/amazon/aws/hooks/test_ec2.py @@ -25,12 +25,8 @@ class TestEC2Hook(unittest.TestCase): - def test_init(self): - ec2_hook = EC2Hook( - aws_conn_id="aws_conn_test", - region_name="region-test", - ) + ec2_hook = EC2Hook(aws_conn_id="aws_conn_test", region_name="region-test",) self.assertEqual(ec2_hook.aws_conn_id, "aws_conn_test") self.assertEqual(ec2_hook.region_name, "region-test") @@ -43,29 +39,19 @@ def test_get_conn_returns_boto3_resource(self): @mock_ec2 def test_get_instance(self): ec2_hook = EC2Hook() - created_instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + created_instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) created_instance_id = created_instances[0].instance_id # test get_instance method - existing_instance = ec2_hook.get_instance( - instance_id=created_instance_id - ) + existing_instance = ec2_hook.get_instance(instance_id=created_instance_id) self.assertEqual(created_instance_id, existing_instance.instance_id) @mock_ec2 def test_get_instance_state(self): ec2_hook = EC2Hook() - created_instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + created_instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) created_instance_id = created_instances[0].instance_id all_instances = list(ec2_hook.conn.instances.all()) created_instance_state = all_instances[0].state["Name"] # test get_instance_state method - existing_instance_state = ec2_hook.get_instance_state( - instance_id=created_instance_id - ) + existing_instance_state = ec2_hook.get_instance_state(instance_id=created_instance_id) self.assertEqual(created_instance_state, existing_instance_state) diff --git a/tests/providers/amazon/aws/hooks/test_emr.py b/tests/providers/amazon/aws/hooks/test_emr.py index 5fbcd96b5a226..7d954e88ee0f3 100644 --- a/tests/providers/amazon/aws/hooks/test_emr.py +++ b/tests/providers/amazon/aws/hooks/test_emr.py @@ -43,8 +43,7 @@ def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self): hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default') cluster = hook.create_job_flow({'Name': 'test_cluster'}) - self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], - cluster['JobFlowId']) + self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId']) @mock_emr def test_create_job_flow_extra_args(self): @@ -60,9 +59,7 @@ def test_create_job_flow_extra_args(self): # AmiVersion is really old and almost no one will use it anymore, but # it's one of the "optional" request params that moto supports - it's # coverage of EMR isn't 100% it turns out. - cluster = hook.create_job_flow({'Name': 'test_cluster', - 'ReleaseLabel': '', - 'AmiVersion': '3.2'}) + cluster = hook.create_job_flow({'Name': 'test_cluster', 'ReleaseLabel': '', 'AmiVersion': '3.2'}) cluster = client.describe_cluster(ClusterId=cluster['JobFlowId'])['Cluster'] @@ -76,8 +73,9 @@ def test_get_cluster_id_by_name(self): """ hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default') - job_flow = hook.create_job_flow({'Name': 'test_cluster', - 'Instances': {'KeepJobFlowAliveWhenNoSteps': True}}) + job_flow = hook.create_job_flow( + {'Name': 'test_cluster', 'Instances': {'KeepJobFlowAliveWhenNoSteps': True}} + ) job_flow_id = job_flow['JobFlowId'] diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index 3871025666fe2..96c89085713ef 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -35,50 +35,48 @@ def setUp(self): @unittest.skipIf(mock_iam is None, 'mock_iam package not present') @mock_iam def test_get_iam_execution_role(self): - hook = AwsGlueJobHook(job_name='aws_test_glue_job', - s3_bucket='some_bucket', - iam_role_name='my_test_role') + hook = AwsGlueJobHook( + job_name='aws_test_glue_job', s3_bucket='some_bucket', iam_role_name='my_test_role' + ) iam_role = hook.get_client_type('iam').create_role( Path="/", RoleName='my_test_role', - AssumeRolePolicyDocument=json.dumps({ - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Principal": {"Service": "glue.amazonaws.com"}, - "Action": "sts:AssumeRole" + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Principal": {"Service": "glue.amazonaws.com"}, + "Action": "sts:AssumeRole", + }, } - }) + ), ) iam_role = hook.get_iam_execution_role() self.assertIsNotNone(iam_role) @mock.patch.object(AwsGlueJobHook, "get_iam_execution_role") @mock.patch.object(AwsGlueJobHook, "get_conn") - def test_get_or_create_glue_job(self, mock_get_conn, - mock_get_iam_execution_role - ): - mock_get_iam_execution_role.return_value = \ - mock.MagicMock(Role={'RoleName': 'my_test_role'}) + def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role): + mock_get_iam_execution_role.return_value = mock.MagicMock(Role={'RoleName': 'my_test_role'}) some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py" some_s3_bucket = "my-includes" mock_glue_job = mock_get_conn.return_value.get_job()['Job']['Name'] - glue_job = AwsGlueJobHook(job_name='aws_test_glue_job', - desc='This is test case job from Airflow', - script_location=some_script, - iam_role_name='my_test_role', - s3_bucket=some_s3_bucket, - region_name=self.some_aws_region)\ - .get_or_create_glue_job() + glue_job = AwsGlueJobHook( + job_name='aws_test_glue_job', + desc='This is test case job from Airflow', + script_location=some_script, + iam_role_name='my_test_role', + s3_bucket=some_s3_bucket, + region_name=self.some_aws_region, + ).get_or_create_glue_job() self.assertEqual(glue_job, mock_glue_job) @mock.patch.object(AwsGlueJobHook, "get_job_state") @mock.patch.object(AwsGlueJobHook, "get_or_create_glue_job") @mock.patch.object(AwsGlueJobHook, "get_conn") - def test_initialize_job(self, mock_get_conn, - mock_get_or_create_glue_job, - mock_get_job_state): + def test_initialize_job(self, mock_get_conn, mock_get_or_create_glue_job, mock_get_job_state): some_data_path = "s3://glue-datasets/examples/medicare/SampleData.csv" some_script_arguments = {"--s3_input_data_path": some_data_path} some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py" @@ -88,12 +86,14 @@ def test_initialize_job(self, mock_get_conn, mock_get_conn.return_value.start_job_run() mock_job_run_state = mock_get_job_state.return_value - glue_job_hook = AwsGlueJobHook(job_name='aws_test_glue_job', - desc='This is test case job from Airflow', - iam_role_name='my_test_role', - script_location=some_script, - s3_bucket=some_s3_bucket, - region_name=self.some_aws_region) + glue_job_hook = AwsGlueJobHook( + job_name='aws_test_glue_job', + desc='This is test case job from Airflow', + iam_role_name='my_test_role', + script_location=some_script, + s3_bucket=some_s3_bucket, + region_name=self.some_aws_region, + ) glue_job_run = glue_job_hook.initialize_job(some_script_arguments) glue_job_run_state = glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId']) self.assertEqual(glue_job_run_state, mock_job_run_state, msg='Mocks but be equal') diff --git a/tests/providers/amazon/aws/hooks/test_glue_catalog.py b/tests/providers/amazon/aws/hooks/test_glue_catalog.py index b6b594b434aae..4d35ee9d848a5 100644 --- a/tests/providers/amazon/aws/hooks/test_glue_catalog.py +++ b/tests/providers/amazon/aws/hooks/test_glue_catalog.py @@ -33,22 +33,14 @@ TABLE_INPUT = { "Name": TABLE_NAME, "StorageDescriptor": { - "Columns": [ - { - "Name": "string", - "Type": "string", - "Comment": "string" - } - ], + "Columns": [{"Name": "string", "Type": "string", "Comment": "string"}], "Location": "s3://mybucket/{}/{}".format(DB_NAME, TABLE_NAME), - } + }, } -@unittest.skipIf(mock_glue is None, - "Skipping test because moto.mock_glue is not available") +@unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") class TestAwsGlueCatalogHook(unittest.TestCase): - @mock_glue def setUp(self): self.client = boto3.client('glue', region_name='us-east-1') @@ -81,31 +73,23 @@ def test_get_partitions_empty(self, mock_get_conn): @mock_glue @mock.patch.object(AwsGlueCatalogHook, 'get_conn') def test_get_partitions(self, mock_get_conn): - response = [{ - 'Partitions': [{ - 'Values': ['2015-01-01'] - }] - }] + response = [{'Partitions': [{'Values': ['2015-01-01']}]}] mock_paginator = mock.Mock() mock_paginator.paginate.return_value = response mock_conn = mock.Mock() mock_conn.get_paginator.return_value = mock_paginator mock_get_conn.return_value = mock_conn hook = AwsGlueCatalogHook(region_name="us-east-1") - result = hook.get_partitions('db', - 'tbl', - expression='foo=bar', - page_size=2, - max_items=3) + result = hook.get_partitions('db', 'tbl', expression='foo=bar', page_size=2, max_items=3) self.assertEqual(result, {('2015-01-01',)}) mock_conn.get_paginator.assert_called_once_with('get_partitions') - mock_paginator.paginate.assert_called_once_with(DatabaseName='db', - TableName='tbl', - Expression='foo=bar', - PaginationConfig={ - 'PageSize': 2, - 'MaxItems': 3}) + mock_paginator.paginate.assert_called_once_with( + DatabaseName='db', + TableName='tbl', + Expression='foo=bar', + PaginationConfig={'PageSize': 2, 'MaxItems': 3}, + ) @mock_glue @mock.patch.object(AwsGlueCatalogHook, 'get_partitions') @@ -126,48 +110,28 @@ def test_check_for_partition_false(self, mock_get_partitions): @mock_glue def test_get_table_exists(self): - self.client.create_database( - DatabaseInput={ - 'Name': DB_NAME - } - ) - self.client.create_table( - DatabaseName=DB_NAME, - TableInput=TABLE_INPUT - ) + self.client.create_database(DatabaseInput={'Name': DB_NAME}) + self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT) result = self.hook.get_table(DB_NAME, TABLE_NAME) self.assertEqual(result['Name'], TABLE_INPUT['Name']) - self.assertEqual(result['StorageDescriptor']['Location'], - TABLE_INPUT['StorageDescriptor']['Location']) + self.assertEqual( + result['StorageDescriptor']['Location'], TABLE_INPUT['StorageDescriptor']['Location'] + ) @mock_glue def test_get_table_not_exists(self): - self.client.create_database( - DatabaseInput={ - 'Name': DB_NAME - } - ) - self.client.create_table( - DatabaseName=DB_NAME, - TableInput=TABLE_INPUT - ) + self.client.create_database(DatabaseInput={'Name': DB_NAME}) + self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT) with self.assertRaises(Exception): self.hook.get_table(DB_NAME, 'dummy_table') @mock_glue def test_get_table_location(self): - self.client.create_database( - DatabaseInput={ - 'Name': DB_NAME - } - ) - self.client.create_table( - DatabaseName=DB_NAME, - TableInput=TABLE_INPUT - ) + self.client.create_database(DatabaseInput={'Name': DB_NAME}) + self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT) result = self.hook.get_table_location(DB_NAME, TABLE_NAME) self.assertEqual(result, TABLE_INPUT['StorageDescriptor']['Location']) diff --git a/tests/providers/amazon/aws/hooks/test_kinesis.py b/tests/providers/amazon/aws/hooks/test_kinesis.py index 8973d6486b4f2..7f2e65dae89b1 100644 --- a/tests/providers/amazon/aws/hooks/test_kinesis.py +++ b/tests/providers/amazon/aws/hooks/test_kinesis.py @@ -28,19 +28,20 @@ class TestAwsFirehoseHook(unittest.TestCase): - @unittest.skipIf(mock_kinesis is None, 'mock_kinesis package not present') @mock_kinesis def test_get_conn_returns_a_boto3_connection(self): - hook = AwsFirehoseHook(aws_conn_id='aws_default', - delivery_stream="test_airflow", region_name="us-east-1") + hook = AwsFirehoseHook( + aws_conn_id='aws_default', delivery_stream="test_airflow", region_name="us-east-1" + ) self.assertIsNotNone(hook.get_conn()) @unittest.skipIf(mock_kinesis is None, 'mock_kinesis package not present') @mock_kinesis def test_insert_batch_records_kinesis_firehose(self): - hook = AwsFirehoseHook(aws_conn_id='aws_default', - delivery_stream="test_airflow", region_name="us-east-1") + hook = AwsFirehoseHook( + aws_conn_id='aws_default', delivery_stream="test_airflow", region_name="us-east-1" + ) response = hook.get_conn().create_delivery_stream( DeliveryStreamName="test_airflow", @@ -48,20 +49,15 @@ def test_insert_batch_records_kinesis_firehose(self): 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', 'BucketARN': 'arn:aws:s3:::kinesis-test', 'Prefix': 'airflow/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, + 'BufferingHints': {'SizeInMBs': 123, 'IntervalInSeconds': 124}, 'CompressionFormat': 'UNCOMPRESSED', - } + }, ) stream_arn = response['DeliveryStreamARN'] - self.assertEqual( - stream_arn, "arn:aws:firehose:us-east-1:123456789012:deliverystream/test_airflow") + self.assertEqual(stream_arn, "arn:aws:firehose:us-east-1:123456789012:deliverystream/test_airflow") - records = [{"Data": str(uuid.uuid4())} - for _ in range(100)] + records = [{"Data": str(uuid.uuid4())} for _ in range(100)] response = hook.put_records(records) diff --git a/tests/providers/amazon/aws/hooks/test_lambda_function.py b/tests/providers/amazon/aws/hooks/test_lambda_function.py index 0ce61f5512ac8..64ffe7d7402a0 100644 --- a/tests/providers/amazon/aws/hooks/test_lambda_function.py +++ b/tests/providers/amazon/aws/hooks/test_lambda_function.py @@ -29,15 +29,17 @@ class TestAwsLambdaHook: @mock_lambda def test_get_conn_returns_a_boto3_connection(self): - hook = AwsLambdaHook(aws_conn_id='aws_default', - function_name="test_function", region_name="us-east-1") + hook = AwsLambdaHook( + aws_conn_id='aws_default', function_name="test_function", region_name="us-east-1" + ) assert hook.conn is not None @mock_lambda def test_invoke_lambda_function(self): - hook = AwsLambdaHook(aws_conn_id='aws_default', - function_name="test_function", region_name="us-east-1") + hook = AwsLambdaHook( + aws_conn_id='aws_default', function_name="test_function", region_name="us-east-1" + ) with patch.object(hook.conn, 'invoke') as mock_invoke: payload = '{"hello": "airflow"}' @@ -48,5 +50,5 @@ def test_invoke_lambda_function(self): InvocationType="RequestResponse", LogType="None", Payload=payload, - Qualifier="$LATEST" + Qualifier="$LATEST", ) diff --git a/tests/providers/amazon/aws/hooks/test_logs.py b/tests/providers/amazon/aws/hooks/test_logs.py index 666e93efc7a74..65852aa458f2f 100644 --- a/tests/providers/amazon/aws/hooks/test_logs.py +++ b/tests/providers/amazon/aws/hooks/test_logs.py @@ -28,12 +28,10 @@ class TestAwsLogsHook(unittest.TestCase): - @unittest.skipIf(mock_logs is None, 'mock_logs package not present') @mock_logs def test_get_conn_returns_a_boto3_connection(self): - hook = AwsLogsHook(aws_conn_id='aws_default', - region_name="us-east-1") + hook = AwsLogsHook(aws_conn_id='aws_default', region_name="us-east-1") self.assertIsNotNone(hook.get_conn()) @unittest.skipIf(mock_logs is None, 'mock_logs package not present') @@ -44,30 +42,21 @@ def test_get_log_events(self): log_group_name = 'example-group' log_stream_name = 'example-log-stream' - hook = AwsLogsHook(aws_conn_id='aws_default', - region_name="us-east-1") + hook = AwsLogsHook(aws_conn_id='aws_default', region_name="us-east-1") # First we create some log events conn = hook.get_conn() conn.create_log_group(logGroupName=log_group_name) conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) - input_events = [ - { - 'timestamp': 1, - 'message': 'Test Message 1' - } - ] - - conn.put_log_events(logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=input_events) + input_events = [{'timestamp': 1, 'message': 'Test Message 1'}] - events = hook.get_log_events( - log_group=log_group_name, - log_stream_name=log_stream_name + conn.put_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=input_events ) + events = hook.get_log_events(log_group=log_group_name, log_stream_name=log_stream_name) + # Iterate through entire generator events = list(events) count = len(events) diff --git a/tests/providers/amazon/aws/hooks/test_redshift.py b/tests/providers/amazon/aws/hooks/test_redshift.py index f53b0c9d2cc30..551a7c3b6daaf 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift.py +++ b/tests/providers/amazon/aws/hooks/test_redshift.py @@ -38,13 +38,13 @@ def _create_clusters(): ClusterIdentifier='test_cluster', NodeType='dc1.large', MasterUsername='admin', - MasterUserPassword='mock_password' + MasterUserPassword='mock_password', ) client.create_cluster( ClusterIdentifier='test_cluster_2', NodeType='dc1.large', MasterUsername='admin', - MasterUserPassword='mock_password' + MasterUserPassword='mock_password', ) if not client.describe_clusters()['Clusters']: raise ValueError('AWS not properly mocked') @@ -66,10 +66,9 @@ def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self): hook = RedshiftHook(aws_conn_id='aws_default') hook.create_cluster_snapshot('test_snapshot', 'test_cluster') self.assertEqual( - hook.restore_from_cluster_snapshot( - 'test_cluster_3', 'test_snapshot' - )['ClusterIdentifier'], - 'test_cluster_3') + hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'], + 'test_cluster_3', + ) @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') @mock_redshift diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 59f823a39dd2e..8370596439244 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -38,7 +38,6 @@ @pytest.mark.skipif(mock_s3 is None, reason='moto package not present') class TestAwsS3Hook: - @mock_s3 def test_get_conn(self): hook = S3Hook() @@ -183,8 +182,9 @@ def test_read_key(self, s3_bucket): # As of 1.3.2, Moto doesn't support select_object_content yet. @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type') def test_select_key(self, mock_get_client_type, s3_bucket): - mock_get_client_type.return_value.select_object_content.return_value = \ - {'Payload': [{'Records': {'Payload': b'Cont\xC3\xA9nt'}}]} + mock_get_client_type.return_value.select_object_content.return_value = { + 'Payload': [{'Records': {'Payload': b'Cont\xC3\xA9nt'}}] + } hook = S3Hook() assert hook.select_key('my_key', s3_bucket) == 'Contént' @@ -233,11 +233,11 @@ def test_load_string(self, s3_bucket): def test_load_string_acl(self, s3_bucket): hook = S3Hook() - hook.load_string("Contént", "my_key", s3_bucket, - acl_policy='public-read') + hook.load_string("Contént", "my_key", s3_bucket, acl_policy='public-read') response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer='requester') - assert ((response['Grants'][1]['Permission'] == 'READ') and - (response['Grants'][0]['Permission'] == 'FULL_CONTROL')) + assert (response['Grants'][1]['Permission'] == 'READ') and ( + response['Grants'][0]['Permission'] == 'FULL_CONTROL' + ) def test_load_bytes(self, s3_bucket): hook = S3Hook() @@ -247,11 +247,11 @@ def test_load_bytes(self, s3_bucket): def test_load_bytes_acl(self, s3_bucket): hook = S3Hook() - hook.load_bytes(b"Content", "my_key", s3_bucket, - acl_policy='public-read') + hook.load_bytes(b"Content", "my_key", s3_bucket, acl_policy='public-read') response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer='requester') - assert ((response['Grants'][1]['Permission'] == 'READ') and - (response['Grants'][0]['Permission'] == 'FULL_CONTROL')) + assert (response['Grants'][1]['Permission'] == 'READ') and ( + response['Grants'][0]['Permission'] == 'FULL_CONTROL' + ) def test_load_fileobj(self, s3_bucket): hook = S3Hook() @@ -267,13 +267,13 @@ def test_load_fileobj_acl(self, s3_bucket): with tempfile.TemporaryFile() as temp_file: temp_file.write(b"Content") temp_file.seek(0) - hook.load_file_obj(temp_file, "my_key", s3_bucket, - acl_policy='public-read') - response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, - Key="my_key", - RequestPayer='requester') # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 - assert ((response['Grants'][1]['Permission'] == 'READ') and - (response['Grants'][0]['Permission'] == 'FULL_CONTROL')) + hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy='public-read') + response = boto3.client('s3').get_object_acl( + Bucket=s3_bucket, Key="my_key", RequestPayer='requester' + ) # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 + assert (response['Grants'][1]['Permission'] == 'READ') and ( + response['Grants'][0]['Permission'] == 'FULL_CONTROL' + ) def test_load_file_gzip(self, s3_bucket): hook = S3Hook() @@ -290,13 +290,13 @@ def test_load_file_acl(self, s3_bucket): with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(b"Content") temp_file.seek(0) - hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True, - acl_policy='public-read') - response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, - Key="my_key", - RequestPayer='requester') # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 - assert ((response['Grants'][1]['Permission'] == 'READ') and - (response['Grants'][0]['Permission'] == 'FULL_CONTROL')) + hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True, acl_policy='public-read') + response = boto3.client('s3').get_object_acl( + Bucket=s3_bucket, Key="my_key", RequestPayer='requester' + ) # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 + assert (response['Grants'][1]['Permission'] == 'READ') and ( + response['Grants'][0]['Permission'] == 'FULL_CONTROL' + ) os.unlink(temp_file.name) def test_copy_object_acl(self, s3_bucket): @@ -306,11 +306,10 @@ def test_copy_object_acl(self, s3_bucket): temp_file.seek(0) hook.load_file_obj(temp_file, "my_key", s3_bucket) hook.copy_object("my_key", "my_key", s3_bucket, s3_bucket) - response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, - Key="my_key", - RequestPayer='requester') # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 - assert ((response['Grants'][0]['Permission'] == 'FULL_CONTROL') and - (len(response['Grants']) == 1)) + response = boto3.client('s3').get_object_acl( + Bucket=s3_bucket, Key="my_key", RequestPayer='requester' + ) # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 + assert (response['Grants'][0]['Permission'] == 'FULL_CONTROL') and (len(response['Grants']) == 1) @mock_s3 def test_delete_bucket_if_bucket_exist(self, s3_bucket): @@ -332,9 +331,7 @@ def test_delete_bucket_if_not_bucket_exist(self, s3_bucket): @mock.patch.object(S3Hook, 'get_connection', return_value=Connection(schema='test_bucket')) def test_provide_bucket_name(self, mock_get_connection): - class FakeS3Hook(S3Hook): - @provide_bucket_name def test_function(self, bucket_name=None): return bucket_name @@ -376,9 +373,7 @@ def test_delete_objects_many_keys(self, mocked_s3_res, s3_bucket): assert [o.key for o in mocked_s3_res.Bucket(s3_bucket).objects.all()] == [] def test_unify_bucket_name_and_key(self): - class FakeS3Hook(S3Hook): - @unify_bucket_name_and_key def test_function_with_wildcard_key(self, wildcard_key, bucket_name=None): return bucket_name, wildcard_key @@ -422,12 +417,11 @@ def test_download_file(self, mock_temp_file): def test_generate_presigned_url(self, s3_bucket): hook = S3Hook() - presigned_url = hook.generate_presigned_url(client_method="get_object", - params={'Bucket': s3_bucket, - 'Key': "my_key"}) + presigned_url = hook.generate_presigned_url( + client_method="get_object", params={'Bucket': s3_bucket, 'Key': "my_key"} + ) url = presigned_url.split("?")[1] - params = {x[0]: x[1] - for x in [x.split("=") for x in url[0:].split("&")]} + params = {x[0]: x[1] for x in [x.split("=") for x in url[0:].split("&")]} assert {"AWSAccessKeyId", "Signature", "Expires"}.issubset(set(params.keys())) diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py index f1275748e2bab..ff456129410e5 100644 --- a/tests/providers/amazon/aws/hooks/test_sagemaker.py +++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py @@ -28,7 +28,10 @@ from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.hooks.sagemaker import ( - LogState, SageMakerHook, secondary_training_status_changed, secondary_training_status_message, + LogState, + SageMakerHook, + secondary_training_status_changed, + secondary_training_status_message, ) role = 'arn:aws:iam:role/test-role' @@ -48,29 +51,13 @@ output_url = 's3://{}/test/output'.format(bucket) create_training_params = { - 'AlgorithmSpecification': { - 'TrainingImage': image, - 'TrainingInputMode': 'File' - }, + 'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'}, 'RoleArn': role, - 'OutputDataConfig': { - 'S3OutputPath': output_url - }, - 'ResourceConfig': { - 'InstanceCount': 2, - 'InstanceType': 'ml.c4.8xlarge', - 'VolumeSizeInGB': 50 - }, + 'OutputDataConfig': {'S3OutputPath': output_url}, + 'ResourceConfig': {'InstanceCount': 2, 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': 50}, 'TrainingJobName': job_name, - 'HyperParameters': { - 'k': '10', - 'feature_dim': '784', - 'mini_batch_size': '500', - 'force_dense': 'True' - }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 60 * 60 - }, + 'HyperParameters': {'k': '10', 'feature_dim': '784', 'mini_batch_size': '500', 'force_dense': 'True'}, + 'StoppingCondition': {'MaxRuntimeInSeconds': 60 * 60}, 'InputDataConfig': [ { 'ChannelName': 'train', @@ -78,37 +65,22 @@ 'S3DataSource': { 'S3DataType': 'S3Prefix', 'S3Uri': data_url, - 'S3DataDistributionType': 'FullyReplicated' + 'S3DataDistributionType': 'FullyReplicated', } }, 'CompressionType': 'None', - 'RecordWrapperType': 'None' + 'RecordWrapperType': 'None', } - ] + ], } create_tuning_params = { 'HyperParameterTuningJobName': job_name, 'HyperParameterTuningJobConfig': { 'Strategy': 'Bayesian', - 'HyperParameterTuningJobObjective': { - 'Type': 'Maximize', - 'MetricName': 'test_metric' - }, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': 123, - 'MaxParallelTrainingJobs': 123 - }, - 'ParameterRanges': { - 'IntegerParameterRanges': [ - { - 'Name': 'k', - 'MinValue': '2', - 'MaxValue': '10' - }, - - ] - } + 'HyperParameterTuningJobObjective': {'Type': 'Maximize', 'MetricName': 'test_metric'}, + 'ResourceLimits': {'MaxNumberOfTrainingJobs': 123, 'MaxParallelTrainingJobs': 123}, + 'ParameterRanges': {'IntegerParameterRanges': [{'Name': 'k', 'MinValue': '2', 'MaxValue': '10'},]}, }, 'TrainingJobDefinition': { 'StaticHyperParameters': create_training_params['HyperParameters'], @@ -117,38 +89,23 @@ 'InputDataConfig': create_training_params['InputDataConfig'], 'OutputDataConfig': create_training_params['OutputDataConfig'], 'ResourceConfig': create_training_params['ResourceConfig'], - 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) - } + 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60), + }, } create_transform_params = { 'TransformJobName': job_name, 'ModelName': model_name, 'BatchStrategy': 'MultiRecord', - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url - } - } - }, - 'TransformOutput': { - 'S3OutputPath': output_url, - }, - 'TransformResources': { - 'InstanceType': 'ml.m4.xlarge', - 'InstanceCount': 123 - } + 'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': data_url}}}, + 'TransformOutput': {'S3OutputPath': output_url,}, + 'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 123}, } create_model_params = { 'ModelName': model_name, - 'PrimaryContainer': { - 'Image': image, - 'ModelDataUrl': output_url, - }, - 'ExecutionRoleArn': role + 'PrimaryContainer': {'Image': image, 'ModelDataUrl': output_url,}, + 'ExecutionRoleArn': role, } create_endpoint_config_params = { @@ -158,38 +115,28 @@ 'VariantName': 'AllTraffic', 'ModelName': model_name, 'InitialInstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' + 'InstanceType': 'ml.c4.xlarge', } - ] + ], } -create_endpoint_params = { - 'EndpointName': endpoint_name, - 'EndpointConfigName': config_name -} +create_endpoint_params = {'EndpointName': endpoint_name, 'EndpointConfigName': config_name} update_endpoint_params = create_endpoint_params DESCRIBE_TRAINING_COMPLETED_RETURN = { 'TrainingJobStatus': 'Completed', - 'ResourceConfig': { - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge', - 'VolumeSizeInGB': 10 - }, + 'ResourceConfig': {'InstanceCount': 1, 'InstanceType': 'ml.c4.xlarge', 'VolumeSizeInGB': 10}, 'TrainingStartTime': datetime(2018, 2, 17, 7, 15, 0, 103000), 'TrainingEndTime': datetime(2018, 2, 17, 7, 19, 34, 953000), - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_TRAINING_INPROGRESS_RETURN = dict(DESCRIBE_TRAINING_COMPLETED_RETURN) DESCRIBE_TRAINING_INPROGRESS_RETURN.update({'TrainingJobStatus': 'InProgress'}) DESCRIBE_TRAINING_FAILED_RETURN = dict(DESCRIBE_TRAINING_COMPLETED_RETURN) -DESCRIBE_TRAINING_FAILED_RETURN.update({'TrainingJobStatus': 'Failed', - 'FailureReason': 'Unknown'}) +DESCRIBE_TRAINING_FAILED_RETURN.update({'TrainingJobStatus': 'Failed', 'FailureReason': 'Unknown'}) DESCRIBE_TRAINING_STOPPING_RETURN = dict(DESCRIBE_TRAINING_COMPLETED_RETURN) DESCRIBE_TRAINING_STOPPING_RETURN.update({'TrainingJobStatus': 'Stopping'}) @@ -204,43 +151,45 @@ } DEFAULT_LOG_STREAMS = {'logStreams': [{'logStreamName': job_name + '/xxxxxxxxx'}]} -LIFECYCLE_LOG_STREAMS = [DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS] - -DEFAULT_LOG_EVENTS = [{'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, - {'nextForwardToken': None, 'events': []}] -STREAM_LOG_EVENTS = [{'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, - {'nextForwardToken': None, 'events': []}, - {'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}, - {'timestamp': 2, 'message': 'hi there #2'}]}, - {'nextForwardToken': None, 'events': []}, - {'nextForwardToken': None, 'events': [{'timestamp': 2, 'message': 'hi there #2'}, - {'timestamp': 2, 'message': 'hi there #2a'}, - {'timestamp': 3, 'message': 'hi there #3'}]}, - {'nextForwardToken': None, 'events': []}] +LIFECYCLE_LOG_STREAMS = [ + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, +] + +DEFAULT_LOG_EVENTS = [ + {'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, + {'nextForwardToken': None, 'events': []}, +] +STREAM_LOG_EVENTS = [ + {'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, + {'nextForwardToken': None, 'events': []}, + { + 'nextForwardToken': None, + 'events': [{'timestamp': 1, 'message': 'hi there #1'}, {'timestamp': 2, 'message': 'hi there #2'}], + }, + {'nextForwardToken': None, 'events': []}, + { + 'nextForwardToken': None, + 'events': [ + {'timestamp': 2, 'message': 'hi there #2'}, + {'timestamp': 2, 'message': 'hi there #2a'}, + {'timestamp': 3, 'message': 'hi there #3'}, + ], + }, + {'nextForwardToken': None, 'events': []}, +] test_evaluation_config = { 'Image': image, 'Role': role, 'S3Operations': { - 'S3CreateBucket': [ - { - 'Bucket': bucket - } - ], - 'S3Upload': [ - { - 'Path': path, - 'Bucket': bucket, - 'Key': key, - 'Tar': False - } - ] - } + 'S3CreateBucket': [{'Bucket': bucket}], + 'S3Upload': [{'Path': path, 'Bucket': bucket, 'Key': key, 'Tar': False}], + }, } @@ -257,10 +206,7 @@ def test_multi_stream_iter(self, mock_log_stream): @mock.patch.object(S3Hook, 'load_file') def test_configure_s3_resources(self, mock_load_file, mock_create_bucket): hook = SageMakerHook() - evaluation_result = { - 'Image': image, - 'Role': role - } + evaluation_result = {'Image': image, 'Role': role} hook.configure_s3_resources(test_evaluation_config) self.assertEqual(test_evaluation_config, evaluation_result) mock_create_bucket.assert_called_once_with(bucket_name=bucket) @@ -270,20 +216,14 @@ def test_configure_s3_resources(self, mock_load_file, mock_create_bucket): @mock.patch.object(S3Hook, 'check_for_key') @mock.patch.object(S3Hook, 'check_for_bucket') @mock.patch.object(S3Hook, 'check_for_prefix') - def test_check_s3_url(self, - mock_check_prefix, - mock_check_bucket, - mock_check_key, - mock_client): + def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key, mock_client): mock_client.return_value = None hook = SageMakerHook() mock_check_bucket.side_effect = [False, True, True, True] mock_check_key.side_effect = [False, True, False] mock_check_prefix.side_effect = [False, True, True] - self.assertRaises(AirflowException, - hook.check_s3_url, data_url) - self.assertRaises(AirflowException, - hook.check_s3_url, data_url) + self.assertRaises(AirflowException, hook.check_s3_url, data_url) + self.assertRaises(AirflowException, hook.check_s3_url, data_url) self.assertEqual(hook.check_s3_url(data_url), True) self.assertEqual(hook.check_s3_url(data_url), True) @@ -318,14 +258,13 @@ def test_conn(self, mock_get_client_type): def test_create_training_job(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() - attrs = {'create_training_job.return_value': - test_arn_return} + attrs = {'create_training_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.create_training_job(create_training_params, - wait_for_completion=False, - print_log=False) + response = hook.create_training_job( + create_training_params, wait_for_completion=False, print_log=False + ) mock_session.create_training_job.assert_called_once_with(**create_training_params) self.assertEqual(response, test_arn_return) @@ -334,56 +273,60 @@ def test_create_training_job(self, mock_client, mock_check_training): def test_training_ends_with_wait(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() - attrs = {'create_training_job.return_value': - test_arn_return, - 'describe_training_job.side_effect': - [DESCRIBE_TRAINING_INPROGRESS_RETURN, - DESCRIBE_TRAINING_STOPPING_RETURN, - DESCRIBE_TRAINING_COMPLETED_RETURN, - DESCRIBE_TRAINING_COMPLETED_RETURN] - } + attrs = { + 'create_training_job.return_value': test_arn_return, + 'describe_training_job.side_effect': [ + DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_COMPLETED_RETURN, + DESCRIBE_TRAINING_COMPLETED_RETURN, + ], + } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') - hook.create_training_job(create_training_params, wait_for_completion=True, - print_log=False, check_interval=1) + hook.create_training_job( + create_training_params, wait_for_completion=True, print_log=False, check_interval=1 + ) self.assertEqual(mock_session.describe_training_job.call_count, 4) @mock.patch.object(SageMakerHook, 'check_training_config') @mock.patch.object(SageMakerHook, 'get_conn') - def test_training_throws_error_when_failed_with_wait( - self, mock_client, mock_check_training): + def test_training_throws_error_when_failed_with_wait(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() - attrs = {'create_training_job.return_value': - test_arn_return, - 'describe_training_job.side_effect': - [DESCRIBE_TRAINING_INPROGRESS_RETURN, - DESCRIBE_TRAINING_STOPPING_RETURN, - DESCRIBE_TRAINING_FAILED_RETURN, - DESCRIBE_TRAINING_COMPLETED_RETURN] - } + attrs = { + 'create_training_job.return_value': test_arn_return, + 'describe_training_job.side_effect': [ + DESCRIBE_TRAINING_INPROGRESS_RETURN, + DESCRIBE_TRAINING_STOPPING_RETURN, + DESCRIBE_TRAINING_FAILED_RETURN, + DESCRIBE_TRAINING_COMPLETED_RETURN, + ], + } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') - self.assertRaises(AirflowException, hook.create_training_job, - create_training_params, wait_for_completion=True, - print_log=False, check_interval=1) + self.assertRaises( + AirflowException, + hook.create_training_job, + create_training_params, + wait_for_completion=True, + print_log=False, + check_interval=1, + ) self.assertEqual(mock_session.describe_training_job.call_count, 3) @mock.patch.object(SageMakerHook, 'check_tuning_config') @mock.patch.object(SageMakerHook, 'get_conn') def test_create_tuning_job(self, mock_client, mock_check_tuning_config): mock_session = mock.Mock() - attrs = {'create_hyper_parameter_tuning_job.return_value': - test_arn_return} + attrs = {'create_hyper_parameter_tuning_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.create_tuning_job(create_tuning_params, - wait_for_completion=False) - mock_session.create_hyper_parameter_tuning_job.\ - assert_called_once_with(**create_tuning_params) + response = hook.create_tuning_job(create_tuning_params, wait_for_completion=False) + mock_session.create_hyper_parameter_tuning_job.assert_called_once_with(**create_tuning_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'check_s3_url') @@ -391,22 +334,18 @@ def test_create_tuning_job(self, mock_client, mock_check_tuning_config): def test_create_transform_job(self, mock_client, mock_check_url): mock_check_url.return_value = True mock_session = mock.Mock() - attrs = {'create_transform_job.return_value': - test_arn_return} + attrs = {'create_transform_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.create_transform_job(create_transform_params, - wait_for_completion=False) - mock_session.create_transform_job.assert_called_once_with( - **create_transform_params) + response = hook.create_transform_job(create_transform_params, wait_for_completion=False) + mock_session.create_transform_job.assert_called_once_with(**create_transform_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'get_conn') def test_create_model(self, mock_client): mock_session = mock.Mock() - attrs = {'create_model.return_value': - test_arn_return} + attrs = {'create_model.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') @@ -417,42 +356,34 @@ def test_create_model(self, mock_client): @mock.patch.object(SageMakerHook, 'get_conn') def test_create_endpoint_config(self, mock_client): mock_session = mock.Mock() - attrs = {'create_endpoint_config.return_value': - test_arn_return} + attrs = {'create_endpoint_config.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_endpoint_config(create_endpoint_config_params) - mock_session.create_endpoint_config\ - .assert_called_once_with(**create_endpoint_config_params) + mock_session.create_endpoint_config.assert_called_once_with(**create_endpoint_config_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'get_conn') def test_create_endpoint(self, mock_client): mock_session = mock.Mock() - attrs = {'create_endpoint.return_value': - test_arn_return} + attrs = {'create_endpoint.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.create_endpoint(create_endpoint_params, - wait_for_completion=False) - mock_session.create_endpoint\ - .assert_called_once_with(**create_endpoint_params) + response = hook.create_endpoint(create_endpoint_params, wait_for_completion=False) + mock_session.create_endpoint.assert_called_once_with(**create_endpoint_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'get_conn') def test_update_endpoint(self, mock_client): mock_session = mock.Mock() - attrs = {'update_endpoint.return_value': - test_arn_return} + attrs = {'update_endpoint.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.update_endpoint(update_endpoint_params, - wait_for_completion=False) - mock_session.update_endpoint\ - .assert_called_once_with(**update_endpoint_params) + response = hook.update_endpoint(update_endpoint_params, wait_for_completion=False) + mock_session.update_endpoint.assert_called_once_with(**update_endpoint_params) self.assertEqual(response, test_arn_return) @mock.patch.object(SageMakerHook, 'get_conn') @@ -463,83 +394,76 @@ def test_describe_training_job(self, mock_client): mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_training_job(job_name) - mock_session.describe_training_job.\ - assert_called_once_with(TrainingJobName=job_name) + mock_session.describe_training_job.assert_called_once_with(TrainingJobName=job_name) self.assertEqual(response, 'InProgress') @mock.patch.object(SageMakerHook, 'get_conn') def test_describe_tuning_job(self, mock_client): mock_session = mock.Mock() - attrs = {'describe_hyper_parameter_tuning_job.return_value': - 'InProgress'} + attrs = {'describe_hyper_parameter_tuning_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_tuning_job(job_name) - mock_session.describe_hyper_parameter_tuning_job.\ - assert_called_once_with(HyperParameterTuningJobName=job_name) + mock_session.describe_hyper_parameter_tuning_job.assert_called_once_with( + HyperParameterTuningJobName=job_name + ) self.assertEqual(response, 'InProgress') @mock.patch.object(SageMakerHook, 'get_conn') def test_describe_transform_job(self, mock_client): mock_session = mock.Mock() - attrs = {'describe_transform_job.return_value': - 'InProgress'} + attrs = {'describe_transform_job.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_transform_job(job_name) - mock_session.describe_transform_job.\ - assert_called_once_with(TransformJobName=job_name) + mock_session.describe_transform_job.assert_called_once_with(TransformJobName=job_name) self.assertEqual(response, 'InProgress') @mock.patch.object(SageMakerHook, 'get_conn') def test_describe_model(self, mock_client): mock_session = mock.Mock() - attrs = {'describe_model.return_value': - model_name} + attrs = {'describe_model.return_value': model_name} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_model(model_name) - mock_session.describe_model.\ - assert_called_once_with(ModelName=model_name) + mock_session.describe_model.assert_called_once_with(ModelName=model_name) self.assertEqual(response, model_name) @mock.patch.object(SageMakerHook, 'get_conn') def test_describe_endpoint_config(self, mock_client): mock_session = mock.Mock() - attrs = {'describe_endpoint_config.return_value': - config_name} + attrs = {'describe_endpoint_config.return_value': config_name} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_endpoint_config(config_name) - mock_session.describe_endpoint_config.\ - assert_called_once_with(EndpointConfigName=config_name) + mock_session.describe_endpoint_config.assert_called_once_with(EndpointConfigName=config_name) self.assertEqual(response, config_name) @mock.patch.object(SageMakerHook, 'get_conn') def test_describe_endpoint(self, mock_client): mock_session = mock.Mock() - attrs = {'describe_endpoint.return_value': - 'InProgress'} + attrs = {'describe_endpoint.return_value': 'InProgress'} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.describe_endpoint(endpoint_name) - mock_session.describe_endpoint.\ - assert_called_once_with(EndpointName=endpoint_name) + mock_session.describe_endpoint.assert_called_once_with(EndpointName=endpoint_name) self.assertEqual(response, 'InProgress') def test_secondary_training_status_changed_true(self): - changed = secondary_training_status_changed(SECONDARY_STATUS_DESCRIPTION_1, - SECONDARY_STATUS_DESCRIPTION_2) + changed = secondary_training_status_changed( + SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_2 + ) self.assertTrue(changed) def test_secondary_training_status_changed_false(self): - changed = secondary_training_status_changed(SECONDARY_STATUS_DESCRIPTION_1, - SECONDARY_STATUS_DESCRIPTION_1) + changed = secondary_training_status_changed( + SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_1 + ) self.assertFalse(changed) def test_secondary_training_status_message_status_changed(self): @@ -548,11 +472,12 @@ def test_secondary_training_status_message_status_changed(self): expected = '{} {} - {}'.format( datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'), status, - message + message, ) self.assertEqual( secondary_training_status_message(SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_2), - expected) + expected, + ) @mock.patch.object(AwsLogsHook, 'get_conn') @mock.patch.object(SageMakerHook, 'get_conn') @@ -560,27 +485,26 @@ def test_secondary_training_status_message_status_changed(self): def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_client, mock_log_client): mock_session = mock.Mock() mock_log_session = mock.Mock() - attrs = {'describe_training_job.return_value': - DESCRIBE_TRAINING_COMPLETED_RETURN - } - log_attrs = {'describe_log_streams.side_effect': - LIFECYCLE_LOG_STREAMS, - 'get_log_events.side_effect': - STREAM_LOG_EVENTS - } + attrs = {'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN} + log_attrs = { + 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': STREAM_LOG_EVENTS, + } mock_time.return_value = 50 mock_session.configure_mock(**attrs) mock_client.return_value = mock_session mock_log_session.configure_mock(**log_attrs) mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.describe_training_job_with_log(job_name=job_name, - positions={}, - stream_names=[], - instance_count=1, - state=LogState.WAIT_IN_PROGRESS, - last_description={}, - last_describe_job_call=0) + response = hook.describe_training_job_with_log( + job_name=job_name, + positions={}, + stream_names=[], + instance_count=1, + state=LogState.WAIT_IN_PROGRESS, + last_description={}, + last_describe_job_call=0, + ) self.assertEqual(response, (LogState.JOB_COMPLETE, {}, 50)) @mock.patch.object(AwsLogsHook, 'get_conn') @@ -588,26 +512,25 @@ def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_clien def test_describe_training_job_with_logs_job_complete(self, mock_client, mock_log_client): mock_session = mock.Mock() mock_log_session = mock.Mock() - attrs = {'describe_training_job.return_value': - DESCRIBE_TRAINING_COMPLETED_RETURN - } - log_attrs = {'describe_log_streams.side_effect': - LIFECYCLE_LOG_STREAMS, - 'get_log_events.side_effect': - STREAM_LOG_EVENTS - } + attrs = {'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN} + log_attrs = { + 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': STREAM_LOG_EVENTS, + } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session mock_log_session.configure_mock(**log_attrs) mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.describe_training_job_with_log(job_name=job_name, - positions={}, - stream_names=[], - instance_count=1, - state=LogState.JOB_COMPLETE, - last_description={}, - last_describe_job_call=0) + response = hook.describe_training_job_with_log( + job_name=job_name, + positions={}, + stream_names=[], + instance_count=1, + state=LogState.JOB_COMPLETE, + last_description={}, + last_describe_job_call=0, + ) self.assertEqual(response, (LogState.COMPLETE, {}, 0)) @mock.patch.object(AwsLogsHook, 'get_conn') @@ -615,26 +538,25 @@ def test_describe_training_job_with_logs_job_complete(self, mock_client, mock_lo def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_client): mock_session = mock.Mock() mock_log_session = mock.Mock() - attrs = {'describe_training_job.return_value': - DESCRIBE_TRAINING_COMPLETED_RETURN - } - log_attrs = {'describe_log_streams.side_effect': - LIFECYCLE_LOG_STREAMS, - 'get_log_events.side_effect': - STREAM_LOG_EVENTS - } + attrs = {'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN} + log_attrs = { + 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': STREAM_LOG_EVENTS, + } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session mock_log_session.configure_mock(**log_attrs) mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') - response = hook.describe_training_job_with_log(job_name=job_name, - positions={}, - stream_names=[], - instance_count=1, - state=LogState.COMPLETE, - last_description={}, - last_describe_job_call=0) + response = hook.describe_training_job_with_log( + job_name=job_name, + positions={}, + stream_names=[], + instance_count=1, + state=LogState.COMPLETE, + last_description={}, + last_describe_job_call=0, + ) self.assertEqual(response, (LogState.COMPLETE, {}, 0)) @mock.patch.object(SageMakerHook, 'check_training_config') @@ -643,28 +565,28 @@ def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_cl @mock.patch.object(SageMakerHook, 'describe_training_job_with_log') def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training): mock_check_training.return_value = True - mock_describe.side_effect = \ - [(LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), - (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), - (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RETURN, 0)] + mock_describe.side_effect = [ + (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), + (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), + (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RETURN, 0), + ] mock_session = mock.Mock() mock_log_session = mock.Mock() - attrs = {'create_training_job.return_value': - test_arn_return, - 'describe_training_job.return_value': - DESCRIBE_TRAINING_COMPLETED_RETURN - } - log_attrs = {'describe_log_streams.side_effect': - LIFECYCLE_LOG_STREAMS, - 'get_log_events.side_effect': - STREAM_LOG_EVENTS - } + attrs = { + 'create_training_job.return_value': test_arn_return, + 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN, + } + log_attrs = { + 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, + 'get_log_events.side_effect': STREAM_LOG_EVENTS, + } mock_session.configure_mock(**attrs) mock_log_session.configure_mock(**log_attrs) mock_client.return_value = mock_session mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') - hook.create_training_job(create_training_params, wait_for_completion=True, - print_log=True, check_interval=1) + hook.create_training_job( + create_training_params, wait_for_completion=True, print_log=True, check_interval=1 + ) self.assertEqual(mock_describe.call_count, 3) self.assertEqual(mock_session.describe_training_job.call_count, 1) diff --git a/tests/providers/amazon/aws/hooks/test_ses.py b/tests/providers/amazon/aws/hooks/test_ses.py index 555755d255f64..06d56f908d992 100644 --- a/tests/providers/amazon/aws/hooks/test_ses.py +++ b/tests/providers/amazon/aws/hooks/test_ses.py @@ -31,24 +31,15 @@ def test_get_conn(): @mock_ses -@pytest.mark.parametrize('to', - [ - 'to@domain.com', - ['to1@domain.com', 'to2@domain.com'], - 'to1@domain.com,to2@domain.com' - ]) -@pytest.mark.parametrize('cc', - [ - 'cc@domain.com', - ['cc1@domain.com', 'cc2@domain.com'], - 'cc1@domain.com,cc2@domain.com' - ]) -@pytest.mark.parametrize('bcc', - [ - 'bcc@domain.com', - ['bcc1@domain.com', 'bcc2@domain.com'], - 'bcc1@domain.com,bcc2@domain.com' - ]) +@pytest.mark.parametrize( + 'to', ['to@domain.com', ['to1@domain.com', 'to2@domain.com'], 'to1@domain.com,to2@domain.com'] +) +@pytest.mark.parametrize( + 'cc', ['cc@domain.com', ['cc1@domain.com', 'cc2@domain.com'], 'cc1@domain.com,cc2@domain.com'] +) +@pytest.mark.parametrize( + 'bcc', ['bcc@domain.com', ['bcc1@domain.com', 'bcc2@domain.com'], 'bcc1@domain.com,bcc2@domain.com'] +) def test_send_email(to, cc, bcc): # Given hook = SESHook() diff --git a/tests/providers/amazon/aws/hooks/test_sns.py b/tests/providers/amazon/aws/hooks/test_sns.py index bbec25ebd42b5..f14fde6b9ac71 100644 --- a/tests/providers/amazon/aws/hooks/test_sns.py +++ b/tests/providers/amazon/aws/hooks/test_sns.py @@ -29,7 +29,6 @@ @unittest.skipIf(mock_sns is None, 'moto package not present') class TestAwsSnsHook(unittest.TestCase): - @mock_sns def test_get_conn_returns_a_boto3_connection(self): hook = AwsSnsHook(aws_conn_id='aws_default') @@ -56,12 +55,16 @@ def test_publish_to_target_with_attributes(self): topic_name = "test-topic" target = hook.get_conn().create_topic(Name=topic_name).get('TopicArn') - response = hook.publish_to_target(target, message, message_attributes={ - 'test-string': 'string-value', - 'test-number': 123456, - 'test-array': ['first', 'second', 'third'], - 'test-binary': b'binary-value', - }) + response = hook.publish_to_target( + target, + message, + message_attributes={ + 'test-string': 'string-value', + 'test-number': 123456, + 'test-array': ['first', 'second', 'third'], + 'test-binary': b'binary-value', + }, + ) assert 'MessageId' in response diff --git a/tests/providers/amazon/aws/hooks/test_sqs.py b/tests/providers/amazon/aws/hooks/test_sqs.py index 5f39c57365667..4b85af9d7cc15 100644 --- a/tests/providers/amazon/aws/hooks/test_sqs.py +++ b/tests/providers/amazon/aws/hooks/test_sqs.py @@ -29,7 +29,6 @@ @unittest.skipIf(mock_sqs is None, 'moto sqs package missing') class TestAwsSQSHook(unittest.TestCase): - @mock_sqs def test_get_conn(self): hook = SQSHook(aws_conn_id='aws_default') diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index 679d2e44039ad..82b7f1fe2c9f5 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -29,7 +29,6 @@ @unittest.skipIf(mock_stepfunctions is None, 'moto package not present') class TestStepFunctionHook(unittest.TestCase): - @mock_stepfunctions def test_get_conn_returns_a_boto3_connection(self): hook = StepFunctionHook(aws_conn_id='aws_default') @@ -39,12 +38,14 @@ def test_get_conn_returns_a_boto3_connection(self): def test_start_execution(self): hook = StepFunctionHook(aws_conn_id='aws_default', region_name='us-east-1') state_machine = hook.get_conn().create_state_machine( - name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role') + name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role' + ) state_machine_arn = state_machine.get('stateMachineArn', None) execution_arn = hook.start_execution( - state_machine_arn=state_machine_arn, name=None, state_machine_input={}) + state_machine_arn=state_machine_arn, name=None, state_machine_input={} + ) assert execution_arn is not None @@ -52,12 +53,14 @@ def test_start_execution(self): def test_describe_execution(self): hook = StepFunctionHook(aws_conn_id='aws_default', region_name='us-east-1') state_machine = hook.get_conn().create_state_machine( - name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role') + name='pseudo-state-machine', definition='{}', roleArn='arn:aws:iam::000000000000:role/Role' + ) state_machine_arn = state_machine.get('stateMachineArn', None) execution_arn = hook.start_execution( - state_machine_arn=state_machine_arn, name=None, state_machine_input={}) + state_machine_arn=state_machine_arn, name=None, state_machine_input={} + ) response = hook.describe_execution(execution_arn) assert 'input' in response diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index ab032cd4cfbf3..7509be222b34c 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -38,11 +38,9 @@ mock_logs = None -@unittest.skipIf(mock_logs is None, - "Skipping test because moto.mock_logs is not available") +@unittest.skipIf(mock_logs is None, "Skipping test because moto.mock_logs is not available") @mock_logs class TestCloudwatchTaskHandler(unittest.TestCase): - @conf_vars({('logging', 'remote_log_conn_id'): 'aws_default'}) def setUp(self): self.remote_log_group = 'log_group_name' @@ -52,7 +50,7 @@ def setUp(self): self.cloudwatch_task_handler = CloudwatchTaskHandler( self.local_log_location, "arn:aws:logs:{}:11111111:log-group:{}".format(self.region_name, self.remote_log_group), - self.filename_template + self.filename_template, ) self.cloudwatch_task_handler.hook @@ -83,7 +81,7 @@ def test_hook_raises(self): handler = CloudwatchTaskHandler( self.local_log_location, "arn:aws:logs:{}:11111111:log-group:{}".format(self.region_name, self.remote_log_group), - self.filename_template + self.filename_template, ) with mock.patch.object(handler.log, 'error') as mock_error: @@ -95,7 +93,7 @@ def test_hook_raises(self): mock_error.assert_called_once_with( 'Could not create an AwsLogsHook with connection id "%s". Please make ' 'sure that airflow[aws] is installed and the Cloudwatch logs connection exists.', - 'aws_default' + 'aws_default', ) def test_handler(self): @@ -118,26 +116,18 @@ def test_read(self): self.remote_log_group, self.remote_log_stream, [ - { - 'timestamp': 20000, - 'message': 'Second' - }, - { - 'timestamp': 10000, - 'message': 'First' - }, - { - 'timestamp': 30000, - 'message': 'Third' - }, - ] + {'timestamp': 20000, 'message': 'Second'}, + {'timestamp': 10000, 'message': 'First'}, + {'timestamp': 30000, 'message': 'Third'}, + ], ) - expected = '*** Reading remote log from Cloudwatch log_group: {} ' \ - 'log_stream: {}.\nFirst\nSecond\nThird\n' + expected = ( + '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\nFirst\nSecond\nThird\n' + ) self.assertEqual( self.cloudwatch_task_handler.read(self.ti), - ([expected.format(self.remote_log_group, self.remote_log_stream)], [{'end_of_log': True}]) + ([expected.format(self.remote_log_group, self.remote_log_stream)], [{'end_of_log': True}]), ) def test_read_wrong_log_stream(self): @@ -146,33 +136,22 @@ def test_read_wrong_log_stream(self): self.remote_log_group, 'alternate_log_stream', [ - { - 'timestamp': 20000, - 'message': 'Second' - }, - { - 'timestamp': 10000, - 'message': 'First' - }, - { - 'timestamp': 30000, - 'message': 'Third' - }, - ] + {'timestamp': 20000, 'message': 'Second'}, + {'timestamp': 10000, 'message': 'First'}, + {'timestamp': 30000, 'message': 'Third'}, + ], ) - msg_template = '*** Reading remote log from Cloudwatch log_group: {} ' \ - 'log_stream: {}.\n{}\n' + msg_template = '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n' error_msg = 'Could not read remote logs from log_group: {} log_stream: {}.'.format( - self.remote_log_group, - self.remote_log_stream + self.remote_log_group, self.remote_log_stream ) self.assertEqual( self.cloudwatch_task_handler.read(self.ti), - ([msg_template.format( - self.remote_log_group, - self.remote_log_stream, error_msg)], - [{'end_of_log': True}]) + ( + [msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg)], + [{'end_of_log': True}], + ), ) def test_read_wrong_log_group(self): @@ -181,32 +160,22 @@ def test_read_wrong_log_group(self): 'alternate_log_group', self.remote_log_stream, [ - { - 'timestamp': 20000, - 'message': 'Second' - }, - { - 'timestamp': 10000, - 'message': 'First' - }, - { - 'timestamp': 30000, - 'message': 'Third' - }, - ] + {'timestamp': 20000, 'message': 'Second'}, + {'timestamp': 10000, 'message': 'First'}, + {'timestamp': 30000, 'message': 'Third'}, + ], ) msg_template = '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n' error_msg = 'Could not read remote logs from log_group: {} log_stream: {}.'.format( - self.remote_log_group, - self.remote_log_stream + self.remote_log_group, self.remote_log_stream ) self.assertEqual( self.cloudwatch_task_handler.read(self.ti), - ([msg_template.format( - self.remote_log_group, - self.remote_log_stream, error_msg)], - [{'end_of_log': True}]) + ( + [msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg)], + [{'end_of_log': True}], + ), ) def test_close_prevents_duplicate_calls(self): @@ -221,12 +190,5 @@ def test_close_prevents_duplicate_calls(self): def generate_log_events(conn, log_group_name, log_stream_name, log_events): conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) - conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=log_events - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) + conn.put_log_events(logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=log_events) diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py index ff2ae6517e595..4c737e05da324 100644 --- a/tests/providers/amazon/aws/log/test_s3_task_handler.py +++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py @@ -36,11 +36,9 @@ mock_s3 = None -@unittest.skipIf(mock_s3 is None, - "Skipping test because moto.mock_s3 is not available") +@unittest.skipIf(mock_s3 is None, "Skipping test because moto.mock_s3 is not available") @mock_s3 class TestS3TaskHandler(unittest.TestCase): - @conf_vars({('logging', 'remote_log_conn_id'): 'aws_default'}) def setUp(self): super().setUp() @@ -50,9 +48,7 @@ def setUp(self): self.local_log_location = 'local/log/location' self.filename_template = '{try_number}.log' self.s3_task_handler = S3TaskHandler( - self.local_log_location, - self.remote_log_base, - self.filename_template + self.local_log_location, self.remote_log_base, self.filename_template ) # Vivfy the hook now with the config override assert self.s3_task_handler.hook is not None @@ -83,11 +79,7 @@ def test_hook(self): @conf_vars({('logging', 'remote_log_conn_id'): 'aws_default'}) def test_hook_raises(self): - handler = S3TaskHandler( - self.local_log_location, - self.remote_log_base, - self.filename_template - ) + handler = S3TaskHandler(self.local_log_location, self.remote_log_base, self.filename_template) with mock.patch.object(handler.log, 'error') as mock_error: with mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook") as mock_hook: mock_hook.side_effect = Exception('Failed to connect') @@ -138,8 +130,10 @@ def test_read(self): self.conn.put_object(Bucket='bucket', Key=self.remote_log_key, Body=b'Log line\n') self.assertEqual( self.s3_task_handler.read(self.ti), - (['*** Reading remote log from s3://bucket/remote/log/location/1.log.\n' - 'Log line\n\n'], [{'end_of_log': True}]) + ( + ['*** Reading remote log from s3://bucket/remote/log/location/1.log.\nLog line\n\n'], + [{'end_of_log': True}], + ), ) def test_read_when_s3_log_missing(self): @@ -164,16 +158,24 @@ def test_write(self): self.s3_task_handler.s3_write('text', self.remote_log_location) # We shouldn't expect any error logs in the default working case. mock_error.assert_not_called() - body = boto3.resource('s3').Object( # pylint: disable=no-member - 'bucket', self.remote_log_key).get()['Body'].read() + body = ( + boto3.resource('s3') + .Object('bucket', self.remote_log_key) # pylint: disable=no-member + .get()['Body'] + .read() + ) self.assertEqual(body, b'text') def test_write_existing(self): self.conn.put_object(Bucket='bucket', Key=self.remote_log_key, Body=b'previous ') self.s3_task_handler.s3_write('text', self.remote_log_location) - body = boto3.resource('s3').Object( # pylint: disable=no-member - 'bucket', self.remote_log_key).get()['Body'].read() + body = ( + boto3.resource('s3') + .Object('bucket', self.remote_log_key) # pylint: disable=no-member + .get()['Body'] + .read() + ) self.assertEqual(body, b'previous \ntext') @@ -183,8 +185,7 @@ def test_write_raises(self): with mock.patch.object(handler.log, 'error') as mock_error: handler.s3_write('text', url) self.assertEqual - mock_error.assert_called_once_with( - 'Could not write logs to %s', url, exc_info=True) + mock_error.assert_called_once_with('Could not write logs to %s', url, exc_info=True) def test_close(self): self.s3_task_handler.set_context(self.ti) @@ -192,8 +193,7 @@ def test_close(self): self.s3_task_handler.close() # Should not raise - boto3.resource('s3').Object( # pylint: disable=no-member - 'bucket', self.remote_log_key).get() + boto3.resource('s3').Object('bucket', self.remote_log_key).get() # pylint: disable=no-member def test_close_no_upload(self): self.ti.raw = True @@ -202,5 +202,4 @@ def test_close_no_upload(self): self.s3_task_handler.close() with self.assertRaises(self.conn.exceptions.NoSuchKey): - boto3.resource('s3').Object( # pylint: disable=no-member - 'bucket', self.remote_log_key).get() + boto3.resource('s3').Object('bucket', self.remote_log_key).get() # pylint: disable=no-member diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 15bcfcfac4698..5357c4c4cc27b 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -35,33 +35,32 @@ 'database': 'TEST_DATABASE', 'outputLocation': 's3://test_s3_bucket/', 'client_request_token': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595', - 'workgroup': 'primary' + 'workgroup': 'primary', } -query_context = { - 'Database': MOCK_DATA['database'] -} -result_configuration = { - 'OutputLocation': MOCK_DATA['outputLocation'] -} +query_context = {'Database': MOCK_DATA['database']} +result_configuration = {'OutputLocation': MOCK_DATA['outputLocation']} # pylint: disable=unused-argument class TestAWSAthenaOperator(unittest.TestCase): - def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, } - self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', - default_args=args, - schedule_interval='@once') - self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE', - database='TEST_DATABASE', output_location='s3://test_s3_bucket/', - client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595', - sleep_time=0, max_tries=3, dag=self.dag) + self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once') + self.athena = AWSAthenaOperator( + task_id='test_aws_athena_operator', + query='SELECT * FROM TEST_TABLE', + database='TEST_DATABASE', + output_location='s3://test_s3_bucket/', + client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595', + sleep_time=0, + max_tries=3, + dag=self.dag, + ) def test_init(self): self.assertEqual(self.athena.task_id, MOCK_DATA['task_id']) @@ -78,8 +77,13 @@ def test_init(self): @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status): self.athena.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, - MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], + query_context, + result_configuration, + MOCK_DATA['client_request_token'], + MOCK_DATA['workgroup'], + ) self.assertEqual(mock_check_query_status.call_count, 1) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "SUCCESS",)) @@ -87,8 +91,13 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status): self.athena.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, - MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], + query_context, + result_configuration, + MOCK_DATA['client_request_token'], + MOCK_DATA['workgroup'], + ) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=(None, None,)) @@ -97,20 +106,31 @@ def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_ def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status): with self.assertRaises(Exception): self.athena.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, - MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], + query_context, + result_configuration, + MOCK_DATA['client_request_token'], + MOCK_DATA['workgroup'], + ) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'get_state_change_reason') @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "FAILED",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') - def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_query_status, - mock_get_state_change_reason): + def test_hook_run_failure_query( + self, mock_conn, mock_run_query, mock_check_query_status, mock_get_state_change_reason + ): with self.assertRaises(Exception): self.athena.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, - MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], + query_context, + result_configuration, + MOCK_DATA['client_request_token'], + MOCK_DATA['workgroup'], + ) self.assertEqual(mock_check_query_status.call_count, 2) self.assertEqual(mock_get_state_change_reason.call_count, 1) @@ -120,8 +140,13 @@ def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_quer def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status): with self.assertRaises(Exception): self.athena.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, - MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], + query_context, + result_configuration, + MOCK_DATA['client_request_token'], + MOCK_DATA['workgroup'], + ) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "RUNNING",)) @@ -130,8 +155,13 @@ def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_qu def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status): with self.assertRaises(Exception): self.athena.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, - MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], + query_context, + result_configuration, + MOCK_DATA['client_request_token'], + MOCK_DATA['workgroup'], + ) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) @@ -141,6 +171,7 @@ def test_xcom_push_and_pull(self, mock_conn, mock_run_query, mock_check_query_st ti = TaskInstance(task=self.athena, execution_date=timezone.utcnow()) ti.run() - self.assertEqual(ti.xcom_pull(task_ids='test_aws_athena_operator'), - ATHENA_QUERY_ID) + self.assertEqual(ti.xcom_pull(task_ids='test_aws_athena_operator'), ATHENA_QUERY_ID) + + # pylint: enable=unused-argument diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 45c94c82457b1..395a21d27a97b 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -93,9 +93,7 @@ def test_init(self): self.assertEqual(self.batch.hook.aws_conn_id, "airflow_test") self.assertEqual(self.batch.hook.client, self.client_mock) - self.get_client_type_mock.assert_called_once_with( - "batch", region_name="eu-west-1" - ) + self.get_client_type_mock.assert_called_once_with("batch", region_name="eu-west-1") def test_template_fields_overrides(self): self.assertEqual(self.batch.template_fields, ("job_name", "overrides", "parameters",)) @@ -144,9 +142,7 @@ def test_wait_job_complete_using_waiters(self, check_mock): self.batch.waiters = mock_waiters self.client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES - self.client_mock.describe_jobs.return_value = { - "jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}] - } + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} self.batch.execute(None) mock_waiters.wait_for_job.assert_called_once_with(JOB_ID) @@ -155,6 +151,4 @@ def test_wait_job_complete_using_waiters(self, check_mock): def test_kill_job(self): self.client_mock.terminate_job.return_value = {} self.batch.on_kill() - self.client_mock.terminate_job.assert_called_once_with( - jobId=JOB_ID, reason="Task killed by the user" - ) + self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user") diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py index 941182b53fd58..16557f218c3ed 100644 --- a/tests/providers/amazon/aws/operators/test_cloud_formation.py +++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py @@ -20,7 +20,8 @@ from airflow.models.dag import DAG from airflow.providers.amazon.aws.operators.cloud_formation import ( - CloudFormationCreateStackOperator, CloudFormationDeleteStackOperator, + CloudFormationCreateStackOperator, + CloudFormationDeleteStackOperator, ) from airflow.utils import timezone @@ -28,12 +29,8 @@ class TestCloudFormationCreateStackOperator(unittest.TestCase): - def setUp(self): - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} # Mock out the cloudformation_client (moto fails with an exception). self.cloudformation_client_mock = MagicMock() @@ -53,28 +50,21 @@ def test_create_stack(self): operator = CloudFormationCreateStackOperator( task_id='test_task', stack_name=stack_name, - params={ - 'TimeoutInMinutes': timeout, - 'TemplateBody': template_body - }, + params={'TimeoutInMinutes': timeout, 'TemplateBody': template_body}, dag=DAG('test_dag_id', default_args=self.args), ) with patch('boto3.session.Session', self.boto3_session_mock): operator.execute(self.mock_context) - self.cloudformation_client_mock.create_stack.assert_any_call(StackName=stack_name, - TemplateBody=template_body, - TimeoutInMinutes=timeout) + self.cloudformation_client_mock.create_stack.assert_any_call( + StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout + ) class TestCloudFormationDeleteStackOperator(unittest.TestCase): - def setUp(self): - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} # Mock out the cloudformation_client (moto fails with an exception). self.cloudformation_client_mock = MagicMock() @@ -90,9 +80,7 @@ def test_delete_stack(self): stack_name = "myStackToBeDeleted" operator = CloudFormationDeleteStackOperator( - task_id='test_task', - stack_name=stack_name, - dag=DAG('test_dag_id', default_args=self.args), + task_id='test_task', stack_name=stack_name, dag=DAG('test_dag_id', default_args=self.args), ) with patch('boto3.session.Session', self.boto3_session_mock): diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index 94bdd9922b079..f6dd924089dab 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -82,9 +82,7 @@ def no_datasync(x): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class AWSDataSyncTestCaseBase(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -97,11 +95,7 @@ def setUp(self): "start_date": DEFAULT_DATE, } - self.dag = DAG( - TEST_DAG_ID + "test_schedule_dag_once", - default_args=args, - schedule_interval="@once", - ) + self.dag = DAG(TEST_DAG_ID + "test_schedule_dag_once", default_args=args, schedule_interval="@once",) self.client = boto3.client("datasync", region_name="us-east-1") @@ -112,8 +106,7 @@ def setUp(self): **MOCK_DATA["create_destination_location_kwargs"] )["LocationArn"] self.task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] def tearDown(self): @@ -130,16 +123,14 @@ def tearDown(self): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAWSDataSyncOperatorCreate(AWSDataSyncTestCaseBase): def set_up_operator( self, task_arn=None, source_location_uri=SOURCE_LOCATION_URI, destination_location_uri=DESTINATION_LOCATION_URI, - allow_random_location_choice=False + allow_random_location_choice=False, ): # Create operator self.datasync = AWSDataSyncOperator( @@ -171,31 +162,20 @@ def test_init(self, mock_get_conn): # Defaults self.assertEqual(self.datasync.aws_conn_id, "aws_default") self.assertFalse(self.datasync.allow_random_task_choice) - self.assertFalse( # Empty dict - self.datasync.task_execution_kwargs - ) + self.assertFalse(self.datasync.task_execution_kwargs) # Empty dict # Assignments + self.assertEqual(self.datasync.source_location_uri, MOCK_DATA["source_location_uri"]) self.assertEqual( - self.datasync.source_location_uri, MOCK_DATA["source_location_uri"] - ) - self.assertEqual( - self.datasync.destination_location_uri, - MOCK_DATA["destination_location_uri"], - ) - self.assertEqual( - self.datasync.create_task_kwargs, MOCK_DATA["create_task_kwargs"] + self.datasync.destination_location_uri, MOCK_DATA["destination_location_uri"], ) + self.assertEqual(self.datasync.create_task_kwargs, MOCK_DATA["create_task_kwargs"]) self.assertEqual( - self.datasync.create_source_location_kwargs, - MOCK_DATA["create_source_location_kwargs"], + self.datasync.create_source_location_kwargs, MOCK_DATA["create_source_location_kwargs"], ) self.assertEqual( - self.datasync.create_destination_location_kwargs, - MOCK_DATA["create_destination_location_kwargs"], - ) - self.assertFalse( - self.datasync.allow_random_location_choice + self.datasync.create_destination_location_kwargs, MOCK_DATA["create_destination_location_kwargs"], ) + self.assertFalse(self.datasync.allow_random_location_choice) # ### Check mocks: mock_get_conn.assert_not_called() @@ -209,9 +189,7 @@ def test_init_fails(self, mock_get_conn): with self.assertRaises(AirflowException): self.set_up_operator(destination_location_uri=None) with self.assertRaises(AirflowException): - self.set_up_operator( - source_location_uri=None, destination_location_uri=None - ) + self.set_up_operator(source_location_uri=None, destination_location_uri=None) # ### Check mocks: mock_get_conn.assert_not_called() @@ -305,9 +283,7 @@ def create_task_many_locations(self, mock_get_conn): # ### Begin tests: # Create duplicate source location to choose from - self.client.create_location_smb( - **MOCK_DATA["create_source_location_kwargs"] - ) + self.client.create_location_smb(**MOCK_DATA["create_source_location_kwargs"]) self.set_up_operator(task_arn=self.task_arn) with self.assertRaises(AirflowException): @@ -323,8 +299,7 @@ def test_execute_specific_task(self, mock_get_conn): mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) @@ -351,16 +326,14 @@ def test_xcom_push(self, mock_get_conn): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAWSDataSyncOperatorGetTasks(AWSDataSyncTestCaseBase): def set_up_operator( self, task_arn=None, source_location_uri=SOURCE_LOCATION_URI, destination_location_uri=DESTINATION_LOCATION_URI, - allow_random_task_choice=False + allow_random_task_choice=False, ): # Create operator self.datasync = AWSDataSyncOperator( @@ -370,9 +343,7 @@ def set_up_operator( source_location_uri=source_location_uri, destination_location_uri=destination_location_uri, create_source_location_kwargs=MOCK_DATA["create_source_location_kwargs"], - create_destination_location_kwargs=MOCK_DATA[ - "create_destination_location_kwargs" - ], + create_destination_location_kwargs=MOCK_DATA["create_destination_location_kwargs"], create_task_kwargs=MOCK_DATA["create_task_kwargs"], allow_random_task_choice=allow_random_task_choice, wait_interval_seconds=0, @@ -386,12 +357,9 @@ def test_init(self, mock_get_conn): self.assertEqual(self.datasync.aws_conn_id, "aws_default") self.assertFalse(self.datasync.allow_random_location_choice) # Assignments + self.assertEqual(self.datasync.source_location_uri, MOCK_DATA["source_location_uri"]) self.assertEqual( - self.datasync.source_location_uri, MOCK_DATA["source_location_uri"] - ) - self.assertEqual( - self.datasync.destination_location_uri, - MOCK_DATA["destination_location_uri"], + self.datasync.destination_location_uri, MOCK_DATA["destination_location_uri"], ) self.assertFalse(self.datasync.allow_random_task_choice) # ### Check mocks: @@ -407,9 +375,7 @@ def test_init_fails(self, mock_get_conn): with self.assertRaises(AirflowException): self.set_up_operator(destination_location_uri=None) with self.assertRaises(AirflowException): - self.set_up_operator( - source_location_uri=None, destination_location_uri=None - ) + self.set_up_operator(source_location_uri=None, destination_location_uri=None) # ### Check mocks: mock_get_conn.assert_not_called() @@ -495,8 +461,7 @@ def test_get_many_tasks(self, mock_get_conn): self.set_up_operator() self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, ) # Check how many tasks and locations we have @@ -525,8 +490,7 @@ def test_execute_specific_task(self, mock_get_conn): mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) @@ -545,9 +509,7 @@ def test_xcom_push(self, mock_get_conn): self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() - pushed_task_arn = ti.xcom_pull( - task_ids=self.datasync.task_id, key="return_value" - )["TaskArn"] + pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"] self.assertEqual(pushed_task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_called() @@ -555,9 +517,7 @@ def test_xcom_push(self, mock_get_conn): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAWSDataSyncOperatorUpdate(AWSDataSyncTestCaseBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -567,9 +527,7 @@ def set_up_operator(self, task_arn="self", update_task_kwargs="default"): if task_arn == "self": task_arn = self.task_arn if update_task_kwargs == "default": - update_task_kwargs = { - "Options": {"VerifyMode": "BEST_EFFORT", "Atime": "NONE"} - } + update_task_kwargs = {"Options": {"VerifyMode": "BEST_EFFORT", "Atime": "NONE"}} # Create operator self.datasync = AWSDataSyncOperator( task_id="test_aws_datasync_update_task_operator", @@ -587,9 +545,7 @@ def test_init(self, mock_get_conn): self.assertEqual(self.datasync.aws_conn_id, "aws_default") # Assignments self.assertEqual(self.datasync.task_arn, self.task_arn) - self.assertEqual( - self.datasync.update_task_kwargs, MOCK_DATA["update_task_kwargs"] - ) + self.assertEqual(self.datasync.update_task_kwargs, MOCK_DATA["update_task_kwargs"]) # ### Check mocks: mock_get_conn.assert_not_called() @@ -632,8 +588,7 @@ def test_execute_specific_task(self, mock_get_conn): mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) @@ -652,9 +607,7 @@ def test_xcom_push(self, mock_get_conn): self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() - pushed_task_arn = ti.xcom_pull( - task_ids=self.datasync.task_id, key="return_value" - )["TaskArn"] + pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"] self.assertEqual(pushed_task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_called() @@ -662,9 +615,7 @@ def test_xcom_push(self, mock_get_conn): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAWSDataSyncOperator(AWSDataSyncTestCaseBase): def set_up_operator(self, task_arn="self"): if task_arn == "self": @@ -727,9 +678,7 @@ def test_execute_task(self, mock_get_conn): self.assertEqual(len(locations["Locations"]), len_locations_before) # Check with the DataSync client what happened - task_execution = self.client.describe_task_execution( - TaskExecutionArn=task_execution_arn - ) + task_execution = self.client.describe_task_execution(TaskExecutionArn=task_execution_arn) self.assertEqual(task_execution["Status"], "SUCCESS") # Insist that this specific task was executed, not anything else @@ -782,9 +731,7 @@ def kill_task(*args): # Verify the task was killed task = self.client.describe_task(TaskArn=self.task_arn) self.assertEqual(task["Status"], "AVAILABLE") - task_execution = self.client.describe_task_execution( - TaskExecutionArn=task_execution_arn - ) + task_execution = self.client.describe_task_execution(TaskExecutionArn=task_execution_arn) self.assertEqual(task_execution["Status"], "ERROR") # ### Check mocks: mock_get_conn.assert_called() @@ -794,8 +741,7 @@ def test_execute_specific_task(self, mock_get_conn): mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) @@ -822,9 +768,7 @@ def test_xcom_push(self, mock_get_conn): @mock_datasync @mock.patch.object(AWSDataSyncHook, "get_conn") -@unittest.skipIf( - mock_datasync == no_datasync, "moto datasync package missing" -) # pylint: disable=W0143 +@unittest.skipIf(mock_datasync == no_datasync, "moto datasync package missing") # pylint: disable=W0143 class TestAWSDataSyncOperatorDelete(AWSDataSyncTestCaseBase): def set_up_operator(self, task_arn="self"): if task_arn == "self": @@ -890,8 +834,7 @@ def test_execute_specific_task(self, mock_get_conn): mock_get_conn.return_value = self.client # ### Begin tests: task_arn = self.client.create_task( - SourceLocationArn=self.source_location_arn, - DestinationLocationArn=self.destination_location_arn, + SourceLocationArn=self.source_location_arn, DestinationLocationArn=self.destination_location_arn, )["TaskArn"] self.set_up_operator(task_arn=task_arn) @@ -910,9 +853,7 @@ def test_xcom_push(self, mock_get_conn): self.set_up_operator() ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow()) ti.run() - pushed_task_arn = ti.xcom_pull( - task_ids=self.datasync.task_id, key="return_value" - )["TaskArn"] + pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"] self.assertEqual(pushed_task_arn, self.task_arn) # ### Check mocks: mock_get_conn.assert_called() diff --git a/tests/providers/amazon/aws/operators/test_ec2_start_instance.py b/tests/providers/amazon/aws/operators/test_ec2_start_instance.py index 4081d692103d8..627a10db681f1 100644 --- a/tests/providers/amazon/aws/operators/test_ec2_start_instance.py +++ b/tests/providers/amazon/aws/operators/test_ec2_start_instance.py @@ -26,7 +26,6 @@ class TestEC2Operator(unittest.TestCase): - def test_init(self): ec2_operator = EC2StartInstanceOperator( task_id="task_test", @@ -45,20 +44,11 @@ def test_init(self): def test_start_instance(self): # create instance ec2_hook = EC2Hook() - instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) instance_id = instances[0].instance_id # start instance - start_test = EC2StartInstanceOperator( - task_id="start_test", - instance_id=instance_id, - ) + start_test = EC2StartInstanceOperator(task_id="start_test", instance_id=instance_id,) start_test.execute(None) # assert instance state is running - self.assertEqual( - ec2_hook.get_instance_state(instance_id=instance_id), - "running" - ) + self.assertEqual(ec2_hook.get_instance_state(instance_id=instance_id), "running") diff --git a/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py b/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py index 35764c1bff860..0a23401f96081 100644 --- a/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py +++ b/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py @@ -26,7 +26,6 @@ class TestEC2Operator(unittest.TestCase): - def test_init(self): ec2_operator = EC2StopInstanceOperator( task_id="task_test", @@ -45,20 +44,11 @@ def test_init(self): def test_stop_instance(self): # create instance ec2_hook = EC2Hook() - instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) instance_id = instances[0].instance_id # stop instance - stop_test = EC2StopInstanceOperator( - task_id="stop_test", - instance_id=instance_id, - ) + stop_test = EC2StopInstanceOperator(task_id="stop_test", instance_id=instance_id,) stop_test.execute(None) # assert instance state is running - self.assertEqual( - ec2_hook.get_instance_state(instance_id=instance_id), - "stopped" - ) + self.assertEqual(ec2_hook.get_instance_state(instance_id=instance_id), "stopped") diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 49b27d2ea7276..16f32b6e5c2c3 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -27,6 +27,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.operators.ecs import ECSOperator +# fmt: off RESPONSE_WITHOUT_FAILURES = { "failures": [], "tasks": [ @@ -47,10 +48,10 @@ } ] } +# fmt: on class TestECSOperator(unittest.TestCase): - @mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsBaseHook') def set_up_operator(self, aws_hook_mock, **kwargs): self.aws_hook_mock = aws_hook_mock @@ -63,17 +64,13 @@ def set_up_operator(self, aws_hook_mock, **kwargs): 'aws_conn_id': None, 'region_name': 'eu-west-1', 'group': 'group', - 'placement_constraints': [{ - 'expression': 'attribute:ecs.instance-type =~ t2.*', - 'type': 'memberOf' - }], + 'placement_constraints': [ + {'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf'} + ], 'network_configuration': { - 'awsvpcConfiguration': { - 'securityGroups': ['sg-123abc'], - 'subnets': ['subnet-123456ab'] - } + 'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab']} }, - 'propagate_tags': 'TASK_DEFINITION' + 'propagate_tags': 'TASK_DEFINITION', } self.ecs = ECSOperator(**self.ecs_operator_args, **kwargs) self.ecs.get_hook() @@ -94,16 +91,17 @@ def test_init(self): def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides',)) - @parameterized.expand([ - ['EC2', None], - ['FARGATE', None], - ['EC2', {'testTagKey': 'testTagValue'}], - ['', {'testTagKey': 'testTagValue'}], - ]) + @parameterized.expand( + [ + ['EC2', None], + ['FARGATE', None], + ['EC2', {'testTagKey': 'testTagValue'}], + ['', {'testTagKey': 'testTagValue'}], + ] + ) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') - def test_execute_without_failures(self, launch_type, tags, - check_mock, wait_mock): + def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock): self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter client_mock = self.aws_hook_mock.return_value.get_conn.return_value @@ -126,26 +124,19 @@ def test_execute_without_failures(self, launch_type, tags, startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', - placementConstraints=[ - { - 'expression': 'attribute:ecs.instance-type =~ t2.*', - 'type': 'memberOf' - } - ], + placementConstraints=[{'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf'}], networkConfiguration={ - 'awsvpcConfiguration': { - 'securityGroups': ['sg-123abc'], - 'subnets': ['subnet-123456ab'] - } + 'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab']} }, propagateTags='TASK_DEFINITION', - **extend_args + **extend_args, ) wait_mock.assert_called_once_with() check_mock.assert_called_once_with() - self.assertEqual(self.ecs.arn, - 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55') + self.assertEqual( + self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' + ) def test_execute_with_failures(self): client_mock = self.aws_hook_mock.return_value.get_conn.return_value @@ -164,19 +155,11 @@ def test_execute_with_failures(self): startedBy=mock.ANY, # Can by 'airflow' or 'Airflow' taskDefinition='t', group='group', - placementConstraints=[ - { - 'expression': 'attribute:ecs.instance-type =~ t2.*', - 'type': 'memberOf' - } - ], + placementConstraints=[{'expression': 'attribute:ecs.instance-type =~ t2.*', 'type': 'memberOf'}], networkConfiguration={ - 'awsvpcConfiguration': { - 'securityGroups': ['sg-123abc'], - 'subnets': ['subnet-123456ab'], - } + 'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab'],} }, - propagateTags='TASK_DEFINITION' + propagateTags='TASK_DEFINITION', ) def test_wait_end_tasks(self): @@ -186,10 +169,8 @@ def test_wait_end_tasks(self): self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') - client_mock.get_waiter.return_value.wait.assert_called_once_with( - cluster='c', tasks=['arn']) - self.assertEqual( - sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) + client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) + self.assertEqual(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): client_mock = mock.Mock() @@ -197,13 +178,7 @@ def test_check_success_tasks_raises(self): self.ecs.client = client_mock client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'foo', - 'lastStatus': 'STOPPED', - 'exitCode': 1 - }] - }] + 'tasks': [{'containers': [{'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1}]}] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() @@ -213,20 +188,14 @@ def test_check_success_tasks_raises(self): self.assertIn("'name': 'foo'", str(e.exception)) self.assertIn("'lastStatus': 'STOPPED'", str(e.exception)) self.assertIn("'exitCode': 1", str(e.exception)) - client_mock.describe_tasks.assert_called_once_with( - cluster='c', tasks=['arn']) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'container-name', - 'lastStatus': 'PENDING' - }] - }] + 'tasks': [{'containers': [{'name': 'container-name', 'lastStatus': 'PENDING'}]}] } with self.assertRaises(Exception) as e: self.ecs._check_success_task() @@ -234,76 +203,63 @@ def test_check_success_tasks_raises_pending(self): self.assertIn("This task is still pending ", str(e.exception)) self.assertIn("'name': 'container-name'", str(e.exception)) self.assertIn("'lastStatus': 'PENDING'", str(e.exception)) - client_mock.describe_tasks.assert_called_once_with( - cluster='c', tasks=['arn']) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_multiple(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'foo', - 'exitCode': 1 - }, { - 'name': 'bar', - 'lastStatus': 'STOPPED', - 'exitCode': 0 - }] - }] + 'tasks': [ + { + 'containers': [ + {'name': 'foo', 'exitCode': 1}, + {'name': 'bar', 'lastStatus': 'STOPPED', 'exitCode': 0}, + ] + } + ] } self.ecs._check_success_task() - client_mock.describe_tasks.assert_called_once_with( - cluster='c', tasks=['arn']) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_host_terminated_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'stoppedReason': 'Host EC2 (instance i-1234567890abcdef) terminated.', - "containers": [ - { - "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", # noqa: E501 # pylint: disable=line-too-long - "lastStatus": "RUNNING", - "name": "wordpress", - "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55" # noqa: E501 # pylint: disable=line-too-long - } - ], - "desiredStatus": "STOPPED", - "lastStatus": "STOPPED", - "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 # pylint: disable=line-too-long - "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11" # noqa: E501 # pylint: disable=line-too-long - - }] + 'tasks': [ + { + 'stoppedReason': 'Host EC2 (instance i-1234567890abcdef) terminated.', + "containers": [ + { + "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868", # noqa: E501 # pylint: disable=line-too-long + "lastStatus": "RUNNING", + "name": "wordpress", + "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 # pylint: disable=line-too-long + } + ], + "desiredStatus": "STOPPED", + "lastStatus": "STOPPED", + "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", # noqa: E501 # pylint: disable=line-too-long + "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11", # noqa: E501 # pylint: disable=line-too-long + } + ] } with self.assertRaises(AirflowException) as e: self.ecs._check_success_task() - self.assertIn( - "The task was stopped because the host instance terminated:", - str(e.exception)) + self.assertIn("The task was stopped because the host instance terminated:", str(e.exception)) self.assertIn("Host EC2 (", str(e.exception)) self.assertIn(") terminated", str(e.exception)) - client_mock.describe_tasks.assert_called_once_with( - cluster='c', tasks=['arn']) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): client_mock = mock.Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { - 'tasks': [{ - 'containers': [{ - 'name': 'container-name', - 'lastStatus': 'STOPPED', - 'exitCode': 0 - }] - }] + 'tasks': [{'containers': [{'name': 'container-name', 'lastStatus': 'STOPPED', 'exitCode': 0}]}] } self.ecs._check_success_task() - client_mock.describe_tasks.assert_called_once_with( - cluster='c', tasks=['arn']) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) diff --git a/tests/providers/amazon/aws/operators/test_ecs_system.py b/tests/providers/amazon/aws/operators/test_ecs_system.py index 1a6eec70a1d57..5a9ad10d82ff3 100644 --- a/tests/providers/amazon/aws/operators/test_ecs_system.py +++ b/tests/providers/amazon/aws/operators/test_ecs_system.py @@ -55,20 +55,17 @@ class ECSSystemTest(AmazonSystemTest): @classmethod def setup_class(cls): cls.create_connection( - aws_conn_id=cls.aws_conn_id, - region=cls._region_name(), + aws_conn_id=cls.aws_conn_id, region=cls._region_name(), ) # create ecs cluster if it does not exist cls.create_ecs_cluster( - aws_conn_id=cls.aws_conn_id, - cluster_name=cls.cluster, + aws_conn_id=cls.aws_conn_id, cluster_name=cls.cluster, ) # create task_definition if it does not exist task_definition_exists = cls.is_ecs_task_definition_exists( - aws_conn_id=cls.aws_conn_id, - task_definition=cls.task_definition, + aws_conn_id=cls.aws_conn_id, task_definition=cls.task_definition, ) if not task_definition_exists: cls.create_ecs_task_definition( @@ -87,12 +84,10 @@ def teardown_class(cls): # remove all created/existing resources in tear down if cls._remove_resources(): cls.delete_ecs_cluster( - aws_conn_id=cls.aws_conn_id, - cluster_name=cls.cluster, + aws_conn_id=cls.aws_conn_id, cluster_name=cls.cluster, ) cls.delete_ecs_task_definition( - aws_conn_id=cls.aws_conn_id, - task_definition=cls.task_definition, + aws_conn_id=cls.aws_conn_id, task_definition=cls.task_definition, ) def test_run_example_dag_ecs_fargate_dag(self): diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 49808e3a6da2c..2d3e1a6c79c38 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -33,39 +33,28 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -ADD_STEPS_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - }, - 'StepIds': ['s-2LH3R5GW3A53T'] -} +ADD_STEPS_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'StepIds': ['s-2LH3R5GW3A53T']} TEMPLATE_SEARCHPATH = os.path.join( - AIRFLOW_MAIN_FOLDER, - 'tests', 'providers', 'amazon', 'aws', 'config_templates' + AIRFLOW_MAIN_FOLDER, 'tests', 'providers', 'amazon', 'aws', 'config_templates' ) class TestEmrAddStepsOperator(unittest.TestCase): # When - _config = [{ - 'Name': 'test_step', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - '{{ macros.ds_add(ds, -1) }}', - '{{ ds }}' - ] + _config = [ + { + 'Name': 'test_step', + 'ActionOnFailure': 'CONTINUE', + 'HadoopJarStep': { + 'Jar': 'command-runner.jar', + 'Args': ['/usr/lib/spark/bin/run-example', '{{ macros.ds_add(ds, -1) }}', '{{ ds }}'], + }, } - }] + ] def setUp(self): - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} # Mock out the emr_client (moto has incorrect response) self.emr_client_mock = MagicMock() @@ -82,7 +71,7 @@ def setUp(self): job_flow_id='j-8989898989', aws_conn_id='aws_default', steps=self._config, - dag=DAG('test_dag_id', default_args=self.args) + dag=DAG('test_dag_id', default_args=self.args), ) def test_init(self): @@ -93,45 +82,37 @@ def test_render_template(self): ti = TaskInstance(self.operator, DEFAULT_DATE) ti.render_templates() - expected_args = [{ - 'Name': 'test_step', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"), - DEFAULT_DATE.strftime("%Y-%m-%d"), - ] + expected_args = [ + { + 'Name': 'test_step', + 'ActionOnFailure': 'CONTINUE', + 'HadoopJarStep': { + 'Jar': 'command-runner.jar', + 'Args': [ + '/usr/lib/spark/bin/run-example', + (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"), + DEFAULT_DATE.strftime("%Y-%m-%d"), + ], + }, } - }] + ] self.assertListEqual(self.operator.steps, expected_args) def test_render_template_2(self): - dag = DAG( - dag_id='test_xcom', default_args=self.args) + dag = DAG(dag_id='test_xcom', default_args=self.args) xcom_steps = [ { 'Name': 'test_step1', 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example1' - ] - } - }, { + 'HadoopJarStep': {'Jar': 'command-runner.jar', 'Args': ['/usr/lib/spark/bin/run-example1']}, + }, + { 'Name': 'test_step2', 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example2' - ] - } - } + 'HadoopJarStep': {'Jar': 'command-runner.jar', 'Args': ['/usr/lib/spark/bin/run-example2']}, + }, ] make_steps = DummyOperator(task_id='make_steps', dag=dag, owner='airflow') @@ -146,34 +127,30 @@ def test_render_template_2(self): job_flow_id='j-8989898989', aws_conn_id='aws_default', steps="{{ ti.xcom_pull(task_ids='make_steps',key='steps') }}", - dag=dag) + dag=dag, + ) with patch('boto3.session.Session', self.boto3_session_mock): ti = TaskInstance(task=test_task, execution_date=execution_date) ti.run() self.emr_client_mock.add_job_flow_steps.assert_called_once_with( - JobFlowId='j-8989898989', - Steps=xcom_steps) + JobFlowId='j-8989898989', Steps=xcom_steps + ) def test_render_template_from_file(self): dag = DAG( dag_id='test_file', default_args=self.args, template_searchpath=TEMPLATE_SEARCHPATH, - template_undefined=StrictUndefined + template_undefined=StrictUndefined, ) file_steps = [ { 'Name': 'test_step1', 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example1' - ] - } + 'HadoopJarStep': {'Jar': 'command-runner.jar', 'Args': ['/usr/lib/spark/bin/run-example1']}, } ] @@ -186,15 +163,16 @@ def test_render_template_from_file(self): job_flow_id='j-8989898989', aws_conn_id='aws_default', steps='steps.j2.json', - dag=dag) + dag=dag, + ) with patch('boto3.session.Session', self.boto3_session_mock): ti = TaskInstance(task=test_task, execution_date=execution_date) ti.run() self.emr_client_mock.add_job_flow_steps.assert_called_once_with( - JobFlowId='j-8989898989', - Steps=file_steps) + JobFlowId='j-8989898989', Steps=file_steps + ) def test_execute_returns_step_id(self): self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN @@ -208,8 +186,9 @@ def test_init_with_cluster_name(self): self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN with patch('boto3.session.Session', self.boto3_session_mock): - with patch('airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name') \ - as mock_get_cluster_id_by_name: + with patch( + 'airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name' + ) as mock_get_cluster_id_by_name: mock_get_cluster_id_by_name.return_value = expected_job_flow_id operator = EmrAddStepsOperator( @@ -217,7 +196,7 @@ def test_init_with_cluster_name(self): job_flow_name='test_cluster', cluster_states=['RUNNING', 'WAITING'], aws_conn_id='aws_default', - dag=DAG('test_dag_id', default_args=self.args) + dag=DAG('test_dag_id', default_args=self.args), ) operator.execute(self.mock_context) @@ -229,8 +208,9 @@ def test_init_with_cluster_name(self): def test_init_with_nonexistent_cluster_name(self): cluster_name = 'test_cluster' - with patch('airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name') \ - as mock_get_cluster_id_by_name: + with patch( + 'airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name' + ) as mock_get_cluster_id_by_name: mock_get_cluster_id_by_name.return_value = None operator = EmrAddStepsOperator( @@ -238,7 +218,7 @@ def test_init_with_nonexistent_cluster_name(self): job_flow_name=cluster_name, cluster_states=['RUNNING', 'WAITING'], aws_conn_id='aws_default', - dag=DAG('test_dag_id', default_args=self.args) + dag=DAG('test_dag_id', default_args=self.args), ) with self.assertRaises(AirflowException) as error: diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 7be4326dc9bef..5050b2cea51c1 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -32,16 +32,10 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -RUN_JOB_FLOW_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - }, - 'JobFlowId': 'j-8989898989' -} +RUN_JOB_FLOW_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'JobFlowId': 'j-8989898989'} TEMPLATE_SEARCHPATH = os.path.join( - AIRFLOW_MAIN_FOLDER, - 'tests', 'providers', 'amazon', 'aws', 'config_templates' + AIRFLOW_MAIN_FOLDER, 'tests', 'providers', 'amazon', 'aws', 'config_templates' ) @@ -50,25 +44,20 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase): _config = { 'Name': 'test_job_flow', 'ReleaseLabel': '5.11.0', - 'Steps': [{ - 'Name': 'test_step', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - '{{ macros.ds_add(ds, -1) }}', - '{{ ds }}' - ] + 'Steps': [ + { + 'Name': 'test_step', + 'ActionOnFailure': 'CONTINUE', + 'HadoopJarStep': { + 'Jar': 'command-runner.jar', + 'Args': ['/usr/lib/spark/bin/run-example', '{{ macros.ds_add(ds, -1) }}', '{{ ds }}'], + }, } - }] + ], } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} # Mock out the emr_client (moto has incorrect response) self.emr_client_mock = MagicMock() @@ -81,8 +70,8 @@ def setUp(self): 'test_dag_id', default_args=args, template_searchpath=TEMPLATE_SEARCHPATH, - template_undefined=StrictUndefined - ) + template_undefined=StrictUndefined, + ), ) def test_init(self): @@ -98,18 +87,20 @@ def test_render_template(self): expected_args = { 'Name': 'test_job_flow', 'ReleaseLabel': '5.11.0', - 'Steps': [{ - 'Name': 'test_step', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"), - DEFAULT_DATE.strftime("%Y-%m-%d"), - ] + 'Steps': [ + { + 'Name': 'test_step', + 'ActionOnFailure': 'CONTINUE', + 'HadoopJarStep': { + 'Jar': 'command-runner.jar', + 'Args': [ + '/usr/lib/spark/bin/run-example', + (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"), + DEFAULT_DATE.strftime("%Y-%m-%d"), + ], + }, } - }] + ], } self.assertDictEqual(self.operator.job_flow_overrides, expected_args) @@ -132,18 +123,16 @@ def test_render_template_from_file(self): expected_args = { 'Name': 'test_job_flow', 'ReleaseLabel': '5.11.0', - 'Steps': [{ - 'Name': 'test_step', - 'ActionOnFailure': 'CONTINUE', - 'HadoopJarStep': { - 'Jar': 'command-runner.jar', - 'Args': [ - '/usr/lib/spark/bin/run-example', - '2016-12-31', - '2017-01-01', - ] + 'Steps': [ + { + 'Name': 'test_step', + 'ActionOnFailure': 'CONTINUE', + 'HadoopJarStep': { + 'Jar': 'command-runner.jar', + 'Args': ['/usr/lib/spark/bin/run-example', '2016-12-31', '2017-01-01',], + }, } - }] + ], } self.assertDictEqual(self.operator.job_flow_overrides, expected_args) diff --git a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py index 209e009a3fda2..ef284ef756963 100644 --- a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py +++ b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py @@ -26,26 +26,14 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -MODIFY_CLUSTER_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - }, - 'StepConcurrencyLevel': 1 -} +MODIFY_CLUSTER_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'StepConcurrencyLevel': 1} -MODIFY_CLUSTER_ERROR_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 400 - } -} +MODIFY_CLUSTER_ERROR_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 400}} class TestEmrModifyClusterOperator(unittest.TestCase): def setUp(self): - self.args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} # Mock out the emr_client (moto has incorrect response) self.emr_client_mock = MagicMock() @@ -62,7 +50,7 @@ def setUp(self): cluster_id='j-8989898989', step_concurrency_level=1, aws_conn_id='aws_default', - dag=DAG('test_dag_id', default_args=self.args) + dag=DAG('test_dag_id', default_args=self.args), ) def test_init(self): diff --git a/tests/providers/amazon/aws/operators/test_emr_system.py b/tests/providers/amazon/aws/operators/test_emr_system.py index a5f0e1bfc8f6a..c9433438989a6 100644 --- a/tests/providers/amazon/aws/operators/test_emr_system.py +++ b/tests/providers/amazon/aws/operators/test_emr_system.py @@ -22,6 +22,7 @@ class EmrSystemTest(AmazonSystemTest): """ System tests for AWS EMR operators """ + @classmethod def setup_class(cls): cls.create_emr_default_roles() diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py index 6f564a7947f5f..87dbe8df6ef02 100644 --- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py @@ -21,11 +21,7 @@ from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator -TERMINATE_SUCCESS_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - } -} +TERMINATE_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}} class TestEmrTerminateJobFlowOperator(unittest.TestCase): @@ -43,9 +39,7 @@ def setUp(self): def test_execute_terminates_the_job_flow_and_does_not_error(self): with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrTerminateJobFlowOperator( - task_id='test_task', - job_flow_id='j-8989898989', - aws_conn_id='aws_default' + task_id='test_task', job_flow_id='j-8989898989', aws_conn_id='aws_default' ) operator.execute(None) diff --git a/tests/providers/amazon/aws/operators/test_example_s3_bucket.py b/tests/providers/amazon/aws/operators/test_example_s3_bucket.py index 84be5f7c1af7a..3873fa1fd71b9 100644 --- a/tests/providers/amazon/aws/operators/test_example_s3_bucket.py +++ b/tests/providers/amazon/aws/operators/test_example_s3_bucket.py @@ -21,6 +21,7 @@ class S3BucketExampleDagsSystemTest(AmazonSystemTest): """ System tests for AWS S3 operators """ + @provide_aws_context() def test_run_example_dag_s3(self): self.run_dag('s3_bucket_dag', AWS_DAG_FOLDER) diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index 83c2c0f3bfd79..db4d458db8987 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -1,5 +1,3 @@ - - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -28,30 +26,29 @@ class TestAwsGlueJobOperator(unittest.TestCase): - @mock.patch('airflow.providers.amazon.aws.hooks.glue.AwsGlueJobHook') def setUp(self, glue_hook_mock): configuration.load_test_config() self.glue_hook_mock = glue_hook_mock some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py" - self.glue = AwsGlueJobOperator(task_id='test_glue_operator', - job_name='my_test_job', - script_location=some_script, - aws_conn_id='aws_default', - region_name='us-west-2', - s3_bucket='some_bucket', - iam_role_name='my_test_role') + self.glue = AwsGlueJobOperator( + task_id='test_glue_operator', + job_name='my_test_job', + script_location=some_script, + aws_conn_id='aws_default', + region_name='us-west-2', + s3_bucket='some_bucket', + iam_role_name='my_test_role', + ) @mock.patch.object(AwsGlueJobHook, 'get_job_state') @mock.patch.object(AwsGlueJobHook, 'initialize_job') @mock.patch.object(AwsGlueJobHook, "get_conn") @mock.patch.object(S3Hook, "load_file") - def test_execute_without_failure(self, - mock_load_file, - mock_get_conn, - mock_initialize_job, - mock_get_job_state): + def test_execute_without_failure( + self, mock_load_file, mock_get_conn, mock_initialize_job, mock_get_job_state + ): mock_initialize_job.return_value = {'JobRunState': 'RUNNING', 'JobRunId': '11111'} mock_get_job_state.return_value = 'SUCCEEDED' self.glue.execute(None) diff --git a/tests/providers/amazon/aws/operators/test_s3_bucket.py b/tests/providers/amazon/aws/operators/test_s3_bucket.py index 5913271edf441..43f8dd9021db2 100644 --- a/tests/providers/amazon/aws/operators/test_s3_bucket.py +++ b/tests/providers/amazon/aws/operators/test_s3_bucket.py @@ -30,10 +30,7 @@ class TestS3CreateBucketOperator(unittest.TestCase): def setUp(self): - self.create_bucket_operator = S3CreateBucketOperator( - task_id=TASK_ID, - bucket_name=BUCKET_NAME, - ) + self.create_bucket_operator = S3CreateBucketOperator(task_id=TASK_ID, bucket_name=BUCKET_NAME,) @mock_s3 @mock.patch.object(S3Hook, "create_bucket") @@ -58,10 +55,7 @@ def test_execute_if_not_bucket_exist(self, mock_check_for_bucket, mock_create_bu class TestS3DeleteBucketOperator(unittest.TestCase): def setUp(self): - self.delete_bucket_operator = S3DeleteBucketOperator( - task_id=TASK_ID, - bucket_name=BUCKET_NAME, - ) + self.delete_bucket_operator = S3DeleteBucketOperator(task_id=TASK_ID, bucket_name=BUCKET_NAME,) @mock_s3 @mock.patch.object(S3Hook, "delete_bucket") diff --git a/tests/providers/amazon/aws/operators/test_s3_copy_object.py b/tests/providers/amazon/aws/operators/test_s3_copy_object.py index ecc5951aeef16..e049d0b41856f 100644 --- a/tests/providers/amazon/aws/operators/test_s3_copy_object.py +++ b/tests/providers/amazon/aws/operators/test_s3_copy_object.py @@ -26,7 +26,6 @@ class TestS3CopyObjectOperator(unittest.TestCase): - def setUp(self): self.source_bucket = "bucket1" self.source_key = "path1/data.txt" @@ -38,23 +37,21 @@ def test_s3_copy_object_arg_combination_1(self): conn = boto3.client('s3') conn.create_bucket(Bucket=self.source_bucket) conn.create_bucket(Bucket=self.dest_bucket) - conn.upload_fileobj(Bucket=self.source_bucket, - Key=self.source_key, - Fileobj=io.BytesIO(b"input")) + conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=io.BytesIO(b"input")) # there should be nothing found before S3CopyObjectOperator is executed - self.assertFalse('Contents' in conn.list_objects(Bucket=self.dest_bucket, - Prefix=self.dest_key)) + self.assertFalse('Contents' in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)) - op = S3CopyObjectOperator(task_id="test_task_s3_copy_object", - source_bucket_key=self.source_key, - source_bucket_name=self.source_bucket, - dest_bucket_key=self.dest_key, - dest_bucket_name=self.dest_bucket) + op = S3CopyObjectOperator( + task_id="test_task_s3_copy_object", + source_bucket_key=self.source_key, + source_bucket_name=self.source_bucket, + dest_bucket_key=self.dest_key, + dest_bucket_name=self.dest_bucket, + ) op.execute(None) - objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, - Prefix=self.dest_key) + objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key) # there should be object found, and there should only be one object found self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) # the object found should be consistent with dest_key specified earlier @@ -65,23 +62,21 @@ def test_s3_copy_object_arg_combination_2(self): conn = boto3.client('s3') conn.create_bucket(Bucket=self.source_bucket) conn.create_bucket(Bucket=self.dest_bucket) - conn.upload_fileobj(Bucket=self.source_bucket, - Key=self.source_key, - Fileobj=io.BytesIO(b"input")) + conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=io.BytesIO(b"input")) # there should be nothing found before S3CopyObjectOperator is executed - self.assertFalse('Contents' in conn.list_objects(Bucket=self.dest_bucket, - Prefix=self.dest_key)) + self.assertFalse('Contents' in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)) source_key_s3_url = "s3://{}/{}".format(self.source_bucket, self.source_key) dest_key_s3_url = "s3://{}/{}".format(self.dest_bucket, self.dest_key) - op = S3CopyObjectOperator(task_id="test_task_s3_copy_object", - source_bucket_key=source_key_s3_url, - dest_bucket_key=dest_key_s3_url) + op = S3CopyObjectOperator( + task_id="test_task_s3_copy_object", + source_bucket_key=source_key_s3_url, + dest_bucket_key=dest_key_s3_url, + ) op.execute(None) - objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, - Prefix=self.dest_key) + objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key) # there should be object found, and there should only be one object found self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) # the object found should be consistent with dest_key specified earlier diff --git a/tests/providers/amazon/aws/operators/test_s3_delete_objects.py b/tests/providers/amazon/aws/operators/test_s3_delete_objects.py index 7db9816948aaf..5d7821c4d043f 100644 --- a/tests/providers/amazon/aws/operators/test_s3_delete_objects.py +++ b/tests/providers/amazon/aws/operators/test_s3_delete_objects.py @@ -26,7 +26,6 @@ class TestS3DeleteObjectsOperator(unittest.TestCase): - @mock_s3 def test_s3_delete_single_object(self): bucket = "testbucket" @@ -34,24 +33,18 @@ def test_s3_delete_single_object(self): conn = boto3.client('s3') conn.create_bucket(Bucket=bucket) - conn.upload_fileobj(Bucket=bucket, - Key=key, - Fileobj=io.BytesIO(b"input")) + conn.upload_fileobj(Bucket=bucket, Key=key, Fileobj=io.BytesIO(b"input")) # The object should be detected before the DELETE action is taken - objects_in_dest_bucket = conn.list_objects(Bucket=bucket, - Prefix=key) + objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key) self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], key) - op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", - bucket=bucket, - keys=key) + op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=key) op.execute(None) # There should be no object found in the bucket created earlier - self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, - Prefix=key)) + self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, Prefix=key)) @mock_s3 def test_s3_delete_multiple_objects(self): @@ -63,25 +56,18 @@ def test_s3_delete_multiple_objects(self): conn = boto3.client('s3') conn.create_bucket(Bucket=bucket) for k in keys: - conn.upload_fileobj(Bucket=bucket, - Key=k, - Fileobj=io.BytesIO(b"input")) + conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=io.BytesIO(b"input")) # The objects should be detected before the DELETE action is taken - objects_in_dest_bucket = conn.list_objects(Bucket=bucket, - Prefix=key_pattern) + objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_pattern) self.assertEqual(len(objects_in_dest_bucket['Contents']), n_keys) - self.assertEqual(sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]), - sorted(keys)) + self.assertEqual(sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]), sorted(keys)) - op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_multiple_objects", - bucket=bucket, - keys=keys) + op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_multiple_objects", bucket=bucket, keys=keys) op.execute(None) # There should be no object found in the bucket created earlier - self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, - Prefix=key_pattern)) + self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, Prefix=key_pattern)) @mock_s3 def test_s3_delete_prefix(self): @@ -93,22 +79,15 @@ def test_s3_delete_prefix(self): conn = boto3.client('s3') conn.create_bucket(Bucket=bucket) for k in keys: - conn.upload_fileobj(Bucket=bucket, - Key=k, - Fileobj=io.BytesIO(b"input")) + conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=io.BytesIO(b"input")) # The objects should be detected before the DELETE action is taken - objects_in_dest_bucket = conn.list_objects(Bucket=bucket, - Prefix=key_pattern) + objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_pattern) self.assertEqual(len(objects_in_dest_bucket['Contents']), n_keys) - self.assertEqual(sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]), - sorted(keys)) + self.assertEqual(sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]), sorted(keys)) - op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_prefix", - bucket=bucket, - prefix=key_pattern) + op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_prefix", bucket=bucket, prefix=key_pattern) op.execute(None) # There should be no object found in the bucket created earlier - self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, - Prefix=key_pattern)) + self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, Prefix=key_pattern)) diff --git a/tests/providers/amazon/aws/operators/test_s3_file_transform.py b/tests/providers/amazon/aws/operators/test_s3_file_transform.py index a1bf7beeb5ed0..8394a5078570c 100644 --- a/tests/providers/amazon/aws/operators/test_s3_file_transform.py +++ b/tests/providers/amazon/aws/operators/test_s3_file_transform.py @@ -34,7 +34,6 @@ class TestS3FileTransformOperator(unittest.TestCase): - def setUp(self): self.content = b"input" self.bucket = "bucket" @@ -66,12 +65,13 @@ def test_execute_with_transform_script(self, mock_log, mock_popen): dest_s3_key=output_path, transform_script=self.transform_script, replace=True, - task_id="task_id") + task_id="task_id", + ) op.execute(None) - mock_log.info.assert_has_calls([ - mock.call(line.decode(sys.getdefaultencoding())) for line in process_output - ]) + mock_log.info.assert_has_calls( + [mock.call(line.decode(sys.getdefaultencoding())) for line in process_output] + ) @mock.patch('subprocess.Popen') @mock_s3 @@ -84,7 +84,8 @@ def test_execute_with_failing_transform_script(self, mock_popen): dest_s3_key=output_path, transform_script=self.transform_script, replace=True, - task_id="task_id") + task_id="task_id", + ) with self.assertRaises(AirflowException) as e: op.execute(None) @@ -104,7 +105,8 @@ def test_execute_with_transform_script_args(self, mock_popen): transform_script=self.transform_script, script_args=script_args, replace=True, - task_id="task_id") + task_id="task_id", + ) op.execute(None) self.assertEqual(script_args, mock_popen.call_args[0][0][3:]) @@ -120,13 +122,11 @@ def test_execute_with_select_expression(self, mock_select_key): dest_s3_key=output_path, select_expression=select_expression, replace=True, - task_id="task_id") + task_id="task_id", + ) op.execute(None) - mock_select_key.assert_called_once_with( - key=input_path, - expression=select_expression - ) + mock_select_key.assert_called_once_with(key=input_path, expression=select_expression) conn = boto3.client('s3') result = conn.get_object(Bucket=self.bucket, Key=self.output_key) diff --git a/tests/providers/amazon/aws/operators/test_s3_list.py b/tests/providers/amazon/aws/operators/test_s3_list.py index 3e5f484fad648..245983cce6bd2 100644 --- a/tests/providers/amazon/aws/operators/test_s3_list.py +++ b/tests/providers/amazon/aws/operators/test_s3_list.py @@ -35,11 +35,11 @@ def test_execute(self, mock_hook): mock_hook.return_value.list_keys.return_value = MOCK_FILES - operator = S3ListOperator( - task_id=TASK_ID, bucket=BUCKET, prefix=PREFIX, delimiter=DELIMITER) + operator = S3ListOperator(task_id=TASK_ID, bucket=BUCKET, prefix=PREFIX, delimiter=DELIMITER) files = operator.execute(None) mock_hook.return_value.list_keys.assert_called_once_with( - bucket_name=BUCKET, prefix=PREFIX, delimiter=DELIMITER) + bucket_name=BUCKET, prefix=PREFIX, delimiter=DELIMITER + ) self.assertEqual(sorted(files), sorted(MOCK_FILES)) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py index 0f59efa140b44..57f3e36a2a4e3 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py @@ -19,51 +19,18 @@ from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator -config = { - 'key1': '1', - 'key2': { - 'key3': '3', - 'key4': '4' - }, - 'key5': [ - { - 'key6': '6' - }, - { - 'key6': '7' - } - ] -} +config = {'key1': '1', 'key2': {'key3': '3', 'key4': '4'}, 'key5': [{'key6': '6'}, {'key6': '7'}]} -parsed_config = { - 'key1': 1, - 'key2': { - 'key3': 3, - 'key4': 4 - }, - 'key5': [ - { - 'key6': 6 - }, - { - 'key6': 7 - } - ] -} +parsed_config = {'key1': 1, 'key2': {'key3': 3, 'key4': 4}, 'key5': [{'key6': 6}, {'key6': 7}]} class TestSageMakerBaseOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerBaseOperator( - task_id='test_sagemaker_operator', - aws_conn_id='sagemaker_test_id', - config=config + task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=config ) def test_parse_integer(self): - self.sagemaker.integer_fields = [ - ['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6'] - ] + self.sagemaker.integer_fields = [['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6']] self.sagemaker.parse_config_integers() self.assertEqual(self.sagemaker.config, parsed_config) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py index 05913ead8dd88..39148190eb513 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py @@ -35,11 +35,8 @@ create_model_params = { 'ModelName': model_name, - 'PrimaryContainer': { - 'Image': image, - 'ModelDataUrl': output_url, - }, - 'ExecutionRoleArn': role + 'PrimaryContainer': {'Image': image, 'ModelDataUrl': output_url,}, + 'ExecutionRoleArn': role, } create_endpoint_config_params = { @@ -49,25 +46,21 @@ 'VariantName': 'AllTraffic', 'ModelName': model_name, 'InitialInstanceCount': '1', - 'InstanceType': 'ml.c4.xlarge' + 'InstanceType': 'ml.c4.xlarge', } - ] + ], } -create_endpoint_params = { - 'EndpointName': endpoint_name, - 'EndpointConfigName': config_name -} +create_endpoint_params = {'EndpointName': endpoint_name, 'EndpointConfigName': config_name} config = { 'Model': create_model_params, 'EndpointConfig': create_endpoint_config_params, - 'Endpoint': create_endpoint_params + 'Endpoint': create_endpoint_params, } class TestSageMakerEndpointOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerEndpointOperator( task_id='test_sagemaker_operator', @@ -75,42 +68,33 @@ def setUp(self): config=config, wait_for_completion=False, check_interval=5, - operation='create' + operation='create', ) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']: - self.assertEqual(variant['InitialInstanceCount'], - int(variant['InitialInstanceCount'])) + self.assertEqual(variant['InitialInstanceCount'], int(variant['InitialInstanceCount'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_endpoint_config') @mock.patch.object(SageMakerHook, 'create_endpoint') - def test_execute(self, mock_endpoint, mock_endpoint_config, - mock_model, mock_client): - mock_endpoint.return_value = {'EndpointArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} + def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client): + mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) mock_model.assert_called_once_with(create_model_params) mock_endpoint_config.assert_called_once_with(create_endpoint_config_params) - mock_endpoint.assert_called_once_with(create_endpoint_params, - wait_for_completion=False, - check_interval=5, - max_ingestion_time=None - ) + mock_endpoint.assert_called_once_with( + create_endpoint_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_endpoint_config') @mock.patch.object(SageMakerHook, 'create_endpoint') - def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, - mock_model, mock_client): - mock_endpoint.return_value = {'EndpointArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 404}} + def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client): + mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}} self.assertRaises(AirflowException, self.sagemaker.execute, None) @mock.patch.object(SageMakerHook, 'get_conn') @@ -118,13 +102,15 @@ def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, @mock.patch.object(SageMakerHook, 'create_endpoint_config') @mock.patch.object(SageMakerHook, 'create_endpoint') @mock.patch.object(SageMakerHook, 'update_endpoint') - def test_execute_with_duplicate_endpoint_creation(self, mock_endpoint_update, - mock_endpoint, mock_endpoint_config, - mock_model, mock_client): - response = {"Error": {"Code": "ValidationException", - "Message": "Cannot create already existing endpoint."}} + def test_execute_with_duplicate_endpoint_creation( + self, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client + ): + response = { + "Error": {"Code": "ValidationException", "Message": "Cannot create already existing endpoint."} + } mock_endpoint.side_effect = ClientError(error_response=response, operation_name="CreateEndpoint") - mock_endpoint_update.return_value = {'EndpointArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} + mock_endpoint_update.return_value = { + 'EndpointArn': 'testarn', + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } self.sagemaker.execute(None) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py index d60bf7b62b7d7..cba265aa1e1bd 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py @@ -34,35 +34,31 @@ 'VariantName': 'AllTraffic', 'ModelName': model_name, 'InitialInstanceCount': '1', - 'InstanceType': 'ml.c4.xlarge' + 'InstanceType': 'ml.c4.xlarge', } - ] + ], } class TestSageMakerEndpointConfigOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerEndpointConfigOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', - config=create_endpoint_config_params + config=create_endpoint_config_params, ) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() for variant in self.sagemaker.config['ProductionVariants']: - self.assertEqual(variant['InitialInstanceCount'], - int(variant['InitialInstanceCount'])) + self.assertEqual(variant['InitialInstanceCount'], int(variant['InitialInstanceCount'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_endpoint_config') def test_execute(self, mock_model, mock_client): mock_model.return_value = { 'EndpointConfigArn': 'testarn', - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - } + 'ResponseMetadata': {'HTTPStatusCode': 200}, } self.sagemaker.execute(None) mock_model.assert_called_once_with(create_endpoint_config_params) @@ -72,8 +68,6 @@ def test_execute(self, mock_model, mock_client): def test_execute_with_failure(self, mock_model, mock_client): mock_model.return_value = { 'EndpointConfigArn': 'testarn', - 'ResponseMetadata': { - 'HTTPStatusCode': 200 - } + 'ResponseMetadata': {'HTTPStatusCode': 200}, } self.assertRaises(AirflowException, self.sagemaker.execute, None) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py index a3ddb93a060fa..53a392eb2d69a 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py @@ -35,36 +35,26 @@ output_url = 's3://{}/test/output'.format(bucket) create_model_params = { 'ModelName': model_name, - 'PrimaryContainer': { - 'Image': image, - 'ModelDataUrl': output_url, - }, - 'ExecutionRoleArn': role + 'PrimaryContainer': {'Image': image, 'ModelDataUrl': output_url,}, + 'ExecutionRoleArn': role, } class TestSageMakerModelOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerModelOperator( - task_id='test_sagemaker_operator', - aws_conn_id='sagemaker_test_id', - config=create_model_params + task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=create_model_params ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') def test_execute(self, mock_model, mock_client): - mock_model.return_value = {'ModelArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} + mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) mock_model.assert_called_once_with(create_model_params) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') def test_execute_with_failure(self, mock_model, mock_client): - mock_model.return_value = {'ModelArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 404}} + mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}} self.assertRaises(AirflowException, self.sagemaker.execute, None) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index db75cde1cebc3..8546e9b7dcb6a 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -82,57 +82,75 @@ class TestSageMakerProcessingOperator(unittest.TestCase): - def setUp(self): - self.processing_config_kwargs = dict(task_id='test_sagemaker_operator', - aws_conn_id='sagemaker_test_id', - wait_for_completion=False, - check_interval=5) + self.processing_config_kwargs = dict( + task_id='test_sagemaker_operator', + aws_conn_id='sagemaker_test_id', + wait_for_completion=False, + check_interval=5, + ) - @parameterized.expand([ - (create_processing_params, [['ProcessingResources', 'ClusterConfig', 'InstanceCount'], - ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB']]), - (create_processing_params_with_stopping_condition, [ - ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], - ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], - ['StoppingCondition', 'MaxRuntimeInSeconds']])]) + @parameterized.expand( + [ + ( + create_processing_params, + [ + ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], + ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], + ], + ), + ( + create_processing_params_with_stopping_condition, + [ + ['ProcessingResources', 'ClusterConfig', 'InstanceCount'], + ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'], + ['StoppingCondition', 'MaxRuntimeInSeconds'], + ], + ), + ] + ) def test_integer_fields_are_set(self, config, expected_fields): sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, config=config) assert sagemaker.integer_fields == expected_fields @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_processing_job', - return_value={'ProcessingJobArn': 'testarn', - 'ResponseMetadata': {'HTTPStatusCode': 200}}) + @mock.patch.object( + SageMakerHook, + 'create_processing_job', + return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}, + ) def test_execute(self, mock_processing, mock_client): - sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, - config=create_processing_params) + sagemaker = SageMakerProcessingOperator( + **self.processing_config_kwargs, config=create_processing_params + ) sagemaker.execute(None) - mock_processing.assert_called_once_with(create_processing_params, - wait_for_completion=False, - check_interval=5, - max_ingestion_time=None - ) + mock_processing.assert_called_once_with( + create_processing_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None + ) @mock.patch.object(SageMakerHook, 'get_conn') - @mock.patch.object(SageMakerHook, 'create_processing_job', - return_value={'ProcessingJobArn': 'testarn', - 'ResponseMetadata': {'HTTPStatusCode': 404}}) + @mock.patch.object( + SageMakerHook, + 'create_processing_job', + return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}, + ) def test_execute_with_failure(self, mock_processing, mock_client): - sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, - config=create_processing_params) + sagemaker = SageMakerProcessingOperator( + **self.processing_config_kwargs, config=create_processing_params + ) self.assertRaises(AirflowException, sagemaker.execute, None) @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "list_processing_jobs", - return_value=[{"ProcessingJobName": job_name}]) - @mock.patch.object(SageMakerHook, "create_processing_job", - return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}) + @mock.patch.object(SageMakerHook, "list_processing_jobs", return_value=[{"ProcessingJobName": job_name}]) + @mock.patch.object( + SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}} + ) def test_execute_with_existing_job_increment( self, mock_create_processing_job, mock_list_processing_jobs, mock_client ): - sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, - config=create_processing_params) + sagemaker = SageMakerProcessingOperator( + **self.processing_config_kwargs, config=create_processing_params + ) sagemaker.action_if_job_exists = "increment" sagemaker.execute(None) @@ -140,28 +158,26 @@ def test_execute_with_existing_job_increment( # Expect to see ProcessingJobName suffixed with "-2" because we return one existing job expected_config["ProcessingJobName"] = f"{job_name}-2" mock_create_processing_job.assert_called_once_with( - expected_config, - wait_for_completion=False, - check_interval=5, - max_ingestion_time=None, + expected_config, wait_for_completion=False, check_interval=5, max_ingestion_time=None, ) @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "list_processing_jobs", - return_value=[{"ProcessingJobName": job_name}]) - @mock.patch.object(SageMakerHook, "create_processing_job", - return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}) + @mock.patch.object(SageMakerHook, "list_processing_jobs", return_value=[{"ProcessingJobName": job_name}]) + @mock.patch.object( + SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}} + ) def test_execute_with_existing_job_fail( self, mock_create_processing_job, mock_list_processing_jobs, mock_client ): - sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, - config=create_processing_params) + sagemaker = SageMakerProcessingOperator( + **self.processing_config_kwargs, config=create_processing_params + ) sagemaker.action_if_job_exists = "fail" self.assertRaises(AirflowException, sagemaker.execute, None) @mock.patch.object(SageMakerHook, "get_conn") def test_action_if_job_exists_validation(self, mock_client): - sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, - config=create_processing_params) - self.assertRaises(AirflowException, sagemaker.__init__, - action_if_job_exists="not_fail_or_increment") + sagemaker = SageMakerProcessingOperator( + **self.processing_config_kwargs, config=create_processing_params + ) + self.assertRaises(AirflowException, sagemaker.__init__, action_if_job_exists="not_fail_or_increment") diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 25a45b08fe85d..d8da49d712a6e 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -35,91 +35,83 @@ image = 'test-image' output_url = 's3://{}/test/output'.format(bucket) -create_training_params = \ - { - 'AlgorithmSpecification': { - 'TrainingImage': image, - 'TrainingInputMode': 'File' - }, - 'RoleArn': role, - 'OutputDataConfig': { - 'S3OutputPath': output_url - }, - 'ResourceConfig': { - 'InstanceCount': '2', - 'InstanceType': 'ml.c4.8xlarge', - 'VolumeSizeInGB': '50' - }, - 'TrainingJobName': job_name, - 'HyperParameters': { - 'k': '10', - 'feature_dim': '784', - 'mini_batch_size': '500', - 'force_dense': 'True' - }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': '3600' - }, - 'InputDataConfig': [ - { - 'ChannelName': 'train', - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url, - 'S3DataDistributionType': 'FullyReplicated' - } - }, - 'CompressionType': 'None', - 'RecordWrapperType': 'None' - } - ] - } +create_training_params = { + 'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'}, + 'RoleArn': role, + 'OutputDataConfig': {'S3OutputPath': output_url}, + 'ResourceConfig': {'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': '50'}, + 'TrainingJobName': job_name, + 'HyperParameters': {'k': '10', 'feature_dim': '784', 'mini_batch_size': '500', 'force_dense': 'True'}, + 'StoppingCondition': {'MaxRuntimeInSeconds': '3600'}, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url, + 'S3DataDistributionType': 'FullyReplicated', + } + }, + 'CompressionType': 'None', + 'RecordWrapperType': 'None', + } + ], +} # pylint: disable=unused-argument class TestSageMakerTrainingOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerTrainingOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=create_training_params, wait_for_completion=False, - check_interval=5 + check_interval=5, ) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() - self.assertEqual(self.sagemaker.config['ResourceConfig']['InstanceCount'], - int(self.sagemaker.config['ResourceConfig']['InstanceCount'])) - self.assertEqual(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'], - int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'])) - self.assertEqual(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'], - int(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'])) + self.assertEqual( + self.sagemaker.config['ResourceConfig']['InstanceCount'], + int(self.sagemaker.config['ResourceConfig']['InstanceCount']), + ) + self.assertEqual( + self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'], + int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']), + ) + self.assertEqual( + self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'], + int(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']), + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_training_job') def test_execute(self, mock_training, mock_client): - mock_training.return_value = {'TrainingJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} + mock_training.return_value = { + 'TrainingJobArn': 'testarn', + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } self.sagemaker.execute(None) - mock_training.assert_called_once_with(create_training_params, - wait_for_completion=False, - print_log=True, - check_interval=5, - max_ingestion_time=None - ) + mock_training.assert_called_once_with( + create_training_params, + wait_for_completion=False, + print_log=True, + check_interval=5, + max_ingestion_time=None, + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_training_job') def test_execute_with_failure(self, mock_training, mock_client): - mock_training.return_value = {'TrainingJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 404}} + mock_training.return_value = { + 'TrainingJobArn': 'testarn', + 'ResponseMetadata': {'HTTPStatusCode': 404}, + } self.assertRaises(AirflowException, self.sagemaker.execute, None) -# pylint: enable=unused-argument + + # pylint: enable=unused-argument @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "list_training_jobs") diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 837f8937e9f90..838a740ff9c7f 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -45,79 +45,60 @@ 'MaxConcurrentTransforms': '12', 'MaxPayloadInMB': '6', 'BatchStrategy': 'MultiRecord', - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url - } - } - }, - 'TransformOutput': { - 'S3OutputPath': output_url, - }, - 'TransformResources': { - 'InstanceType': 'ml.m4.xlarge', - 'InstanceCount': '3' - } + 'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': data_url}}}, + 'TransformOutput': {'S3OutputPath': output_url,}, + 'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': '3'}, } create_model_params = { 'ModelName': model_name, - 'PrimaryContainer': { - 'Image': image, - 'ModelDataUrl': output_url, - }, - 'ExecutionRoleArn': role + 'PrimaryContainer': {'Image': image, 'ModelDataUrl': output_url,}, + 'ExecutionRoleArn': role, } -config = { - 'Model': create_model_params, - 'Transform': create_transform_params -} +config = {'Model': create_model_params, 'Transform': create_transform_params} class TestSageMakerTransformOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerTransformOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=config, wait_for_completion=False, - check_interval=5 + check_interval=5, ) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() test_config = self.sagemaker.config['Transform'] - self.assertEqual(test_config['TransformResources']['InstanceCount'], - int(test_config['TransformResources']['InstanceCount'])) - self.assertEqual(test_config['MaxConcurrentTransforms'], - int(test_config['MaxConcurrentTransforms'])) - self.assertEqual(test_config['MaxPayloadInMB'], - int(test_config['MaxPayloadInMB'])) + self.assertEqual( + test_config['TransformResources']['InstanceCount'], + int(test_config['TransformResources']['InstanceCount']), + ) + self.assertEqual(test_config['MaxConcurrentTransforms'], int(test_config['MaxConcurrentTransforms'])) + self.assertEqual(test_config['MaxPayloadInMB'], int(test_config['MaxPayloadInMB'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_transform_job') def test_execute(self, mock_transform, mock_model, mock_client): - mock_transform.return_value = {'TransformJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} + mock_transform.return_value = { + 'TransformJobArn': 'testarn', + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } self.sagemaker.execute(None) mock_model.assert_called_once_with(create_model_params) - mock_transform.assert_called_once_with(create_transform_params, - wait_for_completion=False, - check_interval=5, - max_ingestion_time=None - ) + mock_transform.assert_called_once_with( + create_transform_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_transform_job') def test_execute_with_failure(self, mock_transform, mock_model, mock_client): - mock_transform.return_value = {'TransformJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 404}} + mock_transform.return_value = { + 'TransformJobArn': 'testarn', + 'ResponseMetadata': {'HTTPStatusCode': 404}, + } self.assertRaises(AirflowException, self.sagemaker.execute, None) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py index a55be39bd76c5..6b089598c25b2 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py @@ -37,119 +37,96 @@ output_url = 's3://{}/test/output'.format(bucket) -create_tuning_params = {'HyperParameterTuningJobName': job_name, - 'HyperParameterTuningJobConfig': { - 'Strategy': 'Bayesian', - 'HyperParameterTuningJobObjective': { - 'Type': 'Maximize', - 'MetricName': 'test_metric' - }, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': '123', - 'MaxParallelTrainingJobs': '123' - }, - 'ParameterRanges': { - 'IntegerParameterRanges': [ - { - 'Name': 'k', - 'MinValue': '2', - 'MaxValue': '10' - }, - ] - } - }, - 'TrainingJobDefinition': { - 'StaticHyperParameters': - { - 'k': '10', - 'feature_dim': '784', - 'mini_batch_size': '500', - 'force_dense': 'True' - }, - 'AlgorithmSpecification': - { - 'TrainingImage': image, - 'TrainingInputMode': 'File' - }, - 'RoleArn': role, - 'InputDataConfig': - [ - { - 'ChannelName': 'train', - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': data_url, - 'S3DataDistributionType': - 'FullyReplicated' - } - }, - 'CompressionType': 'None', - 'RecordWrapperType': 'None' - } - ], - 'OutputDataConfig': - { - 'S3OutputPath': output_url - }, - 'ResourceConfig': - { - 'InstanceCount': '2', - 'InstanceType': 'ml.c4.8xlarge', - 'VolumeSizeInGB': '50' - }, - 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) - } - } +create_tuning_params = { + 'HyperParameterTuningJobName': job_name, + 'HyperParameterTuningJobConfig': { + 'Strategy': 'Bayesian', + 'HyperParameterTuningJobObjective': {'Type': 'Maximize', 'MetricName': 'test_metric'}, + 'ResourceLimits': {'MaxNumberOfTrainingJobs': '123', 'MaxParallelTrainingJobs': '123'}, + 'ParameterRanges': {'IntegerParameterRanges': [{'Name': 'k', 'MinValue': '2', 'MaxValue': '10'},]}, + }, + 'TrainingJobDefinition': { + 'StaticHyperParameters': { + 'k': '10', + 'feature_dim': '784', + 'mini_batch_size': '500', + 'force_dense': 'True', + }, + 'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'}, + 'RoleArn': role, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url, + 'S3DataDistributionType': 'FullyReplicated', + } + }, + 'CompressionType': 'None', + 'RecordWrapperType': 'None', + } + ], + 'OutputDataConfig': {'S3OutputPath': output_url}, + 'ResourceConfig': {'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': '50'}, + 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60), + }, +} class TestSageMakerTuningOperator(unittest.TestCase): - def setUp(self): self.sagemaker = SageMakerTuningOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_conn', config=create_tuning_params, wait_for_completion=False, - check_interval=5 + check_interval=5, ) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() - self.assertEqual(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] - ['InstanceCount'], - int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] - ['InstanceCount'])) - self.assertEqual(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] - ['VolumeSizeInGB'], - int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] - ['VolumeSizeInGB'])) - self.assertEqual(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] - ['MaxNumberOfTrainingJobs'], - int(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] - ['MaxNumberOfTrainingJobs'])) - self.assertEqual(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] - ['MaxParallelTrainingJobs'], - int(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'] - ['MaxParallelTrainingJobs'])) + self.assertEqual( + self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount'], + int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount']), + ) + self.assertEqual( + self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'], + int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB']), + ) + self.assertEqual( + self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][ + 'MaxNumberOfTrainingJobs' + ], + int( + self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][ + 'MaxNumberOfTrainingJobs' + ] + ), + ) + self.assertEqual( + self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][ + 'MaxParallelTrainingJobs' + ], + int( + self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][ + 'MaxParallelTrainingJobs' + ] + ), + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') def test_execute(self, mock_tuning, mock_client): - mock_tuning.return_value = {'TrainingJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 200}} + mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}} self.sagemaker.execute(None) - mock_tuning.assert_called_once_with(create_tuning_params, - wait_for_completion=False, - check_interval=5, - max_ingestion_time=None - ) + mock_tuning.assert_called_once_with( + create_tuning_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None + ) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') def test_execute_with_failure(self, mock_tuning, mock_client): - mock_tuning.return_value = {'TrainingJobArn': 'testarn', - 'ResponseMetadata': - {'HTTPStatusCode': 404}} + mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}} self.assertRaises(AirflowException, self.sagemaker.execute, None) diff --git a/tests/providers/amazon/aws/operators/test_sns.py b/tests/providers/amazon/aws/operators/test_sns.py index 2c1993b396085..78eda16e530e7 100644 --- a/tests/providers/amazon/aws/operators/test_sns.py +++ b/tests/providers/amazon/aws/operators/test_sns.py @@ -31,7 +31,6 @@ class TestSnsPublishOperator(unittest.TestCase): - def test_init(self): # Given / When operator = SnsPublishOperator( diff --git a/tests/providers/amazon/aws/operators/test_sqs.py b/tests/providers/amazon/aws/operators/test_sqs.py index dac49b540f0e0..ef65b67c6d8e1 100644 --- a/tests/providers/amazon/aws/operators/test_sqs.py +++ b/tests/providers/amazon/aws/operators/test_sqs.py @@ -31,12 +31,8 @@ class TestSQSPublishOperator(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) self.operator = SQSPublishOperator( @@ -44,7 +40,7 @@ def setUp(self): dag=self.dag, sqs_queue='test', message_content='hello', - aws_conn_id='aws_default' + aws_conn_id='aws_default', ) self.mock_context = MagicMock() diff --git a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py index 8997df9fdaba8..2aba95bc29d99 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py +++ b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py @@ -26,24 +26,22 @@ ) TASK_ID = 'step_function_get_execution_output' -EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ - 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +EXECUTION_ARN = ( + 'arn:aws:states:us-east-1:123456789012:execution:' + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +) AWS_CONN_ID = 'aws_non_default' REGION_NAME = 'us-west-2' class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase): - def setUp(self): self.mock_context = MagicMock() def test_init(self): # Given / When operator = StepFunctionGetExecutionOutputOperator( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) # Then @@ -55,18 +53,13 @@ def test_init(self): @mock.patch('airflow.providers.amazon.aws.operators.step_function_get_execution_output.StepFunctionHook') def test_execute(self, mock_hook): # Given - hook_response = { - 'output': '{}' - } + hook_response = {'output': '{}'} hook_instance = mock_hook.return_value hook_instance.describe_execution.return_value = hook_response operator = StepFunctionGetExecutionOutputOperator( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) # When diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py index 5f6c336521594..4ceedbcd37c6d 100644 --- a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py +++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py @@ -34,7 +34,6 @@ class TestStepFunctionStartExecutionOperator(unittest.TestCase): - def setUp(self): self.mock_context = MagicMock() @@ -46,7 +45,7 @@ def test_init(self): name=NAME, state_machine_input=INPUT, aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + region_name=REGION_NAME, ) # Then @@ -60,8 +59,10 @@ def test_init(self): @mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook') def test_execute(self, mock_hook): # Given - hook_response = 'arn:aws:states:us-east-1:123456789012:execution:'\ - 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' + hook_response = ( + 'arn:aws:states:us-east-1:123456789012:execution:' + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' + ) hook_instance = mock_hook.return_value hook_instance.start_execution.return_value = hook_response @@ -72,7 +73,7 @@ def test_execute(self, mock_hook): name=NAME, state_machine_input=INPUT, aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + region_name=REGION_NAME, ) # When diff --git a/tests/providers/amazon/aws/secrets/test_secrets_manager.py b/tests/providers/amazon/aws/secrets/test_secrets_manager.py index feb29c2b1b71b..86c9d772c29e3 100644 --- a/tests/providers/amazon/aws/secrets/test_secrets_manager.py +++ b/tests/providers/amazon/aws/secrets/test_secrets_manager.py @@ -23,8 +23,7 @@ class TestSecretsManagerBackend(TestCase): - @mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager." - "SecretsManagerBackend.get_conn_uri") + @mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager." "SecretsManagerBackend.get_conn_uri") def test_aws_secrets_manager_get_connections(self, mock_get_uri): mock_get_uri.return_value = "scheme://user:pass@host:100" conn_list = SecretsManagerBackend().get_connections("fake_conn") @@ -35,7 +34,7 @@ def test_aws_secrets_manager_get_connections(self, mock_get_uri): def test_get_conn_uri(self): param = { 'SecretId': 'airflow/connections/test_postgres', - 'SecretString': 'postgresql://airflow:airflow@host:5432/airflow' + 'SecretString': 'postgresql://airflow:airflow@host:5432/airflow', } secrets_manager_backend = SecretsManagerBackend() @@ -53,7 +52,7 @@ def test_get_conn_uri_non_existent_key(self): conn_id = "test_mysql" param = { 'SecretId': 'airflow/connections/test_postgres', - 'SecretString': 'postgresql://airflow:airflow@host:5432/airflow' + 'SecretString': 'postgresql://airflow:airflow@host:5432/airflow', } secrets_manager_backend = SecretsManagerBackend() @@ -64,10 +63,7 @@ def test_get_conn_uri_non_existent_key(self): @mock_secretsmanager def test_get_variable(self): - param = { - 'SecretId': 'airflow/variables/hello', - 'SecretString': 'world' - } + param = {'SecretId': 'airflow/variables/hello', 'SecretString': 'world'} secrets_manager_backend = SecretsManagerBackend() secrets_manager_backend.client.put_secret_value(**param) @@ -81,10 +77,7 @@ def test_get_variable_non_existent_key(self): Test that if Variable key is not present, SystemsManagerParameterStoreBackend.get_variables should return None """ - param = { - 'SecretId': 'airflow/variables/hello', - 'SecretString': 'world' - } + param = {'SecretId': 'airflow/variables/hello', 'SecretString': 'world'} secrets_manager_backend = SecretsManagerBackend() secrets_manager_backend.client.put_secret_value(**param) diff --git a/tests/providers/amazon/aws/secrets/test_systems_manager.py b/tests/providers/amazon/aws/secrets/test_systems_manager.py index 7556552072dc7..d43fadc005559 100644 --- a/tests/providers/amazon/aws/secrets/test_systems_manager.py +++ b/tests/providers/amazon/aws/secrets/test_systems_manager.py @@ -25,8 +25,10 @@ class TestSsmSecrets(TestCase): - @mock.patch("airflow.providers.amazon.aws.secrets.systems_manager." - "SystemsManagerParameterStoreBackend.get_conn_uri") + @mock.patch( + "airflow.providers.amazon.aws.secrets.systems_manager." + "SystemsManagerParameterStoreBackend.get_conn_uri" + ) def test_aws_ssm_get_connections(self, mock_get_uri): mock_get_uri.return_value = "scheme://user:pass@host:100" conn_list = SystemsManagerParameterStoreBackend().get_connections("fake_conn") @@ -38,7 +40,7 @@ def test_get_conn_uri(self): param = { 'Name': '/airflow/connections/test_postgres', 'Type': 'String', - 'Value': 'postgresql://airflow:airflow@host:5432/airflow' + 'Value': 'postgresql://airflow:airflow@host:5432/airflow', } ssm_backend = SystemsManagerParameterStoreBackend() @@ -57,7 +59,7 @@ def test_get_conn_uri_non_existent_key(self): param = { 'Name': '/airflow/connections/test_postgres', 'Type': 'String', - 'Value': 'postgresql://airflow:airflow@host:5432/airflow' + 'Value': 'postgresql://airflow:airflow@host:5432/airflow', } ssm_backend = SystemsManagerParameterStoreBackend() @@ -68,11 +70,7 @@ def test_get_conn_uri_non_existent_key(self): @mock_ssm def test_get_variable(self): - param = { - 'Name': '/airflow/variables/hello', - 'Type': 'String', - 'Value': 'world' - } + param = {'Name': '/airflow/variables/hello', 'Type': 'String', 'Value': 'world'} ssm_backend = SystemsManagerParameterStoreBackend() ssm_backend.client.put_parameter(**param) @@ -82,11 +80,7 @@ def test_get_variable(self): @mock_ssm def test_get_variable_secret_string(self): - param = { - 'Name': '/airflow/variables/hello', - 'Type': 'SecureString', - 'Value': 'world' - } + param = {'Name': '/airflow/variables/hello', 'Type': 'SecureString', 'Value': 'world'} ssm_backend = SystemsManagerParameterStoreBackend() ssm_backend.client.put_parameter(**param) returned_uri = ssm_backend.get_variable('hello') @@ -98,27 +92,26 @@ def test_get_variable_non_existent_key(self): Test that if Variable key is not present in SSM, SystemsManagerParameterStoreBackend.get_variables should return None """ - param = { - 'Name': '/airflow/variables/hello', - 'Type': 'String', - 'Value': 'world' - } + param = {'Name': '/airflow/variables/hello', 'Type': 'String', 'Value': 'world'} ssm_backend = SystemsManagerParameterStoreBackend() ssm_backend.client.put_parameter(**param) self.assertIsNone(ssm_backend.get_variable("test_mysql")) - @conf_vars({ - ('secrets', 'backend'): 'airflow.providers.amazon.aws.secrets.systems_manager.' - 'SystemsManagerParameterStoreBackend', - ('secrets', 'backend_kwargs'): '{"use_ssl": false}' - }) + @conf_vars( + { + ('secrets', 'backend'): 'airflow.providers.amazon.aws.secrets.systems_manager.' + 'SystemsManagerParameterStoreBackend', + ('secrets', 'backend_kwargs'): '{"use_ssl": false}', + } + ) @mock.patch("airflow.providers.amazon.aws.secrets.systems_manager.boto3.Session.client") def test_passing_client_kwargs(self, mock_ssm_client): backends = initialize_secrets_backends() systems_manager = [ - backend for backend in backends + backend + for backend in backends if backend.__class__.__name__ == 'SystemsManagerParameterStoreBackend' ][0] diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index c28084edc4225..6a243bfa45d3c 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -26,13 +26,14 @@ class TestAthenaSensor(unittest.TestCase): - def setUp(self): - self.sensor = AthenaSensor(task_id='test_athena_sensor', - query_execution_id='abc', - sleep_time=5, - max_retries=1, - aws_conn_id='aws_default') + self.sensor = AthenaSensor( + task_id='test_athena_sensor', + query_execution_id='abc', + sleep_time=5, + max_retries=1, + aws_conn_id='aws_default', + ) @mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("SUCCEEDED",)) def test_poke_success(self, mock_poll_query_status): diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py index 0148db8838e13..af62d3eb40659 100644 --- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py +++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py @@ -21,7 +21,8 @@ from mock import MagicMock, patch from airflow.providers.amazon.aws.sensors.cloud_formation import ( - CloudFormationCreateStackSensor, CloudFormationDeleteStackSensor, + CloudFormationCreateStackSensor, + CloudFormationDeleteStackSensor, ) try: @@ -30,8 +31,9 @@ mock_cloudformation = None -@unittest.skipIf(mock_cloudformation is None, - "Skipping test because moto.mock_cloudformation is not available") +@unittest.skipIf( + mock_cloudformation is None, "Skipping test because moto.mock_cloudformation is not available" +) class TestCloudFormationCreateStackSensor(unittest.TestCase): task_id = 'test_cloudformation_cluster_create_sensor' @@ -73,8 +75,9 @@ def test_poke_stack_in_unsuccessful_state(self): self.assertEqual('Stack foo in bad state: bar', str(error.exception)) -@unittest.skipIf(mock_cloudformation is None, - "Skipping test because moto.mock_cloudformation is not available") +@unittest.skipIf( + mock_cloudformation is None, "Skipping test because moto.mock_cloudformation is not available" +) class TestCloudFormationDeleteStackSensor(unittest.TestCase): task_id = 'test_cloudformation_cluster_delete_sensor' diff --git a/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py b/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py index 5608ddf1d4816..d886029a4bd16 100644 --- a/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py +++ b/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py @@ -26,7 +26,6 @@ class TestEC2InstanceStateSensor(unittest.TestCase): - def test_init(self): ec2_operator = EC2InstanceStateSensor( task_id="task_test", @@ -45,9 +44,7 @@ def test_init_invalid_target_state(self): invalid_target_state = "target_state_test" with self.assertRaises(ValueError) as cm: EC2InstanceStateSensor( - task_id="task_test", - target_state=invalid_target_state, - instance_id="i-123abc", + task_id="task_test", target_state=invalid_target_state, instance_id="i-123abc", ) msg = f"Invalid target_state: {invalid_target_state}" self.assertEqual(str(cm.exception), msg) @@ -56,19 +53,14 @@ def test_init_invalid_target_state(self): def test_running(self): # create instance ec2_hook = EC2Hook() - instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) instance_id = instances[0].instance_id # stop instance ec2_hook.get_instance(instance_id=instance_id).stop() # start sensor, waits until ec2 instance state became running start_sensor = EC2InstanceStateSensor( - task_id="start_sensor", - target_state="running", - instance_id=instance_id, + task_id="start_sensor", target_state="running", instance_id=instance_id, ) # assert instance state is not running self.assertFalse(start_sensor.poke(None)) @@ -81,19 +73,14 @@ def test_running(self): def test_stopped(self): # create instance ec2_hook = EC2Hook() - instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) instance_id = instances[0].instance_id # start instance ec2_hook.get_instance(instance_id=instance_id).start() # stop sensor, waits until ec2 instance state became stopped stop_sensor = EC2InstanceStateSensor( - task_id="stop_sensor", - target_state="stopped", - instance_id=instance_id, + task_id="stop_sensor", target_state="stopped", instance_id=instance_id, ) # assert instance state is not stopped self.assertFalse(stop_sensor.poke(None)) @@ -106,19 +93,14 @@ def test_stopped(self): def test_terminated(self): # create instance ec2_hook = EC2Hook() - instances = ec2_hook.conn.create_instances( - MaxCount=1, - MinCount=1, - ) + instances = ec2_hook.conn.create_instances(MaxCount=1, MinCount=1,) instance_id = instances[0].instance_id # start instance ec2_hook.get_instance(instance_id=instance_id).start() # stop sensor, waits until ec2 instance state became terminated stop_sensor = EC2InstanceStateSensor( - task_id="stop_sensor", - target_state="terminated", - instance_id=instance_id, + task_id="stop_sensor", target_state="terminated", instance_id=instance_id, ) # assert instance state is not terminated self.assertFalse(stop_sensor.poke(None)) diff --git a/tests/providers/amazon/aws/sensors/test_emr_base.py b/tests/providers/amazon/aws/sensors/test_emr_base.py index 71c590120d74b..b297342612daf 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_base.py +++ b/tests/providers/amazon/aws/sensors/test_emr_base.py @@ -51,57 +51,44 @@ def failure_message_from_response(response): change_reason = response['SomeKey'].get('StateChangeReason') if change_reason: return 'for code: {} with message {}'.format( - change_reason.get('Code', EMPTY_CODE), - change_reason.get('Message', 'Unknown')) + change_reason.get('Code', EMPTY_CODE), change_reason.get('Message', 'Unknown') + ) return None class TestEmrBaseSensor(unittest.TestCase): def test_poke_returns_true_when_state_is_in_target_states(self): - operator = EmrBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - ) + operator = EmrBaseSensorSubclass(task_id='test_task', poke_interval=2,) operator.response = { 'SomeKey': {'State': TARGET_STATE}, - 'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS} + 'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS}, } operator.execute(None) def test_poke_returns_false_when_state_is_not_in_target_states(self): - operator = EmrBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - ) + operator = EmrBaseSensorSubclass(task_id='test_task', poke_interval=2,) operator.response = { 'SomeKey': {'State': NON_TARGET_STATE}, - 'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS} + 'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS}, } self.assertEqual(operator.poke(None), False) def test_poke_returns_false_when_http_response_is_bad(self): - operator = EmrBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - ) + operator = EmrBaseSensorSubclass(task_id='test_task', poke_interval=2,) operator.response = { 'SomeKey': {'State': TARGET_STATE}, - 'ResponseMetadata': {'HTTPStatusCode': BAD_HTTP_STATUS} + 'ResponseMetadata': {'HTTPStatusCode': BAD_HTTP_STATUS}, } self.assertEqual(operator.poke(None), False) def test_poke_raises_error_when_state_is_in_failed_states(self): - operator = EmrBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - ) + operator = EmrBaseSensorSubclass(task_id='test_task', poke_interval=2,) operator.response = { - 'SomeKey': {'State': FAILED_STATE, - 'StateChangeReason': {'Code': EXPECTED_CODE}}, - 'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS} + 'SomeKey': {'State': FAILED_STATE, 'StateChangeReason': {'Code': EXPECTED_CODE}}, + 'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS}, } with self.assertRaises(AirflowException) as context: diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py index f5de8de0c1a26..a98f8f7ea5ab8 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py +++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py @@ -27,9 +27,7 @@ DESCRIBE_CLUSTER_STARTING_RETURN = { 'Cluster': { - 'Applications': [ - {'Name': 'Spark', 'Version': '1.6.1'} - ], + 'Applications': [{'Name': 'Spark', 'Version': '1.6.1'}], 'AutoTerminate': True, 'Configurations': [], 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'}, @@ -43,26 +41,19 @@ 'State': 'STARTING', 'StateChangeReason': {}, 'Timeline': { - 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())} + 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal()) + }, }, - 'Tags': [ - {'Key': 'app', 'Value': 'analytics'}, - {'Key': 'environment', 'Value': 'development'} - ], + 'Tags': [{'Key': 'app', 'Value': 'analytics'}, {'Key': 'environment', 'Value': 'development'}], 'TerminationProtected': False, - 'VisibleToAllUsers': True + 'VisibleToAllUsers': True, }, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e' - } + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'}, } DESCRIBE_CLUSTER_BOOTSTRAPPING_RETURN = { 'Cluster': { - 'Applications': [ - {'Name': 'Spark', 'Version': '1.6.1'} - ], + 'Applications': [{'Name': 'Spark', 'Version': '1.6.1'}], 'AutoTerminate': True, 'Configurations': [], 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'}, @@ -76,26 +67,19 @@ 'State': 'BOOTSTRAPPING', 'StateChangeReason': {}, 'Timeline': { - 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())} + 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal()) + }, }, - 'Tags': [ - {'Key': 'app', 'Value': 'analytics'}, - {'Key': 'environment', 'Value': 'development'} - ], + 'Tags': [{'Key': 'app', 'Value': 'analytics'}, {'Key': 'environment', 'Value': 'development'}], 'TerminationProtected': False, - 'VisibleToAllUsers': True + 'VisibleToAllUsers': True, }, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e' - } + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'}, } DESCRIBE_CLUSTER_RUNNING_RETURN = { 'Cluster': { - 'Applications': [ - {'Name': 'Spark', 'Version': '1.6.1'} - ], + 'Applications': [{'Name': 'Spark', 'Version': '1.6.1'}], 'AutoTerminate': True, 'Configurations': [], 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'}, @@ -109,26 +93,19 @@ 'State': 'RUNNING', 'StateChangeReason': {}, 'Timeline': { - 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())} + 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal()) + }, }, - 'Tags': [ - {'Key': 'app', 'Value': 'analytics'}, - {'Key': 'environment', 'Value': 'development'} - ], + 'Tags': [{'Key': 'app', 'Value': 'analytics'}, {'Key': 'environment', 'Value': 'development'}], 'TerminationProtected': False, - 'VisibleToAllUsers': True + 'VisibleToAllUsers': True, }, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e' - } + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'}, } DESCRIBE_CLUSTER_WAITING_RETURN = { 'Cluster': { - 'Applications': [ - {'Name': 'Spark', 'Version': '1.6.1'} - ], + 'Applications': [{'Name': 'Spark', 'Version': '1.6.1'}], 'AutoTerminate': True, 'Configurations': [], 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'}, @@ -142,26 +119,19 @@ 'State': 'WAITING', 'StateChangeReason': {}, 'Timeline': { - 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())} + 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal()) + }, }, - 'Tags': [ - {'Key': 'app', 'Value': 'analytics'}, - {'Key': 'environment', 'Value': 'development'} - ], + 'Tags': [{'Key': 'app', 'Value': 'analytics'}, {'Key': 'environment', 'Value': 'development'}], 'TerminationProtected': False, - 'VisibleToAllUsers': True + 'VisibleToAllUsers': True, }, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e' - } + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'}, } DESCRIBE_CLUSTER_TERMINATED_RETURN = { 'Cluster': { - 'Applications': [ - {'Name': 'Spark', 'Version': '1.6.1'} - ], + 'Applications': [{'Name': 'Spark', 'Version': '1.6.1'}], 'AutoTerminate': True, 'Configurations': [], 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'}, @@ -175,26 +145,19 @@ 'State': 'TERMINATED', 'StateChangeReason': {}, 'Timeline': { - 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())} + 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal()) + }, }, - 'Tags': [ - {'Key': 'app', 'Value': 'analytics'}, - {'Key': 'environment', 'Value': 'development'} - ], + 'Tags': [{'Key': 'app', 'Value': 'analytics'}, {'Key': 'environment', 'Value': 'development'}], 'TerminationProtected': False, - 'VisibleToAllUsers': True + 'VisibleToAllUsers': True, }, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e' - } + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'}, } DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN = { 'Cluster': { - 'Applications': [ - {'Name': 'Spark', 'Version': '1.6.1'} - ], + 'Applications': [{'Name': 'Spark', 'Version': '1.6.1'}], 'AutoTerminate': True, 'Configurations': [], 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'}, @@ -209,22 +172,17 @@ 'StateChangeReason': { 'Code': 'BOOTSTRAP_FAILURE', 'Message': 'Master instance (i-0663047709b12345c) failed attempting to ' - 'download bootstrap action 1 file from S3' + 'download bootstrap action 1 file from S3', }, 'Timeline': { - 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())} + 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal()) + }, }, - 'Tags': [ - {'Key': 'app', 'Value': 'analytics'}, - {'Key': 'environment', 'Value': 'development'} - ], + 'Tags': [{'Key': 'app', 'Value': 'analytics'}, {'Key': 'environment', 'Value': 'development'}], 'TerminationProtected': False, - 'VisibleToAllUsers': True + 'VisibleToAllUsers': True, }, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e' - } + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'}, } @@ -243,14 +201,11 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self self.mock_emr_client.describe_cluster.side_effect = [ DESCRIBE_CLUSTER_STARTING_RETURN, DESCRIBE_CLUSTER_RUNNING_RETURN, - DESCRIBE_CLUSTER_TERMINATED_RETURN + DESCRIBE_CLUSTER_TERMINATED_RETURN, ] with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrJobFlowSensor( - task_id='test_task', - poke_interval=0, - job_flow_id='j-8989898989', - aws_conn_id='aws_default' + task_id='test_task', poke_interval=0, job_flow_id='j-8989898989', aws_conn_id='aws_default' ) operator.execute(None) @@ -265,14 +220,11 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_exception(self): self.mock_emr_client.describe_cluster.side_effect = [ DESCRIBE_CLUSTER_RUNNING_RETURN, - DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN + DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN, ] with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrJobFlowSensor( - task_id='test_task', - poke_interval=0, - job_flow_id='j-8989898989', - aws_conn_id='aws_default' + task_id='test_task', poke_interval=0, job_flow_id='j-8989898989', aws_conn_id='aws_default' ) with self.assertRaises(AirflowException): @@ -299,7 +251,7 @@ def test_different_target_states(self): poke_interval=0, job_flow_id='j-8989898989', aws_conn_id='aws_default', - target_states=['RUNNING', 'WAITING'] + target_states=['RUNNING', 'WAITING'], ) operator.execute(None) diff --git a/tests/providers/amazon/aws/sensors/test_emr_step.py b/tests/providers/amazon/aws/sensors/test_emr_step.py index aaf1d909b5156..512ec5b972486 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_step.py +++ b/tests/providers/amazon/aws/sensors/test_emr_step.py @@ -26,20 +26,13 @@ from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor DESCRIBE_JOB_STEP_RUNNING_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6' - }, + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'}, 'Step': { 'ActionOnFailure': 'CONTINUE', 'Config': { - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ], + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], 'Jar': 'command-runner.jar', - 'Properties': {} + 'Properties': {}, }, 'Id': 's-VK57YR1Z9Z5N', 'Name': 'calculate_pi', @@ -48,27 +41,20 @@ 'StateChangeReason': {}, 'Timeline': { 'CreationDateTime': datetime(2016, 6, 20, 19, 0, 18, tzinfo=tzlocal()), - 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()) - } - } - } + 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()), + }, + }, + }, } DESCRIBE_JOB_STEP_CANCELLED_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6' - }, + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'}, 'Step': { 'ActionOnFailure': 'CONTINUE', 'Config': { - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ], + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], 'Jar': 'command-runner.jar', - 'Properties': {} + 'Properties': {}, }, 'Id': 's-VK57YR1Z9Z5N', 'Name': 'calculate_pi', @@ -77,27 +63,20 @@ 'StateChangeReason': {}, 'Timeline': { 'CreationDateTime': datetime(2016, 6, 20, 19, 0, 18, tzinfo=tzlocal()), - 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()) - } - } - } + 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()), + }, + }, + }, } DESCRIBE_JOB_STEP_FAILED_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6' - }, + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'}, 'Step': { 'ActionOnFailure': 'CONTINUE', 'Config': { - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ], + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], 'Jar': 'command-runner.jar', - 'Properties': {} + 'Properties': {}, }, 'Id': 's-VK57YR1Z9Z5N', 'Name': 'calculate_pi', @@ -106,31 +85,24 @@ 'StateChangeReason': {}, 'FailureDetails': { 'LogFile': 's3://fake-log-files/emr-logs/j-8989898989/steps/s-VK57YR1Z9Z5N', - 'Reason': 'Unknown Error.' + 'Reason': 'Unknown Error.', }, 'Timeline': { 'CreationDateTime': datetime(2016, 6, 20, 19, 0, 18, tzinfo=tzlocal()), - 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()) - } - } - } + 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()), + }, + }, + }, } DESCRIBE_JOB_STEP_INTERRUPTED_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6' - }, + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'}, 'Step': { 'ActionOnFailure': 'CONTINUE', 'Config': { - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ], + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], 'Jar': 'command-runner.jar', - 'Properties': {} + 'Properties': {}, }, 'Id': 's-VK57YR1Z9Z5N', 'Name': 'calculate_pi', @@ -139,27 +111,20 @@ 'StateChangeReason': {}, 'Timeline': { 'CreationDateTime': datetime(2016, 6, 20, 19, 0, 18, tzinfo=tzlocal()), - 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()) - } - } - } + 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()), + }, + }, + }, } DESCRIBE_JOB_STEP_COMPLETED_RETURN = { - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6' - }, + 'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'}, 'Step': { 'ActionOnFailure': 'CONTINUE', 'Config': { - 'Args': [ - '/usr/lib/spark/bin/run-example', - 'SparkPi', - '10' - ], + 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'], 'Jar': 'command-runner.jar', - 'Properties': {} + 'Properties': {}, }, 'Id': 's-VK57YR1Z9Z5N', 'Name': 'calculate_pi', @@ -168,10 +133,10 @@ 'StateChangeReason': {}, 'Timeline': { 'CreationDateTime': datetime(2016, 6, 20, 19, 0, 18, tzinfo=tzlocal()), - 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()) - } - } - } + 'StartDateTime': datetime(2016, 6, 20, 19, 2, 34, tzinfo=tzlocal()), + }, + }, + }, } @@ -195,7 +160,7 @@ def setUp(self): def test_step_completed(self): self.emr_client_mock.describe_step.side_effect = [ DESCRIBE_JOB_STEP_RUNNING_RETURN, - DESCRIBE_JOB_STEP_COMPLETED_RETURN + DESCRIBE_JOB_STEP_COMPLETED_RETURN, ] with patch('boto3.session.Session', self.boto3_session_mock): @@ -204,14 +169,14 @@ def test_step_completed(self): self.assertEqual(self.emr_client_mock.describe_step.call_count, 2) calls = [ unittest.mock.call(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N'), - unittest.mock.call(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N') + unittest.mock.call(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N'), ] self.emr_client_mock.describe_step.assert_has_calls(calls) def test_step_cancelled(self): self.emr_client_mock.describe_step.side_effect = [ DESCRIBE_JOB_STEP_RUNNING_RETURN, - DESCRIBE_JOB_STEP_CANCELLED_RETURN + DESCRIBE_JOB_STEP_CANCELLED_RETURN, ] with patch('boto3.session.Session', self.boto3_session_mock): @@ -220,7 +185,7 @@ def test_step_cancelled(self): def test_step_failed(self): self.emr_client_mock.describe_step.side_effect = [ DESCRIBE_JOB_STEP_RUNNING_RETURN, - DESCRIBE_JOB_STEP_FAILED_RETURN + DESCRIBE_JOB_STEP_FAILED_RETURN, ] with patch('boto3.session.Session', self.boto3_session_mock): @@ -229,7 +194,7 @@ def test_step_failed(self): def test_step_interrupted(self): self.emr_client_mock.describe_step.side_effect = [ DESCRIBE_JOB_STEP_RUNNING_RETURN, - DESCRIBE_JOB_STEP_INTERRUPTED_RETURN + DESCRIBE_JOB_STEP_INTERRUPTED_RETURN, ] with patch('boto3.session.Session', self.boto3_session_mock): diff --git a/tests/providers/amazon/aws/sensors/test_glue.py b/tests/providers/amazon/aws/sensors/test_glue.py index ec6f921a05a64..67f91d0b146ae 100644 --- a/tests/providers/amazon/aws/sensors/test_glue.py +++ b/tests/providers/amazon/aws/sensors/test_glue.py @@ -1,5 +1,3 @@ - - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -27,7 +25,6 @@ class TestAwsGlueJobSensor(unittest.TestCase): - def setUp(self): configuration.load_test_config() @@ -36,12 +33,14 @@ def setUp(self): def test_poke(self, mock_get_job_state, mock_conn): mock_conn.return_value.get_job_run() mock_get_job_state.return_value = 'SUCCEEDED' - op = AwsGlueJobSensor(task_id='test_glue_job_sensor', - job_name='aws_test_glue_job', - run_id='5152fgsfsjhsh61661', - poke_interval=1, - timeout=5, - aws_conn_id='aws_default') + op = AwsGlueJobSensor( + task_id='test_glue_job_sensor', + job_name='aws_test_glue_job', + run_id='5152fgsfsjhsh61661', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + ) self.assertTrue(op.poke(None)) @mock.patch.object(AwsGlueJobHook, 'get_conn') @@ -49,12 +48,14 @@ def test_poke(self, mock_get_job_state, mock_conn): def test_poke_false(self, mock_get_job_state, mock_conn): mock_conn.return_value.get_job_run() mock_get_job_state.return_value = 'RUNNING' - op = AwsGlueJobSensor(task_id='test_glue_job_sensor', - job_name='aws_test_glue_job', - run_id='5152fgsfsjhsh61661', - poke_interval=1, - timeout=5, - aws_conn_id='aws_default') + op = AwsGlueJobSensor( + task_id='test_glue_job_sensor', + job_name='aws_test_glue_job', + run_id='5152fgsfsjhsh61661', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + ) self.assertFalse(op.poke(None)) diff --git a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py index 8e69875bd456c..92103572c4b9e 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py +++ b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py @@ -29,8 +29,7 @@ mock_glue = None -@unittest.skipIf(mock_glue is None, - "Skipping test because moto.mock_glue is not available") +@unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is not available") class TestAwsGlueCatalogPartitionSensor(unittest.TestCase): task_id = 'test_glue_catalog_partition_sensor' @@ -39,31 +38,26 @@ class TestAwsGlueCatalogPartitionSensor(unittest.TestCase): @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition') def test_poke(self, mock_check_for_partition): mock_check_for_partition.return_value = True - op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, - table_name='tbl') + op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name='tbl') self.assertTrue(op.poke(None)) @mock_glue @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition') def test_poke_false(self, mock_check_for_partition): mock_check_for_partition.return_value = False - op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, - table_name='tbl') + op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name='tbl') self.assertFalse(op.poke(None)) @mock_glue @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition') def test_poke_default_args(self, mock_check_for_partition): table_name = 'test_glue_catalog_partition_sensor_tbl' - op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, - table_name=table_name) + op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name=table_name) op.poke(None) self.assertEqual(op.hook.region_name, None) self.assertEqual(op.hook.aws_conn_id, 'aws_default') - mock_check_for_partition.assert_called_once_with('default', - table_name, - "ds='{{ ds }}'") + mock_check_for_partition.assert_called_once_with('default', table_name, "ds='{{ ds }}'") @mock_glue @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition') @@ -75,32 +69,29 @@ def test_poke_nondefault_args(self, mock_check_for_partition): database_name = 'my_db' poke_interval = 2 timeout = 3 - op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, - table_name=table_name, - expression=expression, - aws_conn_id=aws_conn_id, - region_name=region_name, - database_name=database_name, - poke_interval=poke_interval, - timeout=timeout) + op = AwsGlueCatalogPartitionSensor( + task_id=self.task_id, + table_name=table_name, + expression=expression, + aws_conn_id=aws_conn_id, + region_name=region_name, + database_name=database_name, + poke_interval=poke_interval, + timeout=timeout, + ) op.poke(None) self.assertEqual(op.hook.region_name, region_name) self.assertEqual(op.hook.aws_conn_id, aws_conn_id) self.assertEqual(op.poke_interval, poke_interval) self.assertEqual(op.timeout, timeout) - mock_check_for_partition.assert_called_once_with(database_name, - table_name, - expression) + mock_check_for_partition.assert_called_once_with(database_name, table_name, expression) @mock_glue @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition') def test_dot_notation(self, mock_check_for_partition): db_table = 'my_db.my_tbl' - op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, - table_name=db_table) + op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name=db_table) op.poke(None) - mock_check_for_partition.assert_called_once_with('my_db', - 'my_tbl', - "ds='{{ ds }}'") + mock_check_for_partition.assert_called_once_with('my_db', 'my_tbl', "ds='{{ ds }}'") diff --git a/tests/providers/amazon/aws/sensors/test_redshift.py b/tests/providers/amazon/aws/sensors/test_redshift.py index 64bc7cbe29ead..ddd3e3179f6b9 100644 --- a/tests/providers/amazon/aws/sensors/test_redshift.py +++ b/tests/providers/amazon/aws/sensors/test_redshift.py @@ -37,7 +37,7 @@ def _create_cluster(): ClusterIdentifier='test_cluster', NodeType='dc1.large', MasterUsername='admin', - MasterUserPassword='mock_password' + MasterUserPassword='mock_password', ) if not client.describe_clusters()['Clusters']: raise ValueError('AWS not properly mocked') @@ -46,24 +46,28 @@ def _create_cluster(): @mock_redshift def test_poke(self): self._create_cluster() - op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor', - poke_interval=1, - timeout=5, - aws_conn_id='aws_default', - cluster_identifier='test_cluster', - target_status='available') + op = AwsRedshiftClusterSensor( + task_id='test_cluster_sensor', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + cluster_identifier='test_cluster', + target_status='available', + ) self.assertTrue(op.poke(None)) @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') @mock_redshift def test_poke_false(self): self._create_cluster() - op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor', - poke_interval=1, - timeout=5, - aws_conn_id='aws_default', - cluster_identifier='test_cluster_not_found', - target_status='available') + op = AwsRedshiftClusterSensor( + task_id='test_cluster_sensor', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + cluster_identifier='test_cluster_not_found', + target_status='available', + ) self.assertFalse(op.poke(None)) @@ -71,11 +75,13 @@ def test_poke_false(self): @mock_redshift def test_poke_cluster_not_found(self): self._create_cluster() - op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor', - poke_interval=1, - timeout=5, - aws_conn_id='aws_default', - cluster_identifier='test_cluster_not_found', - target_status='cluster_not_found') + op = AwsRedshiftClusterSensor( + task_id='test_cluster_sensor', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + cluster_identifier='test_cluster_not_found', + target_status='cluster_not_found', + ) self.assertTrue(op.poke(None)) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index 373600be22b64..36a13d0336c6a 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -26,7 +26,6 @@ class TestS3KeySensor(unittest.TestCase): - def test_bucket_name_none_and_bucket_key_as_relative_path(self): """ Test if exception is raised when bucket_name is None @@ -34,9 +33,7 @@ def test_bucket_name_none_and_bucket_key_as_relative_path(self): :return: """ with self.assertRaises(AirflowException): - S3KeySensor( - task_id='s3_key_sensor', - bucket_key="file_in_bucket") + S3KeySensor(task_id='s3_key_sensor', bucket_key="file_in_bucket") def test_bucket_name_provided_and_bucket_key_is_s3_url(self): """ @@ -46,28 +43,20 @@ def test_bucket_name_provided_and_bucket_key_is_s3_url(self): """ with self.assertRaises(AirflowException): S3KeySensor( - task_id='s3_key_sensor', - bucket_key="s3://test_bucket/file", - bucket_name='test_bucket') + task_id='s3_key_sensor', bucket_key="s3://test_bucket/file", bucket_name='test_bucket' + ) - @parameterized.expand([ - ['s3://bucket/key', None, 'key', 'bucket'], - ['key', 'bucket', 'key', 'bucket'], - ]) + @parameterized.expand( + [['s3://bucket/key', None, 'key', 'bucket'], ['key', 'bucket', 'key', 'bucket'],] + ) def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket): - op = S3KeySensor( - task_id='s3_key_sensor', - bucket_key=key, - bucket_name=bucket, - ) + op = S3KeySensor(task_id='s3_key_sensor', bucket_key=key, bucket_name=bucket,) self.assertEqual(op.bucket_key, parsed_key) self.assertEqual(op.bucket_name, parsed_bucket) @mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook') def test_poke(self, mock_hook): - op = S3KeySensor( - task_id='s3_key_sensor', - bucket_key='s3://test_bucket/file') + op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') mock_check_for_key = mock_hook.return_value.check_for_key mock_check_for_key.return_value = False @@ -79,10 +68,7 @@ def test_poke(self, mock_hook): @mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook') def test_poke_wildcard(self, mock_hook): - op = S3KeySensor( - task_id='s3_key_sensor', - bucket_key='s3://test_bucket/file', - wildcard_match=True) + op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key mock_check_for_wildcard_key.return_value = False diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py index 504c0a7446841..cff6f2dab6b0f 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py +++ b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py @@ -30,7 +30,6 @@ class TestS3KeysUnchangedSensor(TestCase): - def setUp(self): args = { 'owner': 'airflow', @@ -48,7 +47,7 @@ def setUp(self): poke_interval=0.1, min_objects=1, allow_delete=True, - dag=self.dag + dag=self.dag, ) def test_reschedule_mode_not_allowed(self): @@ -59,7 +58,7 @@ def test_reschedule_mode_not_allowed(self): prefix='test-prefix/path', poke_interval=0.1, mode='reschedule', - dag=self.dag + dag=self.dag, ) @freeze_time(DEFAULT_DATE, auto_tick_seconds=10) @@ -69,16 +68,18 @@ def test_files_deleted_between_pokes_throw_error(self): with self.assertRaises(AirflowException): self.sensor.is_keys_unchanged({'a'}) - @parameterized.expand([ - # Test: resetting inactivity period after key change - (({'a'}, {'a', 'b'}, {'a', 'b', 'c'}), (False, False, False), (0, 0, 0)), - # ..and in case an item was deleted with option `allow_delete=True` - (({'a', 'b'}, {'a'}, {'a', 'c'}), (False, False, False), (0, 0, 0)), - # Test: passes after inactivity period was exceeded - (({'a'}, {'a'}, {'a'}), (False, False, True), (0, 10, 20)), - # ..and do not pass if empty key is given - ((set(), set(), set()), (False, False, False), (0, 10, 20)) - ]) + @parameterized.expand( + [ + # Test: resetting inactivity period after key change + (({'a'}, {'a', 'b'}, {'a', 'b', 'c'}), (False, False, False), (0, 0, 0)), + # ..and in case an item was deleted with option `allow_delete=True` + (({'a', 'b'}, {'a'}, {'a', 'c'}), (False, False, False), (0, 0, 0)), + # Test: passes after inactivity period was exceeded + (({'a'}, {'a'}, {'a'}), (False, False, True), (0, 10, 20)), + # ..and do not pass if empty key is given + ((set(), set(), set()), (False, False, False), (0, 10, 20)), + ] + ) @freeze_time(DEFAULT_DATE, auto_tick_seconds=10) def test_key_changes(self, current_objects, expected_returns, inactivity_periods): self.assertEqual(self.sensor.is_keys_unchanged(current_objects[0]), expected_returns[0]) diff --git a/tests/providers/amazon/aws/sensors/test_s3_prefix.py b/tests/providers/amazon/aws/sensors/test_s3_prefix.py index 4950f07dc4237..a06d6f0186259 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_prefix.py +++ b/tests/providers/amazon/aws/sensors/test_s3_prefix.py @@ -23,20 +23,15 @@ class TestS3PrefixSensor(unittest.TestCase): - @mock.patch('airflow.providers.amazon.aws.sensors.s3_prefix.S3Hook') def test_poke(self, mock_hook): - op = S3PrefixSensor( - task_id='s3_prefix', - bucket_name='bucket', - prefix='prefix') + op = S3PrefixSensor(task_id='s3_prefix', bucket_name='bucket', prefix='prefix') mock_hook.return_value.check_for_prefix.return_value = False self.assertFalse(op.poke(None)) mock_hook.return_value.check_for_prefix.assert_called_once_with( - prefix='prefix', - delimiter='/', - bucket_name='bucket') + prefix='prefix', delimiter='/', bucket_name='bucket' + ) mock_hook.return_value.check_for_prefix.return_value = True self.assertTrue(op.poke(None)) diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py index aacbc4c6a72dc..a52f7e53827e5 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py @@ -32,19 +32,12 @@ def failed_states(self): return ['FAILED'] def get_sagemaker_response(self): - return { - 'SomeKey': {'State': 'COMPLETED'}, - 'ResponseMetadata': {'HTTPStatusCode': 200} - } + return {'SomeKey': {'State': 'COMPLETED'}, 'ResponseMetadata': {'HTTPStatusCode': 200}} def state_from_response(self, response): return response['SomeKey']['State'] - sensor = SageMakerBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test' - ) + sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test') sensor.execute(None) @@ -57,19 +50,12 @@ def failed_states(self): return ['FAILED'] def get_sagemaker_response(self): - return { - 'SomeKey': {'State': 'PENDING'}, - 'ResponseMetadata': {'HTTPStatusCode': 200} - } + return {'SomeKey': {'State': 'PENDING'}, 'ResponseMetadata': {'HTTPStatusCode': 200}} def state_from_response(self, response): return response['SomeKey']['State'] - sensor = SageMakerBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test' - ) + sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test') self.assertEqual(sensor.poke(None), False) @@ -81,11 +67,7 @@ def non_terminal_states(self): def failed_states(self): return ['FAILED'] - sensor = SageMakerBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test' - ) + sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test') self.assertRaises(NotImplementedError, sensor.poke, None) @@ -98,19 +80,12 @@ def failed_states(self): return ['FAILED'] def get_sagemaker_response(self): - return { - 'SomeKey': {'State': 'COMPLETED'}, - 'ResponseMetadata': {'HTTPStatusCode': 400} - } + return {'SomeKey': {'State': 'COMPLETED'}, 'ResponseMetadata': {'HTTPStatusCode': 400}} def state_from_response(self, response): return response['SomeKey']['State'] - sensor = SageMakerBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test' - ) + sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test') self.assertEqual(sensor.poke(None), False) @@ -123,18 +98,11 @@ def failed_states(self): return ['FAILED'] def get_sagemaker_response(self): - return { - 'SomeKey': {'State': 'FAILED'}, - 'ResponseMetadata': {'HTTPStatusCode': 200} - } + return {'SomeKey': {'State': 'FAILED'}, 'ResponseMetadata': {'HTTPStatusCode': 200}} def state_from_response(self, response): return response['SomeKey']['State'] - sensor = SageMakerBaseSensorSubclass( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test' - ) + sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test') self.assertRaises(AirflowException, sensor.poke, None) diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py index 992136fb591c7..56d8a6c7a1a68 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py @@ -26,30 +26,22 @@ DESCRIBE_ENDPOINT_CREATING_RESPONSE = { 'EndpointStatus': 'Creating', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_ENDPOINT_INSERVICE_RESPONSE = { 'EndpointStatus': 'InService', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_ENDPOINT_FAILED_RESPONSE = { 'EndpointStatus': 'Failed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - }, - 'FailureReason': 'Unknown' + 'ResponseMetadata': {'HTTPStatusCode': 200,}, + 'FailureReason': 'Unknown', } DESCRIBE_ENDPOINT_UPDATING_RESPONSE = { 'EndpointStatus': 'Updating', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } @@ -59,10 +51,7 @@ class TestSageMakerEndpointSensor(unittest.TestCase): def test_sensor_with_failure(self, mock_describe, mock_get_conn): mock_describe.side_effect = [DESCRIBE_ENDPOINT_FAILED_RESPONSE] sensor = SageMakerEndpointSensor( - task_id='test_task', - poke_interval=1, - aws_conn_id='aws_test', - endpoint_name='test_job_name' + task_id='test_task', poke_interval=1, aws_conn_id='aws_test', endpoint_name='test_job_name' ) self.assertRaises(AirflowException, sensor.execute, None) mock_describe.assert_called_once_with('test_job_name') @@ -76,13 +65,10 @@ def test_sensor(self, mock_describe, hook_init, mock_get_conn): mock_describe.side_effect = [ DESCRIBE_ENDPOINT_CREATING_RESPONSE, DESCRIBE_ENDPOINT_UPDATING_RESPONSE, - DESCRIBE_ENDPOINT_INSERVICE_RESPONSE + DESCRIBE_ENDPOINT_INSERVICE_RESPONSE, ] sensor = SageMakerEndpointSensor( - task_id='test_task', - poke_interval=1, - aws_conn_id='aws_test', - endpoint_name='test_job_name' + task_id='test_task', poke_interval=1, aws_conn_id='aws_test', endpoint_name='test_job_name' ) sensor.execute(None) @@ -91,7 +77,5 @@ def test_sensor(self, mock_describe, hook_init, mock_get_conn): self.assertEqual(mock_describe.call_count, 3) # make sure the hook was initialized with the specific params - calls = [ - mock.call(aws_conn_id='aws_test') - ] + calls = [mock.call(aws_conn_id='aws_test')] hook_init.assert_has_calls(calls) diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py index aeb020a28586f..c14d3a1f1c9b2 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py @@ -28,24 +28,17 @@ DESCRIBE_TRAINING_COMPLETED_RESPONSE = { 'TrainingJobStatus': 'Completed', - 'ResourceConfig': { - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge', - 'VolumeSizeInGB': 10 - }, + 'ResourceConfig': {'InstanceCount': 1, 'InstanceType': 'ml.c4.xlarge', 'VolumeSizeInGB': 10}, 'TrainingStartTime': datetime(2018, 2, 17, 7, 15, 0, 103000), 'TrainingEndTime': datetime(2018, 2, 17, 7, 19, 34, 953000), - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_TRAINING_INPROGRESS_RESPONSE = dict(DESCRIBE_TRAINING_COMPLETED_RESPONSE) DESCRIBE_TRAINING_INPROGRESS_RESPONSE.update({'TrainingJobStatus': 'InProgress'}) DESCRIBE_TRAINING_FAILED_RESPONSE = dict(DESCRIBE_TRAINING_COMPLETED_RESPONSE) -DESCRIBE_TRAINING_FAILED_RESPONSE.update({'TrainingJobStatus': 'Failed', - 'FailureReason': 'Unknown'}) +DESCRIBE_TRAINING_FAILED_RESPONSE.update({'TrainingJobStatus': 'Failed', 'FailureReason': 'Unknown'}) DESCRIBE_TRAINING_STOPPING_RESPONSE = dict(DESCRIBE_TRAINING_COMPLETED_RESPONSE) DESCRIBE_TRAINING_STOPPING_RESPONSE.update({'TrainingJobStatus': 'Stopping'}) @@ -64,7 +57,7 @@ def test_sensor_with_failure(self, mock_describe_job, hook_init, mock_client): poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', - print_log=False + print_log=False, ) self.assertRaises(AirflowException, sensor.execute, None) mock_describe_job.assert_called_once_with('test_job_name') @@ -78,14 +71,14 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): mock_describe_job.side_effect = [ DESCRIBE_TRAINING_INPROGRESS_RESPONSE, DESCRIBE_TRAINING_STOPPING_RESPONSE, - DESCRIBE_TRAINING_COMPLETED_RESPONSE + DESCRIBE_TRAINING_COMPLETED_RESPONSE, ] sensor = SageMakerTrainingSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', - print_log=False + print_log=False, ) sensor.execute(None) @@ -94,9 +87,7 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params - calls = [ - mock.call(aws_conn_id='aws_test') - ] + calls = [mock.call(aws_conn_id='aws_test')] hook_init.assert_has_calls(calls) @mock.patch.object(SageMakerHook, 'get_conn') @@ -104,22 +95,23 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): @mock.patch.object(SageMakerHook, '__init__') @mock.patch.object(SageMakerHook, 'describe_training_job_with_log') @mock.patch.object(SageMakerHook, 'describe_training_job') - def test_sensor_with_log(self, mock_describe_job, mock_describe_job_with_log, - hook_init, mock_log_client, mock_client): + def test_sensor_with_log( + self, mock_describe_job, mock_describe_job_with_log, hook_init, mock_log_client, mock_client + ): hook_init.return_value = None mock_describe_job.return_value = DESCRIBE_TRAINING_COMPLETED_RESPONSE mock_describe_job_with_log.side_effect = [ (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RESPONSE, 0), (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RESPONSE, 0), - (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RESPONSE, 0) + (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RESPONSE, 0), ] sensor = SageMakerTrainingSensor( task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name', - print_log=True + print_log=True, ) sensor.execute(None) @@ -127,7 +119,5 @@ def test_sensor_with_log(self, mock_describe_job, mock_describe_job_with_log, self.assertEqual(mock_describe_job_with_log.call_count, 3) self.assertEqual(mock_describe_job.call_count, 1) - calls = [ - mock.call(aws_conn_id='aws_test') - ] + calls = [mock.call(aws_conn_id='aws_test')] hook_init.assert_has_calls(calls) diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py index 6548c29aa2fd7..d548e483ec865 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py @@ -26,28 +26,20 @@ DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE = { 'TransformJobStatus': 'InProgress', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_TRANSFORM_COMPLETED_RESPONSE = { 'TransformJobStatus': 'Completed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_TRANSFORM_FAILED_RESPONSE = { 'TransformJobStatus': 'Failed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - }, - 'FailureReason': 'Unknown' + 'ResponseMetadata': {'HTTPStatusCode': 200,}, + 'FailureReason': 'Unknown', } DESCRIBE_TRANSFORM_STOPPING_RESPONSE = { 'TransformJobStatus': 'Stopping', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } @@ -57,10 +49,7 @@ class TestSageMakerTransformSensor(unittest.TestCase): def test_sensor_with_failure(self, mock_describe_job, mock_client): mock_describe_job.side_effect = [DESCRIBE_TRANSFORM_FAILED_RESPONSE] sensor = SageMakerTransformSensor( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test', - job_name='test_job_name' + task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name' ) self.assertRaises(AirflowException, sensor.execute, None) mock_describe_job.assert_called_once_with('test_job_name') @@ -74,13 +63,10 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): mock_describe_job.side_effect = [ DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE, DESCRIBE_TRANSFORM_STOPPING_RESPONSE, - DESCRIBE_TRANSFORM_COMPLETED_RESPONSE + DESCRIBE_TRANSFORM_COMPLETED_RESPONSE, ] sensor = SageMakerTransformSensor( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test', - job_name='test_job_name' + task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name' ) sensor.execute(None) @@ -89,7 +75,5 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params - calls = [ - mock.call(aws_conn_id='aws_test') - ] + calls = [mock.call(aws_conn_id='aws_test')] hook_init.assert_has_calls(calls) diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py index 94bb3431e0280..ade7edf312e73 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py @@ -26,31 +26,23 @@ DESCRIBE_TUNING_INPROGRESS_RESPONSE = { 'HyperParameterTuningJobStatus': 'InProgress', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_TUNING_COMPLETED_RESPONSE = { 'HyperParameterTuningJobStatus': 'Completed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } DESCRIBE_TUNING_FAILED_RESPONSE = { 'HyperParameterTuningJobStatus': 'Failed', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - }, - 'FailureReason': 'Unknown' + 'ResponseMetadata': {'HTTPStatusCode': 200,}, + 'FailureReason': 'Unknown', } DESCRIBE_TUNING_STOPPING_RESPONSE = { 'HyperParameterTuningJobStatus': 'Stopping', - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - } + 'ResponseMetadata': {'HTTPStatusCode': 200,}, } @@ -60,10 +52,7 @@ class TestSageMakerTuningSensor(unittest.TestCase): def test_sensor_with_failure(self, mock_describe_job, mock_client): mock_describe_job.side_effect = [DESCRIBE_TUNING_FAILED_RESPONSE] sensor = SageMakerTuningSensor( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test', - job_name='test_job_name' + task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name' ) self.assertRaises(AirflowException, sensor.execute, None) mock_describe_job.assert_called_once_with('test_job_name') @@ -77,13 +66,10 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): mock_describe_job.side_effect = [ DESCRIBE_TUNING_INPROGRESS_RESPONSE, DESCRIBE_TUNING_STOPPING_RESPONSE, - DESCRIBE_TUNING_COMPLETED_RESPONSE + DESCRIBE_TUNING_COMPLETED_RESPONSE, ] sensor = SageMakerTuningSensor( - task_id='test_task', - poke_interval=2, - aws_conn_id='aws_test', - job_name='test_job_name' + task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name' ) sensor.execute(None) @@ -92,7 +78,5 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client): self.assertEqual(mock_describe_job.call_count, 3) # make sure the hook was initialized with the specific params - calls = [ - mock.call(aws_conn_id='aws_test') - ] + calls = [mock.call(aws_conn_id='aws_test')] hook_init.assert_has_calls(calls) diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py index 032cf19503dad..ee9b544f47fb2 100644 --- a/tests/providers/amazon/aws/sensors/test_sqs.py +++ b/tests/providers/amazon/aws/sensors/test_sqs.py @@ -32,19 +32,12 @@ class TestSQSSensor(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) self.sensor = SQSSensor( - task_id='test_task', - dag=self.dag, - sqs_queue='test', - aws_conn_id='aws_default' + task_id='test_task', dag=self.dag, sqs_queue='test', aws_conn_id='aws_default' ) self.mock_context = mock.MagicMock() @@ -58,8 +51,10 @@ def test_poke_success(self): result = self.sensor.poke(self.mock_context) self.assertTrue(result) - self.assertTrue("'Body': 'hello'" in str(self.mock_context['ti'].method_calls), - "context call should contain message hello") + self.assertTrue( + "'Body': 'hello'" in str(self.mock_context['ti'].method_calls), + "context call should contain message hello", + ) @mock_sqs def test_poke_no_messsage_failed(self): @@ -74,20 +69,31 @@ def test_poke_no_messsage_failed(self): @mock.patch.object(SQSHook, 'get_conn') def test_poke_delete_raise_airflow_exception(self, mock_conn): - message = {'Messages': [{'MessageId': 'c585e508-2ea0-44c7-bf3e-d1ba0cb87834', - 'ReceiptHandle': 'mockHandle', - 'MD5OfBody': 'e5a9d8684a8edfed460b8d42fd28842f', - 'Body': 'h21'}], - 'ResponseMetadata': {'RequestId': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411', - 'HTTPStatusCode': 200, - 'HTTPHeaders': { - 'x-amzn-requestid': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411', - 'date': 'Mon, 18 Feb 2019 18:41:52 GMT', - 'content-type': 'text/xml', 'mock_sqs_hook-length': '830'}, - 'RetryAttempts': 0}} + message = { + 'Messages': [ + { + 'MessageId': 'c585e508-2ea0-44c7-bf3e-d1ba0cb87834', + 'ReceiptHandle': 'mockHandle', + 'MD5OfBody': 'e5a9d8684a8edfed460b8d42fd28842f', + 'Body': 'h21', + } + ], + 'ResponseMetadata': { + 'RequestId': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411', + 'HTTPStatusCode': 200, + 'HTTPHeaders': { + 'x-amzn-requestid': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411', + 'date': 'Mon, 18 Feb 2019 18:41:52 GMT', + 'content-type': 'text/xml', + 'mock_sqs_hook-length': '830', + }, + 'RetryAttempts': 0, + }, + } mock_conn.return_value.receive_message.return_value = message - mock_conn.return_value.delete_message_batch.return_value = \ - {'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}]} + mock_conn.return_value.delete_message_batch.return_value = { + 'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}] + } with self.assertRaises(AirflowException) as context: self.sensor.poke(self.mock_context) diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py index 237f8ef424cb2..bbfffacbb9c5c 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function_execution.py +++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py @@ -26,23 +26,21 @@ from airflow.providers.amazon.aws.sensors.step_function_execution import StepFunctionExecutionSensor TASK_ID = 'step_function_execution_sensor' -EXECUTION_ARN = 'arn:aws:states:us-east-1:123456789012:execution:'\ - 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +EXECUTION_ARN = ( + 'arn:aws:states:us-east-1:123456789012:execution:' + 'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934' +) AWS_CONN_ID = 'aws_non_default' REGION_NAME = 'us-west-2' class TestStepFunctionExecutionSensor(unittest.TestCase): - def setUp(self): self.mock_context = MagicMock() def test_init(self): sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) self.assertEqual(TASK_ID, sensor.task_id) @@ -53,18 +51,13 @@ def test_init(self): @parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)]) @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') def test_exceptions(self, mock_status, mock_hook): - hook_response = { - 'status': mock_status - } + hook_response = {'status': mock_status} hook_instance = mock_hook.return_value hook_instance.describe_execution.return_value = hook_response sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) with self.assertRaises(AirflowException): @@ -72,36 +65,26 @@ def test_exceptions(self, mock_status, mock_hook): @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') def test_running(self, mock_hook): - hook_response = { - 'status': 'RUNNING' - } + hook_response = {'status': 'RUNNING'} hook_instance = mock_hook.return_value hook_instance.describe_execution.return_value = hook_response sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) self.assertFalse(sensor.poke(self.mock_context)) @mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook') def test_succeeded(self, mock_hook): - hook_response = { - 'status': 'SUCCEEDED' - } + hook_response = {'status': 'SUCCEEDED'} hook_instance = mock_hook.return_value hook_instance.describe_execution.return_value = hook_response sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, - execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME ) self.assertTrue(sensor.poke(self.mock_context)) diff --git a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py index ee27aa0cfd285..2bf69addd0e4e 100644 --- a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py @@ -24,7 +24,6 @@ class DynamodbToS3Test(unittest.TestCase): - def setUp(self): self.output_queue = [] @@ -38,13 +37,8 @@ def mock_upload_file(self, Filename, Bucket, Key): # pylint: disable=unused-arg @patch('airflow.providers.amazon.aws.transfers.dynamodb_to_s3.AwsDynamoDBHook') def test_dynamodb_to_s3_success(self, mock_aws_dynamodb_hook, mock_s3_hook): responses = [ - { - 'Items': [{'a': 1}, {'b': 2}], - 'LastEvaluatedKey': '123', - }, - { - 'Items': [{'c': 3}], - }, + {'Items': [{'a': 1}, {'b': 2}], 'LastEvaluatedKey': '123',}, + {'Items': [{'c': 3}],}, ] table = MagicMock() table.return_value.scan.side_effect = responses diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py index eb5d0582f0bbd..83b1239dc800f 100644 --- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py @@ -47,13 +47,15 @@ def test_execute_incremental(self, mock_hook, mock_hook2): mock_hook.return_value.download.return_value = b"testing" mock_hook2.return_value.list.return_value = MOCK_FILES - operator = GCSToS3Operator(task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + ) # create dest bucket hook = S3Hook(aws_conn_id='airflow_gcs_test') bucket = hook.get_bucket('bucket') @@ -63,10 +65,8 @@ def test_execute_incremental(self, mock_hook, mock_hook2): # we expect all except first file in MOCK_FILES to be uploaded # and all the MOCK_FILES to be present at the S3 bucket uploaded_files = operator.execute(None) - self.assertEqual(sorted(MOCK_FILES[1:]), - sorted(uploaded_files)) - self.assertEqual(sorted(MOCK_FILES), - sorted(hook.list_keys('bucket', delimiter='/'))) + self.assertEqual(sorted(MOCK_FILES[1:]), sorted(uploaded_files)) + self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/'))) # Test2: All the files are already in origin and destination without replace @mock_s3 @@ -77,13 +77,15 @@ def test_execute_without_replace(self, mock_hook, mock_hook2): mock_hook.return_value.download.return_value = b"testing" mock_hook2.return_value.list.return_value = MOCK_FILES - operator = GCSToS3Operator(task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + ) # create dest bucket with all the files hook = S3Hook(aws_conn_id='airflow_gcs_test') bucket = hook.get_bucket('bucket') @@ -94,10 +96,8 @@ def test_execute_without_replace(self, mock_hook, mock_hook2): # we expect nothing to be uploaded # and all the MOCK_FILES to be present at the S3 bucket uploaded_files = operator.execute(None) - self.assertEqual([], - uploaded_files) - self.assertEqual(sorted(MOCK_FILES), - sorted(hook.list_keys('bucket', delimiter='/'))) + self.assertEqual([], uploaded_files) + self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/'))) # Test3: There are no files in destination bucket @mock_s3 @@ -108,13 +108,15 @@ def test_execute(self, mock_hook, mock_hook2): mock_hook.return_value.download.return_value = b"testing" mock_hook2.return_value.list.return_value = MOCK_FILES - operator = GCSToS3Operator(task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + ) # create dest bucket without files hook = S3Hook(aws_conn_id='airflow_gcs_test') bucket = hook.get_bucket('bucket') @@ -123,10 +125,8 @@ def test_execute(self, mock_hook, mock_hook2): # we expect all MOCK_FILES to be uploaded # and all MOCK_FILES to be present at the S3 bucket uploaded_files = operator.execute(None) - self.assertEqual(sorted(MOCK_FILES), - sorted(uploaded_files)) - self.assertEqual(sorted(MOCK_FILES), - sorted(hook.list_keys('bucket', delimiter='/'))) + self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files)) + self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/'))) # Test4: Destination and Origin are in sync but replace all files in destination @mock_s3 @@ -137,13 +137,15 @@ def test_execute_with_replace(self, mock_hook, mock_hook2): mock_hook.return_value.download.return_value = b"testing" mock_hook2.return_value.list.return_value = MOCK_FILES - operator = GCSToS3Operator(task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=True) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=True, + ) # create dest bucket with all the files hook = S3Hook(aws_conn_id='airflow_gcs_test') bucket = hook.get_bucket('bucket') @@ -154,10 +156,8 @@ def test_execute_with_replace(self, mock_hook, mock_hook2): # we expect all MOCK_FILES to be uploaded and replace the existing ones # and all MOCK_FILES to be present at the S3 bucket uploaded_files = operator.execute(None) - self.assertEqual(sorted(MOCK_FILES), - sorted(uploaded_files)) - self.assertEqual(sorted(MOCK_FILES), - sorted(hook.list_keys('bucket', delimiter='/'))) + self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files)) + self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/'))) # Test5: Incremental sync with replace @mock_s3 @@ -168,13 +168,15 @@ def test_execute_incremental_with_replace(self, mock_hook, mock_hook2): mock_hook.return_value.download.return_value = b"testing" mock_hook2.return_value.list.return_value = MOCK_FILES - operator = GCSToS3Operator(task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=True) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=True, + ) # create dest bucket with just two files (the first two files in MOCK_FILES) hook = S3Hook(aws_conn_id='airflow_gcs_test') bucket = hook.get_bucket('bucket') @@ -185,7 +187,5 @@ def test_execute_incremental_with_replace(self, mock_hook, mock_hook2): # we expect all the MOCK_FILES to be uploaded and replace the existing ones # and all MOCK_FILES to be present at the S3 bucket uploaded_files = operator.execute(None) - self.assertEqual(sorted(MOCK_FILES), - sorted(uploaded_files)) - self.assertEqual(sorted(MOCK_FILES), - sorted(hook.list_keys('bucket', delimiter='/'))) + self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files)) + self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/'))) diff --git a/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py b/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py index 0283937d6ea0a..44d8e3f6d404a 100644 --- a/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py @@ -27,7 +27,6 @@ class TestGoogleApiToS3(unittest.TestCase): - def setUp(self): load_test_config() @@ -38,7 +37,7 @@ def setUp(self): conn_type="google_cloud_platform", schema='refresh_token', login='client_id', - password='client_secret' + password='client_secret', ) ) db.merge_conn( @@ -47,7 +46,7 @@ def setUp(self): conn_type='s3', schema='test', extra='{"aws_access_key_id": "aws_access_key_id", "aws_secret_access_key":' - ' "aws_secret_access_key"}' + ' "aws_secret_access_key"}', ) ) @@ -63,7 +62,7 @@ def setUp(self): 's3_destination_key': 'test/google_api_to_s3_test.csv', 's3_overwrite': True, 'task_id': 'task_id', - 'dag': None + 'dag': None, } @patch('airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleDiscoveryApiHook.query') @@ -78,13 +77,13 @@ def test_execute(self, mock_json_dumps, mock_s3_hook_load_string, mock_google_ap endpoint=self.kwargs['google_api_endpoint_path'], data=self.kwargs['google_api_endpoint_params'], paginate=self.kwargs['google_api_pagination'], - num_retries=self.kwargs['google_api_num_retries'] + num_retries=self.kwargs['google_api_num_retries'], ) mock_json_dumps.assert_called_once_with(mock_google_api_hook_query.return_value) mock_s3_hook_load_string.assert_called_once_with( string_data=mock_json_dumps.return_value, key=self.kwargs['s3_destination_key'], - replace=self.kwargs['s3_overwrite'] + replace=self.kwargs['s3_overwrite'], ) context['task_instance'].xcom_pull.assert_not_called() context['task_instance'].xcom_push.assert_not_called() @@ -107,36 +106,30 @@ def test_execute_with_xcom(self, mock_json_dumps, mock_s3_hook_load_string, mock endpoint=self.kwargs['google_api_endpoint_path'], data=self.kwargs['google_api_endpoint_params'], paginate=self.kwargs['google_api_pagination'], - num_retries=self.kwargs['google_api_num_retries'] + num_retries=self.kwargs['google_api_num_retries'], ) mock_json_dumps.assert_called_once_with(mock_google_api_hook_query.return_value) mock_s3_hook_load_string.assert_called_once_with( string_data=mock_json_dumps.return_value, key=self.kwargs['s3_destination_key'], - replace=self.kwargs['s3_overwrite'] + replace=self.kwargs['s3_overwrite'], ) context['task_instance'].xcom_pull.assert_called_once_with( task_ids=xcom_kwargs['google_api_endpoint_params_via_xcom_task_ids'], - key=xcom_kwargs['google_api_endpoint_params_via_xcom'] + key=xcom_kwargs['google_api_endpoint_params_via_xcom'], ) context['task_instance'].xcom_push.assert_called_once_with( - key=xcom_kwargs['google_api_response_via_xcom'], - value=mock_google_api_hook_query.return_value + key=xcom_kwargs['google_api_response_via_xcom'], value=mock_google_api_hook_query.return_value ) @patch('airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleDiscoveryApiHook.query') @patch('airflow.providers.amazon.aws.transfers.google_api_to_s3.S3Hook.load_string') @patch('airflow.providers.amazon.aws.transfers.google_api_to_s3.json.dumps') @patch( - 'airflow.providers.amazon.aws.transfers.google_api_to_s3.sys.getsizeof', - return_value=MAX_XCOM_SIZE + 'airflow.providers.amazon.aws.transfers.google_api_to_s3.sys.getsizeof', return_value=MAX_XCOM_SIZE ) def test_execute_with_xcom_exceeded_max_xcom_size( - self, - mock_sys_getsizeof, - mock_json_dumps, - mock_s3_hook_load_string, - mock_google_api_hook_query + self, mock_sys_getsizeof, mock_json_dumps, mock_s3_hook_load_string, mock_google_api_hook_query ): context = {'task_instance': Mock()} xcom_kwargs = { @@ -146,24 +139,23 @@ def test_execute_with_xcom_exceeded_max_xcom_size( } context['task_instance'].xcom_pull.return_value = {} - self.assertRaises(RuntimeError, - GoogleApiToS3Operator(**self.kwargs, **xcom_kwargs).execute, context) + self.assertRaises(RuntimeError, GoogleApiToS3Operator(**self.kwargs, **xcom_kwargs).execute, context) mock_google_api_hook_query.assert_called_once_with( endpoint=self.kwargs['google_api_endpoint_path'], data=self.kwargs['google_api_endpoint_params'], paginate=self.kwargs['google_api_pagination'], - num_retries=self.kwargs['google_api_num_retries'] + num_retries=self.kwargs['google_api_num_retries'], ) mock_json_dumps.assert_called_once_with(mock_google_api_hook_query.return_value) mock_s3_hook_load_string.assert_called_once_with( string_data=mock_json_dumps.return_value, key=self.kwargs['s3_destination_key'], - replace=self.kwargs['s3_overwrite'] + replace=self.kwargs['s3_overwrite'], ) context['task_instance'].xcom_pull.assert_called_once_with( task_ids=xcom_kwargs['google_api_endpoint_params_via_xcom_task_ids'], - key=xcom_kwargs['google_api_endpoint_params_via_xcom'] + key=xcom_kwargs['google_api_endpoint_params_via_xcom'], ) context['task_instance'].xcom_push.assert_not_called() mock_sys_getsizeof.assert_called_once_with(mock_google_api_hook_query.return_value) diff --git a/tests/providers/amazon/aws/transfers/test_google_api_to_s3_system.py b/tests/providers/amazon/aws/transfers/test_google_api_to_s3_system.py index f9e6b3b58080e..a60b2fe29a6ac 100644 --- a/tests/providers/amazon/aws/transfers/test_google_api_to_s3_system.py +++ b/tests/providers/amazon/aws/transfers/test_google_api_to_s3_system.py @@ -28,7 +28,10 @@ from airflow.providers.amazon.aws.hooks.s3 import S3Hook from tests.providers.google.cloud.utils.gcp_authenticator import GMP_KEY from tests.test_utils.amazon_system_helpers import ( - AWS_DAG_FOLDER, AmazonSystemTest, provide_aws_context, provide_aws_s3_bucket, + AWS_DAG_FOLDER, + AmazonSystemTest, + provide_aws_context, + provide_aws_s3_bucket, ) from tests.test_utils.gcp_system_helpers import GoogleSystemTest, provide_gcp_context @@ -51,7 +54,6 @@ def provide_s3_bucket_advanced(): @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GMP_KEY) class GoogleApiToS3TransferExampleDagsSystemTest(GoogleSystemTest, AmazonSystemTest): - @pytest.mark.usefixtures("provide_s3_bucket_basic") @provide_aws_context() @provide_gcp_context(GMP_KEY, scopes=['https://www.googleapis.com/auth/spreadsheets.readonly']) diff --git a/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py b/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py index b27a9c482029e..179cf5fb4d18b 100644 --- a/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py +++ b/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py @@ -39,14 +39,12 @@ class TestHiveToDynamoDBOperator(unittest.TestCase): - def setUp(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) self.dag = dag self.sql = 'SELECT 1' - self.hook = AwsDynamoDBHook( - aws_conn_id='aws_default', region_name='us-east-1') + self.hook = AwsDynamoDBHook(aws_conn_id='aws_default', region_name='us-east-1') @staticmethod def process_data(data, *args, **kwargs): @@ -58,30 +56,19 @@ def test_get_conn_returns_a_boto3_connection(self): hook = AwsDynamoDBHook(aws_conn_id='aws_default') self.assertIsNotNone(hook.get_conn()) - @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df', - return_value=pd.DataFrame(data=[('1', 'sid')], columns=['id', 'name'])) + @mock.patch( + 'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df', + return_value=pd.DataFrame(data=[('1', 'sid')], columns=['id', 'name']), + ) @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') @mock_dynamodb2 def test_get_records_with_schema(self, mock_get_pandas_df): # this table needs to be created in production self.hook.get_conn().create_table( TableName='test_airflow', - KeySchema=[ - { - 'AttributeName': 'id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'id', - 'AttributeType': 'S' - } - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 - } + KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'},], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, ) operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator( @@ -89,39 +76,28 @@ def test_get_records_with_schema(self, mock_get_pandas_df): table_name="test_airflow", task_id='hive_to_dynamodb_check', table_keys=['id'], - dag=self.dag) + dag=self.dag, + ) operator.execute(None) table = self.hook.get_conn().Table('test_airflow') - table.meta.client.get_waiter( - 'table_exists').wait(TableName='test_airflow') + table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow') self.assertEqual(table.item_count, 1) - @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df', - return_value=pd.DataFrame(data=[('1', 'sid'), ('1', 'gupta')], columns=['id', 'name'])) + @mock.patch( + 'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df', + return_value=pd.DataFrame(data=[('1', 'sid'), ('1', 'gupta')], columns=['id', 'name']), + ) @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') @mock_dynamodb2 def test_pre_process_records_with_schema(self, mock_get_pandas_df): # this table needs to be created in production self.hook.get_conn().create_table( TableName='test_airflow', - KeySchema=[ - { - 'AttributeName': 'id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'id', - 'AttributeType': 'S' - } - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 - } + KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'},], + AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10}, ) operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator( @@ -130,7 +106,8 @@ def test_pre_process_records_with_schema(self, mock_get_pandas_df): task_id='hive_to_dynamodb_check', table_keys=['id'], pre_process=self.process_data, - dag=self.dag) + dag=self.dag, + ) operator.execute(None) diff --git a/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3.py b/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3.py index 18e87887d24f6..16bdf590eafba 100644 --- a/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3.py @@ -23,7 +23,6 @@ class TestImapAttachmentToS3Operator(unittest.TestCase): - def setUp(self): self.kwargs = dict( imap_attachment_name='test_file', @@ -33,7 +32,7 @@ def setUp(self): imap_mail_filter='All', s3_overwrite=False, task_id='test_task', - dag=None + dag=None, ) @patch('airflow.providers.amazon.aws.transfers.imap_attachment_to_s3.S3Hook') @@ -49,10 +48,10 @@ def test_execute(self, mock_imap_hook, mock_s3_hook): check_regex=self.kwargs['imap_check_regex'], latest_only=True, mail_folder=self.kwargs['imap_mail_folder'], - mail_filter=self.kwargs['imap_mail_filter'] + mail_filter=self.kwargs['imap_mail_filter'], ) mock_s3_hook.return_value.load_bytes.assert_called_once_with( bytes_data=mock_imap_hook.return_value.retrieve_mail_attachments.return_value[0][1], key=self.kwargs['s3_key'], - replace=self.kwargs['s3_overwrite'] + replace=self.kwargs['s3_overwrite'], ) diff --git a/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3_system.py b/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3_system.py index 51b26f6d377a0..a8a91cd6e4651 100644 --- a/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3_system.py +++ b/tests/providers/amazon/aws/transfers/test_imap_attachment_to_s3_system.py @@ -20,7 +20,10 @@ from airflow.providers.amazon.aws.example_dags.example_imap_attachment_to_s3 import S3_DESTINATION_KEY from airflow.providers.amazon.aws.hooks.s3 import S3Hook from tests.test_utils.amazon_system_helpers import ( - AWS_DAG_FOLDER, AmazonSystemTest, provide_aws_context, provide_aws_s3_bucket, + AWS_DAG_FOLDER, + AmazonSystemTest, + provide_aws_context, + provide_aws_s3_bucket, ) BUCKET, _ = S3Hook.parse_s3_url(S3_DESTINATION_KEY) @@ -35,7 +38,6 @@ def provide_s3_bucket(): @pytest.mark.backend("mysql", "postgres") @pytest.mark.system("imap") class TestImapAttachmentToS3ExampleDags(AmazonSystemTest): - @pytest.mark.usefixtures("provide_s3_bucket") @provide_aws_context() def test_run_example_dag_imap_attachment_to_s3(self): diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py index 81db1746f47d9..f53444cd9da3b 100644 --- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py @@ -35,17 +35,13 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) MOCK_MONGO_RETURN = [ {'example_return_key_1': 'example_return_value_1'}, - {'example_return_key_2': 'example_return_value_2'} + {'example_return_key_2': 'example_return_value_2'}, ] class TestMongoToS3Operator(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) @@ -57,7 +53,7 @@ def setUp(self): mongo_query=MONGO_QUERY, s3_bucket=S3_BUCKET, s3_key=S3_KEY, - dag=self.dag + dag=self.dag, ) def test_init(self): @@ -78,10 +74,7 @@ def test_render_template(self): expected_rendered_template = {'$lt': '2017-01-01T00:00:00+00:00Z'} - self.assertDictEqual( - expected_rendered_template, - getattr(self.mock_operator, 'mongo_query') - ) + self.assertDictEqual(expected_rendered_template, getattr(self.mock_operator, 'mongo_query')) @mock.patch('airflow.providers.amazon.aws.transfers.mongo_to_s3.MongoHook') @mock.patch('airflow.providers.amazon.aws.transfers.mongo_to_s3.S3Hook') @@ -94,9 +87,7 @@ def test_execute(self, mock_s3_hook, mock_mongo_hook): operator.execute(None) mock_mongo_hook.return_value.find.assert_called_once_with( - mongo_collection=MONGO_COLLECTION, - query=MONGO_QUERY, - mongo_db=None + mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None ) op_stringify = self.mock_operator._stringify @@ -105,8 +96,5 @@ def test_execute(self, mock_s3_hook, mock_mongo_hook): s3_doc_str = op_stringify(op_transform(MOCK_MONGO_RETURN)) mock_s3_hook.return_value.load_string.assert_called_once_with( - string_data=s3_doc_str, - key=S3_KEY, - bucket_name=S3_BUCKET, - replace=False + string_data=s3_doc_str, key=S3_KEY, bucket_name=S3_BUCKET, replace=False ) diff --git a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py index 66172b4c40ead..6f3eba18fbda3 100644 --- a/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mysql_to_s3.py @@ -26,7 +26,6 @@ class TestMySqlToS3Operator(unittest.TestCase): - @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.NamedTemporaryFile") @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.S3Hook") @mock.patch("airflow.providers.amazon.aws.transfers.mysql_to_s3.MySqlHook") @@ -41,15 +40,16 @@ def test_execute(self, mock_mysql_hook, mock_s3_hook, temp_mock): with NamedTemporaryFile() as f: temp_mock.return_value.__enter__.return_value.name = f.name - op = MySQLToS3Operator(query=query, - s3_bucket=s3_bucket, - s3_key=s3_key, - mysql_conn_id="mysql_conn_id", - aws_conn_id="aws_conn_id", - task_id="task_id", - pd_csv_kwargs={'index': False, 'header': False}, - dag=None - ) + op = MySQLToS3Operator( + query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + mysql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + pd_csv_kwargs={'index': False, 'header': False}, + dag=None, + ) op.execute(None) mock_mysql_hook.assert_called_once_with(mysql_conn_id="mysql_conn_id") mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) @@ -57,6 +57,6 @@ def test_execute(self, mock_mysql_hook, mock_s3_hook, temp_mock): get_pandas_df_mock.assert_called_once_with(query) temp_mock.assert_called_once_with(mode='r+', suffix=".csv") - mock_s3_hook.return_value.load_file.assert_called_once_with(filename=f.name, - key=s3_key, - bucket_name=s3_bucket) + mock_s3_hook.return_value.load_file.assert_called_once_with( + filename=f.name, key=s3_key, bucket_name=s3_bucket + ) diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py index 84783fa307499..518d992309239 100644 --- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py @@ -28,14 +28,14 @@ class TestRedshiftToS3Transfer(unittest.TestCase): - - @parameterized.expand([ - [True, "key/table_"], - [False, "key"], - ]) + @parameterized.expand( + [[True, "key/table_"], [False, "key"],] + ) @mock.patch("boto3.session.Session") @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run") - def test_execute(self, table_as_file_name, expected_s3_key, mock_run, mock_session,): + def test_execute( + self, table_as_file_name, expected_s3_key, mock_run, mock_session, + ): access_key = "aws_access_key_id" secret_key = "aws_secret_access_key" mock_session.return_value = Session(access_key, secret_key) @@ -43,7 +43,9 @@ def test_execute(self, table_as_file_name, expected_s3_key, mock_run, mock_sessi table = "table" s3_bucket = "bucket" s3_key = "key" - unload_options = ['HEADER', ] + unload_options = [ + 'HEADER', + ] RedshiftToS3Operator( schema=schema, @@ -56,7 +58,7 @@ def test_execute(self, table_as_file_name, expected_s3_key, mock_run, mock_sessi aws_conn_id="aws_conn_id", task_id="task_id", table_as_file_name=table_as_file_name, - dag=None + dag=None, ).execute(None) unload_options = '\n\t\t\t'.join(unload_options) @@ -67,12 +69,14 @@ def test_execute(self, table_as_file_name, expected_s3_key, mock_run, mock_sessi with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {unload_options}; - """.format(select_query=select_query, - s3_bucket=s3_bucket, - s3_key=expected_s3_key, - access_key=access_key, - secret_key=secret_key, - unload_options=unload_options) + """.format( + select_query=select_query, + s3_bucket=s3_bucket, + s3_key=expected_s3_key, + access_key=access_key, + secret_key=secret_key, + unload_options=unload_options, + ) assert mock_run.call_count == 1 assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index 0c619db8d6097..1f2d1b2edf775 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -27,7 +27,6 @@ class TestS3ToRedshiftTransfer(unittest.TestCase): - @mock.patch("boto3.session.Session") @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run") def test_execute(self, mock_run, mock_session): @@ -50,7 +49,8 @@ def test_execute(self, mock_run, mock_session): redshift_conn_id="redshift_conn_id", aws_conn_id="aws_conn_id", task_id="task_id", - dag=None) + dag=None, + ) op.execute(None) copy_query = """ @@ -59,13 +59,15 @@ def test_execute(self, mock_run, mock_session): with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {copy_options}; - """.format(schema=schema, - table=table, - s3_bucket=s3_bucket, - s3_key=s3_key, - access_key=access_key, - secret_key=secret_key, - copy_options=copy_options) + """.format( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + access_key=access_key, + secret_key=secret_key, + copy_options=copy_options, + ) assert mock_run.call_count == 1 assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py index 85e5e5bb51b02..a39be307a1bb2 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py @@ -73,9 +73,10 @@ def setUp(self): @conf_vars({("core", "enable_xcom_pickling"): "True"}) def test_s3_to_sftp_operation(self): # Setting - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ + test_remote_file_content = ( + "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" + ) # Test for creation of s3 bucket conn = boto3.client('s3') @@ -87,8 +88,7 @@ def test_s3_to_sftp_operation(self): self.s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET) # Check if object was created in s3 - objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, - Prefix=self.s3_key) + objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key) # there should be object found, and there should only be one object found self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) @@ -103,7 +103,7 @@ def test_s3_to_sftp_operation(self): sftp_conn_id=SFTP_CONN_ID, s3_conn_id=S3_CONN_ID, task_id=TASK_ID, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(run_task) @@ -115,14 +115,15 @@ def test_s3_to_sftp_operation(self): ssh_hook=self.hook, command="cat {0}".format(self.sftp_path), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(check_file_task) ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) ti3.run() self.assertEqual( ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), - test_remote_file_content.encode('utf-8')) + test_remote_file_content.encode('utf-8'), + ) # Clean up after finishing with test conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) @@ -136,7 +137,7 @@ def delete_remote_resource(self): ssh_hook=self.hook, command="rm {0}".format(self.sftp_path), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(remove_file_task) ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) diff --git a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py index 3621104afce0c..381a463e0a056 100644 --- a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py @@ -44,7 +44,6 @@ class TestSFTPToS3Operator(unittest.TestCase): - @mock_s3 def setUp(self): hook = SSHHook(ssh_conn_id='ssh_default') @@ -73,18 +72,18 @@ def setUp(self): @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_sftp_to_s3_operation(self): # Setting - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ + test_remote_file_content = ( + "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" + ) # create a test file remotely create_file_task = SSHOperator( task_id="test_create_file", ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.sftp_path), + command="echo '{0}' > {1}".format(test_remote_file_content, self.sftp_path), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(create_file_task) ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) @@ -103,15 +102,14 @@ def test_sftp_to_s3_operation(self): sftp_conn_id=SFTP_CONN_ID, s3_conn_id=S3_CONN_ID, task_id='test_sftp_to_s3', - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(run_task) run_task.execute(None) # Check if object was created in s3 - objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, - Prefix=self.s3_key) + objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key) # there should be object found, and there should only be one object found self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) diff --git a/tests/providers/apache/cassandra/hooks/test_cassandra.py b/tests/providers/apache/cassandra/hooks/test_cassandra.py index a519c67d5156a..06cf665918249 100644 --- a/tests/providers/apache/cassandra/hooks/test_cassandra.py +++ b/tests/providers/apache/cassandra/hooks/test_cassandra.py @@ -22,7 +22,10 @@ import pytest from cassandra.cluster import Cluster from cassandra.policies import ( - DCAwareRoundRobinPolicy, RoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, + DCAwareRoundRobinPolicy, + RoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, ) from airflow.models import Connection @@ -35,13 +38,23 @@ class TestCassandraHook(unittest.TestCase): def setUp(self): db.merge_conn( Connection( - conn_id='cassandra_test', conn_type='cassandra', - host='host-1,host-2', port='9042', schema='test_keyspace', - extra='{"load_balancing_policy":"TokenAwarePolicy"}')) + conn_id='cassandra_test', + conn_type='cassandra', + host='host-1,host-2', + port='9042', + schema='test_keyspace', + extra='{"load_balancing_policy":"TokenAwarePolicy"}', + ) + ) db.merge_conn( Connection( - conn_id='cassandra_default_with_schema', conn_type='cassandra', - host='cassandra', port='9042', schema='s')) + conn_id='cassandra_default_with_schema', + conn_type='cassandra', + host='cassandra', + port='9042', + schema='s', + ) + ) hook = CassandraHook("cassandra_default") session = hook.get_conn() @@ -59,8 +72,9 @@ def setUp(self): hook.shutdown_cluster() def test_get_conn(self): - with mock.patch.object(Cluster, "connect") as mock_connect, \ - mock.patch("socket.getaddrinfo", return_value=[]) as mock_getaddrinfo: + with mock.patch.object(Cluster, "connect") as mock_connect, mock.patch( + "socket.getaddrinfo", return_value=[] + ) as mock_getaddrinfo: mock_connect.return_value = 'session' hook = CassandraHook(cassandra_conn_id='cassandra_test') hook.get_conn() @@ -76,65 +90,81 @@ def test_get_lb_policy_with_no_args(self): # test LB policies with no args self._assert_get_lb_policy('RoundRobinPolicy', {}, RoundRobinPolicy) self._assert_get_lb_policy('DCAwareRoundRobinPolicy', {}, DCAwareRoundRobinPolicy) - self._assert_get_lb_policy('TokenAwarePolicy', {}, TokenAwarePolicy, - expected_child_policy_type=RoundRobinPolicy) + self._assert_get_lb_policy( + 'TokenAwarePolicy', {}, TokenAwarePolicy, expected_child_policy_type=RoundRobinPolicy + ) def test_get_lb_policy_with_args(self): # test DCAwareRoundRobinPolicy with args - self._assert_get_lb_policy('DCAwareRoundRobinPolicy', - {'local_dc': 'foo', 'used_hosts_per_remote_dc': '3'}, - DCAwareRoundRobinPolicy) + self._assert_get_lb_policy( + 'DCAwareRoundRobinPolicy', + {'local_dc': 'foo', 'used_hosts_per_remote_dc': '3'}, + DCAwareRoundRobinPolicy, + ) # test WhiteListRoundRobinPolicy with args - fake_addr_info = [['family', 'sockettype', 'proto', - 'canonname', ('2606:2800:220:1:248:1893:25c8:1946', 80, 0, 0)]] + fake_addr_info = [ + ['family', 'sockettype', 'proto', 'canonname', ('2606:2800:220:1:248:1893:25c8:1946', 80, 0, 0)] + ] with mock.patch('socket.getaddrinfo', return_value=fake_addr_info): - self._assert_get_lb_policy('WhiteListRoundRobinPolicy', - {'hosts': ['host1', 'host2']}, - WhiteListRoundRobinPolicy) + self._assert_get_lb_policy( + 'WhiteListRoundRobinPolicy', {'hosts': ['host1', 'host2']}, WhiteListRoundRobinPolicy + ) # test TokenAwarePolicy with args with mock.patch('socket.getaddrinfo', return_value=fake_addr_info): self._assert_get_lb_policy( 'TokenAwarePolicy', - {'child_load_balancing_policy': 'WhiteListRoundRobinPolicy', - 'child_load_balancing_policy_args': {'hosts': ['host-1', 'host-2']} - }, TokenAwarePolicy, expected_child_policy_type=WhiteListRoundRobinPolicy) + { + 'child_load_balancing_policy': 'WhiteListRoundRobinPolicy', + 'child_load_balancing_policy_args': {'hosts': ['host-1', 'host-2']}, + }, + TokenAwarePolicy, + expected_child_policy_type=WhiteListRoundRobinPolicy, + ) def test_get_lb_policy_invalid_policy(self): # test invalid policy name should default to RoundRobinPolicy self._assert_get_lb_policy('DoesNotExistPolicy', {}, RoundRobinPolicy) # test invalid child policy name should default child policy to RoundRobinPolicy - self._assert_get_lb_policy('TokenAwarePolicy', {}, TokenAwarePolicy, - expected_child_policy_type=RoundRobinPolicy) - self._assert_get_lb_policy('TokenAwarePolicy', - {'child_load_balancing_policy': 'DoesNotExistPolicy'}, - TokenAwarePolicy, - expected_child_policy_type=RoundRobinPolicy) + self._assert_get_lb_policy( + 'TokenAwarePolicy', {}, TokenAwarePolicy, expected_child_policy_type=RoundRobinPolicy + ) + self._assert_get_lb_policy( + 'TokenAwarePolicy', + {'child_load_balancing_policy': 'DoesNotExistPolicy'}, + TokenAwarePolicy, + expected_child_policy_type=RoundRobinPolicy, + ) def test_get_lb_policy_no_host_for_allow_list(self): # test host not specified for WhiteListRoundRobinPolicy should throw exception - self._assert_get_lb_policy('WhiteListRoundRobinPolicy', - {}, - WhiteListRoundRobinPolicy, - should_throw=True) - self._assert_get_lb_policy('TokenAwarePolicy', - {'child_load_balancing_policy': 'WhiteListRoundRobinPolicy'}, - TokenAwarePolicy, - expected_child_policy_type=RoundRobinPolicy, - should_throw=True) - - def _assert_get_lb_policy(self, policy_name, policy_args, expected_policy_type, - expected_child_policy_type=None, - should_throw=False): + self._assert_get_lb_policy( + 'WhiteListRoundRobinPolicy', {}, WhiteListRoundRobinPolicy, should_throw=True + ) + self._assert_get_lb_policy( + 'TokenAwarePolicy', + {'child_load_balancing_policy': 'WhiteListRoundRobinPolicy'}, + TokenAwarePolicy, + expected_child_policy_type=RoundRobinPolicy, + should_throw=True, + ) + + def _assert_get_lb_policy( + self, + policy_name, + policy_args, + expected_policy_type, + expected_child_policy_type=None, + should_throw=False, + ): thrown = False try: policy = CassandraHook.get_lb_policy(policy_name, policy_args) self.assertTrue(isinstance(policy, expected_policy_type)) if expected_child_policy_type: - self.assertTrue(isinstance(policy._child_policy, - expected_child_policy_type)) + self.assertTrue(isinstance(policy._child_policy, expected_child_policy_type)) except Exception: # pylint: disable=broad-except thrown = True self.assertEqual(should_throw, thrown) diff --git a/tests/providers/apache/cassandra/sensors/test_table.py b/tests/providers/apache/cassandra/sensors/test_table.py index 4f35bac4b9dcb..384528b70133d 100644 --- a/tests/providers/apache/cassandra/sensors/test_table.py +++ b/tests/providers/apache/cassandra/sensors/test_table.py @@ -30,9 +30,7 @@ class TestCassandraTableSensor(unittest.TestCase): @patch("airflow.providers.apache.cassandra.sensors.table.CassandraHook") def test_poke(self, mock_hook): sensor = CassandraTableSensor( - task_id='test_task', - cassandra_conn_id=TEST_CASSANDRA_CONN_ID, - table=TEST_CASSANDRA_TABLE, + task_id='test_task', cassandra_conn_id=TEST_CASSANDRA_CONN_ID, table=TEST_CASSANDRA_TABLE, ) exists = sensor.poke(dict()) @@ -46,9 +44,7 @@ def test_poke_should_return_false_for_non_existing_table(self, mock_hook): mock_hook.return_value.table_exists.return_value = False sensor = CassandraTableSensor( - task_id='test_task', - cassandra_conn_id=TEST_CASSANDRA_CONN_ID, - table=TEST_CASSANDRA_TABLE, + task_id='test_task', cassandra_conn_id=TEST_CASSANDRA_CONN_ID, table=TEST_CASSANDRA_TABLE, ) exists = sensor.poke(dict()) diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py index 365568711ab74..85e67fd8eefa2 100644 --- a/tests/providers/apache/druid/hooks/test_druid.py +++ b/tests/providers/apache/druid/hooks/test_druid.py @@ -28,7 +28,6 @@ class TestDruidHook(unittest.TestCase): - def setUp(self): super().setUp() session = requests.Session() @@ -38,18 +37,18 @@ def setUp(self): class TestDRuidhook(DruidHook): def get_conn_url(self): return 'http://druid-overlord:8081/druid/indexer/v1/task' + self.db_hook = TestDRuidhook() @requests_mock.mock() def test_submit_gone_wrong(self, m): task_post = m.post( 'http://druid-overlord:8081/druid/indexer/v1/task', - text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}' + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', ) status_check = m.get( - 'http://druid-overlord:8081/druid/indexer/v1/task/' - '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', - text='{"status":{"status": "FAILED"}}' + 'http://druid-overlord:8081/druid/indexer/v1/task/' '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', + text='{"status":{"status": "FAILED"}}', ) # The job failed for some reason @@ -63,12 +62,11 @@ def test_submit_gone_wrong(self, m): def test_submit_ok(self, m): task_post = m.post( 'http://druid-overlord:8081/druid/indexer/v1/task', - text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}' + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', ) status_check = m.get( - 'http://druid-overlord:8081/druid/indexer/v1/task/' - '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', - text='{"status":{"status": "SUCCESS"}}' + 'http://druid-overlord:8081/druid/indexer/v1/task/' '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', + text='{"status":{"status": "SUCCESS"}}', ) # Exists just as it should @@ -81,12 +79,11 @@ def test_submit_ok(self, m): def test_submit_correct_json_body(self, m): task_post = m.post( 'http://druid-overlord:8081/druid/indexer/v1/task', - text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}' + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', ) status_check = m.get( - 'http://druid-overlord:8081/druid/indexer/v1/task/' - '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', - text='{"status":{"status": "SUCCESS"}}' + 'http://druid-overlord:8081/druid/indexer/v1/task/' '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', + text='{"status":{"status": "SUCCESS"}}', ) json_ingestion_string = """ @@ -106,12 +103,11 @@ def test_submit_correct_json_body(self, m): def test_submit_unknown_response(self, m): task_post = m.post( 'http://druid-overlord:8081/druid/indexer/v1/task', - text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}' + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', ) status_check = m.get( - 'http://druid-overlord:8081/druid/indexer/v1/task/' - '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', - text='{"status":{"status": "UNKNOWN"}}' + 'http://druid-overlord:8081/druid/indexer/v1/task/' '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', + text='{"status":{"status": "UNKNOWN"}}', ) # An unknown error code @@ -127,17 +123,16 @@ def test_submit_timeout(self, m): self.db_hook.max_ingestion_time = 5 task_post = m.post( 'http://druid-overlord:8081/druid/indexer/v1/task', - text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}' + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', ) status_check = m.get( - 'http://druid-overlord:8081/druid/indexer/v1/task/' - '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', - text='{"status":{"status": "RUNNING"}}' + 'http://druid-overlord:8081/druid/indexer/v1/task/' '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status', + text='{"status":{"status": "RUNNING"}}', ) shutdown_post = m.post( 'http://druid-overlord:8081/druid/indexer/v1/task/' '9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/shutdown', - text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}' + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', ) # Because the jobs keeps running @@ -194,7 +189,6 @@ def test_get_auth_with_no_user_and_password(self, mock_get_connection): class TestDruidDbApiHook(unittest.TestCase): - def setUp(self): super().setUp() self.cur = MagicMock() diff --git a/tests/providers/apache/druid/operators/test_druid.py b/tests/providers/apache/druid/operators/test_druid.py index 75a44addb482f..9e563e4c07077 100644 --- a/tests/providers/apache/druid/operators/test_druid.py +++ b/tests/providers/apache/druid/operators/test_druid.py @@ -29,10 +29,7 @@ class TestDruidOperator(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': timezone.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': timezone.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) def test_render_template(self): @@ -52,11 +49,8 @@ def test_render_template(self): operator = DruidOperator( task_id='spark_submit_job', json_index_file=json_str, - params={ - 'index_type': 'index_hadoop', - 'datasource': 'datasource_prd' - }, - dag=self.dag + params={'index_type': 'index_hadoop', 'datasource': 'datasource_prd'}, + dag=self.dag, ) ti = TaskInstance(operator, DEFAULT_DATE) ti.render_templates() diff --git a/tests/providers/apache/druid/operators/test_druid_check.py b/tests/providers/apache/druid/operators/test_druid_check.py index 70b66c0c01c84..c6bfc8c8f28e3 100644 --- a/tests/providers/apache/druid/operators/test_druid_check.py +++ b/tests/providers/apache/druid/operators/test_druid_check.py @@ -28,7 +28,6 @@ class TestDruidCheckOperator(unittest.TestCase): - def setUp(self): self.task_id = 'test_task' self.druid_broker_conn_id = 'default_conn' @@ -38,10 +37,8 @@ def __construct_operator(self, sql): dag = DAG('test_dag', start_date=datetime(2017, 1, 1)) return DruidCheckOperator( - dag=dag, - task_id=self.task_id, - druid_broker_conn_id=self.druid_broker_conn_id, - sql=sql) + dag=dag, task_id=self.task_id, druid_broker_conn_id=self.druid_broker_conn_id, sql=sql + ) @mock.patch.object(DruidCheckOperator, 'get_first') def test_execute_pass(self, mock_get_first): diff --git a/tests/providers/apache/druid/transfers/test_hive_to_druid.py b/tests/providers/apache/druid/transfers/test_hive_to_druid.py index 8951fe66997f5..968dd1d78f90a 100644 --- a/tests/providers/apache/druid/transfers/test_hive_to_druid.py +++ b/tests/providers/apache/druid/transfers/test_hive_to_druid.py @@ -37,7 +37,7 @@ class TestDruidHook(unittest.TestCase): 'ts_dim': 'timedimension_column', 'metric_spec': [ {"name": "count", "type": "count"}, - {"name": "amountSum", "type": "doubleSum", "fieldName": "amount"} + {"name": "amountSum", "type": "doubleSum", "fieldName": "amount"}, ], 'hive_cli_conn_id': 'hive_cli_custom', 'druid_ingest_conn_id': 'druid_ingest_default', @@ -51,22 +51,16 @@ class TestDruidHook(unittest.TestCase): 'job_properties': { "mapreduce.job.user.classpath.first": "false", "mapreduce.map.output.compress": "false", - "mapreduce.output.fileoutputformat.compress": "false" - } + "mapreduce.output.fileoutputformat.compress": "false", + }, } - index_spec_config = { - 'static_path': '/apps/db/warehouse/hive/', - 'columns': ['country', 'segment'] - } + index_spec_config = {'static_path': '/apps/db/warehouse/hive/', 'columns': ['country', 'segment']} def setUp(self): super().setUp() - args = { - 'owner': 'airflow', - 'start_date': '2017-01-01' - } + args = {'owner': 'airflow', 'start_date': '2017-01-01'} self.dag = DAG('hive_to_druid', default_args=args) session = requests.Session() @@ -74,15 +68,9 @@ def setUp(self): session.mount('mock', adapter) def test_construct_ingest_query(self): - operator = HiveToDruidOperator( - task_id='hive_to_druid', - dag=self.dag, - **self.hook_config - ) + operator = HiveToDruidOperator(task_id='hive_to_druid', dag=self.dag, **self.hook_config) - provided_index_spec = operator.construct_ingest_query( - **self.index_spec_config - ) + provided_index_spec = operator.construct_ingest_query(**self.index_spec_config) expected_index_spec = { "hadoopDependencyCoordinates": self.hook_config['hadoop_dependency_coordinates'], @@ -103,16 +91,13 @@ def test_construct_ingest_query(self): "dimensionsSpec": { "dimensionExclusions": [], "dimensions": self.index_spec_config['columns'], - "spatialDimensions": [] - }, - "timestampSpec": { - "column": self.hook_config['ts_dim'], - "format": "auto" + "spatialDimensions": [], }, - "format": "tsv" - } + "timestampSpec": {"column": self.hook_config['ts_dim'], "format": "auto"}, + "format": "tsv", + }, }, - "dataSource": self.hook_config['druid_datasource'] + "dataSource": self.hook_config['druid_datasource'], }, "tuningConfig": { "type": "hadoop", @@ -124,13 +109,10 @@ def test_construct_ingest_query(self): }, }, "ioConfig": { - "inputSpec": { - "paths": self.index_spec_config['static_path'], - "type": "static" - }, - "type": "hadoop" - } - } + "inputSpec": {"paths": self.index_spec_config['static_path'], "type": "static"}, + "type": "hadoop", + }, + }, } # Make sure it is like we expect it diff --git a/tests/providers/apache/hdfs/hooks/test_hdfs.py b/tests/providers/apache/hdfs/hooks/test_hdfs.py index ba04d453ed12c..22d48588d8446 100644 --- a/tests/providers/apache/hdfs/hooks/test_hdfs.py +++ b/tests/providers/apache/hdfs/hooks/test_hdfs.py @@ -25,20 +25,17 @@ try: import snakebite + snakebite_loaded = True except ImportError: snakebite_loaded = False if not snakebite_loaded: - raise unittest.SkipTest( - "Skipping test because HDFSHook is not installed" - ) + raise unittest.SkipTest("Skipping test because HDFSHook is not installed") class TestHDFSHook(unittest.TestCase): - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020', - }) + @mock.patch.dict('os.environ', {'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020',}) def test_get_client(self): client = HDFSHook(proxy_user='foo').get_conn() self.assertIsInstance(client, snakebite.client.Client) @@ -46,9 +43,7 @@ def test_get_client(self): self.assertEqual(8020, client.port) self.assertEqual('foo', client.service.channel.effective_user) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020', - }) + @mock.patch.dict('os.environ', {'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020',}) @mock.patch('airflow.providers.apache.hdfs.hooks.hdfs.AutoConfigClient') @mock.patch('airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook.get_connections') def test_get_autoconfig_client(self, mock_get_connections, mock_client): @@ -58,15 +53,13 @@ def test_get_autoconfig_client(self, mock_get_connections, mock_client): host='localhost', port=8020, login='foo', - extra=json.dumps({'autoconfig': True}) + extra=json.dumps({'autoconfig': True}), ) mock_get_connections.return_value = [conn] HDFSHook(hdfs_conn_id='hdfs').get_conn() mock_client.assert_called_once_with(effective_user='foo', use_sasl=False) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020', - }) + @mock.patch.dict('os.environ', {'AIRFLOW_CONN_HDFS_DEFAULT': 'hdfs://localhost:8020',}) @mock.patch('airflow.providers.apache.hdfs.hooks.hdfs.AutoConfigClient') def test_get_autoconfig_client_no_conn(self, mock_client): HDFSHook(hdfs_conn_id='hdfs_missing', autoconfig=True).get_conn() @@ -74,18 +67,8 @@ def test_get_autoconfig_client_no_conn(self, mock_client): @mock.patch('airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook.get_connections') def test_get_ha_client(self, mock_get_connections): - conn_1 = Connection( - conn_id='hdfs_default', - conn_type='hdfs', - host='localhost', - port=8020 - ) - conn_2 = Connection( - conn_id='hdfs_default', - conn_type='hdfs', - host='localhost2', - port=8020 - ) + conn_1 = Connection(conn_id='hdfs_default', conn_type='hdfs', host='localhost', port=8020) + conn_2 = Connection(conn_id='hdfs_default', conn_type='hdfs', host='localhost2', port=8020) mock_get_connections.return_value = [conn_1, conn_2] client = HDFSHook().get_conn() self.assertIsInstance(client, snakebite.client.HAClient) diff --git a/tests/providers/apache/hdfs/hooks/test_webhdfs.py b/tests/providers/apache/hdfs/hooks/test_webhdfs.py index 4c70bd99dd56c..05ca418a7c6ef 100644 --- a/tests/providers/apache/hdfs/hooks/test_webhdfs.py +++ b/tests/providers/apache/hdfs/hooks/test_webhdfs.py @@ -26,46 +26,49 @@ class TestWebHDFSHook(unittest.TestCase): - def setUp(self): self.webhdfs_hook = WebHDFSHook() @patch('airflow.providers.apache.hdfs.hooks.webhdfs.InsecureClient') - @patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook.get_connections', return_value=[ - Connection(host='host_1', port=123), - Connection(host='host_2', port=321, login='user') - ]) + @patch( + 'airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook.get_connections', + return_value=[Connection(host='host_1', port=123), Connection(host='host_2', port=321, login='user')], + ) @patch("airflow.providers.apache.hdfs.hooks.webhdfs.socket") def test_get_conn(self, socket_mock, mock_get_connections, mock_insecure_client): mock_insecure_client.side_effect = [HdfsError('Error'), mock_insecure_client.return_value] socket_mock.socket.return_value.connect_ex.return_value = 0 conn = self.webhdfs_hook.get_conn() - mock_insecure_client.assert_has_calls([ - call('http://{host}:{port}'.format(host=connection.host, port=connection.port), - user=connection.login) - for connection in mock_get_connections.return_value - ]) + mock_insecure_client.assert_has_calls( + [ + call( + 'http://{host}:{port}'.format(host=connection.host, port=connection.port), + user=connection.login, + ) + for connection in mock_get_connections.return_value + ] + ) mock_insecure_client.return_value.status.assert_called_once_with('/') self.assertEqual(conn, mock_insecure_client.return_value) @patch('airflow.providers.apache.hdfs.hooks.webhdfs.KerberosClient', create=True) - @patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook.get_connections', return_value=[ - Connection(host='host_1', port=123) - ]) + @patch( + 'airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook.get_connections', + return_value=[Connection(host='host_1', port=123)], + ) @patch('airflow.providers.apache.hdfs.hooks.webhdfs._kerberos_security_mode', return_value=True) @patch("airflow.providers.apache.hdfs.hooks.webhdfs.socket") - def test_get_conn_kerberos_security_mode(self, - socket_mock, - mock_kerberos_security_mode, - mock_get_connections, - mock_kerberos_client): + def test_get_conn_kerberos_security_mode( + self, socket_mock, mock_kerberos_security_mode, mock_get_connections, mock_kerberos_client + ): socket_mock.socket.return_value.connect_ex.return_value = 0 conn = self.webhdfs_hook.get_conn() connection = mock_get_connections.return_value[0] mock_kerberos_client.assert_called_once_with( - 'http://{host}:{port}'.format(host=connection.host, port=connection.port)) + 'http://{host}:{port}'.format(host=connection.host, port=connection.port) + ) self.assertEqual(conn, mock_kerberos_client.return_value) @patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook._find_valid_server', return_value=None) @@ -94,10 +97,7 @@ def test_load_file(self, mock_get_conn): mock_get_conn.assert_called_once_with() mock_upload = mock_get_conn.return_value.upload mock_upload.assert_called_once_with( - hdfs_path=destination, - local_path=source, - overwrite=True, - n_threads=1 + hdfs_path=destination, local_path=source, overwrite=True, n_threads=1 ) def test_simple_init(self): diff --git a/tests/providers/apache/hdfs/sensors/test_hdfs.py b/tests/providers/apache/hdfs/sensors/test_hdfs.py index 536122997bc66..4d25cad191f9a 100644 --- a/tests/providers/apache/hdfs/sensors/test_hdfs.py +++ b/tests/providers/apache/hdfs/sensors/test_hdfs.py @@ -30,7 +30,6 @@ class TestHdfsSensor(unittest.TestCase): - def setUp(self): self.hook = FakeHDFSHook @@ -40,12 +39,14 @@ def test_legacy_file_exist(self): :return: """ # When - task = HdfsSensor(task_id='Should_be_file_legacy', - filepath='/datadirectory/datafile', - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsSensor( + task_id='Should_be_file_legacy', + filepath='/datadirectory/datafile', + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) task.execute(None) # Then @@ -57,13 +58,15 @@ def test_legacy_file_exist_but_filesize(self): :return: """ # When - task = HdfsSensor(task_id='Should_be_file_legacy', - filepath='/datadirectory/datafile', - timeout=1, - file_size=20, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsSensor( + task_id='Should_be_file_legacy', + filepath='/datadirectory/datafile', + timeout=1, + file_size=20, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then @@ -75,12 +78,14 @@ def test_legacy_file_does_not_exists(self): Test the legacy behaviour :return: """ - task = HdfsSensor(task_id='Should_not_be_file_legacy', - filepath='/datadirectory/not_existing_file_or_directory', - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsSensor( + task_id='Should_not_be_file_legacy', + filepath='/datadirectory/not_existing_file_or_directory', + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then @@ -103,13 +108,15 @@ def test_should_be_empty_directory(self): self.log.debug('#' * 10) self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) - task = HdfsFolderSensor(task_id='Should_be_empty_directory', - filepath='/datadirectory/empty_directory', - be_empty=True, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsFolderSensor( + task_id='Should_be_empty_directory', + filepath='/datadirectory/empty_directory', + be_empty=True, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When task.execute(None) @@ -126,13 +133,15 @@ def test_should_be_empty_directory_fail(self): self.log.debug('#' * 10) self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) - task = HdfsFolderSensor(task_id='Should_be_empty_directory_fail', - filepath='/datadirectory/not_empty_directory', - be_empty=True, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsFolderSensor( + task_id='Should_be_empty_directory_fail', + filepath='/datadirectory/not_empty_directory', + be_empty=True, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then @@ -148,12 +157,14 @@ def test_should_be_a_non_empty_directory(self): self.log.debug('#' * 10) self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) - task = HdfsFolderSensor(task_id='Should_be_non_empty_directory', - filepath='/datadirectory/not_empty_directory', - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsFolderSensor( + task_id='Should_be_non_empty_directory', + filepath='/datadirectory/not_empty_directory', + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When task.execute(None) @@ -170,12 +181,14 @@ def test_should_be_non_empty_directory_fail(self): self.log.debug('#' * 10) self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) - task = HdfsFolderSensor(task_id='Should_be_empty_directory_fail', - filepath='/datadirectory/empty_directory', - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsFolderSensor( + task_id='Should_be_empty_directory_fail', + filepath='/datadirectory/empty_directory', + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then @@ -199,13 +212,15 @@ def test_should_match_regex(self): self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) compiled_regex = re.compile("test[1-2]file") - task = HdfsRegexSensor(task_id='Should_match_the_regex', - filepath='/datadirectory/regex_dir', - regex=compiled_regex, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsRegexSensor( + task_id='Should_match_the_regex', + filepath='/datadirectory/regex_dir', + regex=compiled_regex, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When task.execute(None) @@ -223,13 +238,15 @@ def test_should_not_match_regex(self): self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) compiled_regex = re.compile("^IDoNotExist") - task = HdfsRegexSensor(task_id='Should_not_match_the_regex', - filepath='/datadirectory/regex_dir', - regex=compiled_regex, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsRegexSensor( + task_id='Should_not_match_the_regex', + filepath='/datadirectory/regex_dir', + regex=compiled_regex, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then @@ -246,16 +263,18 @@ def test_should_match_regex_and_filesize(self): self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) compiled_regex = re.compile("test[1-2]file") - task = HdfsRegexSensor(task_id='Should_match_the_regex_and_filesize', - filepath='/datadirectory/regex_dir', - regex=compiled_regex, - ignore_copying=True, - ignored_ext=['_COPYING_', 'sftp'], - file_size=10, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsRegexSensor( + task_id='Should_match_the_regex_and_filesize', + filepath='/datadirectory/regex_dir', + regex=compiled_regex, + ignore_copying=True, + ignored_ext=['_COPYING_', 'sftp'], + file_size=10, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When task.execute(None) @@ -273,14 +292,16 @@ def test_should_match_regex_but_filesize(self): self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) compiled_regex = re.compile("test[1-2]file") - task = HdfsRegexSensor(task_id='Should_match_the_regex_but_filesize', - filepath='/datadirectory/regex_dir', - regex=compiled_regex, - file_size=20, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsRegexSensor( + task_id='Should_match_the_regex_but_filesize', + filepath='/datadirectory/regex_dir', + regex=compiled_regex, + file_size=20, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then @@ -297,15 +318,17 @@ def test_should_match_regex_but_copyingext(self): self.log.debug('Running %s', self._testMethodName) self.log.debug('#' * 10) compiled_regex = re.compile(r"copying_file_\d+.txt") - task = HdfsRegexSensor(task_id='Should_match_the_regex_but_filesize', - filepath='/datadirectory/regex_dir', - regex=compiled_regex, - ignored_ext=['_COPYING_', 'sftp'], - file_size=20, - timeout=1, - retry_delay=timedelta(seconds=1), - poke_interval=1, - hook=self.hook) + task = HdfsRegexSensor( + task_id='Should_match_the_regex_but_filesize', + filepath='/datadirectory/regex_dir', + regex=compiled_regex, + ignored_ext=['_COPYING_', 'sftp'], + file_size=20, + timeout=1, + retry_delay=timedelta(seconds=1), + poke_interval=1, + hook=self.hook, + ) # When # Then diff --git a/tests/providers/apache/hdfs/sensors/test_web_hdfs.py b/tests/providers/apache/hdfs/sensors/test_web_hdfs.py index cad72403ff629..bc53cfc0d0be1 100644 --- a/tests/providers/apache/hdfs/sensors/test_web_hdfs.py +++ b/tests/providers/apache/hdfs/sensors/test_web_hdfs.py @@ -26,14 +26,9 @@ class TestWebHdfsSensor(TestHiveEnvironment): - @mock.patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook') def test_poke(self, mock_hook): - sensor = WebHdfsSensor( - task_id='test_task', - webhdfs_conn_id=TEST_HDFS_CONN, - filepath=TEST_HDFS_PATH, - ) + sensor = WebHdfsSensor(task_id='test_task', webhdfs_conn_id=TEST_HDFS_CONN, filepath=TEST_HDFS_PATH,) exists = sensor.poke(dict()) self.assertTrue(exists) @@ -45,11 +40,7 @@ def test_poke(self, mock_hook): def test_poke_should_return_false_for_non_existing_table(self, mock_hook): mock_hook.return_value.check_for_path.return_value = False - sensor = WebHdfsSensor( - task_id='test_task', - webhdfs_conn_id=TEST_HDFS_CONN, - filepath=TEST_HDFS_PATH, - ) + sensor = WebHdfsSensor(task_id='test_task', webhdfs_conn_id=TEST_HDFS_CONN, filepath=TEST_HDFS_PATH,) exists = sensor.poke(dict()) self.assertFalse(exists) diff --git a/tests/providers/apache/hive/__init__.py b/tests/providers/apache/hive/__init__.py index 1de6e09f73f0e..08245b0654812 100644 --- a/tests/providers/apache/hive/__init__.py +++ b/tests/providers/apache/hive/__init__.py @@ -27,7 +27,6 @@ class TestHiveEnvironment(TestCase): - def setUp(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) diff --git a/tests/providers/apache/hive/hooks/test_hive.py b/tests/providers/apache/hive/hooks/test_hive.py index decfcbc4bfb28..f658957c5cb81 100644 --- a/tests/providers/apache/hive/hooks/test_hive.py +++ b/tests/providers/apache/hive/hooks/test_hive.py @@ -44,15 +44,14 @@ class TestHiveEnvironment(unittest.TestCase): - def setUp(self): - self.next_day = (DEFAULT_DATE + - datetime.timedelta(days=1)).isoformat()[:10] + self.next_day = (DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()[:10] self.database = 'airflow' self.partition_by = 'ds' self.table = 'static_babynames_partitioned' - with mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_metastore_client' - ) as get_metastore_mock: + with mock.patch( + 'airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_metastore_client' + ) as get_metastore_mock: get_metastore_mock.return_value = mock.MagicMock() self.hook = HiveMetastoreHook() @@ -67,107 +66,132 @@ def test_run_cli(self, mock_popen, mock_temp_dir): mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "test_run_cli" - with mock.patch.dict('os.environ', { - 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', - 'AIRFLOW_CTX_TASK_ID': 'test_task_id', - 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', - 'AIRFLOW_CTX_DAG_RUN_ID': '55', - 'AIRFLOW_CTX_DAG_OWNER': 'airflow', - 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', - }): + with mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', + 'AIRFLOW_CTX_TASK_ID': 'test_task_id', + 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', + 'AIRFLOW_CTX_DAG_RUN_ID': '55', + 'AIRFLOW_CTX_DAG_OWNER': 'airflow', + 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', + }, + ): hook = MockHiveCliHook() hook.run_cli("SHOW DATABASES") - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=test_dag_id', '-hiveconf', 'airflow.ctx.task_id=test_task_id', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=test@airflow.com', '-hiveconf', - 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', - '-hiveconf', 'tez.queue.name=airflow', '-f', - '/tmp/airflow_hiveop_test_run_cli/tmptest_run_cli'] + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=test_dag_id', + '-hiveconf', + 'airflow.ctx.task_id=test_task_id', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=55', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=test@airflow.com', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_test_run_cli/tmptest_run_cli', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_run_cli", - close_fds=True + close_fds=True, ) @mock.patch('subprocess.Popen') def test_run_cli_with_hive_conf(self, mock_popen): - hql = "set key;\n" \ - "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \ - "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n" - - dag_id_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] - task_id_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] - execution_date_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ - 'env_var_format'] - dag_run_id_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ - 'env_var_format'] - - mock_output = ['Connecting to jdbc:hive2://localhost:10000/default', - 'log4j:WARN No appenders could be found for logger (org.apache.hive.jdbc.Utils).', - 'log4j:WARN Please initialize the log4j system properly.', - 'log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.', - 'Connected to: Apache Hive (version 1.2.1.2.3.2.0-2950)', - 'Driver: Hive JDBC (version 1.2.1.spark2)', - 'Transaction isolation: TRANSACTION_REPEATABLE_READ', - '0: jdbc:hive2://localhost:10000/default> USE default;', - 'No rows affected (0.37 seconds)', - '0: jdbc:hive2://localhost:10000/default> set key;', - '+------------+--+', - '| set |', - '+------------+--+', - '| key=value |', - '+------------+--+', - '1 row selected (0.133 seconds)', - '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_id;', - '+---------------------------------+--+', - '| set |', - '+---------------------------------+--+', - '| airflow.ctx.dag_id=test_dag_id |', - '+---------------------------------+--+', - '1 row selected (0.008 seconds)', - '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_run_id;', - '+-----------------------------------------+--+', - '| set |', - '+-----------------------------------------+--+', - '| airflow.ctx.dag_run_id=test_dag_run_id |', - '+-----------------------------------------+--+', - '1 row selected (0.007 seconds)', - '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.task_id;', - '+-----------------------------------+--+', - '| set |', - '+-----------------------------------+--+', - '| airflow.ctx.task_id=test_task_id |', - '+-----------------------------------+--+', - '1 row selected (0.009 seconds)', - '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.execution_date;', - '+-------------------------------------------------+--+', - '| set |', - '+-------------------------------------------------+--+', - '| airflow.ctx.execution_date=test_execution_date |', - '+-------------------------------------------------+--+', - '1 row selected (0.006 seconds)', - '0: jdbc:hive2://localhost:10000/default> ', - '0: jdbc:hive2://localhost:10000/default> ', - 'Closing: 0: jdbc:hive2://localhost:10000/default', - ''] - - with mock.patch.dict('os.environ', { - dag_id_ctx_var_name: 'test_dag_id', - task_id_ctx_var_name: 'test_task_id', - execution_date_ctx_var_name: 'test_execution_date', - dag_run_id_ctx_var_name: 'test_dag_run_id', - }): + hql = ( + "set key;\n" + "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" + "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n" + ) + + dag_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] + task_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] + execution_date_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ + 'env_var_format' + ] + dag_run_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ + 'env_var_format' + ] + + mock_output = [ + 'Connecting to jdbc:hive2://localhost:10000/default', + 'log4j:WARN No appenders could be found for logger (org.apache.hive.jdbc.Utils).', + 'log4j:WARN Please initialize the log4j system properly.', + 'log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.', + 'Connected to: Apache Hive (version 1.2.1.2.3.2.0-2950)', + 'Driver: Hive JDBC (version 1.2.1.spark2)', + 'Transaction isolation: TRANSACTION_REPEATABLE_READ', + '0: jdbc:hive2://localhost:10000/default> USE default;', + 'No rows affected (0.37 seconds)', + '0: jdbc:hive2://localhost:10000/default> set key;', + '+------------+--+', + '| set |', + '+------------+--+', + '| key=value |', + '+------------+--+', + '1 row selected (0.133 seconds)', + '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_id;', + '+---------------------------------+--+', + '| set |', + '+---------------------------------+--+', + '| airflow.ctx.dag_id=test_dag_id |', + '+---------------------------------+--+', + '1 row selected (0.008 seconds)', + '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_run_id;', + '+-----------------------------------------+--+', + '| set |', + '+-----------------------------------------+--+', + '| airflow.ctx.dag_run_id=test_dag_run_id |', + '+-----------------------------------------+--+', + '1 row selected (0.007 seconds)', + '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.task_id;', + '+-----------------------------------+--+', + '| set |', + '+-----------------------------------+--+', + '| airflow.ctx.task_id=test_task_id |', + '+-----------------------------------+--+', + '1 row selected (0.009 seconds)', + '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.execution_date;', + '+-------------------------------------------------+--+', + '| set |', + '+-------------------------------------------------+--+', + '| airflow.ctx.execution_date=test_execution_date |', + '+-------------------------------------------------+--+', + '1 row selected (0.006 seconds)', + '0: jdbc:hive2://localhost:10000/default> ', + '0: jdbc:hive2://localhost:10000/default> ', + 'Closing: 0: jdbc:hive2://localhost:10000/default', + '', + ] + + with mock.patch.dict( + 'os.environ', + { + dag_id_ctx_var_name: 'test_dag_id', + task_id_ctx_var_name: 'test_task_id', + execution_date_ctx_var_name: 'test_execution_date', + dag_run_id_ctx_var_name: 'test_dag_run_id', + }, + ): hook = MockHiveCliHook() mock_popen.return_value = MockSubProcess(output=mock_output) @@ -195,14 +219,10 @@ def test_load_file_without_create_table(self, mock_run_cli): hook = MockHiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) - query = ( - "LOAD DATA LOCAL INPATH '{filepath}' " - "OVERWRITE INTO TABLE {table} ;\n" - .format(filepath=filepath, table=table) + query = "LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format( + filepath=filepath, table=table ) - calls = [ - mock.call(query) - ] + calls = [mock.call(query)] mock_run_cli.assert_has_calls(calls, any_order=True) @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveCliHook.run_cli') @@ -210,12 +230,10 @@ def test_load_file_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" field_dict = OrderedDict([("name", "string"), ("gender", "string")]) - fields = ",\n ".join( - ['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()]) + fields = ",\n ".join(['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()]) hook = MockHiveCliHook() - hook.load_file(filepath=filepath, table=table, - field_dict=field_dict, create=True, recreate=True) + hook.load_file(filepath=filepath, table=table, field_dict=field_dict, create=True, recreate=True) create_table = ( "DROP TABLE IF EXISTS {table};\n" @@ -225,15 +243,10 @@ def test_load_file_create_table(self, mock_run_cli): "STORED AS textfile\n;".format(table=table, fields=fields) ) - load_data = ( - "LOAD DATA LOCAL INPATH '{filepath}' " - "OVERWRITE INTO TABLE {table} ;\n" - .format(filepath=filepath, table=table) + load_data = "LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format( + filepath=filepath, table=table ) - calls = [ - mock.call(create_table), - mock.call(load_data) - ] + calls = [mock.call(create_table), mock.call(load_data)] mock_run_cli.assert_has_calls(calls, any_order=True) @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveCliHook.load_file') @@ -245,10 +258,7 @@ def test_load_df(self, mock_to_csv, mock_load_file): encoding = "utf-8" hook = MockHiveCliHook() - hook.load_df(df=df, - table=table, - delimiter=delimiter, - encoding=encoding) + hook.load_df(df=df, table=table, delimiter=delimiter, encoding=encoding) assert mock_to_csv.call_count == 1 kwargs = mock_to_csv.call_args[1] @@ -270,10 +280,7 @@ def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file): bools = (True, False) for create, recreate in itertools.product(bools, bools): mock_load_file.reset_mock() - hook.load_df(df=pd.DataFrame({"c": range(0, 10)}), - table="t", - create=create, - recreate=recreate) + hook.load_df(df=pd.DataFrame({"c": range(0, 10)}), table="t", create=create, recreate=recreate) assert mock_load_file.call_count == 1 kwargs = mock_load_file.call_args[1] @@ -315,84 +322,80 @@ def test_load_df_with_data_types(self, mock_run_cli): STORED AS textfile ; """ - assert_equal_ignore_multiple_spaces( - self, mock_run_cli.call_args_list[0][0][0], query) + assert_equal_ignore_multiple_spaces(self, mock_run_cli.call_args_list[0][0][0], query) class TestHiveMetastoreHook(TestHiveEnvironment): VALID_FILTER_MAP = {'key2': 'value2'} def test_get_max_partition_from_empty_part_specs(self): - max_partition = \ - HiveMetastoreHook._get_max_partition_from_part_specs([], - 'key1', - self.VALID_FILTER_MAP) + max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( + [], 'key1', self.VALID_FILTER_MAP + ) self.assertIsNone(max_partition) # @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook', 'get_metastore_client') def test_get_max_partition_from_valid_part_specs_and_invalid_filter_map(self): with self.assertRaises(AirflowException): HiveMetastoreHook._get_max_partition_from_part_specs( - [{'key1': 'value1', 'key2': 'value2'}, - {'key1': 'value3', 'key2': 'value4'}], + [{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}], 'key1', - {'key3': 'value5'}) + {'key3': 'value5'}, + ) def test_get_max_partition_from_valid_part_specs_and_invalid_partition_key(self): with self.assertRaises(AirflowException): HiveMetastoreHook._get_max_partition_from_part_specs( - [{'key1': 'value1', 'key2': 'value2'}, - {'key1': 'value3', 'key2': 'value4'}], + [{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}], 'key3', - self.VALID_FILTER_MAP) + self.VALID_FILTER_MAP, + ) def test_get_max_partition_from_valid_part_specs_and_none_partition_key(self): with self.assertRaises(AirflowException): HiveMetastoreHook._get_max_partition_from_part_specs( - [{'key1': 'value1', 'key2': 'value2'}, - {'key1': 'value3', 'key2': 'value4'}], + [{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}], None, - self.VALID_FILTER_MAP) + self.VALID_FILTER_MAP, + ) def test_get_max_partition_from_valid_part_specs_and_none_filter_map(self): - max_partition = \ - HiveMetastoreHook._get_max_partition_from_part_specs( - [{'key1': 'value1', 'key2': 'value2'}, - {'key1': 'value3', 'key2': 'value4'}], - 'key1', - None) + max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( + [{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}], 'key1', None + ) # No partition will be filtered out. self.assertEqual(max_partition, 'value3') def test_get_max_partition_from_valid_part_specs(self): - max_partition = \ - HiveMetastoreHook._get_max_partition_from_part_specs( - [{'key1': 'value1', 'key2': 'value2'}, - {'key1': 'value3', 'key2': 'value4'}], - 'key1', - self.VALID_FILTER_MAP) + max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( + [{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}], + 'key1', + self.VALID_FILTER_MAP, + ) self.assertEqual(max_partition, 'value1') def test_get_max_partition_from_valid_part_specs_return_type(self): - max_partition = \ - HiveMetastoreHook._get_max_partition_from_part_specs( - [{'key1': 'value1', 'key2': 'value2'}, - {'key1': 'value3', 'key2': 'value4'}], - 'key1', - self.VALID_FILTER_MAP) + max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( + [{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}], + 'key1', + self.VALID_FILTER_MAP, + ) self.assertIsInstance(max_partition, str) - @mock.patch("airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_connection", - return_value=[Connection(host="localhost", port=9802)]) + @mock.patch( + "airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_connection", + return_value=[Connection(host="localhost", port=9802)], + ) @mock.patch("airflow.providers.apache.hive.hooks.hive.socket") def test_error_metastore_client(self, socket_mock, _find_valid_server_mock): socket_mock.socket.return_value.connect_ex.return_value = 0 self.hook.get_metastore_client() def test_get_conn(self): - with mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook._find_valid_server' - ) as find_valid_server: + with mock.patch( + 'airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook._find_valid_server' + ) as find_valid_server: find_valid_server.return_value = mock.MagicMock(return_value={}) metastore_hook = HiveMetastoreHook() @@ -405,82 +408,66 @@ def test_check_for_partition(self): metastore = self.hook.metastore.__enter__() - partition = "{p_by}='{date}'".format(date=DEFAULT_DATE_DS, - p_by=self.partition_by) + partition = "{p_by}='{date}'".format(date=DEFAULT_DATE_DS, p_by=self.partition_by) - metastore.get_partitions_by_filter = mock.MagicMock( - return_value=[fake_partition]) + metastore.get_partitions_by_filter = mock.MagicMock(return_value=[fake_partition]) - self.assertTrue( - self.hook.check_for_partition(self.database, self.table, - partition) - ) + self.assertTrue(self.hook.check_for_partition(self.database, self.table, partition)) - metastore.get_partitions_by_filter( - self.database, self.table, partition, 1) + metastore.get_partitions_by_filter(self.database, self.table, partition, 1) # Check for non-existent partition. - missing_partition = "{p_by}='{date}'".format(date=self.next_day, - p_by=self.partition_by) + missing_partition = "{p_by}='{date}'".format(date=self.next_day, p_by=self.partition_by) metastore.get_partitions_by_filter = mock.MagicMock(return_value=[]) - self.assertFalse( - self.hook.check_for_partition(self.database, self.table, - missing_partition) - ) + self.assertFalse(self.hook.check_for_partition(self.database, self.table, missing_partition)) - metastore.get_partitions_by_filter.assert_called_with( - self.database, self.table, missing_partition, 1) + metastore.get_partitions_by_filter.assert_called_with(self.database, self.table, missing_partition, 1) def test_check_for_named_partition(self): # Check for existing partition. - partition = "{p_by}={date}".format(date=DEFAULT_DATE_DS, - p_by=self.partition_by) + partition = "{p_by}={date}".format(date=DEFAULT_DATE_DS, p_by=self.partition_by) - self.hook.metastore.__enter__( - ).check_for_named_partition = mock.MagicMock(return_value=True) + self.hook.metastore.__enter__().check_for_named_partition = mock.MagicMock(return_value=True) - self.assertTrue( - self.hook.check_for_named_partition(self.database, - self.table, - partition)) + self.assertTrue(self.hook.check_for_named_partition(self.database, self.table, partition)) self.hook.metastore.__enter__().check_for_named_partition.assert_called_with( - self.database, self.table, partition) + self.database, self.table, partition + ) # Check for non-existent partition - missing_partition = "{p_by}={date}".format(date=self.next_day, - p_by=self.partition_by) + missing_partition = "{p_by}={date}".format(date=self.next_day, p_by=self.partition_by) - self.hook.metastore.__enter__().check_for_named_partition = mock.MagicMock( - return_value=False) + self.hook.metastore.__enter__().check_for_named_partition = mock.MagicMock(return_value=False) - self.assertFalse( - self.hook.check_for_named_partition(self.database, - self.table, - missing_partition) - ) + self.assertFalse(self.hook.check_for_named_partition(self.database, self.table, missing_partition)) self.hook.metastore.__enter__().check_for_named_partition.assert_called_with( - self.database, self.table, missing_partition) + self.database, self.table, missing_partition + ) def test_get_table(self): self.hook.metastore.__enter__().get_table = mock.MagicMock() self.hook.get_table(db=self.database, table_name=self.table) self.hook.metastore.__enter__().get_table.assert_called_with( - dbname=self.database, tbl_name=self.table) + dbname=self.database, tbl_name=self.table + ) def test_get_tables(self): # static_babynames_partitioned self.hook.metastore.__enter__().get_tables = mock.MagicMock( - return_value=['static_babynames_partitioned']) + return_value=['static_babynames_partitioned'] + ) self.hook.get_tables(db=self.database, pattern=self.table + "*") self.hook.metastore.__enter__().get_tables.assert_called_with( - db_name='airflow', pattern='static_babynames_partitioned*') + db_name='airflow', pattern='static_babynames_partitioned*' + ) self.hook.metastore.__enter__().get_table_objects_by_name.assert_called_with( - 'airflow', ['static_babynames_partitioned']) + 'airflow', ['static_babynames_partitioned'] + ) def test_get_databases(self): metastore = self.hook.metastore.__enter__() @@ -500,18 +487,16 @@ def test_get_partitions(self): metastore = self.hook.metastore.__enter__() metastore.get_table = mock.MagicMock(return_value=fake_table) - metastore.get_partitions = mock.MagicMock( - return_value=[fake_partition]) + metastore.get_partitions = mock.MagicMock(return_value=[fake_partition]) - partitions = self.hook.get_partitions(schema=self.database, - table_name=self.table) + partitions = self.hook.get_partitions(schema=self.database, table_name=self.table) self.assertEqual(len(partitions), 1) self.assertEqual(partitions, [{self.partition_by: DEFAULT_DATE_DS}]) - metastore.get_table.assert_called_with( - dbname=self.database, tbl_name=self.table) + metastore.get_table.assert_called_with(dbname=self.database, tbl_name=self.table) metastore.get_partitions.assert_called_with( - db_name=self.database, tbl_name=self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT) + db_name=self.database, tbl_name=self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT + ) def test_max_partition(self): FakeFieldSchema = namedtuple('FakeFieldSchema', ['name']) @@ -522,22 +507,19 @@ def test_max_partition(self): metastore = self.hook.metastore.__enter__() metastore.get_table = mock.MagicMock(return_value=fake_table) - metastore.get_partition_names = mock.MagicMock( - return_value=['ds=2015-01-01']) - metastore.partition_name_to_spec = mock.MagicMock( - return_value={'ds': '2015-01-01'}) + metastore.get_partition_names = mock.MagicMock(return_value=['ds=2015-01-01']) + metastore.partition_name_to_spec = mock.MagicMock(return_value={'ds': '2015-01-01'}) filter_map = {self.partition_by: DEFAULT_DATE_DS} - partition = self.hook.max_partition(schema=self.database, - table_name=self.table, - field=self.partition_by, - filter_map=filter_map) + partition = self.hook.max_partition( + schema=self.database, table_name=self.table, field=self.partition_by, filter_map=filter_map + ) self.assertEqual(partition, DEFAULT_DATE_DS) - metastore.get_table.assert_called_with( - dbname=self.database, tbl_name=self.table) + metastore.get_table.assert_called_with(dbname=self.database, tbl_name=self.table) metastore.get_partition_names.assert_called_with( - self.database, self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT) + self.database, self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT + ) metastore.partition_name_to_spec.assert_called_with('ds=2015-01-01') def test_table_exists(self): @@ -546,16 +528,16 @@ def test_table_exists(self): self.assertTrue(self.hook.table_exists(self.table, db=self.database)) self.hook.metastore.__enter__().get_table.assert_called_with( - dbname='airflow', tbl_name='static_babynames_partitioned') + dbname='airflow', tbl_name='static_babynames_partitioned' + ) # Test with non-existent table. self.hook.metastore.__enter__().get_table = mock.MagicMock(side_effect=Exception()) - self.assertFalse( - self.hook.table_exists("does-not-exist") - ) + self.assertFalse(self.hook.table_exists("does-not-exist")) self.hook.metastore.__enter__().get_table.assert_called_with( - dbname='default', tbl_name='does-not-exist') + dbname='default', tbl_name='does-not-exist' + ) @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.table_exists') @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_metastore_client') @@ -564,12 +546,10 @@ def test_drop_partition(self, get_metastore_client_mock, table_exist_mock): table_exist_mock.return_value = True ret = self.hook.drop_partitions(self.table, db=self.database, part_vals=[DEFAULT_DATE_DS]) table_exist_mock.assert_called_once_with(self.table, self.database) - assert metastore_mock.drop_partition( - self.table, db=self.database, part_vals=[DEFAULT_DATE_DS]), ret + assert metastore_mock.drop_partition(self.table, db=self.database, part_vals=[DEFAULT_DATE_DS]), ret class TestHiveServer2Hook(unittest.TestCase): - def _upload_dataframe(self): df = pd.DataFrame({'a': [1, 2], 'b': [1, 2]}) self.local_path = '/tmp/TestHiveServer2Hook.csv' @@ -594,11 +574,11 @@ def setUp(self): LOAD DATA LOCAL INPATH '{{ params.csv_path }}' OVERWRITE INTO TABLE {{ params.table }}; """ - self.columns = ['{}.a'.format(self.table), - '{}.b'.format(self.table)] + self.columns = ['{}.a'.format(self.table), '{}.b'.format(self.table)] - with mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_metastore_client' - ) as get_metastore_mock: + with mock.patch( + 'airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_metastore_client' + ) as get_metastore_mock: get_metastore_mock.return_value = mock.MagicMock() self.hook = HiveMetastoreHook() @@ -614,7 +594,7 @@ def test_get_conn_with_password(self, mock_connect): with mock.patch.dict( 'os.environ', - {conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"} + {conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"}, ): HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn() mock_connect.assert_called_once_with( @@ -624,68 +604,63 @@ def test_get_conn_with_password(self, mock_connect): kerberos_service_name=None, username='conn_id', password='conn_pass', - database='default') + database='default', + ) def test_get_records(self): hook = MockHiveServer2Hook() query = "SELECT * FROM {}".format(self.table) - with mock.patch.dict('os.environ', { - 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', - 'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835', - 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', - 'AIRFLOW_CTX_DAG_RUN_ID': '55', - 'AIRFLOW_CTX_DAG_OWNER': 'airflow', - 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', - }): + with mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', + 'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835', + 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', + 'AIRFLOW_CTX_DAG_RUN_ID': '55', + 'AIRFLOW_CTX_DAG_OWNER': 'airflow', + 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', + }, + ): results = hook.get_records(query, schema=self.database) self.assertListEqual(results, [(1, 1), (2, 2)]) hook.get_conn.assert_called_with(self.database) - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_id=test_dag_id') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.task_id=HiveHook_3835') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_run_id=55') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_owner=airflow') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_email=test@airflow.com') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_id=test_dag_id') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.task_id=HiveHook_3835') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_run_id=55') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_owner=airflow') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_email=test@airflow.com') def test_get_pandas_df(self): hook = MockHiveServer2Hook() query = "SELECT * FROM {}".format(self.table) - with mock.patch.dict('os.environ', { - 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', - 'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835', - 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', - 'AIRFLOW_CTX_DAG_RUN_ID': '55', - 'AIRFLOW_CTX_DAG_OWNER': 'airflow', - 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', - }): + with mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', + 'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835', + 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', + 'AIRFLOW_CTX_DAG_RUN_ID': '55', + 'AIRFLOW_CTX_DAG_OWNER': 'airflow', + 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', + }, + ): df = hook.get_pandas_df(query, schema=self.database) self.assertEqual(len(df), 2) self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2]) hook.get_conn.assert_called_with(self.database) - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_id=test_dag_id') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.task_id=HiveHook_3835') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_run_id=55') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_owner=airflow') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_email=test@airflow.com') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_id=test_dag_id') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.task_id=HiveHook_3835') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_run_id=55') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_owner=airflow') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_email=test@airflow.com') def test_get_results_header(self): hook = MockHiveServer2Hook() @@ -693,8 +668,7 @@ def test_get_results_header(self): query = "SELECT * FROM {}".format(self.table) results = hook.get_results(query, schema=self.database) - self.assertListEqual([col[0] for col in results['header']], - self.columns) + self.assertListEqual([col[0] for col in results['header']], self.columns) def test_get_results_data(self): hook = MockHiveServer2Hook() @@ -706,16 +680,29 @@ def test_get_results_data(self): def test_to_csv(self): hook = MockHiveServer2Hook() - hook._get_results = mock.MagicMock(return_value=iter([ - [ - ('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True), - ('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True) - ], (1, 1), (2, 2) - ])) + hook._get_results = mock.MagicMock( + return_value=iter( + [ + [ + ('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True), + ('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True), + ], + (1, 1), + (2, 2), + ] + ) + ) query = "SELECT * FROM {}".format(self.table) csv_filepath = 'query_results.csv' - hook.to_csv(query, csv_filepath, schema=self.database, - delimiter=',', lineterminator='\n', output_header=True, fetch_size=2) + hook.to_csv( + query, + csv_filepath, + schema=self.database, + delimiter=',', + lineterminator='\n', + output_header=True, + fetch_size=2, + ) df = pd.read_csv(csv_filepath, sep=',') self.assertListEqual(df.columns.tolist(), self.columns) self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2]) @@ -730,14 +717,17 @@ def test_multi_statements(self): hook = MockHiveServer2Hook() - with mock.patch.dict('os.environ', { - 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', - 'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835', - 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', - 'AIRFLOW_CTX_DAG_RUN_ID': '55', - 'AIRFLOW_CTX_DAG_OWNER': 'airflow', - 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', - }): + with mock.patch.dict( + 'os.environ', + { + 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', + 'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835', + 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', + 'AIRFLOW_CTX_DAG_RUN_ID': '55', + 'AIRFLOW_CTX_DAG_OWNER': 'airflow', + 'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com', + }, + ): # df = hook.get_pandas_df(query, schema=self.database) results = hook.get_records(sqls, schema=self.database) self.assertListEqual(results, [(1, 1), (2, 2)]) @@ -746,58 +736,60 @@ def test_multi_statements(self): # self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2]) hook.get_conn.assert_called_with(self.database) - hook.mock_cursor.execute.assert_any_call( - 'CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)') - hook.mock_cursor.execute.assert_any_call( - 'SELECT * FROM {}'.format(self.table)) - hook.mock_cursor.execute.assert_any_call( - 'DROP TABLE test_multi_statements') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_id=test_dag_id') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.task_id=HiveHook_3835') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_run_id=55') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_owner=airflow') - hook.mock_cursor.execute.assert_any_call( - 'set airflow.ctx.dag_email=test@airflow.com') + hook.mock_cursor.execute.assert_any_call('CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)') + hook.mock_cursor.execute.assert_any_call('SELECT * FROM {}'.format(self.table)) + hook.mock_cursor.execute.assert_any_call('DROP TABLE test_multi_statements') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_id=test_dag_id') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.task_id=HiveHook_3835') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_run_id=55') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_owner=airflow') + hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_email=test@airflow.com') def test_get_results_with_hive_conf(self): - hql = ["set key", - "set airflow.ctx.dag_id", - "set airflow.ctx.dag_run_id", - "set airflow.ctx.task_id", - "set airflow.ctx.execution_date"] - - dag_id_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] - task_id_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] - execution_date_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ - 'env_var_format'] - dag_run_id_ctx_var_name = \ - AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ - 'env_var_format'] - - with mock.patch.dict('os.environ', { - dag_id_ctx_var_name: 'test_dag_id', - task_id_ctx_var_name: 'test_task_id', - execution_date_ctx_var_name: 'test_execution_date', - dag_run_id_ctx_var_name: 'test_dag_run_id', - - }): - hook = MockHiveServer2Hook() - hook._get_results = mock.MagicMock(return_value=iter( - ["header", ("value", "test"), ("test_dag_id", "test"), ("test_task_id", "test"), - ("test_execution_date", "test"), ("test_dag_run_id", "test")] - )) + hql = [ + "set key", + "set airflow.ctx.dag_id", + "set airflow.ctx.dag_run_id", + "set airflow.ctx.task_id", + "set airflow.ctx.execution_date", + ] - output = '\n'.join(res_tuple[0] for res_tuple in hook.get_results( - hql=hql, hive_conf={'key': 'value'})['data']) + dag_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] + task_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] + execution_date_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ + 'env_var_format' + ] + dag_run_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ + 'env_var_format' + ] + + with mock.patch.dict( + 'os.environ', + { + dag_id_ctx_var_name: 'test_dag_id', + task_id_ctx_var_name: 'test_task_id', + execution_date_ctx_var_name: 'test_execution_date', + dag_run_id_ctx_var_name: 'test_dag_run_id', + }, + ): + hook = MockHiveServer2Hook() + hook._get_results = mock.MagicMock( + return_value=iter( + [ + "header", + ("value", "test"), + ("test_dag_id", "test"), + ("test_task_id", "test"), + ("test_execution_date", "test"), + ("test_dag_run_id", "test"), + ] + ) + ) + + output = '\n'.join( + res_tuple[0] for res_tuple in hook.get_results(hql=hql, hive_conf={'key': 'value'})['data'] + ) self.assertIn('value', output) self.assertIn('test_dag_id', output) self.assertIn('test_task_id', output) @@ -806,7 +798,6 @@ def test_get_results_with_hive_conf(self): class TestHiveCli(unittest.TestCase): - def setUp(self): self.nondefault_schema = "nondefault" os.environ["AIRFLOW__CORE__SECURITY"] = "kerberos" diff --git a/tests/providers/apache/hive/operators/test_hive.py b/tests/providers/apache/hive/operators/test_hive.py index 63c04b2beb00d..6541c8561b18f 100644 --- a/tests/providers/apache/hive/operators/test_hive.py +++ b/tests/providers/apache/hive/operators/test_hive.py @@ -30,22 +30,18 @@ class HiveOperatorConfigTest(TestHiveEnvironment): - def test_hive_airflow_default_config_queue(self): op = MockHiveOperator( task_id='test_default_config_queue', hql=self.hql, mapred_queue_priority='HIGH', mapred_job_name='airflow.test_default_config_queue', - dag=self.dag) + dag=self.dag, + ) # just check that the correct default value in test_default.cfg is used - test_config_hive_mapred_queue = conf.get( - 'hive', - 'default_hive_mapred_queue' - ) - self.assertEqual(op.get_hook().mapred_queue, - test_config_hive_mapred_queue) + test_config_hive_mapred_queue = conf.get('hive', 'default_hive_mapred_queue') + self.assertEqual(op.get_hook().mapred_queue, test_config_hive_mapred_queue) def test_hive_airflow_default_config_queue_override(self): specific_mapred_queue = 'default' @@ -55,18 +51,18 @@ def test_hive_airflow_default_config_queue_override(self): mapred_queue=specific_mapred_queue, mapred_queue_priority='HIGH', mapred_job_name='airflow.test_default_config_queue', - dag=self.dag) + dag=self.dag, + ) self.assertEqual(op.get_hook().mapred_queue, specific_mapred_queue) class HiveOperatorTest(TestHiveEnvironment): - def test_hiveconf_jinja_translate(self): hql = "SELECT ${num_col} FROM ${hiveconf:table};" op = MockHiveOperator( - hiveconf_jinja_translate=True, - task_id='dry_run_basic_hql', hql=hql, dag=self.dag) + hiveconf_jinja_translate=True, task_id='dry_run_basic_hql', hql=hql, dag=self.dag + ) op.prepare_template() self.assertEqual(op.hql, "SELECT {{ num_col }} FROM {{ table }};") @@ -74,20 +70,18 @@ def test_hiveconf(self): hql = "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});" op = MockHiveOperator( hiveconfs={'table': 'static_babynames', 'day': '{{ ds }}'}, - task_id='dry_run_basic_hql', hql=hql, dag=self.dag) + task_id='dry_run_basic_hql', + hql=hql, + dag=self.dag, + ) op.prepare_template() - self.assertEqual( - op.hql, - "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});") + self.assertEqual(op.hql, "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});") @mock.patch('airflow.providers.apache.hive.operators.hive.HiveOperator.get_hook') def test_mapred_job_name(self, mock_get_hook): mock_hook = mock.MagicMock() mock_get_hook.return_value = mock_hook - op = MockHiveOperator( - task_id='test_mapred_job_name', - hql=self.hql, - dag=self.dag) + op = MockHiveOperator(task_id='test_mapred_job_name', hql=self.hql, dag=self.dag) fake_execution_date = timezone.datetime(2018, 6, 19) fake_ti = TaskInstance(task=op, execution_date=fake_execution_date) @@ -96,17 +90,15 @@ def test_mapred_job_name(self, mock_get_hook): op.execute(fake_context) self.assertEqual( - "Airflow HiveOperator task for {}.{}.{}.{}" - .format(fake_ti.hostname, - self.dag.dag_id, op.task_id, - fake_execution_date.isoformat()), mock_hook.mapred_job_name) + "Airflow HiveOperator task for {}.{}.{}.{}".format( + fake_ti.hostname, self.dag.dag_id, op.task_id, fake_execution_date.isoformat() + ), + mock_hook.mapred_job_name, + ) -@unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") +@unittest.skipIf('AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set") class TestHivePresto(TestHiveEnvironment): - @mock.patch('tempfile.tempdir', '/tmp/') @mock.patch('tempfile._RandomNameSequence.__next__') @mock.patch('subprocess.Popen') @@ -114,25 +106,43 @@ def test_hive(self, mock_popen, mock_temp_dir): mock_subprocess = MockSubProcess() mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "tst" - op = HiveOperator( - task_id='basic_hql', hql=self.hql, dag=self.dag, mapred_job_name="test_job_name") - - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=test_dag_id', '-hiveconf', 'airflow.ctx.task_id=basic_hql', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=', '-hiveconf', 'airflow.ctx.dag_owner=airflow', '-hiveconf', - 'airflow.ctx.dag_email=', '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', - 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', '-hiveconf', - 'mapred.job.name=test_job_name', '-f', '/tmp/airflow_hiveop_tst/tmptst'] + op = HiveOperator(task_id='basic_hql', hql=self.hql, dag=self.dag, mapred_job_name="test_job_name") + + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=test_dag_id', + '-hiveconf', + 'airflow.ctx.task_id=basic_hql', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-hiveconf', + 'mapred.job.name=test_job_name', + '-f', + '/tmp/airflow_hiveop_tst/tmptst', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_tst", - close_fds=True + close_fds=True, ) @mock.patch('tempfile.tempdir', '/tmp/') @@ -143,28 +153,51 @@ def test_hive_queues(self, mock_popen, mock_temp_dir): mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "tst" - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=test_dag_id', '-hiveconf', 'airflow.ctx.task_id=test_hive_queues', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=', '-hiveconf', 'mapreduce.job.queuename=default', - '-hiveconf', 'mapred.job.queue.name=default', '-hiveconf', 'tez.queue.name=default', - '-hiveconf', 'mapreduce.job.priority=HIGH', '-hiveconf', - 'mapred.job.name=airflow.test_hive_queues', '-f', '/tmp/airflow_hiveop_tst/tmptst'] + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=test_dag_id', + '-hiveconf', + 'airflow.ctx.task_id=test_hive_queues', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=', + '-hiveconf', + 'mapreduce.job.queuename=default', + '-hiveconf', + 'mapred.job.queue.name=default', + '-hiveconf', + 'tez.queue.name=default', + '-hiveconf', + 'mapreduce.job.priority=HIGH', + '-hiveconf', + 'mapred.job.name=airflow.test_hive_queues', + '-f', + '/tmp/airflow_hiveop_tst/tmptst', + ] op = HiveOperator( - task_id='test_hive_queues', hql=self.hql, - mapred_queue='default', mapred_queue_priority='HIGH', + task_id='test_hive_queues', + hql=self.hql, + mapred_queue='default', + mapred_queue_priority='HIGH', mapred_job_name='airflow.test_hive_queues', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_tst", - close_fds=True + close_fds=True, ) @mock.patch('tempfile.tempdir', '/tmp/') @@ -175,23 +208,40 @@ def test_hive_dryrun(self, mock_popen, mock_temp_dir): mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "tst" - op = HiveOperator(task_id='dry_run_basic_hql', - hql=self.hql, dag=self.dag) + op = HiveOperator(task_id='dry_run_basic_hql', hql=self.hql, dag=self.dag) op.dry_run() - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', - '-hiveconf', 'airflow.ctx.dag_id=', '-hiveconf', 'airflow.ctx.task_id=', - '-hiveconf', 'airflow.ctx.execution_date=', '-hiveconf', 'airflow.ctx.dag_run_id=', - '-hiveconf', 'airflow.ctx.dag_owner=', '-hiveconf', 'airflow.ctx.dag_email=', - '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', - 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', - '-f', '/tmp/airflow_hiveop_tst/tmptst'] + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=', + '-hiveconf', + 'airflow.ctx.task_id=', + '-hiveconf', + 'airflow.ctx.execution_date=', + '-hiveconf', + 'airflow.ctx.dag_run_id=', + '-hiveconf', + 'airflow.ctx.dag_owner=', + '-hiveconf', + 'airflow.ctx.dag_email=', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_tst/tmptst', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_tst", - close_fds=True + close_fds=True, ) @mock.patch('tempfile.tempdir', '/tmp/') @@ -202,25 +252,46 @@ def test_beeline(self, mock_popen, mock_temp_dir): mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "tst" - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', - '-hiveconf', 'airflow.ctx.dag_id=test_dag_id', '-hiveconf', - 'airflow.ctx.task_id=beeline_hql', '-hiveconf', - 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', - '-hiveconf', 'airflow.ctx.dag_run_id=', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=', - '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', - 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', '-hiveconf', - 'mapred.job.name=test_job_name', '-f', '/tmp/airflow_hiveop_tst/tmptst'] + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=test_dag_id', + '-hiveconf', + 'airflow.ctx.task_id=beeline_hql', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-hiveconf', + 'mapred.job.name=test_job_name', + '-f', + '/tmp/airflow_hiveop_tst/tmptst', + ] op = HiveOperator( - task_id='beeline_hql', hive_cli_conn_id='hive_cli_default', - hql=self.hql, dag=self.dag, mapred_job_name="test_job_name") - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task_id='beeline_hql', + hive_cli_conn_id='hive_cli_default', + hql=self.hql, + dag=self.dag, + mapred_job_name="test_job_name", + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_tst", - close_fds=True + close_fds=True, ) diff --git a/tests/providers/apache/hive/operators/test_hive_stats.py b/tests/providers/apache/hive/operators/test_hive_stats.py index 5c2b33717bbe8..46e739ebf2669 100644 --- a/tests/providers/apache/hive/operators/test_hive_stats.py +++ b/tests/providers/apache/hive/operators/test_hive_stats.py @@ -38,7 +38,6 @@ def __init__(self, col_name, col_type): class TestHiveStatsCollectionOperator(TestHiveEnvironment): - def setUp(self): self.kwargs = dict( table='table', @@ -55,9 +54,7 @@ def test_get_default_exprs(self): default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, None) - self.assertEqual(default_exprs, { - (col, 'non_null'): 'COUNT({})'.format(col) - }) + self.assertEqual(default_exprs, {(col, 'non_null'): 'COUNT({})'.format(col)}) def test_get_default_exprs_excluded_cols(self): col = 'excluded_col' @@ -72,13 +69,16 @@ def test_get_default_exprs_number(self): for col_type in ['double', 'int', 'bigint', 'float']: default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type) - self.assertEqual(default_exprs, { - (col, 'avg'): 'AVG({})'.format(col), - (col, 'max'): 'MAX({})'.format(col), - (col, 'min'): 'MIN({})'.format(col), - (col, 'non_null'): 'COUNT({})'.format(col), - (col, 'sum'): 'SUM({})'.format(col) - }) + self.assertEqual( + default_exprs, + { + (col, 'avg'): 'AVG({})'.format(col), + (col, 'max'): 'MAX({})'.format(col), + (col, 'min'): 'MIN({})'.format(col), + (col, 'non_null'): 'COUNT({})'.format(col), + (col, 'sum'): 'SUM({})'.format(col), + }, + ) def test_get_default_exprs_boolean(self): col = 'col' @@ -86,11 +86,14 @@ def test_get_default_exprs_boolean(self): default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type) - self.assertEqual(default_exprs, { - (col, 'false'): 'SUM(CASE WHEN NOT {} THEN 1 ELSE 0 END)'.format(col), - (col, 'non_null'): 'COUNT({})'.format(col), - (col, 'true'): 'SUM(CASE WHEN {} THEN 1 ELSE 0 END)'.format(col) - }) + self.assertEqual( + default_exprs, + { + (col, 'false'): 'SUM(CASE WHEN NOT {} THEN 1 ELSE 0 END)'.format(col), + (col, 'non_null'): 'COUNT({})'.format(col), + (col, 'true'): 'SUM(CASE WHEN {} THEN 1 ELSE 0 END)'.format(col), + }, + ) def test_get_default_exprs_string(self): col = 'col' @@ -98,11 +101,14 @@ def test_get_default_exprs_string(self): default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type) - self.assertEqual(default_exprs, { - (col, 'approx_distinct'): 'APPROX_DISTINCT({})'.format(col), - (col, 'len'): 'SUM(CAST(LENGTH({}) AS BIGINT))'.format(col), - (col, 'non_null'): 'COUNT({})'.format(col) - }) + self.assertEqual( + default_exprs, + { + (col, 'approx_distinct'): 'APPROX_DISTINCT({})'.format(col), + (col, 'len'): 'SUM(CAST(LENGTH({}) AS BIGINT))'.format(col), + (col, 'non_null'): 'COUNT({})'.format(col), + }, + ) @patch('airflow.providers.apache.hive.operators.hive_stats.json.dumps') @patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook') @@ -116,54 +122,46 @@ def test_execute(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_ho hive_stats_collection_operator.execute(context={}) mock_hive_metastore_hook.assert_called_once_with( - metastore_conn_id=hive_stats_collection_operator.metastore_conn_id) + metastore_conn_id=hive_stats_collection_operator.metastore_conn_id + ) mock_hive_metastore_hook.return_value.get_table.assert_called_once_with( - table_name=hive_stats_collection_operator.table) + table_name=hive_stats_collection_operator.table + ) mock_presto_hook.assert_called_once_with(presto_conn_id=hive_stats_collection_operator.presto_conn_id) mock_mysql_hook.assert_called_once_with(hive_stats_collection_operator.mysql_conn_id) mock_json_dumps.assert_called_once_with(hive_stats_collection_operator.partition, sort_keys=True) field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols } - exprs = { - ('', 'count'): 'COUNT(*)' - } + exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): exprs.update(hive_stats_collection_operator.get_default_exprs(col, col_type)) exprs = OrderedDict(exprs) - rows = [(hive_stats_collection_operator.ds, - hive_stats_collection_operator.dttm, - hive_stats_collection_operator.table, - mock_json_dumps.return_value) + - (r[0][0], r[0][1], r[1]) - for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value)] + rows = [ + ( + hive_stats_collection_operator.ds, + hive_stats_collection_operator.dttm, + hive_stats_collection_operator.table, + mock_json_dumps.return_value, + ) + + (r[0][0], r[0][1], r[1]) + for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value) + ] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, - target_fields=[ - 'ds', - 'dttm', - 'table_name', - 'partition_repr', - 'col', - 'metric', - 'value', - ] + target_fields=['ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',], ) @patch('airflow.providers.apache.hive.operators.hive_stats.json.dumps') @patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook') @patch('airflow.providers.apache.hive.operators.hive_stats.PrestoHook') @patch('airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook') - def test_execute_with_assignment_func(self, - mock_hive_metastore_hook, - mock_presto_hook, - mock_mysql_hook, - mock_json_dumps): + def test_execute_with_assignment_func( + self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps + ): def assignment_func(col, _): - return { - (col, 'test'): 'TEST({})'.format(col) - } + return {(col, 'test'): 'TEST({})'.format(col)} self.kwargs.update(dict(assignment_func=assignment_func)) mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col] @@ -175,41 +173,33 @@ def assignment_func(col, _): field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols } - exprs = { - ('', 'count'): 'COUNT(*)' - } + exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): exprs.update(hive_stats_collection_operator.assignment_func(col, col_type)) exprs = OrderedDict(exprs) - rows = [(hive_stats_collection_operator.ds, - hive_stats_collection_operator.dttm, - hive_stats_collection_operator.table, - mock_json_dumps.return_value) + - (r[0][0], r[0][1], r[1]) - for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value)] + rows = [ + ( + hive_stats_collection_operator.ds, + hive_stats_collection_operator.dttm, + hive_stats_collection_operator.table, + mock_json_dumps.return_value, + ) + + (r[0][0], r[0][1], r[1]) + for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value) + ] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, - target_fields=[ - 'ds', - 'dttm', - 'table_name', - 'partition_repr', - 'col', - 'metric', - 'value', - ] + target_fields=['ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',], ) @patch('airflow.providers.apache.hive.operators.hive_stats.json.dumps') @patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook') @patch('airflow.providers.apache.hive.operators.hive_stats.PrestoHook') @patch('airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook') - def test_execute_with_assignment_func_no_return_value(self, - mock_hive_metastore_hook, - mock_presto_hook, - mock_mysql_hook, - mock_json_dumps): + def test_execute_with_assignment_func_no_return_value( + self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps + ): def assignment_func(_, __): pass @@ -223,30 +213,24 @@ def assignment_func(_, __): field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols } - exprs = { - ('', 'count'): 'COUNT(*)' - } + exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): exprs.update(hive_stats_collection_operator.get_default_exprs(col, col_type)) exprs = OrderedDict(exprs) - rows = [(hive_stats_collection_operator.ds, - hive_stats_collection_operator.dttm, - hive_stats_collection_operator.table, - mock_json_dumps.return_value) + - (r[0][0], r[0][1], r[1]) - for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value)] + rows = [ + ( + hive_stats_collection_operator.ds, + hive_stats_collection_operator.dttm, + hive_stats_collection_operator.table, + mock_json_dumps.return_value, + ) + + (r[0][0], r[0][1], r[1]) + for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value) + ] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, - target_fields=[ - 'ds', - 'dttm', - 'table_name', - 'partition_repr', - 'col', - 'metric', - 'value', - ] + target_fields=['ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',], ) @patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook') @@ -263,11 +247,9 @@ def test_execute_no_query_results(self, mock_hive_metastore_hook, mock_presto_ho @patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook') @patch('airflow.providers.apache.hive.operators.hive_stats.PrestoHook') @patch('airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook') - def test_execute_delete_previous_runs_rows(self, - mock_hive_metastore_hook, - mock_presto_hook, - mock_mysql_hook, - mock_json_dumps): + def test_execute_delete_previous_runs_rows( + self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps + ): mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col] mock_mysql_hook.return_value.get_records.return_value = True @@ -283,57 +265,67 @@ def test_execute_delete_previous_runs_rows(self, """.format( hive_stats_collection_operator.table, mock_json_dumps.return_value, - hive_stats_collection_operator.dttm + hive_stats_collection_operator.dttm, ) mock_mysql_hook.return_value.run.assert_called_once_with(sql) @unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") - @patch('airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook', - side_effect=MockHiveMetastoreHook) + 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + ) + @patch( + 'airflow.providers.apache.hive.operators.hive_stats.HiveMetastoreHook', + side_effect=MockHiveMetastoreHook, + ) def test_runs_for_hive_stats(self, mock_hive_metastore_hook): mock_mysql_hook = MockMySqlHook() mock_presto_hook = MockPrestoHook() - with patch('airflow.providers.apache.hive.operators.hive_stats.PrestoHook', - return_value=mock_presto_hook): - with patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook', - return_value=mock_mysql_hook): + with patch( + 'airflow.providers.apache.hive.operators.hive_stats.PrestoHook', return_value=mock_presto_hook + ): + with patch( + 'airflow.providers.apache.hive.operators.hive_stats.MySqlHook', return_value=mock_mysql_hook + ): op = HiveStatsCollectionOperator( task_id='hive_stats_check', table="airflow.static_babynames_partitioned", partition={'ds': DEFAULT_DATE_DS}, - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - select_count_query = "SELECT COUNT(*) AS __count FROM airflow." \ + select_count_query = ( + "SELECT COUNT(*) AS __count FROM airflow." + "static_babynames_partitioned WHERE ds = '2015-01-01';" + ) mock_presto_hook.get_first.assert_called_with(hql=select_count_query) - expected_stats_select_query = "SELECT 1 FROM hive_stats WHERE table_name='airflow." \ - + "static_babynames_partitioned' AND " \ - + "partition_repr='{\"ds\": \"2015-01-01\"}' AND " \ - + "dttm='2015-01-01T00:00:00+00:00' " \ + expected_stats_select_query = ( + "SELECT 1 FROM hive_stats WHERE table_name='airflow." + + "static_babynames_partitioned' AND " + + "partition_repr='{\"ds\": \"2015-01-01\"}' AND " + + "dttm='2015-01-01T00:00:00+00:00' " + "LIMIT 1;" + ) raw_stats_select_query = mock_mysql_hook.get_records.call_args_list[0][0][0] actual_stats_select_query = re.sub(r'\s{2,}', ' ', raw_stats_select_query).strip() self.assertEqual(expected_stats_select_query, actual_stats_select_query) - insert_rows_val = [('2015-01-01', '2015-01-01T00:00:00+00:00', - 'airflow.static_babynames_partitioned', - '{"ds": "2015-01-01"}', '', 'count', ['val_0', 'val_1'])] - - mock_mysql_hook.insert_rows.assert_called_with(table='hive_stats', - rows=insert_rows_val, - target_fields=[ - 'ds', - 'dttm', - 'table_name', - 'partition_repr', - 'col', - 'metric', - 'value', - ]) + insert_rows_val = [ + ( + '2015-01-01', + '2015-01-01T00:00:00+00:00', + 'airflow.static_babynames_partitioned', + '{"ds": "2015-01-01"}', + '', + 'count', + ['val_0', 'val_1'], + ) + ] + + mock_mysql_hook.insert_rows.assert_called_with( + table='hive_stats', + rows=insert_rows_val, + target_fields=['ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',], + ) diff --git a/tests/providers/apache/hive/sensors/test_hdfs.py b/tests/providers/apache/hive/sensors/test_hdfs.py index b658cdbeceafa..aa4fca0138ce2 100644 --- a/tests/providers/apache/hive/sensors/test_hdfs.py +++ b/tests/providers/apache/hive/sensors/test_hdfs.py @@ -23,15 +23,12 @@ from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment -@unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") +@unittest.skipIf('AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set") class TestHdfsSensor(TestHiveEnvironment): - def test_hdfs_sensor(self): op = HdfsSensor( task_id='hdfs_sensor_check', filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_hive_partition.py b/tests/providers/apache/hive/sensors/test_hive_partition.py index ccb551422b105..e992ec0e97c91 100644 --- a/tests/providers/apache/hive/sensors/test_hive_partition.py +++ b/tests/providers/apache/hive/sensors/test_hive_partition.py @@ -25,16 +25,14 @@ from tests.test_utils.mock_hooks import MockHiveMetastoreHook -@unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") -@patch('airflow.providers.apache.hive.sensors.hive_partition.HiveMetastoreHook', - side_effect=MockHiveMetastoreHook) +@unittest.skipIf('AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set") +@patch( + 'airflow.providers.apache.hive.sensors.hive_partition.HiveMetastoreHook', + side_effect=MockHiveMetastoreHook, +) class TestHivePartitionSensor(TestHiveEnvironment): def test_hive_partition_sensor(self, mock_hive_metastore_hook): op = HivePartitionSensor( - task_id='hive_partition_check', - table='airflow.static_babynames_partitioned', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + task_id='hive_partition_check', table='airflow.static_babynames_partitioned', dag=self.dag + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_metastore_partition.py b/tests/providers/apache/hive/sensors/test_metastore_partition.py index d7227adef001f..fa8b9eda50d6e 100644 --- a/tests/providers/apache/hive/sensors/test_metastore_partition.py +++ b/tests/providers/apache/hive/sensors/test_metastore_partition.py @@ -25,9 +25,7 @@ from tests.test_utils.mock_process import MockDBConnection -@unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") +@unittest.skipIf('AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set") class TestHivePartitionSensor(TestHiveEnvironment): def test_hive_metastore_sql_sensor(self): op = MetastorePartitionSensor( @@ -36,7 +34,7 @@ def test_hive_metastore_sql_sensor(self): sql='test_sql', table='airflow.static_babynames_partitioned', partition_name='ds={}'.format(DEFAULT_DATE_DS), - dag=self.dag) + dag=self.dag, + ) op._get_hook = mock.MagicMock(return_value=MockDBConnection({})) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_named_hive_partition.py b/tests/providers/apache/hive/sensors/test_named_hive_partition.py index bbe507c2dfcb1..8a7ab7d319ba1 100644 --- a/tests/providers/apache/hive/sensors/test_named_hive_partition.py +++ b/tests/providers/apache/hive/sensors/test_named_hive_partition.py @@ -38,8 +38,7 @@ class TestNamedHivePartitionSensor(unittest.TestCase): def setUp(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) - self.next_day = (DEFAULT_DATE + - timedelta(days=1)).isoformat()[:10] + self.next_day = (DEFAULT_DATE + timedelta(days=1)).isoformat()[:10] self.database = 'airflow' self.partition_by = 'ds' self.table = 'static_babynames_partitioned' @@ -60,26 +59,19 @@ def setUp(self): self.hook = MockHiveMetastoreHook() op = MockHiveOperator( task_id='HiveHook_' + str(random.randint(1, 10000)), - params={ - 'database': self.database, - 'table': self.table, - 'partition_by': self.partition_by - }, + params={'database': self.database, 'table': self.table, 'partition_by': self.partition_by}, hive_cli_conn_id='hive_cli_default', - hql=self.hql, dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + hql=self.hql, + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_parse_partition_name_correct(self): schema = 'default' table = 'users' partition = 'ds=2016-01-01/state=IT' - name = '{schema}.{table}/{partition}'.format(schema=schema, - table=table, - partition=partition) - parsed_schema, parsed_table, parsed_partition = ( - NamedHivePartitionSensor.parse_partition_name(name) - ) + name = '{schema}.{table}/{partition}'.format(schema=schema, table=table, partition=partition) + parsed_schema, parsed_table, parsed_partition = NamedHivePartitionSensor.parse_partition_name(name) self.assertEqual(schema, parsed_schema) self.assertEqual(table, parsed_table) self.assertEqual(partition, parsed_partition) @@ -92,95 +84,87 @@ def test_parse_partition_name_incorrect(self): def test_parse_partition_name_default(self): table = 'users' partition = 'ds=2016-01-01/state=IT' - name = '{table}/{partition}'.format(table=table, - partition=partition) - parsed_schema, parsed_table, parsed_partition = ( - NamedHivePartitionSensor.parse_partition_name(name) - ) + name = '{table}/{partition}'.format(table=table, partition=partition) + parsed_schema, parsed_table, parsed_partition = NamedHivePartitionSensor.parse_partition_name(name) self.assertEqual('default', parsed_schema) self.assertEqual(table, parsed_table) self.assertEqual(partition, parsed_partition) def test_poke_existing(self): self.hook.metastore.__enter__().check_for_named_partition.return_value = True - partitions = ["{}.{}/{}={}".format(self.database, - self.table, - self.partition_by, - DEFAULT_DATE_DS)] - sensor = NamedHivePartitionSensor(partition_names=partitions, - task_id='test_poke_existing', - poke_interval=1, - hook=self.hook, - dag=self.dag) + partitions = ["{}.{}/{}={}".format(self.database, self.table, self.partition_by, DEFAULT_DATE_DS)] + sensor = NamedHivePartitionSensor( + partition_names=partitions, + task_id='test_poke_existing', + poke_interval=1, + hook=self.hook, + dag=self.dag, + ) self.assertTrue(sensor.poke(None)) self.hook.metastore.__enter__().check_for_named_partition.assert_called_with( - self.database, self.table, f"{self.partition_by}={DEFAULT_DATE_DS}") + self.database, self.table, f"{self.partition_by}={DEFAULT_DATE_DS}" + ) def test_poke_non_existing(self): self.hook.metastore.__enter__().check_for_named_partition.return_value = False - partitions = ["{}.{}/{}={}".format(self.database, - self.table, - self.partition_by, - self.next_day)] - sensor = NamedHivePartitionSensor(partition_names=partitions, - task_id='test_poke_non_existing', - poke_interval=1, - hook=self.hook, - dag=self.dag) + partitions = ["{}.{}/{}={}".format(self.database, self.table, self.partition_by, self.next_day)] + sensor = NamedHivePartitionSensor( + partition_names=partitions, + task_id='test_poke_non_existing', + poke_interval=1, + hook=self.hook, + dag=self.dag, + ) self.assertFalse(sensor.poke(None)) self.hook.metastore.__enter__().check_for_named_partition.assert_called_with( - self.database, self.table, f"{self.partition_by}={self.next_day}") + self.database, self.table, f"{self.partition_by}={self.next_day}" + ) -@unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") +@unittest.skipIf('AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set") class TestPartitions(TestHiveEnvironment): - def test_succeeds_on_one_partition(self): mock_hive_metastore_hook = MockHiveMetastoreHook() - mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock( - return_value=True) + mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock(return_value=True) op = NamedHivePartitionSensor( task_id='hive_partition_check', - partition_names=[ - "airflow.static_babynames_partitioned/ds={{ds}}" - ], + partition_names=["airflow.static_babynames_partitioned/ds={{ds}}"], dag=self.dag, - hook=mock_hive_metastore_hook + hook=mock_hive_metastore_hook, ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_hive_metastore_hook.check_for_named_partition.assert_called_once_with( - 'airflow', 'static_babynames_partitioned', 'ds=2015-01-01') + 'airflow', 'static_babynames_partitioned', 'ds=2015-01-01' + ) def test_succeeds_on_multiple_partitions(self): mock_hive_metastore_hook = MockHiveMetastoreHook() - mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock( - return_value=True) + mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock(return_value=True) op = NamedHivePartitionSensor( task_id='hive_partition_check', partition_names=[ "airflow.static_babynames_partitioned/ds={{ds}}", - "airflow.static_babynames_partitioned2/ds={{ds}}" + "airflow.static_babynames_partitioned2/ds={{ds}}", ], dag=self.dag, - hook=mock_hive_metastore_hook + hook=mock_hive_metastore_hook, ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_hive_metastore_hook.check_for_named_partition.assert_any_call( - 'airflow', 'static_babynames_partitioned', 'ds=2015-01-01') + 'airflow', 'static_babynames_partitioned', 'ds=2015-01-01' + ) mock_hive_metastore_hook.check_for_named_partition.assert_any_call( - 'airflow', 'static_babynames_partitioned2', 'ds=2015-01-01') + 'airflow', 'static_babynames_partitioned2', 'ds=2015-01-01' + ) def test_parses_partitions_with_periods(self): name = NamedHivePartitionSensor.parse_partition_name( - partition="schema.table/part1=this.can.be.an.issue/part2=ok") + partition="schema.table/part1=this.can.be.an.issue/part2=ok" + ) self.assertEqual(name[0], "schema") self.assertEqual(name[1], "table") self.assertEqual(name[2], "part1=this.can.be.an.issue/part2=ok") @@ -188,19 +172,17 @@ def test_parses_partitions_with_periods(self): def test_times_out_on_nonexistent_partition(self): with self.assertRaises(AirflowSensorTimeout): mock_hive_metastore_hook = MockHiveMetastoreHook() - mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock( - return_value=False) + mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock(return_value=False) op = NamedHivePartitionSensor( task_id='hive_partition_check', partition_names=[ "airflow.static_babynames_partitioned/ds={{ds}}", - "airflow.static_babynames_partitioned/ds=nonexistent" + "airflow.static_babynames_partitioned/ds=nonexistent", ], poke_interval=0.1, timeout=1, dag=self.dag, - hook=mock_hive_metastore_hook + hook=mock_hive_metastore_hook, ) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py index 1a437d8f3dc03..4253325eec475 100644 --- a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py +++ b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py @@ -30,7 +30,6 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment): - def setUp(self): self.kwargs = dict( sql='sql', @@ -50,8 +49,7 @@ def test_execute(self, mock_hive_hook, mock_mysql_hook): mock_hive_hook.return_value.get_records.assert_called_once_with('sql', hive_conf={}) mock_mysql_hook.assert_called_once_with(mysql_conn_id=self.kwargs['mysql_conn_id']) mock_mysql_hook.return_value.insert_rows.assert_called_once_with( - table=self.kwargs['mysql_table'], - rows=mock_hive_hook.return_value.get_records.return_value + table=self.kwargs['mysql_table'], rows=mock_hive_hook.return_value.get_records.return_value ) @patch('airflow.providers.apache.hive.transfers.hive_to_mysql.MySqlHook') @@ -89,11 +87,10 @@ def test_execute_bulk_load(self, mock_hive_hook, mock_tmp_file, mock_mysql_hook) delimiter='\t', lineterminator='\n', output_header=False, - hive_conf=context_to_airflow_vars(context) + hive_conf=context_to_airflow_vars(context), ) mock_mysql_hook.return_value.bulk_load.assert_called_once_with( - table=self.kwargs['mysql_table'], - tmp_file=mock_tmp_file.return_value.name + table=self.kwargs['mysql_table'], tmp_file=mock_tmp_file.return_value.name ) mock_tmp_file.return_value.close.assert_called_once_with() @@ -105,21 +102,20 @@ def test_execute_with_hive_conf(self, mock_mysql_hook): self.kwargs.update(dict(hive_conf={'mapreduce.job.queuename': 'fake_queue'})) - with patch('airflow.providers.apache.hive.transfers.hive_to_mysql.HiveServer2Hook', - return_value=mock_hive_hook): + with patch( + 'airflow.providers.apache.hive.transfers.hive_to_mysql.HiveServer2Hook', + return_value=mock_hive_hook, + ): HiveToMySqlOperator(**self.kwargs).execute(context=context) hive_conf = context_to_airflow_vars(context) hive_conf.update(self.kwargs['hive_conf']) - mock_hive_hook.get_records.assert_called_once_with( - self.kwargs['sql'], - hive_conf=hive_conf - ) + mock_hive_hook.get_records.assert_called_once_with(self.kwargs['sql'], hive_conf=hive_conf) @unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") + 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + ) def test_hive_to_mysql(self): test_hive_results = 'test_hive_results' @@ -130,10 +126,14 @@ def test_hive_to_mysql(self): mock_mysql_hook.run = MagicMock() mock_mysql_hook.insert_rows = MagicMock() - with patch('airflow.providers.apache.hive.transfers.hive_to_mysql.HiveServer2Hook', - return_value=mock_hive_hook): - with patch('airflow.providers.apache.hive.transfers.hive_to_mysql.MySqlHook', - return_value=mock_mysql_hook): + with patch( + 'airflow.providers.apache.hive.transfers.hive_to_mysql.HiveServer2Hook', + return_value=mock_hive_hook, + ): + with patch( + 'airflow.providers.apache.hive.transfers.hive_to_mysql.MySqlHook', + return_value=mock_mysql_hook, + ): op = HiveToMySqlOperator( mysql_conn_id='airflow_db', @@ -148,10 +148,10 @@ def test_hive_to_mysql(self): 'DROP TABLE IF EXISTS test_static_babynames;', 'CREATE TABLE test_static_babynames (name VARCHAR(500))', ], - dag=self.dag) + dag=self.dag, + ) op.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) raw_select_name_query = mock_hive_hook.get_records.call_args_list[0][0][0] actual_select_name_query = re.sub(r'\s{2,}', ' ', raw_select_name_query).strip() @@ -159,14 +159,18 @@ def test_hive_to_mysql(self): self.assertEqual(expected_select_name_query, actual_select_name_query) actual_hive_conf = mock_hive_hook.get_records.call_args_list[0][1]['hive_conf'] - expected_hive_conf = {'airflow.ctx.dag_owner': 'airflow', - 'airflow.ctx.dag_id': 'test_dag_id', - 'airflow.ctx.task_id': 'hive_to_mysql_check', - 'airflow.ctx.execution_date': '2015-01-01T00:00:00+00:00'} + expected_hive_conf = { + 'airflow.ctx.dag_owner': 'airflow', + 'airflow.ctx.dag_id': 'test_dag_id', + 'airflow.ctx.task_id': 'hive_to_mysql_check', + 'airflow.ctx.execution_date': '2015-01-01T00:00:00+00:00', + } self.assertEqual(expected_hive_conf, actual_hive_conf) - expected_mysql_preoperator = ['DROP TABLE IF EXISTS test_static_babynames;', - 'CREATE TABLE test_static_babynames (name VARCHAR(500))'] + expected_mysql_preoperator = [ + 'DROP TABLE IF EXISTS test_static_babynames;', + 'CREATE TABLE test_static_babynames (name VARCHAR(500))', + ] mock_mysql_hook.run.assert_called_with(expected_mysql_preoperator) mock_mysql_hook.insert_rows.assert_called_with(table='test_static_babynames', rows=test_hive_results) diff --git a/tests/providers/apache/hive/transfers/test_hive_to_samba.py b/tests/providers/apache/hive/transfers/test_hive_to_samba.py index dbd71f567f51c..26c33292ca828 100644 --- a/tests/providers/apache/hive/transfers/test_hive_to_samba.py +++ b/tests/providers/apache/hive/transfers/test_hive_to_samba.py @@ -26,7 +26,6 @@ class TestHive2SambaOperator(TestHiveEnvironment): - def setUp(self): self.kwargs = dict( hql='hql', @@ -47,41 +46,43 @@ def test_execute(self, mock_tmp_file, mock_hive_hook, mock_samba_hook): HiveToSambaOperator(**self.kwargs).execute(context) - mock_hive_hook.assert_called_once_with( - hiveserver2_conn_id=self.kwargs['hiveserver2_conn_id']) + mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs['hiveserver2_conn_id']) mock_hive_hook.return_value.to_csv.assert_called_once_with( hql=self.kwargs['hql'], csv_filepath=mock_tmp_file.name, - hive_conf=context_to_airflow_vars(context)) - mock_samba_hook.assert_called_once_with( - samba_conn_id=self.kwargs['samba_conn_id']) + hive_conf=context_to_airflow_vars(context), + ) + mock_samba_hook.assert_called_once_with(samba_conn_id=self.kwargs['samba_conn_id']) mock_samba_hook.return_value.push_from_local.assert_called_once_with( - self.kwargs['destination_filepath'], mock_tmp_file.name) + self.kwargs['destination_filepath'], mock_tmp_file.name + ) @unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") + 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + ) @patch('tempfile.tempdir', '/tmp/') @patch('tempfile._RandomNameSequence.__next__') - @patch('airflow.providers.apache.hive.transfers.hive_to_samba.HiveServer2Hook', - side_effect=MockHiveServer2Hook) + @patch( + 'airflow.providers.apache.hive.transfers.hive_to_samba.HiveServer2Hook', + side_effect=MockHiveServer2Hook, + ) def test_hive2samba(self, mock_hive_server_hook, mock_temp_dir): mock_temp_dir.return_value = "tst" samba_hook = MockSambaHook(self.kwargs['samba_conn_id']) samba_hook.upload = MagicMock() - with patch('airflow.providers.apache.hive.transfers.hive_to_samba.SambaHook', - return_value=samba_hook): + with patch( + 'airflow.providers.apache.hive.transfers.hive_to_samba.SambaHook', return_value=samba_hook + ): samba_hook.conn.upload = MagicMock() op = HiveToSambaOperator( task_id='hive2samba_check', samba_conn_id='tableau_samba', hql="SELECT * FROM airflow.static_babynames LIMIT 10000", destination_filepath='test_airflow.csv', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - samba_hook.conn.upload.assert_called_with( - '/tmp/tmptst', 'test_airflow.csv') + samba_hook.conn.upload.assert_called_with('/tmp/tmptst', 'test_airflow.csv') diff --git a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py index e6ca835fccc4d..43ecb934aaaee 100644 --- a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py @@ -37,31 +37,24 @@ @unittest.skipIf(PY38, "Mssql package not available when Python >= 3.8.") @unittest.skipIf(pymssql is None, 'pymssql package not present') class TestMsSqlToHiveTransfer(unittest.TestCase): - def setUp(self): - self.kwargs = dict( - sql='sql', - hive_table='table', - task_id='test_mssql_to_hive', - dag=None - ) + self.kwargs = dict(sql='sql', hive_table='table', task_id='test_mssql_to_hive', dag=None) - # pylint: disable=c-extension-no-member def test_type_map_binary(self): - mapped_type = MsSqlToHiveOperator( - **self.kwargs).type_map(pymssql.BINARY.value) # pylint: disable=c-extension-no-member + # pylint: disable=c-extension-no-member + mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.BINARY.value) self.assertEqual(mapped_type, 'INT') def test_type_map_decimal(self): - mapped_type = MsSqlToHiveOperator( - **self.kwargs).type_map(pymssql.DECIMAL.value) # pylint: disable=c-extension-no-member + # pylint: disable=c-extension-no-member + mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.DECIMAL.value) self.assertEqual(mapped_type, 'FLOAT') def test_type_map_number(self): - mapped_type = MsSqlToHiveOperator( - **self.kwargs).type_map(pymssql.NUMBER.value) # pylint: disable=c-extension-no-member + # pylint: disable=c-extension-no-member + mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.NUMBER.value) self.assertEqual(mapped_type, 'INT') @@ -86,7 +79,8 @@ def test_execute(self, mock_hive_hook, mock_mssql_hook, mock_tmp_file, mock_csv) mock_mssql_hook_cursor.return_value.execute.assert_called_once_with(mssql_to_hive_transfer.sql) mock_csv.writer.assert_called_once_with( - mock_tmp_file, delimiter=mssql_to_hive_transfer.delimiter, encoding='utf-8') + mock_tmp_file, delimiter=mssql_to_hive_transfer.delimiter, encoding='utf-8' + ) field_dict = OrderedDict() for field in mock_mssql_hook_cursor.return_value.description: field_dict[field[0]] = mssql_to_hive_transfer.type_map(field[1]) @@ -99,7 +93,8 @@ def test_execute(self, mock_hive_hook, mock_mssql_hook, mock_tmp_file, mock_csv) partition=mssql_to_hive_transfer.partition, delimiter=mssql_to_hive_transfer.delimiter, recreate=mssql_to_hive_transfer.recreate, - tblproperties=mssql_to_hive_transfer.tblproperties) + tblproperties=mssql_to_hive_transfer.tblproperties, + ) @patch('airflow.providers.apache.hive.transfers.mssql_to_hive.csv') @patch('airflow.providers.apache.hive.transfers.mssql_to_hive.NamedTemporaryFile') @@ -129,4 +124,5 @@ def test_execute_empty_description_field(self, mock_hive_hook, mock_mssql_hook, partition=mssql_to_hive_transfer.partition, delimiter=mssql_to_hive_transfer.delimiter, recreate=mssql_to_hive_transfer.recreate, - tblproperties=mssql_to_hive_transfer.tblproperties) + tblproperties=mssql_to_hive_transfer.tblproperties, + ) diff --git a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py index e5d6ec6fe048f..71f4de9050031 100644 --- a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py @@ -108,14 +108,16 @@ def setUp(self): } with MySqlHook().get_conn() as cur: - cur.execute(''' + cur.execute( + ''' CREATE TABLE IF NOT EXISTS baby_names ( org_year integer(4), baby_name VARCHAR(25), rate FLOAT(7,6), sex VARCHAR(4) ) - ''') + ''' + ) for row in rows: cur.execute("INSERT INTO baby_names VALUES(%s, %s, %s, %s);", row) @@ -141,25 +143,42 @@ def test_mysql_to_hive(self, mock_popen, mock_temp_dir): hive_table='test_mysql_to_hive', recreate=True, delimiter=",", - dag=self.dag) - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) - - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=test@airflow.com', '-hiveconf', - 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', - '-hiveconf', 'tez.queue.name=airflow', '-f', - '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive'] + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=unit_test_dag', + '-hiveconf', + 'airflow.ctx.task_id=test_m2h', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=55', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=test@airflow.com', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", - close_fds=True + close_fds=True, ) @mock.patch('tempfile.tempdir', '/tmp/') @@ -181,25 +200,42 @@ def test_mysql_to_hive_partition(self, mock_popen, mock_temp_dir): recreate=False, create=True, delimiter=",", - dag=self.dag) - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) - - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=test@airflow.com', '-hiveconf', - 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', - '-hiveconf', 'tez.queue.name=airflow', '-f', - '/tmp/airflow_hiveop_test_mysql_to_hive_partition/tmptest_mysql_to_hive_partition'] + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=unit_test_dag', + '-hiveconf', + 'airflow.ctx.task_id=test_m2h', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=55', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=test@airflow.com', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_test_mysql_to_hive_partition/tmptest_mysql_to_hive_partition', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive_partition", - close_fds=True + close_fds=True, ) @mock.patch('tempfile.tempdir', '/tmp/') @@ -220,25 +256,42 @@ def test_mysql_to_hive_tblproperties(self, mock_popen, mock_temp_dir): recreate=True, delimiter=",", tblproperties={'test_property': 'test_value'}, - dag=self.dag) - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) - - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=test@airflow.com', '-hiveconf', - 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', - '-hiveconf', 'tez.queue.name=airflow', - '-f', '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive'] + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=unit_test_dag', + '-hiveconf', + 'airflow.ctx.task_id=test_m2h', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=55', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=test@airflow.com', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", - close_fds=True + close_fds=True, ) @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveCliHook.load_file') @@ -250,7 +303,8 @@ def test_mysql_to_hive_type_conversion(self, mock_load_file): try: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - conn.execute(""" + conn.execute( + """ CREATE TABLE {} ( c0 TINYINT, c1 SMALLINT, @@ -259,16 +313,19 @@ def test_mysql_to_hive_type_conversion(self, mock_load_file): c4 BIGINT, c5 TIMESTAMP ) - """.format(mysql_table)) + """.format( + mysql_table + ) + ) op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table='test_mysql_to_hive', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert mock_load_file.call_count == 1 ordered_dict = OrderedDict() @@ -278,8 +335,7 @@ def test_mysql_to_hive_type_conversion(self, mock_load_file): ordered_dict["c3"] = "BIGINT" ordered_dict["c4"] = "DECIMAL(38,0)" ordered_dict["c5"] = "TIMESTAMP" - self.assertEqual( - mock_load_file.call_args[1]["field_dict"], ordered_dict) + self.assertEqual(mock_load_file.call_args[1]["field_dict"], ordered_dict) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) @@ -298,26 +354,32 @@ def test_mysql_to_hive_verify_csv_special_char(self, mock_popen, mock_temp_dir): hook = MySqlHook() try: - db_record = ( - 'c0', - '["true"]' - ) + db_record = ('c0', '["true"]') with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - conn.execute(""" + conn.execute( + """ CREATE TABLE {} ( c0 VARCHAR(25), c1 VARCHAR(25) ) - """.format(mysql_table)) - conn.execute(""" + """.format( + mysql_table + ) + ) + conn.execute( + """ INSERT INTO {} VALUES ( '{}', '{}' ) - """.format(mysql_table, *db_record)) + """.format( + mysql_table, *db_record + ) + ) with mock.patch.dict('os.environ', self.env_vars): import unicodecsv as csv + op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', @@ -328,33 +390,49 @@ def test_mysql_to_hive_verify_csv_special_char(self, mock_popen, mock_temp_dir): quoting=csv.QUOTE_NONE, quotechar='', escapechar='@', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_cursor = MockConnectionCursor() mock_cursor.iterable = [('c0', '["true"]'), (2, 2)] hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor) - result = hive_hook.get_records( - "SELECT * FROM {}".format(hive_table)) + result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], db_record) - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=test@airflow.com', '-hiveconf', - 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', - '-hiveconf', 'tez.queue.name=airflow', '-f', - '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive'] + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=unit_test_dag', + '-hiveconf', + 'airflow.ctx.task_id=test_m2h', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=55', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=test@airflow.com', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", - close_fds=True + close_fds=True, ) finally: with hook.get_conn() as conn: @@ -384,12 +462,13 @@ def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir): -32768, -8388608, -2147483648, - -9223372036854775808 + -9223372036854775808, ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - conn.execute(""" + conn.execute( + """ CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, @@ -402,12 +481,19 @@ def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir): c8 INT, c9 BIGINT ) - """.format(mysql_table)) - conn.execute(""" + """.format( + mysql_table + ) + ) + conn.execute( + """ INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) - """.format(mysql_table, *minmax)) + """.format( + mysql_table, *minmax + ) + ) with mock.patch.dict('os.environ', self.env_vars): op = MySqlToHiveOperator( @@ -417,33 +503,49 @@ def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir): hive_table=hive_table, recreate=True, delimiter=",", - dag=self.dag) - op.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_cursor = MockConnectionCursor() mock_cursor.iterable = [minmax] hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor) - result = hive_hook.get_records( - "SELECT * FROM {}".format(hive_table)) + result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], minmax) - hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', - 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', - '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', - 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', - '-hiveconf', 'airflow.ctx.dag_email=test@airflow.com', '-hiveconf', - 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', - '-hiveconf', 'tez.queue.name=airflow', '-f', - '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive'] + hive_cmd = [ + 'beeline', + '-u', + '"jdbc:hive2://localhost:10000/default"', + '-hiveconf', + 'airflow.ctx.dag_id=unit_test_dag', + '-hiveconf', + 'airflow.ctx.task_id=test_m2h', + '-hiveconf', + 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', + '-hiveconf', + 'airflow.ctx.dag_run_id=55', + '-hiveconf', + 'airflow.ctx.dag_owner=airflow', + '-hiveconf', + 'airflow.ctx.dag_email=test@airflow.com', + '-hiveconf', + 'mapreduce.job.queuename=airflow', + '-hiveconf', + 'mapred.job.queue.name=airflow', + '-hiveconf', + 'tez.queue.name=airflow', + '-f', + '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', + ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", - close_fds=True + close_fds=True, ) finally: diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py b/tests/providers/apache/hive/transfers/test_s3_to_hive.py index 67aec2a2330a8..81617cedaa946 100644 --- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py @@ -40,7 +40,6 @@ class TestS3ToHiveTransfer(unittest.TestCase): - def setUp(self): self.file_names = {} self.task_id = 'S3ToHiveTransferTest' @@ -55,28 +54,27 @@ def setUp(self): self.check_headers = True self.wildcard_match = False self.input_compressed = False - self.kwargs = {'task_id': self.task_id, - 's3_key': self.s3_key, - 'field_dict': self.field_dict, - 'hive_table': self.hive_table, - 'delimiter': self.delimiter, - 'create': self.create, - 'recreate': self.recreate, - 'partition': self.partition, - 'headers': self.headers, - 'check_headers': self.check_headers, - 'wildcard_match': self.wildcard_match, - 'input_compressed': self.input_compressed - } + self.kwargs = { + 'task_id': self.task_id, + 's3_key': self.s3_key, + 'field_dict': self.field_dict, + 'hive_table': self.hive_table, + 'delimiter': self.delimiter, + 'create': self.create, + 'recreate': self.recreate, + 'partition': self.partition, + 'headers': self.headers, + 'check_headers': self.check_headers, + 'wildcard_match': self.wildcard_match, + 'input_compressed': self.input_compressed, + } try: header = b"Sno\tSome,Text \n" line1 = b"1\tAirflow Test\n" line2 = b"2\tS32HiveTransfer\n" self.tmp_dir = mkdtemp(prefix='test_tmps32hive_') # create sample txt, gz and bz2 with and without headers - with NamedTemporaryFile(mode='wb+', - dir=self.tmp_dir, - delete=False) as f_txt_h: + with NamedTemporaryFile(mode='wb+', dir=self.tmp_dir, delete=False) as f_txt_h: self._set_fn(f_txt_h.name, '.txt', True) f_txt_h.writelines([header, line1, line2]) fn_gz = self._get_fn('.txt', True) + ".gz" @@ -156,55 +154,54 @@ def _check_file_equality(fn_1, fn_2, ext): def test_bad_parameters(self): self.kwargs['check_headers'] = True self.kwargs['headers'] = False - self.assertRaisesRegex(AirflowException, "To check_headers.*", - S3ToHiveOperator, **self.kwargs) + self.assertRaisesRegex(AirflowException, "To check_headers.*", S3ToHiveOperator, **self.kwargs) def test__get_top_row_as_list(self): self.kwargs['delimiter'] = '\t' fn_txt = self._get_fn('.txt', True) - header_list = S3ToHiveOperator(**self.kwargs). \ - _get_top_row_as_list(fn_txt) - self.assertEqual(header_list, ['Sno', 'Some,Text'], - msg="Top row from file doesnt matched expected value") + header_list = S3ToHiveOperator(**self.kwargs)._get_top_row_as_list(fn_txt) + self.assertEqual( + header_list, ['Sno', 'Some,Text'], msg="Top row from file doesnt matched expected value" + ) self.kwargs['delimiter'] = ',' - header_list = S3ToHiveOperator(**self.kwargs). \ - _get_top_row_as_list(fn_txt) - self.assertEqual(header_list, ['Sno\tSome', 'Text'], - msg="Top row from file doesnt matched expected value") + header_list = S3ToHiveOperator(**self.kwargs)._get_top_row_as_list(fn_txt) + self.assertEqual( + header_list, ['Sno\tSome', 'Text'], msg="Top row from file doesnt matched expected value" + ) def test__match_headers(self): - self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'), - ('Some,Text', 'STRING')]) - self.assertTrue(S3ToHiveOperator(**self.kwargs). - _match_headers(['Sno', 'Some,Text']), - msg="Header row doesnt match expected value") + self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'), ('Some,Text', 'STRING')]) + self.assertTrue( + S3ToHiveOperator(**self.kwargs)._match_headers(['Sno', 'Some,Text']), + msg="Header row doesnt match expected value", + ) # Testing with different column order - self.assertFalse(S3ToHiveOperator(**self.kwargs). - _match_headers(['Some,Text', 'Sno']), - msg="Header row doesnt match expected value") + self.assertFalse( + S3ToHiveOperator(**self.kwargs)._match_headers(['Some,Text', 'Sno']), + msg="Header row doesnt match expected value", + ) # Testing with extra column in header - self.assertFalse(S3ToHiveOperator(**self.kwargs). - _match_headers(['Sno', 'Some,Text', 'ExtraColumn']), - msg="Header row doesnt match expected value") + self.assertFalse( + S3ToHiveOperator(**self.kwargs)._match_headers(['Sno', 'Some,Text', 'ExtraColumn']), + msg="Header row doesnt match expected value", + ) def test__delete_top_row_and_compress(self): s32hive = S3ToHiveOperator(**self.kwargs) # Testing gz file type fn_txt = self._get_fn('.txt', True) - gz_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, - '.gz', - self.tmp_dir) + gz_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, '.gz', self.tmp_dir) fn_gz = self._get_fn('.gz', False) - self.assertTrue(self._check_file_equality(gz_txt_nh, fn_gz, '.gz'), - msg="gz Compressed file not as expected") + self.assertTrue( + self._check_file_equality(gz_txt_nh, fn_gz, '.gz'), msg="gz Compressed file not as expected" + ) # Testing bz2 file type - bz2_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, - '.bz2', - self.tmp_dir) + bz2_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, '.bz2', self.tmp_dir) fn_bz2 = self._get_fn('.bz2', False) - self.assertTrue(self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'), - msg="bz2 Compressed file not as expected") + self.assertTrue( + self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'), msg="bz2 Compressed file not as expected" + ) @unittest.skipIf(mock is None, 'mock package not present') @unittest.skipIf(mock_s3 is None, 'moto package not present') @@ -229,10 +226,10 @@ def test_execute(self, mock_hiveclihook): # file parameter to HiveCliHook.load_file is compared # against expected file output - mock_hiveclihook().load_file.side_effect = \ - lambda *args, **kwargs: self.assertTrue( - self._check_file_equality(args[0], op_fn, ext), - msg='{0} output file not as expected'.format(ext)) + mock_hiveclihook().load_file.side_effect = lambda *args, **kwargs: self.assertTrue( + self._check_file_equality(args[0], op_fn, ext), + msg='{0} output file not as expected'.format(ext), + ) # Execute S3ToHiveTransfer s32hive = S3ToHiveOperator(**self.kwargs) s32hive.execute(None) @@ -266,23 +263,23 @@ def test_execute_with_select_expression(self, mock_hiveclihook): # Upload the file into the Mocked S3 bucket conn.upload_file(ip_fn, bucket, key) - input_serialization = { - 'CSV': {'FieldDelimiter': self.delimiter} - } + input_serialization = {'CSV': {'FieldDelimiter': self.delimiter}} if input_compressed: input_serialization['CompressionType'] = 'GZIP' if has_header: input_serialization['CSV']['FileHeaderInfo'] = 'USE' # Confirm that select_key was called with the right params - with mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key', - return_value="") as mock_select_key: + with mock.patch( + 'airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key', return_value="" + ) as mock_select_key: # Execute S3ToHiveTransfer s32hive = S3ToHiveOperator(**self.kwargs) s32hive.execute(None) mock_select_key.assert_called_once_with( - bucket_name=bucket, key=key, + bucket_name=bucket, + key=key, expression=select_expression, - input_serialization=input_serialization + input_serialization=input_serialization, ) diff --git a/tests/providers/apache/hive/transfers/test_vertica_to_hive.py b/tests/providers/apache/hive/transfers/test_vertica_to_hive.py index a454030743c1b..fc09db4e2b2fc 100644 --- a/tests/providers/apache/hive/transfers/test_vertica_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_vertica_to_hive.py @@ -25,34 +25,24 @@ def mock_get_conn(): - commit_mock = mock.MagicMock( - ) + commit_mock = mock.MagicMock() cursor_mock = mock.MagicMock( - execute=[], - fetchall=[['1', '2', '3']], - description=['a', 'b', 'c'], - iterate=[['1', '2', '3']], - ) - conn_mock = mock.MagicMock( - commit=commit_mock, - cursor=cursor_mock, + execute=[], fetchall=[['1', '2', '3']], description=['a', 'b', 'c'], iterate=[['1', '2', '3']], ) + conn_mock = mock.MagicMock(commit=commit_mock, cursor=cursor_mock,) return conn_mock class TestVerticaToHiveTransfer(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) @mock.patch( 'airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaHook.get_conn', - side_effect=mock_get_conn) - @mock.patch( - 'airflow.providers.apache.hive.transfers.vertica_to_hive.HiveCliHook.load_file') + side_effect=mock_get_conn, + ) + @mock.patch('airflow.providers.apache.hive.transfers.vertica_to_hive.HiveCliHook.load_file') def test_select_insert_transfer(self, *args): """ Test check selection from vertica into memory and @@ -64,5 +54,6 @@ def test_select_insert_transfer(self, *args): hive_table='test_table', vertica_conn_id='test_vertica_conn_id', hive_cli_conn_id='hive_cli_default', - dag=self.dag) + dag=self.dag, + ) task.execute(None) diff --git a/tests/providers/apache/kylin/hooks/test_kylin.py b/tests/providers/apache/kylin/hooks/test_kylin.py index 5d7b1c3da859d..8e76aff30602a 100644 --- a/tests/providers/apache/kylin/hooks/test_kylin.py +++ b/tests/providers/apache/kylin/hooks/test_kylin.py @@ -27,7 +27,6 @@ class TestKylinHook(unittest.TestCase): - def setUp(self) -> None: self.hook = KylinHook(kylin_conn_id='kylin_default', project='learn_kylin') @@ -40,12 +39,23 @@ def test_get_job_status(self, mock_job): @patch("kylinpy.Kylin.get_datasource") def test_cube_run(self, cube_source): - class MockCubeSource: def invoke_command(self, command, **kwargs): - invoke_command_list = ['fullbuild', 'build', 'merge', 'refresh', - 'delete', 'build_streaming', 'merge_streaming', 'refresh_streaming', - 'disable', 'enable', 'purge', 'clone', 'drop'] + invoke_command_list = [ + 'fullbuild', + 'build', + 'merge', + 'refresh', + 'delete', + 'build_streaming', + 'merge_streaming', + 'refresh_streaming', + 'disable', + 'enable', + 'purge', + 'clone', + 'drop', + ] if command in invoke_command_list: return {"code": "000", "data": {}} else: @@ -57,4 +67,6 @@ def invoke_command(self, command, **kwargs): self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'refresh'), response_data) self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'merge'), response_data) self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'build_streaming'), response_data) - self.assertRaises(AirflowException, self.hook.cube_run, 'kylin_sales_cube', 'build123',) + self.assertRaises( + AirflowException, self.hook.cube_run, 'kylin_sales_cube', 'build123', + ) diff --git a/tests/providers/apache/kylin/operators/test_kylin_cube.py b/tests/providers/apache/kylin/operators/test_kylin_cube.py index f71ec3d88c94f..f7d21a276d3c9 100644 --- a/tests/providers/apache/kylin/operators/test_kylin_cube.py +++ b/tests/providers/apache/kylin/operators/test_kylin_cube.py @@ -37,32 +37,48 @@ class TestKylinCubeOperator(unittest.TestCase): 'command': 'build', 'start_time': datetime(2012, 1, 2, 0, 0).strftime("%s") + '000', 'end_time': datetime(2012, 1, 3, 0, 0).strftime("%s") + '000', - } - cube_command = ['fullbuild', 'build', 'merge', 'refresh', - 'delete', 'build_streaming', 'merge_streaming', 'refresh_streaming', - 'disable', 'enable', 'purge', 'clone', 'drop'] + cube_command = [ + 'fullbuild', + 'build', + 'merge', + 'refresh', + 'delete', + 'build_streaming', + 'merge_streaming', + 'refresh_streaming', + 'disable', + 'enable', + 'purge', + 'clone', + 'drop', + ] build_response = {"uuid": "c143e0e4-ac5f-434d-acf3-46b0d15e3dc6"} def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) @patch('airflow.providers.apache.kylin.operators.kylin_cube.KylinHook') def test_execute(self, mock_hook): - operator = KylinCubeOperator( - task_id='kylin_task', - dag=self.dag, - **self._config - ) + operator = KylinCubeOperator(task_id='kylin_task', dag=self.dag, **self._config) hook = MagicMock() - hook.invoke_command = ['fullbuild', 'build', 'merge', 'refresh', - 'delete', 'build_streaming', 'merge_streaming', 'refresh_streaming', - 'disable', 'enable', 'purge', 'clone', 'drop'] + hook.invoke_command = [ + 'fullbuild', + 'build', + 'merge', + 'refresh', + 'delete', + 'build_streaming', + 'merge_streaming', + 'refresh_streaming', + 'disable', + 'enable', + 'purge', + 'clone', + 'drop', + ] mock_hook.return_value = hook mock_hook.cube_run.return_value = {} @@ -75,28 +91,23 @@ def test_execute(self, mock_hook): self.assertEqual(self._config['end_time'], operator.end_time) operator.execute(None) mock_hook.assert_called_once_with( - kylin_conn_id=self._config['kylin_conn_id'], - project=self._config['project'], - dsn=None + kylin_conn_id=self._config['kylin_conn_id'], project=self._config['project'], dsn=None ) - mock_hook.return_value.cube_run.assert_called_once_with('kylin_sales_cube', - 'build', - end=datetime(2012, 1, 3, 0, 0), - name=None, - offset_end=None, - offset_start=None, - start=datetime(2012, 1, 2, 0, 0)) + mock_hook.return_value.cube_run.assert_called_once_with( + 'kylin_sales_cube', + 'build', + end=datetime(2012, 1, 3, 0, 0), + name=None, + offset_end=None, + offset_start=None, + start=datetime(2012, 1, 2, 0, 0), + ) @patch('airflow.providers.apache.kylin.operators.kylin_cube.KylinHook') def test_execute_build(self, mock_hook): operator = KylinCubeOperator( - is_track_job=True, - timeout=5, - interval=1, - task_id='kylin_task', - dag=self.dag, - **self._config + is_track_job=True, timeout=5, interval=1, task_id='kylin_task', dag=self.dag, **self._config ) hook = MagicMock() hook.invoke_command = self.cube_command @@ -109,12 +120,7 @@ def test_execute_build(self, mock_hook): @patch('airflow.providers.apache.kylin.operators.kylin_cube.KylinHook') def test_execute_build_status_error(self, mock_hook): operator = KylinCubeOperator( - is_track_job=True, - timeout=5, - interval=1, - task_id='kylin_task', - dag=self.dag, - **self._config + is_track_job=True, timeout=5, interval=1, task_id='kylin_task', dag=self.dag, **self._config ) hook = MagicMock() hook.invoke_command = self.cube_command @@ -127,12 +133,7 @@ def test_execute_build_status_error(self, mock_hook): @patch('airflow.providers.apache.kylin.operators.kylin_cube.KylinHook') def test_execute_build_time_out_error(self, mock_hook): operator = KylinCubeOperator( - is_track_job=True, - timeout=5, - interval=1, - task_id='kylin_task', - dag=self.dag, - **self._config + is_track_job=True, timeout=5, interval=1, task_id='kylin_task', dag=self.dag, **self._config ) hook = MagicMock() hook.invoke_command = self.cube_command diff --git a/tests/providers/apache/livy/hooks/test_livy.py b/tests/providers/apache/livy/hooks/test_livy.py index aa1338cee81c8..4985fd12caf35 100644 --- a/tests/providers/apache/livy/hooks/test_livy.py +++ b/tests/providers/apache/livy/hooks/test_livy.py @@ -35,13 +35,15 @@ class TestLivyHook(unittest.TestCase): @classmethod def setUpClass(cls): db.merge_conn( - Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port=8998)) + Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port=8998) + ) db.merge_conn(Connection(conn_id='default_port', conn_type='http', host='http://host')) db.merge_conn(Connection(conn_id='default_protocol', conn_type='http', host='host')) db.merge_conn(Connection(conn_id='port_set', host='host', conn_type='http', port=1234)) db.merge_conn(Connection(conn_id='schema_set', host='host', conn_type='http', schema='zzz')) db.merge_conn( - Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz')) + Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz') + ) db.merge_conn(Connection(conn_id='missing_host', conn_type='http', port=1234)) db.merge_conn(Connection(conn_id='invalid_uri', uri='http://invalid_uri:4321')) @@ -94,24 +96,27 @@ def test_build_body(self): num_executors='10', ) - self.assertEqual(body, { - 'file': 'appname', - 'className': 'org.example.livy', - 'proxyUser': 'proxyUser', - 'args': ['a', '1'], - 'jars': ['jar1', 'jar2'], - 'files': ['file1', 'file2'], - 'pyFiles': ['py1', 'py2'], - 'archives': ['arch1', 'arch2'], - 'queue': 'queue', - 'name': 'name', - 'conf': {'a': 'b'}, - 'driverCores': 2, - 'driverMemory': '1M', - 'executorMemory': '1m', - 'executorCores': '1', - 'numExecutors': '10' - }) + self.assertEqual( + body, + { + 'file': 'appname', + 'className': 'org.example.livy', + 'proxyUser': 'proxyUser', + 'args': ['a', '1'], + 'jars': ['jar1', 'jar2'], + 'files': ['file1', 'file2'], + 'pyFiles': ['py1', 'py2'], + 'archives': ['arch1', 'arch2'], + 'queue': 'queue', + 'name': 'name', + 'conf': {'a': 'b'}, + 'driverCores': 2, + 'driverMemory': '1M', + 'executorMemory': '1m', + 'executorCores': '1', + 'numExecutors': '10', + }, + ) def test_parameters_validation(self): with self.subTest('not a size'): @@ -120,8 +125,7 @@ def test_parameters_validation(self): with self.subTest('list of stringables'): self.assertEqual( - LivyHook.build_post_batch_body(file='appname', args=['a', 1, 0.1])['args'], - ['a', '1', '0.1'] + LivyHook.build_post_batch_body(file='appname', args=['a', 1, 0.1])['args'], ['a', '1', '0.1'] ) def test_validate_size_format(self): @@ -244,16 +248,14 @@ def test_post_batch_arguments(self, mock_request): mock_request.return_value.json.return_value = { 'id': BATCH_ID, 'state': BatchState.STARTING.value, - 'log': [] + 'log': [], } hook = LivyHook() resp = hook.post_batch(file='sparkapp') mock_request.assert_called_once_with( - method='POST', - endpoint='/batches', - data=json.dumps({'file': 'sparkapp'}) + method='POST', endpoint='/batches', data=json.dumps({'file': 'sparkapp'}) ) request_args = mock_request.call_args[1] @@ -266,9 +268,10 @@ def test_post_batch_arguments(self, mock_request): @requests_mock.mock() def test_post_batch_success(self, mock): mock.register_uri( - 'POST', '//livy:8998/batches', + 'POST', + '//livy:8998/batches', json={'id': BATCH_ID, 'state': BatchState.STARTING.value, 'log': []}, - status_code=201 + status_code=201, ) resp = LivyHook().post_batch(file='sparkapp') @@ -278,12 +281,7 @@ def test_post_batch_success(self, mock): @requests_mock.mock() def test_post_batch_fail(self, mock): - mock.register_uri( - 'POST', '//livy:8998/batches', - json={}, - status_code=400, - reason='ERROR' - ) + mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=400, reason='ERROR') hook = LivyHook() with self.assertRaises(AirflowException): @@ -292,9 +290,7 @@ def test_post_batch_fail(self, mock): @requests_mock.mock() def test_get_batch_success(self, mock): mock.register_uri( - 'GET', '//livy:8998/batches/{}'.format(BATCH_ID), - json={'id': BATCH_ID}, - status_code=200 + 'GET', '//livy:8998/batches/{}'.format(BATCH_ID), json={'id': BATCH_ID}, status_code=200 ) hook = LivyHook() @@ -306,10 +302,11 @@ def test_get_batch_success(self, mock): @requests_mock.mock() def test_get_batch_fail(self, mock): mock.register_uri( - 'GET', '//livy:8998/batches/{}'.format(BATCH_ID), + 'GET', + '//livy:8998/batches/{}'.format(BATCH_ID), json={'msg': 'Unable to find batch'}, status_code=404, - reason='ERROR' + reason='ERROR', ) hook = LivyHook() @@ -327,9 +324,10 @@ def test_get_batch_state_success(self, mock): running = BatchState.RUNNING mock.register_uri( - 'GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), + 'GET', + '//livy:8998/batches/{}/state'.format(BATCH_ID), json={'id': BATCH_ID, 'state': running.value}, - status_code=200 + status_code=200, ) state = LivyHook().get_batch_state(BATCH_ID) @@ -340,10 +338,7 @@ def test_get_batch_state_success(self, mock): @requests_mock.mock() def test_get_batch_state_fail(self, mock): mock.register_uri( - 'GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), - json={}, - status_code=400, - reason='ERROR' + 'GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), json={}, status_code=400, reason='ERROR' ) hook = LivyHook() @@ -352,11 +347,7 @@ def test_get_batch_state_fail(self, mock): @requests_mock.mock() def test_get_batch_state_missing(self, mock): - mock.register_uri( - 'GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), - json={}, - status_code=200 - ) + mock.register_uri('GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), json={}, status_code=200) hook = LivyHook() with self.assertRaises(AirflowException): @@ -370,9 +361,7 @@ def test_parse_post_response(self): @requests_mock.mock() def test_delete_batch_success(self, mock): mock.register_uri( - 'DELETE', '//livy:8998/batches/{}'.format(BATCH_ID), - json={'msg': 'deleted'}, - status_code=200 + 'DELETE', '//livy:8998/batches/{}'.format(BATCH_ID), json={'msg': 'deleted'}, status_code=200 ) resp = LivyHook().delete_batch(BATCH_ID) @@ -382,10 +371,7 @@ def test_delete_batch_success(self, mock): @requests_mock.mock() def test_delete_batch_fail(self, mock): mock.register_uri( - 'DELETE', '//livy:8998/batches/{}'.format(BATCH_ID), - json={}, - status_code=400, - reason='ERROR' + 'DELETE', '//livy:8998/batches/{}'.format(BATCH_ID), json={}, status_code=400, reason='ERROR' ) hook = LivyHook() @@ -394,11 +380,7 @@ def test_delete_batch_fail(self, mock): @requests_mock.mock() def test_missing_batch_id(self, mock): - mock.register_uri( - 'POST', '//livy:8998/batches', - json={}, - status_code=201 - ) + mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=201) hook = LivyHook() with self.assertRaises(AirflowException): @@ -407,9 +389,7 @@ def test_missing_batch_id(self, mock): @requests_mock.mock() def test_get_batch_validation(self, mock): mock.register_uri( - 'GET', '//livy:8998/batches/{}'.format(BATCH_ID), - json=SAMPLE_GET_RESPONSE, - status_code=200 + 'GET', '//livy:8998/batches/{}'.format(BATCH_ID), json=SAMPLE_GET_RESPONSE, status_code=200 ) hook = LivyHook() @@ -425,9 +405,7 @@ def test_get_batch_validation(self, mock): @requests_mock.mock() def test_get_batch_state_validation(self, mock): mock.register_uri( - 'GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), - json=SAMPLE_GET_RESPONSE, - status_code=200 + 'GET', '//livy:8998/batches/{}/state'.format(BATCH_ID), json=SAMPLE_GET_RESPONSE, status_code=200 ) hook = LivyHook() @@ -442,9 +420,7 @@ def test_get_batch_state_validation(self, mock): @requests_mock.mock() def test_delete_batch_validation(self, mock): mock.register_uri( - 'DELETE', '//livy:8998/batches/{}'.format(BATCH_ID), - json={'id': BATCH_ID}, - status_code=200 + 'DELETE', '//livy:8998/batches/{}'.format(BATCH_ID), json={'id': BATCH_ID}, status_code=200 ) hook = LivyHook() diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py index 8c56d4654e5b1..86a0aaea706be 100644 --- a/tests/providers/apache/livy/operators/test_livy.py +++ b/tests/providers/apache/livy/operators/test_livy.py @@ -34,15 +34,13 @@ class TestLivyOperator(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) - db.merge_conn(Connection( - conn_id='livyunittest', conn_type='livy', - host='localhost:8998', port='8998', schema='http' - )) + db.merge_conn( + Connection( + conn_id='livyunittest', conn_type='livy', host='localhost:8998', port='8998', schema='http' + ) + ) @patch('airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state') def test_poll_for_termination(self, mock_livy): @@ -57,12 +55,7 @@ def side_effect(_): mock_livy.side_effect = side_effect - task = LivyOperator( - file='sparkapp', - polling_interval=1, - dag=self.dag, - task_id='livy_example' - ) + task = LivyOperator(file='sparkapp', polling_interval=1, dag=self.dag, task_id='livy_example') task._livy_hook = task.get_hook() task.poll_for_termination(BATCH_ID) @@ -82,12 +75,7 @@ def side_effect(_): mock_livy.side_effect = side_effect - task = LivyOperator( - file='sparkapp', - polling_interval=1, - dag=self.dag, - task_id='livy_example' - ) + task = LivyOperator(file='sparkapp', polling_interval=1, dag=self.dag, task_id='livy_example') task._livy_hook = task.get_hook() with self.assertRaises(AirflowException): @@ -96,8 +84,10 @@ def side_effect(_): mock_livy.assert_called_with(BATCH_ID) self.assertEqual(mock_livy.call_count, 3) - @patch('airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state', - return_value=BatchState.SUCCESS) + @patch( + 'airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state', + return_value=BatchState.SUCCESS, + ) @patch('airflow.providers.apache.livy.operators.livy.LivyHook.post_batch', return_value=BATCH_ID) def test_execution(self, mock_post, mock_get): task = LivyOperator( @@ -105,7 +95,7 @@ def test_execution(self, mock_post, mock_get): file='sparkapp', polling_interval=1, dag=self.dag, - task_id='livy_example' + task_id='livy_example', ) task.execute(context={}) @@ -117,10 +107,7 @@ def test_execution(self, mock_post, mock_get): @patch('airflow.providers.apache.livy.operators.livy.LivyHook.post_batch', return_value=BATCH_ID) def test_deletion(self, mock_post, mock_delete): task = LivyOperator( - livy_conn_id='livyunittest', - file='sparkapp', - dag=self.dag, - task_id='livy_example' + livy_conn_id='livyunittest', file='sparkapp', dag=self.dag, task_id='livy_example' ) task.execute(context={}) task.kill() @@ -130,11 +117,7 @@ def test_deletion(self, mock_post, mock_delete): def test_injected_hook(self): def_hook = LivyHook(livy_conn_id='livyunittest') - task = LivyOperator( - file='sparkapp', - dag=self.dag, - task_id='livy_example' - ) + task = LivyOperator(file='sparkapp', dag=self.dag, task_id='livy_example') task._livy_hook = def_hook self.assertEqual(task.get_hook(), def_hook) diff --git a/tests/providers/apache/livy/sensors/test_livy.py b/tests/providers/apache/livy/sensors/test_livy.py index 99f6d618ad667..8440dd9fa69ab 100644 --- a/tests/providers/apache/livy/sensors/test_livy.py +++ b/tests/providers/apache/livy/sensors/test_livy.py @@ -29,25 +29,15 @@ class TestLivySensor(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) - db.merge_conn(Connection( - conn_id='livyunittest', conn_type='livy', - host='http://localhost:8998' - )) + db.merge_conn(Connection(conn_id='livyunittest', conn_type='livy', host='http://localhost:8998')) @patch('airflow.providers.apache.livy.hooks.livy.LivyHook.get_batch_state') def test_poke(self, mock_state): sensor = LivySensor( - livy_conn_id='livyunittest', - task_id='livy_sensor_test', - dag=self.dag, - batch_id=100 + livy_conn_id='livyunittest', task_id='livy_sensor_test', dag=self.dag, batch_id=100 ) for state in BatchState: diff --git a/tests/providers/apache/pig/hooks/test_pig.py b/tests/providers/apache/pig/hooks/test_pig.py index bab66cace61f2..5ff4e3bc6d3b9 100644 --- a/tests/providers/apache/pig/hooks/test_pig.py +++ b/tests/providers/apache/pig/hooks/test_pig.py @@ -24,7 +24,6 @@ class TestPigCliHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -66,6 +65,7 @@ def test_run_cli_fail(self, popen_mock): hook = self.pig_hook() from airflow.exceptions import AirflowException + self.assertRaises(AirflowException, hook.run_cli, "") @mock.patch('subprocess.Popen') diff --git a/tests/providers/apache/pinot/hooks/test_pinot.py b/tests/providers/apache/pinot/hooks/test_pinot.py index 445d71cdcdab0..fcd90828bc9e2 100644 --- a/tests/providers/apache/pinot/hooks/test_pinot.py +++ b/tests/providers/apache/pinot/hooks/test_pinot.py @@ -28,7 +28,6 @@ class TestPinotAdminHook(unittest.TestCase): - def setUp(self): super().setUp() self.conn = conn = mock.MagicMock() @@ -46,19 +45,33 @@ def get_connection(self, conn_id): def test_add_schema(self, mock_run_cli): params = ["schema_file", False] self.db_hook.add_schema(*params) - mock_run_cli.assert_called_once_with(['AddSchema', - '-controllerHost', self.conn.host, - '-controllerPort', self.conn.port, - '-schemaFile', params[0]]) + mock_run_cli.assert_called_once_with( + [ + 'AddSchema', + '-controllerHost', + self.conn.host, + '-controllerPort', + self.conn.port, + '-schemaFile', + params[0], + ] + ) @mock.patch('airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli') def test_add_table(self, mock_run_cli): params = ["config_file", False] self.db_hook.add_table(*params) - mock_run_cli.assert_called_once_with(['AddTable', - '-controllerHost', self.conn.host, - '-controllerPort', self.conn.port, - '-filePath', params[0]]) + mock_run_cli.assert_called_once_with( + [ + 'AddTable', + '-controllerHost', + self.conn.host, + '-controllerPort', + self.conn.port, + '-filePath', + params[0], + ] + ) @mock.patch('airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli') def test_create_segment(self, mock_run_cli): @@ -85,34 +98,61 @@ def test_create_segment(self, mock_run_cli): self.db_hook.create_segment(**params) - mock_run_cli.assert_called_once_with([ - 'CreateSegment', - '-generatorConfigFile', params["generator_config_file"], - '-dataDir', params["data_dir"], - '-format', params["segment_format"], - '-outDir', params["out_dir"], - '-overwrite', params["overwrite"], - '-tableName', params["table_name"], - '-segmentName', params["segment_name"], - '-timeColumnName', params["time_column_name"], - '-schemaFile', params["schema_file"], - '-readerConfigFile', params["reader_config_file"], - '-starTreeIndexSpecFile', params["star_tree_index_spec_file"], - '-hllSize', params["hll_size"], - '-hllColumns', params["hll_columns"], - '-hllSuffix', params["hll_suffix"], - '-numThreads', params["num_threads"], - '-postCreationVerification', params["post_creation_verification"], - '-retry', params["retry"]]) + mock_run_cli.assert_called_once_with( + [ + 'CreateSegment', + '-generatorConfigFile', + params["generator_config_file"], + '-dataDir', + params["data_dir"], + '-format', + params["segment_format"], + '-outDir', + params["out_dir"], + '-overwrite', + params["overwrite"], + '-tableName', + params["table_name"], + '-segmentName', + params["segment_name"], + '-timeColumnName', + params["time_column_name"], + '-schemaFile', + params["schema_file"], + '-readerConfigFile', + params["reader_config_file"], + '-starTreeIndexSpecFile', + params["star_tree_index_spec_file"], + '-hllSize', + params["hll_size"], + '-hllColumns', + params["hll_columns"], + '-hllSuffix', + params["hll_suffix"], + '-numThreads', + params["num_threads"], + '-postCreationVerification', + params["post_creation_verification"], + '-retry', + params["retry"], + ] + ) @mock.patch('airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook.run_cli') def test_upload_segment(self, mock_run_cli): params = ["segment_dir", False] self.db_hook.upload_segment(*params) - mock_run_cli.assert_called_once_with(['UploadSegment', - '-controllerHost', self.conn.host, - '-controllerPort', self.conn.port, - '-segmentDir', params[0]]) + mock_run_cli.assert_called_once_with( + [ + 'UploadSegment', + '-controllerHost', + self.conn.host, + '-controllerPort', + self.conn.port, + '-segmentDir', + params[0], + ] + ) @mock.patch('subprocess.Popen') def test_run_cli_success(self, mock_popen): @@ -124,11 +164,9 @@ def test_run_cli_success(self, mock_popen): params = ["foo", "bar", "baz"] self.db_hook.run_cli(params) params.insert(0, self.conn.extra_dejson.get('cmd_path')) - mock_popen.assert_called_once_with(params, - stderr=subprocess.STDOUT, - stdout=subprocess.PIPE, - close_fds=True, - env=None) + mock_popen.assert_called_once_with( + params, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, close_fds=True, env=None + ) @mock.patch('subprocess.Popen') def test_run_cli_failure_error_message(self, mock_popen): @@ -142,11 +180,9 @@ def test_run_cli_failure_error_message(self, mock_popen): with self.assertRaises(AirflowException, msg=msg): self.db_hook.run_cli(params) params.insert(0, self.conn.extra_dejson.get('cmd_path')) - mock_popen.assert_called_once_with(params, - stderr=subprocess.STDOUT, - stdout=subprocess.PIPE, - close_fds=True, - env=None) + mock_popen.assert_called_once_with( + params, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, close_fds=True, env=None + ) @mock.patch('subprocess.Popen') def test_run_cli_failure_status_code(self, mock_popen): @@ -162,15 +198,12 @@ def test_run_cli_failure_status_code(self, mock_popen): params.insert(0, self.conn.extra_dejson.get('cmd_path')) env = os.environ.copy() env.update({"JAVA_OPTS": "-Dpinot.admin.system.exit=true "}) - mock_popen.assert_called_once_with(params, - stderr=subprocess.STDOUT, - stdout=subprocess.PIPE, - close_fds=True, - env=env) + mock_popen.assert_called_once_with( + params, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, close_fds=True, env=env + ) class TestPinotDbApiHook(unittest.TestCase): - def setUp(self): super().setUp() self.conn = conn = mock.MagicMock() diff --git a/tests/providers/apache/spark/hooks/test_spark_jdbc.py b/tests/providers/apache/spark/hooks/test_spark_jdbc.py index 32ae4426414f4..bd80bccdd8072 100644 --- a/tests/providers/apache/spark/hooks/test_spark_jdbc.py +++ b/tests/providers/apache/spark/hooks/test_spark_jdbc.py @@ -40,7 +40,7 @@ class TestSparkJDBCHook(unittest.TestCase): 'lower_bound': '10', 'upper_bound': '20', 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + 'comments VARCHAR(1024)', } # this config is invalid because if one of [partitionColumn, lowerBound, upperBound] @@ -59,22 +59,28 @@ class TestSparkJDBCHook(unittest.TestCase): 'partition_column': 'columnMcColumnFace', 'upper_bound': '20', 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + 'comments VARCHAR(1024)', } def setUp(self): db.merge_conn( Connection( - conn_id='spark-default', conn_type='spark', + conn_id='spark-default', + conn_type='spark', host='yarn://yarn-master', - extra='{"queue": "root.etl", "deploy-mode": "cluster"}') + extra='{"queue": "root.etl", "deploy-mode": "cluster"}', + ) ) db.merge_conn( Connection( - conn_id='jdbc-default', conn_type='postgres', - host='localhost', schema='default', port=5432, - login='user', password='supersecret', - extra='{"conn_prefix":"jdbc:postgresql://"}' + conn_id='jdbc-default', + conn_type='postgres', + host='localhost', + schema='default', + port=5432, + login='user', + password='supersecret', + extra='{"conn_prefix":"jdbc:postgresql://"}', ) ) @@ -86,7 +92,7 @@ def test_resolve_jdbc_connection(self): 'schema': 'default', 'conn_prefix': 'jdbc:postgresql://', 'user': 'user', - 'password': 'supersecret' + 'password': 'supersecret', } # When @@ -104,23 +110,38 @@ def test_build_jdbc_arguments(self): # Then expected_jdbc_arguments = [ - '-cmdType', 'spark_to_jdbc', - '-url', 'jdbc:postgresql://localhost:5432/default', - '-user', 'user', - '-password', 'supersecret', - '-metastoreTable', 'hiveMcHiveFace', - '-jdbcTable', 'tableMcTableFace', - '-jdbcDriver', 'org.postgresql.Driver', - '-batchsize', '100', - '-fetchsize', '200', - '-numPartitions', '10', - '-partitionColumn', 'columnMcColumnFace', - '-lowerBound', '10', - '-upperBound', '20', - '-saveMode', 'append', - '-saveFormat', 'parquet', - '-createTableColumnTypes', 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + '-cmdType', + 'spark_to_jdbc', + '-url', + 'jdbc:postgresql://localhost:5432/default', + '-user', + 'user', + '-password', + 'supersecret', + '-metastoreTable', + 'hiveMcHiveFace', + '-jdbcTable', + 'tableMcTableFace', + '-jdbcDriver', + 'org.postgresql.Driver', + '-batchsize', + '100', + '-fetchsize', + '200', + '-numPartitions', + '10', + '-partitionColumn', + 'columnMcColumnFace', + '-lowerBound', + '10', + '-upperBound', + '20', + '-saveMode', + 'append', + '-saveFormat', + 'parquet', + '-createTableColumnTypes', + 'columnMcColumnFace INTEGER(100), name CHAR(64),comments VARCHAR(1024)', ] self.assertEqual(expected_jdbc_arguments, cmd) diff --git a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py index 4ea7c3121664e..4be4a14aff9b1 100644 --- a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py +++ b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py @@ -19,8 +19,13 @@ from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from airflow.providers.apache.spark.hooks.spark_jdbc_script import ( - SPARK_READ_FROM_JDBC, SPARK_WRITE_TO_JDBC, _create_spark_session, _parse_arguments, _run_spark, - spark_read_from_jdbc, spark_write_to_jdbc, + SPARK_READ_FROM_JDBC, + SPARK_WRITE_TO_JDBC, + _create_spark_session, + _parse_arguments, + _run_spark, + spark_read_from_jdbc, + spark_write_to_jdbc, ) @@ -32,25 +37,42 @@ def mock_spark_session(): class TestSparkJDBCScrip: jdbc_arguments = [ - '-cmdType', 'spark_to_jdbc', - '-url', 'jdbc:postgresql://localhost:5432/default', - '-user', 'user', - '-password', 'supersecret', - '-metastoreTable', 'hiveMcHiveFace', - '-jdbcTable', 'tableMcTableFace', - '-jdbcDriver', 'org.postgresql.Driver', - '-jdbcTruncate', 'false', - '-saveMode', 'append', - '-saveFormat', 'parquet', - '-batchsize', '100', - '-fetchsize', '200', - '-name', 'airflow-spark-jdbc-script-test', - '-numPartitions', '10', - '-partitionColumn', 'columnMcColumnFace', - '-lowerBound', '10', - '-upperBound', '20', - '-createTableColumnTypes', 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + '-cmdType', + 'spark_to_jdbc', + '-url', + 'jdbc:postgresql://localhost:5432/default', + '-user', + 'user', + '-password', + 'supersecret', + '-metastoreTable', + 'hiveMcHiveFace', + '-jdbcTable', + 'tableMcTableFace', + '-jdbcDriver', + 'org.postgresql.Driver', + '-jdbcTruncate', + 'false', + '-saveMode', + 'append', + '-saveFormat', + 'parquet', + '-batchsize', + '100', + '-fetchsize', + '200', + '-name', + 'airflow-spark-jdbc-script-test', + '-numPartitions', + '10', + '-partitionColumn', + 'columnMcColumnFace', + '-lowerBound', + '10', + '-upperBound', + '20', + '-createTableColumnTypes', + 'columnMcColumnFace INTEGER(100), name CHAR(64),comments VARCHAR(1024)', ] default_arguments = { @@ -72,7 +94,7 @@ class TestSparkJDBCScrip: 'lower_bound': '10', 'upper_bound': '20', 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + 'comments VARCHAR(1024)', } def test_parse_arguments(self): @@ -87,10 +109,7 @@ def test_parse_arguments(self): def test_run_spark_write_to_jdbc(self, mock_spark_write_to_jdbc, mock_spark_session): # Given arguments = _parse_arguments(['-cmdType', SPARK_WRITE_TO_JDBC] + self.jdbc_arguments[2:]) - spark_session = mock_spark_session.builder \ - .appName(arguments.name) \ - .enableHiveSupport() \ - .getOrCreate() + spark_session = mock_spark_session.builder.appName(arguments.name).enableHiveSupport().getOrCreate() # When _run_spark(arguments=arguments) @@ -115,10 +134,7 @@ def test_run_spark_write_to_jdbc(self, mock_spark_write_to_jdbc, mock_spark_sess def test_run_spark_read_from_jdbc(self, mock_spark_read_from_jdbc, mock_spark_session): # Given arguments = _parse_arguments(['-cmdType', SPARK_READ_FROM_JDBC] + self.jdbc_arguments[2:]) - spark_session = mock_spark_session.builder \ - .appName(arguments.name) \ - .enableHiveSupport() \ - .getOrCreate() + spark_session = mock_spark_session.builder.appName(arguments.name).enableHiveSupport().getOrCreate() # When _run_spark(arguments=arguments) @@ -138,7 +154,7 @@ def test_run_spark_read_from_jdbc(self, mock_spark_read_from_jdbc, mock_spark_se arguments.num_partitions, arguments.partition_column, arguments.lower_bound, - arguments.upper_bound + arguments.upper_bound, ) @pytest.mark.system("spark") diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py b/tests/providers/apache/spark/hooks/test_spark_sql.py index 882936d31269c..adc420cf30b60 100644 --- a/tests/providers/apache/spark/hooks/test_spark_sql.py +++ b/tests/providers/apache/spark/hooks/test_spark_sql.py @@ -44,16 +44,12 @@ class TestSparkSqlHook(unittest.TestCase): 'num_executors': 10, 'verbose': True, 'sql': ' /path/to/sql/file.sql ', - 'conf': 'key=value,PROP=VALUE' + 'conf': 'key=value,PROP=VALUE', } def setUp(self): - db.merge_conn( - Connection( - conn_id='spark_default', conn_type='spark', - host='yarn://yarn-master') - ) + db.merge_conn(Connection(conn_id='spark_default', conn_type='spark', host='yarn://yarn-master')) def test_build_command(self): hook = SparkSqlHook(**self._config) @@ -86,27 +82,46 @@ def test_spark_process_runcmd(self, mock_popen): mock_popen.return_value.wait.return_value = 0 # When - hook = SparkSqlHook( - conn_id='spark_default', - sql='SELECT 1' - ) + hook = SparkSqlHook(conn_id='spark_default', sql='SELECT 1') with patch.object(hook.log, 'debug') as mock_debug: with patch.object(hook.log, 'info') as mock_info: hook.run_query() mock_debug.assert_called_once_with( 'Spark-Sql cmd: %s', - ['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose', - '--queue', 'default'] - ) - mock_info.assert_called_once_with( - 'Spark-sql communicates using stdout' + [ + 'spark-sql', + '-e', + 'SELECT 1', + '--master', + 'yarn', + '--name', + 'default-name', + '--verbose', + '--queue', + 'default', + ], ) + mock_info.assert_called_once_with('Spark-sql communicates using stdout') # Then self.assertEqual( mock_popen.mock_calls[0], - call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose', - '--queue', 'default'], stderr=-2, stdout=-1) + call( + [ + 'spark-sql', + '-e', + 'SELECT 1', + '--master', + 'yarn', + '--name', + 'default-name', + '--verbose', + '--queue', + 'default', + ], + stderr=-2, + stdout=-1, + ), ) @patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen') @@ -115,17 +130,30 @@ def test_spark_process_runcmd_with_str(self, mock_popen): mock_popen.return_value.wait.return_value = 0 # When - hook = SparkSqlHook( - conn_id='spark_default', - sql='SELECT 1' - ) + hook = SparkSqlHook(conn_id='spark_default', sql='SELECT 1') hook.run_query('--deploy-mode cluster') # Then self.assertEqual( mock_popen.mock_calls[0], - call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose', - '--queue', 'default', '--deploy-mode', 'cluster'], stderr=-2, stdout=-1) + call( + [ + 'spark-sql', + '-e', + 'SELECT 1', + '--master', + 'yarn', + '--name', + 'default-name', + '--verbose', + '--queue', + 'default', + '--deploy-mode', + 'cluster', + ], + stderr=-2, + stdout=-1, + ), ) @patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen') @@ -134,17 +162,30 @@ def test_spark_process_runcmd_with_list(self, mock_popen): mock_popen.return_value.wait.return_value = 0 # When - hook = SparkSqlHook( - conn_id='spark_default', - sql='SELECT 1' - ) + hook = SparkSqlHook(conn_id='spark_default', sql='SELECT 1') hook.run_query(['--deploy-mode', 'cluster']) # Then self.assertEqual( mock_popen.mock_calls[0], - call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose', - '--queue', 'default', '--deploy-mode', 'cluster'], stderr=-2, stdout=-1) + call( + [ + 'spark-sql', + '-e', + 'SELECT 1', + '--master', + 'yarn', + '--name', + 'default-name', + '--verbose', + '--queue', + 'default', + '--deploy-mode', + 'cluster', + ], + stderr=-2, + stdout=-1, + ), ) @patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen') @@ -158,11 +199,7 @@ def test_spark_process_runcmd_and_fail(self, mock_popen): # When with self.assertRaises(AirflowException) as e: - hook = SparkSqlHook( - conn_id='spark_default', - sql=sql, - master=master, - ) + hook = SparkSqlHook(conn_id='spark_default', sql=sql, master=master,) hook.run_query(params) # Then @@ -170,5 +207,5 @@ def test_spark_process_runcmd_and_fail(self, mock_popen): str(e.exception), "Cannot execute '{}' on {} (additional parameters: '{}'). Process exit code: {}.".format( sql, master, params, status - ) + ), ) diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py b/tests/providers/apache/spark/hooks/test_spark_submit.py index 3879c8bda2c24..f5918b25aa435 100644 --- a/tests/providers/apache/spark/hooks/test_spark_submit.py +++ b/tests/providers/apache/spark/hooks/test_spark_submit.py @@ -33,9 +33,7 @@ class TestSparkSubmitHook(unittest.TestCase): _spark_job_file = 'test_application.py' _config = { - 'conf': { - 'parquet.compression': 'SNAPPY' - }, + 'conf': {'parquet.compression': 'SNAPPY'}, 'conn_id': 'default_spark', 'files': 'hive-site.xml', 'py_files': 'sample_library.py', @@ -56,11 +54,14 @@ class TestSparkSubmitHook(unittest.TestCase): 'driver_memory': '3g', 'java_class': 'com.foo.bar.AppMain', 'application_args': [ - '-f', 'foo', - '--bar', 'bar', - '--with-spaces', 'args should keep embdedded spaces', - 'baz' - ] + '-f', + 'foo', + '--bar', + 'bar', + '--with-spaces', + 'args should keep embdedded spaces', + 'baz', + ], } @staticmethod @@ -75,59 +76,67 @@ def cmd_args_to_dict(list_cmd): def setUp(self): db.merge_conn( Connection( - conn_id='spark_yarn_cluster', conn_type='spark', + conn_id='spark_yarn_cluster', + conn_type='spark', host='yarn://yarn-master', - extra='{"queue": "root.etl", "deploy-mode": "cluster"}') + extra='{"queue": "root.etl", "deploy-mode": "cluster"}', + ) ) db.merge_conn( Connection( - conn_id='spark_k8s_cluster', conn_type='spark', + conn_id='spark_k8s_cluster', + conn_type='spark', host='k8s://https://k8s-master', - extra='{"spark-home": "/opt/spark", ' + - '"deploy-mode": "cluster", ' + - '"namespace": "mynamespace"}') + extra='{"spark-home": "/opt/spark", ' + + '"deploy-mode": "cluster", ' + + '"namespace": "mynamespace"}', + ) ) db.merge_conn( - Connection( - conn_id='spark_default_mesos', conn_type='spark', - host='mesos://host', port=5050) + Connection(conn_id='spark_default_mesos', conn_type='spark', host='mesos://host', port=5050) ) db.merge_conn( Connection( - conn_id='spark_home_set', conn_type='spark', + conn_id='spark_home_set', + conn_type='spark', host='yarn://yarn-master', - extra='{"spark-home": "/opt/myspark"}') + extra='{"spark-home": "/opt/myspark"}', + ) ) + db.merge_conn(Connection(conn_id='spark_home_not_set', conn_type='spark', host='yarn://yarn-master')) db.merge_conn( Connection( - conn_id='spark_home_not_set', conn_type='spark', - host='yarn://yarn-master') - ) - db.merge_conn( - Connection( - conn_id='spark_binary_set', conn_type='spark', - host='yarn', extra='{"spark-binary": "custom-spark-submit"}') + conn_id='spark_binary_set', + conn_type='spark', + host='yarn', + extra='{"spark-binary": "custom-spark-submit"}', + ) ) db.merge_conn( Connection( - conn_id='spark_binary_and_home_set', conn_type='spark', + conn_id='spark_binary_and_home_set', + conn_type='spark', host='yarn', - extra='{"spark-home": "/path/to/spark_home", ' + - '"spark-binary": "custom-spark-submit"}') + extra='{"spark-home": "/path/to/spark_home", ' + '"spark-binary": "custom-spark-submit"}', + ) ) db.merge_conn( Connection( - conn_id='spark_standalone_cluster', conn_type='spark', + conn_id='spark_standalone_cluster', + conn_type='spark', host='spark://spark-standalone-master:6066', - extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}') + extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}', + ) ) db.merge_conn( Connection( - conn_id='spark_standalone_cluster_client_mode', conn_type='spark', + conn_id='spark_standalone_cluster_client_mode', + conn_type='spark', host='spark://spark-standalone-master:6066', - extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "client"}') + extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "client"}', + ) ) def test_build_spark_submit_command(self): @@ -140,31 +149,53 @@ def test_build_spark_submit_command(self): # Then expected_build_cmd = [ 'spark-submit', - '--master', 'yarn', - '--conf', 'parquet.compression=SNAPPY', - '--files', 'hive-site.xml', - '--py-files', 'sample_library.py', - '--archives', 'sample_archive.zip#SAMPLE', - '--jars', 'parquet.jar', - '--packages', 'com.databricks:spark-avro_2.11:3.2.0', - '--exclude-packages', 'org.bad.dependency:1.0.0', - '--repositories', 'http://myrepo.org', - '--num-executors', '10', - '--total-executor-cores', '4', - '--executor-cores', '4', - '--executor-memory', '22g', - '--driver-memory', '3g', - '--keytab', 'privileged_user.keytab', - '--principal', 'user/spark@airflow.org', - '--proxy-user', 'sample_user', - '--name', 'spark-job', - '--class', 'com.foo.bar.AppMain', + '--master', + 'yarn', + '--conf', + 'parquet.compression=SNAPPY', + '--files', + 'hive-site.xml', + '--py-files', + 'sample_library.py', + '--archives', + 'sample_archive.zip#SAMPLE', + '--jars', + 'parquet.jar', + '--packages', + 'com.databricks:spark-avro_2.11:3.2.0', + '--exclude-packages', + 'org.bad.dependency:1.0.0', + '--repositories', + 'http://myrepo.org', + '--num-executors', + '10', + '--total-executor-cores', + '4', + '--executor-cores', + '4', + '--executor-memory', + '22g', + '--driver-memory', + '3g', + '--keytab', + 'privileged_user.keytab', + '--principal', + 'user/spark@airflow.org', + '--proxy-user', + 'sample_user', + '--name', + 'spark-job', + '--class', + 'com.foo.bar.AppMain', '--verbose', 'test_application.py', - '-f', 'foo', - '--bar', 'bar', - '--with-spaces', 'args should keep embdedded spaces', - 'baz' + '-f', + 'foo', + '--bar', + 'bar', + '--with-spaces', + 'args should keep embdedded spaces', + 'baz', ] self.assertEqual(expected_build_cmd, cmd) @@ -173,27 +204,33 @@ def test_build_track_driver_status_command(self): # 'spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster' # Given - hook_spark_standalone_cluster = SparkSubmitHook( - conn_id='spark_standalone_cluster') + hook_spark_standalone_cluster = SparkSubmitHook(conn_id='spark_standalone_cluster') hook_spark_standalone_cluster._driver_id = 'driver-20171128111416-0001' - hook_spark_yarn_cluster = SparkSubmitHook( - conn_id='spark_yarn_cluster') + hook_spark_yarn_cluster = SparkSubmitHook(conn_id='spark_yarn_cluster') hook_spark_yarn_cluster._driver_id = 'driver-20171128111417-0001' # When - build_track_driver_status_spark_standalone_cluster = \ + build_track_driver_status_spark_standalone_cluster = ( hook_spark_standalone_cluster._build_track_driver_status_command() - build_track_driver_status_spark_yarn_cluster = \ + ) + build_track_driver_status_spark_yarn_cluster = ( hook_spark_yarn_cluster._build_track_driver_status_command() + ) # Then expected_spark_standalone_cluster = [ '/usr/bin/curl', '--max-time', '30', - 'http://spark-standalone-master:6066/v1/submissions/status/driver-20171128111416-0001'] + 'http://spark-standalone-master:6066/v1/submissions/status/driver-20171128111416-0001', + ] expected_spark_yarn_cluster = [ - 'spark-submit', '--master', 'yarn://yarn-master', '--status', 'driver-20171128111417-0001'] + 'spark-submit', + '--master', + 'yarn://yarn-master', + '--status', + 'driver-20171128111417-0001', + ] assert expected_spark_standalone_cluster == build_track_driver_status_spark_standalone_cluster assert expected_spark_yarn_cluster == build_track_driver_status_spark_yarn_cluster @@ -210,10 +247,16 @@ def test_spark_process_runcmd(self, mock_popen): hook.submit() # Then - self.assertEqual(mock_popen.mock_calls[0], - call(['spark-submit', '--master', 'yarn', - '--name', 'default-name', ''], - stderr=-2, stdout=-1, universal_newlines=True, bufsize=-1)) + self.assertEqual( + mock_popen.mock_calls[0], + call( + ['spark-submit', '--master', 'yarn', '--name', 'default-name', ''], + stderr=-2, + stdout=-1, + universal_newlines=True, + bufsize=-1, + ), + ) def test_resolve_should_track_driver_status(self): # Given @@ -224,30 +267,33 @@ def test_resolve_should_track_driver_status(self): hook_spark_home_set = SparkSubmitHook(conn_id='spark_home_set') hook_spark_home_not_set = SparkSubmitHook(conn_id='spark_home_not_set') hook_spark_binary_set = SparkSubmitHook(conn_id='spark_binary_set') - hook_spark_binary_and_home_set = SparkSubmitHook( - conn_id='spark_binary_and_home_set') - hook_spark_standalone_cluster = SparkSubmitHook( - conn_id='spark_standalone_cluster') + hook_spark_binary_and_home_set = SparkSubmitHook(conn_id='spark_binary_and_home_set') + hook_spark_standalone_cluster = SparkSubmitHook(conn_id='spark_standalone_cluster') # When - should_track_driver_status_default = hook_default \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_yarn_cluster = hook_spark_yarn_cluster \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_k8s_cluster = hook_spark_k8s_cluster \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_default_mesos = hook_spark_default_mesos \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_home_set = hook_spark_home_set \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_home_not_set = hook_spark_home_not_set \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_binary_set = hook_spark_binary_set \ - ._resolve_should_track_driver_status() - should_track_driver_status_spark_binary_and_home_set = \ + should_track_driver_status_default = hook_default._resolve_should_track_driver_status() + should_track_driver_status_spark_yarn_cluster = ( + hook_spark_yarn_cluster._resolve_should_track_driver_status() + ) + should_track_driver_status_spark_k8s_cluster = ( + hook_spark_k8s_cluster._resolve_should_track_driver_status() + ) + should_track_driver_status_spark_default_mesos = ( + hook_spark_default_mesos._resolve_should_track_driver_status() + ) + should_track_driver_status_spark_home_set = hook_spark_home_set._resolve_should_track_driver_status() + should_track_driver_status_spark_home_not_set = ( + hook_spark_home_not_set._resolve_should_track_driver_status() + ) + should_track_driver_status_spark_binary_set = ( + hook_spark_binary_set._resolve_should_track_driver_status() + ) + should_track_driver_status_spark_binary_and_home_set = ( hook_spark_binary_and_home_set._resolve_should_track_driver_status() - should_track_driver_status_spark_standalone_cluster = \ + ) + should_track_driver_status_spark_standalone_cluster = ( hook_spark_standalone_cluster._resolve_should_track_driver_status() + ) # Then self.assertEqual(should_track_driver_status_default, False) @@ -270,12 +316,14 @@ def test_resolve_connection_yarn_default(self): # Then dict_cmd = self.cmd_args_to_dict(cmd) - expected_spark_connection = {"master": "yarn", - "spark_binary": "spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "yarn") @@ -289,12 +337,14 @@ def test_resolve_connection_yarn_default_connection(self): # Then dict_cmd = self.cmd_args_to_dict(cmd) - expected_spark_connection = {"master": "yarn", - "spark_binary": "spark-submit", - "deploy_mode": None, - "queue": "root.default", - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": "root.default", + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "yarn") self.assertEqual(dict_cmd["--queue"], "root.default") @@ -309,12 +359,14 @@ def test_resolve_connection_mesos_default_connection(self): # Then dict_cmd = self.cmd_args_to_dict(cmd) - expected_spark_connection = {"master": "mesos://host:5050", - "spark_binary": "spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "mesos://host:5050", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "mesos://host:5050") @@ -328,12 +380,14 @@ def test_resolve_connection_spark_yarn_cluster_connection(self): # Then dict_cmd = self.cmd_args_to_dict(cmd) - expected_spark_connection = {"master": "yarn://yarn-master", - "spark_binary": "spark-submit", - "deploy_mode": "cluster", - "queue": "root.etl", - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn://yarn-master", + "spark_binary": "spark-submit", + "deploy_mode": "cluster", + "queue": "root.etl", + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "yarn://yarn-master") self.assertEqual(dict_cmd["--queue"], "root.etl") @@ -349,12 +403,14 @@ def test_resolve_connection_spark_k8s_cluster_connection(self): # Then dict_cmd = self.cmd_args_to_dict(cmd) - expected_spark_connection = {"spark_home": "/opt/spark", - "queue": None, - "spark_binary": "spark-submit", - "master": "k8s://https://k8s-master", - "deploy_mode": "cluster", - "namespace": "mynamespace"} + expected_spark_connection = { + "spark_home": "/opt/spark", + "queue": None, + "spark_binary": "spark-submit", + "master": "k8s://https://k8s-master", + "deploy_mode": "cluster", + "namespace": "mynamespace", + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "k8s://https://k8s-master") self.assertEqual(dict_cmd["--deploy-mode"], "cluster") @@ -372,12 +428,14 @@ def test_resolve_connection_spark_k8s_cluster_ns_conf(self): # Then dict_cmd = self.cmd_args_to_dict(cmd) - expected_spark_connection = {"spark_home": "/opt/spark", - "queue": None, - "spark_binary": "spark-submit", - "master": "k8s://https://k8s-master", - "deploy_mode": "cluster", - "namespace": "airflow"} + expected_spark_connection = { + "spark_home": "/opt/spark", + "queue": None, + "spark_binary": "spark-submit", + "master": "k8s://https://k8s-master", + "deploy_mode": "cluster", + "namespace": "airflow", + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "k8s://https://k8s-master") self.assertEqual(dict_cmd["--deploy-mode"], "cluster") @@ -392,12 +450,14 @@ def test_resolve_connection_spark_home_set_connection(self): cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "yarn://yarn-master", - "spark_binary": "spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": "/opt/myspark", - "namespace": None} + expected_spark_connection = { + "master": "yarn://yarn-master", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": "/opt/myspark", + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], '/opt/myspark/bin/spark-submit') @@ -410,12 +470,14 @@ def test_resolve_connection_spark_home_not_set_connection(self): cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "yarn://yarn-master", - "spark_binary": "spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn://yarn-master", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], 'spark-submit') @@ -428,31 +490,34 @@ def test_resolve_connection_spark_binary_set_connection(self): cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "yarn", - "spark_binary": "custom-spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn", + "spark_binary": "custom-spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], 'custom-spark-submit') def test_resolve_connection_spark_binary_default_value_override(self): # Given - hook = SparkSubmitHook(conn_id='spark_binary_set', - spark_binary='another-custom-spark-submit') + hook = SparkSubmitHook(conn_id='spark_binary_set', spark_binary='another-custom-spark-submit') # When connection = hook._resolve_connection() cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "yarn", - "spark_binary": "another-custom-spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn", + "spark_binary": "another-custom-spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], 'another-custom-spark-submit') @@ -465,12 +530,14 @@ def test_resolve_connection_spark_binary_default_value(self): cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "yarn", - "spark_binary": "spark-submit", - "deploy_mode": None, - "queue": 'root.default', - "spark_home": None, - "namespace": None} + expected_spark_connection = { + "master": "yarn", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": 'root.default', + "spark_home": None, + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], 'spark-submit') @@ -483,12 +550,14 @@ def test_resolve_connection_spark_binary_and_home_set_connection(self): cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "yarn", - "spark_binary": "custom-spark-submit", - "deploy_mode": None, - "queue": None, - "spark_home": "/path/to/spark_home", - "namespace": None} + expected_spark_connection = { + "master": "yarn", + "spark_binary": "custom-spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": "/path/to/spark_home", + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit') @@ -501,19 +570,20 @@ def test_resolve_connection_spark_standalone_cluster_connection(self): cmd = hook._build_spark_submit_command(self._spark_job_file) # Then - expected_spark_connection = {"master": "spark://spark-standalone-master:6066", - "spark_binary": "spark-submit", - "deploy_mode": "cluster", - "queue": None, - "spark_home": "/path/to/spark_home", - "namespace": None} + expected_spark_connection = { + "master": "spark://spark-standalone-master:6066", + "spark_binary": "spark-submit", + "deploy_mode": "cluster", + "queue": None, + "spark_home": "/path/to/spark_home", + "namespace": None, + } self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], '/path/to/spark_home/bin/spark-submit') def test_resolve_spark_submit_env_vars_standalone_client_mode(self): # Given - hook = SparkSubmitHook(conn_id='spark_standalone_cluster_client_mode', - env_vars={"bar": "foo"}) + hook = SparkSubmitHook(conn_id='spark_standalone_cluster_client_mode', env_vars={"bar": "foo"}) # When hook._build_spark_submit_command(self._spark_job_file) @@ -522,23 +592,19 @@ def test_resolve_spark_submit_env_vars_standalone_client_mode(self): self.assertEqual(hook._env, {"bar": "foo"}) def test_resolve_spark_submit_env_vars_standalone_cluster_mode(self): - def env_vars_exception_in_standalone_cluster_mode(): # Given - hook = SparkSubmitHook(conn_id='spark_standalone_cluster', - env_vars={"bar": "foo"}) + hook = SparkSubmitHook(conn_id='spark_standalone_cluster', env_vars={"bar": "foo"}) # When hook._build_spark_submit_command(self._spark_job_file) # Then - self.assertRaises(AirflowException, - env_vars_exception_in_standalone_cluster_mode) + self.assertRaises(AirflowException, env_vars_exception_in_standalone_cluster_mode) def test_resolve_spark_submit_env_vars_yarn(self): # Given - hook = SparkSubmitHook(conn_id='spark_yarn_cluster', - env_vars={"bar": "foo"}) + hook = SparkSubmitHook(conn_id='spark_yarn_cluster', env_vars={"bar": "foo"}) # When cmd = hook._build_spark_submit_command(self._spark_job_file) @@ -549,8 +615,7 @@ def test_resolve_spark_submit_env_vars_yarn(self): def test_resolve_spark_submit_env_vars_k8s(self): # Given - hook = SparkSubmitHook(conn_id='spark_k8s_cluster', - env_vars={"bar": "foo"}) + hook = SparkSubmitHook(conn_id='spark_k8s_cluster', env_vars={"bar": "foo"}) # When cmd = hook._build_spark_submit_command(self._spark_job_file) @@ -563,13 +628,12 @@ def test_process_spark_submit_log_yarn(self): hook = SparkSubmitHook(conn_id='spark_yarn_cluster') log_lines = [ 'SPARK_MAJOR_VERSION is set to 2, using Spark2', - 'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' + - 'platform... using builtin-java classes where applicable', + 'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' + + 'platform... using builtin-java classes where applicable', 'WARN DomainSocketFactory: The short-circuit local reads feature cannot ' 'be used because libhadoop cannot be loaded.', 'INFO Client: Requesting a new application from cluster with 10 NodeManagers', - 'INFO Client: Submitting application application_1486558679801_1820 ' + - 'to ResourceManager' + 'INFO Client: Submitting application application_1486558679801_1820 ' + 'to ResourceManager', ] # When hook._process_spark_submit_log(log_lines) @@ -582,42 +646,39 @@ def test_process_spark_submit_log_k8s(self): # Given hook = SparkSubmitHook(conn_id='spark_k8s_cluster') log_lines = [ - 'INFO LoggingPodStatusWatcherImpl:54 - State changed, new state:' + - 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + - 'namespace: default' + - 'labels: spark-app-selector -> spark-465b868ada474bda82ccb84ab2747fcd,' + - 'spark-role -> driver' + - 'pod uid: ba9c61f6-205f-11e8-b65f-d48564c88e42' + - 'creation time: 2018-03-05T10:26:55Z' + - 'service account name: spark' + - 'volumes: spark-init-properties, download-jars-volume,' + - 'download-files-volume, spark-token-2vmlm' + - 'node name: N/A' + - 'start time: N/A' + - 'container images: N/A' + - 'phase: Pending' + - 'status: []' + - '2018-03-05 11:26:56 INFO LoggingPodStatusWatcherImpl:54 - State changed,' + - ' new state:' + - 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + - 'namespace: default' + - 'Exit code: 999' + 'INFO LoggingPodStatusWatcherImpl:54 - State changed, new state:' + + 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + + 'namespace: default' + + 'labels: spark-app-selector -> spark-465b868ada474bda82ccb84ab2747fcd,' + + 'spark-role -> driver' + + 'pod uid: ba9c61f6-205f-11e8-b65f-d48564c88e42' + + 'creation time: 2018-03-05T10:26:55Z' + + 'service account name: spark' + + 'volumes: spark-init-properties, download-jars-volume,' + + 'download-files-volume, spark-token-2vmlm' + + 'node name: N/A' + + 'start time: N/A' + + 'container images: N/A' + + 'phase: Pending' + + 'status: []' + + '2018-03-05 11:26:56 INFO LoggingPodStatusWatcherImpl:54 - State changed,' + + ' new state:' + + 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + + 'namespace: default' + + 'Exit code: 999' ] # When hook._process_spark_submit_log(log_lines) # Then - self.assertEqual(hook._kubernetes_driver_pod, - 'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver') + self.assertEqual(hook._kubernetes_driver_pod, 'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver') self.assertEqual(hook._spark_exit_code, 999) def test_process_spark_submit_log_k8s_spark_3(self): # Given hook = SparkSubmitHook(conn_id='spark_k8s_cluster') - log_lines = [ - 'exit code: 999' - ] + log_lines = ['exit code: 999'] # When hook._process_spark_submit_log(log_lines) @@ -632,8 +693,8 @@ def test_process_spark_submit_log_standalone_cluster(self): 'Running Spark using the REST application submission protocol.', '17/11/28 11:14:15 INFO RestSubmissionClient: Submitting a request ' 'to launch an application in spark://spark-standalone-master:6066', - '17/11/28 11:14:15 INFO RestSubmissionClient: Submission successfully ' + - 'created as driver-20171128111415-0001. Polling submission state...' + '17/11/28 11:14:15 INFO RestSubmissionClient: Submission successfully ' + + 'created as driver-20171128111415-0001. Polling submission state...', ] # When hook._process_spark_submit_log(log_lines) @@ -646,10 +707,10 @@ def test_process_spark_driver_status_log(self): # Given hook = SparkSubmitHook(conn_id='spark_standalone_cluster') log_lines = [ - 'Submitting a request for the status of submission ' + - 'driver-20171128111415-0001 in spark://spark-standalone-master:6066', - '17/11/28 11:15:37 INFO RestSubmissionClient: Server responded with ' + - 'SubmissionStatusResponse:', + 'Submitting a request for the status of submission ' + + 'driver-20171128111415-0001 in spark://spark-standalone-master:6066', + '17/11/28 11:15:37 INFO RestSubmissionClient: Server responded with ' + + 'SubmissionStatusResponse:', '{', '"action" : "SubmissionStatusResponse",', '"driverState" : "RUNNING",', @@ -658,7 +719,7 @@ def test_process_spark_driver_status_log(self): '"success" : true,', '"workerHostPort" : "172.18.0.7:38561",', '"workerId" : "worker-20171128110741-172.18.0.7-38561"', - '}' + '}', ] # When hook._process_spark_status_log(log_lines) @@ -677,14 +738,13 @@ def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt): mock_popen.return_value.wait.return_value = 0 log_lines = [ 'SPARK_MAJOR_VERSION is set to 2, using Spark2', - 'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' + - 'platform... using builtin-java classes where applicable', - 'WARN DomainSocketFactory: The short-circuit local reads feature cannot ' + - 'be used because libhadoop cannot be loaded.', - 'INFO Client: Requesting a new application from cluster with 10 ' + - 'NodeManagerapplication_1486558679801_1820s', - 'INFO Client: Submitting application application_1486558679801_1820 ' + - 'to ResourceManager' + 'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' + + 'platform... using builtin-java classes where applicable', + 'WARN DomainSocketFactory: The short-circuit local reads feature cannot ' + + 'be used because libhadoop cannot be loaded.', + 'INFO Client: Requesting a new application from cluster with 10 ' + + 'NodeManagerapplication_1486558679801_1820s', + 'INFO Client: Submitting application application_1486558679801_1820 ' + 'to ResourceManager', ] hook = SparkSubmitHook(conn_id='spark_yarn_cluster') hook._process_spark_submit_log(log_lines) @@ -694,15 +754,21 @@ def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt): hook.on_kill() # Then - self.assertIn(call(['yarn', 'application', '-kill', - 'application_1486558679801_1820'], - env=None, stderr=-1, stdout=-1), - mock_popen.mock_calls) + self.assertIn( + call( + ['yarn', 'application', '-kill', 'application_1486558679801_1820'], + env=None, + stderr=-1, + stdout=-1, + ), + mock_popen.mock_calls, + ) # resetting the mock to test kill with keytab & principal mock_popen.reset_mock() # Given - hook = SparkSubmitHook(conn_id='spark_yarn_cluster', keytab='privileged_user.keytab', - principal='user/spark@airflow.org') + hook = SparkSubmitHook( + conn_id='spark_yarn_cluster', keytab='privileged_user.keytab', principal='user/spark@airflow.org' + ) hook._process_spark_submit_log(log_lines) hook.submit() @@ -711,19 +777,24 @@ def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt): # Then expected_env = os.environ.copy() expected_env["KRB5CCNAME"] = '/tmp/airflow_krb5_ccache' - self.assertIn(call(['yarn', 'application', '-kill', - 'application_1486558679801_1820'], - env=expected_env, stderr=-1, stdout=-1), - mock_popen.mock_calls) + self.assertIn( + call( + ['yarn', 'application', '-kill', 'application_1486558679801_1820'], + env=expected_env, + stderr=-1, + stdout=-1, + ), + mock_popen.mock_calls, + ) def test_standalone_cluster_process_on_kill(self): # Given log_lines = [ 'Running Spark using the REST application submission protocol.', - '17/11/28 11:14:15 INFO RestSubmissionClient: Submitting a request ' + - 'to launch an application in spark://spark-standalone-master:6066', - '17/11/28 11:14:15 INFO RestSubmissionClient: Submission successfully ' + - 'created as driver-20171128111415-0001. Polling submission state...' + '17/11/28 11:14:15 INFO RestSubmissionClient: Submitting a request ' + + 'to launch an application in spark://spark-standalone-master:6066', + '17/11/28 11:14:15 INFO RestSubmissionClient: Submission successfully ' + + 'created as driver-20171128111415-0001. Polling submission state...', ] hook = SparkSubmitHook(conn_id='spark_standalone_cluster') hook._process_spark_submit_log(log_lines) @@ -749,26 +820,26 @@ def test_k8s_process_on_kill(self, mock_popen, mock_client_method): client = mock_client_method.return_value hook = SparkSubmitHook(conn_id='spark_k8s_cluster') log_lines = [ - 'INFO LoggingPodStatusWatcherImpl:54 - State changed, new state:' + - 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + - 'namespace: default' + - 'labels: spark-app-selector -> spark-465b868ada474bda82ccb84ab2747fcd,' + - 'spark-role -> driver' + - 'pod uid: ba9c61f6-205f-11e8-b65f-d48564c88e42' + - 'creation time: 2018-03-05T10:26:55Z' + - 'service account name: spark' + - 'volumes: spark-init-properties, download-jars-volume,' + - 'download-files-volume, spark-token-2vmlm' + - 'node name: N/A' + - 'start time: N/A' + - 'container images: N/A' + - 'phase: Pending' + - 'status: []' + - '2018-03-05 11:26:56 INFO LoggingPodStatusWatcherImpl:54 - State changed,' + - ' new state:' + - 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + - 'namespace: default' + - 'Exit code: 0' + 'INFO LoggingPodStatusWatcherImpl:54 - State changed, new state:' + + 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + + 'namespace: default' + + 'labels: spark-app-selector -> spark-465b868ada474bda82ccb84ab2747fcd,' + + 'spark-role -> driver' + + 'pod uid: ba9c61f6-205f-11e8-b65f-d48564c88e42' + + 'creation time: 2018-03-05T10:26:55Z' + + 'service account name: spark' + + 'volumes: spark-init-properties, download-jars-volume,' + + 'download-files-volume, spark-token-2vmlm' + + 'node name: N/A' + + 'start time: N/A' + + 'container images: N/A' + + 'phase: Pending' + + 'status: []' + + '2018-03-05 11:26:56 INFO LoggingPodStatusWatcherImpl:54 - State changed,' + + ' new state:' + + 'pod name: spark-pi-edf2ace37be7353a958b38733a12f8e6-driver' + + 'namespace: default' + + 'Exit code: 0' ] hook._process_spark_submit_log(log_lines) hook.submit() @@ -778,10 +849,11 @@ def test_k8s_process_on_kill(self, mock_popen, mock_client_method): # Then import kubernetes + kwargs = {'pretty': True, 'body': kubernetes.client.V1DeleteOptions()} client.delete_namespaced_pod.assert_called_once_with( - 'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver', - 'mynamespace', **kwargs) + 'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver', 'mynamespace', **kwargs + ) @parameterized.expand( ( @@ -813,10 +885,7 @@ def test_k8s_process_on_kill(self, mock_popen, mock_client_method): ("spark-submit", "foo", "--bar", "baz", '--password="sec\'ret"'), 'spark-submit foo --bar baz --password="******"', ), - ( - ("spark-submit",), - "spark-submit", - ), + (("spark-submit",), "spark-submit",), ) ) def test_masks_passwords(self, command: str, expected: str) -> None: diff --git a/tests/providers/apache/spark/operators/test_spark_jdbc.py b/tests/providers/apache/spark/operators/test_spark_jdbc.py index 807725b849251..25143adf28499 100644 --- a/tests/providers/apache/spark/operators/test_spark_jdbc.py +++ b/tests/providers/apache/spark/operators/test_spark_jdbc.py @@ -29,9 +29,7 @@ class TestSparkJDBCOperator(unittest.TestCase): _config = { 'spark_app_name': '{{ task_instance.task_id }}', - 'spark_conf': { - 'parquet.compression': 'SNAPPY' - }, + 'spark_conf': {'parquet.compression': 'SNAPPY'}, 'spark_files': 'hive-site.xml', 'spark_py_files': 'sample_library.py', 'spark_jars': 'parquet.jar', @@ -56,14 +54,11 @@ class TestSparkJDBCOperator(unittest.TestCase): 'lower_bound': '10', 'upper_bound': '20', 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + 'comments VARCHAR(1024)', } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) def test_execute(self): @@ -71,18 +66,12 @@ def test_execute(self): spark_conn_id = 'spark-default' jdbc_conn_id = 'jdbc-default' - operator = SparkJDBCOperator( - task_id='spark_jdbc_job', - dag=self.dag, - **self._config - ) + operator = SparkJDBCOperator(task_id='spark_jdbc_job', dag=self.dag, **self._config) # Then expected_dict = { 'spark_app_name': '{{ task_instance.task_id }}', - 'spark_conf': { - 'parquet.compression': 'SNAPPY' - }, + 'spark_conf': {'parquet.compression': 'SNAPPY'}, 'spark_files': 'hive-site.xml', 'spark_py_files': 'sample_library.py', 'spark_jars': 'parquet.jar', @@ -107,7 +96,7 @@ def test_execute(self): 'lower_bound': '10', 'upper_bound': '20', 'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),' - 'comments VARCHAR(1024)' + 'comments VARCHAR(1024)', } self.assertEqual(spark_conn_id, operator._spark_conn_id) @@ -137,5 +126,4 @@ def test_execute(self): self.assertEqual(expected_dict['partition_column'], operator._partition_column) self.assertEqual(expected_dict['lower_bound'], operator._lower_bound) self.assertEqual(expected_dict['upper_bound'], operator._upper_bound) - self.assertEqual(expected_dict['create_table_column_types'], - operator._create_table_column_types) + self.assertEqual(expected_dict['create_table_column_types'], operator._create_table_column_types) diff --git a/tests/providers/apache/spark/operators/test_spark_sql.py b/tests/providers/apache/spark/operators/test_spark_sql.py index dabf1a4fef200..a282e0170afc3 100644 --- a/tests/providers/apache/spark/operators/test_spark_sql.py +++ b/tests/providers/apache/spark/operators/test_spark_sql.py @@ -39,23 +39,16 @@ class TestSparkSqlOperator(unittest.TestCase): 'name': 'special-application-name', 'num_executors': 8, 'verbose': False, - 'yarn_queue': 'special-queue' + 'yarn_queue': 'special-queue', } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) def test_execute(self): # Given / When - operator = SparkSqlOperator( - task_id='spark_sql_job', - dag=self.dag, - **self._config - ) + operator = SparkSqlOperator(task_id='spark_sql_job', dag=self.dag, **self._config) self.assertEqual(self._config['sql'], operator._sql) self.assertEqual(self._config['conn_id'], operator._conn_id) diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py index fda353f581727..4dc9a5fca01c7 100644 --- a/tests/providers/apache/spark/operators/test_spark_submit.py +++ b/tests/providers/apache/spark/operators/test_spark_submit.py @@ -31,9 +31,7 @@ class TestSparkSubmitOperator(unittest.TestCase): _config = { - 'conf': { - 'parquet.compression': 'SNAPPY' - }, + 'conf': {'parquet.compression': 'SNAPPY'}, 'files': 'hive-site.xml', 'py_files': 'sample_library.py', 'archives': 'sample_archive.zip#SAMPLE', @@ -56,19 +54,21 @@ class TestSparkSubmitOperator(unittest.TestCase): 'driver_memory': '3g', 'java_class': 'com.foo.bar.AppMain', 'application_args': [ - '-f', 'foo', - '--bar', 'bar', - '--start', '{{ macros.ds_add(ds, -1)}}', - '--end', '{{ ds }}', - '--with-spaces', 'args should keep embdedded spaces', - ] + '-f', + 'foo', + '--bar', + 'bar', + '--start', + '{{ macros.ds_add(ds, -1)}}', + '--end', + '{{ ds }}', + '--with-spaces', + 'args should keep embdedded spaces', + ], } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) def test_execute(self): @@ -76,17 +76,12 @@ def test_execute(self): # Given / When conn_id = 'spark_default' operator = SparkSubmitOperator( - task_id='spark_submit_job', - spark_binary="sparky", - dag=self.dag, - **self._config + task_id='spark_submit_job', spark_binary="sparky", dag=self.dag, **self._config ) # Then expected results expected_dict = { - 'conf': { - 'parquet.compression': 'SNAPPY' - }, + 'conf': {'parquet.compression': 'SNAPPY'}, 'files': 'hive-site.xml', 'py_files': 'sample_library.py', 'archives': 'sample_archive.zip#SAMPLE', @@ -109,13 +104,18 @@ def test_execute(self): 'driver_memory': '3g', 'java_class': 'com.foo.bar.AppMain', 'application_args': [ - '-f', 'foo', - '--bar', 'bar', - '--start', '{{ macros.ds_add(ds, -1)}}', - '--end', '{{ ds }}', - '--with-spaces', 'args should keep embdedded spaces', + '-f', + 'foo', + '--bar', + 'bar', + '--start', + '{{ macros.ds_add(ds, -1)}}', + '--end', + '{{ ds }}', + '--with-spaces', + 'args should keep embdedded spaces', ], - 'spark_binary': 'sparky' + 'spark_binary': 'sparky', } self.assertEqual(conn_id, operator._conn_id) @@ -129,8 +129,7 @@ def test_execute(self): self.assertEqual(expected_dict['packages'], operator._packages) self.assertEqual(expected_dict['exclude_packages'], operator._exclude_packages) self.assertEqual(expected_dict['repositories'], operator._repositories) - self.assertEqual(expected_dict['total_executor_cores'], - operator._total_executor_cores) + self.assertEqual(expected_dict['total_executor_cores'], operator._total_executor_cores) self.assertEqual(expected_dict['executor_cores'], operator._executor_cores) self.assertEqual(expected_dict['executor_memory'], operator._executor_memory) self.assertEqual(expected_dict['keytab'], operator._keytab) @@ -147,23 +146,25 @@ def test_execute(self): def test_render_template(self): # Given - operator = SparkSubmitOperator(task_id='spark_submit_job', - dag=self.dag, **self._config) + operator = SparkSubmitOperator(task_id='spark_submit_job', dag=self.dag, **self._config) ti = TaskInstance(operator, DEFAULT_DATE) # When ti.render_templates() # Then - expected_application_args = ['-f', 'foo', - '--bar', 'bar', - '--start', (DEFAULT_DATE - timedelta(days=1)) - .strftime("%Y-%m-%d"), - '--end', DEFAULT_DATE.strftime("%Y-%m-%d"), - '--with-spaces', - 'args should keep embdedded spaces', - ] + expected_application_args = [ + '-f', + 'foo', + '--bar', + 'bar', + '--start', + (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"), + '--end', + DEFAULT_DATE.strftime("%Y-%m-%d"), + '--with-spaces', + 'args should keep embdedded spaces', + ] expected_name = 'spark_submit_job' - self.assertListEqual(expected_application_args, - getattr(operator, '_application_args')) + self.assertListEqual(expected_application_args, getattr(operator, '_application_args')) self.assertEqual(expected_name, getattr(operator, '_name')) diff --git a/tests/providers/apache/sqoop/hooks/test_sqoop.py b/tests/providers/apache/sqoop/hooks/test_sqoop.py index 7ddf6bd68d6cf..6249be2375efd 100644 --- a/tests/providers/apache/sqoop/hooks/test_sqoop.py +++ b/tests/providers/apache/sqoop/hooks/test_sqoop.py @@ -34,11 +34,9 @@ class TestSqoopHook(unittest.TestCase): 'conn_id': 'sqoop_test', 'num_mappers': 22, 'verbose': True, - 'properties': { - 'mapred.map.max.attempts': '1' - }, + 'properties': {'mapred.map.max.attempts': '1'}, 'hcatalog_database': 'hive_database', - 'hcatalog_table': 'hive_table' + 'hcatalog_table': 'hive_table', } _config_export = { 'table': 'domino.export_data_to', @@ -54,11 +52,9 @@ class TestSqoopHook(unittest.TestCase): 'input_optionally_enclosed_by': '"', 'batch': True, 'relaxed_isolation': True, - 'extra_export_options': collections.OrderedDict([ - ('update-key', 'id'), - ('update-mode', 'allowinsert'), - ('fetch-size', 1) - ]) + 'extra_export_options': collections.OrderedDict( + [('update-key', 'id'), ('update-mode', 'allowinsert'), ('fetch-size', 1)] + ), } _config_import = { 'target_dir': '/hdfs/data/target/location', @@ -70,8 +66,8 @@ class TestSqoopHook(unittest.TestCase): 'extra_import_options': { 'hcatalog-storage-stanza': "\"stored as orcfile\"", 'show': '', - 'fetch-size': 1 - } + 'fetch-size': 1, + }, } _config_json = { @@ -79,14 +75,18 @@ class TestSqoopHook(unittest.TestCase): 'job_tracker': 'http://0.0.0.0:50030/', 'libjars': '/path/to/jars', 'files': '/path/to/files', - 'archives': '/path/to/archives' + 'archives': '/path/to/archives', } def setUp(self): db.merge_conn( Connection( - conn_id='sqoop_test', conn_type='sqoop', schema='schema', - host='rmdbs', port=5050, extra=json.dumps(self._config_json) + conn_id='sqoop_test', + conn_type='sqoop', + schema='schema', + host='rmdbs', + port=5050, + extra=json.dumps(self._config_json), ) ) @@ -96,39 +96,68 @@ def test_popen(self, mock_popen): mock_popen.return_value.stdout = StringIO('stdout') mock_popen.return_value.stderr = StringIO('stderr') mock_popen.return_value.returncode = 0 - mock_popen.return_value.communicate.return_value = \ - [StringIO('stdout\nstdout'), StringIO('stderr\nstderr')] + mock_popen.return_value.communicate.return_value = [ + StringIO('stdout\nstdout'), + StringIO('stderr\nstderr'), + ] # When hook = SqoopHook(conn_id='sqoop_test') hook.export_table(**self._config_export) # Then - self.assertEqual(mock_popen.mock_calls[0], call( - ['sqoop', - 'export', - '-fs', self._config_json['namenode'], - '-jt', self._config_json['job_tracker'], - '-libjars', self._config_json['libjars'], - '-files', self._config_json['files'], - '-archives', self._config_json['archives'], - '--connect', 'rmdbs:5050/schema', - '--input-null-string', self._config_export['input_null_string'], - '--input-null-non-string', self._config_export['input_null_non_string'], - '--staging-table', self._config_export['staging_table'], - '--clear-staging-table', - '--enclosed-by', self._config_export['enclosed_by'], - '--escaped-by', self._config_export['escaped_by'], - '--input-fields-terminated-by', self._config_export['input_fields_terminated_by'], - '--input-lines-terminated-by', self._config_export['input_lines_terminated_by'], - '--input-optionally-enclosed-by', self._config_export['input_optionally_enclosed_by'], - '--batch', - '--relaxed-isolation', - '--export-dir', self._config_export['export_dir'], - '--update-key', 'id', - '--update-mode', 'allowinsert', - '--fetch-size', str(self._config_export['extra_export_options'].get('fetch-size')), - '--table', self._config_export['table']], stderr=-2, stdout=-1)) + self.assertEqual( + mock_popen.mock_calls[0], + call( + [ + 'sqoop', + 'export', + '-fs', + self._config_json['namenode'], + '-jt', + self._config_json['job_tracker'], + '-libjars', + self._config_json['libjars'], + '-files', + self._config_json['files'], + '-archives', + self._config_json['archives'], + '--connect', + 'rmdbs:5050/schema', + '--input-null-string', + self._config_export['input_null_string'], + '--input-null-non-string', + self._config_export['input_null_non_string'], + '--staging-table', + self._config_export['staging_table'], + '--clear-staging-table', + '--enclosed-by', + self._config_export['enclosed_by'], + '--escaped-by', + self._config_export['escaped_by'], + '--input-fields-terminated-by', + self._config_export['input_fields_terminated_by'], + '--input-lines-terminated-by', + self._config_export['input_lines_terminated_by'], + '--input-optionally-enclosed-by', + self._config_export['input_optionally_enclosed_by'], + '--batch', + '--relaxed-isolation', + '--export-dir', + self._config_export['export_dir'], + '--update-key', + 'id', + '--update-mode', + 'allowinsert', + '--fetch-size', + str(self._config_export['extra_export_options'].get('fetch-size')), + '--table', + self._config_export['table'], + ], + stderr=-2, + stdout=-1, + ), + ) def test_submit_none_mappers(self): """ @@ -201,40 +230,35 @@ def test_export_cmd(self): self._config_export['table'], self._config_export['export_dir'], input_null_string=self._config_export['input_null_string'], - input_null_non_string=self._config_export[ - 'input_null_non_string'], + input_null_non_string=self._config_export['input_null_non_string'], staging_table=self._config_export['staging_table'], clear_staging_table=self._config_export['clear_staging_table'], enclosed_by=self._config_export['enclosed_by'], escaped_by=self._config_export['escaped_by'], - input_fields_terminated_by=self._config_export[ - 'input_fields_terminated_by'], - input_lines_terminated_by=self._config_export[ - 'input_lines_terminated_by'], - input_optionally_enclosed_by=self._config_export[ - 'input_optionally_enclosed_by'], + input_fields_terminated_by=self._config_export['input_fields_terminated_by'], + input_lines_terminated_by=self._config_export['input_lines_terminated_by'], + input_optionally_enclosed_by=self._config_export['input_optionally_enclosed_by'], batch=self._config_export['batch'], relaxed_isolation=self._config_export['relaxed_isolation'], - extra_export_options=self._config_export['extra_export_options'] + extra_export_options=self._config_export['extra_export_options'], ) ) - self.assertIn("--input-null-string {}".format( - self._config_export['input_null_string']), cmd) - self.assertIn("--input-null-non-string {}".format( - self._config_export['input_null_non_string']), cmd) - self.assertIn("--staging-table {}".format( - self._config_export['staging_table']), cmd) - self.assertIn("--enclosed-by {}".format( - self._config_export['enclosed_by']), cmd) - self.assertIn("--escaped-by {}".format( - self._config_export['escaped_by']), cmd) - self.assertIn("--input-fields-terminated-by {}".format( - self._config_export['input_fields_terminated_by']), cmd) - self.assertIn("--input-lines-terminated-by {}".format( - self._config_export['input_lines_terminated_by']), cmd) - self.assertIn("--input-optionally-enclosed-by {}".format( - self._config_export['input_optionally_enclosed_by']), cmd) + self.assertIn("--input-null-string {}".format(self._config_export['input_null_string']), cmd) + self.assertIn("--input-null-non-string {}".format(self._config_export['input_null_non_string']), cmd) + self.assertIn("--staging-table {}".format(self._config_export['staging_table']), cmd) + self.assertIn("--enclosed-by {}".format(self._config_export['enclosed_by']), cmd) + self.assertIn("--escaped-by {}".format(self._config_export['escaped_by']), cmd) + self.assertIn( + "--input-fields-terminated-by {}".format(self._config_export['input_fields_terminated_by']), cmd + ) + self.assertIn( + "--input-lines-terminated-by {}".format(self._config_export['input_lines_terminated_by']), cmd + ) + self.assertIn( + "--input-optionally-enclosed-by {}".format(self._config_export['input_optionally_enclosed_by']), + cmd, + ) # these options are from the extra export options self.assertIn("--update-key id", cmd) self.assertIn("--update-mode allowinsert", cmd) @@ -268,7 +292,7 @@ def test_import_cmd(self): split_by=self._config_import['split_by'], direct=self._config_import['direct'], driver=self._config_import['driver'], - extra_import_options=None + extra_import_options=None, ) ) @@ -278,8 +302,7 @@ def test_import_cmd(self): if self._config_import['direct']: self.assertIn('--direct', cmd) - self.assertIn('--target-dir {}'.format( - self._config_import['target_dir']), cmd) + self.assertIn('--target-dir {}'.format(self._config_import['target_dir']), cmd) self.assertIn('--driver {}'.format(self._config_import['driver']), cmd) self.assertIn('--split-by {}'.format(self._config_import['split_by']), cmd) @@ -295,7 +318,7 @@ def test_import_cmd(self): split_by=self._config_import['split_by'], direct=self._config_import['direct'], driver=self._config_import['driver'], - extra_import_options=self._config_import['extra_import_options'] + extra_import_options=self._config_import['extra_import_options'], ) ) @@ -311,14 +334,10 @@ def test_get_export_format_argument(self): correct Sqoop command with correct format type. """ hook = SqoopHook() - self.assertIn("--as-avrodatafile", - hook._get_export_format_argument('avro')) - self.assertIn("--as-parquetfile", - hook._get_export_format_argument('parquet')) - self.assertIn("--as-sequencefile", - hook._get_export_format_argument('sequence')) - self.assertIn("--as-textfile", - hook._get_export_format_argument('text')) + self.assertIn("--as-avrodatafile", hook._get_export_format_argument('avro')) + self.assertIn("--as-parquetfile", hook._get_export_format_argument('parquet')) + self.assertIn("--as-sequencefile", hook._get_export_format_argument('sequence')) + self.assertIn("--as-textfile", hook._get_export_format_argument('text')) with self.assertRaises(AirflowException): hook._get_export_format_argument('unknown') @@ -327,13 +346,7 @@ def test_cmd_mask_password(self): Tests to verify the hook masking function will correctly mask a user password in Sqoop command. """ hook = SqoopHook() - self.assertEqual( - hook.cmd_mask_password(['--password', 'supersecret']), - ['--password', 'MASKED'] - ) + self.assertEqual(hook.cmd_mask_password(['--password', 'supersecret']), ['--password', 'MASKED']) cmd = ['--target', 'targettable'] - self.assertEqual( - hook.cmd_mask_password(cmd), - cmd - ) + self.assertEqual(hook.cmd_mask_password(cmd), cmd) diff --git a/tests/providers/apache/sqoop/operators/test_sqoop.py b/tests/providers/apache/sqoop/operators/test_sqoop.py index e54304d746a23..295ca841e61fd 100644 --- a/tests/providers/apache/sqoop/operators/test_sqoop.py +++ b/tests/providers/apache/sqoop/operators/test_sqoop.py @@ -54,36 +54,20 @@ class TestSqoopOperator(unittest.TestCase): 'create_hcatalog_table': True, 'hcatalog_database': 'hive_database', 'hcatalog_table': 'hive_table', - 'properties': { - 'mapred.map.max.attempts': '1' - }, - 'extra_import_options': { - 'hcatalog-storage-stanza': "\"stored as orcfile\"", - 'show': '' - }, - 'extra_export_options': { - 'update-key': 'id', - 'update-mode': 'allowinsert', - 'fetch-size': 1 - } + 'properties': {'mapred.map.max.attempts': '1'}, + 'extra_import_options': {'hcatalog-storage-stanza': "\"stored as orcfile\"", 'show': ''}, + 'extra_export_options': {'update-key': 'id', 'update-mode': 'allowinsert', 'fetch-size': 1}, } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) def test_execute(self): """ Tests to verify values of the SqoopOperator match that passed in from the config. """ - operator = SqoopOperator( - task_id='sqoop_job', - dag=self.dag, - **self._config - ) + operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, **self._config) self.assertEqual(self._config['conn_id'], operator.conn_id) self.assertEqual(self._config['query'], operator.query) @@ -121,7 +105,7 @@ def test_execute(self): hcatalog_table='import_table_1', create_hcatalog_table=True, extra_import_options={'hcatalog-storage-stanza': "\"stored as orcfile\""}, - dag=self.dag + dag=self.dag, ) SqoopOperator( @@ -137,7 +121,7 @@ def test_execute(self): hcatalog_table='import_table_2', create_hcatalog_table=True, extra_import_options={'hcatalog-storage-stanza': "\"stored as orcfile\""}, - dag=self.dag + dag=self.dag, ) SqoopOperator( @@ -154,9 +138,9 @@ def test_execute(self): 'hcatalog-storage-stanza': "\"stored as orcfile\"", 'hive-partition-key': 'day', 'hive-partition-value': '2017-10-18', - 'fetch-size': 1 + 'fetch-size': 1, }, - dag=self.dag + dag=self.dag, ) SqoopOperator( @@ -169,7 +153,7 @@ def test_execute(self): hcatalog_database='default', hcatalog_table='hive_export_table_1', extra_export_options=None, - dag=self.dag + dag=self.dag, ) SqoopOperator( @@ -182,15 +166,14 @@ def test_execute(self): verbose=True, num_mappers=None, extra_export_options=None, - dag=self.dag + dag=self.dag, ) def test_invalid_cmd_type(self): """ Tests to verify if the cmd_type is not import or export, an exception is raised. """ - operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, - cmd_type='invalid') + operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, cmd_type='invalid') with self.assertRaises(AirflowException): operator.execute({}) @@ -200,10 +183,6 @@ def test_invalid_import_options(self): """ import_query_and_table_configs = self._config.copy() import_query_and_table_configs['cmd_type'] = 'import' - operator = SqoopOperator( - task_id='sqoop_job', - dag=self.dag, - **import_query_and_table_configs - ) + operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, **import_query_and_table_configs) with self.assertRaises(AirflowException): operator.execute({}) diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py index 1ae4e30020c8d..ba2b15903076c 100644 --- a/tests/providers/celery/sensors/test_celery_queue.py +++ b/tests/providers/celery/sensors/test_celery_queue.py @@ -23,10 +23,8 @@ class TestCeleryQueueSensor(unittest.TestCase): - def setUp(self): class TestCeleryqueueSensor(CeleryQueueSensor): - def _check_task_id(self, context): return True @@ -36,43 +34,29 @@ def _check_task_id(self, context): def test_poke_success(self, mock_inspect): mock_inspect_result = mock_inspect.return_value # test success - mock_inspect_result.reserved.return_value = { - 'test_queue': [] - } + mock_inspect_result.reserved.return_value = {'test_queue': []} - mock_inspect_result.scheduled.return_value = { - 'test_queue': [] - } + mock_inspect_result.scheduled.return_value = {'test_queue': []} - mock_inspect_result.active.return_value = { - 'test_queue': [] - } - test_sensor = self.sensor(celery_queue='test_queue', - task_id='test-task') + mock_inspect_result.active.return_value = {'test_queue': []} + test_sensor = self.sensor(celery_queue='test_queue', task_id='test-task') self.assertTrue(test_sensor.poke(None)) @patch('celery.app.control.Inspect') def test_poke_fail(self, mock_inspect): mock_inspect_result = mock_inspect.return_value # test success - mock_inspect_result.reserved.return_value = { - 'test_queue': [] - } + mock_inspect_result.reserved.return_value = {'test_queue': []} - mock_inspect_result.scheduled.return_value = { - 'test_queue': [] - } + mock_inspect_result.scheduled.return_value = {'test_queue': []} - mock_inspect_result.active.return_value = { - 'test_queue': ['task'] - } - test_sensor = self.sensor(celery_queue='test_queue', - task_id='test-task') + mock_inspect_result.active.return_value = {'test_queue': ['task']} + test_sensor = self.sensor(celery_queue='test_queue', task_id='test-task') self.assertFalse(test_sensor.poke(None)) @patch('celery.app.control.Inspect') def test_poke_success_with_taskid(self, mock_inspect): - test_sensor = self.sensor(celery_queue='test_queue', - task_id='test-task', - target_task_id='target-task') + test_sensor = self.sensor( + celery_queue='test_queue', task_id='test-task', target_task_id='target-task' + ) self.assertTrue(test_sensor.poke(None)) diff --git a/tests/providers/cloudant/hooks/test_cloudant.py b/tests/providers/cloudant/hooks/test_cloudant.py index 147673da70ec6..9d7bb74f2bb47 100644 --- a/tests/providers/cloudant/hooks/test_cloudant.py +++ b/tests/providers/cloudant/hooks/test_cloudant.py @@ -25,12 +25,13 @@ class TestCloudantHook(unittest.TestCase): - def setUp(self): self.cloudant_hook = CloudantHook() - @patch('airflow.providers.cloudant.hooks.cloudant.CloudantHook.get_connection', - return_value=Connection(login='user', password='password', host='account')) + @patch( + 'airflow.providers.cloudant.hooks.cloudant.CloudantHook.get_connection', + return_value=Connection(login='user', password='password', host='account'), + ) @patch('airflow.providers.cloudant.hooks.cloudant.cloudant') def test_get_conn(self, mock_cloudant, mock_get_connection): cloudant_session = self.cloudant_hook.get_conn() @@ -39,8 +40,10 @@ def test_get_conn(self, mock_cloudant, mock_get_connection): mock_cloudant.assert_called_once_with(user=conn.login, passwd=conn.password, account=conn.host) self.assertEqual(cloudant_session, mock_cloudant.return_value) - @patch('airflow.providers.cloudant.hooks.cloudant.CloudantHook.get_connection', - return_value=Connection(login='user')) + @patch( + 'airflow.providers.cloudant.hooks.cloudant.CloudantHook.get_connection', + return_value=Connection(login='user'), + ) def test_get_conn_invalid_connection(self, mock_get_connection): with self.assertRaises(AirflowException): self.cloudant_hook.get_conn() diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index 25654430b209d..e56073e0acdb4 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -33,20 +33,28 @@ class TestKubernetesHook(unittest.TestCase): def setUp(self): db.merge_conn( Connection( - conn_id='kubernetes_in_cluster', conn_type='kubernetes', - extra=json.dumps({'extra__kubernetes__in_cluster': True}))) + conn_id='kubernetes_in_cluster', + conn_type='kubernetes', + extra=json.dumps({'extra__kubernetes__in_cluster': True}), + ) + ) db.merge_conn( Connection( - conn_id='kubernetes_kube_config', conn_type='kubernetes', - extra=json.dumps({'extra__kubernetes__kube_config': '{"test": "kube"}'}))) + conn_id='kubernetes_kube_config', + conn_type='kubernetes', + extra=json.dumps({'extra__kubernetes__kube_config': '{"test": "kube"}'}), + ) + ) db.merge_conn( - Connection( - conn_id='kubernetes_default_kube_config', conn_type='kubernetes', - extra=json.dumps({}))) + Connection(conn_id='kubernetes_default_kube_config', conn_type='kubernetes', extra=json.dumps({})) + ) db.merge_conn( Connection( - conn_id='kubernetes_with_namespace', conn_type='kubernetes', - extra=json.dumps({'extra__kubernetes__namespace': 'mock_namespace'}))) + conn_id='kubernetes_with_namespace', + conn_type='kubernetes', + extra=json.dumps({'extra__kubernetes__namespace': 'mock_namespace'}), + ) + ) @patch("kubernetes.config.incluster_config.InClusterConfigLoader") def test_in_cluster_connection(self, mock_kube_config_loader): @@ -58,10 +66,7 @@ def test_in_cluster_connection(self, mock_kube_config_loader): @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") @patch.object(tempfile, 'NamedTemporaryFile') - def test_kube_config_connection(self, - mock_kube_config_loader, - mock_kube_config_merger, - mock_tempfile): + def test_kube_config_connection(self, mock_kube_config_loader, mock_kube_config_merger, mock_tempfile): kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config') api_conn = kubernetes_hook.get_conn() mock_tempfile.is_called_once() @@ -72,10 +77,9 @@ def test_kube_config_connection(self, @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") @patch("kubernetes.config.kube_config.KUBE_CONFIG_DEFAULT_LOCATION", "/mock/config") - def test_default_kube_config_connection(self, - mock_kube_config_loader, - mock_kube_config_merger, - ): + def test_default_kube_config_connection( + self, mock_kube_config_loader, mock_kube_config_merger, + ): kubernetes_hook = KubernetesHook(conn_id='kubernetes_default_kube_config') api_conn = kubernetes_hook.get_conn() mock_kube_config_loader.assert_called_once_with("/mock/config") diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py index 4dbc2633bea22..1d4dc01f442bf 100644 --- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -28,14 +28,12 @@ class TestKubernetesPodOperator(unittest.TestCase): - @staticmethod def create_context(task): dag = DAG(dag_id="dag") tzinfo = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) - task_instance = TaskInstance(task=task, - execution_date=execution_date) + task_instance = TaskInstance(task=task, execution_date=execution_date) return { "dag": dag, "ts": execution_date.isoformat(), @@ -68,9 +66,7 @@ def test_config_path(self, client_mock, monitor_mock, start_mock): # pylint: di context = self.create_context(k) k.execute(context=context) client_mock.assert_called_once_with( - in_cluster=False, - cluster_context='default', - config_file=file_path, + in_cluster=False, cluster_context='default', config_file=file_path, ) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @@ -98,7 +94,7 @@ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start k.execute(context=context) self.assertEqual( start_mock.call_args[0][0].spec.image_pull_secrets, - [k8s.V1LocalObjectReference(name=fake_pull_secrets)] + [k8s.V1LocalObjectReference(name=fake_pull_secrets)], ) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @@ -106,11 +102,8 @@ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.delete_pod") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_pod_delete_even_on_launcher_error( - self, - mock_client, - delete_pod_mock, - monitor_pod_mock, - start_pod_mock): + self, mock_client, delete_pod_mock, monitor_pod_mock, start_pod_mock + ): k = KubernetesPodOperator( namespace='default', image="ubuntu:16.04", diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index 8c6688dd0af3f..44ffec63465a4 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -26,8 +26,7 @@ from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator from airflow.utils import db, timezone -TEST_VALID_APPLICATION_YAML = \ - """ +TEST_VALID_APPLICATION_YAML = """ apiVersion: "sparkoperator.k8s.io/v1beta2" kind: SparkApplication metadata: @@ -68,8 +67,7 @@ - name: "test-volume" mountPath: "/tmp" """ -TEST_VALID_APPLICATION_JSON = \ - """ +TEST_VALID_APPLICATION_JSON = """ { "apiVersion":"sparkoperator.k8s.io/v1beta2", "kind":"SparkApplication", @@ -129,105 +127,124 @@ } } """ -TEST_APPLICATION_DICT = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': {'name': 'spark-pi', 'namespace': 'default'}, - 'spec': {'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.5'}, - 'memory': '512m', - 'serviceAccount': 'spark', - 'volumeMounts': [{'mountPath': '/tmp', - 'name': 'test-volume'}]}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.5'}, - 'memory': '512m', - 'volumeMounts': [{'mountPath': '/tmp', - 'name': 'test-volume'}]}, - 'image': 'gcr.io/spark-operator/spark:v2.4.5', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.5', - 'type': 'Scala', - 'volumes': [{'hostPath': {'path': '/tmp', 'type': 'Directory'}, - 'name': 'test-volume'}]}} +TEST_APPLICATION_DICT = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': {'name': 'spark-pi', 'namespace': 'default'}, + 'spec': { + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.5'}, + 'memory': '512m', + 'serviceAccount': 'spark', + 'volumeMounts': [{'mountPath': '/tmp', 'name': 'test-volume'}], + }, + 'executor': { + 'cores': 1, + 'instances': 1, + 'labels': {'version': '2.4.5'}, + 'memory': '512m', + 'volumeMounts': [{'mountPath': '/tmp', 'name': 'test-volume'}], + }, + 'image': 'gcr.io/spark-operator/spark:v2.4.5', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.5', + 'type': 'Scala', + 'volumes': [{'hostPath': {'path': '/tmp', 'type': 'Directory'}, 'name': 'test-volume'}], + }, +} @patch('airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_conn') class TestSparkKubernetesOperator(unittest.TestCase): def setUp(self): db.merge_conn( - Connection( - conn_id='kubernetes_default_kube_config', conn_type='kubernetes', - extra=json.dumps({}))) + Connection(conn_id='kubernetes_default_kube_config', conn_type='kubernetes', extra=json.dumps({})) + ) db.merge_conn( Connection( - conn_id='kubernetes_with_namespace', conn_type='kubernetes', - extra=json.dumps({'extra__kubernetes__namespace': 'mock_namespace'}))) - args = { - 'owner': 'airflow', - 'start_date': timezone.datetime(2020, 2, 1) - } + conn_id='kubernetes_with_namespace', + conn_type='kubernetes', + extra=json.dumps({'extra__kubernetes__namespace': 'mock_namespace'}), + ) + ) + args = {'owner': 'airflow', 'start_date': timezone.datetime(2020, 2, 1)} self.dag = DAG('test_dag_id', default_args=args) @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kubernetes_hook): - op = SparkKubernetesOperator(application_file=TEST_VALID_APPLICATION_YAML, - dag=self.dag, - kubernetes_conn_id='kubernetes_default_kube_config', - task_id='test_task_id') + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_YAML, + dag=self.dag, + kubernetes_conn_id='kubernetes_default_kube_config', + task_id='test_task_id', + ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_create_namespaced_crd.assert_called_with(body=TEST_APPLICATION_DICT, - group='sparkoperator.k8s.io', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group='sparkoperator.k8s.io', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kubernetes_hook): - op = SparkKubernetesOperator(application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - kubernetes_conn_id='kubernetes_default_kube_config', - task_id='test_task_id') + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id='kubernetes_default_kube_config', + task_id='test_task_id', + ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_create_namespaced_crd.assert_called_with(body=TEST_APPLICATION_DICT, - group='sparkoperator.k8s.io', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group='sparkoperator.k8s.io', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernetes_hook): - op = SparkKubernetesOperator(application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - namespace='operator_namespace', - kubernetes_conn_id='kubernetes_with_namespace', - task_id='test_task_id') + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + namespace='operator_namespace', + kubernetes_conn_id='kubernetes_with_namespace', + task_id='test_task_id', + ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_create_namespaced_crd.assert_called_with(body=TEST_APPLICATION_DICT, - group='sparkoperator.k8s.io', - namespace='operator_namespace', - plural='sparkapplications', - version='v1beta2') + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group='sparkoperator.k8s.io', + namespace='operator_namespace', + plural='sparkapplications', + version='v1beta2', + ) @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object') def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubernetes_hook): - op = SparkKubernetesOperator(application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - kubernetes_conn_id='kubernetes_with_namespace', - task_id='test_task_id') + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id='kubernetes_with_namespace', + task_id='test_task_id', + ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_create_namespaced_crd.assert_called_with(body=TEST_APPLICATION_DICT, - group='sparkoperator.k8s.io', - namespace='mock_namespace', - plural='sparkapplications', - version='v1beta2') + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group='sparkoperator.k8s.io', + namespace='mock_namespace', + plural='sparkapplications', + version='v1beta2', + ) diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes_system.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes_system.py index 289f5d74bbfbc..b85de2d53f8e4 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes_system.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes_system.py @@ -25,20 +25,22 @@ from tests.test_utils.system_tests_class import SystemTest KUBERNETES_DAG_FOLDER = os.path.join( - AIRFLOW_MAIN_FOLDER, "airflow", "providers", "cncf", "kubernetes", "example_dags") + AIRFLOW_MAIN_FOLDER, "airflow", "providers", "cncf", "kubernetes", "example_dags" +) SPARK_OPERATOR_VERSION = "v1beta2-1.1.1-2.4.5" -MANIFEST_BASE_URL = \ - f'https://raw.githubusercontent.com/GoogleCloudPlatform/spark-on-k8s-operator/' \ +MANIFEST_BASE_URL = ( + f'https://raw.githubusercontent.com/GoogleCloudPlatform/spark-on-k8s-operator/' f'{SPARK_OPERATOR_VERSION}/manifest/' +) SPARK_OPERATOR_MANIFESTS = [ f"{MANIFEST_BASE_URL}crds/sparkoperator.k8s.io_sparkapplications.yaml", f"{MANIFEST_BASE_URL}crds/sparkoperator.k8s.io_scheduledsparkapplications.yaml", f"{MANIFEST_BASE_URL}spark-operator-rbac.yaml", f"{MANIFEST_BASE_URL}spark-operator.yaml", - f"{MANIFEST_BASE_URL}spark-rbac.yaml" + f"{MANIFEST_BASE_URL}spark-rbac.yaml", ] @@ -56,15 +58,11 @@ def kubectl_delete_list(manifests): @pytest.mark.system("cncf.kubernetes") class SparkKubernetesExampleDagsSystemTest(SystemTest): - def setUp(self): super().setUp() kubectl_apply_list(SPARK_OPERATOR_MANIFESTS) if os.environ.get("RUN_AIRFLOW_1_10") == "true": - db.merge_conn( - Connection( - conn_id='kubernetes_default', conn_type='kubernetes' - )) + db.merge_conn(Connection(conn_id='kubernetes_default', conn_type='kubernetes')) def tearDown(self): super().tearDown() diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py index 62747fc8f94bf..98717a22d0233 100644 --- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py @@ -27,523 +27,593 @@ from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import SparkKubernetesSensor from airflow.utils import db, timezone -TEST_COMPLETED_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': {'creationTimestamp': '2020-02-24T07:34:22Z', - 'generation': 1, - 'labels': {'spark_flow_name': 'spark-pi'}, - 'name': 'spark-pi-2020-02-24-1', - 'namespace': 'default', - 'resourceVersion': '455577', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '9f825516-6e1a-4af1-8967-b05661e8fb08'}, - 'spec': {'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'spark_flow_name': 'spark-pi', - 'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default', - 'volumeMounts': [{'mountPath': '/tmp', - 'name': 'test-volume'}]}, - 'executor': {'cores': 1, - 'instances': 3, - 'labels': {'spark_flow_name': 'spark-pi', - 'version': '2.4.4'}, - 'memory': '512m', - 'volumeMounts': [{'mountPath': '/tmp', - 'name': 'test-volume'}]}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala', - 'volumes': [{'hostPath': {'path': '/tmp', 'type': 'Directory'}, - 'name': 'test-volume'}]}, - 'status': {'applicationState': {'state': 'COMPLETED'}, - 'driverInfo': {'podName': 'spark-pi-2020-02-24-1-driver', - 'webUIAddress': '10.97.130.44:4040', - 'webUIPort': 4040, - 'webUIServiceName': 'spark-pi-2020-02-24-1-ui-svc'}, - 'executionAttempts': 1, - 'executorState': {'spark-pi-2020-02-24-1-1582529666227-exec-1': 'FAILED', - 'spark-pi-2020-02-24-1-1582529666227-exec-2': 'FAILED', - 'spark-pi-2020-02-24-1-1582529666227-exec-3': 'FAILED'}, - 'lastSubmissionAttemptTime': '2020-02-24T07:34:30Z', - 'sparkApplicationId': 'spark-7bb432c422ca46f3854838c419460fec', - 'submissionAttempts': 1, - 'submissionID': '1a1f9c5e-6bdd-4824-806f-40a814c1cf43', - 'terminationTime': '2020-02-24T07:35:01Z'}} +TEST_COMPLETED_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-24T07:34:22Z', + 'generation': 1, + 'labels': {'spark_flow_name': 'spark-pi'}, + 'name': 'spark-pi-2020-02-24-1', + 'namespace': 'default', + 'resourceVersion': '455577', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '9f825516-6e1a-4af1-8967-b05661e8fb08', + }, + 'spec': { + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'spark_flow_name': 'spark-pi', 'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + 'volumeMounts': [{'mountPath': '/tmp', 'name': 'test-volume'}], + }, + 'executor': { + 'cores': 1, + 'instances': 3, + 'labels': {'spark_flow_name': 'spark-pi', 'version': '2.4.4'}, + 'memory': '512m', + 'volumeMounts': [{'mountPath': '/tmp', 'name': 'test-volume'}], + }, + 'image': 'gcr.io/spark-operator/spark:v2.4.4', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + 'volumes': [{'hostPath': {'path': '/tmp', 'type': 'Directory'}, 'name': 'test-volume'}], + }, + 'status': { + 'applicationState': {'state': 'COMPLETED'}, + 'driverInfo': { + 'podName': 'spark-pi-2020-02-24-1-driver', + 'webUIAddress': '10.97.130.44:4040', + 'webUIPort': 4040, + 'webUIServiceName': 'spark-pi-2020-02-24-1-ui-svc', + }, + 'executionAttempts': 1, + 'executorState': { + 'spark-pi-2020-02-24-1-1582529666227-exec-1': 'FAILED', + 'spark-pi-2020-02-24-1-1582529666227-exec-2': 'FAILED', + 'spark-pi-2020-02-24-1-1582529666227-exec-3': 'FAILED', + }, + 'lastSubmissionAttemptTime': '2020-02-24T07:34:30Z', + 'sparkApplicationId': 'spark-7bb432c422ca46f3854838c419460fec', + 'submissionAttempts': 1, + 'submissionID': '1a1f9c5e-6bdd-4824-806f-40a814c1cf43', + 'terminationTime': '2020-02-24T07:35:01Z', + }, +} -TEST_FAILED_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': { - 'creationTimestamp': '2020-02-26T11:59:30Z', - 'generation': 1, - 'name': 'spark-pi', - 'namespace': 'default', - 'resourceVersion': '531657', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': 'f507ee3a-4461-45ef-86d8-ff42e4211e7d'}, - 'spec': {'arguments': ['100000'], - 'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default'}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m'}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi123', - 'mode': 'cluster', - 'monitoring': {'exposeDriverMetrics': True, - 'exposeExecutorMetrics': True, - 'prometheus': - {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', - 'port': 8090}}, - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala'}, - 'status': {'applicationState': {'errorMessage': 'driver pod failed with ' - 'ExitCode: 101, Reason: Error', - 'state': 'FAILED'}, - 'driverInfo': {'podName': 'spark-pi-driver', - 'webUIAddress': '10.108.18.168:4040', - 'webUIPort': 4040, - 'webUIServiceName': 'spark-pi-ui-svc'}, - 'executionAttempts': 1, - 'lastSubmissionAttemptTime': '2020-02-26T11:59:38Z', - 'sparkApplicationId': 'spark-5fb7445d988f434cbe1e86166a0c038a', - 'submissionAttempts': 1, - 'submissionID': '26654a75-5bf6-4618-b191-0340280d2d3d', - 'terminationTime': '2020-02-26T11:59:49Z'}} +TEST_FAILED_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-26T11:59:30Z', + 'generation': 1, + 'name': 'spark-pi', + 'namespace': 'default', + 'resourceVersion': '531657', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': 'f507ee3a-4461-45ef-86d8-ff42e4211e7d', + }, + 'spec': { + 'arguments': ['100000'], + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + }, + 'executor': {'cores': 1, 'instances': 1, 'labels': {'version': '2.4.4'}, 'memory': '512m'}, + 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi123', + 'mode': 'cluster', + 'monitoring': { + 'exposeDriverMetrics': True, + 'exposeExecutorMetrics': True, + 'prometheus': {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', 'port': 8090}, + }, + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + }, + 'status': { + 'applicationState': { + 'errorMessage': 'driver pod failed with ' 'ExitCode: 101, Reason: Error', + 'state': 'FAILED', + }, + 'driverInfo': { + 'podName': 'spark-pi-driver', + 'webUIAddress': '10.108.18.168:4040', + 'webUIPort': 4040, + 'webUIServiceName': 'spark-pi-ui-svc', + }, + 'executionAttempts': 1, + 'lastSubmissionAttemptTime': '2020-02-26T11:59:38Z', + 'sparkApplicationId': 'spark-5fb7445d988f434cbe1e86166a0c038a', + 'submissionAttempts': 1, + 'submissionID': '26654a75-5bf6-4618-b191-0340280d2d3d', + 'terminationTime': '2020-02-26T11:59:49Z', + }, +} -TEST_UNKNOWN_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': {'creationTimestamp': '2020-02-24T07:34:22Z', - 'generation': 1, - 'labels': {'spark_flow_name': 'spark-pi'}, - 'name': 'spark-pi-2020-02-24-1', - 'namespace': 'default', - 'resourceVersion': '455577', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '9f825516-6e1a-4af1-8967-b05661e8fb08'}, - 'spec': {'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'spark_flow_name': 'spark-pi', - 'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default', - 'volumeMounts': [{'mountPath': '/tmp', - 'name': 'test-volume'}]}, - 'executor': {'cores': 1, - 'instances': 3, - 'labels': {'spark_flow_name': 'spark-pi', - 'version': '2.4.4'}, - 'memory': '512m', - 'volumeMounts': [{'mountPath': '/tmp', - 'name': 'test-volume'}]}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala', - 'volumes': [{'hostPath': {'path': '/tmp', 'type': 'Directory'}, - 'name': 'test-volume'}]}, - 'status': {'applicationState': {'state': 'UNKNOWN'}, - 'driverInfo': {'podName': 'spark-pi-2020-02-24-1-driver', - 'webUIAddress': '10.97.130.44:4040', - 'webUIPort': 4040, - 'webUIServiceName': 'spark-pi-2020-02-24-1-ui-svc'}, - 'executionAttempts': 1, - 'executorState': {'spark-pi-2020-02-24-1-1582529666227-exec-1': 'FAILED', - 'spark-pi-2020-02-24-1-1582529666227-exec-2': 'FAILED', - 'spark-pi-2020-02-24-1-1582529666227-exec-3': 'FAILED'}, - 'lastSubmissionAttemptTime': '2020-02-24T07:34:30Z', - 'sparkApplicationId': 'spark-7bb432c422ca46f3854838c419460fec', - 'submissionAttempts': 1, - 'submissionID': '1a1f9c5e-6bdd-4824-806f-40a814c1cf43', - 'terminationTime': '2020-02-24T07:35:01Z'}} -TEST_NOT_PROCESSED_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': { - 'creationTimestamp': '2020-02-26T09:14:48Z', - 'generation': 1, - 'name': 'spark-pi', - 'namespace': 'default', - 'resourceVersion': '525235', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '58da0778-fa72-4e90-8ddc-18b5e658f93d'}, - 'spec': {'arguments': ['100000'], - 'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default'}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m'}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'monitoring': {'exposeDriverMetrics': True, - 'exposeExecutorMetrics': True, - 'prometheus': - {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', - 'port': 8090}}, - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala'}} +TEST_UNKNOWN_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-24T07:34:22Z', + 'generation': 1, + 'labels': {'spark_flow_name': 'spark-pi'}, + 'name': 'spark-pi-2020-02-24-1', + 'namespace': 'default', + 'resourceVersion': '455577', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '9f825516-6e1a-4af1-8967-b05661e8fb08', + }, + 'spec': { + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'spark_flow_name': 'spark-pi', 'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + 'volumeMounts': [{'mountPath': '/tmp', 'name': 'test-volume'}], + }, + 'executor': { + 'cores': 1, + 'instances': 3, + 'labels': {'spark_flow_name': 'spark-pi', 'version': '2.4.4'}, + 'memory': '512m', + 'volumeMounts': [{'mountPath': '/tmp', 'name': 'test-volume'}], + }, + 'image': 'gcr.io/spark-operator/spark:v2.4.4', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + 'volumes': [{'hostPath': {'path': '/tmp', 'type': 'Directory'}, 'name': 'test-volume'}], + }, + 'status': { + 'applicationState': {'state': 'UNKNOWN'}, + 'driverInfo': { + 'podName': 'spark-pi-2020-02-24-1-driver', + 'webUIAddress': '10.97.130.44:4040', + 'webUIPort': 4040, + 'webUIServiceName': 'spark-pi-2020-02-24-1-ui-svc', + }, + 'executionAttempts': 1, + 'executorState': { + 'spark-pi-2020-02-24-1-1582529666227-exec-1': 'FAILED', + 'spark-pi-2020-02-24-1-1582529666227-exec-2': 'FAILED', + 'spark-pi-2020-02-24-1-1582529666227-exec-3': 'FAILED', + }, + 'lastSubmissionAttemptTime': '2020-02-24T07:34:30Z', + 'sparkApplicationId': 'spark-7bb432c422ca46f3854838c419460fec', + 'submissionAttempts': 1, + 'submissionID': '1a1f9c5e-6bdd-4824-806f-40a814c1cf43', + 'terminationTime': '2020-02-24T07:35:01Z', + }, +} +TEST_NOT_PROCESSED_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-26T09:14:48Z', + 'generation': 1, + 'name': 'spark-pi', + 'namespace': 'default', + 'resourceVersion': '525235', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '58da0778-fa72-4e90-8ddc-18b5e658f93d', + }, + 'spec': { + 'arguments': ['100000'], + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + }, + 'executor': {'cores': 1, 'instances': 1, 'labels': {'version': '2.4.4'}, 'memory': '512m'}, + 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'monitoring': { + 'exposeDriverMetrics': True, + 'exposeExecutorMetrics': True, + 'prometheus': {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', 'port': 8090}, + }, + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + }, +} -TEST_RUNNING_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': { - 'creationTimestamp': '2020-02-26T09:11:25Z', - 'generation': 1, - 'name': 'spark-pi', - 'namespace': 'default', - 'resourceVersion': '525001', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '95ff1418-eeb5-454c-b59e-9e021aa3a239'}, - 'spec': {'arguments': ['100000'], - 'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default'}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m'}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'monitoring': {'exposeDriverMetrics': True, - 'exposeExecutorMetrics': True, - 'prometheus': - {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', - 'port': 8090}}, - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala'}, - 'status': {'applicationState': {'state': 'RUNNING'}, - 'driverInfo': {'podName': 'spark-pi-driver', - 'webUIAddress': '10.106.36.53:4040', - 'webUIPort': 4040, - 'webUIServiceName': 'spark-pi-ui-svc'}, - 'executionAttempts': 1, - 'executorState': {'spark-pi-1582708290692-exec-1': 'RUNNING'}, - 'lastSubmissionAttemptTime': '2020-02-26T09:11:35Z', - 'sparkApplicationId': 'spark-a47a002df46448f1a8395d7dd79ba448', - 'submissionAttempts': 1, - 'submissionID': 'd4f5a768-b9d1-4a79-92b0-54779124d997', - 'terminationTime': None}} +TEST_RUNNING_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-26T09:11:25Z', + 'generation': 1, + 'name': 'spark-pi', + 'namespace': 'default', + 'resourceVersion': '525001', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '95ff1418-eeb5-454c-b59e-9e021aa3a239', + }, + 'spec': { + 'arguments': ['100000'], + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + }, + 'executor': {'cores': 1, 'instances': 1, 'labels': {'version': '2.4.4'}, 'memory': '512m'}, + 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'monitoring': { + 'exposeDriverMetrics': True, + 'exposeExecutorMetrics': True, + 'prometheus': {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', 'port': 8090}, + }, + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + }, + 'status': { + 'applicationState': {'state': 'RUNNING'}, + 'driverInfo': { + 'podName': 'spark-pi-driver', + 'webUIAddress': '10.106.36.53:4040', + 'webUIPort': 4040, + 'webUIServiceName': 'spark-pi-ui-svc', + }, + 'executionAttempts': 1, + 'executorState': {'spark-pi-1582708290692-exec-1': 'RUNNING'}, + 'lastSubmissionAttemptTime': '2020-02-26T09:11:35Z', + 'sparkApplicationId': 'spark-a47a002df46448f1a8395d7dd79ba448', + 'submissionAttempts': 1, + 'submissionID': 'd4f5a768-b9d1-4a79-92b0-54779124d997', + 'terminationTime': None, + }, +} -TEST_SUBMITTED_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': { - 'creationTimestamp': '2020-02-26T09:16:53Z', - 'generation': 1, - 'name': 'spark-pi', - 'namespace': 'default', - 'resourceVersion': '525536', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '424a682b-6e5c-40d5-8a41-164253500b58'}, - 'spec': {'arguments': ['100000'], - 'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default'}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m'}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'monitoring': {'exposeDriverMetrics': True, - 'exposeExecutorMetrics': True, - 'prometheus': - {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', - 'port': 8090}}, - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala'}, - 'status': {'applicationState': {'state': 'SUBMITTED'}, - 'driverInfo': {'podName': 'spark-pi-driver', - 'webUIAddress': '10.108.175.17:4040', - 'webUIPort': 4040, - 'webUIServiceName': 'spark-pi-ui-svc'}, - 'executionAttempts': 1, - 'lastSubmissionAttemptTime': '2020-02-26T09:17:03Z', - 'sparkApplicationId': 'spark-ae1a522d200246a99470743e880c5650', - 'submissionAttempts': 1, - 'submissionID': 'f8b70b0b-3c81-403f-8c6d-e7f6c3653409', - 'terminationTime': None}} +TEST_SUBMITTED_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-26T09:16:53Z', + 'generation': 1, + 'name': 'spark-pi', + 'namespace': 'default', + 'resourceVersion': '525536', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '424a682b-6e5c-40d5-8a41-164253500b58', + }, + 'spec': { + 'arguments': ['100000'], + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + }, + 'executor': {'cores': 1, 'instances': 1, 'labels': {'version': '2.4.4'}, 'memory': '512m'}, + 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'monitoring': { + 'exposeDriverMetrics': True, + 'exposeExecutorMetrics': True, + 'prometheus': {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', 'port': 8090}, + }, + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + }, + 'status': { + 'applicationState': {'state': 'SUBMITTED'}, + 'driverInfo': { + 'podName': 'spark-pi-driver', + 'webUIAddress': '10.108.175.17:4040', + 'webUIPort': 4040, + 'webUIServiceName': 'spark-pi-ui-svc', + }, + 'executionAttempts': 1, + 'lastSubmissionAttemptTime': '2020-02-26T09:17:03Z', + 'sparkApplicationId': 'spark-ae1a522d200246a99470743e880c5650', + 'submissionAttempts': 1, + 'submissionID': 'f8b70b0b-3c81-403f-8c6d-e7f6c3653409', + 'terminationTime': None, + }, +} -TEST_NEW_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': { - 'creationTimestamp': '2020-02-26T09:16:53Z', - 'generation': 1, - 'name': 'spark-pi', - 'namespace': 'default', - 'resourceVersion': '525536', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '424a682b-6e5c-40d5-8a41-164253500b58'}, - 'spec': {'arguments': ['100000'], - 'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default'}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m'}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'monitoring': {'exposeDriverMetrics': True, - 'exposeExecutorMetrics': True, - 'prometheus': - {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', - 'port': 8090}}, - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala'}, - 'status': {'applicationState': {'state': ''}}} +TEST_NEW_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-26T09:16:53Z', + 'generation': 1, + 'name': 'spark-pi', + 'namespace': 'default', + 'resourceVersion': '525536', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '424a682b-6e5c-40d5-8a41-164253500b58', + }, + 'spec': { + 'arguments': ['100000'], + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + }, + 'executor': {'cores': 1, 'instances': 1, 'labels': {'version': '2.4.4'}, 'memory': '512m'}, + 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'monitoring': { + 'exposeDriverMetrics': True, + 'exposeExecutorMetrics': True, + 'prometheus': {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', 'port': 8090}, + }, + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + }, + 'status': {'applicationState': {'state': ''}}, +} -TEST_PENDING_RERUN_APPLICATION = \ - {'apiVersion': 'sparkoperator.k8s.io/v1beta2', - 'kind': 'SparkApplication', - 'metadata': { - 'creationTimestamp': '2020-02-27T08:03:02Z', - 'generation': 4, - 'name': 'spark-pi', - 'namespace': 'default', - 'resourceVersion': '552073', - 'selfLink': - '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', - 'uid': '0c93527d-4dd9-4006-b40a-1672872e8d6f'}, - 'spec': {'arguments': ['100000'], - 'driver': {'coreLimit': '1200m', - 'cores': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m', - 'serviceAccount': 'default'}, - 'executor': {'cores': 1, - 'instances': 1, - 'labels': {'version': '2.4.4'}, - 'memory': '512m'}, - 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', - 'imagePullPolicy': 'Always', - 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', - 'mainClass': 'org.apache.spark.examples.SparkPi', - 'mode': 'cluster', - 'monitoring': {'exposeDriverMetrics': True, - 'exposeExecutorMetrics': True, - 'prometheus': - {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', - 'port': 8090}}, - 'restartPolicy': {'type': 'Never'}, - 'sparkVersion': '2.4.4', - 'type': 'Scala'}, - 'status': {'applicationState': {'state': 'PENDING_RERUN'}, - 'driverInfo': {}, - 'lastSubmissionAttemptTime': None, - 'terminationTime': None}} +TEST_PENDING_RERUN_APPLICATION = { + 'apiVersion': 'sparkoperator.k8s.io/v1beta2', + 'kind': 'SparkApplication', + 'metadata': { + 'creationTimestamp': '2020-02-27T08:03:02Z', + 'generation': 4, + 'name': 'spark-pi', + 'namespace': 'default', + 'resourceVersion': '552073', + 'selfLink': '/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi', + 'uid': '0c93527d-4dd9-4006-b40a-1672872e8d6f', + }, + 'spec': { + 'arguments': ['100000'], + 'driver': { + 'coreLimit': '1200m', + 'cores': 1, + 'labels': {'version': '2.4.4'}, + 'memory': '512m', + 'serviceAccount': 'default', + }, + 'executor': {'cores': 1, 'instances': 1, 'labels': {'version': '2.4.4'}, 'memory': '512m'}, + 'image': 'gcr.io/spark-operator/spark:v2.4.4-gcs-prometheus', + 'imagePullPolicy': 'Always', + 'mainApplicationFile': 'local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar', + 'mainClass': 'org.apache.spark.examples.SparkPi', + 'mode': 'cluster', + 'monitoring': { + 'exposeDriverMetrics': True, + 'exposeExecutorMetrics': True, + 'prometheus': {'jmxExporterJar': '/prometheus/jmx_prometheus_javaagent-0.11.0.jar', 'port': 8090}, + }, + 'restartPolicy': {'type': 'Never'}, + 'sparkVersion': '2.4.4', + 'type': 'Scala', + }, + 'status': { + 'applicationState': {'state': 'PENDING_RERUN'}, + 'driverInfo': {}, + 'lastSubmissionAttemptTime': None, + 'terminationTime': None, + }, +} @patch('airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_conn') class TestSparkKubernetesSensor(unittest.TestCase): def setUp(self): + db.merge_conn(Connection(conn_id='kubernetes_default', conn_type='kubernetes', extra=json.dumps({}))) db.merge_conn( Connection( - conn_id='kubernetes_default', conn_type='kubernetes', - extra=json.dumps({}))) - db.merge_conn( - Connection( - conn_id='kubernetes_with_namespace', conn_type='kubernetes', - extra=json.dumps({'extra__kubernetes__namespace': 'mock_namespace'}))) - args = { - 'owner': 'airflow', - 'start_date': timezone.datetime(2020, 2, 1) - } + conn_id='kubernetes_with_namespace', + conn_type='kubernetes', + extra=json.dumps({'extra__kubernetes__namespace': 'mock_namespace'}), + ) + ) + args = {'owner': 'airflow', 'start_date': timezone.datetime(2020, 2, 1)} self.dag = DAG('test_dag_id', default_args=args) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_COMPLETED_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_COMPLETED_APPLICATION, + ) def test_completed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertTrue(sensor.poke(None)) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_FAILED_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_FAILED_APPLICATION, + ) def test_failed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertRaises(AirflowException, sensor.poke, None) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_NOT_PROCESSED_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_NOT_PROCESSED_APPLICATION, + ) def test_not_processed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertFalse(sensor.poke(None)) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_NEW_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_NEW_APPLICATION, + ) def test_new_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertFalse(sensor.poke(None)) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_RUNNING_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_RUNNING_APPLICATION, + ) def test_running_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertFalse(sensor.poke(None)) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_SUBMITTED_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_SUBMITTED_APPLICATION, + ) def test_submitted_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertFalse(sensor.poke(None)) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_PENDING_RERUN_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_PENDING_RERUN_APPLICATION, + ) def test_pending_rerun_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertFalse(sensor.poke(None)) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_UNKNOWN_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_UNKNOWN_APPLICATION, + ) def test_unknown_application(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - task_id='test_task_id') + sensor = SparkKubernetesSensor(application_name='spark_pi', dag=self.dag, task_id='test_task_id') self.assertRaises(AirflowException, sensor.poke, None) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='default', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='default', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_COMPLETED_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_COMPLETED_APPLICATION, + ) def test_namespace_from_sensor(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - kubernetes_conn_id='kubernetes_with_namespace', - namespace='sensor_namespace', - task_id='test_task_id') + sensor = SparkKubernetesSensor( + application_name='spark_pi', + dag=self.dag, + kubernetes_conn_id='kubernetes_with_namespace', + namespace='sensor_namespace', + task_id='test_task_id', + ) sensor.poke(None) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='sensor_namespace', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='sensor_namespace', + plural='sparkapplications', + version='v1beta2', + ) - @patch('kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', - return_value=TEST_COMPLETED_APPLICATION) + @patch( + 'kubernetes.client.apis.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object', + return_value=TEST_COMPLETED_APPLICATION, + ) def test_namespace_from_connection(self, mock_get_namespaced_crd, mock_kubernetes_hook): - sensor = SparkKubernetesSensor(application_name='spark_pi', - dag=self.dag, - kubernetes_conn_id='kubernetes_with_namespace', - task_id='test_task_id') + sensor = SparkKubernetesSensor( + application_name='spark_pi', + dag=self.dag, + kubernetes_conn_id='kubernetes_with_namespace', + task_id='test_task_id', + ) sensor.poke(None) mock_kubernetes_hook.assert_called_once_with() - mock_get_namespaced_crd.assert_called_once_with(group='sparkoperator.k8s.io', - name='spark_pi', - namespace='mock_namespace', - plural='sparkapplications', - version='v1beta2') + mock_get_namespaced_crd.assert_called_once_with( + group='sparkoperator.k8s.io', + name='spark_pi', + namespace='mock_namespace', + plural='sparkapplications', + version='v1beta2', + ) diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 2ef1273264f6e..aa00cbedc964c 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -32,18 +32,9 @@ TASK_ID = 'databricks-operator' DEFAULT_CONN_ID = 'databricks_default' -NOTEBOOK_TASK = { - 'notebook_path': '/test' -} -SPARK_PYTHON_TASK = { - 'python_file': 'test.py', - 'parameters': ['--param', '123'] -} -NEW_CLUSTER = { - 'spark_version': '2.0.x-scala2.10', - 'node_type_id': 'r3.xlarge', - 'num_workers': 1 -} +NOTEBOOK_TASK = {'notebook_path': '/test'} +SPARK_PYTHON_TASK = {'python_file': 'test.py', 'parameters': ['--param', '123']} +NEW_CLUSTER = {'spark_version': '2.0.x-scala2.10', 'node_type_id': 'r3.xlarge', 'num_workers': 1} CLUSTER_ID = 'cluster_id' RUN_ID = 1 JOB_ID = 42 @@ -59,15 +50,9 @@ GET_RUN_RESPONSE = { 'job_id': JOB_ID, 'run_page_url': RUN_PAGE_URL, - 'state': { - 'life_cycle_state': LIFE_CYCLE_STATE, - 'state_message': STATE_MESSAGE - } -} -NOTEBOOK_PARAMS = { - "dry-run": "true", - "oldest-time-to-consider": "1457570074236" + 'state': {'life_cycle_state': LIFE_CYCLE_STATE, 'state_message': STATE_MESSAGE}, } +NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] RESULT_STATE = None # type: None @@ -137,11 +122,7 @@ def create_post_side_effect(exception, status_code=500): return response -def setup_mock_requests(mock_requests, - exception, - status_code=500, - error_count=None, - response_content=None): +def setup_mock_requests(mock_requests, exception, status_code=500, error_count=None, response_content=None): side_effect = create_post_side_effect(exception, status_code) if error_count is None: @@ -149,8 +130,9 @@ def setup_mock_requests(mock_requests, mock_requests.post.side_effect = itertools.repeat(side_effect) else: # POST requests will fail 'error_count' times, and then they will succeed (once) - mock_requests.post.side_effect = \ - [side_effect] * error_count + [create_valid_response_mock(response_content)] + mock_requests.post.side_effect = [side_effect] * error_count + [ + create_valid_response_mock(response_content) + ] class TestDatabricksHook(unittest.TestCase): @@ -160,9 +142,7 @@ class TestDatabricksHook(unittest.TestCase): @provide_session def setUp(self, session=None): - conn = session.query(Connection) \ - .filter(Connection.conn_id == DEFAULT_CONN_ID) \ - .first() + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.login = LOGIN conn.password = PASSWORD @@ -184,11 +164,13 @@ def test_init_bad_retry_limit(self): DatabricksHook(retry_limit=0) def test_do_api_call_retries_with_retryable_error(self): - for exception in [requests_exceptions.ConnectionError, - requests_exceptions.SSLError, - requests_exceptions.Timeout, - requests_exceptions.ConnectTimeout, - requests_exceptions.HTTPError]: + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError, + ]: with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests: with mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception) @@ -200,9 +182,7 @@ def test_do_api_call_retries_with_retryable_error(self): @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): - setup_mock_requests( - mock_requests, requests_exceptions.HTTPError, status_code=400 - ) + setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=400) with mock.patch.object(self.hook.log, 'error') as mock_errors: with self.assertRaises(AirflowException): @@ -211,18 +191,17 @@ def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests mock_errors.assert_not_called() def test_do_api_call_succeeds_after_retrying(self): - for exception in [requests_exceptions.ConnectionError, - requests_exceptions.SSLError, - requests_exceptions.Timeout, - requests_exceptions.ConnectTimeout, - requests_exceptions.HTTPError]: + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError, + ]: with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests: with mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests( - mock_requests, - exception, - error_count=2, - response_content={'run_id': '1'} + mock_requests, exception, error_count=2, response_content={'run_id': '1'} ) response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) @@ -235,11 +214,13 @@ def test_do_api_call_waits_between_retries(self, mock_sleep): retry_delay = 5 self.hook = DatabricksHook(retry_delay=retry_delay) - for exception in [requests_exceptions.ConnectionError, - requests_exceptions.SSLError, - requests_exceptions.Timeout, - requests_exceptions.ConnectTimeout, - requests_exceptions.HTTPError]: + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError, + ]: with mock.patch('airflow.providers.databricks.hooks.databricks.requests') as mock_requests: with mock.patch.object(self.hook.log, 'error'): mock_sleep.reset_mock() @@ -249,72 +230,56 @@ def test_do_api_call_waits_between_retries(self, mock_sleep): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) self.assertEqual(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) - calls = [ - mock.call(retry_delay), - mock.call(retry_delay) - ] + calls = [mock.call(retry_delay), mock.call(retry_delay)] mock_sleep.assert_has_calls(calls) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_do_api_call_patch(self, mock_requests): mock_requests.patch.return_value.json.return_value = {'cluster_name': 'new_name'} - data = { - 'cluster_name': 'new_name' - } + data = {'cluster_name': 'new_name'} patched_cluster_name = self.hook._do_api_call(('PATCH', 'api/2.0/jobs/runs/submit'), data) self.assertEqual(patched_cluster_name['cluster_name'], 'new_name') mock_requests.patch.assert_called_once_with( submit_run_endpoint(HOST), - json={ - 'cluster_name': 'new_name' - }, + json={'cluster_name': 'new_name'}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} - data = { - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER - } + data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) self.assertEqual(run_id, '1') mock_requests.post.assert_called_once_with( submit_run_endpoint(HOST), - json={ - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER, - }, + json={'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER,}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_spark_python_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} - data = { - 'spark_python_task': SPARK_PYTHON_TASK, - 'new_cluster': NEW_CLUSTER - } + data = {'spark_python_task': SPARK_PYTHON_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) self.assertEqual(run_id, '1') mock_requests.post.assert_called_once_with( submit_run_endpoint(HOST), - json={ - 'spark_python_task': SPARK_PYTHON_TASK, - 'new_cluster': NEW_CLUSTER, - }, + json={'spark_python_task': SPARK_PYTHON_TASK, 'new_cluster': NEW_CLUSTER,}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_run_now(self, mock_requests): @@ -322,26 +287,19 @@ def test_run_now(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock - data = { - 'notebook_params': NOTEBOOK_PARAMS, - 'jar_params': JAR_PARAMS, - 'job_id': JOB_ID - } + data = {'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID} run_id = self.hook.run_now(data) self.assertEqual(run_id, '1') mock_requests.post.assert_called_once_with( run_now_endpoint(HOST), - json={ - 'notebook_params': NOTEBOOK_PARAMS, - 'jar_params': JAR_PARAMS, - 'job_id': JOB_ID - }, + json={'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID}, params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_run_page_url(self, mock_requests): @@ -356,7 +314,8 @@ def test_get_run_page_url(self, mock_requests): params={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_job_id(self, mock_requests): @@ -371,7 +330,8 @@ def test_get_job_id(self, mock_requests): params={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_get_run_state(self, mock_requests): @@ -379,17 +339,15 @@ def test_get_run_state(self, mock_requests): run_state = self.hook.get_run_state(RUN_ID) - self.assertEqual(run_state, RunState( - LIFE_CYCLE_STATE, - RESULT_STATE, - STATE_MESSAGE)) + self.assertEqual(run_state, RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE)) mock_requests.get.assert_called_once_with( get_run_endpoint(HOST), json=None, params={'run_id': RUN_ID}, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_cancel_run(self, mock_requests): @@ -403,7 +361,8 @@ def test_cancel_run(self, mock_requests): params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_start_cluster(self, mock_requests): @@ -420,7 +379,8 @@ def test_start_cluster(self, mock_requests): params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_restart_cluster(self, mock_requests): @@ -437,7 +397,8 @@ def test_restart_cluster(self, mock_requests): params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_terminate_cluster(self, mock_requests): @@ -454,7 +415,8 @@ def test_terminate_cluster(self, mock_requests): params=None, auth=(LOGIN, PASSWORD), headers=USER_AGENT_HEADER, - timeout=self.hook.timeout_seconds) + timeout=self.hook.timeout_seconds, + ) class TestDatabricksHookToken(unittest.TestCase): @@ -464,9 +426,7 @@ class TestDatabricksHookToken(unittest.TestCase): @provide_session def setUp(self, session=None): - conn = session.query(Connection) \ - .filter(Connection.conn_id == DEFAULT_CONN_ID) \ - .first() + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.extra = json.dumps({'token': TOKEN, 'host': HOST}) session.commit() @@ -479,10 +439,7 @@ def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock - data = { - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER - } + data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} run_id = self.hook.submit_run(data) self.assertEqual(run_id, '1') diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 78f5e71e1c8fc..539037a26982f 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -26,56 +26,30 @@ from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators import databricks as databricks_operator from airflow.providers.databricks.operators.databricks import ( - DatabricksRunNowOperator, DatabricksSubmitRunOperator, + DatabricksRunNowOperator, + DatabricksSubmitRunOperator, ) DATE = '2017-04-20' TASK_ID = 'databricks-operator' DEFAULT_CONN_ID = 'databricks_default' -NOTEBOOK_TASK = { - 'notebook_path': '/test' -} -TEMPLATED_NOTEBOOK_TASK = { - 'notebook_path': '/test-{{ ds }}' -} -RENDERED_TEMPLATED_NOTEBOOK_TASK = { - 'notebook_path': '/test-{0}'.format(DATE) -} -SPARK_JAR_TASK = { - 'main_class_name': 'com.databricks.Test' -} -SPARK_PYTHON_TASK = { - 'python_file': 'test.py', - 'parameters': ['--param', '123'] -} +NOTEBOOK_TASK = {'notebook_path': '/test'} +TEMPLATED_NOTEBOOK_TASK = {'notebook_path': '/test-{{ ds }}'} +RENDERED_TEMPLATED_NOTEBOOK_TASK = {'notebook_path': '/test-{0}'.format(DATE)} +SPARK_JAR_TASK = {'main_class_name': 'com.databricks.Test'} +SPARK_PYTHON_TASK = {'python_file': 'test.py', 'parameters': ['--param', '123']} SPARK_SUBMIT_TASK = { - "parameters": [ - "--class", - "org.apache.spark.examples.SparkPi", - "dbfs:/path/to/examples.jar", - "10" - ] -} -NEW_CLUSTER = { - 'spark_version': '2.0.x-scala2.10', - 'node_type_id': 'development-node', - 'num_workers': 1 + "parameters": ["--class", "org.apache.spark.examples.SparkPi", "dbfs:/path/to/examples.jar", "10"] } +NEW_CLUSTER = {'spark_version': '2.0.x-scala2.10', 'node_type_id': 'development-node', 'num_workers': 1} EXISTING_CLUSTER_ID = 'existing-cluster-id' RUN_NAME = 'run-name' RUN_ID = 1 JOB_ID = 42 -NOTEBOOK_PARAMS = { - "dry-run": "true", - "oldest-time-to-consider": "1457570074236" -} +NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] -RENDERED_TEMPLATED_JAR_PARAMS = [ - '/test-{0}'.format(DATE) -] -TEMPLATED_JAR_PARAMS = [ - '/test-{{ ds }}' -] +RENDERED_TEMPLATED_JAR_PARAMS = ['/test-{0}'.format(DATE)] +TEMPLATED_JAR_PARAMS = ['/test-{{ ds }}'] PYTHON_PARAMS = ["john doe", "35"] SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"] @@ -87,7 +61,7 @@ def test_deep_string_coerce(self): 'test_float': 1.0, 'test_dict': {'key': 'value'}, 'test_list': [1, 1.0, 'a', 'b'], - 'test_tuple': (1, 1.0, 'a', 'b') + 'test_tuple': (1, 1.0, 'a', 'b'), } expected = { @@ -95,7 +69,7 @@ def test_deep_string_coerce(self): 'test_float': '1.0', 'test_dict': {'key': 'value'}, 'test_list': ['1', '1.0', 'a', 'b'], - 'test_tuple': ['1', '1.0', 'a', 'b'] + 'test_tuple': ['1', '1.0', 'a', 'b'], } self.assertDictEqual(databricks_operator._deep_string_coerce(test_json), expected) @@ -105,14 +79,12 @@ def test_init_with_notebook_task_named_parameters(self): """ Test the initializer with the named parameters. """ - op = DatabricksSubmitRunOperator(task_id=TASK_ID, - new_cluster=NEW_CLUSTER, - notebook_task=NOTEBOOK_TASK) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID - }) + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK + ) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} + ) self.assertDictEqual(expected, op.json) @@ -120,14 +92,12 @@ def test_init_with_spark_python_task_named_parameters(self): """ Test the initializer with the named parameters. """ - op = DatabricksSubmitRunOperator(task_id=TASK_ID, - new_cluster=NEW_CLUSTER, - spark_python_task=SPARK_PYTHON_TASK) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'spark_python_task': SPARK_PYTHON_TASK, - 'run_name': TASK_ID - }) + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_python_task=SPARK_PYTHON_TASK + ) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'spark_python_task': SPARK_PYTHON_TASK, 'run_name': TASK_ID} + ) self.assertDictEqual(expected, op.json) @@ -135,14 +105,12 @@ def test_init_with_spark_submit_task_named_parameters(self): """ Test the initializer with the named parameters. """ - op = DatabricksSubmitRunOperator(task_id=TASK_ID, - new_cluster=NEW_CLUSTER, - spark_submit_task=SPARK_SUBMIT_TASK) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'spark_submit_task': SPARK_SUBMIT_TASK, - 'run_name': TASK_ID - }) + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_submit_task=SPARK_SUBMIT_TASK + ) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'spark_submit_task': SPARK_SUBMIT_TASK, 'run_name': TASK_ID} + ) self.assertDictEqual(expected, op.json) @@ -150,33 +118,22 @@ def test_init_with_json(self): """ Test the initializer with json data. """ - json = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK - } + json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID - }) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} + ) self.assertDictEqual(expected, op.json) def test_init_with_specified_run_name(self): """ Test the initializer with a specified run_name. """ - json = { - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': RUN_NAME - } + json = {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': RUN_NAME} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': RUN_NAME - }) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': RUN_NAME} + ) self.assertDictEqual(expected, op.json) def test_init_with_merging(self): @@ -190,14 +147,10 @@ def test_init_with_merging(self): 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, } - op = DatabricksSubmitRunOperator(task_id=TASK_ID, - json=json, - new_cluster=override_new_cluster) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': override_new_cluster, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID, - }) + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': override_new_cluster, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID,} + ) self.assertDictEqual(expected, op.json) def test_init_with_templating(self): @@ -208,20 +161,22 @@ def test_init_with_templating(self): dag = DAG('test', start_date=datetime.now()) op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={'ds': DATE}) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK, - 'run_name': TASK_ID, - }) + expected = databricks_operator._deep_string_coerce( + { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK, + 'run_name': TASK_ID, + } + ) self.assertDictEqual(expected, op.json) def test_init_with_bad_type(self): - json = { - 'test': datetime.now() - } + json = {'test': datetime.now()} # Looks a bit weird since we have to escape regex reserved symbols. - exception_message = r'Type \<(type|class) \'datetime.datetime\'\> used ' + \ - r'for parameter json\[test\] is not a number or a string' + exception_message = ( + r'Type \<(type|class) \'datetime.datetime\'\> used ' + + r'for parameter json\[test\] is not a number or a string' + ) with self.assertRaisesRegex(AirflowException, exception_message): DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) @@ -241,15 +196,12 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID - }) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID} + ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay) + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) @@ -273,15 +225,12 @@ def test_exec_failure(self, db_mock_class): with self.assertRaises(AirflowException): op.execute(None) - expected = databricks_operator._deep_string_coerce({ - 'new_cluster': NEW_CLUSTER, - 'notebook_task': NOTEBOOK_TASK, - 'run_name': TASK_ID, - }) + expected = databricks_operator._deep_string_coerce( + {'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID,} + ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay) + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -303,15 +252,12 @@ def test_on_kill(self, db_mock_class): class TestDatabricksRunNowOperator(unittest.TestCase): - def test_init_with_named_parameters(self): """ Test the initializer with the named parameters. """ op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID) - expected = databricks_operator._deep_string_coerce({ - 'job_id': 42 - }) + expected = databricks_operator._deep_string_coerce({'job_id': 42}) self.assertDictEqual(expected, op.json) @@ -324,17 +270,19 @@ def test_init_with_json(self): 'jar_params': JAR_PARAMS, 'python_params': PYTHON_PARAMS, 'spark_submit_params': SPARK_SUBMIT_PARAMS, - 'job_id': JOB_ID + 'job_id': JOB_ID, } op = DatabricksRunNowOperator(task_id=TASK_ID, json=json) - expected = databricks_operator._deep_string_coerce({ - 'notebook_params': NOTEBOOK_PARAMS, - 'jar_params': JAR_PARAMS, - 'python_params': PYTHON_PARAMS, - 'spark_submit_params': SPARK_SUBMIT_PARAMS, - 'job_id': JOB_ID - }) + expected = databricks_operator._deep_string_coerce( + { + 'notebook_params': NOTEBOOK_PARAMS, + 'jar_params': JAR_PARAMS, + 'python_params': PYTHON_PARAMS, + 'spark_submit_params': SPARK_SUBMIT_PARAMS, + 'job_id': JOB_ID, + } + ) self.assertDictEqual(expected, op.json) @@ -345,51 +293,51 @@ def test_init_with_merging(self): json dict. """ override_notebook_params = {'workers': 999} - json = { - 'notebook_params': NOTEBOOK_PARAMS, - 'jar_params': JAR_PARAMS - } - - op = DatabricksRunNowOperator(task_id=TASK_ID, - json=json, - job_id=JOB_ID, - notebook_params=override_notebook_params, - python_params=PYTHON_PARAMS, - spark_submit_params=SPARK_SUBMIT_PARAMS) - - expected = databricks_operator._deep_string_coerce({ - 'notebook_params': override_notebook_params, - 'jar_params': JAR_PARAMS, - 'python_params': PYTHON_PARAMS, - 'spark_submit_params': SPARK_SUBMIT_PARAMS, - 'job_id': JOB_ID - }) + json = {'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS} + + op = DatabricksRunNowOperator( + task_id=TASK_ID, + json=json, + job_id=JOB_ID, + notebook_params=override_notebook_params, + python_params=PYTHON_PARAMS, + spark_submit_params=SPARK_SUBMIT_PARAMS, + ) + + expected = databricks_operator._deep_string_coerce( + { + 'notebook_params': override_notebook_params, + 'jar_params': JAR_PARAMS, + 'python_params': PYTHON_PARAMS, + 'spark_submit_params': SPARK_SUBMIT_PARAMS, + 'job_id': JOB_ID, + } + ) self.assertDictEqual(expected, op.json) def test_init_with_templating(self): - json = { - 'notebook_params': NOTEBOOK_PARAMS, - 'jar_params': TEMPLATED_JAR_PARAMS - } + json = {'notebook_params': NOTEBOOK_PARAMS, 'jar_params': TEMPLATED_JAR_PARAMS} dag = DAG('test', start_date=datetime.now()) op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) op.render_template_fields(context={'ds': DATE}) - expected = databricks_operator._deep_string_coerce({ - 'notebook_params': NOTEBOOK_PARAMS, - 'jar_params': RENDERED_TEMPLATED_JAR_PARAMS, - 'job_id': JOB_ID - }) + expected = databricks_operator._deep_string_coerce( + { + 'notebook_params': NOTEBOOK_PARAMS, + 'jar_params': RENDERED_TEMPLATED_JAR_PARAMS, + 'job_id': JOB_ID, + } + ) self.assertDictEqual(expected, op.json) def test_init_with_bad_type(self): - json = { - 'test': datetime.now() - } + json = {'test': datetime.now()} # Looks a bit weird since we have to escape regex reserved symbols. - exception_message = r'Type \<(type|class) \'datetime.datetime\'\> used ' + \ - r'for parameter json\[test\] is not a number or a string' + exception_message = ( + r'Type \<(type|class) \'datetime.datetime\'\> used ' + + r'for parameter json\[test\] is not a number or a string' + ) with self.assertRaisesRegex(AirflowException, exception_message): DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) @@ -398,11 +346,7 @@ def test_exec_success(self, db_mock_class): """ Test the execute function in case where the run is successful. """ - run = { - 'notebook_params': NOTEBOOK_PARAMS, - 'notebook_task': NOTEBOOK_TASK, - 'jar_params': JAR_PARAMS - } + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 @@ -410,17 +354,18 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = databricks_operator._deep_string_coerce({ - 'notebook_params': NOTEBOOK_PARAMS, - 'notebook_task': NOTEBOOK_TASK, - 'jar_params': JAR_PARAMS, - 'job_id': JOB_ID - }) + expected = databricks_operator._deep_string_coerce( + { + 'notebook_params': NOTEBOOK_PARAMS, + 'notebook_task': NOTEBOOK_TASK, + 'jar_params': JAR_PARAMS, + 'job_id': JOB_ID, + } + ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay) + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -431,11 +376,7 @@ def test_exec_failure(self, db_mock_class): """ Test the execute function in case where the run failed. """ - run = { - 'notebook_params': NOTEBOOK_PARAMS, - 'notebook_task': NOTEBOOK_TASK, - 'jar_params': JAR_PARAMS - } + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = 1 @@ -444,16 +385,17 @@ def test_exec_failure(self, db_mock_class): with self.assertRaises(AirflowException): op.execute(None) - expected = databricks_operator._deep_string_coerce({ - 'notebook_params': NOTEBOOK_PARAMS, - 'notebook_task': NOTEBOOK_TASK, - 'jar_params': JAR_PARAMS, - 'job_id': JOB_ID - }) + expected = databricks_operator._deep_string_coerce( + { + 'notebook_params': NOTEBOOK_PARAMS, + 'notebook_task': NOTEBOOK_TASK, + 'jar_params': JAR_PARAMS, + 'job_id': JOB_ID, + } + ) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay) + DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay + ) db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -461,11 +403,7 @@ def test_exec_failure(self, db_mock_class): @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook') def test_on_kill(self, db_mock_class): - run = { - 'notebook_params': NOTEBOOK_PARAMS, - 'notebook_task': NOTEBOOK_TASK, - 'jar_params': JAR_PARAMS - } + run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value op.run_id = RUN_ID diff --git a/tests/providers/datadog/hooks/test_datadog.py b/tests/providers/datadog/hooks/test_datadog.py index 540c54289a653..71248d59cc42b 100644 --- a/tests/providers/datadog/hooks/test_datadog.py +++ b/tests/providers/datadog/hooks/test_datadog.py @@ -43,14 +43,12 @@ class TestDatadogHook(unittest.TestCase): - @mock.patch('airflow.providers.datadog.hooks.datadog.initialize') @mock.patch('airflow.providers.datadog.hooks.datadog.DatadogHook.get_connection') def setUp(self, mock_get_connection, mock_initialize): - mock_get_connection.return_value = Connection(extra=json.dumps({ - 'app_key': APP_KEY, - 'api_key': API_KEY, - })) + mock_get_connection.return_value = Connection( + extra=json.dumps({'app_key': APP_KEY, 'api_key': API_KEY,}) + ) self.hook = DatadogHook() @mock.patch('airflow.providers.datadog.hooks.datadog.initialize') @@ -59,8 +57,7 @@ def test_api_key_required(self, mock_get_connection, mock_initialize): mock_get_connection.return_value = Connection() with self.assertRaises(AirflowException) as ctx: DatadogHook() - self.assertEqual(str(ctx.exception), - 'api_key must be specified in the Datadog connection details') + self.assertEqual(str(ctx.exception), 'api_key must be specified in the Datadog connection details') def test_validate_response_valid(self): try: @@ -76,11 +73,7 @@ def test_validate_response_invalid(self): def test_send_metric(self, mock_send): mock_send.return_value = {'status': 'ok'} self.hook.send_metric( - METRIC_NAME, - DATAPOINT, - tags=TAGS, - type_=TYPE, - interval=INTERVAL, + METRIC_NAME, DATAPOINT, tags=TAGS, type_=TYPE, interval=INTERVAL, ) mock_send.assert_called_once_with( metric=METRIC_NAME, @@ -99,9 +92,7 @@ def test_query_metric(self, mock_time, mock_query): mock_query.return_value = {'status': 'ok'} self.hook.query_metric('query', 60, 30) mock_query.assert_called_once_with( - start=now - 60, - end=now - 30, - query='query', + start=now - 60, end=now - 30, query='query', ) @mock.patch('airflow.providers.datadog.hooks.datadog.api.Event.create') diff --git a/tests/providers/datadog/sensors/test_datadog.py b/tests/providers/datadog/sensors/test_datadog.py index 628906d022f0a..33952a0698339 100644 --- a/tests/providers/datadog/sensors/test_datadog.py +++ b/tests/providers/datadog/sensors/test_datadog.py @@ -26,46 +26,53 @@ from airflow.providers.datadog.sensors.datadog import DatadogSensor from airflow.utils import db -at_least_one_event = [{'alert_type': 'info', - 'comments': [], - 'date_happened': 1419436860, - 'device_name': None, - 'host': None, - 'id': 2603387619536318140, - 'is_aggregate': False, - 'priority': 'normal', - 'resource': '/api/v1/events/2603387619536318140', - 'source': 'My Apps', - 'tags': ['application:web', 'version:1'], - 'text': 'And let me tell you all about it here!', - 'title': 'Something big happened!', - 'url': '/event/jump_to?event_id=2603387619536318140'}, - {'alert_type': 'info', - 'comments': [], - 'date_happened': 1419436865, - 'device_name': None, - 'host': None, - 'id': 2603387619536318141, - 'is_aggregate': False, - 'priority': 'normal', - 'resource': '/api/v1/events/2603387619536318141', - 'source': 'My Apps', - 'tags': ['application:web', 'version:1'], - 'text': 'And let me tell you all about it here!', - 'title': 'Something big happened!', - 'url': '/event/jump_to?event_id=2603387619536318141'}] +at_least_one_event = [ + { + 'alert_type': 'info', + 'comments': [], + 'date_happened': 1419436860, + 'device_name': None, + 'host': None, + 'id': 2603387619536318140, + 'is_aggregate': False, + 'priority': 'normal', + 'resource': '/api/v1/events/2603387619536318140', + 'source': 'My Apps', + 'tags': ['application:web', 'version:1'], + 'text': 'And let me tell you all about it here!', + 'title': 'Something big happened!', + 'url': '/event/jump_to?event_id=2603387619536318140', + }, + { + 'alert_type': 'info', + 'comments': [], + 'date_happened': 1419436865, + 'device_name': None, + 'host': None, + 'id': 2603387619536318141, + 'is_aggregate': False, + 'priority': 'normal', + 'resource': '/api/v1/events/2603387619536318141', + 'source': 'My Apps', + 'tags': ['application:web', 'version:1'], + 'text': 'And let me tell you all about it here!', + 'title': 'Something big happened!', + 'url': '/event/jump_to?event_id=2603387619536318141', + }, +] zero_events = [] # type: List class TestDatadogSensor(unittest.TestCase): - def setUp(self): db.merge_conn( Connection( - conn_id='datadog_default', conn_type='datadog', - login='login', password='password', - extra=json.dumps({'api_key': 'api_key', 'app_key': 'app_key'}) + conn_id='datadog_default', + conn_type='datadog', + login='login', + password='password', + extra=json.dumps({'api_key': 'api_key', 'app_key': 'app_key'}), ) ) @@ -83,7 +90,8 @@ def test_sensor_ok(self, api1, api2): priority=None, sources=None, tags=None, - response_check=None) + response_check=None, + ) self.assertTrue(sensor.poke({})) @@ -101,6 +109,7 @@ def test_sensor_fail(self, api1, api2): priority=None, sources=None, tags=None, - response_check=None) + response_check=None, + ) self.assertFalse(sensor.poke({})) diff --git a/tests/providers/dingding/hooks/test_dingding.py b/tests/providers/dingding/hooks/test_dingding.py index 0041625722876..fa06210a1be07 100644 --- a/tests/providers/dingding/hooks/test_dingding.py +++ b/tests/providers/dingding/hooks/test_dingding.py @@ -33,7 +33,9 @@ def setUp(self): conn_id=self.conn_id, conn_type='http', host='https://oapi.dingtalk.com', - password='you_token_here')) + password='you_token_here', + ) + ) def test_get_endpoint_conn_id(self): hook = DingdingHook(dingding_conn_id=self.conn_id) @@ -50,13 +52,8 @@ def test_build_text_message_not_remind(self): } expect = { 'msgtype': 'text', - 'text': { - 'content': 'Airflow dingding text message remind no one' - }, - 'at': { - 'atMobiles': False, - 'isAtAll': False - } + 'text': {'content': 'Airflow dingding text message remind no one'}, + 'at': {'atMobiles': False, 'isAtAll': False}, } hook = DingdingHook(**config) message = hook._build_message() @@ -72,13 +69,8 @@ def test_build_text_message_remind_specific(self): } expect = { 'msgtype': 'text', - 'text': { - 'content': 'Airflow dingding text message remind specific users' - }, - 'at': { - 'atMobiles': ['1234', '5768'], - 'isAtAll': False - } + 'text': {'content': 'Airflow dingding text message remind specific users'}, + 'at': {'atMobiles': ['1234', '5768'], 'isAtAll': False}, } hook = DingdingHook(**config) message = hook._build_message() @@ -93,13 +85,8 @@ def test_build_text_message_remind_all(self): } expect = { 'msgtype': 'text', - 'text': { - 'content': 'Airflow dingding text message remind all user in group' - }, - 'at': { - 'atMobiles': None, - 'isAtAll': True - } + 'text': {'content': 'Airflow dingding text message remind all user in group'}, + 'at': {'atMobiles': None, 'isAtAll': True}, } hook = DingdingHook(**config) message = hook._build_message() @@ -109,7 +96,7 @@ def test_build_markdown_message_remind_specific(self): msg = { 'title': 'Airflow dingding markdown message', 'text': '# Markdown message title\ncontent content .. \n### sub-title\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)' + '![logo](http://airflow.apache.org/_images/pin_large.png)', } config = { 'dingding_conn_id': self.conn_id, @@ -121,10 +108,7 @@ def test_build_markdown_message_remind_specific(self): expect = { 'msgtype': 'markdown', 'markdown': msg, - 'at': { - 'atMobiles': ['1234', '5678'], - 'isAtAll': False - } + 'at': {'atMobiles': ['1234', '5678'], 'isAtAll': False}, } hook = DingdingHook(**config) message = hook._build_message() @@ -134,7 +118,7 @@ def test_build_markdown_message_remind_all(self): msg = { 'title': 'Airflow dingding markdown message', 'text': '# Markdown message title\ncontent content .. \n### sub-title\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)' + '![logo](http://airflow.apache.org/_images/pin_large.png)', } config = { 'dingding_conn_id': self.conn_id, @@ -142,14 +126,7 @@ def test_build_markdown_message_remind_all(self): 'message': msg, 'at_all': True, } - expect = { - 'msgtype': 'markdown', - 'markdown': msg, - 'at': { - 'atMobiles': None, - 'isAtAll': True - } - } + expect = {'msgtype': 'markdown', 'markdown': msg, 'at': {'atMobiles': None, 'isAtAll': True}} hook = DingdingHook(**config) message = hook._build_message() self.assertEqual(json.dumps(expect), message) @@ -159,17 +136,10 @@ def test_build_link_message(self): 'title': 'Airflow dingding link message', 'text': 'Airflow official documentation link', 'messageUrl': 'http://airflow.apache.org', - 'picURL': 'http://airflow.apache.org/_images/pin_large.png' - } - config = { - 'dingding_conn_id': self.conn_id, - 'message_type': 'link', - 'message': msg - } - expect = { - 'msgtype': 'link', - 'link': msg + 'picURL': 'http://airflow.apache.org/_images/pin_large.png', } + config = {'dingding_conn_id': self.conn_id, 'message_type': 'link', 'message': msg} + expect = {'msgtype': 'link', 'link': msg} hook = DingdingHook(**config) message = hook._build_message() self.assertEqual(json.dumps(expect), message) @@ -178,22 +148,15 @@ def test_build_single_action_card_message(self): msg = { 'title': 'Airflow dingding single actionCard message', 'text': 'Airflow dingding single actionCard message\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)\n' - 'This is a official logo in Airflow website.', + '![logo](http://airflow.apache.org/_images/pin_large.png)\n' + 'This is a official logo in Airflow website.', 'hideAvatar': '0', 'btnOrientation': '0', 'singleTitle': 'read more', - 'singleURL': 'http://airflow.apache.org' - } - config = { - 'dingding_conn_id': self.conn_id, - 'message_type': 'actionCard', - 'message': msg - } - expect = { - 'msgtype': 'actionCard', - 'actionCard': msg + 'singleURL': 'http://airflow.apache.org', } + config = {'dingding_conn_id': self.conn_id, 'message_type': 'actionCard', 'message': msg} + expect = {'msgtype': 'actionCard', 'actionCard': msg} hook = DingdingHook(**config) message = hook._build_message() self.assertEqual(json.dumps(expect), message) @@ -202,30 +165,17 @@ def test_build_multi_action_card_message(self): msg = { 'title': 'Airflow dingding multi actionCard message', 'text': 'Airflow dingding multi actionCard message\n' - '![logo](http://airflow.apache.org/_images/pin_large.png)\n' - 'Airflow documentation and github', + '![logo](http://airflow.apache.org/_images/pin_large.png)\n' + 'Airflow documentation and github', 'hideAvatar': '0', 'btnOrientation': '0', 'btns': [ - { - 'title': 'Airflow Documentation', - 'actionURL': 'http://airflow.apache.org' - }, - { - 'title': 'Airflow Github', - 'actionURL': 'https://github.com/apache/airflow' - } - ] - } - config = { - 'dingding_conn_id': self.conn_id, - 'message_type': 'actionCard', - 'message': msg - } - expect = { - 'msgtype': 'actionCard', - 'actionCard': msg + {'title': 'Airflow Documentation', 'actionURL': 'http://airflow.apache.org'}, + {'title': 'Airflow Github', 'actionURL': 'https://github.com/apache/airflow'}, + ], } + config = {'dingding_conn_id': self.conn_id, 'message_type': 'actionCard', 'message': msg} + expect = {'msgtype': 'actionCard', 'actionCard': msg} hook = DingdingHook(**config) message = hook._build_message() self.assertEqual(json.dumps(expect), message) @@ -236,29 +186,22 @@ def test_build_feed_card_message(self): { "title": "Airflow DAG feed card", "messageURL": "https://airflow.readthedocs.io/en/latest/ui.html", - "picURL": "http://airflow.apache.org/_images/dags.png" + "picURL": "http://airflow.apache.org/_images/dags.png", }, { "title": "Airflow tree feed card", "messageURL": "https://airflow.readthedocs.io/en/latest/ui.html", - "picURL": "http://airflow.apache.org/_images/tree.png" + "picURL": "http://airflow.apache.org/_images/tree.png", }, { "title": "Airflow graph feed card", "messageURL": "https://airflow.readthedocs.io/en/latest/ui.html", - "picURL": "http://airflow.apache.org/_images/graph.png" - } + "picURL": "http://airflow.apache.org/_images/graph.png", + }, ] } - config = { - 'dingding_conn_id': self.conn_id, - 'message_type': 'feedCard', - 'message': msg - } - expect = { - 'msgtype': 'feedCard', - 'feedCard': msg - } + config = {'dingding_conn_id': self.conn_id, 'message_type': 'feedCard', 'message': msg} + expect = {'msgtype': 'feedCard', 'feedCard': msg} hook = DingdingHook(**config) message = hook._build_message() self.assertEqual(json.dumps(expect), message) @@ -267,7 +210,7 @@ def test_send_not_support_type(self): config = { 'dingding_conn_id': self.conn_id, 'message_type': 'not_support_type', - 'message': 'Airflow dingding text message remind no one' + 'message': 'Airflow dingding text message remind no one', } hook = DingdingHook(**config) self.assertRaises(ValueError, hook.send) diff --git a/tests/providers/dingding/operators/test_dingding.py b/tests/providers/dingding/operators/test_dingding.py index 14b299109ac3e..f7ed13a775d43 100644 --- a/tests/providers/dingding/operators/test_dingding.py +++ b/tests/providers/dingding/operators/test_dingding.py @@ -32,23 +32,16 @@ class TestDingdingOperator(unittest.TestCase): 'message_type': 'text', 'message': 'Airflow dingding webhook test', 'at_mobiles': ['123', '456'], - 'at_all': False + 'at_all': False, } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) @mock.patch('airflow.providers.dingding.operators.dingding.DingdingHook') def test_execute(self, mock_hook): - operator = DingdingOperator( - task_id='dingding_task', - dag=self.dag, - **self._config - ) + operator = DingdingOperator(task_id='dingding_task', dag=self.dag, **self._config) self.assertIsNotNone(operator) self.assertEqual(self._config['dingding_conn_id'], operator.dingding_conn_id) @@ -63,6 +56,6 @@ def test_execute(self, mock_hook): self._config['message_type'], self._config['message'], self._config['at_mobiles'], - self._config['at_all'] + self._config['at_all'], ) mock_hook.return_value.send.assert_called_once_with() diff --git a/tests/providers/discord/hooks/test_discord_webhook.py b/tests/providers/discord/hooks/test_discord_webhook.py index b15878e6b126d..f28a10a10473f 100644 --- a/tests/providers/discord/hooks/test_discord_webhook.py +++ b/tests/providers/discord/hooks/test_discord_webhook.py @@ -34,14 +34,14 @@ class TestDiscordWebhookHook(unittest.TestCase): 'username': 'Airflow Webhook', 'avatar_url': 'https://static-cdn.avatars.com/my-avatar-path', 'tts': False, - 'proxy': 'https://proxy.proxy.com:8888' + 'proxy': 'https://proxy.proxy.com:8888', } expected_payload_dict = { 'username': _config['username'], 'avatar_url': _config['avatar_url'], 'tts': _config['tts'], - 'content': _config['message'] + 'content': _config['message'], } expected_payload = json.dumps(expected_payload_dict) @@ -52,7 +52,8 @@ def setUp(self): conn_id='default-discord-webhook', conn_type='http', host='https://discordapp.com/api/', - extra='{"webhook_endpoint": "webhooks/00000/some-discord-token_000"}') + extra='{"webhook_endpoint": "webhooks/00000/some-discord-token_000"}', + ) ) def test_get_webhook_endpoint_manual_token(self): diff --git a/tests/providers/discord/operators/test_discord_webhook.py b/tests/providers/discord/operators/test_discord_webhook.py index 0ad0dd4b93061..8cf3b64bdcec1 100644 --- a/tests/providers/discord/operators/test_discord_webhook.py +++ b/tests/providers/discord/operators/test_discord_webhook.py @@ -33,22 +33,15 @@ class TestDiscordWebhookOperator(unittest.TestCase): 'username': 'Airflow Webhook', 'avatar_url': 'https://static-cdn.avatars.com/my-avatar-path', 'tts': False, - 'proxy': 'https://proxy.proxy.com:8888' + 'proxy': 'https://proxy.proxy.com:8888', } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) def test_execute(self): - operator = DiscordWebhookOperator( - task_id='discord_webhook_task', - dag=self.dag, - **self._config - ) + operator = DiscordWebhookOperator(task_id='discord_webhook_task', dag=self.dag, **self._config) self.assertEqual(self._config['http_conn_id'], operator.http_conn_id) self.assertEqual(self._config['webhook_endpoint'], operator.webhook_endpoint) diff --git a/tests/providers/docker/hooks/test_docker.py b/tests/providers/docker/hooks/test_docker.py index 042b71d943f1f..4927480bdf755 100644 --- a/tests/providers/docker/hooks/test_docker.py +++ b/tests/providers/docker/hooks/test_docker.py @@ -39,7 +39,7 @@ def setUp(self): conn_type='docker', host='some.docker.registry.com', login='some_user', - password='some_p4$$w0rd' + password='some_p4$$w0rd', ) ) db.merge_conn( @@ -50,46 +50,34 @@ def setUp(self): port=9876, login='some_user', password='some_p4$$w0rd', - extra='{"email": "some@example.com", "reauth": "no"}' + extra='{"email": "some@example.com", "reauth": "no"}', ) ) def test_init_fails_when_no_base_url_given(self, _): with self.assertRaises(AirflowException): - DockerHook( - docker_conn_id='docker_default', - version='auto', - tls=None - ) + DockerHook(docker_conn_id='docker_default', version='auto', tls=None) def test_init_fails_when_no_api_version_given(self, _): with self.assertRaises(AirflowException): - DockerHook( - docker_conn_id='docker_default', - base_url='unix://var/run/docker.sock', - tls=None - ) + DockerHook(docker_conn_id='docker_default', base_url='unix://var/run/docker.sock', tls=None) def test_get_conn_override_defaults(self, docker_client_mock): hook = DockerHook( docker_conn_id='docker_default', base_url='https://index.docker.io/v1/', version='1.23', - tls='someconfig' + tls='someconfig', ) hook.get_conn() docker_client_mock.assert_called_once_with( - base_url='https://index.docker.io/v1/', - version='1.23', - tls='someconfig' + base_url='https://index.docker.io/v1/', version='1.23', tls='someconfig' ) def test_get_conn_with_standard_config(self, _): try: hook = DockerHook( - docker_conn_id='docker_default', - base_url='unix://var/run/docker.sock', - version='auto' + docker_conn_id='docker_default', base_url='unix://var/run/docker.sock', version='auto' ) client = hook.get_conn() self.assertIsNotNone(client) @@ -99,9 +87,7 @@ def test_get_conn_with_standard_config(self, _): def test_get_conn_with_extra_config(self, _): try: hook = DockerHook( - docker_conn_id='docker_with_extras', - base_url='unix://var/run/docker.sock', - version='auto' + docker_conn_id='docker_with_extras', base_url='unix://var/run/docker.sock', version='auto' ) client = hook.get_conn() self.assertIsNotNone(client) @@ -110,9 +96,7 @@ def test_get_conn_with_extra_config(self, _): def test_conn_with_standard_config_passes_parameters(self, _): hook = DockerHook( - docker_conn_id='docker_default', - base_url='unix://var/run/docker.sock', - version='auto' + docker_conn_id='docker_default', base_url='unix://var/run/docker.sock', version='auto' ) client = hook.get_conn() client.login.assert_called_once_with( # pylint: disable=no-member @@ -120,14 +104,12 @@ def test_conn_with_standard_config_passes_parameters(self, _): password='some_p4$$w0rd', registry='some.docker.registry.com', reauth=True, - email=None + email=None, ) def test_conn_with_extra_config_passes_parameters(self, _): hook = DockerHook( - docker_conn_id='docker_with_extras', - base_url='unix://var/run/docker.sock', - version='auto' + docker_conn_id='docker_with_extras', base_url='unix://var/run/docker.sock', version='auto' ) client = hook.get_conn() client.login.assert_called_once_with( # pylint: disable=no-member @@ -135,7 +117,7 @@ def test_conn_with_extra_config_passes_parameters(self, _): password='some_p4$$w0rd', registry='another.docker.registry.com:9876', reauth=False, - email='some@example.com' + email='some@example.com', ) def test_conn_with_broken_config_missing_username_fails(self, _): @@ -145,28 +127,23 @@ def test_conn_with_broken_config_missing_username_fails(self, _): conn_type='docker', host='some.docker.registry.com', password='some_p4$$w0rd', - extra='{"email": "some@example.com"}' + extra='{"email": "some@example.com"}', ) ) with self.assertRaises(AirflowException): DockerHook( docker_conn_id='docker_without_username', base_url='unix://var/run/docker.sock', - version='auto' + version='auto', ) def test_conn_with_broken_config_missing_host_fails(self, _): db.merge_conn( Connection( - conn_id='docker_without_host', - conn_type='docker', - login='some_user', - password='some_p4$$w0rd' + conn_id='docker_without_host', conn_type='docker', login='some_user', password='some_p4$$w0rd' ) ) with self.assertRaises(AirflowException): DockerHook( - docker_conn_id='docker_without_host', - base_url='unix://var/run/docker.sock', - version='auto' + docker_conn_id='docker_without_host', base_url='unix://var/run/docker.sock', version='auto' ) diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py index 4b9c1f7835372..b3258ba0d4227 100644 --- a/tests/providers/docker/operators/test_docker.py +++ b/tests/providers/docker/operators/test_docker.py @@ -49,59 +49,66 @@ def test_execute(self, client_class_mock, tempdir_mock): client_class_mock.return_value = client_mock - operator = DockerOperator(api_version='1.19', command='env', environment={'UNIT': 'TEST'}, - private_environment={'PRIVATE': 'MESSAGE'}, image='ubuntu:latest', - network_mode='bridge', owner='unittest', task_id='unittest', - volumes=['/host/path:/container/path'], - working_dir='/container/path', shm_size=1000, - host_tmp_dir='/host/airflow', container_name='test_container', - tty=True) + operator = DockerOperator( + api_version='1.19', + command='env', + environment={'UNIT': 'TEST'}, + private_environment={'PRIVATE': 'MESSAGE'}, + image='ubuntu:latest', + network_mode='bridge', + owner='unittest', + task_id='unittest', + volumes=['/host/path:/container/path'], + working_dir='/container/path', + shm_size=1000, + host_tmp_dir='/host/airflow', + container_name='test_container', + tty=True, + ) operator.execute(None) - client_class_mock.assert_called_once_with(base_url='unix://var/run/docker.sock', tls=None, - version='1.19') - - client_mock.create_container.assert_called_once_with(command='env', - name='test_container', - environment={ - 'AIRFLOW_TMP_DIR': '/tmp/airflow', - 'UNIT': 'TEST', - 'PRIVATE': 'MESSAGE' - }, - host_config=host_config, - image='ubuntu:latest', - user=None, - working_dir='/container/path', - tty=True - ) - client_mock.create_host_config.assert_called_once_with(binds=['/host/path:/container/path', - '/mkdtemp:/tmp/airflow'], - network_mode='bridge', - shm_size=1000, - cpu_shares=1024, - mem_limit=None, - auto_remove=False, - dns=None, - dns_search=None, - cap_add=None) + client_class_mock.assert_called_once_with( + base_url='unix://var/run/docker.sock', tls=None, version='1.19' + ) + + client_mock.create_container.assert_called_once_with( + command='env', + name='test_container', + environment={'AIRFLOW_TMP_DIR': '/tmp/airflow', 'UNIT': 'TEST', 'PRIVATE': 'MESSAGE'}, + host_config=host_config, + image='ubuntu:latest', + user=None, + working_dir='/container/path', + tty=True, + ) + client_mock.create_host_config.assert_called_once_with( + binds=['/host/path:/container/path', '/mkdtemp:/tmp/airflow'], + network_mode='bridge', + shm_size=1000, + cpu_shares=1024, + mem_limit=None, + auto_remove=False, + dns=None, + dns_search=None, + cap_add=None, + ) tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp') client_mock.images.assert_called_once_with(name='ubuntu:latest') - client_mock.attach.assert_called_once_with(container='some_id', stdout=True, - stderr=True, stream=True) - client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, - decode=True) + client_mock.attach.assert_called_once_with(container='some_id', stdout=True, stderr=True, stream=True) + client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True) client_mock.wait.assert_called_once_with('some_id') - self.assertEqual(operator.cli.pull('ubuntu:latest', stream=True, - decode=True), - client_mock.pull.return_value) + self.assertEqual( + operator.cli.pull('ubuntu:latest', stream=True, decode=True), client_mock.pull.return_value + ) def test_private_environment_is_private(self): - operator = DockerOperator(private_environment={'PRIVATE': 'MESSAGE'}, - image='ubuntu:latest', - task_id='unittest') + operator = DockerOperator( + private_environment={'PRIVATE': 'MESSAGE'}, image='ubuntu:latest', task_id='unittest' + ) self.assertEqual( - operator._private_environment, {'PRIVATE': 'MESSAGE'}, - "To keep this private, it must be an underscored attribute." + operator._private_environment, + {'PRIVATE': 'MESSAGE'}, + "To keep this private, it must be an underscored attribute.", ) @mock.patch('airflow.providers.docker.operators.docker.tls.TLSConfig') @@ -119,17 +126,28 @@ def test_execute_tls(self, client_class_mock, tls_class_mock): tls_mock = mock.Mock() tls_class_mock.return_value = tls_mock - operator = DockerOperator(docker_url='tcp://127.0.0.1:2376', image='ubuntu', - owner='unittest', task_id='unittest', tls_client_cert='cert.pem', - tls_ca_cert='ca.pem', tls_client_key='key.pem') + operator = DockerOperator( + docker_url='tcp://127.0.0.1:2376', + image='ubuntu', + owner='unittest', + task_id='unittest', + tls_client_cert='cert.pem', + tls_ca_cert='ca.pem', + tls_client_key='key.pem', + ) operator.execute(None) - tls_class_mock.assert_called_once_with(assert_hostname=None, ca_cert='ca.pem', - client_cert=('cert.pem', 'key.pem'), - ssl_version=None, verify=True) + tls_class_mock.assert_called_once_with( + assert_hostname=None, + ca_cert='ca.pem', + client_cert=('cert.pem', 'key.pem'), + ssl_version=None, + verify=True, + ) - client_class_mock.assert_called_once_with(base_url='https://127.0.0.1:2376', - tls=tls_mock, version=None) + client_class_mock.assert_called_once_with( + base_url='https://127.0.0.1:2376', tls=tls_mock, version=None + ) @mock.patch('airflow.providers.docker.operators.docker.APIClient') def test_execute_unicode_logs(self, client_class_mock): @@ -194,31 +212,21 @@ def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock): operator_client_mock.return_value = client_mock # Create the DockerOperator - operator = DockerOperator( - image='publicregistry/someimage', - owner='unittest', - task_id='unittest' - ) + operator = DockerOperator(image='publicregistry/someimage', owner='unittest', task_id='unittest') # Mock out the DockerHook hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) hook_mock.get_conn.return_value = client_mock operator.get_hook = mock.Mock( - name='DockerOperator.get_hook mock', - spec=DockerOperator.get_hook, - return_value=hook_mock + name='DockerOperator.get_hook mock', spec=DockerOperator.get_hook, return_value=hook_mock ) operator.execute(None) - self.assertEqual( - operator.get_hook.call_count, 0, - 'Hook called though no docker_conn_id configured' - ) + self.assertEqual(operator.get_hook.call_count, 0, 'Hook called though no docker_conn_id configured') @mock.patch('airflow.providers.docker.operators.docker.DockerHook') @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, - operator_docker_hook): + def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operator_docker_hook): # Mock out a Docker client, so operations don't raise errors client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient) client_mock.images.return_value = [] @@ -233,7 +241,7 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, image='publicregistry/someimage', owner='unittest', task_id='unittest', - docker_conn_id='some_conn_id' + docker_conn_id='some_conn_id', ) # Mock out the DockerHook @@ -244,17 +252,12 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operator.execute(None) self.assertEqual( - operator_client_mock.call_count, 0, - 'Client was called on the operator instead of the hook' - ) - self.assertEqual( - operator_docker_hook.call_count, 1, - 'Hook was not called although docker_conn_id configured' + operator_client_mock.call_count, 0, 'Client was called on the operator instead of the hook' ) self.assertEqual( - client_mock.pull.call_count, 1, - 'Image was not pulled using operator client' + operator_docker_hook.call_count, 1, 'Hook was not called although docker_conn_id configured' ) + self.assertEqual(client_mock.pull.call_count, 1, 'Image was not pulled using operator client') @mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory') @mock.patch('airflow.providers.docker.operators.docker.APIClient') diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py index 039b50b850fb3..a75b38e8dcb67 100644 --- a/tests/providers/docker/operators/test_docker_swarm.py +++ b/tests/providers/docker/operators/test_docker_swarm.py @@ -27,7 +27,6 @@ class TestDockerSwarmOperator(unittest.TestCase): - @mock.patch('airflow.providers.docker.operators.docker.APIClient') @mock.patch('airflow.providers.docker.operators.docker_swarm.types') def test_execute(self, types_mock, client_class_mock): @@ -57,8 +56,15 @@ def _client_service_logs_effect(): client_class_mock.return_value = client_mock operator = DockerSwarmOperator( - api_version='1.19', command='env', environment={'UNIT': 'TEST'}, image='ubuntu:latest', - mem_limit='128m', user='unittest', task_id='unittest', auto_remove=True, tty=True, + api_version='1.19', + command='env', + environment={'UNIT': 'TEST'}, + image='ubuntu:latest', + mem_limit='128m', + user='unittest', + task_id='unittest', + auto_remove=True, + tty=True, ) operator.execute(None) @@ -66,8 +72,11 @@ def _client_service_logs_effect(): container_spec=mock_obj, restart_policy=mock_obj, resources=mock_obj ) types_mock.ContainerSpec.assert_called_once_with( - image='ubuntu:latest', command='env', user='unittest', tty=True, - env={'UNIT': 'TEST', 'AIRFLOW_TMP_DIR': '/tmp/airflow'} + image='ubuntu:latest', + command='env', + user='unittest', + tty=True, + env={'UNIT': 'TEST', 'AIRFLOW_TMP_DIR': '/tmp/airflow'}, ) types_mock.RestartPolicy.assert_called_once_with(condition='none') types_mock.Resources.assert_called_once_with(mem_limit='128m') @@ -84,7 +93,7 @@ def _client_service_logs_effect(): self.assertEqual( len(csargs), 1, 'create_service called with different number of arguments than expected' ) - self.assertEqual(csargs, (mock_obj, )) + self.assertEqual(csargs, (mock_obj,)) self.assertEqual(cskwargs['labels'], {'name': 'airflow__adhoc_airflow__unittest'}) self.assertTrue(cskwargs['name'].startswith('airflow-')) self.assertEqual(client_mock.tasks.call_count, 5) @@ -112,8 +121,9 @@ def test_no_auto_remove(self, types_mock, client_class_mock): operator.execute(None) self.assertEqual( - client_mock.remove_service.call_count, 0, - 'Docker service being removed even when `auto_remove` set to `False`' + client_mock.remove_service.call_count, + 0, + 'Docker service being removed even when `auto_remove` set to `False`', ) @mock.patch('airflow.providers.docker.operators.docker.APIClient') @@ -170,9 +180,16 @@ def _client_service_logs_effect(): client_class_mock.return_value = client_mock operator = DockerSwarmOperator( - api_version='1.19', command='env', environment={'UNIT': 'TEST'}, image='ubuntu:latest', - mem_limit='128m', user='unittest', task_id='unittest', auto_remove=True, tty=True, - enable_logging=True + api_version='1.19', + command='env', + environment={'UNIT': 'TEST'}, + image='ubuntu:latest', + mem_limit='128m', + user='unittest', + task_id='unittest', + auto_remove=True, + tty=True, + enable_logging=True, ) operator.execute(None) diff --git a/tests/providers/elasticsearch/hooks/test_elasticsearch.py b/tests/providers/elasticsearch/hooks/test_elasticsearch.py index 4fa7216122aeb..854f149d81a47 100644 --- a/tests/providers/elasticsearch/hooks/test_elasticsearch.py +++ b/tests/providers/elasticsearch/hooks/test_elasticsearch.py @@ -25,15 +25,10 @@ class TestElasticsearchHookConn(unittest.TestCase): - def setUp(self): super().setUp() - self.connection = Connection( - host='localhost', - port=9200, - schema='http' - ) + self.connection = Connection(host='localhost', port=9200, schema='http') class UnitTestElasticsearchHook(ElasticsearchHook): conn_name_attr = 'elasticsearch_conn_id' @@ -46,13 +41,10 @@ class UnitTestElasticsearchHook(ElasticsearchHook): def test_get_conn(self, mock_connect): self.db_hook.test_conn_id = 'non_default' # pylint: disable=attribute-defined-outside-init self.db_hook.get_conn() - mock_connect.assert_called_with(host='localhost', port=9200, - scheme='http', user=None, - password=None) + mock_connect.assert_called_with(host='localhost', port=9200, scheme='http', user=None, password=None) class TestElasticsearchHook(unittest.TestCase): - def setUp(self): super().setUp() diff --git a/tests/providers/elasticsearch/log/elasticmock/__init__.py b/tests/providers/elasticsearch/log/elasticmock/__init__.py index eb61f2eb77a33..ed6bc7d783360 100644 --- a/tests/providers/elasticsearch/log/elasticmock/__init__.py +++ b/tests/providers/elasticsearch/log/elasticmock/__init__.py @@ -51,9 +51,7 @@ def _get_elasticmock(hosts=None, *args, **kwargs): # pylint: disable=unused-argument host = _normalize_hosts(hosts)[0] - elastic_key = '{0}:{1}'.format( - host.get('host', 'localhost'), host.get('port', 9200) - ) + elastic_key = '{0}:{1}'.format(host.get('host', 'localhost'), host.get('port', 9200)) if elastic_key in ELASTIC_INSTANCES: connection = ELASTIC_INSTANCES.get(elastic_key) @@ -65,10 +63,12 @@ def _get_elasticmock(hosts=None, *args, **kwargs): # pylint: disable=unused-arg def elasticmock(function): """Elasticmock decorator""" + @wraps(function) def decorated(*args, **kwargs): ELASTIC_INSTANCES.clear() with patch('elasticsearch.Elasticsearch', _get_elasticmock): result = function(*args, **kwargs) return result + return decorated diff --git a/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py b/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py index dd58dbfffa27e..7b5a2c7b68697 100644 --- a/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py +++ b/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py @@ -64,20 +64,30 @@ def info(self, params=None): return { 'status': 200, 'cluster_name': 'elasticmock', - 'version': - { - 'lucene_version': '4.10.4', - 'build_hash': '00f95f4ffca6de89d68b7ccaf80d148f1f70e4d4', - 'number': '1.7.5', - 'build_timestamp': '2016-02-02T09:55:30Z', - 'build_snapshot': False - }, + 'version': { + 'lucene_version': '4.10.4', + 'build_hash': '00f95f4ffca6de89d68b7ccaf80d148f1f70e4d4', + 'number': '1.7.5', + 'build_timestamp': '2016-02-02T09:55:30Z', + 'build_snapshot': False, + }, 'name': 'Nightwatch', - 'tagline': 'You Know, for Search' + 'tagline': 'You Know, for Search', } - @query_params('consistency', 'op_type', 'parent', 'refresh', 'replication', - 'routing', 'timeout', 'timestamp', 'ttl', 'version', 'version_type') + @query_params( + 'consistency', + 'op_type', + 'parent', + 'refresh', + 'replication', + 'routing', + 'timeout', + 'timestamp', + 'ttl', + 'version', + 'version_type', + ) def index(self, index, doc_type, body, id=None, params=None): if index not in self.__documents_dict: self.__documents_dict[index] = [] @@ -87,21 +97,11 @@ def index(self, index, doc_type, body, id=None, params=None): version = 1 - self.__documents_dict[index].append({ - '_type': doc_type, - '_id': id, - '_source': body, - '_index': index, - '_version': version - }) + self.__documents_dict[index].append( + {'_type': doc_type, '_id': id, '_source': body, '_index': index, '_version': version} + ) - return { - '_type': doc_type, - '_id': id, - 'created': True, - '_version': version, - '_index': index - } + return {'_type': doc_type, '_id': id, 'created': True, '_version': version, '_index': index} @query_params('parent', 'preference', 'realtime', 'refresh', 'routing') def exists(self, index, doc_type, id, params=None): @@ -113,9 +113,19 @@ def exists(self, index, doc_type, id, params=None): break return result - @query_params('_source', '_source_exclude', '_source_include', 'fields', - 'parent', 'preference', 'realtime', 'refresh', 'routing', 'version', - 'version_type') + @query_params( + '_source', + '_source_exclude', + '_source_include', + 'fields', + 'parent', + 'preference', + 'realtime', + 'refresh', + 'routing', + 'version', + 'version_type', + ) def get(self, index, id, doc_type='_all', params=None): result = None if index in self.__documents_dict: @@ -124,12 +134,7 @@ def get(self, index, id, doc_type='_all', params=None): if result: result['found'] = True else: - error_data = { - '_index': index, - '_type': doc_type, - '_id': id, - 'found': False - } + error_data = {'_index': index, '_type': doc_type, '_id': id, 'found': False} raise NotFoundError(404, json.dumps(error_data)) return result @@ -142,21 +147,57 @@ def find_document(self, doc_type, id, index, result): break return result - @query_params('_source', '_source_exclude', '_source_include', 'parent', - 'preference', 'realtime', 'refresh', 'routing', 'version', - 'version_type') + @query_params( + '_source', + '_source_exclude', + '_source_include', + 'parent', + 'preference', + 'realtime', + 'refresh', + 'routing', + 'version', + 'version_type', + ) def get_source(self, index, doc_type, id, params=None): document = self.get(index=index, doc_type=doc_type, id=id, params=params) return document.get('_source') - @query_params('_source', '_source_exclude', '_source_include', - 'allow_no_indices', 'analyze_wildcard', 'analyzer', 'default_operator', - 'df', 'expand_wildcards', 'explain', 'fielddata_fields', 'fields', - 'from_', 'ignore_unavailable', 'lenient', 'lowercase_expanded_terms', - 'preference', 'q', 'request_cache', 'routing', 'scroll', 'search_type', - 'size', 'sort', 'stats', 'suggest_field', 'suggest_mode', - 'suggest_size', 'suggest_text', 'terminate_after', 'timeout', - 'track_scores', 'version') + @query_params( + '_source', + '_source_exclude', + '_source_include', + 'allow_no_indices', + 'analyze_wildcard', + 'analyzer', + 'default_operator', + 'df', + 'expand_wildcards', + 'explain', + 'fielddata_fields', + 'fields', + 'from_', + 'ignore_unavailable', + 'lenient', + 'lowercase_expanded_terms', + 'preference', + 'q', + 'request_cache', + 'routing', + 'scroll', + 'search_type', + 'size', + 'sort', + 'stats', + 'suggest_field', + 'suggest_mode', + 'suggest_size', + 'suggest_text', + 'terminate_after', + 'timeout', + 'track_scores', + 'version', + ) def count(self, index=None, doc_type=None, body=None, params=None): searchable_indexes = self._normalize_index_to_list(index) searchable_doc_types = self._normalize_doc_type_to_list(doc_type) @@ -164,47 +205,63 @@ def count(self, index=None, doc_type=None, body=None, params=None): i = 0 for searchable_index in searchable_indexes: for document in self.__documents_dict[searchable_index]: - if searchable_doc_types\ - and document.get('_type') not in searchable_doc_types: + if searchable_doc_types and document.get('_type') not in searchable_doc_types: continue i += 1 - result = { - 'count': i, - '_shards': { - 'successful': 1, - 'failed': 0, - 'total': 1 - } - } + result = {'count': i, '_shards': {'successful': 1, 'failed': 0, 'total': 1}} return result - @query_params('_source', '_source_exclude', '_source_include', - 'allow_no_indices', 'analyze_wildcard', 'analyzer', 'default_operator', - 'df', 'expand_wildcards', 'explain', 'fielddata_fields', 'fields', - 'from_', 'ignore_unavailable', 'lenient', 'lowercase_expanded_terms', - 'preference', 'q', 'request_cache', 'routing', 'scroll', 'search_type', - 'size', 'sort', 'stats', 'suggest_field', 'suggest_mode', - 'suggest_size', 'suggest_text', 'terminate_after', 'timeout', - 'track_scores', 'version') + @query_params( + '_source', + '_source_exclude', + '_source_include', + 'allow_no_indices', + 'analyze_wildcard', + 'analyzer', + 'default_operator', + 'df', + 'expand_wildcards', + 'explain', + 'fielddata_fields', + 'fields', + 'from_', + 'ignore_unavailable', + 'lenient', + 'lowercase_expanded_terms', + 'preference', + 'q', + 'request_cache', + 'routing', + 'scroll', + 'search_type', + 'size', + 'sort', + 'stats', + 'suggest_field', + 'suggest_mode', + 'suggest_size', + 'suggest_text', + 'terminate_after', + 'timeout', + 'track_scores', + 'version', + ) def search(self, index=None, doc_type=None, body=None, params=None): searchable_indexes = self._normalize_index_to_list(index) matches = self._find_match(index, doc_type, body) result = { - 'hits': { - 'total': len(matches), - 'max_score': 1.0 - }, + 'hits': {'total': len(matches), 'max_score': 1.0}, '_shards': { # Simulate indexes with 1 shard each 'successful': len(searchable_indexes), 'failed': 0, - 'total': len(searchable_indexes) + 'total': len(searchable_indexes), }, 'took': 1, - 'timed_out': False + 'timed_out': False, } hits = [] @@ -215,8 +272,9 @@ def search(self, index=None, doc_type=None, body=None, params=None): return result - @query_params('consistency', 'parent', 'refresh', 'replication', 'routing', - 'timeout', 'version', 'version_type') + @query_params( + 'consistency', 'parent', 'refresh', 'replication', 'routing', 'timeout', 'version', 'version_type' + ) def delete(self, index, doc_type, id, params=None): found = False @@ -241,8 +299,7 @@ def delete(self, index, doc_type, id, params=None): else: raise NotFoundError(404, json.dumps(result_dict)) - @query_params('allow_no_indices', 'expand_wildcards', 'ignore_unavailable', - 'preference', 'routing') + @query_params('allow_no_indices', 'expand_wildcards', 'ignore_unavailable', 'preference', 'routing') def suggest(self, body, index=None): if index is not None and index not in self.__documents_dict: raise NotFoundError(404, 'IndexMissingException[[{0}] missing]'.format(index)) @@ -250,20 +307,13 @@ def suggest(self, body, index=None): result_dict = {} for key, value in body.items(): text = value.get('text') - suggestion = int(text) + 1 if isinstance(text, int) \ - else '{0}_suggestion'.format(text) + suggestion = int(text) + 1 if isinstance(text, int) else '{0}_suggestion'.format(text) result_dict[key] = [ { 'text': text, 'length': 1, - 'options': [ - { - 'text': suggestion, - 'freq': 1, - 'score': 1.0 - } - ], - 'offset': 0 + 'options': [{'text': suggestion, 'freq': 1, 'score': 1.0}], + 'offset': 0, } ] return result_dict @@ -314,9 +364,7 @@ def _normalize_index_to_list(self, index): # Check index(es) exists for searchable_index in searchable_indexes: if searchable_index not in self.__documents_dict: - raise NotFoundError(404, - 'IndexMissingException[[{0}] missing]' - .format(searchable_index)) + raise NotFoundError(404, 'IndexMissingException[[{0}] missing]'.format(searchable_index)) return searchable_indexes @@ -334,4 +382,6 @@ def _normalize_doc_type_to_list(doc_type): raise ValueError("Invalid param 'index'") return searchable_doc_types + + # pylint: enable=redefined-builtin diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index bb82886c80369..deec5403665af 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -61,7 +61,7 @@ def setUp(self): self.end_of_log_mark, self.write_stdout, self.json_format, - self.json_fields + self.json_fields, ) self.es = elasticsearch.Elasticsearch( # pylint: disable=invalid-name @@ -70,11 +70,9 @@ def setUp(self): self.index_name = 'test_index' self.doc_type = 'log' self.test_message = 'some random stuff' - self.body = {'message': self.test_message, 'log_id': self.LOG_ID, - 'offset': 1} + self.body = {'message': self.test_message, 'log_id': self.LOG_ID, 'offset': 1} - self.es.index(index=self.index_name, doc_type=self.doc_type, - body=self.body, id=1) + self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=1) self.dag = DAG(self.DAG_ID, start_date=self.EXECUTION_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=self.dag) @@ -105,16 +103,14 @@ def test_client_with_config(self): self.write_stdout, self.json_format, self.json_fields, - es_conf + es_conf, ) def test_read(self): ts = pendulum.now() - logs, metadatas = self.es_task_handler.read(self.ti, - 1, - {'offset': 0, - 'last_log_timestamp': str(ts), - 'end_of_log': False}) + logs, metadatas = self.es_task_handler.read( + self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) @@ -125,20 +121,17 @@ def test_read(self): def test_read_with_match_phrase_query(self): similar_log_id = '{task_id}-{dag_id}-2016-01-01T00:00:00+00:00-1'.format( - dag_id=TestElasticsearchTaskHandler.DAG_ID, - task_id=TestElasticsearchTaskHandler.TASK_ID) + dag_id=TestElasticsearchTaskHandler.DAG_ID, task_id=TestElasticsearchTaskHandler.TASK_ID + ) another_test_message = 'another message' another_body = {'message': another_test_message, 'log_id': similar_log_id, 'offset': 1} - self.es.index(index=self.index_name, doc_type=self.doc_type, - body=another_body, id=1) + self.es.index(index=self.index_name, doc_type=self.doc_type, body=another_body, id=1) ts = pendulum.now() - logs, metadatas = self.es_task_handler.read(self.ti, - 1, - {'offset': 0, - 'last_log_timestamp': str(ts), - 'end_of_log': False}) + logs, metadatas = self.es_task_handler.read( + self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) self.assertEqual(self.test_message, logs[0]) @@ -155,8 +148,7 @@ def test_read_with_none_meatadata(self): self.assertEqual(self.test_message, logs[0]) self.assertFalse(metadatas[0]['end_of_log']) self.assertEqual('1', metadatas[0]['offset']) - self.assertTrue( - timezone.parse(metadatas[0]['last_log_timestamp']) < pendulum.now()) + self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) < pendulum.now()) def test_read_nonexistent_log(self): ts = pendulum.now() @@ -164,11 +156,9 @@ def test_read_nonexistent_log(self): # and doc_type regardless of match filters, so we delete the log entry instead # of making a new TaskInstance to query. self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1) - logs, metadatas = self.es_task_handler.read(self.ti, - 1, - {'offset': 0, - 'last_log_timestamp': str(ts), - 'end_of_log': False}) + logs, metadatas = self.es_task_handler.read( + self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) self.assertEqual([''], logs) @@ -207,11 +197,9 @@ def test_read_timeout(self): ts = pendulum.now().subtract(minutes=5) self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1) - logs, metadatas = self.es_task_handler.read(self.ti, - 1, - {'offset': 0, - 'last_log_timestamp': str(ts), - 'end_of_log': False}) + logs, metadatas = self.es_task_handler.read( + self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False} + ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) self.assertEqual([''], logs) @@ -222,12 +210,11 @@ def test_read_timeout(self): def test_read_as_download_logs(self): ts = pendulum.now() - logs, metadatas = self.es_task_handler.read(self.ti, - 1, - {'offset': 0, - 'last_log_timestamp': str(ts), - 'download_logs': True, - 'end_of_log': False}) + logs, metadatas = self.es_task_handler.read( + self.ti, + 1, + {'offset': 0, 'last_log_timestamp': str(ts), 'download_logs': True, 'end_of_log': False}, + ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) self.assertEqual(self.test_message, logs[0]) @@ -268,9 +255,9 @@ def test_close(self): self.es_task_handler.set_context(self.ti) self.es_task_handler.close() - with open(os.path.join(self.local_log_location, - self.filename_template.format(try_number=1)), - 'r') as log_file: + with open( + os.path.join(self.local_log_location, self.filename_template.format(try_number=1)), 'r' + ) as log_file: # end_of_log_mark may contain characters like '\n' which is needed to # have the log uploaded but will not be stored in elasticsearch. # so apply the strip() to log_file.read() @@ -282,9 +269,9 @@ def test_close_no_mark_end(self): self.ti.raw = True self.es_task_handler.set_context(self.ti) self.es_task_handler.close() - with open(os.path.join(self.local_log_location, - self.filename_template.format(try_number=1)), - 'r') as log_file: + with open( + os.path.join(self.local_log_location, self.filename_template.format(try_number=1)), 'r' + ) as log_file: self.assertNotIn(self.end_of_log_mark, log_file.read()) self.assertTrue(self.es_task_handler.closed) @@ -292,18 +279,18 @@ def test_close_closed(self): self.es_task_handler.closed = True self.es_task_handler.set_context(self.ti) self.es_task_handler.close() - with open(os.path.join(self.local_log_location, - self.filename_template.format(try_number=1)), - 'r') as log_file: + with open( + os.path.join(self.local_log_location, self.filename_template.format(try_number=1)), 'r' + ) as log_file: self.assertEqual(0, len(log_file.read())) def test_close_with_no_handler(self): self.es_task_handler.set_context(self.ti) self.es_task_handler.handler = None self.es_task_handler.close() - with open(os.path.join(self.local_log_location, - self.filename_template.format(try_number=1)), - 'r') as log_file: + with open( + os.path.join(self.local_log_location, self.filename_template.format(try_number=1)), 'r' + ) as log_file: self.assertEqual(0, len(log_file.read())) self.assertTrue(self.es_task_handler.closed) @@ -311,24 +298,26 @@ def test_close_with_no_stream(self): self.es_task_handler.set_context(self.ti) self.es_task_handler.handler.stream = None self.es_task_handler.close() - with open(os.path.join(self.local_log_location, - self.filename_template.format(try_number=1)), - 'r') as log_file: + with open( + os.path.join(self.local_log_location, self.filename_template.format(try_number=1)), 'r' + ) as log_file: self.assertIn(self.end_of_log_mark, log_file.read()) self.assertTrue(self.es_task_handler.closed) self.es_task_handler.set_context(self.ti) self.es_task_handler.handler.stream.close() self.es_task_handler.close() - with open(os.path.join(self.local_log_location, - self.filename_template.format(try_number=1)), - 'r') as log_file: + with open( + os.path.join(self.local_log_location, self.filename_template.format(try_number=1)), 'r' + ) as log_file: self.assertIn(self.end_of_log_mark, log_file.read()) self.assertTrue(self.es_task_handler.closed) def test_render_log_id(self): - expected_log_id = 'dag_for_testing_file_task_handler-' \ - 'task_for_testing_file_log_handler-2016-01-01T00:00:00+00:00-1' + expected_log_id = ( + 'dag_for_testing_file_task_handler-' + 'task_for_testing_file_log_handler-2016-01-01T00:00:00+00:00-1' + ) log_id = self.es_task_handler._render_log_id(self.ti, 1) self.assertEqual(expected_log_id, log_id) @@ -340,7 +329,7 @@ def test_render_log_id(self): self.end_of_log_mark, self.write_stdout, self.json_format, - self.json_fields + self.json_fields, ) log_id = self.es_task_handler._render_log_id(self.ti, 1) self.assertEqual(expected_log_id, log_id) @@ -349,12 +338,14 @@ def test_clean_execution_date(self): clean_execution_date = self.es_task_handler._clean_execution_date(datetime(2016, 7, 8, 9, 10, 11, 12)) self.assertEqual('2016_07_08T09_10_11_000012', clean_execution_date) - @parameterized.expand([ - # Common case - ('localhost:5601/{log_id}', 'https://localhost:5601/' + quote(LOG_ID.replace('T', ' '))), - # Ignore template if "{log_id}"" is missing in the URL - ('localhost:5601', 'https://localhost:5601'), - ]) + @parameterized.expand( + [ + # Common case + ('localhost:5601/{log_id}', 'https://localhost:5601/' + quote(LOG_ID.replace('T', ' '))), + # Ignore template if "{log_id}"" is missing in the URL + ('localhost:5601', 'https://localhost:5601'), + ] + ) def test_get_external_log_url(self, es_frontend, expected_url): es_task_handler = ElasticsearchTaskHandler( self.local_log_location, @@ -364,7 +355,7 @@ def test_get_external_log_url(self, es_frontend, expected_url): self.write_stdout, self.json_format, self.json_fields, - frontend=es_frontend + frontend=es_frontend, ) url = es_task_handler.get_external_log_url(self.ti, self.ti.try_number) self.assertEqual(expected_url, url) diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py index 19444175d9a0e..0a6cdfc34e8b3 100644 --- a/tests/providers/exasol/hooks/test_exasol.py +++ b/tests/providers/exasol/hooks/test_exasol.py @@ -27,16 +27,11 @@ class TestExasolHookConn(unittest.TestCase): - def setUp(self): super(TestExasolHookConn, self).setUp() self.connection = models.Connection( - login='login', - password='password', - host='host', - port=1234, - schema='schema', + login='login', password='password', host='host', port=1234, schema='schema', ) self.db_hook = ExasolHook() @@ -67,7 +62,6 @@ def test_get_conn_extra_args(self, mock_pyexasol): class TestExasolHook(unittest.TestCase): - def setUp(self): super(TestExasolHook, self).setUp() diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index a7ea52d5daa78..68e3d121b48bc 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -24,43 +24,20 @@ class TestExasol(unittest.TestCase): - @mock.patch('airflow.providers.exasol.hooks.exasol.ExasolHook.run') def test_overwrite_autocommit(self, mock_run): - operator = ExasolOperator( - task_id='TEST', - sql='SELECT 1', - autocommit=True - ) + operator = ExasolOperator(task_id='TEST', sql='SELECT 1', autocommit=True) operator.execute({}) - mock_run.assert_called_once_with( - 'SELECT 1', - autocommit=True, - parameters=None - ) + mock_run.assert_called_once_with('SELECT 1', autocommit=True, parameters=None) @mock.patch('airflow.providers.exasol.hooks.exasol.ExasolHook.run') def test_pass_parameters(self, mock_run): - operator = ExasolOperator( - task_id='TEST', - sql='SELECT {value!s}', - parameters={'value': 1} - ) + operator = ExasolOperator(task_id='TEST', sql='SELECT {value!s}', parameters={'value': 1}) operator.execute({}) - mock_run.assert_called_once_with( - 'SELECT {value!s}', - autocommit=False, - parameters={'value': 1} - ) + mock_run.assert_called_once_with('SELECT {value!s}', autocommit=False, parameters={'value': 1}) @mock.patch('airflow.providers.exasol.operators.exasol.ExasolHook') def test_overwrite_schema(self, mock_hook): - operator = ExasolOperator( - task_id='TEST', - sql='SELECT 1', - schema='dummy' - ) + operator = ExasolOperator(task_id='TEST', sql='SELECT 1', schema='dummy') operator.execute({}) - mock_hook.assert_called_once_with( - exasol_conn_id='exasol_default', schema='dummy' - ) + mock_hook.assert_called_once_with(exasol_conn_id='exasol_default', schema='dummy') diff --git a/tests/providers/facebook/ads/hooks/test_ads.py b/tests/providers/facebook/ads/hooks/test_ads.py index 25f63ae3221fd..fc98720407d22 100644 --- a/tests/providers/facebook/ads/hooks/test_ads.py +++ b/tests/providers/facebook/ads/hooks/test_ads.py @@ -21,12 +21,7 @@ from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook API_VERSION = "api_version" -EXTRAS = { - "account_id": "act_12345", - "app_id": "12345", - "app_secret": "1fg444", - "access_token": "Ab35gf7E" -} +EXTRAS = {"account_id": "act_12345", "app_id": "12345", "app_secret": "1fg444", "access_token": "Ab35gf7E"} FIELDS = [ "campaign_name", "campaign_id", @@ -34,10 +29,7 @@ "clicks", "impressions", ] -PARAMS = { - "level": "ad", - "date_preset": "yesterday" -} +PARAMS = {"level": "ad", "date_preset": "yesterday"} @pytest.fixture() @@ -53,23 +45,25 @@ class TestFacebookAdsReportingHook: def test_get_service(self, mock_api, mock_hook): mock_hook._get_service() api = mock_api.init - api.assert_called_once_with(app_id=EXTRAS["app_id"], - app_secret=EXTRAS["app_secret"], - access_token=EXTRAS["access_token"], - account_id=EXTRAS["account_id"], - api_version=API_VERSION) + api.assert_called_once_with( + app_id=EXTRAS["app_id"], + app_secret=EXTRAS["app_secret"], + access_token=EXTRAS["access_token"], + account_id=EXTRAS["account_id"], + api_version=API_VERSION, + ) @mock.patch("airflow.providers.facebook.ads.hooks.ads.AdAccount") @mock.patch("airflow.providers.facebook.ads.hooks.ads.FacebookAdsApi") def test_bulk_facebook_report(self, mock_client, mock_ad_account, mock_hook): mock_client = mock_client.init() ad_account = mock_ad_account().get_insights - ad_account.return_value.api_get.return_value = {"async_status": "Job Completed", - "report_run_id": "12345", - "async_percent_completion": 100} + ad_account.return_value.api_get.return_value = { + "async_status": "Job Completed", + "report_run_id": "12345", + "async_percent_completion": 100, + } mock_hook.bulk_facebook_report(params=PARAMS, fields=FIELDS) - mock_ad_account.assert_has_calls([ - mock.call(mock_client.get_default_account_id(), api=mock_client) - ]) + mock_ad_account.assert_has_calls([mock.call(mock_client.get_default_account_id(), api=mock_client)]) ad_account.assert_called_once_with(params=PARAMS, fields=FIELDS, is_async=True) ad_account.return_value.api_get.assert_has_calls([mock.call(), mock.call()]) diff --git a/tests/providers/ftp/hooks/test_ftp.py b/tests/providers/ftp/hooks/test_ftp.py index 5067819f6b05e..0359e9d0be952 100644 --- a/tests/providers/ftp/hooks/test_ftp.py +++ b/tests/providers/ftp/hooks/test_ftp.py @@ -23,7 +23,6 @@ class TestFTPHook(unittest.TestCase): - def setUp(self): super().setUp() self.path = '/some/path' @@ -123,21 +122,18 @@ def test_retrieve_file_with_callback(self): class TestIntegrationFTPHook(unittest.TestCase): - def setUp(self): super().setUp() from airflow.models import Connection from airflow.utils import db db.merge_conn( - Connection( - conn_id='ftp_passive', conn_type='ftp', - host='localhost', extra='{"passive": true}')) + Connection(conn_id='ftp_passive', conn_type='ftp', host='localhost', extra='{"passive": true}') + ) db.merge_conn( - Connection( - conn_id='ftp_active', conn_type='ftp', - host='localhost', extra='{"passive": false}')) + Connection(conn_id='ftp_active', conn_type='ftp', host='localhost', extra='{"passive": false}') + ) def _test_mode(self, hook_type, connection_id, expected_mode): hook = hook_type(connection_id) @@ -147,19 +143,23 @@ def _test_mode(self, hook_type, connection_id, expected_mode): @mock.patch("ftplib.FTP") def test_ftp_passive_mode(self, mock_ftp): from airflow.providers.ftp.hooks.ftp import FTPHook + self._test_mode(FTPHook, "ftp_passive", True) @mock.patch("ftplib.FTP") def test_ftp_active_mode(self, mock_ftp): from airflow.providers.ftp.hooks.ftp import FTPHook + self._test_mode(FTPHook, "ftp_active", False) @mock.patch("ftplib.FTP_TLS") def test_ftps_passive_mode(self, mock_ftp): from airflow.providers.ftp.hooks.ftp import FTPSHook + self._test_mode(FTPSHook, "ftp_passive", True) @mock.patch("ftplib.FTP_TLS") def test_ftps_active_mode(self, mock_ftp): from airflow.providers.ftp.hooks.ftp import FTPSHook + self._test_mode(FTPSHook, "ftp_active", False) diff --git a/tests/providers/ftp/sensors/test_ftp.py b/tests/providers/ftp/sensors/test_ftp.py index 58b048a379435..df10d60d13bcd 100644 --- a/tests/providers/ftp/sensors/test_ftp.py +++ b/tests/providers/ftp/sensors/test_ftp.py @@ -25,16 +25,15 @@ class TestFTPSensor(unittest.TestCase): - @mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook) def test_poke(self, mock_hook): - op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", - task_id="test_task") + op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task") mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = [ error_perm("550: Can't check for file existence"), error_perm("550: Directory or file does not exist"), - error_perm("550 - Directory or file does not exist"), None + error_perm("550 - Directory or file does not exist"), + None, ] self.assertFalse(op.poke(None)) @@ -44,11 +43,11 @@ def test_poke(self, mock_hook): @mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook) def test_poke_fails_due_error(self, mock_hook): - op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", - task_id="test_task") + op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task") - mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = \ - error_perm("530: Login authentication failed") + mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = error_perm( + "530: Login authentication failed" + ) with self.assertRaises(error_perm) as context: op.execute(None) @@ -57,11 +56,11 @@ def test_poke_fails_due_error(self, mock_hook): @mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook) def test_poke_fail_on_transient_error(self, mock_hook): - op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", - task_id="test_task") + op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task") - mock_hook.return_value.__enter__.return_value\ - .get_mod_time.side_effect = error_perm("434: Host unavailable") + mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = error_perm( + "434: Host unavailable" + ) with self.assertRaises(error_perm) as context: op.execute(None) @@ -70,11 +69,13 @@ def test_poke_fail_on_transient_error(self, mock_hook): @mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook) def test_poke_ignore_transient_error(self, mock_hook): - op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", - task_id="test_task", fail_on_transient_errors=False) + op = FTPSensor( + path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task", fail_on_transient_errors=False + ) mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = [ - error_perm("434: Host unavailable"), None + error_perm("434: Host unavailable"), + None, ] self.assertFalse(op.poke(None)) diff --git a/tests/providers/google/ads/hooks/test_ads.py b/tests/providers/google/ads/hooks/test_ads.py index a7b5433c136e7..a2c1a4f061881 100644 --- a/tests/providers/google/ads/hooks/test_ads.py +++ b/tests/providers/google/ads/hooks/test_ads.py @@ -43,18 +43,14 @@ def test_get_customer_service(self, mock_client, mock_hook): mock_hook._get_customer_service() client = mock_client.load_from_dict client.assert_called_once_with(mock_hook.google_ads_config) - client.return_value.get_service.assert_called_once_with( - "CustomerService", version=API_VERSION - ) + client.return_value.get_service.assert_called_once_with("CustomerService", version=API_VERSION) @mock.patch("airflow.providers.google.ads.hooks.ads.GoogleAdsClient") def test_get_service(self, mock_client, mock_hook): mock_hook._get_service() client = mock_client.load_from_dict client.assert_called_once_with(mock_hook.google_ads_config) - client.return_value.get_service.assert_called_once_with( - "GoogleAdsService", version=API_VERSION - ) + client.return_value.get_service.assert_called_once_with("GoogleAdsService", version=API_VERSION) @mock.patch("airflow.providers.google.ads.hooks.ads.GoogleAdsClient") def test_search(self, mock_client, mock_hook): diff --git a/tests/providers/google/ads/operators/test_ads.py b/tests/providers/google/ads/operators/test_ads.py index 7757b1c440da1..330cfa3c8e229 100644 --- a/tests/providers/google/ads/operators/test_ads.py +++ b/tests/providers/google/ads/operators/test_ads.py @@ -61,12 +61,9 @@ def test_execute(self, mocks_csv_writer, mock_tempfile, mock_gcs_hook, mock_ads_ ) op.execute({}) - mock_ads_hook.assert_called_once_with( - gcp_conn_id=gcp_conn_id, google_ads_conn_id=google_ads_conn_id - ) + mock_ads_hook.assert_called_once_with(gcp_conn_id=gcp_conn_id, google_ads_conn_id=google_ads_conn_id) mock_gcs_hook.assert_called_once_with( - gcp_conn_id=gcp_conn_id, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=gcp_conn_id, impersonation_chain=IMPERSONATION_CHAIN, ) mock_ads_hook.return_value.list_accessible_customers.assert_called_once_with() diff --git a/tests/providers/google/ads/transfers/test_ads_to_gcs.py b/tests/providers/google/ads/transfers/test_ads_to_gcs.py index 9d0d1ed0f5f4b..3823f3f4c3129 100644 --- a/tests/providers/google/ads/transfers/test_ads_to_gcs.py +++ b/tests/providers/google/ads/transfers/test_ads_to_gcs.py @@ -19,7 +19,13 @@ from airflow.providers.google.ads.transfers.ads_to_gcs import GoogleAdsToGcsOperator from tests.providers.google.ads.operators.test_ads import ( - BUCKET, CLIENT_IDS, FIELDS_TO_EXTRACT, GCS_OBJ_PATH, IMPERSONATION_CHAIN, QUERY, gcp_conn_id, + BUCKET, + CLIENT_IDS, + FIELDS_TO_EXTRACT, + GCS_OBJ_PATH, + IMPERSONATION_CHAIN, + QUERY, + gcp_conn_id, google_ads_conn_id, ) @@ -40,15 +46,12 @@ def test_execute(self, mock_gcs_hook, mock_ads_hook): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute({}) - mock_ads_hook.assert_called_once_with( - gcp_conn_id=gcp_conn_id, google_ads_conn_id=google_ads_conn_id - ) + mock_ads_hook.assert_called_once_with(gcp_conn_id=gcp_conn_id, google_ads_conn_id=google_ads_conn_id) mock_ads_hook.return_value.search.assert_called_once_with( client_ids=CLIENT_IDS, query=QUERY, page_size=10000 ) mock_gcs_hook.assert_called_once_with( - gcp_conn_id=gcp_conn_id, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=gcp_conn_id, impersonation_chain=IMPERSONATION_CHAIN, ) mock_gcs_hook.return_value.upload.assert_called_once_with( bucket_name=BUCKET, object_name=GCS_OBJ_PATH, filename=mock.ANY, gzip=False diff --git a/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py b/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py index b9787001bfe94..6fe0ce5e21178 100644 --- a/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py +++ b/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py @@ -28,7 +28,6 @@ class TestSecretManagerClient(TestCase): - @mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient") @mock.patch(INTERNAL_CLIENT_MODULE + ".ClientInfo") def test_auth(self, mock_client_info, mock_secrets_client): @@ -37,13 +36,8 @@ def test_auth(self, mock_client_info, mock_secrets_client): mock_secrets_client.return_value = mock.MagicMock() secrets_client = _SecretManagerClient(credentials="credentials") _ = secrets_client.client - mock_client_info.assert_called_with( - client_library_version='airflow_v' + version - ) - mock_secrets_client.assert_called_with( - credentials='credentials', - client_info=mock_client_info_mock - ) + mock_client_info.assert_called_with(client_library_version='airflow_v' + version) + mock_secrets_client.assert_called_with(credentials='credentials', client_info=mock_client_info_mock) @mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient") @mock.patch(INTERNAL_CLIENT_MODULE + ".ClientInfo") @@ -102,8 +96,9 @@ def test_get_existing_key_with_version(self, mock_client_info, mock_secrets_clie test_response.payload.data = "result".encode("UTF-8") mock_client.access_secret_version.return_value = test_response secrets_client = _SecretManagerClient(credentials="credentials") - secret = secrets_client.get_secret(secret_id="existing", project_id="project_id", - secret_version="test-version") + secret = secrets_client.get_secret( + secret_id="existing", project_id="project_id", secret_version="test-version" + ) mock_client.secret_version_path.assert_called_once_with("project_id", 'existing', 'test-version') self.assertEqual("result", secret) mock_client.access_secret_version.assert_called_once_with('full-path') diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py index 07bdd70efd2ae..102821bc46578 100644 --- a/tests/providers/google/cloud/hooks/test_automl.py +++ b/tests/providers/google/cloud/hooks/test_automl.py @@ -67,9 +67,7 @@ def setUp(self) -> None: @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient") def test_get_conn(self, mock_automl_client, mock_client_info): self.hook.get_conn() - mock_automl_client.assert_called_once_with( - credentials=CREDENTIALS, client_info=CLIENT_INFO - ) + mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) @mock.patch( "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.client_info", @@ -78,23 +76,17 @@ def test_get_conn(self, mock_automl_client, mock_client_info): @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient") def test_prediction_client(self, mock_prediction_client, mock_client_info): client = self.hook.prediction_client # pylint: disable=unused-variable # noqa - mock_prediction_client.assert_called_once_with( - credentials=CREDENTIALS, client_info=CLIENT_INFO - ) + mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model") def test_create_model(self, mock_create_model): - self.hook.create_model( - model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID - ) + self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_model.assert_called_once_with( parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None ) - @mock.patch( - "airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict" - ) + @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict") def test_batch_predict(self, mock_batch_predict): self.hook.batch_predict( model_id=MODEL_ID, @@ -117,33 +109,19 @@ def test_batch_predict(self, mock_batch_predict): @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict") def test_predict(self, mock_predict): self.hook.predict( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - payload=PAYLOAD, + model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, payload=PAYLOAD, ) mock_predict.assert_called_once_with( - name=MODEL_PATH, - payload=PAYLOAD, - params=None, - retry=None, - timeout=None, - metadata=None, + name=MODEL_PATH, payload=PAYLOAD, params=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset") def test_create_dataset(self, mock_create_dataset): - self.hook.create_dataset( - dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID - ) + self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_dataset.assert_called_once_with( - parent=LOCATION_PATH, - dataset=DATASET, - retry=None, - timeout=None, - metadata=None, + parent=LOCATION_PATH, dataset=DATASET, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data") @@ -156,11 +134,7 @@ def test_import_dataset(self, mock_import_data): ) mock_import_data.assert_called_once_with( - name=DATASET_PATH, - input_config=INPUT_CONFIG, - retry=None, - timeout=None, - metadata=None, + name=DATASET_PATH, input_config=INPUT_CONFIG, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs") @@ -179,9 +153,7 @@ def test_list_column_specs(self, mock_list_column_specs): page_size=page_size, ) - parent = AutoMlClient.table_spec_path( - GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec - ) + parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) mock_list_column_specs.assert_called_once_with( parent=parent, field_mask=MASK, @@ -194,23 +166,15 @@ def test_list_column_specs(self, mock_list_column_specs): @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model") def test_get_model(self, mock_get_model): - self.hook.get_model( - model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID - ) + self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - mock_get_model.assert_called_once_with( - name=MODEL_PATH, retry=None, timeout=None, metadata=None - ) + mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model") def test_delete_model(self, mock_delete_model): - self.hook.delete_model( - model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID - ) + self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - mock_delete_model.assert_called_once_with( - name=MODEL_PATH, retry=None, timeout=None, metadata=None - ) + mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset") def test_update_dataset(self, mock_update_dataset): @@ -273,9 +237,7 @@ def test_list_datasets(self, mock_list_datasets): @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset") def test_delete_dataset(self, mock_delete_dataset): - self.hook.delete_dataset( - dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID - ) + self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_delete_dataset.assert_called_once_with( name=DATASET_PATH, retry=None, timeout=None, metadata=None diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index ef4c70fa43512..dbd14d842957c 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -26,8 +26,13 @@ from airflow import AirflowException from airflow.providers.google.cloud.hooks.bigquery import ( - BigQueryCursor, BigQueryHook, _api_resource_configs_duplication_check, _cleanse_time_partitioning, - _split_tablename, _validate_src_fmt_configs, _validate_value, + BigQueryCursor, + BigQueryHook, + _api_resource_configs_duplication_check, + _cleanse_time_partitioning, + _split_tablename, + _validate_src_fmt_configs, + _validate_value, ) PROJECT_ID = "bq-project" @@ -59,9 +64,7 @@ class TestBigQueryHookMethods(_BigQueryBaseTestClass): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryConnection") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook._authorize") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") - def test_bigquery_client_creation( - self, mock_build, mock_authorize, mock_bigquery_connection - ): + def test_bigquery_client_creation(self, mock_build, mock_authorize, mock_bigquery_connection): result = self.hook.get_conn() mock_build.assert_called_once_with( 'bigquery', 'v2', http=mock_authorize.return_value, cache_discovery=False @@ -72,7 +75,7 @@ def test_bigquery_client_creation( hook=self.hook, use_legacy_sql=self.hook.use_legacy_sql, location=self.hook.location, - num_retries=self.hook.num_retries + num_retries=self.hook.num_retries, ) self.assertEqual(mock_bigquery_connection.return_value, result) @@ -81,14 +84,14 @@ def test_bigquery_bigquery_conn_id_deprecation_warning( self, mock_base_hook_init, ): bigquery_conn_id = "bigquery conn id" - warning_message = "The bigquery_conn_id parameter has been deprecated. " \ - "You should pass the gcp_conn_id parameter." + warning_message = ( + "The bigquery_conn_id parameter has been deprecated. " + "You should pass the gcp_conn_id parameter." + ) with self.assertWarns(DeprecationWarning) as warn: BigQueryHook(bigquery_conn_id=bigquery_conn_id) mock_base_hook_init.assert_called_once_with( - delegate_to=None, - gcp_conn_id='bigquery conn id', - impersonation_chain=None, + delegate_to=None, gcp_conn_id='bigquery conn id', impersonation_chain=None, ) self.assertEqual(warning_message, str(warn.warning)) @@ -124,10 +127,7 @@ def test_bigquery_table_exists_false(self, mock_client): def test_bigquery_table_partition_exists_true(self, mock_client): mock_client.return_value.list_partitions.return_value = [PARTITION_ID] result = self.hook.table_partition_exists( - project_id=PROJECT_ID, - dataset_id=DATASET_ID, - table_id=TABLE_ID, - partition_id=PARTITION_ID + project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, partition_id=PARTITION_ID ) mock_client.return_value.list_partitions.assert_called_once_with(TABLE_REFERENCE) mock_client.assert_called_once_with(project_id=PROJECT_ID) @@ -137,10 +137,7 @@ def test_bigquery_table_partition_exists_true(self, mock_client): def test_bigquery_table_partition_exists_false_no_table(self, mock_client): mock_client.return_value.get_table.side_effect = NotFound("Dataset not found") result = self.hook.table_partition_exists( - project_id=PROJECT_ID, - dataset_id=DATASET_ID, - table_id=TABLE_ID, - partition_id=PARTITION_ID + project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, partition_id=PARTITION_ID ) mock_client.return_value.list_partitions.assert_called_once_with(TABLE_REFERENCE) mock_client.assert_called_once_with(project_id=PROJECT_ID) @@ -150,10 +147,7 @@ def test_bigquery_table_partition_exists_false_no_table(self, mock_client): def test_bigquery_table_partition_exists_false_no_partition(self, mock_client): mock_client.return_value.list_partitions.return_value = [] result = self.hook.table_partition_exists( - project_id=PROJECT_ID, - dataset_id=DATASET_ID, - table_id=TABLE_ID, - partition_id=PARTITION_ID + project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, partition_id=PARTITION_ID ) mock_client.return_value.list_partitions.assert_called_once_with(TABLE_REFERENCE) mock_client.assert_called_once_with(project_id=PROJECT_ID) @@ -172,32 +166,35 @@ def test_invalid_schema_update_options(self, mock_get_service): with self.assertRaisesRegex( Exception, r"\['THIS IS NOT VALID'\] contains invalid schema update options.Please only use one or more of " - r"the following options: \['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]" + r"the following options: \['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]", ): self.hook.run_load( "test.test", "test_schema.json", ["test_data.json"], - schema_update_options=["THIS IS NOT VALID"] + schema_update_options=["THIS IS NOT VALID"], ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") def test_invalid_schema_update_and_write_disposition(self, mock_get_service): - with self.assertRaisesRegex(Exception, "schema_update_options is only allowed if" - " write_disposition is 'WRITE_APPEND' or 'WRITE_TRUNCATE'."): + with self.assertRaisesRegex( + Exception, + "schema_update_options is only allowed if" + " write_disposition is 'WRITE_APPEND' or 'WRITE_TRUNCATE'.", + ): self.hook.run_load( "test.test", "test_schema.json", ["test_data.json"], schema_update_options=['ALLOW_FIELD_ADDITION'], - write_disposition='WRITE_EMPTY' + write_disposition='WRITE_EMPTY', ) @mock.patch( "airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete", - side_effect=[False, True] + side_effect=[False, True], ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") def test_cancel_queries(self, mock_client, mock_poll_job_complete): @@ -212,7 +209,9 @@ def test_cancel_queries(self, mock_client, mock_poll_job_complete): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") - def test_run_query_sql_dialect_default(self, mock_insert, _,): + def test_run_query_sql_dialect_default( + self, mock_insert, _, + ): self.hook.run_query('query') _, kwargs = mock_insert.call_args self.assertIs(kwargs['configuration']['query']['useLegacySql'], True) @@ -227,107 +226,103 @@ def test_run_query_sql_dialect(self, mock_insert, _): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_query_sql_dialect_legacy_with_query_params(self, mock_insert, _): - params = [{ - 'name': "param_name", - 'parameterType': {'type': "STRING"}, - 'parameterValue': {'value': "param_value"} - }] + params = [ + { + 'name': "param_name", + 'parameterType': {'type': "STRING"}, + 'parameterValue': {'value': "param_value"}, + } + ] self.hook.run_query('query', use_legacy_sql=False, query_params=params) _, kwargs = mock_insert.call_args self.assertIs(kwargs['configuration']['query']['useLegacySql'], False) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") def test_run_query_sql_dialect_legacy_with_query_params_fails(self, _): - params = [{ - 'name': "param_name", - 'parameterType': {'type': "STRING"}, - 'parameterValue': {'value': "param_value"} - }] + params = [ + { + 'name': "param_name", + 'parameterType': {'type': "STRING"}, + 'parameterValue': {'value': "param_value"}, + } + ] with self.assertRaisesRegex(ValueError, "Query parameters are not allowed when using legacy SQL"): self.hook.run_query('query', use_legacy_sql=True, query_params=params) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") def test_run_query_without_sql_fails(self, _): with self.assertRaisesRegex( - TypeError, - r"`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`" + TypeError, r"`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`" ): self.hook.run_query(sql=None) - @parameterized.expand([ - (['ALLOW_FIELD_ADDITION'], 'WRITE_APPEND'), - (['ALLOW_FIELD_RELAXATION'], 'WRITE_APPEND'), - (['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'], 'WRITE_APPEND'), - (['ALLOW_FIELD_ADDITION'], 'WRITE_TRUNCATE'), - (['ALLOW_FIELD_RELAXATION'], 'WRITE_TRUNCATE'), - (['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'], 'WRITE_TRUNCATE'), - ]) + @parameterized.expand( + [ + (['ALLOW_FIELD_ADDITION'], 'WRITE_APPEND'), + (['ALLOW_FIELD_RELAXATION'], 'WRITE_APPEND'), + (['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'], 'WRITE_APPEND'), + (['ALLOW_FIELD_ADDITION'], 'WRITE_TRUNCATE'), + (['ALLOW_FIELD_RELAXATION'], 'WRITE_TRUNCATE'), + (['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'], 'WRITE_TRUNCATE'), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_query_schema_update_options( - self, - schema_update_options, - write_disposition, - mock_insert, - mock_get_service, + self, schema_update_options, write_disposition, mock_insert, mock_get_service, ): self.hook.run_query( sql='query', destination_dataset_table='my_dataset.my_table', schema_update_options=schema_update_options, - write_disposition=write_disposition + write_disposition=write_disposition, ) _, kwargs = mock_insert.call_args - self.assertEqual( - kwargs['configuration']['query']['schemaUpdateOptions'], - schema_update_options - ) - self.assertEqual( - kwargs['configuration']['query']['writeDisposition'], - write_disposition - ) - - @parameterized.expand([ - ( - ['INCORRECT_OPTION'], - None, - r"\['INCORRECT_OPTION'\] contains invalid schema update options\. " - r"Please only use one or more of the following options: " - r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]" - ), - ( - ['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION', 'INCORRECT_OPTION'], - None, - r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION', 'INCORRECT_OPTION'\] contains invalid " - r"schema update options\. Please only use one or more of the following options: " - r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]" - ), - ( - ['ALLOW_FIELD_ADDITION'], - None, - r"schema_update_options is only allowed if write_disposition is " - r"'WRITE_APPEND' or 'WRITE_TRUNCATE'"), - ]) + self.assertEqual(kwargs['configuration']['query']['schemaUpdateOptions'], schema_update_options) + self.assertEqual(kwargs['configuration']['query']['writeDisposition'], write_disposition) + + @parameterized.expand( + [ + ( + ['INCORRECT_OPTION'], + None, + r"\['INCORRECT_OPTION'\] contains invalid schema update options\. " + r"Please only use one or more of the following options: " + r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]", + ), + ( + ['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION', 'INCORRECT_OPTION'], + None, + r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION', 'INCORRECT_OPTION'\] contains invalid " + r"schema update options\. Please only use one or more of the following options: " + r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]", + ), + ( + ['ALLOW_FIELD_ADDITION'], + None, + r"schema_update_options is only allowed if write_disposition is " + r"'WRITE_APPEND' or 'WRITE_TRUNCATE'", + ), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") def test_run_query_schema_update_options_incorrect( - self, - schema_update_options, - write_disposition, - expected_regex, - mock_get_service, + self, schema_update_options, write_disposition, expected_regex, mock_get_service, ): with self.assertRaisesRegex(ValueError, expected_regex): self.hook.run_query( sql='query', destination_dataset_table='my_dataset.my_table', schema_update_options=schema_update_options, - write_disposition=write_disposition + write_disposition=write_disposition, ) @parameterized.expand([(True,), (False,)]) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") - def test_api_resource_configs(self, bool_val, mock_insert, _,): + def test_api_resource_configs( + self, bool_val, mock_insert, _, + ): self.hook.run_query('query', api_resource_configs={'query': {'useQueryCache': bool_val}}) _, kwargs = mock_insert.call_args self.assertIs(kwargs["configuration"]['query']['useQueryCache'], bool_val) @@ -339,12 +334,11 @@ def test_api_resource_configs_duplication_warning(self, mock_get_service): ValueError, r"Values of useLegacySql param are duplicated\. api_resource_configs contained useLegacySql " r"param in `query` config and useLegacySql was also provided with arg to run_query\(\) method\. " - r"Please remove duplicates\." + r"Please remove duplicates\.", ): - self.hook.run_query('query', - use_legacy_sql=True, - api_resource_configs={'query': {'useLegacySql': False}} - ) + self.hook.run_query( + 'query', use_legacy_sql=True, api_resource_configs={'query': {'useLegacySql': False}} + ) def test_validate_value(self): with self.assertRaisesRegex( @@ -358,12 +352,11 @@ def test_duplication_check(self): ValueError, r"Values of key_one param are duplicated. api_resource_configs contained key_one param in" r" `query` config and key_one was also provided with arg to run_query\(\) method. " - r"Please remove duplicates." + r"Please remove duplicates.", ): key_one = True _api_resource_configs_duplication_check("key_one", key_one, {"key_one": False}) - self.assertIsNone(_api_resource_configs_duplication_check( - "key_one", key_one, {"key_one": True})) + self.assertIsNone(_api_resource_configs_duplication_check("key_one", key_one, {"key_one": True})) def test_validate_src_fmt_configs(self): source_format = "test_format" @@ -375,23 +368,23 @@ def test_validate_src_fmt_configs(self): ): # This config should raise a value error. src_fmt_configs = {"test_config_unknown": "val"} - _validate_src_fmt_configs(source_format, - src_fmt_configs, - valid_configs, - backward_compatibility_configs) + _validate_src_fmt_configs( + source_format, src_fmt_configs, valid_configs, backward_compatibility_configs + ) src_fmt_configs = {"test_config_known": "val"} - src_fmt_configs = _validate_src_fmt_configs(source_format, src_fmt_configs, valid_configs, - backward_compatibility_configs) - assert "test_config_known" in src_fmt_configs, \ - "src_fmt_configs should contain al known src_fmt_configs" + src_fmt_configs = _validate_src_fmt_configs( + source_format, src_fmt_configs, valid_configs, backward_compatibility_configs + ) + assert ( + "test_config_known" in src_fmt_configs + ), "src_fmt_configs should contain al known src_fmt_configs" - assert "compatibility_val" in src_fmt_configs, \ - "_validate_src_fmt_configs should add backward_compatibility config" + assert ( + "compatibility_val" in src_fmt_configs + ), "_validate_src_fmt_configs should add backward_compatibility config" - @parameterized.expand( - [("AVRO",), ("PARQUET",), ("NEWLINE_DELIMITED_JSON",), ("DATASTORE_BACKUP",)] - ) + @parameterized.expand([("AVRO",), ("PARQUET",), ("NEWLINE_DELIMITED_JSON",), ("DATASTORE_BACKUP",)]) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_load_with_non_csv_as_src_fmt(self, fmt, _): @@ -400,7 +393,7 @@ def test_run_load_with_non_csv_as_src_fmt(self, fmt, _): destination_project_dataset_table='my_dataset.my_table', source_uris=[], source_format=fmt, - autodetect=True + autodetect=True, ) except ValueError: self.fail("run_load() raised ValueError unexpectedly!") @@ -411,11 +404,7 @@ def test_run_extract(self, mock_insert): destination_cloud_storage_uris = ["gs://bucket/file.csv"] expected_configuration = { "extract": { - "sourceTable": { - "projectId": PROJECT_ID, - "datasetId": DATASET_ID, - "tableId": TABLE_ID, - }, + "sourceTable": {"projectId": PROJECT_ID, "datasetId": DATASET_ID, "tableId": TABLE_ID,}, "compression": "NONE", "destinationUris": destination_cloud_storage_uris, "destinationFormat": "CSV", @@ -426,7 +415,7 @@ def test_run_extract(self, mock_insert): self.hook.run_extract( source_project_dataset_table=source_project_dataset_table, - destination_cloud_storage_uris=destination_cloud_storage_uris + destination_cloud_storage_uris=destination_cloud_storage_uris, ) mock_insert.assert_called_once_with(configuration=expected_configuration, project_id=PROJECT_ID) @@ -444,15 +433,13 @@ def test_list_rows(self, mock_client, mock_schema, mock_table): location=LOCATION, ) mock_table.from_api_repr.assert_called_once_with({"tableReference": TABLE_REFERENCE_REPR}) - mock_schema.has_calls( - [mock.call(x, "") for x in ["field_1", "field_2"]] - ) + mock_schema.has_calls([mock.call(x, "") for x in ["field_1", "field_2"]]) mock_client.return_value.list_rows.assert_called_once_with( table=mock_table.from_api_repr.return_value, max_results=10, selected_fields=mock.ANY, page_token='page123', - start_index=5 + start_index=5, ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table") @@ -462,18 +449,13 @@ def test_run_table_delete(self, mock_client, mock_table): self.hook.run_table_delete(source_project_dataset_table, ignore_if_missing=False) mock_table.from_string.assert_called_once_with(source_project_dataset_table) mock_client.return_value.delete_table.assert_called_once_with( - table=mock_table.from_string.return_value, - not_found_ok=False + table=mock_table.from_string.return_value, not_found_ok=False ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables") def test_table_upsert_create_new_table(self, mock_get, mock_create): - table_resource = { - "tableReference": { - "tableId": TABLE_ID - } - } + table_resource = {"tableReference": {"tableId": TABLE_ID}} mock_get.return_value = [] self.hook.run_table_upsert(dataset_id=DATASET_ID, table_resource=table_resource) @@ -484,11 +466,7 @@ def test_table_upsert_create_new_table(self, mock_get, mock_create): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_table") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset_tables") def test_table_upsert_already_exists(self, mock_get, mock_update): - table_resource = { - "tableReference": { - "tableId": TABLE_ID - } - } + table_resource = {"tableReference": {"tableId": TABLE_ID}} mock_get.return_value = [{"tableId": TABLE_ID}] self.hook.run_table_upsert(dataset_id=DATASET_ID, table_resource=table_resource) @@ -504,11 +482,7 @@ def test_run_grant_dataset_view_access_granting(self, mock_update, mock_get): view_access = AccessEntry( role=None, entity_type="view", - entity_id={ - 'projectId': PROJECT_ID, - 'datasetId': view_dataset, - 'tableId': view_table - } + entity_id={'projectId': PROJECT_ID, 'datasetId': view_dataset, 'tableId': view_table}, ) dataset = Dataset(DatasetReference.from_string(DATASET_ID, PROJECT_ID)) @@ -516,17 +490,13 @@ def test_run_grant_dataset_view_access_granting(self, mock_update, mock_get): mock_get.return_value = dataset self.hook.run_grant_dataset_view_access( - source_dataset=DATASET_ID, - view_dataset=view_dataset, - view_table=view_table + source_dataset=DATASET_ID, view_dataset=view_dataset, view_table=view_table ) mock_get.assert_called_once_with(project_id=PROJECT_ID, dataset_id=DATASET_ID) assert view_access in dataset.access_entries mock_update.assert_called_once_with( - fields=["access"], - dataset_resource=dataset.to_api_repr(), - project_id=PROJECT_ID, + fields=["access"], dataset_resource=dataset.to_api_repr(), project_id=PROJECT_ID, ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset") @@ -537,11 +507,7 @@ def test_run_grant_dataset_view_access_already_granted(self, mock_update, mock_g view_access = AccessEntry( role=None, entity_type="view", - entity_id={ - 'projectId': PROJECT_ID, - 'datasetId': view_dataset, - 'tableId': view_table - } + entity_id={'projectId': PROJECT_ID, 'datasetId': view_dataset, 'tableId': view_table}, ) dataset = Dataset(DatasetReference.from_string(DATASET_ID, PROJECT_ID)) @@ -549,9 +515,7 @@ def test_run_grant_dataset_view_access_already_granted(self, mock_update, mock_g mock_get.return_value = dataset self.hook.run_grant_dataset_view_access( - source_dataset=DATASET_ID, - view_dataset=view_dataset, - view_table=view_table + source_dataset=DATASET_ID, view_dataset=view_dataset, view_table=view_table ) mock_get.assert_called_once_with(project_id=PROJECT_ID, dataset_id=DATASET_ID) @@ -572,18 +536,13 @@ def test_get_dataset_tables_list(self, mock_client): result = self.hook.get_dataset_tables_list(dataset_id=DATASET_ID, project_id=PROJECT_ID) mock_client.return_value.list_tables.assert_called_once_with( - dataset=dataset_reference, - max_results=None + dataset=dataset_reference, max_results=None ) self.assertEqual(table_list, result) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") def test_poll_job_complete(self, mock_client): - self.hook.poll_job_complete( - job_id=JOB_ID, - location=LOCATION, - project_id=PROJECT_ID - ) + self.hook.poll_job_complete(job_id=JOB_ID, location=LOCATION, project_id=PROJECT_ID) mock_client.assert_called_once_with(location=LOCATION, project_id=PROJECT_ID) mock_client.return_value.get_job.assert_called_once_with(job_id=JOB_ID) mock_client.return_value.get_job.return_value.done.assert_called_once_with(retry=DEFAULT_RETRY) @@ -615,8 +574,10 @@ def test_cancel_query_cancel_timeout( assert poll_job_complete.call_count == 13 assert mock_sleep.call_count == 11 mock_logger_info.has_call( - mock.call(f"Stopping polling due to timeout. Job with id {JOB_ID} " - "has not completed cancel and may or may not finish.") + mock.call( + f"Stopping polling due to timeout. Job with id {JOB_ID} " + "has not completed cancel and may or may not finish." + ) ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") @@ -644,7 +605,7 @@ def test_get_schema(self, mock_client): {'name': 'id', 'type': 'STRING', 'mode': 'REQUIRED'}, {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, ] - } + }, } mock_client.return_value.get_table.return_value = Table.from_api_repr(table) @@ -659,18 +620,14 @@ def test_invalid_source_format(self, mock_get_service): with self.assertRaisesRegex( Exception, r"JSON is not a valid source format. Please use one of the following types: \['CSV', " - r"'NEWLINE_DELIMITED_JSON', 'AVRO', 'GOOGLE_SHEETS', 'DATASTORE_BACKUP', 'PARQUET'\]" + r"'NEWLINE_DELIMITED_JSON', 'AVRO', 'GOOGLE_SHEETS', 'DATASTORE_BACKUP', 'PARQUET'\]", ): - self.hook.run_load( - "test.test", "test_schema.json", ["test_data.json"], source_format="json" - ) + self.hook.run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") def test_insert_all_succeed(self, mock_client, mock_table): - rows = [ - {"json": {"a_key": "a_value_0"}} - ] + rows = [{"json": {"a_key": "a_value_0"}}] self.hook.insert_all( project_id=PROJECT_ID, @@ -685,67 +642,50 @@ def test_insert_all_succeed(self, mock_client, mock_table): table=mock_table.from_api_repr.return_value, rows=rows, ignore_unknown_values=True, - skip_invalid_rows=True + skip_invalid_rows=True, ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") def test_insert_all_fail(self, mock_client): - rows = [ - {"json": {"a_key": "a_value_0"}} - ] + rows = [{"json": {"a_key": "a_value_0"}}] mock_client.return_value.insert_rows.return_value = ["some", "errors"] with self.assertRaisesRegex(AirflowException, "insert error"): - self.hook.insert_all(project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, - rows=rows, fail_on_error=True) + self.hook.insert_all( + project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, rows=rows, fail_on_error=True + ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_query_with_arg(self, mock_insert): self.hook.run_query( sql='select 1', destination_dataset_table='my_dataset.my_table', - labels={'label1': 'test1', 'label2': 'test2'} + labels={'label1': 'test1', 'label2': 'test2'}, ) _, kwargs = mock_insert.call_args - self.assertEqual( - kwargs["configuration"]['labels'], - {'label1': 'test1', 'label2': 'test2'} - ) + self.assertEqual(kwargs["configuration"]['labels'], {'label1': 'test1', 'label2': 'test2'}) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.QueryJob") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") def test_insert_job(self, mock_client, mock_query_job): - job_conf = { - "query": { - "query": "SELECT * FROM test", - "useLegacySql": "False", - } - } + job_conf = {"query": {"query": "SELECT * FROM test", "useLegacySql": "False",}} mock_query_job._JOB_TYPE = "query" self.hook.insert_job( - configuration=job_conf, - job_id=JOB_ID, - project_id=PROJECT_ID, - location=LOCATION, + configuration=job_conf, job_id=JOB_ID, project_id=PROJECT_ID, location=LOCATION, ) mock_client.assert_called_once_with( - project_id=PROJECT_ID, - location=LOCATION, + project_id=PROJECT_ID, location=LOCATION, ) mock_query_job.from_api_repr.assert_called_once_with( { 'configuration': job_conf, - 'jobReference': { - 'jobId': JOB_ID, - 'projectId': PROJECT_ID, - 'location': LOCATION - } + 'jobReference': {'jobId': JOB_ID, 'projectId': PROJECT_ID, 'location': LOCATION}, }, - mock_client.return_value + mock_client.return_value, ) mock_query_job.from_api_repr.return_value.result.assert_called_once_with() @@ -755,13 +695,15 @@ def test_internal_need_default_project(self): with self.assertRaisesRegex(Exception, "INTERNAL: No default project is specified"): _split_tablename("dataset.table", None) - @parameterized.expand([ - ("project", "dataset", "table", "dataset.table"), - ("alternative", "dataset", "table", "alternative:dataset.table"), - ("alternative", "dataset", "table", "alternative.dataset.table"), - ("alt1:alt", "dataset", "table", "alt1:alt.dataset.table"), - ("alt1:alt", "dataset", "table", "alt1:alt:dataset.table"), - ]) + @parameterized.expand( + [ + ("project", "dataset", "table", "dataset.table"), + ("alternative", "dataset", "table", "alternative:dataset.table"), + ("alternative", "dataset", "table", "alternative.dataset.table"), + ("alt1:alt", "dataset", "table", "alt1:alt.dataset.table"), + ("alt1:alt", "dataset", "table", "alt1:alt:dataset.table"), + ] + ) def test_split_tablename(self, project_expected, dataset_expected, table_expected, table_input): default_project_id = "project" project, dataset, table = _split_tablename(table_input, default_project_id) @@ -769,26 +711,32 @@ def test_split_tablename(self, project_expected, dataset_expected, table_expecte self.assertEqual(dataset_expected, dataset) self.assertEqual(table_expected, table) - @parameterized.expand([ - ("alt1:alt2:alt3:dataset.table", None, "Use either : or . to specify project got {}"), - ( - "alt1.alt.dataset.table", None, - r"Expect format of \(\.
, got {}", - ), - ( - "alt1:alt2:alt.dataset.table", "var_x", - "Format exception for var_x: Use either : or . to specify project got {}", - ), - ( - "alt1:alt2:alt:dataset.table", "var_x", - "Format exception for var_x: Use either : or . to specify project got {}", - ), - ( - "alt1.alt.dataset.table", "var_x", - r"Format exception for var_x: Expect format of " - r"\(.
, got {}", - ), - ]) + @parameterized.expand( + [ + ("alt1:alt2:alt3:dataset.table", None, "Use either : or . to specify project got {}"), + ( + "alt1.alt.dataset.table", + None, + r"Expect format of \(\.
, got {}", + ), + ( + "alt1:alt2:alt.dataset.table", + "var_x", + "Format exception for var_x: Use either : or . to specify project got {}", + ), + ( + "alt1:alt2:alt:dataset.table", + "var_x", + "Format exception for var_x: Use either : or . to specify project got {}", + ), + ( + "alt1.alt.dataset.table", + "var_x", + r"Format exception for var_x: Expect format of " + r"\(.
, got {}", + ), + ] + ) def test_invalid_syntax(self, table_input, var_name, exception_message): default_project_id = "project" with self.assertRaisesRegex(Exception, exception_message.format(table_input)): @@ -801,21 +749,16 @@ class TestTableOperations(_BigQueryBaseTestClass): def test_create_view(self, mock_bq_client, mock_table): view = { 'query': 'SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*`', - "useLegacySql": False + "useLegacySql": False, } self.hook.create_empty_table( project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, view=view, retry=DEFAULT_RETRY ) - body = { - 'tableReference': TABLE_REFERENCE_REPR, - 'view': view - } + body = {'tableReference': TABLE_REFERENCE_REPR, 'view': view} mock_table.from_api_repr.assert_called_once_with(body) mock_bq_client.return_value.create_table.assert_called_once_with( - table=mock_table.from_api_repr.return_value, - exists_ok=True, - retry=DEFAULT_RETRY, + table=mock_table.from_api_repr.return_value, exists_ok=True, retry=DEFAULT_RETRY, ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table") @@ -829,15 +772,13 @@ def test_patch_table(self, mock_client, mock_table): {'name': 'id', 'type': 'STRING', 'mode': 'REQUIRED'}, {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, {'name': 'balance', 'type': 'FLOAT', 'mode': 'NULLABLE'}, - {'name': 'new_field', 'type': 'STRING', 'mode': 'NULLABLE'} + {'name': 'new_field', 'type': 'STRING', 'mode': 'NULLABLE'}, ] - time_partitioning_patched = { - 'expirationMs': 10000000 - } + time_partitioning_patched = {'expirationMs': 10000000} require_partition_filter_patched = True view_patched = { 'query': "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500", - 'useLegacySql': False + 'useLegacySql': False, } self.hook.patch_table( @@ -847,10 +788,11 @@ def test_patch_table(self, mock_client, mock_table): description=description_patched, expiration_time=expiration_time_patched, friendly_name=friendly_name_patched, - labels=labels_patched, schema=schema_patched, + labels=labels_patched, + schema=schema_patched, time_partitioning=time_partitioning_patched, require_partition_filter=require_partition_filter_patched, - view=view_patched + view=view_patched, ) body = { @@ -858,9 +800,7 @@ def test_patch_table(self, mock_client, mock_table): "expirationTime": expiration_time_patched, "friendlyName": friendly_name_patched, "labels": labels_patched, - "schema": { - "fields": schema_patched - }, + "schema": {"fields": schema_patched}, "timePartitioning": time_partitioning_patched, "view": view_patched, "requirePartitionFilter": require_partition_filter_patched, @@ -870,30 +810,18 @@ def test_patch_table(self, mock_client, mock_table): mock_table.from_api_repr.assert_called_once_with(body) mock_client.return_value.update_table.assert_called_once_with( - table=mock_table.from_api_repr.return_value, - fields=fields + table=mock_table.from_api_repr.return_value, fields=fields ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") def test_create_empty_table_succeed(self, mock_bq_client, mock_table): - self.hook.create_empty_table( - project_id=PROJECT_ID, - dataset_id=DATASET_ID, - table_id=TABLE_ID) + self.hook.create_empty_table(project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID) - body = { - 'tableReference': { - 'tableId': TABLE_ID, - 'projectId': PROJECT_ID, - 'datasetId': DATASET_ID, - } - } + body = {'tableReference': {'tableId': TABLE_ID, 'projectId': PROJECT_ID, 'datasetId': DATASET_ID,}} mock_table.from_api_repr.assert_called_once_with(body) mock_bq_client.return_value.create_table.assert_called_once_with( - table=mock_table.from_api_repr.return_value, - exists_ok=True, - retry=DEFAULT_RETRY + table=mock_table.from_api_repr.return_value, exists_ok=True, retry=DEFAULT_RETRY ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table") @@ -913,28 +841,18 @@ def test_create_empty_table_with_extras_succeed(self, mock_bq_client, mock_table table_id=TABLE_ID, schema_fields=schema_fields, time_partitioning=time_partitioning, - cluster_fields=cluster_fields + cluster_fields=cluster_fields, ) body = { - 'tableReference': { - 'tableId': TABLE_ID, - 'projectId': PROJECT_ID, - 'datasetId': DATASET_ID, - }, - 'schema': { - 'fields': schema_fields - }, + 'tableReference': {'tableId': TABLE_ID, 'projectId': PROJECT_ID, 'datasetId': DATASET_ID,}, + 'schema': {'fields': schema_fields}, 'timePartitioning': time_partitioning, - 'clustering': { - 'fields': cluster_fields - } + 'clustering': {'fields': cluster_fields}, } mock_table.from_api_repr.assert_called_once_with(body) mock_bq_client.return_value.create_table.assert_called_once_with( - table=mock_table.from_api_repr.return_value, - exists_ok=True, - retry=DEFAULT_RETRY + table=mock_table.from_api_repr.return_value, exists_ok=True, retry=DEFAULT_RETRY ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") @@ -946,10 +864,10 @@ def test_get_tables_list(self, mock_client): "tableReference": { "projectId": "your-project", "datasetId": "your_dataset", - "tableId": "table1" + "tableId": "table1", }, "type": "TABLE", - "creationTime": "1565781859261" + "creationTime": "1565781859261", }, { "kind": "bigquery#table", @@ -957,11 +875,11 @@ def test_get_tables_list(self, mock_client): "tableReference": { "projectId": "your-project", "datasetId": "your_dataset", - "tableId": "table2" + "tableId": "table2", }, "type": "TABLE", - "creationTime": "1565782713480" - } + "creationTime": "1565782713480", + }, ] table_list_response = [Table.from_api_repr(t) for t in table_list] mock_client.return_value.list_tables.return_value = table_list_response @@ -970,9 +888,7 @@ def test_get_tables_list(self, mock_client): result = self.hook.get_dataset_tables(dataset_id=DATASET_ID, project_id=PROJECT_ID) mock_client.return_value.list_tables.assert_called_once_with( - dataset=dataset_reference, - max_results=None, - retry=DEFAULT_RETRY, + dataset=dataset_reference, max_results=None, retry=DEFAULT_RETRY, ) for res, exp in zip(result, table_list): assert res["tableId"] == exp["tableReference"]["tableId"] @@ -989,7 +905,7 @@ def test_execute_with_parameters(self, mock_insert, _): 'query': "SELECT 'bar'", 'priority': 'INTERACTIVE', 'useLegacySql': True, - 'schemaUpdateOptions': [] + 'schemaUpdateOptions': [], } } mock_insert.assert_called_once_with(configuration=conf, project_id=PROJECT_ID) @@ -1007,10 +923,10 @@ def test_execute_many(self, mock_insert, _): 'query': "SELECT 'bar'", 'priority': 'INTERACTIVE', 'useLegacySql': True, - 'schemaUpdateOptions': [] + 'schemaUpdateOptions': [], } }, - project_id=PROJECT_ID + project_id=PROJECT_ID, ), mock.call( configuration={ @@ -1018,11 +934,11 @@ def test_execute_many(self, mock_insert, _): 'query': "SELECT 'baz'", 'priority': 'INTERACTIVE', 'useLegacySql': True, - 'schemaUpdateOptions': [] + 'schemaUpdateOptions': [], } }, - project_id=PROJECT_ID - ) + project_id=PROJECT_ID, + ), ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @@ -1053,8 +969,7 @@ def test_fetchone(self, mock_next, mock_get_service): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @mock.patch( - "airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.fetchone", - side_effect=[1, 2, 3, None] + "airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.fetchone", side_effect=[1, 2, 3, None] ) def test_fetchall(self, mock_fetchone, mock_get_service): bq_cursor = self.hook.get_cursor() @@ -1103,17 +1018,11 @@ def test_next(self, mock_get_service): mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults mock_execute = mock_get_query_results.return_value.execute mock_execute.return_value = { - "rows": [ - {"f": [{"v": "one"}, {"v": 1}]}, - {"f": [{"v": "two"}, {"v": 2}]}, - ], + "rows": [{"f": [{"v": "one"}, {"v": 1}]}, {"f": [{"v": "two"}, {"v": 2}]},], "pageToken": None, "schema": { - "fields": [ - {"name": "field_1", "type": "STRING"}, - {"name": "field_2", "type": "INTEGER"}, - ] - } + "fields": [{"name": "field_1", "type": "STRING"}, {"name": "field_2", "type": "INTEGER"},] + }, } bq_cursor = self.hook.get_cursor() @@ -1126,8 +1035,9 @@ def test_next(self, mock_get_service): result = bq_cursor.next() self.assertEqual(['two', 2], result) - mock_get_query_results.assert_called_once_with(jobId=JOB_ID, location=LOCATION, pageToken=None, - projectId='bq-project') + mock_get_query_results.assert_called_once_with( + jobId=JOB_ID, location=LOCATION, pageToken=None, projectId='bq-project' + ) mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @@ -1143,8 +1053,9 @@ def test_next_no_rows(self, mock_flush_results, mock_get_service): result = bq_cursor.next() self.assertIsNone(result) - mock_get_query_results.assert_called_once_with(jobId=JOB_ID, location=None, pageToken=None, - projectId='bq-project') + mock_get_query_results.assert_called_once_with( + jobId=JOB_ID, location=None, pageToken=None, projectId='bq-project' + ) mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries) assert mock_flush_results.call_count == 1 @@ -1190,17 +1101,13 @@ def test_create_empty_dataset_with_params(self, mock_client, mock_dataset): self.hook.create_empty_dataset(project_id=PROJECT_ID, dataset_id=DATASET_ID, location=LOCATION) expected_body = { "location": LOCATION, - "datasetReference": { - "datasetId": DATASET_ID, - "projectId": PROJECT_ID - } + "datasetReference": {"datasetId": DATASET_ID, "projectId": PROJECT_ID}, } api_repr = mock_dataset.from_api_repr api_repr.assert_called_once_with(expected_body) mock_client.return_value.create_dataset.assert_called_once_with( - dataset=api_repr.return_value, - exists_ok=True + dataset=api_repr.return_value, exists_ok=True ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Dataset") @@ -1208,18 +1115,14 @@ def test_create_empty_dataset_with_params(self, mock_client, mock_dataset): def test_create_empty_dataset_with_object(self, mock_client, mock_dataset): dataset = { "location": "LOCATION", - "datasetReference": { - "datasetId": "DATASET_ID", - "projectId": "PROJECT_ID" - } + "datasetReference": {"datasetId": "DATASET_ID", "projectId": "PROJECT_ID"}, } self.hook.create_empty_dataset(dataset_reference=dataset) api_repr = mock_dataset.from_api_repr api_repr.assert_called_once_with(dataset) mock_client.return_value.create_dataset.assert_called_once_with( - dataset=api_repr.return_value, - exists_ok=True + dataset=api_repr.return_value, exists_ok=True ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Dataset") @@ -1227,10 +1130,7 @@ def test_create_empty_dataset_with_object(self, mock_client, mock_dataset): def test_create_empty_dataset_use_values_from_object(self, mock_client, mock_dataset): dataset = { "location": "LOCATION", - "datasetReference": { - "datasetId": "DATASET_ID", - "projectId": "PROJECT_ID" - } + "datasetReference": {"datasetId": "DATASET_ID", "projectId": "PROJECT_ID"}, } self.hook.create_empty_dataset( dataset_reference=dataset, @@ -1242,8 +1142,7 @@ def test_create_empty_dataset_use_values_from_object(self, mock_client, mock_dat api_repr = mock_dataset.from_api_repr api_repr.assert_called_once_with(dataset) mock_client.return_value.create_dataset.assert_called_once_with( - dataset=api_repr.return_value, - exists_ok=True + dataset=api_repr.return_value, exists_ok=True ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") @@ -1252,10 +1151,7 @@ def test_get_dataset(self, mock_client): "kind": "bigquery#dataset", "location": "US", "id": "your-project:dataset_2_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_2_test" - } + "datasetReference": {"projectId": "your-project", "datasetId": "dataset_2_test"}, } expected_result = Dataset.from_api_repr(_expected_result) mock_client.return_value.get_dataset.return_value = expected_result @@ -1274,20 +1170,14 @@ def test_get_datasets_list(self, mock_client): "kind": "bigquery#dataset", "location": "US", "id": "your-project:dataset_2_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_2_test" - } + "datasetReference": {"projectId": "your-project", "datasetId": "dataset_2_test"}, }, { "kind": "bigquery#dataset", "location": "US", "id": "your-project:dataset_1_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_1_test" - } - } + "datasetReference": {"projectId": "your-project", "datasetId": "dataset_1_test"}, + }, ] return_value = [DatasetListItem(d) for d in datasets] mock_client.return_value.list_datasets.return_value = return_value @@ -1321,27 +1211,14 @@ def test_delete_dataset(self, mock_client): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") def test_patch_dataset(self, mock_get_service): - dataset_resource = { - "access": [ - { - "role": "WRITER", - "groupByEmail": "cloud-logs@google.com" - } - ] - } + dataset_resource = {"access": [{"role": "WRITER", "groupByEmail": "cloud-logs@google.com"}]} method = mock_get_service.return_value.datasets.return_value.patch self.hook.patch_dataset( - dataset_id=DATASET_ID, - project_id=PROJECT_ID, - dataset_resource=dataset_resource + dataset_id=DATASET_ID, project_id=PROJECT_ID, dataset_resource=dataset_resource ) - method.assert_called_once_with( - projectId=PROJECT_ID, - datasetId=DATASET_ID, - body=dataset_resource - ) + method.assert_called_once_with(projectId=PROJECT_ID, datasetId=DATASET_ID, body=dataset_resource) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Dataset") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") @@ -1350,10 +1227,7 @@ def test_update_dataset(self, mock_client, mock_dataset): "kind": "bigquery#dataset", "location": "US", "id": "your-project:dataset_2_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_2_test" - } + "datasetReference": {"projectId": "your-project", "datasetId": "dataset_2_test"}, } method = mock_client.return_value.update_dataset @@ -1365,14 +1239,12 @@ def test_update_dataset(self, mock_client, mock_dataset): dataset_id=DATASET_ID, project_id=PROJECT_ID, dataset_resource=dataset_resource, - fields=["location"] + fields=["location"], ) mock_dataset.from_api_repr.assert_called_once_with(dataset_resource) method.assert_called_once_with( - dataset=dataset, - fields=["location"], - retry=DEFAULT_RETRY, + dataset=dataset, fields=["location"], retry=DEFAULT_RETRY, ) assert result == dataset @@ -1381,9 +1253,7 @@ class TestTimePartitioningInRunJob(_BigQueryBaseTestClass): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_load_default(self, mock_insert): self.hook.run_load( - destination_project_dataset_table='my_dataset.my_table', - schema_fields=[], - source_uris=[], + destination_project_dataset_table='my_dataset.my_table', schema_fields=[], source_uris=[], ) _, kwargs = mock_insert.call_args @@ -1402,17 +1272,13 @@ def test_run_load_with_arg(self, mock_insert): destination_project_dataset_table=f"{DATASET_ID}.{TABLE_ID}", schema_fields=[], source_uris=[], - time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} + time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}, ) configuration = { 'load': { 'autodetect': False, 'createDisposition': 'CREATE_IF_NEEDED', - 'destinationTable': { - 'projectId': PROJECT_ID, - 'datasetId': DATASET_ID, - 'tableId': TABLE_ID - }, + 'destinationTable': {'projectId': PROJECT_ID, 'datasetId': DATASET_ID, 'tableId': TABLE_ID}, 'sourceFormat': 'CSV', 'sourceUris': [], 'writeDisposition': 'WRITE_EMPTY', @@ -1421,7 +1287,8 @@ def test_run_load_with_arg(self, mock_insert): 'skipLeadingRows': 0, 'fieldDelimiter': ',', 'quote': None, - 'allowQuotedNewlines': False, 'encoding': 'UTF-8' + 'allowQuotedNewlines': False, + 'encoding': 'UTF-8', } } mock_insert.assert_called_once_with(configuration=configuration, project_id=PROJECT_ID) @@ -1431,8 +1298,7 @@ def test_run_query_with_arg(self, mock_insert): self.hook.run_query( sql='select 1', destination_dataset_table=f"{DATASET_ID}.{TABLE_ID}", - time_partitioning={'type': 'DAY', - 'field': 'test_field', 'expirationMs': 1000} + time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}, ) configuration = { @@ -1442,15 +1308,11 @@ def test_run_query_with_arg(self, mock_insert): 'useLegacySql': True, 'timePartitioning': {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}, 'schemaUpdateOptions': [], - 'destinationTable': { - 'projectId': PROJECT_ID, - 'datasetId': DATASET_ID, - 'tableId': TABLE_ID - }, + 'destinationTable': {'projectId': PROJECT_ID, 'datasetId': DATASET_ID, 'tableId': TABLE_ID}, 'allowLargeResults': False, 'flattenResults': None, 'writeDisposition': 'WRITE_EMPTY', - 'createDisposition': 'CREATE_IF_NEEDED' + 'createDisposition': 'CREATE_IF_NEEDED', } } @@ -1458,22 +1320,15 @@ def test_run_query_with_arg(self, mock_insert): def test_dollar_makes_partition(self): tp_out = _cleanse_time_partitioning('test.teast$20170101', {}) - expect = { - 'type': 'DAY' - } + expect = {'type': 'DAY'} self.assertEqual(tp_out, expect) def test_extra_time_partitioning_options(self): tp_out = _cleanse_time_partitioning( - 'test.teast', - {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} + 'test.teast', {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} ) - expect = { - 'type': 'DAY', - 'field': 'test_field', - 'expirationMs': 1000 - } + expect = {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} self.assertEqual(tp_out, expect) @@ -1481,9 +1336,7 @@ class TestClusteringInRunJob(_BigQueryBaseTestClass): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_load_default(self, mock_insert): self.hook.run_load( - destination_project_dataset_table='my_dataset.my_table', - schema_fields=[], - source_uris=[], + destination_project_dataset_table='my_dataset.my_table', schema_fields=[], source_uris=[], ) _, kwargs = mock_insert.call_args @@ -1496,14 +1349,11 @@ def test_run_load_with_arg(self, mock_insert): schema_fields=[], source_uris=[], cluster_fields=['field1', 'field2'], - time_partitioning={'type': 'DAY'} + time_partitioning={'type': 'DAY'}, ) _, kwargs = mock_insert.call_args - self.assertEqual( - kwargs["configuration"]['load']['clustering'], - {'fields': ['field1', 'field2']} - ) + self.assertEqual(kwargs["configuration"]['load']['clustering'], {'fields': ['field1', 'field2']}) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_query_default(self, mock_insert): @@ -1518,14 +1368,11 @@ def test_run_query_with_arg(self, mock_insert): sql='select 1', destination_dataset_table='my_dataset.my_table', cluster_fields=['field1', 'field2'], - time_partitioning={'type': 'DAY'} + time_partitioning={'type': 'DAY'}, ) _, kwargs = mock_insert.call_args - self.assertEqual( - kwargs["configuration"]['query']['clustering'], - {'fields': ['field1', 'field2']} - ) + self.assertEqual(kwargs["configuration"]['query']['clustering'], {'fields': ['field1', 'field2']}) class TestBigQueryHookLegacySql(_BigQueryBaseTestClass): @@ -1540,7 +1387,7 @@ def test_hook_uses_legacy_sql_by_default(self, mock_insert, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") @@ -1567,10 +1414,11 @@ def test_run_with_configuration_location(self, mock_client, mock_job): self.hook.run_with_configuration(conf) mock_client.assert_called_once_with(project_id=PROJECT_ID, location=location) mock_job.from_api_repr.assert_called_once_with( - {"configuration": conf, - "jobReference": {"jobId": mock.ANY, "projectId": PROJECT_ID, "location": location} - }, - mock_client.return_value + { + "configuration": conf, + "jobReference": {"jobId": mock.ANY, "projectId": PROJECT_ID, "location": location}, + }, + mock_client.return_value, ) @@ -1578,12 +1426,8 @@ class TestBigQueryWithKMS(_BigQueryBaseTestClass): @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") def test_create_empty_table_with_kms(self, mock_bq_client, mock_table): - schema_fields = [ - {"name": "id", "type": "STRING", "mode": "REQUIRED"} - ] - encryption_configuration = { - "kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c" - } + schema_fields = [{"name": "id", "type": "STRING", "mode": "REQUIRED"}] + encryption_configuration = {"kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c"} self.hook.create_empty_table( project_id=PROJECT_ID, @@ -1600,9 +1444,7 @@ def test_create_empty_table_with_kms(self, mock_bq_client, mock_table): } mock_table.from_api_repr.assert_called_once_with(body) mock_bq_client.return_value.create_table.assert_called_once_with( - table=mock_table.from_api_repr.return_value, - exists_ok=True, - retry=DEFAULT_RETRY, + table=mock_table.from_api_repr.return_value, exists_ok=True, retry=DEFAULT_RETRY, ) # pylint: disable=too-many-locals @@ -1622,12 +1464,8 @@ def test_create_external_table_with_kms(self, mock_create): allow_jagged_rows = False encoding = "UTF-8" labels = {'label1': 'test1', 'label2': 'test2'} - schema_fields = [ - {'mode': 'REQUIRED', 'name': 'id', 'type': 'STRING', 'description': None} - ] - encryption_configuration = { - "kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c" - } + schema_fields = [{'mode': 'REQUIRED', 'name': 'id', 'type': 'STRING', 'description': None}] + encryption_configuration = {"kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c"} self.hook.create_external_table( external_project_dataset_table=external_project_dataset_table, @@ -1645,7 +1483,7 @@ def test_create_external_table_with_kms(self, mock_create): allow_quoted_newlines=allow_quoted_newlines, labels=labels, schema_fields=schema_fields, - encryption_configuration=encryption_configuration + encryption_configuration=encryption_configuration, ) body = { @@ -1663,100 +1501,83 @@ def test_create_external_table_with_kms(self, mock_create): 'quote': quote_character, 'allowQuotedNewlines': allow_quoted_newlines, 'allowJaggedRows': allow_jagged_rows, - 'encoding': encoding - } - }, - 'tableReference': { - 'projectId': PROJECT_ID, - 'datasetId': DATASET_ID, - 'tableId': TABLE_ID, + 'encoding': encoding, + }, }, + 'tableReference': {'projectId': PROJECT_ID, 'datasetId': DATASET_ID, 'tableId': TABLE_ID,}, 'labels': labels, "encryptionConfiguration": encryption_configuration, } mock_create.assert_called_once_with( - table_resource=body, - project_id=PROJECT_ID, - location=None, - exists_ok=True, + table_resource=body, project_id=PROJECT_ID, location=None, exists_ok=True, ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_query_with_kms(self, mock_insert): - encryption_configuration = { - "kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c" - } - self.hook.run_query( - sql='query', - encryption_configuration=encryption_configuration - ) + encryption_configuration = {"kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c"} + self.hook.run_query(sql='query', encryption_configuration=encryption_configuration) _, kwargs = mock_insert.call_args self.assertIs( - kwargs["configuration"]['query']['destinationEncryptionConfiguration'], - encryption_configuration + kwargs["configuration"]['query']['destinationEncryptionConfiguration'], encryption_configuration ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_copy_with_kms(self, mock_insert): - encryption_configuration = { - "kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c" - } + encryption_configuration = {"kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c"} self.hook.run_copy( source_project_dataset_tables='p.d.st', destination_project_dataset_table='p.d.dt', - encryption_configuration=encryption_configuration + encryption_configuration=encryption_configuration, ) _, kwargs = mock_insert.call_args self.assertIs( - kwargs["configuration"]['copy']['destinationEncryptionConfiguration'], - encryption_configuration + kwargs["configuration"]['copy']['destinationEncryptionConfiguration'], encryption_configuration ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_run_load_with_kms(self, mock_insert): - encryption_configuration = { - "kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c" - } + encryption_configuration = {"kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c"} self.hook.run_load( destination_project_dataset_table='p.d.dt', source_uris=['abc.csv'], autodetect=True, - encryption_configuration=encryption_configuration + encryption_configuration=encryption_configuration, ) _, kwargs = mock_insert.call_args self.assertIs( - kwargs["configuration"]['load']['destinationEncryptionConfiguration'], - encryption_configuration + kwargs["configuration"]['load']['destinationEncryptionConfiguration'], encryption_configuration ) class TestBigQueryBaseCursorMethodsDeprecationWarning(unittest.TestCase): - @parameterized.expand([ - ("create_empty_table",), - ("create_empty_dataset",), - ("get_dataset_tables",), - ("delete_dataset",), - ("create_external_table",), - ("patch_table",), - ("insert_all",), - ("update_dataset",), - ("patch_dataset",), - ("get_dataset_tables_list",), - ("get_datasets_list",), - ("get_dataset",), - ("run_grant_dataset_view_access",), - ("run_table_upsert",), - ("run_table_delete",), - ("get_tabledata",), - ("get_schema",), - ("poll_job_complete",), - ("cancel_query",), - ("run_with_configuration",), - ("run_load",), - ("run_copy",), - ("run_extract",), - ("run_query",), - ]) + @parameterized.expand( + [ + ("create_empty_table",), + ("create_empty_dataset",), + ("get_dataset_tables",), + ("delete_dataset",), + ("create_external_table",), + ("patch_table",), + ("insert_all",), + ("update_dataset",), + ("patch_dataset",), + ("get_dataset_tables_list",), + ("get_datasets_list",), + ("get_dataset",), + ("run_grant_dataset_view_access",), + ("run_table_upsert",), + ("run_table_delete",), + ("get_tabledata",), + ("get_schema",), + ("poll_job_complete",), + ("cancel_query",), + ("run_with_configuration",), + ("run_load",), + ("run_copy",), + ("run_extract",), + ("run_query",), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook") def test_deprecation_warning(self, func_name, mock_bq_hook): args, kwargs = [1], {"param1": "val1"} diff --git a/tests/providers/google/cloud/hooks/test_bigquery_dts.py b/tests/providers/google/cloud/hooks/test_bigquery_dts.py index ab841733c6557..170988a23f06e 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery_dts.py +++ b/tests/providers/google/cloud/hooks/test_bigquery_dts.py @@ -78,9 +78,7 @@ def test_disable_auto_scheduling(self): "DataTransferServiceClient.create_transfer_config" ) def test_create_transfer_config(self, service_mock): - self.hook.create_transfer_config( - transfer_config=TRANSFER_CONFIG, project_id=PROJECT_ID - ) + self.hook.create_transfer_config(transfer_config=TRANSFER_CONFIG, project_id=PROJECT_ID) parent = DataTransferServiceClient.project_path(PROJECT_ID) expected_config = deepcopy(TRANSFER_CONFIG) expected_config.schedule_options.disable_auto_scheduling = True @@ -98,29 +96,19 @@ def test_create_transfer_config(self, service_mock): "DataTransferServiceClient.delete_transfer_config" ) def test_delete_transfer_config(self, service_mock): - self.hook.delete_transfer_config( - transfer_config_id=TRANSFER_CONFIG_ID, project_id=PROJECT_ID - ) + self.hook.delete_transfer_config(transfer_config_id=TRANSFER_CONFIG_ID, project_id=PROJECT_ID) - name = DataTransferServiceClient.project_transfer_config_path( - PROJECT_ID, TRANSFER_CONFIG_ID - ) - service_mock.assert_called_once_with( - name=name, metadata=None, retry=None, timeout=None - ) + name = DataTransferServiceClient.project_transfer_config_path(PROJECT_ID, TRANSFER_CONFIG_ID) + service_mock.assert_called_once_with(name=name, metadata=None, retry=None, timeout=None) @mock.patch( "airflow.providers.google.cloud.hooks.bigquery_dts." "DataTransferServiceClient.start_manual_transfer_runs" ) def test_start_manual_transfer_runs(self, service_mock): - self.hook.start_manual_transfer_runs( - transfer_config_id=TRANSFER_CONFIG_ID, project_id=PROJECT_ID - ) + self.hook.start_manual_transfer_runs(transfer_config_id=TRANSFER_CONFIG_ID, project_id=PROJECT_ID) - parent = DataTransferServiceClient.project_transfer_config_path( - PROJECT_ID, TRANSFER_CONFIG_ID - ) + parent = DataTransferServiceClient.project_transfer_config_path(PROJECT_ID, TRANSFER_CONFIG_ID) service_mock.assert_called_once_with( parent=parent, requested_time_range=None, diff --git a/tests/providers/google/cloud/hooks/test_bigquery_system.py b/tests/providers/google/cloud/hooks/test_bigquery_system.py index 1bac71eeb9120..78978c1a938df 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery_system.py +++ b/tests/providers/google/cloud/hooks/test_bigquery_system.py @@ -33,6 +33,7 @@ def setUp(self): def test_output_is_dataframe_with_valid_query(self): import pandas as pd + df = self.instance.get_pandas_df('select 1') self.assertIsInstance(df, pd.DataFrame) @@ -46,12 +47,10 @@ def test_succeeds_with_explicit_legacy_query(self): self.assertEqual(df.iloc(0)[0][0], 1) def test_succeeds_with_explicit_std_query(self): - df = self.instance.get_pandas_df( - 'select * except(b) from (select 1 a, 2 b)', dialect='standard') + df = self.instance.get_pandas_df('select * except(b) from (select 1 a, 2 b)', dialect='standard') self.assertEqual(df.iloc(0)[0][0], 1) def test_throws_exception_with_incompatible_syntax(self): with self.assertRaises(Exception) as context: - self.instance.get_pandas_df( - 'select * except(b) from (select 1 a, 2 b)', dialect='legacy') + self.instance.get_pandas_df('select * except(b) from (select 1 a, 2 b)', dialect='legacy') self.assertIn('Reason: ', str(context.exception), "") diff --git a/tests/providers/google/cloud/hooks/test_bigtable.py b/tests/providers/google/cloud/hooks/test_bigtable.py index c603053bbad56..f05441dc108a4 100644 --- a/tests/providers/google/cloud/hooks/test_bigtable.py +++ b/tests/providers/google/cloud/hooks/test_bigtable.py @@ -27,7 +27,8 @@ from airflow.providers.google.cloud.hooks.bigtable import BigtableHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -48,15 +49,16 @@ class TestBigtableHookNoDefaultProjectId(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_no_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_no_default_project_id, + ): self.bigtable_hook_no_default_project_id = BigtableHook(gcp_conn_id='test') @mock.patch( "airflow.providers.google.cloud.hooks.bigtable.BigtableHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.bigtable.Client") @@ -66,7 +68,7 @@ def test_bigtable_client_creation(self, mock_client, mock_get_creds, mock_client project=GCP_PROJECT_ID_HOOK_UNIT_TEST, credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value, - admin=True + admin=True, ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.bigtable_hook_no_default_project_id._client, result) @@ -77,8 +79,8 @@ def test_get_instance_overridden_project_id(self, get_client): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True res = self.bigtable_hook_no_default_project_id.get_instance( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - instance_id=CBT_INSTANCE) + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=CBT_INSTANCE + ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() get_client.assert_called_once_with(project_id='example-project') @@ -91,7 +93,8 @@ def test_delete_instance_overridden_project_id(self, get_client): instance_exists_method.return_value = True delete_method = instance_method.return_value.delete res = self.bigtable_hook_no_default_project_id.delete_instance( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=CBT_INSTANCE) + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=CBT_INSTANCE + ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() delete_method.assert_called_once_with() @@ -108,7 +111,8 @@ def test_create_instance_overridden_project_id(self, get_client, instance_create project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=CBT_INSTANCE, main_cluster_id=CBT_CLUSTER, - main_cluster_zone=CBT_ZONE) + main_cluster_zone=CBT_ZONE, + ) get_client.assert_called_once_with(project_id='example-project') instance_create.assert_called_once_with(clusters=mock.ANY) self.assertEqual(res.instance_id, 'instance') @@ -124,7 +128,7 @@ def test_update_instance_overridden_project_id(self, get_client, instance_update instance_id=CBT_INSTANCE, instance_display_name=CBT_INSTANCE_DISPLAY_NAME, instance_type=CBT_INSTANCE_TYPE, - instance_labels=CBT_INSTANCE_LABELS + instance_labels=CBT_INSTANCE_LABELS, ) get_client.assert_called_once_with(project_id='example-project') instance_update.assert_called_once_with() @@ -137,24 +141,24 @@ def test_delete_table_overridden_project_id(self, get_client): table_delete_method = instance_method.return_value.table.return_value.delete instance_exists_method.return_value = True self.bigtable_hook_no_default_project_id.delete_table( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - instance_id=CBT_INSTANCE, - table_id=CBT_TABLE) + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=CBT_INSTANCE, table_id=CBT_TABLE + ) get_client.assert_called_once_with(project_id='example-project') instance_exists_method.assert_called_once_with() table_delete_method.assert_called_once_with() class TestBigtableHookDefaultProjectId(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_default_project_id, + ): self.bigtable_hook_default_project_id = BigtableHook(gcp_conn_id='test') @mock.patch( "airflow.providers.google.cloud.hooks.bigtable.BigtableHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.bigtable.Client") @@ -164,7 +168,7 @@ def test_bigtable_client_creation(self, mock_client, mock_get_creds, mock_client project=GCP_PROJECT_ID_HOOK_UNIT_TEST, credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value, - admin=True + admin=True, ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.bigtable_hook_default_project_id._client, result) @@ -172,7 +176,7 @@ def test_bigtable_client_creation(self, mock_client, mock_get_creds, mock_client @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') def test_get_instance(self, get_client, mock_project_id): @@ -180,8 +184,7 @@ def test_get_instance(self, get_client, mock_project_id): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True res = self.bigtable_hook_default_project_id.get_instance( - instance_id=CBT_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + instance_id=CBT_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() @@ -194,8 +197,8 @@ def test_get_instance_overridden_project_id(self, get_client): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True res = self.bigtable_hook_default_project_id.get_instance( - project_id='new-project', - instance_id=CBT_INSTANCE) + project_id='new-project', instance_id=CBT_INSTANCE + ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() get_client.assert_called_once_with(project_id='new-project') @@ -204,7 +207,7 @@ def test_get_instance_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') def test_get_instance_no_instance(self, get_client, mock_project_id): @@ -212,8 +215,7 @@ def test_get_instance_no_instance(self, get_client, mock_project_id): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = False res = self.bigtable_hook_default_project_id.get_instance( - instance_id=CBT_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + instance_id=CBT_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() @@ -223,7 +225,7 @@ def test_get_instance_no_instance(self, get_client, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') def test_delete_instance(self, get_client, mock_project_id): @@ -232,8 +234,7 @@ def test_delete_instance(self, get_client, mock_project_id): instance_exists_method.return_value = True delete_method = instance_method.return_value.delete res = self.bigtable_hook_default_project_id.delete_instance( - instance_id=CBT_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + instance_id=CBT_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() @@ -248,7 +249,8 @@ def test_delete_instance_overridden_project_id(self, get_client): instance_exists_method.return_value = True delete_method = instance_method.return_value.delete res = self.bigtable_hook_default_project_id.delete_instance( - project_id='new-project', instance_id=CBT_INSTANCE) + project_id='new-project', instance_id=CBT_INSTANCE + ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() delete_method.assert_called_once_with() @@ -258,7 +260,7 @@ def test_delete_instance_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') def test_delete_instance_no_instance(self, get_client, mock_project_id): @@ -267,8 +269,7 @@ def test_delete_instance_no_instance(self, get_client, mock_project_id): instance_exists_method.return_value = False delete_method = instance_method.return_value.delete self.bigtable_hook_default_project_id.delete_instance( - instance_id=CBT_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + instance_id=CBT_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) instance_method.assert_called_once_with('instance') instance_exists_method.assert_called_once_with() @@ -278,7 +279,7 @@ def test_delete_instance_no_instance(self, get_client, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('google.cloud.bigtable.instance.Instance.create') @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') @@ -308,9 +309,7 @@ def test_create_instance_with_one_replica_cluster( self, get_client, instance_create, cluster, mock_project_id ): operation = mock.Mock() - operation.result_return_value = Instance( - instance_id=CBT_INSTANCE, client=get_client - ) + operation.result_return_value = Instance(instance_id=CBT_INSTANCE, client=get_client) instance_create.return_value = operation res = self.bigtable_hook_default_project_id.create_instance( @@ -325,14 +324,12 @@ def test_create_instance_with_one_replica_cluster( ) cluster.assert_has_calls( [ - unittest.mock.call( - CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD - ), + unittest.mock.call(CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD), unittest.mock.call( CBT_REPLICA_CLUSTER_ID, CBT_REPLICA_CLUSTER_ZONE, 1, enums.StorageType.SSD ), ], - any_order=True + any_order=True, ) get_client.assert_called_once_with(project_id='example-project') instance_create.assert_called_once_with(clusters=mock.ANY) @@ -350,9 +347,7 @@ def test_create_instance_with_multiple_replica_clusters( self, get_client, instance_create, cluster, mock_project_id ): operation = mock.Mock() - operation.result_return_value = Instance( - instance_id=CBT_INSTANCE, client=get_client - ) + operation.result_return_value = Instance(instance_id=CBT_INSTANCE, client=get_client) instance_create.return_value = operation res = self.bigtable_hook_default_project_id.create_instance( @@ -366,20 +361,12 @@ def test_create_instance_with_multiple_replica_clusters( ) cluster.assert_has_calls( [ - unittest.mock.call( - CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD - ), - unittest.mock.call( - 'replica-1', 'us-west1-a', 1, enums.StorageType.SSD - ), - unittest.mock.call( - 'replica-2', 'us-central1-f', 1, enums.StorageType.SSD - ), - unittest.mock.call( - 'replica-3', 'us-east1-d', 1, enums.StorageType.SSD - ), + unittest.mock.call(CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD), + unittest.mock.call('replica-1', 'us-west1-a', 1, enums.StorageType.SSD), + unittest.mock.call('replica-2', 'us-central1-f', 1, enums.StorageType.SSD), + unittest.mock.call('replica-3', 'us-east1-d', 1, enums.StorageType.SSD), ], - any_order=True + any_order=True, ) get_client.assert_called_once_with(project_id='example-project') instance_create.assert_called_once_with(clusters=mock.ANY) @@ -388,7 +375,7 @@ def test_create_instance_with_multiple_replica_clusters( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('google.cloud.bigtable.instance.Instance.update') @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') @@ -417,7 +404,8 @@ def test_create_instance_overridden_project_id(self, get_client, instance_create project_id='new-project', instance_id=CBT_INSTANCE, main_cluster_id=CBT_CLUSTER, - main_cluster_zone=CBT_ZONE) + main_cluster_zone=CBT_ZONE, + ) get_client.assert_called_once_with(project_id='new-project') instance_create.assert_called_once_with(clusters=mock.ANY) self.assertEqual(res.instance_id, 'instance') @@ -425,7 +413,7 @@ def test_create_instance_overridden_project_id(self, get_client, instance_create @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client') def test_delete_table(self, get_client, mock_project_id): @@ -434,9 +422,7 @@ def test_delete_table(self, get_client, mock_project_id): table_delete_method = instance_method.return_value.table.return_value.delete instance_exists_method.return_value = True self.bigtable_hook_default_project_id.delete_table( - instance_id=CBT_INSTANCE, - table_id=CBT_TABLE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + instance_id=CBT_INSTANCE, table_id=CBT_TABLE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) get_client.assert_called_once_with(project_id='example-project') instance_exists_method.assert_called_once_with() @@ -449,9 +435,8 @@ def test_delete_table_overridden_project_id(self, get_client): table_delete_method = instance_method.return_value.table.return_value.delete instance_exists_method.return_value = True self.bigtable_hook_default_project_id.delete_table( - project_id='new-project', - instance_id=CBT_INSTANCE, - table_id=CBT_TABLE) + project_id='new-project', instance_id=CBT_INSTANCE, table_id=CBT_TABLE + ) get_client.assert_called_once_with(project_id='new-project') instance_exists_method.assert_called_once_with() table_delete_method.assert_called_once_with() @@ -463,12 +448,8 @@ def test_create_table(self, get_client, create): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True client = mock.Mock(Client) - instance = google.cloud.bigtable.instance.Instance( - instance_id=CBT_INSTANCE, - client=client) - self.bigtable_hook_default_project_id.create_table( - instance=instance, - table_id=CBT_TABLE) + instance = google.cloud.bigtable.instance.Instance(instance_id=CBT_INSTANCE, client=client) + self.bigtable_hook_default_project_id.create_table(instance=instance, table_id=CBT_TABLE) get_client.assert_not_called() create.assert_called_once_with([], {}) @@ -479,13 +460,10 @@ def test_update_cluster(self, get_client, update): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True client = mock.Mock(Client) - instance = google.cloud.bigtable.instance.Instance( - instance_id=CBT_INSTANCE, - client=client) + instance = google.cloud.bigtable.instance.Instance(instance_id=CBT_INSTANCE, client=client) self.bigtable_hook_default_project_id.update_cluster( - instance=instance, - cluster_id=CBT_CLUSTER, - nodes=4) + instance=instance, cluster_id=CBT_CLUSTER, nodes=4 + ) get_client.assert_not_called() update.assert_called_once_with() @@ -497,11 +475,10 @@ def test_list_column_families(self, get_client, list_column_families): instance_exists_method.return_value = True client = mock.Mock(Client) get_client.return_value = client - instance = google.cloud.bigtable.instance.Instance( - instance_id=CBT_INSTANCE, - client=client) + instance = google.cloud.bigtable.instance.Instance(instance_id=CBT_INSTANCE, client=client) self.bigtable_hook_default_project_id.get_column_families_for_table( - instance=instance, table_id=CBT_TABLE) + instance=instance, table_id=CBT_TABLE + ) get_client.assert_not_called() list_column_families.assert_called_once_with() @@ -512,10 +489,9 @@ def test_get_cluster_states(self, get_client, get_cluster_states): instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True client = mock.Mock(Client) - instance = google.cloud.bigtable.instance.Instance( - instance_id=CBT_INSTANCE, - client=client) + instance = google.cloud.bigtable.instance.Instance(instance_id=CBT_INSTANCE, client=client) self.bigtable_hook_default_project_id.get_cluster_states_for_table( - instance=instance, table_id=CBT_TABLE) + instance=instance, table_id=CBT_TABLE + ) get_client.assert_not_called() get_cluster_states.assert_called_once_with() diff --git a/tests/providers/google/cloud/hooks/test_cloud_build.py b/tests/providers/google/cloud/hooks/test_cloud_build.py index 6a149b4205144..e5bd46bd6716d 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_build.py +++ b/tests/providers/google/cloud/hooks/test_cloud_build.py @@ -27,7 +27,8 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -70,19 +71,15 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize): def test_build_immediately_complete(self, get_conn_mock): service_mock = get_conn_mock.return_value - service_mock.projects.return_value\ - .builds.return_value\ - .create.return_value\ - .execute.return_value = TEST_BUILD + service_mock.projects.return_value.builds.return_value.create.return_value.execute.return_value = ( + TEST_BUILD + ) - service_mock.projects.return_value.\ - builds.return_value.\ - get.return_value.\ - execute.return_value = TEST_BUILD + service_mock.projects.return_value.builds.return_value.get.return_value.execute.return_value = ( + TEST_BUILD + ) - service_mock.operations.return_value.\ - get.return_value.\ - execute.return_value = TEST_DONE_OPERATION + service_mock.operations.return_value.get.return_value.execute.return_value = TEST_DONE_OPERATION result = self.hook.create_build(body={}, project_id=TEST_PROJECT_ID) @@ -117,7 +114,7 @@ def test_waiting_operation(self, _, get_conn_mock): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.time.sleep") @@ -157,7 +154,7 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn") def test_build_immediately_complete(self, get_conn_mock, mock_project_id): @@ -184,7 +181,7 @@ def test_build_immediately_complete(self, get_conn_mock, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.time.sleep") @@ -211,7 +208,7 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.time.sleep") @@ -251,7 +248,7 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn") def test_create_build(self, mock_get_conn, mock_project_id): diff --git a/tests/providers/google/cloud/hooks/test_cloud_memorystore.py b/tests/providers/google/cloud/hooks/test_cloud_memorystore.py index ac5aa99ab619c..c4a763afa570e 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_memorystore.py +++ b/tests/providers/google/cloud/hooks/test_cloud_memorystore.py @@ -27,7 +27,8 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_memorystore import CloudMemorystoreHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -62,11 +63,9 @@ def setUp(self,): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_create_instance_when_exists(self, mock_get_conn, mock_project_id): mock_get_conn.return_value.get_instance.return_value = Instance(name=TEST_NAME) result = self.hook.create_instance( # pylint: disable=no-value-for-parameter @@ -85,11 +84,9 @@ def test_create_instance_when_exists(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_create_instance_when_not_exists(self, mock_get_conn, mock_project_id): mock_get_conn.return_value.get_instance.side_effect = [ NotFound("Instnace not found"), @@ -126,11 +123,9 @@ def test_create_instance_when_not_exists(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_delete_instance(self, mock_get_conn, mock_project_id): self.hook.delete_instance( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -146,11 +141,9 @@ def test_delete_instance(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_get_instance(self, mock_get_conn, mock_project_id): self.hook.get_instance( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -166,11 +159,9 @@ def test_get_instance(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_list_instances(self, mock_get_conn, mock_project_id): self.hook.list_instances( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -190,11 +181,9 @@ def test_list_instances(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_update_instance(self, mock_get_conn, mock_project_id): self.hook.update_instance( # pylint: disable=no-value-for-parameter update_mask=TEST_UPDATE_MASK, @@ -222,9 +211,7 @@ def setUp(self,): ): self.hook = CloudMemorystoreHook(gcp_conn_id="test") - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_create_instance_when_exists(self, mock_get_conn): mock_get_conn.return_value.get_instance.return_value = Instance(name=TEST_NAME) result = self.hook.create_instance( @@ -244,12 +231,11 @@ def test_create_instance_when_exists(self, mock_get_conn): ) self.assertEqual(Instance(name=TEST_NAME), result) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_create_instance_when_not_exists(self, mock_get_conn): mock_get_conn.return_value.get_instance.side_effect = [ - NotFound("Instnace not found"), Instance(name=TEST_NAME) + NotFound("Instnace not found"), + Instance(name=TEST_NAME), ] mock_get_conn.return_value.create_instance.return_value.result.return_value = Instance(name=TEST_NAME) result = self.hook.create_instance( @@ -294,11 +280,9 @@ def test_create_instance_when_not_exists(self, mock_get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_create_instance_without_project_id(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.create_instance( @@ -311,9 +295,7 @@ def test_create_instance_without_project_id(self, mock_get_conn, mock_project_id metadata=TEST_METADATA, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_delete_instance(self, mock_get_conn): self.hook.delete_instance( location=TEST_LOCATION, @@ -330,11 +312,9 @@ def test_delete_instance(self, mock_get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_delete_instance_without_project_id(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.delete_instance( @@ -346,9 +326,7 @@ def test_delete_instance_without_project_id(self, mock_get_conn, mock_project_id metadata=TEST_METADATA, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_get_instance(self, mock_get_conn): self.hook.get_instance( location=TEST_LOCATION, @@ -365,11 +343,9 @@ def test_get_instance(self, mock_get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_get_instance_without_project_id(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.get_instance( @@ -381,9 +357,7 @@ def test_get_instance_without_project_id(self, mock_get_conn, mock_project_id): metadata=TEST_METADATA, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_list_instances(self, mock_get_conn): self.hook.list_instances( location=TEST_LOCATION, @@ -404,11 +378,9 @@ def test_list_instances(self, mock_get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_list_instances_without_project_id(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.list_instances( @@ -420,9 +392,7 @@ def test_list_instances_without_project_id(self, mock_get_conn, mock_project_id) metadata=TEST_METADATA, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_update_instance(self, mock_get_conn): self.hook.update_instance( update_mask=TEST_UPDATE_MASK, @@ -430,7 +400,7 @@ def test_update_instance(self, mock_get_conn): retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, - project_id=TEST_PROJECT_ID + project_id=TEST_PROJECT_ID, ) mock_get_conn.return_value.update_instance.assert_called_once_with( update_mask=TEST_UPDATE_MASK, @@ -443,11 +413,9 @@ def test_update_instance(self, mock_get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn") def test_update_instance_without_project_id(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.update_instance( # pylint: disable=no-value-for-parameter diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py b/tests/providers/google/cloud/hooks/test_cloud_sql.py index f4063d8e02bcc..416c2c4e49437 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_sql.py +++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py @@ -31,48 +31,55 @@ from airflow.models import Connection from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, ) class TestGcpSqlHookDefaultProjectId(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_default_project_id, + ): self.cloudsql_hook = CloudSQLHook(api_version='v1', gcp_conn_id='test') - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) def test_instance_import_exception(self, mock_get_credentials): self.cloudsql_hook.get_conn = mock.Mock( side_effect=HttpError(resp=httplib2.Response({'status': 400}), content=b'Error content') ) with self.assertRaises(AirflowException) as cm: self.cloudsql_hook.import_instance( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) err = cm.exception self.assertIn("Importing instance ", str(err)) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) def test_instance_export_exception(self, mock_get_credentials): self.cloudsql_hook.get_conn = mock.Mock( - side_effect=HttpError(resp=httplib2.Response({'status': 400}), - content=b'Error content') + side_effect=HttpError(resp=httplib2.Response({'status': 400}), content=b'Error content') ) with self.assertRaises(HttpError) as cm: self.cloudsql_hook.export_instance( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) err = cm.exception self.assertEqual(400, err.resp.status) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_instance_import(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -81,17 +88,20 @@ def test_instance_import(self, wait_for_operation_to_complete, get_conn, mock_ge execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook.import_instance( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) import_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_instance_export(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -100,39 +110,38 @@ def test_instance_export(self, wait_for_operation_to_complete, get_conn, mock_ge execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook.export_instance( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) export_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) self.assertEqual(1, mock_get_credentials.call_count) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_instance_export_with_in_progress_retry( - self, wait_for_operation_to_complete, get_conn - ): + def test_instance_export_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn): export_method = get_conn.return_value.instances.return_value.export execute_method = export_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None - self.cloudsql_hook.export_instance( - project_id='example-project', - instance='instance', - body={}) + self.cloudsql_hook.export_instance(project_id='example-project', instance='instance', body={}) self.assertEqual(2, export_method.call_count) self.assertEqual(2, execute_method.call_count) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_get_instance(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -140,8 +149,7 @@ def test_get_instance(self, wait_for_operation_to_complete, get_conn, mock_get_c execute_method = get_method.return_value.execute execute_method.return_value = {"name": "instance"} wait_for_operation_to_complete.return_value = None - res = self.cloudsql_hook.get_instance( # pylint: disable=no-value-for-parameter - instance='instance') + res = self.cloudsql_hook.get_instance(instance='instance') # pylint: disable=no-value-for-parameter self.assertIsNotNone(res) self.assertEqual('instance', res['name']) get_method.assert_called_once_with(instance='instance', project='example-project') @@ -149,8 +157,10 @@ def test_get_instance(self, wait_for_operation_to_complete, get_conn, mock_get_c wait_for_operation_to_complete.assert_not_called() self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_create_instance(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -158,8 +168,7 @@ def test_create_instance(self, wait_for_operation_to_complete, get_conn, mock_ge execute_method = insert_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None - self.cloudsql_hook.create_instance( # pylint: disable=no-value-for-parameter - body={}) + self.cloudsql_hook.create_instance(body={}) # pylint: disable=no-value-for-parameter insert_method.assert_called_once_with(body={}, project='example-project') execute_method.assert_called_once_with(num_retries=5) @@ -168,21 +177,23 @@ def test_create_instance(self, wait_for_operation_to_complete, get_conn, mock_ge ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_create_instance_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn, - mock_get_credentials): + def test_create_instance_with_in_progress_retry( + self, wait_for_operation_to_complete, get_conn, mock_get_credentials + ): insert_method = get_conn.return_value.instances.return_value.insert execute_method = insert_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None - self.cloudsql_hook.create_instance( # pylint: disable=no-value-for-parameter - body={}) + self.cloudsql_hook.create_instance(body={}) # pylint: disable=no-value-for-parameter self.assertEqual(1, mock_get_credentials.call_count) self.assertEqual(2, insert_method.call_count) @@ -191,22 +202,25 @@ def test_create_instance_with_in_progress_retry(self, wait_for_operation_to_comp operation_name='operation_id', project_id='example-project' ) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_patch_instance_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn, - mock_get_credentials): + def test_patch_instance_with_in_progress_retry( + self, wait_for_operation_to_complete, get_conn, mock_get_credentials + ): patch_method = get_conn.return_value.instances.return_value.patch execute_method = patch_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None self.cloudsql_hook.patch_instance( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) self.assertEqual(1, mock_get_credentials.call_count) self.assertEqual(2, patch_method.call_count) @@ -215,8 +229,10 @@ def test_patch_instance_with_in_progress_retry(self, wait_for_operation_to_compl operation_name='operation_id', project_id='example-project' ) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_patch_instance(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -225,8 +241,8 @@ def test_patch_instance(self, wait_for_operation_to_complete, get_conn, mock_get execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook.patch_instance( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) patch_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) @@ -235,8 +251,10 @@ def test_patch_instance(self, wait_for_operation_to_complete, get_conn, mock_get ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_delete_instance(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -244,8 +262,7 @@ def test_delete_instance(self, wait_for_operation_to_complete, get_conn, mock_ge execute_method = delete_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None - self.cloudsql_hook.delete_instance( # pylint: disable=no-value-for-parameter - instance='instance') + self.cloudsql_hook.delete_instance(instance='instance') # pylint: disable=no-value-for-parameter delete_method.assert_called_once_with(instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) @@ -254,21 +271,23 @@ def test_delete_instance(self, wait_for_operation_to_complete, get_conn, mock_ge ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_delete_instance_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn, - mock_get_credentials): + def test_delete_instance_with_in_progress_retry( + self, wait_for_operation_to_complete, get_conn, mock_get_credentials + ): delete_method = get_conn.return_value.instances.return_value.delete execute_method = delete_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None - self.cloudsql_hook.delete_instance( # pylint: disable=no-value-for-parameter - instance='instance') + self.cloudsql_hook.delete_instance(instance='instance') # pylint: disable=no-value-for-parameter self.assertEqual(1, mock_get_credentials.call_count) self.assertEqual(2, delete_method.call_count) @@ -277,8 +296,10 @@ def test_delete_instance_with_in_progress_retry(self, wait_for_operation_to_comp operation_name='operation_id', project_id='example-project' ) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_get_database(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -287,19 +308,21 @@ def test_get_database(self, wait_for_operation_to_complete, get_conn, mock_get_c execute_method.return_value = {"name": "database"} wait_for_operation_to_complete.return_value = None res = self.cloudsql_hook.get_database( # pylint: disable=no-value-for-parameter - database='database', - instance='instance') + database='database', instance='instance' + ) self.assertIsNotNone(res) self.assertEqual('database', res['name']) - get_method.assert_called_once_with(instance='instance', - database='database', - project='example-project') + get_method.assert_called_once_with( + instance='instance', database='database', project='example-project' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_not_called() self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_create_database(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -308,8 +331,8 @@ def test_create_database(self, wait_for_operation_to_complete, get_conn, mock_ge execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook.create_database( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) insert_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) @@ -318,22 +341,25 @@ def test_create_database(self, wait_for_operation_to_complete, get_conn, mock_ge ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_create_database_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn, - mock_get_credentials): + def test_create_database_with_in_progress_retry( + self, wait_for_operation_to_complete, get_conn, mock_get_credentials + ): insert_method = get_conn.return_value.databases.return_value.insert execute_method = insert_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None self.cloudsql_hook.create_database( # pylint: disable=no-value-for-parameter - instance='instance', - body={}) + instance='instance', body={} + ) self.assertEqual(1, mock_get_credentials.call_count) self.assertEqual(2, insert_method.call_count) @@ -342,8 +368,10 @@ def test_create_database_with_in_progress_retry(self, wait_for_operation_to_comp operation_name='operation_id', project_id='example-project' ) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_patch_database(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -352,37 +380,37 @@ def test_patch_database(self, wait_for_operation_to_complete, get_conn, mock_get execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook.patch_database( # pylint: disable=no-value-for-parameter - instance='instance', - database='database', - body={}) - - patch_method.assert_called_once_with(body={}, - database='database', - instance='instance', - project='example-project') + instance='instance', database='database', body={} + ) + + patch_method.assert_called_once_with( + body={}, database='database', instance='instance', project='example-project' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( operation_name='operation_id', project_id='example-project' ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_patch_database_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn, - mock_get_credentials): + def test_patch_database_with_in_progress_retry( + self, wait_for_operation_to_complete, get_conn, mock_get_credentials + ): patch_method = get_conn.return_value.databases.return_value.patch execute_method = patch_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None self.cloudsql_hook.patch_database( # pylint: disable=no-value-for-parameter - instance='instance', - database='database', - body={}) + instance='instance', database='database', body={} + ) self.assertEqual(1, mock_get_credentials.call_count) self.assertEqual(2, patch_method.call_count) @@ -391,8 +419,10 @@ def test_patch_database_with_in_progress_retry(self, wait_for_operation_to_compl operation_name='operation_id', project_id='example-project' ) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') def test_delete_database(self, wait_for_operation_to_complete, get_conn, mock_get_credentials): @@ -401,34 +431,37 @@ def test_delete_database(self, wait_for_operation_to_complete, get_conn, mock_ge execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook.delete_database( # pylint: disable=no-value-for-parameter - instance='instance', - database='database') + instance='instance', database='database' + ) - delete_method.assert_called_once_with(database='database', - instance='instance', - project='example-project') + delete_method.assert_called_once_with( + database='database', instance='instance', project='example-project' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( operation_name='operation_id', project_id='example-project' ) self.assertEqual(1, mock_get_credentials.call_count) - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), 'example-project')) + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id', + return_value=(mock.MagicMock(), 'example-project'), + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') - def test_delete_database_with_in_progress_retry(self, wait_for_operation_to_complete, get_conn, - mock_get_credentials): + def test_delete_database_with_in_progress_retry( + self, wait_for_operation_to_complete, get_conn, mock_get_credentials + ): delete_method = get_conn.return_value.databases.return_value.delete execute_method = delete_method.return_value.execute execute_method.side_effect = [ - HttpError(resp=type('', (object,), {"status": 429, })(), content=b'Internal Server Error'), - {"name": "operation_id"} + HttpError(resp=type('', (object,), {"status": 429,})(), content=b'Internal Server Error'), + {"name": "operation_id"}, ] wait_for_operation_to_complete.return_value = None self.cloudsql_hook.delete_database( # pylint: disable=no-value-for-parameter - instance='instance', - database='database') + instance='instance', database='database' + ) self.assertEqual(1, mock_get_credentials.call_count) self.assertEqual(2, delete_method.call_count) @@ -440,14 +473,16 @@ def test_delete_database_with_in_progress_retry(self, wait_for_operation_to_comp class TestGcpSqlHookNoDefaultProjectID(unittest.TestCase): def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_no_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_no_default_project_id, + ): self.cloudsql_hook_no_default_project_id = CloudSQLHook(api_version='v1', gcp_conn_id='test') @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -459,18 +494,18 @@ def test_instance_import_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.import_instance( - project_id='example-project', - instance='instance', - body={}) + project_id='example-project', instance='instance', body={} + ) import_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -482,18 +517,18 @@ def test_instance_export_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.export_instance( - project_id='example-project', - instance='instance', - body={}) + project_id='example-project', instance='instance', body={} + ) export_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -505,8 +540,8 @@ def test_get_instance_overridden_project_id( execute_method.return_value = {"name": "instance"} wait_for_operation_to_complete.return_value = None res = self.cloudsql_hook_no_default_project_id.get_instance( - project_id='example-project', - instance='instance') + project_id='example-project', instance='instance' + ) self.assertIsNotNone(res) self.assertEqual('instance', res['name']) get_method.assert_called_once_with(instance='instance', project='example-project') @@ -516,7 +551,7 @@ def test_get_instance_overridden_project_id( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -527,9 +562,7 @@ def test_create_instance_overridden_project_id( execute_method = insert_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None - self.cloudsql_hook_no_default_project_id.create_instance( - project_id='example-project', - body={}) + self.cloudsql_hook_no_default_project_id.create_instance(project_id='example-project', body={}) insert_method.assert_called_once_with(body={}, project='example-project') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( @@ -539,7 +572,7 @@ def test_create_instance_overridden_project_id( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -551,9 +584,8 @@ def test_patch_instance_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.patch_instance( - project_id='example-project', - instance='instance', - body={}) + project_id='example-project', instance='instance', body={} + ) patch_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( @@ -563,7 +595,7 @@ def test_patch_instance_overridden_project_id( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -575,8 +607,8 @@ def test_delete_instance_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.delete_instance( - project_id='example-project', - instance='instance') + project_id='example-project', instance='instance' + ) delete_method.assert_called_once_with(instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( @@ -586,7 +618,7 @@ def test_delete_instance_overridden_project_id( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -598,21 +630,20 @@ def test_get_database_overridden_project_id( execute_method.return_value = {"name": "database"} wait_for_operation_to_complete.return_value = None res = self.cloudsql_hook_no_default_project_id.get_database( - project_id='example-project', - database='database', - instance='instance') + project_id='example-project', database='database', instance='instance' + ) self.assertIsNotNone(res) self.assertEqual('database', res['name']) - get_method.assert_called_once_with(instance='instance', - database='database', - project='example-project') + get_method.assert_called_once_with( + instance='instance', database='database', project='example-project' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_not_called() @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -624,9 +655,8 @@ def test_create_database_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.create_database( - project_id='example-project', - instance='instance', - body={}) + project_id='example-project', instance='instance', body={} + ) insert_method.assert_called_once_with(body={}, instance='instance', project='example-project') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( @@ -636,7 +666,7 @@ def test_create_database_overridden_project_id( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -648,14 +678,11 @@ def test_patch_database_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.patch_database( - project_id='example-project', - instance='instance', - database='database', - body={}) - patch_method.assert_called_once_with(body={}, - database='database', - instance='instance', - project='example-project') + project_id='example-project', instance='instance', database='database', body={} + ) + patch_method.assert_called_once_with( + body={}, database='database', instance='instance', project='example-project' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( operation_name='operation_id', project_id='example-project' @@ -664,7 +691,7 @@ def test_patch_database_overridden_project_id( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete') @@ -676,12 +703,11 @@ def test_delete_database_overridden_project_id( execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None self.cloudsql_hook_no_default_project_id.delete_database( - project_id='example-project', - instance='instance', - database='database') - delete_method.assert_called_once_with(database='database', - instance='instance', - project='example-project') + project_id='example-project', instance='instance', database='database' + ) + delete_method.assert_called_once_with( + database='database', instance='instance', project='example-project' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with( operation_name='operation_id', project_id='example-project' @@ -689,47 +715,44 @@ def test_delete_database_overridden_project_id( class TestCloudSqlDatabaseHook(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_validate_ssl_certs_no_ssl(self, get_connection): connection = Connection() - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres" - })) + connection.set_extra( + json.dumps({"location": "test", "instance": "instance", "database_type": "postgres"}) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) hook.validate_ssl_certs() - @parameterized.expand([ - [{}], - [{"sslcert": "cert_file.pem"}], - [{"sslkey": "cert_key.pem"}], - [{"sslrootcert": "root_cert_file.pem"}], - [{"sslcert": "cert_file.pem", "sslkey": "cert_key.pem"}], - [{"sslrootcert": "root_cert_file.pem", "sslkey": "cert_key.pem"}], - [{"sslrootcert": "root_cert_file.pem", "sslcert": "cert_file.pem"}], - ]) + @parameterized.expand( + [ + [{}], + [{"sslcert": "cert_file.pem"}], + [{"sslkey": "cert_key.pem"}], + [{"sslrootcert": "root_cert_file.pem"}], + [{"sslcert": "cert_file.pem", "sslkey": "cert_key.pem"}], + [{"sslrootcert": "root_cert_file.pem", "sslkey": "cert_key.pem"}], + [{"sslrootcert": "root_cert_file.pem", "sslcert": "cert_file.pem"}], + ] + ) @mock.patch('os.path.isfile') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_validate_ssl_certs_missing_cert_params( - self, cert_dict, get_connection, mock_is_file): + self, cert_dict, get_connection, mock_is_file + ): mock_is_file.side_effects = True connection = Connection() - extras = { - "location": "test", - "instance": "instance", - "database_type": "postgres", - "use_ssl": "True" - } + extras = {"location": "test", "instance": "instance", "database_type": "postgres", "use_ssl": "True"} extras.update(cert_dict) connection.set_extra(json.dumps(extras)) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) with self.assertRaises(AirflowException) as cm: hook.validate_ssl_certs() err = cm.exception @@ -740,38 +763,49 @@ def test_cloudsql_database_hook_validate_ssl_certs_missing_cert_params( def test_cloudsql_database_hook_validate_ssl_certs_with_ssl(self, get_connection, mock_is_file): connection = Connection() mock_is_file.return_value = True - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres", - "use_ssl": "True", - "sslcert": "cert_file.pem", - "sslrootcert": "rootcert_file.pem", - "sslkey": "key_file.pem", - })) + connection.set_extra( + json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + "use_ssl": "True", + "sslcert": "cert_file.pem", + "sslrootcert": "rootcert_file.pem", + "sslkey": "key_file.pem", + } + ) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) hook.validate_ssl_certs() @mock.patch('os.path.isfile') @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable( - self, get_connection, mock_is_file): + self, get_connection, mock_is_file + ): connection = Connection() mock_is_file.return_value = False - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres", - "use_ssl": "True", - "sslcert": "cert_file.pem", - "sslrootcert": "rootcert_file.pem", - "sslkey": "key_file.pem", - })) + connection.set_extra( + json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + "use_ssl": "True", + "sslcert": "cert_file.pem", + "sslrootcert": "rootcert_file.pem", + "sslkey": "key_file.pem", + } + ) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) with self.assertRaises(AirflowException) as cm: hook.validate_ssl_certs() err = cm.exception @@ -780,16 +814,21 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable( @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_connection): connection = Connection() - connection.set_extra(json.dumps({ - "location": "test", - "instance": "very_long_instance_name_that_will_be_too_long_to_build_socket_length", - "database_type": "postgres", - "use_proxy": "True", - "use_tcp": "False" - })) + connection.set_extra( + json.dumps( + { + "location": "test", + "instance": "very_long_instance_name_that_will_be_too_long_to_build_socket_length", + "database_type": "postgres", + "use_proxy": "True", + "use_tcp": "False", + } + ) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) with self.assertRaises(AirflowException) as cm: hook.validate_socket_path_length() err = cm.exception @@ -798,27 +837,34 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(self, get_connection): connection = Connection() - connection.set_extra(json.dumps({ - "location": "test", - "instance": "short_instance_name", - "database_type": "postgres", - "use_proxy": "True", - "use_tcp": "False" - })) + connection.set_extra( + json.dumps( + { + "location": "test", + "instance": "short_instance_name", + "database_type": "postgres", + "use_proxy": "True", + "use_tcp": "False", + } + ) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) hook.validate_socket_path_length() - @parameterized.expand([ - ["http://:password@host:80/database"], - ["http://user:@host:80/database"], - ["http://user:password@/database"], - ["http://user:password@host:80/"], - ["http://user:password@/"], - ["http://host:80/database"], - ["http://host:80/"], - ]) + @parameterized.expand( + [ + ["http://:password@host:80/database"], + ["http://user:@host:80/database"], + ["http://user:password@/database"], + ["http://user:password@host:80/"], + ["http://user:password@/"], + ["http://host:80/database"], + ["http://host:80/"], + ] + ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_create_connection_missing_fields(self, uri, get_connection): connection = Connection(uri=uri) @@ -827,12 +873,13 @@ def test_cloudsql_database_hook_create_connection_missing_fields(self, uri, get_ "instance": "instance", "database_type": "postgres", 'use_proxy': "True", - 'use_tcp': "False" + 'use_tcp': "False", } connection.set_extra(json.dumps(params)) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) with self.assertRaises(AirflowException) as cm: hook.create_connection() err = cm.exception @@ -841,14 +888,13 @@ def test_cloudsql_database_hook_create_connection_missing_fields(self, uri, get_ @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connection): connection = Connection(uri="http://user:password@host:80/database") - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres", - })) + connection.set_extra( + json.dumps({"location": "test", "instance": "instance", "database_type": "postgres",}) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) with self.assertRaises(ValueError) as cm: hook.get_sqlproxy_runner() err = cm.exception @@ -857,16 +903,21 @@ def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connectio @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection): connection = Connection(uri="http://user:password@host:80/database") - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres", - 'use_proxy': "True", - 'use_tcp': "False" - })) + connection.set_extra( + json.dumps( + { + "location": "test", + "instance": "instance", + "database_type": "postgres", + 'use_proxy': "True", + 'use_tcp': "False", + } + ) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) hook.create_connection() proxy_runner = hook.get_sqlproxy_runner() self.assertIsNotNone(proxy_runner) @@ -874,21 +925,19 @@ def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection): @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def test_cloudsql_database_hook_get_database_hook(self, get_connection): connection = Connection(uri="http://user:password@host:80/database") - connection.set_extra(json.dumps({ - "location": "test", - "instance": "instance", - "database_type": "postgres", - })) + connection.set_extra( + json.dumps({"location": "test", "instance": "instance", "database_type": "postgres",}) + ) get_connection.return_value = connection - hook = CloudSQLDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection', - default_gcp_project_id='google_connection') + hook = CloudSQLDatabaseHook( + gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection' + ) connection = hook.create_connection() db_hook = hook.get_database_hook(connection=connection) self.assertIsNotNone(db_hook) class TestCloudSqlDatabaseQueryHook(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection') def setUp(self, m): super().setUp() @@ -901,13 +950,10 @@ def setUp(self, m): host='host', schema='schema', extra='{"database_type":"postgres", "location":"my_location", ' - '"instance":"my_instance", "use_proxy": true, ' - '"project_id":"my_project"}' - ) - self.connection = Connection( - conn_id='my_gcp_connection', - conn_type='google_cloud_platform', + '"instance":"my_instance", "use_proxy": true, ' + '"project_id":"my_project"}', ) + self.connection = Connection(conn_id='my_gcp_connection', conn_type='google_cloud_platform',) scopes = [ "https://www.googleapis.com/auth/pubsub", "https://www.googleapis.com/auth/datastore", @@ -919,16 +965,14 @@ def setUp(self, m): conn_extra = { "extra__google_cloud_platform__scope": ",".join(scopes), "extra__google_cloud_platform__project": "your-gcp-project", - "extra__google_cloud_platform__key_path": - '/var/local/google_cloud_default.json' + "extra__google_cloud_platform__key_path": '/var/local/google_cloud_default.json', } conn_extra_json = json.dumps(conn_extra) self.connection.set_extra(conn_extra_json) m.side_effect = [self.sql_connection, self.connection] self.db_hook = CloudSQLDatabaseHook( - gcp_cloudsql_conn_id='my_gcp_sql_connection', - gcp_conn_id='my_gcp_connection' + gcp_cloudsql_conn_id='my_gcp_sql_connection', gcp_conn_id='my_gcp_connection' ) def test_get_sqlproxy_runner(self): @@ -945,11 +989,13 @@ def test_get_sqlproxy_runner(self): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_not_too_long_unix_socket_path(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ - "project_id=example-project&location=europe-west1&" \ - "instance=" \ - "test_db_with_longname_but_with_limit_of_UNIX_socket&" \ - "use_proxy=True&sql_proxy_use_tcp=False" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" + "project_id=example-project&location=europe-west1&" + "instance=" + "test_db_with_longname_but_with_limit_of_UNIX_socket&" + "use_proxy=True&sql_proxy_use_tcp=False" + ) get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() @@ -968,17 +1014,21 @@ def _verify_postgres_connection(self, get_connection, uri): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_postgres(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=False&use_ssl=False" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=False&use_ssl=False" + ) self._verify_postgres_connection(get_connection, uri) @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_postgres_ssl(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \ - "sslkey=/bin/bash&sslrootcert=/bin/bash" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" + "sslkey=/bin/bash&sslrootcert=/bin/bash" + ) connection = self._verify_postgres_connection(get_connection, uri) self.assertEqual('/bin/bash', connection.extra_dejson['sslkey']) self.assertEqual('/bin/bash', connection.extra_dejson['sslcert']) @@ -986,9 +1036,11 @@ def test_hook_with_correct_parameters_postgres_ssl(self, get_connection): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=True&sql_proxy_use_tcp=False" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=True&sql_proxy_use_tcp=False" + ) get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() @@ -1000,9 +1052,11 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_project_id_missing(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ - "location=europe-west1&instance=testdb&" \ - "use_proxy=False&use_ssl=False" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" + "location=europe-west1&instance=testdb&" + "use_proxy=False&use_ssl=False" + ) self.verify_mysql_connection(get_connection, uri) def verify_mysql_connection(self, get_connection, uri): @@ -1017,9 +1071,11 @@ def verify_mysql_connection(self, get_connection, uri): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=True&sql_proxy_use_tcp=True" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=True&sql_proxy_use_tcp=True" + ) get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() @@ -1030,17 +1086,21 @@ def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_mysql(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=False&use_ssl=False" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=False&use_ssl=False" + ) self.verify_mysql_connection(get_connection, uri) @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_mysql_ssl(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \ - "sslkey=/bin/bash&sslrootcert=/bin/bash" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" + "sslkey=/bin/bash&sslrootcert=/bin/bash" + ) connection = self.verify_mysql_connection(get_connection, uri) self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['cert']) self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['key']) @@ -1048,25 +1108,28 @@ def test_hook_with_correct_parameters_mysql_ssl(self, get_connection): @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=True&sql_proxy_use_tcp=False" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=True&sql_proxy_use_tcp=False" + ) get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() self.assertEqual('mysql', connection.conn_type) self.assertEqual('localhost', connection.host) self.assertIn('/tmp', connection.extra_dejson['unix_socket']) - self.assertIn('example-project:europe-west1:testdb', - connection.extra_dejson['unix_socket']) + self.assertIn('example-project:europe-west1:testdb', connection.extra_dejson['unix_socket']) self.assertIsNone(connection.port) self.assertEqual('testdb', connection.schema) @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_mysql_tcp(self, get_connection): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \ - "project_id=example-project&location=europe-west1&instance=testdb&" \ - "use_proxy=True&sql_proxy_use_tcp=True" + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" + "project_id=example-project&location=europe-west1&instance=testdb&" + "use_proxy=True&sql_proxy_use_tcp=True" + ) get_connection.side_effect = [Connection(uri=uri)] hook = CloudSQLDatabaseHook() connection = hook.create_connection() diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py index 14a3286efb587..f89c7966ea4e3 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py @@ -28,12 +28,25 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - DESCRIPTION, FILTER_JOB_NAMES, FILTER_PROJECT_ID, METADATA, OPERATIONS, PROJECT_ID, STATUS, - TIME_TO_SLEEP_IN_SECONDS, TRANSFER_JOB, TRANSFER_JOB_FIELD_MASK, TRANSFER_JOBS, - CloudDataTransferServiceHook, GcpTransferJobsStatus, GcpTransferOperationStatus, gen_job_name, + DESCRIPTION, + FILTER_JOB_NAMES, + FILTER_PROJECT_ID, + METADATA, + OPERATIONS, + PROJECT_ID, + STATUS, + TIME_TO_SLEEP_IN_SECONDS, + TRANSFER_JOB, + TRANSFER_JOB_FIELD_MASK, + TRANSFER_JOBS, + CloudDataTransferServiceHook, + GcpTransferJobsStatus, + GcpTransferOperationStatus, + gen_job_name, ) from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -87,7 +100,6 @@ class GCPRequestMock: class TestGCPTransferServiceHookWithPassedName(unittest.TestCase): - def setUp(self): with mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', @@ -106,22 +118,22 @@ def setUp(self): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' '.CloudDataTransferServiceHook.get_conn' ) # pylint: disable=unused-argument - def test_pass_name_on_create_job(self, - get_conn: MagicMock, - project_id: PropertyMock, - get_transfer_job: MagicMock, - enable_transfer_job: MagicMock - ): + def test_pass_name_on_create_job( + self, + get_conn: MagicMock, + project_id: PropertyMock, + get_transfer_job: MagicMock, + enable_transfer_job: MagicMock, + ): body = _with_name(TEST_BODY, TEST_CLEAR_JOB_NAME) - get_conn.side_effect \ - = HttpError(GCPRequestMock(), TEST_HTTP_ERR_CONTENT) + get_conn.side_effect = HttpError(GCPRequestMock(), TEST_HTTP_ERR_CONTENT) with self.assertRaises(HttpError): @@ -138,17 +150,12 @@ def test_pass_name_on_create_job(self, class TestJobNames(unittest.TestCase): - def setUp(self) -> None: self.re_suffix = re.compile("^[0-9]{10}$") def test_new_suffix(self): - for job_name in ["jobNames/new_job", - "jobNames/new_job_h", - "jobNames/newJob"]: - self.assertIsNotNone( - self.re_suffix.match(gen_job_name(job_name).split("_")[-1]) - ) + for job_name in ["jobNames/new_job", "jobNames/new_job_h", "jobNames/newJob"]: + self.assertIsNotNone(self.re_suffix.match(gen_job_name(job_name).split("_")[-1])) class TestGCPTransferServiceHookWithPassedProjectId(unittest.TestCase): @@ -175,7 +182,7 @@ def test_gct_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -207,7 +214,7 @@ def test_get_transfer_job(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -235,7 +242,7 @@ def test_list_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -302,7 +309,7 @@ def test_get_transfer_operation(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -354,11 +361,13 @@ def test_resume_transfer_operation(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.time.sleep') - @mock.patch('airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.' - 'CloudDataTransferServiceHook.list_transfer_operations') + @mock.patch( + 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.' + 'CloudDataTransferServiceHook.list_transfer_operations' + ) def test_wait_for_transfer_job(self, mock_list, mock_sleep, mock_project_id): mock_list.side_effect = [ [{METADATA: {STATUS: GcpTransferOperationStatus.IN_PROGRESS}}], @@ -370,7 +379,7 @@ def test_wait_for_transfer_job(self, mock_list, mock_sleep, mock_project_id): calls = [ mock.call(request_filter={FILTER_PROJECT_ID: TEST_PROJECT_ID, FILTER_JOB_NAMES: [job_name]}), - mock.call(request_filter={FILTER_PROJECT_ID: TEST_PROJECT_ID, FILTER_JOB_NAMES: [job_name]}) + mock.call(request_filter={FILTER_PROJECT_ID: TEST_PROJECT_ID, FILTER_JOB_NAMES: [job_name]}), ] mock_list.assert_has_calls(calls, any_order=True) @@ -379,16 +388,14 @@ def test_wait_for_transfer_job(self, mock_list, mock_sleep, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.time.sleep') @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' '.CloudDataTransferServiceHook.get_conn' ) - def test_wait_for_transfer_job_failed( - self, mock_get_conn, mock_sleep, mock_project_id - ): + def test_wait_for_transfer_job_failed(self, mock_get_conn, mock_sleep, mock_project_id): list_method = mock_get_conn.return_value.transferOperations.return_value.list list_execute_method = list_method.return_value.execute list_execute_method.return_value = { @@ -406,7 +413,7 @@ def test_wait_for_transfer_job_failed( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch('airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.time.sleep') @mock.patch( @@ -493,8 +500,9 @@ def test_operations_contain_expected_statuses_red_path(self, statuses, expected_ def test_operations_contain_expected_statuses_green_path(self, statuses, expected_statuses): operations = [{NAME: TEST_TRANSFER_OPERATION_NAME, METADATA: {STATUS: status}} for status in statuses] - result = \ - CloudDataTransferServiceHook.operations_contain_expected_statuses(operations, expected_statuses) + result = CloudDataTransferServiceHook.operations_contain_expected_statuses( + operations, expected_statuses + ) self.assertTrue(result) @@ -523,7 +531,7 @@ def test_gct_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -541,7 +549,7 @@ def test_create_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -552,7 +560,8 @@ def test_get_transfer_job(self, get_conn, mock_project_id): execute_method = get_method.return_value.execute execute_method.return_value = TEST_TRANSFER_JOB res = self.gct_hook.get_transfer_job( # pylint: disable=no-value-for-parameter - job_name=TEST_TRANSFER_JOB_NAME) + job_name=TEST_TRANSFER_JOB_NAME + ) self.assertIsNotNone(res) self.assertEqual(TEST_TRANSFER_JOB_NAME, res[NAME]) get_method.assert_called_once_with(jobName=TEST_TRANSFER_JOB_NAME, projectId='example-project') @@ -561,7 +570,7 @@ def test_get_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -592,7 +601,7 @@ def test_list_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -660,7 +669,7 @@ def test_get_transfer_operation(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -725,7 +734,7 @@ def test_gct_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -748,7 +757,7 @@ def test_create_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -760,7 +769,8 @@ def test_get_transfer_job(self, get_conn, mock_project_id): execute_method.return_value = TEST_TRANSFER_JOB with self.assertRaises(AirflowException) as e: self.gct_hook.get_transfer_job( # pylint: disable=no-value-for-parameter - job_name=TEST_TRANSFER_JOB_NAME) + job_name=TEST_TRANSFER_JOB_NAME + ) self.assertEqual( 'The project id must be passed either as keyword project_id ' 'parameter or as project_id extra in GCP connection definition. ' @@ -771,7 +781,7 @@ def test_get_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -787,7 +797,8 @@ def test_list_transfer_job(self, get_conn, mock_project_id): with self.assertRaises(AirflowException) as e: self.gct_hook.list_transfer_job( - request_filter=_without_key(TEST_TRANSFER_JOB_FILTER, FILTER_PROJECT_ID)) + request_filter=_without_key(TEST_TRANSFER_JOB_FILTER, FILTER_PROJECT_ID) + ) self.assertEqual( 'The project id must be passed either as `project_id` key in `filter` parameter or as ' @@ -798,7 +809,7 @@ def test_list_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -820,7 +831,7 @@ def test_list_transfer_operation_multiple_page(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -844,7 +855,7 @@ def test_update_transfer_job(self, get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' @@ -853,7 +864,8 @@ def test_update_transfer_job(self, get_conn, mock_project_id): def test_delete_transfer_job(self, get_conn, mock_project_id): # pylint: disable=unused-argument with self.assertRaises(AirflowException) as e: self.gct_hook.delete_transfer_job( # pylint: disable=no-value-for-parameter - job_name=TEST_TRANSFER_JOB_NAME) + job_name=TEST_TRANSFER_JOB_NAME + ) self.assertEqual( 'The project id must be passed either as keyword project_id parameter or as project_id extra in ' @@ -864,7 +876,7 @@ def test_delete_transfer_job(self, get_conn, mock_project_id): # pylint: disabl @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch( 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service' diff --git a/tests/providers/google/cloud/hooks/test_compute.py b/tests/providers/google/cloud/hooks/test_compute.py index 7d6ae29dcb4a5..8d9f2e13605f1 100644 --- a/tests/providers/google/cloud/hooks/test_compute.py +++ b/tests/providers/google/cloud/hooks/test_compute.py @@ -26,7 +26,8 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook, GceOperationStatus from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -38,10 +39,11 @@ class TestGcpComputeHookNoDefaultProjectId(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_no_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_no_default_project_id, + ): self.gce_hook_no_project_id = ComputeEngineHook(gcp_conn_id='test') @mock.patch("airflow.providers.google.cloud.hooks.compute.ComputeEngineHook._authorize") @@ -63,15 +65,14 @@ def test_start_instance_overridden_project_id(self, wait_for_operation_to_comple execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook_no_project_id.start_instance( - project_id='example-project', - zone=GCE_ZONE, - resource_id=GCE_INSTANCE) + project_id='example-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE + ) self.assertIsNone(res) start_method.assert_called_once_with(instance='instance', project='example-project', zone='zone') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -83,15 +84,14 @@ def test_stop_instance_overridden_project_id(self, wait_for_operation_to_complet execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook_no_project_id.stop_instance( - project_id='example-project', - zone=GCE_ZONE, - resource_id=GCE_INSTANCE) + project_id='example-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE + ) self.assertIsNone(res) stop_method.assert_called_once_with(instance='instance', project='example-project', zone='zone') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -103,17 +103,16 @@ def test_set_machine_type_overridden_project_id(self, wait_for_operation_to_comp execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook_no_project_id.set_machine_type( - body={}, - project_id='example-project', - zone=GCE_ZONE, - resource_id=GCE_INSTANCE) + body={}, project_id='example-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE + ) self.assertIsNone(res) - set_machine_type_method.assert_called_once_with(body={}, instance='instance', - project='example-project', zone='zone') + set_machine_type_method.assert_called_once_with( + body={}, instance='instance', project='example-project', zone='zone' + ) execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -125,8 +124,7 @@ def test_get_instance_template_overridden_project_id(self, wait_for_operation_to execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook_no_project_id.get_instance_template( - resource_id=GCE_INSTANCE_TEMPLATE, - project_id='example-project' + resource_id=GCE_INSTANCE_TEMPLATE, project_id='example-project' ) self.assertIsNotNone(res) get_method.assert_called_once_with(instanceTemplate='instance-template', project='example-project') @@ -143,15 +141,14 @@ def test_insert_instance_template_overridden_project_id(self, wait_for_operation execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook_no_project_id.insert_instance_template( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - body={}, - request_id=GCE_REQUEST_ID + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, body={}, request_id=GCE_REQUEST_ID ) self.assertIsNone(res) insert_method.assert_called_once_with(body={}, project='example-project', requestId='request_id') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -163,14 +160,12 @@ def test_get_instance_group_manager_overridden_project_id(self, wait_for_operati execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook_no_project_id.get_instance_group_manager( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - zone=GCE_ZONE, - resource_id=GCE_INSTANCE_GROUP_MANAGER + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER ) self.assertIsNotNone(res) - get_method.assert_called_once_with(instanceGroupManager='instance_group_manager', - project='example-project', - zone='zone') + get_method.assert_called_once_with( + instanceGroupManager='instance_group_manager', project='example-project', zone='zone' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_not_called() @@ -178,8 +173,9 @@ def test_get_instance_group_manager_overridden_project_id(self, wait_for_operati @mock.patch( 'airflow.providers.google.cloud.hooks.compute.ComputeEngineHook._wait_for_operation_to_complete' ) - def test_patch_instance_group_manager_overridden_project_id(self, - wait_for_operation_to_complete, get_conn): + def test_patch_instance_group_manager_overridden_project_id( + self, wait_for_operation_to_complete, get_conn + ): patch_method = get_conn.return_value.instanceGroupManagers.return_value.patch execute_method = patch_method.return_value.execute execute_method.return_value = {"name": "operation_id"} @@ -189,7 +185,7 @@ def test_patch_instance_group_manager_overridden_project_id(self, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER, body={}, - request_id=GCE_REQUEST_ID + request_id=GCE_REQUEST_ID, ) self.assertIsNone(res) patch_method.assert_called_once_with( @@ -197,24 +193,26 @@ def test_patch_instance_group_manager_overridden_project_id(self, instanceGroupManager='instance_group_manager', project='example-project', requestId='request_id', - zone='zone' + zone='zone', ) execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id', - project_id='example-project', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + operation_name='operation_id', project_id='example-project', zone='zone' + ) class TestGcpComputeHookDefaultProjectId(unittest.TestCase): def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_default_project_id, + ): self.gce_hook = ComputeEngineHook(gcp_conn_id='test') @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -226,16 +224,14 @@ def test_start_instance(self, wait_for_operation_to_complete, get_conn, mock_pro execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.start_instance( - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + zone=GCE_ZONE, resource_id=GCE_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertIsNone(res) start_method.assert_called_once_with(instance='instance', project='example-project', zone='zone') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -246,21 +242,18 @@ def test_start_instance_overridden_project_id(self, wait_for_operation_to_comple execute_method = start_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None - res = self.gce_hook.start_instance( - project_id='new-project', - zone=GCE_ZONE, - resource_id=GCE_INSTANCE) + res = self.gce_hook.start_instance(project_id='new-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE) self.assertIsNone(res) start_method.assert_called_once_with(instance='instance', project='new-project', zone='zone') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='new-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='new-project', operation_name='operation_id', zone='zone' + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -272,16 +265,14 @@ def test_stop_instance(self, wait_for_operation_to_complete, get_conn, mock_proj execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.stop_instance( - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + zone=GCE_ZONE, resource_id=GCE_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertIsNone(res) stop_method.assert_called_once_with(instance='instance', project='example-project', zone='zone') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -292,21 +283,18 @@ def test_stop_instance_overridden_project_id(self, wait_for_operation_to_complet execute_method = stop_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None - res = self.gce_hook.stop_instance( - project_id='new-project', - zone=GCE_ZONE, - resource_id=GCE_INSTANCE) + res = self.gce_hook.stop_instance(project_id='new-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE) self.assertIsNone(res) stop_method.assert_called_once_with(instance='instance', project='new-project', zone='zone') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='new-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='new-project', operation_name='operation_id', zone='zone' + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -317,16 +305,13 @@ def test_set_machine_type_instance(self, wait_for_operation_to_complete, get_con execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.set_machine_type( - body={}, - zone=GCE_ZONE, - resource_id=GCE_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + body={}, zone=GCE_ZONE, resource_id=GCE_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertIsNone(res) execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -337,20 +322,18 @@ def test_set_machine_type_instance_overridden_project_id(self, wait_for_operatio execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.set_machine_type( - project_id='new-project', - body={}, - zone=GCE_ZONE, - resource_id=GCE_INSTANCE) + project_id='new-project', body={}, zone=GCE_ZONE, resource_id=GCE_INSTANCE + ) self.assertIsNone(res) execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='new-project', - operation_name='operation_id', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + project_id='new-project', operation_name='operation_id', zone='zone' + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -362,8 +345,8 @@ def test_get_instance_template(self, wait_for_operation_to_complete, get_conn, m execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.get_instance_template( - resource_id=GCE_INSTANCE_TEMPLATE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + resource_id=GCE_INSTANCE_TEMPLATE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) self.assertIsNotNone(res) get_method.assert_called_once_with(instanceTemplate='instance-template', project='example-project') execute_method.assert_called_once_with(num_retries=5) @@ -378,9 +361,7 @@ def test_get_instance_template_overridden_project_id(self, wait_for_operation_to execute_method = get_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None - res = self.gce_hook.get_instance_template( - project_id='new-project', - resource_id=GCE_INSTANCE_TEMPLATE) + res = self.gce_hook.get_instance_template(project_id='new-project', resource_id=GCE_INSTANCE_TEMPLATE) self.assertIsNotNone(res) get_method.assert_called_once_with(instanceTemplate='instance-template', project='new-project') execute_method.assert_called_once_with(num_retries=5) @@ -389,7 +370,7 @@ def test_get_instance_template_overridden_project_id(self, wait_for_operation_to @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -401,15 +382,14 @@ def test_insert_instance_template(self, wait_for_operation_to_complete, get_conn execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.insert_instance_template( - body={}, - request_id=GCE_REQUEST_ID, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + body={}, request_id=GCE_REQUEST_ID, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertIsNone(res) insert_method.assert_called_once_with(body={}, project='example-project', requestId='request_id') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='example-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='example-project', operation_name='operation_id' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -421,20 +401,19 @@ def test_insert_instance_template_overridden_project_id(self, wait_for_operation execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.insert_instance_template( - project_id='new-project', - body={}, - request_id=GCE_REQUEST_ID + project_id='new-project', body={}, request_id=GCE_REQUEST_ID ) self.assertIsNone(res) insert_method.assert_called_once_with(body={}, project='new-project', requestId='request_id') execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(project_id='new-project', - operation_name='operation_id') + wait_for_operation_to_complete.assert_called_once_with( + project_id='new-project', operation_name='operation_id' + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -446,14 +425,12 @@ def test_get_instance_group_manager(self, wait_for_operation_to_complete, get_co execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.get_instance_group_manager( - zone=GCE_ZONE, - resource_id=GCE_INSTANCE_GROUP_MANAGER, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertIsNotNone(res) - get_method.assert_called_once_with(instanceGroupManager='instance_group_manager', - project='example-project', - zone='zone') + get_method.assert_called_once_with( + instanceGroupManager='instance_group_manager', project='example-project', zone='zone' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_not_called() @@ -467,21 +444,19 @@ def test_get_instance_group_manager_overridden_project_id(self, wait_for_operati execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gce_hook.get_instance_group_manager( - project_id='new-project', - zone=GCE_ZONE, - resource_id=GCE_INSTANCE_GROUP_MANAGER + project_id='new-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER ) self.assertIsNotNone(res) - get_method.assert_called_once_with(instanceGroupManager='instance_group_manager', - project='new-project', - zone='zone') + get_method.assert_called_once_with( + instanceGroupManager='instance_group_manager', project='new-project', zone='zone' + ) execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_not_called() @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -505,20 +480,20 @@ def test_patch_instance_group_manager(self, wait_for_operation_to_complete, get_ instanceGroupManager='instance_group_manager', project='example-project', requestId='request_id', - zone='zone' + zone='zone', ) execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id', - project_id='example-project', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + operation_name='operation_id', project_id='example-project', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( 'airflow.providers.google.cloud.hooks.compute.ComputeEngineHook._wait_for_operation_to_complete' ) - def test_patch_instance_group_manager_overridden_project_id(self, - wait_for_operation_to_complete, - get_conn): + def test_patch_instance_group_manager_overridden_project_id( + self, wait_for_operation_to_complete, get_conn + ): patch_method = get_conn.return_value.instanceGroupManagers.return_value.patch execute_method = patch_method.return_value.execute execute_method.return_value = {"name": "operation_id"} @@ -528,7 +503,7 @@ def test_patch_instance_group_manager_overridden_project_id(self, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER, body={}, - request_id=GCE_REQUEST_ID + request_id=GCE_REQUEST_ID, ) self.assertIsNone(res) patch_method.assert_called_once_with( @@ -536,12 +511,12 @@ def test_patch_instance_group_manager_overridden_project_id(self, instanceGroupManager='instance_group_manager', project='new-project', requestId='request_id', - zone='zone' + zone='zone', ) execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id', - project_id='new-project', - zone='zone') + wait_for_operation_to_complete.assert_called_once_with( + operation_name='operation_id', project_id='new-project', zone='zone' + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -556,16 +531,13 @@ def test_wait_for_operation_to_complete_no_zone(self, mock_operation_status, moc # Test success mock_get_conn.return_value = service mock_operation_status.return_value = {'status': GceOperationStatus.DONE, 'error': None} - self.gce_hook._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=None - ) + self.gce_hook._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=None + ) - mock_operation_status.assert_called_once_with(service=service, - operation_name=operation_name, - project_id=project_id, - num_retries=num_retries - ) + mock_operation_status.assert_called_once_with( + service=service, operation_name=operation_name, project_id=project_id, num_retries=num_retries + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch( @@ -578,17 +550,17 @@ def test_wait_for_operation_to_complete_no_zone_error(self, mock_operation_statu # Test error mock_get_conn.return_value = service - mock_operation_status.return_value = {'status': GceOperationStatus.DONE, - 'error': {'errors': "some nasty errors"}, - 'httpErrorStatusCode': 400, - 'httpErrorMessage': 'sample msg' - } + mock_operation_status.return_value = { + 'status': GceOperationStatus.DONE, + 'error': {'errors': "some nasty errors"}, + 'httpErrorStatusCode': 400, + 'httpErrorMessage': 'sample msg', + } with self.assertRaises(AirflowException): - self.gce_hook._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=None - ) + self.gce_hook._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=None + ) @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn') @mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook._check_zone_operation_status') @@ -602,9 +574,8 @@ def test_wait_for_operation_to_complete_with_zone(self, mock_operation_status, m # Test success mock_get_conn.return_value = service mock_operation_status.return_value = {'status': GceOperationStatus.DONE, 'error': None} - self.gce_hook._wait_for_operation_to_complete(project_id=project_id, - operation_name=operation_name, - zone=zone - ) + self.gce_hook._wait_for_operation_to_complete( + project_id=project_id, operation_name=operation_name, zone=zone + ) mock_operation_status.assert_called_once_with(service, operation_name, project_id, zone, num_retries) diff --git a/tests/providers/google/cloud/hooks/test_datacatalog.py b/tests/providers/google/cloud/hooks/test_datacatalog.py index 232ebe32a0574..cb1844408bf05 100644 --- a/tests/providers/google/cloud/hooks/test_datacatalog.py +++ b/tests/providers/google/cloud/hooks/test_datacatalog.py @@ -26,7 +26,8 @@ from airflow import AirflowException from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, + mock_base_gcp_hook_default_project_id, + mock_base_gcp_hook_no_default_project_id, ) TEST_GCP_CONN_ID: str = "test-gcp-conn-id" @@ -89,9 +90,7 @@ def setUp(self,) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_lookup_entry_with_linked_resource(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.lookup_entry( linked_resource=TEST_LINKED_RESOURCE, @@ -110,9 +109,7 @@ def test_lookup_entry_with_linked_resource(self, mock_get_conn, mock_get_creds_a "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_lookup_entry_with_sql_resource(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.lookup_entry( sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA @@ -125,9 +122,7 @@ def test_lookup_entry_with_sql_resource(self, mock_get_conn, mock_get_creds_and_ "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_lookup_entry_without_resource(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex( AirflowException, re.escape("At least one of linked_resource, sql_resource should be set.") @@ -138,9 +133,7 @@ def test_lookup_entry_without_resource(self, mock_get_conn, mock_get_creds_and_p "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_search_catalog(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.search_catalog( scope=TEST_SCOPE, @@ -174,9 +167,7 @@ def setUp(self,) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_entry( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -200,9 +191,7 @@ def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_entry_group( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -225,9 +214,7 @@ def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -251,9 +238,7 @@ def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -277,9 +262,7 @@ def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag_template( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -302,9 +285,7 @@ def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag_template_field( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -328,9 +309,7 @@ def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_entry( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -351,9 +330,7 @@ def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_entry_group( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -373,9 +350,7 @@ def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_tag( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -397,9 +372,7 @@ def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_tag_template( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -421,9 +394,7 @@ def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_tag_template_field( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -446,9 +417,7 @@ def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.get_entry( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -469,9 +438,7 @@ def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.get_entry_group( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -493,9 +460,7 @@ def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.get_tag_template( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -515,9 +480,7 @@ def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.list_tags( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -540,9 +503,7 @@ def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_project_id) -> None: tag_1 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format("invalid-project")) tag_2 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)) @@ -570,9 +531,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.rename_tag_template_field( # pylint: disable=no-value-for-parameter location=TEST_LOCATION, @@ -595,9 +554,7 @@ def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_entry( # pylint: disable=no-value-for-parameter entry=TEST_ENTRY, @@ -621,9 +578,7 @@ def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_tag( # pylint: disable=no-value-for-parameter tag=deepcopy(TEST_TAG), @@ -648,9 +603,7 @@ def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_tag_template( # pylint: disable=no-value-for-parameter tag_template=TEST_TAG_TEMPLATE, @@ -673,9 +626,7 @@ def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, TEST_PROJECT_ID_1), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_tag_template_field( # pylint: disable=no-value-for-parameter tag_template_field=TEST_TAG_TEMPLATE_FIELD, @@ -709,9 +660,7 @@ def setUp(self,) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_entry( location=TEST_LOCATION, @@ -736,9 +685,7 @@ def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_entry_group( location=TEST_LOCATION, @@ -762,9 +709,7 @@ def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag( location=TEST_LOCATION, @@ -789,9 +734,7 @@ def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag( location=TEST_LOCATION, @@ -816,9 +759,7 @@ def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag_template( location=TEST_LOCATION, @@ -842,9 +783,7 @@ def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.create_tag_template_field( location=TEST_LOCATION, @@ -869,9 +808,7 @@ def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_entry( location=TEST_LOCATION, @@ -893,9 +830,7 @@ def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_entry_group( location=TEST_LOCATION, @@ -916,9 +851,7 @@ def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_tag( location=TEST_LOCATION, @@ -941,9 +874,7 @@ def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_tag_template( location=TEST_LOCATION, @@ -966,9 +897,7 @@ def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.delete_tag_template_field( location=TEST_LOCATION, @@ -992,9 +921,7 @@ def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.get_entry( location=TEST_LOCATION, @@ -1016,9 +943,7 @@ def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.get_entry_group( location=TEST_LOCATION, @@ -1041,9 +966,7 @@ def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.get_tag_template( location=TEST_LOCATION, @@ -1064,9 +987,7 @@ def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.list_tags( location=TEST_LOCATION, @@ -1090,9 +1011,7 @@ def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_project_id) -> None: tag_1 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format("invalid-project")) tag_2 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)) @@ -1121,9 +1040,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.rename_tag_template_field( location=TEST_LOCATION, @@ -1147,9 +1064,7 @@ def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_entry( entry=TEST_ENTRY, @@ -1174,9 +1089,7 @@ def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_tag( tag=deepcopy(TEST_TAG), @@ -1202,9 +1115,7 @@ def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_tag_template( tag_template=TEST_TAG_TEMPLATE, @@ -1228,9 +1139,7 @@ def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: self.hook.update_tag_template_field( tag_template_field=TEST_TAG_TEMPLATE_FIELD, @@ -1271,9 +1180,7 @@ def setUp(self,) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.create_entry( # pylint: disable=no-value-for-parameter @@ -1290,9 +1197,7 @@ def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.create_entry_group( # pylint: disable=no-value-for-parameter @@ -1308,9 +1213,7 @@ def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): @@ -1329,9 +1232,7 @@ def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): @@ -1350,9 +1251,7 @@ def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): @@ -1369,9 +1268,7 @@ def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): @@ -1389,9 +1286,7 @@ def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): @@ -1408,9 +1303,7 @@ def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.delete_entry_group( # pylint: disable=no-value-for-parameter @@ -1425,9 +1318,7 @@ def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.delete_tag( # pylint: disable=no-value-for-parameter @@ -1444,9 +1335,7 @@ def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.delete_tag_template( # pylint: disable=no-value-for-parameter @@ -1462,9 +1351,7 @@ def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.delete_tag_template_field( # pylint: disable=no-value-for-parameter @@ -1481,9 +1368,7 @@ def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.get_entry( # pylint: disable=no-value-for-parameter @@ -1499,9 +1384,7 @@ def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.get_entry_group( # pylint: disable=no-value-for-parameter @@ -1517,9 +1400,7 @@ def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.get_tag_template( # pylint: disable=no-value-for-parameter @@ -1534,9 +1415,7 @@ def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.list_tags( # pylint: disable=no-value-for-parameter @@ -1553,9 +1432,7 @@ def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_project_id) -> None: tag_1 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format("invalid-project")) tag_2 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)) @@ -1576,9 +1453,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.rename_tag_template_field( # pylint: disable=no-value-for-parameter @@ -1595,9 +1470,7 @@ def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.update_entry( # pylint: disable=no-value-for-parameter @@ -1615,9 +1488,7 @@ def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.update_tag( # pylint: disable=no-value-for-parameter @@ -1636,9 +1507,7 @@ def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.update_tag_template( # pylint: disable=no-value-for-parameter @@ -1655,9 +1524,7 @@ def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id", return_value=(TEST_CREDENTIALS, None), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn") def test_update_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None: with self.assertRaisesRegex(AirflowException, TEST_MESSAGE): self.hook.update_tag_template_field( # pylint: disable=no-value-for-parameter diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 50b323ae48acb..92ebe26cee959 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -28,8 +28,13 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( - DEFAULT_DATAFLOW_LOCATION, DataflowHook, DataflowJobStatus, DataflowJobType, _DataflowJobsController, - _DataflowRunner, _fallback_to_project_id_from_variables, + DEFAULT_DATAFLOW_LOCATION, + DataflowHook, + DataflowJobStatus, + DataflowJobType, + _DataflowJobsController, + _DataflowRunner, + _fallback_to_project_id_from_variables, ) TASK_ID = 'test-dataflow-operator' @@ -39,29 +44,21 @@ TEST_TEMPLATE = 'gs://dataflow-templates/wordcount/template_file' PARAMETERS = { 'inputFile': 'gs://dataflow-samples/shakespeare/kinglear.txt', - 'output': 'gs://test/output/my_output' + 'output': 'gs://test/output/my_output', } PY_FILE = 'apache_beam.examples.wordcount' JAR_FILE = 'unitest.jar' JOB_CLASS = 'com.example.UnitTest' PY_OPTIONS = ['-m'] -DATAFLOW_VARIABLES_PY = { - 'project': 'test', - 'staging_location': 'gs://test/staging', - 'labels': {'foo': 'bar'} -} +DATAFLOW_VARIABLES_PY = {'project': 'test', 'staging_location': 'gs://test/staging', 'labels': {'foo': 'bar'}} DATAFLOW_VARIABLES_JAVA = { 'project': 'test', 'stagingLocation': 'gs://test/staging', - 'labels': {'foo': 'bar'} + 'labels': {'foo': 'bar'}, } RUNTIME_ENV = { 'additionalExperiments': ['exp_flag1', 'exp_flag2'], - 'additionalUserLabels': { - 'name': 'wrench', - 'mass': '1.3kg', - 'count': '3' - }, + 'additionalUserLabels': {'name': 'wrench', 'mass': '1.3kg', 'count': '3'}, 'bypassTempDirValidation': {}, 'ipConfiguration': 'WORKER_IP_PRIVATE', 'kmsKeyName': ( @@ -87,7 +84,6 @@ class TestFallbackToVariables(unittest.TestCase): - def test_support_project_id_parameter(self): mock_instance = mock.MagicMock() @@ -123,7 +119,7 @@ def test_fn(self, *args, **kwargs): with self.assertRaisesRegex( AirflowException, "The mutually exclusive parameter `project_id` and `project` key in `variables` parameter are " - "both present\\. Please remove one\\." + "both present\\. Please remove one\\.", ): FixtureFallback().test_fn(variables={'project': "TEST"}, project_id="TEST2") @@ -136,26 +132,20 @@ def test_fn(self, *args, **kwargs): mock_instance(*args, **kwargs) with self.assertRaisesRegex( - AirflowException, - "You must use keyword arguments in this methods rather than positional" + AirflowException, "You must use keyword arguments in this methods rather than positional" ): FixutureFallback().test_fn({'project': "TEST"}, "TEST2") def mock_init( - self, - gcp_conn_id, - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id, delegate_to=None, impersonation_chain=None, ): # pylint: disable=unused-argument pass class TestDataflowHook(unittest.TestCase): - def setUp(self): - with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), - new=mock_init): + with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init): self.dataflow_hook = DataflowHook(gcp_conn_id='test') @mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize") @@ -171,9 +161,7 @@ def test_dataflow_client_creation(self, mock_build, mock_authorize): @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) - def test_start_python_dataflow( - self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid - ): + def test_start_python_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value @@ -181,17 +169,20 @@ def test_start_python_dataflow( dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, - dataflow=PY_FILE, py_options=PY_OPTIONS, + job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, ) - expected_cmd = ["python3", '-m', PY_FILE, - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--labels=foo=bar', - '--staging_location=gs://test/staging', - '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) + expected_cmd = [ + "python3", + '-m', + PY_FILE, + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -209,18 +200,20 @@ def test_start_python_dataflow_with_custom_region_as_variable( variables = copy.deepcopy(DATAFLOW_VARIABLES_PY) variables['region'] = TEST_LOCATION self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=variables, - dataflow=PY_FILE, py_options=PY_OPTIONS, + job_name=JOB_NAME, variables=variables, dataflow=PY_FILE, py_options=PY_OPTIONS, ) - expected_cmd = ["python3", '-m', PY_FILE, - f'--region={TEST_LOCATION}', - '--runner=DataflowRunner', - '--project=test', - '--labels=foo=bar', - '--staging_location=gs://test/staging', - '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) + expected_cmd = [ + "python3", + '-m', + PY_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -236,19 +229,24 @@ def test_start_python_dataflow_with_custom_region_as_paramater( dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, - dataflow=PY_FILE, py_options=PY_OPTIONS, - location=TEST_LOCATION + job_name=JOB_NAME, + variables=DATAFLOW_VARIABLES_PY, + dataflow=PY_FILE, + py_options=PY_OPTIONS, + location=TEST_LOCATION, ) - expected_cmd = ["python3", '-m', PY_FILE, - f'--region={TEST_LOCATION}', - '--runner=DataflowRunner', - '--project=test', - '--labels=foo=bar', - '--staging_location=gs://test/staging', - '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) + expected_cmd = [ + "python3", + '-m', + PY_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -267,25 +265,31 @@ def test_start_python_dataflow_with_multiple_extra_packages( variables['extra-package'] = ['a.whl', 'b.whl'] self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=variables, - dataflow=PY_FILE, py_options=PY_OPTIONS, + job_name=JOB_NAME, variables=variables, dataflow=PY_FILE, py_options=PY_OPTIONS, ) - expected_cmd = ["python3", '-m', PY_FILE, - '--extra-package=a.whl', - '--extra-package=b.whl', - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--labels=foo=bar', - '--staging_location=gs://test/staging', - '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] + expected_cmd = [ + "python3", + '-m', + PY_FILE, + '--extra-package=a.whl', + '--extra-package=b.whl', + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), + ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) - @parameterized.expand([ - ('default_to_python3', 'python3'), - ('major_version_2', 'python2'), - ('major_version_3', 'python3'), - ('minor_version', 'python3.6') - ]) + @parameterized.expand( + [ + ('default_to_python3', 'python3'), + ('major_version_2', 'python2'), + ('major_version_3', 'python3'), + ('minor_version', 'python3.6'), + ] + ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @@ -301,24 +305,28 @@ def test_start_python_dataflow_with_custom_interpreter( dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, - dataflow=PY_FILE, py_options=PY_OPTIONS, + job_name=JOB_NAME, + variables=DATAFLOW_VARIABLES_PY, + dataflow=PY_FILE, + py_options=PY_OPTIONS, py_interpreter=py_interpreter, ) - expected_cmd = [py_interpreter, '-m', PY_FILE, - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--labels=foo=bar', - '--staging_location=gs://test/staging', - '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) - - @parameterized.expand([ - (['foo-bar'], False), - (['foo-bar'], True), - ([], True), - ]) + expected_cmd = [ + py_interpreter, + '-m', + PY_FILE, + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) + + @parameterized.expand( + [(['foo-bar'], False), (['foo-bar'], True), ([], True),] + ) @mock.patch(DATAFLOW_STRING.format('prepare_virtualenv')) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -342,19 +350,25 @@ def test_start_python_dataflow_with_non_empty_py_requirements_and_without_system dataflowjob_instance.wait_for_done.return_value = None mock_virtualenv.return_value = '/dummy_dir/bin/python' self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, - dataflow=PY_FILE, py_options=PY_OPTIONS, + job_name=JOB_NAME, + variables=DATAFLOW_VARIABLES_PY, + dataflow=PY_FILE, + py_options=PY_OPTIONS, py_requirements=current_py_requirements, - py_system_site_packages=current_py_system_site_packages + py_system_site_packages=current_py_system_site_packages, ) - expected_cmd = ['/dummy_dir/bin/python', '-m', PY_FILE, - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--labels=foo=bar', - '--staging_location=gs://test/staging', - '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) + expected_cmd = [ + '/dummy_dir/bin/python', + '-m', + PY_FILE, + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -371,17 +385,18 @@ def test_start_python_dataflow_with_empty_py_requirements_and_without_system_pac dataflowjob_instance.wait_for_done.return_value = None with self.assertRaisesRegex(AirflowException, "Invalid method invocation."): self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, - dataflow=PY_FILE, py_options=PY_OPTIONS, - py_requirements=[] + job_name=JOB_NAME, + variables=DATAFLOW_VARIABLES_PY, + dataflow=PY_FILE, + py_options=PY_OPTIONS, + py_requirements=[], ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) - def test_start_java_dataflow(self, mock_conn, - mock_dataflow, mock_dataflowjob, mock_uuid): + def test_start_java_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value @@ -389,17 +404,21 @@ def test_start_java_dataflow(self, mock_conn, dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, - jar=JAR_FILE) - expected_cmd = ['java', '-jar', JAR_FILE, - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--stagingLocation=gs://test/staging', - '--labels={"foo":"bar"}', - '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] + job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE + ) + expected_cmd = [ + 'java', + '-jar', + JAR_FILE, + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), + ] self.assertListEqual( - sorted(expected_cmd), - sorted(mock_dataflow.call_args[1]["cmd"]), + sorted(expected_cmd), sorted(mock_dataflow.call_args[1]["cmd"]), ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @@ -419,18 +438,22 @@ def test_start_java_dataflow_with_multiple_values_in_variables( variables['mock-option'] = ['a.whl', 'b.whl'] self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=variables, - jar=JAR_FILE) - expected_cmd = ['java', '-jar', JAR_FILE, - '--mock-option=a.whl', - '--mock-option=b.whl', - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--stagingLocation=gs://test/staging', - '--labels={"foo":"bar"}', - '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) + job_name=JOB_NAME, variables=variables, jar=JAR_FILE + ) + expected_cmd = [ + 'java', + '-jar', + JAR_FILE, + '--mock-option=a.whl', + '--mock-option=b.whl', + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @@ -450,18 +473,21 @@ def test_start_java_dataflow_with_custom_region_as_variable( variables['region'] = TEST_LOCATION self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=variables, - jar=JAR_FILE) - expected_cmd = ['java', '-jar', JAR_FILE, - f'--region={TEST_LOCATION}', - '--runner=DataflowRunner', - '--project=test', - '--stagingLocation=gs://test/staging', - '--labels={"foo":"bar"}', - '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] + job_name=JOB_NAME, variables=variables, jar=JAR_FILE + ) + expected_cmd = [ + 'java', + '-jar', + JAR_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), + ] self.assertListEqual( - sorted(expected_cmd), - sorted(mock_dataflow.call_args[1]["cmd"]), + sorted(expected_cmd), sorted(mock_dataflow.call_args[1]["cmd"]), ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @@ -482,26 +508,28 @@ def test_start_java_dataflow_with_custom_region_as_parameter( variables['region'] = TEST_LOCATION self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=variables, - jar=JAR_FILE) - expected_cmd = ['java', '-jar', JAR_FILE, - f'--region={TEST_LOCATION}', - '--runner=DataflowRunner', - '--project=test', - '--stagingLocation=gs://test/staging', - '--labels={"foo":"bar"}', - '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] + job_name=JOB_NAME, variables=variables, jar=JAR_FILE + ) + expected_cmd = [ + 'java', + '-jar', + JAR_FILE, + f'--region={TEST_LOCATION}', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), + ] self.assertListEqual( - sorted(expected_cmd), - sorted(mock_dataflow.call_args[1]["cmd"]), + sorted(expected_cmd), sorted(mock_dataflow.call_args[1]["cmd"]), ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) - def test_start_java_dataflow_with_job_class( - self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): + def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value @@ -509,27 +537,34 @@ def test_start_java_dataflow_with_job_class( dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, - jar=JAR_FILE, job_class=JOB_CLASS) - expected_cmd = ['java', '-cp', JAR_FILE, JOB_CLASS, - '--region=us-central1', - '--runner=DataflowRunner', '--project=test', - '--stagingLocation=gs://test/staging', - '--labels={"foo":"bar"}', - '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)] - self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), - sorted(expected_cmd)) - - @parameterized.expand([ - (JOB_NAME, JOB_NAME, False), - ('test-example', 'test_example', False), - ('test-dataflow-pipeline-12345678', JOB_NAME, True), - ('test-example-12345678', 'test_example', True), - ('df-job-1', 'df-job-1', False), - ('df-job', 'df-job', False), - ('dfjob', 'dfjob', False), - ('dfjob1', 'dfjob1', False), - ]) + job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE, job_class=JOB_CLASS + ) + expected_cmd = [ + 'java', + '-cp', + JAR_FILE, + JOB_CLASS, + '--region=us-central1', + '--runner=DataflowRunner', + '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), + ] + self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) + + @parameterized.expand( + [ + (JOB_NAME, JOB_NAME, False), + ('test-example', 'test_example', False), + ('test-dataflow-pipeline-12345678', JOB_NAME, True), + ('test-example-12345678', 'test_example', True), + ('df-job-1', 'df-job-1', False), + ('df-job', 'df-job', False), + ('dfjob', 'dfjob', False), + ('dfjob1', 'dfjob1', False), + ] + ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_name, mock_uuid4): job_name = self.dataflow_hook._build_dataflow_job_name( @@ -538,24 +573,16 @@ def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_nam self.assertEqual(expected_result, job_name) - @parameterized.expand([ - ("1dfjob@", ), - ("dfjob@", ), - ("df^jo", ) - ]) + @parameterized.expand([("1dfjob@",), ("dfjob@",), ("df^jo",)]) def test_build_dataflow_job_name_with_invalid_value(self, job_name): self.assertRaises( - ValueError, - self.dataflow_hook._build_dataflow_job_name, - job_name=job_name, append_job_name=False + ValueError, self.dataflow_hook._build_dataflow_job_name, job_name=job_name, append_job_name=False ) class TestDataflowTemplateHook(unittest.TestCase): - def setUp(self): - with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), - new=mock_init): + with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init): self.dataflow_hook = DataflowHook(gcp_conn_id='test') @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) @@ -564,27 +591,23 @@ def setUp(self): def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid): launch_method = ( - mock_conn.return_value. - projects.return_value. - locations.return_value. - templates.return_value. - launch + mock_conn.return_value.projects.return_value.locations.return_value.templates.return_value.launch ) launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}} - variables = { - 'zone': 'us-central1-f', - 'tempLocation': 'gs://test/temp' - } + variables = {'zone': 'us-central1-f', 'tempLocation': 'gs://test/temp'} self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables=copy.deepcopy(variables), parameters=PARAMETERS, - dataflow_template=TEST_TEMPLATE, project_id=TEST_PROJECT + job_name=JOB_NAME, + variables=copy.deepcopy(variables), + parameters=PARAMETERS, + dataflow_template=TEST_TEMPLATE, + project_id=TEST_PROJECT, ) launch_method.assert_called_once_with( body={ 'jobName': 'test-dataflow-pipeline-12345678', 'parameters': PARAMETERS, - 'environment': variables + 'environment': variables, }, gcsPath='gs://dataflow-templates/wordcount/template_file', projectId=TEST_PROJECT, @@ -598,7 +621,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid): num_retries=5, poll_sleep=10, project_number=TEST_PROJECT, - location=DEFAULT_DATAFLOW_LOCATION + location=DEFAULT_DATAFLOW_LOCATION, ) mock_controller.return_value.wait_for_done.assert_called_once() @@ -609,11 +632,7 @@ def test_start_template_dataflow_with_custom_region_as_variable( self, mock_conn, mock_controller, mock_uuid ): launch_method = ( - mock_conn.return_value. - projects.return_value. - locations.return_value. - templates.return_value. - launch + mock_conn.return_value.projects.return_value.locations.return_value.templates.return_value.launch ) launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}} self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter @@ -621,14 +640,11 @@ def test_start_template_dataflow_with_custom_region_as_variable( variables={'region': TEST_LOCATION}, parameters=PARAMETERS, dataflow_template=TEST_TEMPLATE, - project_id=TEST_PROJECT + project_id=TEST_PROJECT, ) launch_method.assert_called_once_with( - projectId=TEST_PROJECT, - location=TEST_LOCATION, - gcsPath=TEST_TEMPLATE, - body=mock.ANY, + projectId=TEST_PROJECT, location=TEST_LOCATION, gcsPath=TEST_TEMPLATE, body=mock.ANY, ) mock_controller.assert_called_once_with( @@ -649,25 +665,21 @@ def test_start_template_dataflow_with_custom_region_as_parameter( self, mock_conn, mock_controller, mock_uuid ): launch_method = ( - mock_conn.return_value. - projects.return_value. - locations.return_value. - templates.return_value. - launch + mock_conn.return_value.projects.return_value.locations.return_value.templates.return_value.launch ) launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}} self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter - job_name=JOB_NAME, variables={}, parameters=PARAMETERS, - dataflow_template=TEST_TEMPLATE, location=TEST_LOCATION, project_id=TEST_PROJECT + job_name=JOB_NAME, + variables={}, + parameters=PARAMETERS, + dataflow_template=TEST_TEMPLATE, + location=TEST_LOCATION, + project_id=TEST_PROJECT, ) launch_method.assert_called_once_with( - body={ - 'jobName': UNIQUE_JOB_NAME, - 'parameters': PARAMETERS, - 'environment': {} - }, + body={'jobName': UNIQUE_JOB_NAME, 'parameters': PARAMETERS, 'environment': {}}, gcsPath='gs://dataflow-templates/wordcount/template_file', projectId=TEST_PROJECT, location=TEST_LOCATION, @@ -692,12 +704,13 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None + # fmt: off method = (mock_conn.return_value .projects.return_value .locations.return_value .templates.return_value .launch) - + # fmt: on method.return_value.execute.return_value = {'job': {'id': TEST_JOB_ID}} self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, @@ -706,15 +719,9 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow dataflow_template=TEST_TEMPLATE, project_id=TEST_PROJECT, ) - body = {"jobName": mock.ANY, - "parameters": PARAMETERS, - "environment": RUNTIME_ENV - } + body = {"jobName": mock.ANY, "parameters": PARAMETERS, "environment": RUNTIME_ENV} method.assert_called_once_with( - projectId=TEST_PROJECT, - location=DEFAULT_DATAFLOW_LOCATION, - gcsPath=TEST_TEMPLATE, - body=body, + projectId=TEST_PROJECT, location=DEFAULT_DATAFLOW_LOCATION, gcsPath=TEST_TEMPLATE, body=body, ) mock_dataflowjob.assert_called_once_with( dataflow=mock_conn.return_value, @@ -723,7 +730,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow name='test-dataflow-pipeline-{}'.format(MOCK_UUID), num_retries=5, poll_sleep=10, - project_number=TEST_PROJECT + project_number=TEST_PROJECT, ) mock_uuid.assert_called_once_with() @@ -731,10 +738,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_cancel_job(self, mock_get_conn, jobs_controller): self.dataflow_hook.cancel_job( - job_name=UNIQUE_JOB_NAME, - job_id=TEST_JOB_ID, - project_id=TEST_PROJECT, - location=TEST_LOCATION + job_name=UNIQUE_JOB_NAME, job_id=TEST_JOB_ID, project_id=TEST_PROJECT, location=TEST_LOCATION ) jobs_controller.assert_called_once_with( dataflow=mock_get_conn.return_value, @@ -742,56 +746,47 @@ def test_cancel_job(self, mock_get_conn, jobs_controller): location=TEST_LOCATION, name=UNIQUE_JOB_NAME, poll_sleep=10, - project_number=TEST_PROJECT + project_number=TEST_PROJECT, ) jobs_controller.cancel() class TestDataflowJob(unittest.TestCase): - def setUp(self): self.mock_dataflow = MagicMock() def test_dataflow_job_init_with_job_id(self): mock_jobs = MagicMock() - self.mock_dataflow.projects.return_value.locations.return_value. \ - jobs.return_value = mock_jobs + self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value = mock_jobs _DataflowJobsController( - self.mock_dataflow, TEST_PROJECT, - TEST_LOCATION, 10, UNIQUE_JOB_NAME, TEST_JOB_ID).get_jobs() - mock_jobs.get.assert_called_once_with(projectId=TEST_PROJECT, location=TEST_LOCATION, - jobId=TEST_JOB_ID) + self.mock_dataflow, TEST_PROJECT, TEST_LOCATION, 10, UNIQUE_JOB_NAME, TEST_JOB_ID + ).get_jobs() + mock_jobs.get.assert_called_once_with( + projectId=TEST_PROJECT, location=TEST_LOCATION, jobId=TEST_JOB_ID + ) def test_dataflow_job_init_without_job_id(self): job = {"id": TEST_JOB_ID, "name": UNIQUE_JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE} - mock_list = ( - self.mock_dataflow.projects.return_value. - locations.return_value. - jobs.return_value.list - ) - ( - mock_list.return_value. - execute.return_value - ) = {'jobs': [job]} + mock_list = self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.list + (mock_list.return_value.execute.return_value) = {'jobs': [job]} + # fmt: off ( self.mock_dataflow.projects.return_value. locations.return_value. jobs.return_value. list_next.return_value ) = None + # fmt: on _DataflowJobsController( - self.mock_dataflow, TEST_PROJECT, - TEST_LOCATION, 10, UNIQUE_JOB_NAME).get_jobs() + self.mock_dataflow, TEST_PROJECT, TEST_LOCATION, 10, UNIQUE_JOB_NAME + ).get_jobs() - mock_list.assert_called_once_with( - projectId=TEST_PROJECT, - location=TEST_LOCATION - ) + mock_list.assert_called_once_with(projectId=TEST_PROJECT, location=TEST_LOCATION) def test_dataflow_job_wait_for_multiple_jobs(self): job = {"id": TEST_JOB_ID, "name": UNIQUE_JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE} - + # fmt: off ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -807,7 +802,7 @@ def test_dataflow_job_wait_for_multiple_jobs(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -816,19 +811,22 @@ def test_dataflow_job_wait_for_multiple_jobs(self): poll_sleep=10, job_id=TEST_JOB_ID, num_retries=20, - multiple_jobs=True + multiple_jobs=True, ) dataflow_job.wait_for_done() - self.mock_dataflow.projects.return_value.locations.return_value. \ - jobs.return_value.list.assert_called_once_with(location=TEST_LOCATION, projectId=TEST_PROJECT) + # fmt: off + self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.\ + list.assert_called_once_with(location=TEST_LOCATION, projectId=TEST_PROJECT) - self.mock_dataflow.projects.return_value.locations.return_value. \ - jobs.return_value.list.return_value.execute.assert_called_once_with(num_retries=20) + self.mock_dataflow.projects.return_value.locations.return_value.jobs.return_value.list\ + .return_value.execute.assert_called_once_with(num_retries=20) + # fmt: on self.assertEqual(dataflow_job.get_jobs(), [job, job]) def test_dataflow_job_wait_for_multiple_jobs_and_one_failed(self): + # fmt: off ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -847,7 +845,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_failed(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -856,12 +854,13 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_failed(self): poll_sleep=0, job_id=None, num_retries=20, - multiple_jobs=True + multiple_jobs=True, ) with self.assertRaisesRegex(Exception, 'Google Cloud Dataflow job name-2 has failed\\.'): dataflow_job.wait_for_done() def test_dataflow_job_wait_for_multiple_jobs_and_one_cancelled(self): + # fmt: off ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -880,7 +879,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_cancelled(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -889,12 +888,13 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_cancelled(self): poll_sleep=0, job_id=None, num_retries=20, - multiple_jobs=True + multiple_jobs=True, ) with self.assertRaisesRegex(Exception, 'Google Cloud Dataflow job name-2 was cancelled\\.'): dataflow_job.wait_for_done() def test_dataflow_job_wait_for_multiple_jobs_and_one_unknown(self): + # fmt: off ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -913,7 +913,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_unknown(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -922,12 +922,13 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_unknown(self): poll_sleep=0, job_id=None, num_retries=20, - multiple_jobs=True + multiple_jobs=True, ) with self.assertRaisesRegex(Exception, 'Google Cloud Dataflow job name-2 was unknown state: unknown'): dataflow_job.wait_for_done() def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self): + # fmt: off mock_jobs_list = ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -950,7 +951,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -959,7 +960,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self): poll_sleep=0, job_id=None, num_retries=20, - multiple_jobs=True + multiple_jobs=True, ) dataflow_job.wait_for_done() @@ -967,7 +968,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self): def test_dataflow_job_wait_for_single_jobs(self): job = {"id": TEST_JOB_ID, "name": UNIQUE_JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE} - + # fmt: off self.mock_dataflow.projects.return_value.locations.return_value. \ jobs.return_value.get.return_value.execute.return_value = job @@ -977,7 +978,7 @@ def test_dataflow_job_wait_for_single_jobs(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -986,10 +987,10 @@ def test_dataflow_job_wait_for_single_jobs(self): poll_sleep=10, job_id=TEST_JOB_ID, num_retries=20, - multiple_jobs=False + multiple_jobs=False, ) dataflow_job.wait_for_done() - + # fmt: off self.mock_dataflow.projects.return_value.locations.return_value. \ jobs.return_value.get.assert_called_once_with( jobId=TEST_JOB_ID, @@ -999,10 +1000,11 @@ def test_dataflow_job_wait_for_single_jobs(self): self.mock_dataflow.projects.return_value.locations.return_value. \ jobs.return_value.get.return_value.execute.assert_called_once_with(num_retries=20) - + # fmt: on self.assertEqual(dataflow_job.get_jobs(), [job]) def test_dataflow_job_is_job_running_with_no_job(self): + # fmt: off mock_jobs_list = ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -1018,7 +1020,7 @@ def test_dataflow_job_is_job_running_with_no_job(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -1027,7 +1029,7 @@ def test_dataflow_job_is_job_running_with_no_job(self): poll_sleep=0, job_id=None, num_retries=20, - multiple_jobs=True + multiple_jobs=True, ) result = dataflow_job.is_job_running() @@ -1035,9 +1037,11 @@ def test_dataflow_job_is_job_running_with_no_job(self): def test_dataflow_job_cancel_job(self): job = { - "id": TEST_JOB_ID, "name": UNIQUE_JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_RUNNING + "id": TEST_JOB_ID, + "name": UNIQUE_JOB_NAME, + "currentState": DataflowJobStatus.JOB_STATE_RUNNING, } - + # fmt: off get_method = ( self.mock_dataflow.projects.return_value. locations.return_value. @@ -1052,7 +1056,7 @@ def test_dataflow_job_cancel_job(self): jobs.return_value. list_next.return_value ) = None - + # fmt: on dataflow_job = _DataflowJobsController( dataflow=self.mock_dataflow, project_number=TEST_PROJECT, @@ -1061,36 +1065,32 @@ def test_dataflow_job_cancel_job(self): poll_sleep=10, job_id=TEST_JOB_ID, num_retries=20, - multiple_jobs=False + multiple_jobs=False, ) dataflow_job.cancel() - get_method.assert_called_once_with( - jobId=TEST_JOB_ID, - location=TEST_LOCATION, - projectId=TEST_PROJECT - ) + get_method.assert_called_once_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT) get_method.return_value.execute.assert_called_once_with(num_retries=20) self.mock_dataflow.new_batch_http_request.assert_called_once_with() mock_batch = self.mock_dataflow.new_batch_http_request.return_value + # fmt: off mock_update = ( self.mock_dataflow.projects.return_value. locations.return_value. jobs.return_value. update ) + # fmt: on mock_update.assert_called_once_with( body={'requestedState': 'JOB_STATE_CANCELLED'}, jobId='test-job-id', location=TEST_LOCATION, projectId='test-project', ) - mock_batch.add.assert_called_once_with( - mock_update.return_value - ) + mock_batch.add.assert_called_once_with(mock_update.return_value) mock_batch.execute.assert_called_once() @@ -1161,13 +1161,15 @@ def test_dataflow_job_cancel_job(self): class TestDataflow(unittest.TestCase): - - @parameterized.expand([ - (APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG, ), - (APACHE_BEAM_V_2_22_0_JAVA_SDK_LOG, ), - (APACHE_BEAM_V_2_14_0_PYTHON_SDK_LOG, ), - (APACHE_BEAM_V_2_22_0_PYTHON_SDK_LOG, ), - ], name_func=lambda func, num, p: f"{func.__name__}_{num}") + @parameterized.expand( + [ + (APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG,), + (APACHE_BEAM_V_2_22_0_JAVA_SDK_LOG,), + (APACHE_BEAM_V_2_14_0_PYTHON_SDK_LOG,), + (APACHE_BEAM_V_2_22_0_PYTHON_SDK_LOG,), + ], + name_func=lambda func, num, p: f"{func.__name__}_{num}", + ) def test_data_flow_valid_job_id(self, log): echos = ";".join([f"echo {shlex.quote(line)}" for line in log.split("\n")]) cmd = ["bash", "-c", echos] diff --git a/tests/providers/google/cloud/hooks/test_datafusion.py b/tests/providers/google/cloud/hooks/test_datafusion.py index bcc30be01382b..403e5e51f1ddd 100644 --- a/tests/providers/google/cloud/hooks/test_datafusion.py +++ b/tests/providers/google/cloud/hooks/test_datafusion.py @@ -50,13 +50,10 @@ def hook(): class TestDataFusionHook: @staticmethod def mock_endpoint(get_conn_mock): - return get_conn_mock.return_value.projects.return_value.\ - locations.return_value.instances.return_value + return get_conn_mock.return_value.projects.return_value.locations.return_value.instances.return_value def test_name(self, hook): - expected = ( - f"projects/{PROJECT_ID}/locations/{LOCATION}/instances/{INSTANCE_NAME}" - ) + expected = f"projects/{PROJECT_ID}/locations/{LOCATION}/instances/{INSTANCE_NAME}" assert hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME) == expected def test_parent(self, hook): @@ -68,52 +65,37 @@ def test_parent(self, hook): def test_get_conn(self, mock_authorize, mock_build, hook): mock_authorize.return_value = "test" hook.get_conn() - mock_build.assert_called_once_with( - "datafusion", hook.api_version, http="test", cache_discovery=False - ) + mock_build.assert_called_once_with("datafusion", hook.api_version, http="test", cache_discovery=False) @mock.patch(HOOK_STR.format("DataFusionHook.get_conn")) def test_restart_instance(self, get_conn_mock, hook): method_mock = self.mock_endpoint(get_conn_mock).restart method_mock.return_value.execute.return_value = "value" - result = hook.restart_instance( - instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID - ) + result = hook.restart_instance(instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID) assert result == "value" - method_mock.assert_called_once_with( - name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME) - ) + method_mock.assert_called_once_with(name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME)) @mock.patch(HOOK_STR.format("DataFusionHook.get_conn")) def test_delete_instance(self, get_conn_mock, hook): method_mock = self.mock_endpoint(get_conn_mock).delete method_mock.return_value.execute.return_value = "value" - result = hook.delete_instance( - instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID - ) + result = hook.delete_instance(instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID) assert result == "value" - method_mock.assert_called_once_with( - name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME) - ) + method_mock.assert_called_once_with(name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME)) @mock.patch(HOOK_STR.format("DataFusionHook.get_conn")) def test_create_instance(self, get_conn_mock, hook): method_mock = self.mock_endpoint(get_conn_mock).create method_mock.return_value.execute.return_value = "value" result = hook.create_instance( - instance_name=INSTANCE_NAME, - instance=INSTANCE, - location=LOCATION, - project_id=PROJECT_ID, + instance_name=INSTANCE_NAME, instance=INSTANCE, location=LOCATION, project_id=PROJECT_ID, ) assert result == "value" method_mock.assert_called_once_with( - parent=hook._parent(PROJECT_ID, LOCATION), - body=INSTANCE, - instanceId=INSTANCE_NAME, + parent=hook._parent(PROJECT_ID, LOCATION), body=INSTANCE, instanceId=INSTANCE_NAME, ) @mock.patch(HOOK_STR.format("DataFusionHook.get_conn")) @@ -130,23 +112,17 @@ def test_patch_instance(self, get_conn_mock, hook): assert result == "value" method_mock.assert_called_once_with( - name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME), - body=INSTANCE, - updateMask="instance.name", + name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME), body=INSTANCE, updateMask="instance.name", ) @mock.patch(HOOK_STR.format("DataFusionHook.get_conn")) def test_get_instance(self, get_conn_mock, hook): method_mock = self.mock_endpoint(get_conn_mock).get method_mock.return_value.execute.return_value = "value" - result = hook.get_instance( - instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID - ) + result = hook.get_instance(instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID) assert result == "value" - method_mock.assert_called_once_with( - name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME) - ) + method_mock.assert_called_once_with(name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME)) @mock.patch("google.auth.transport.requests.Request") @mock.patch(HOOK_STR.format("DataFusionHook._get_credentials")) @@ -164,21 +140,15 @@ def test_cdap_request(self, get_credentials_mock, mock_request, hook): get_credentials_mock.return_value.before_request.assert_called_once_with( request=request, method=method, url=url, headers=headers ) - request.assert_called_once_with( - method=method, url=url, headers=headers, body=json.dumps(body) - ) + request.assert_called_once_with(method=method, url=url, headers=headers, body=json.dumps(body)) assert result == request.return_value @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) def test_create_pipeline(self, mock_request, hook): mock_request.return_value.status = 200 - hook.create_pipeline( - pipeline_name=PIPELINE_NAME, pipeline=PIPELINE, instance_url=INSTANCE_URL - ) + hook.create_pipeline(pipeline_name=PIPELINE_NAME, pipeline=PIPELINE, instance_url=INSTANCE_URL) mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}", - method="PUT", - body=PIPELINE, + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}", method="PUT", body=PIPELINE, ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -186,9 +156,7 @@ def test_delete_pipeline(self, mock_request, hook): mock_request.return_value.status = 200 hook.delete_pipeline(pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL) mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}", - method="DELETE", - body=None, + url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}", method="DELETE", body=None, ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -208,28 +176,24 @@ def test_start_pipeline(self, mock_wait_for_pipeline_state, mock_request, hook): run_id = 1234 mock_request.return_value = mock.MagicMock(status=200, data='[{{"runId":{}}}]'.format(run_id)) - hook.start_pipeline( - pipeline_name=PIPELINE_NAME, - instance_url=INSTANCE_URL, - runtime_args=RUNTIME_ARGS - ) - body = [{ - "appId": PIPELINE_NAME, - "programType": "workflow", - "programId": "DataPipelineWorkflow", - "runtimeargs": RUNTIME_ARGS - }] + hook.start_pipeline(pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL, runtime_args=RUNTIME_ARGS) + body = [ + { + "appId": PIPELINE_NAME, + "programType": "workflow", + "programId": "DataPipelineWorkflow", + "runtimeargs": RUNTIME_ARGS, + } + ] mock_request.assert_called_once_with( - url=f"{INSTANCE_URL}/v3/namespaces/default/start", - method="POST", - body=body + url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body ) mock_wait_for_pipeline_state.assert_called_once_with( instance_url=INSTANCE_URL, namespace="default", pipeline_name=PIPELINE_NAME, pipeline_id=run_id, - success_states=SUCCESS_STATES + [PipelineStates.RUNNING] + success_states=SUCCESS_STATES + [PipelineStates.RUNNING], ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) @@ -238,6 +202,6 @@ def test_stop_pipeline(self, mock_request, hook): hook.stop_pipeline(pipeline_name=PIPELINE_NAME, instance_url=INSTANCE_URL) mock_request.assert_called_once_with( url=f"{INSTANCE_URL}/v3/namespaces/default/apps/{PIPELINE_NAME}/" - f"workflows/DataPipelineWorkflow/stop", + f"workflows/DataPipelineWorkflow/stop", method="POST", ) diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py b/tests/providers/google/cloud/hooks/test_dataprep.py index a3e01d2e3470d..110bd9f4075ce 100644 --- a/tests/providers/google/cloud/hooks/test_dataprep.py +++ b/tests/providers/google/cloud/hooks/test_dataprep.py @@ -47,10 +47,7 @@ def test_mock_should_be_called_once_with_params(self, mock_get_request, mock_hoo mock_hook.get_jobs_for_job_group(job_id=JOB_ID) mock_get_request.assert_called_once_with( f"{URL}/{JOB_ID}/jobs", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {TOKEN}", - }, + headers={"Content-Type": "application/json", "Authorization": f"Bearer {TOKEN}",}, ) @patch( @@ -72,13 +69,7 @@ def test_should_not_retry_after_success(self, mock_get_request, mock_hook): @patch( "airflow.providers.google.cloud.hooks.dataprep.requests.get", - side_effect=[ - HTTPError(), - HTTPError(), - HTTPError(), - HTTPError(), - mock.MagicMock(), - ], + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), mock.MagicMock(),], ) def test_should_retry_after_four_errors(self, mock_get_request, mock_hook): mock_hook.get_jobs_for_job_group.retry.sleep = mock.Mock() # pylint: disable=no-member diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py index 08f9d608d0f9a..5205953fed0a1 100644 --- a/tests/providers/google/cloud/hooks/test_dataproc.py +++ b/tests/providers/google/cloud/hooks/test_dataproc.py @@ -47,48 +47,36 @@ def mock_init(*args, **kwargs): class TestDataprocHook(unittest.TestCase): def setUp(self): - with mock.patch( - BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init - ): + with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init): self.hook = DataprocHook(gcp_conn_id="test") @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch( - DATAPROC_STRING.format("DataprocHook.client_info"), - new_callable=mock.PropertyMock, + DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock, ) @mock.patch(DATAPROC_STRING.format("ClusterControllerClient")) - def test_get_cluster_client( - self, mock_client, mock_client_info, mock_get_credentials - ): + def test_get_cluster_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_cluster_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, - client_options={ - "api_endpoint": "{}-dataproc.googleapis.com:443".format(GCP_LOCATION) - }, + client_options={"api_endpoint": "{}-dataproc.googleapis.com:443".format(GCP_LOCATION)}, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch( - DATAPROC_STRING.format("DataprocHook.client_info"), - new_callable=mock.PropertyMock, + DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock, ) @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient")) - def test_get_template_client( - self, mock_client, mock_client_info, mock_get_credentials - ): + def test_get_template_client(self, mock_client, mock_client_info, mock_get_credentials): _ = self.hook.get_template_client mock_client.assert_called_once_with( - credentials=mock_get_credentials.return_value, - client_info=mock_client_info.return_value, + credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch( - DATAPROC_STRING.format("DataprocHook.client_info"), - new_callable=mock.PropertyMock, + DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock, ) @mock.patch(DATAPROC_STRING.format("JobControllerClient")) def test_get_job_client(self, mock_client, mock_client_info, mock_get_credentials): @@ -96,16 +84,12 @@ def test_get_job_client(self, mock_client, mock_client_info, mock_get_credential mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, - client_options={ - "api_endpoint": "{}-dataproc.googleapis.com:443".format(GCP_LOCATION) - }, + client_options={"api_endpoint": "{}-dataproc.googleapis.com:443".format(GCP_LOCATION)}, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_create_cluster(self, mock_client): - self.hook.create_cluster( - project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER - ) + self.hook.create_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.create_cluster.assert_called_once_with( project_id=GCP_PROJECT, @@ -119,9 +103,7 @@ def test_create_cluster(self, mock_client): @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_delete_cluster(self, mock_client): - self.hook.delete_cluster( - project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME - ) + self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.delete_cluster.assert_called_once_with( project_id=GCP_PROJECT, @@ -136,9 +118,7 @@ def test_delete_cluster(self, mock_client): @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_diagnose_cluster(self, mock_client): - self.hook.diagnose_cluster( - project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME - ) + self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.diagnose_cluster.assert_called_once_with( project_id=GCP_PROJECT, @@ -152,9 +132,7 @@ def test_diagnose_cluster(self, mock_client): @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_get_cluster(self, mock_client): - self.hook.get_cluster( - project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME - ) + self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_cluster.assert_called_once_with( project_id=GCP_PROJECT, @@ -169,9 +147,7 @@ def test_get_cluster(self, mock_client): def test_list_clusters(self, mock_client): filter_ = "filter" - self.hook.list_clusters( - project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_ - ) + self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.list_clusters.assert_called_once_with( project_id=GCP_PROJECT, @@ -211,9 +187,7 @@ def test_update_cluster(self, mock_client): def test_create_workflow_template(self, mock_client): template = {"test": "test"} mock_client.region_path.return_value = PARENT - self.hook.create_workflow_template( - location=GCP_LOCATION, template=template, project_id=GCP_PROJECT - ) + self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION) mock_client.create_workflow_template.assert_called_once_with( parent=PARENT, template=template, retry=None, timeout=None, metadata=None @@ -226,9 +200,7 @@ def test_instantiate_workflow_template(self, mock_client): self.hook.instantiate_workflow_template( location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT ) - mock_client.workflow_template_path.assert_called_once_with( - GCP_PROJECT, GCP_LOCATION, template_name - ) + mock_client.workflow_template_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION, template_name) mock_client.instantiate_workflow_template.assert_called_once_with( name=NAME, version=None, @@ -248,12 +220,7 @@ def test_instantiate_inline_workflow_template(self, mock_client): ) mock_client.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION) mock_client.instantiate_inline_workflow_template.assert_called_once_with( - parent=PARENT, - template=template, - request_id=None, - retry=None, - timeout=None, - metadata=None, + parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job")) @@ -264,10 +231,7 @@ def test_wait_for_job(self, mock_get_job): ] with self.assertRaises(AirflowException): self.hook.wait_for_job( - job_id=JOB_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT, - wait_time=0, + job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0, ) calls = [ mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), @@ -308,18 +272,14 @@ def test_submit(self, mock_submit_job, mock_wait_for_job): mock_submit_job.return_value.reference.job_id = JOB_ID with self.assertWarns(DeprecationWarning): self.hook.submit(project_id=GCP_PROJECT, job=JOB, region=GCP_LOCATION) - mock_submit_job.assert_called_once_with( - location=GCP_LOCATION, project_id=GCP_PROJECT, job=JOB - ) + mock_submit_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job=JOB) mock_wait_for_job.assert_called_once_with( location=GCP_LOCATION, project_id=GCP_PROJECT, job_id=JOB_ID ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_cancel_job(self, mock_client): - self.hook.cancel_job( - location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT - ) + self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.cancel_job.assert_called_once_with( region=GCP_LOCATION, @@ -373,9 +333,7 @@ def test_add_labels(self): def test_add_variables(self): variables = ["variable"] self.builder.add_variables(variables) - self.assertEqual( - variables, self.builder.job["job"][self.job_type]["script_variables"] - ) + self.assertEqual(variables, self.builder.job["job"][self.job_type]["script_variables"]) def test_add_args(self): args = ["args"] @@ -385,30 +343,22 @@ def test_add_args(self): def test_add_query(self): query = ["query"] self.builder.add_query(query) - self.assertEqual( - {"queries": [query]}, self.builder.job["job"][self.job_type]["query_list"] - ) + self.assertEqual({"queries": [query]}, self.builder.job["job"][self.job_type]["query_list"]) def test_add_query_uri(self): query_uri = "query_uri" self.builder.add_query_uri(query_uri) - self.assertEqual( - query_uri, self.builder.job["job"][self.job_type]["query_file_uri"] - ) + self.assertEqual(query_uri, self.builder.job["job"][self.job_type]["query_file_uri"]) def test_add_jar_file_uris(self): jar_file_uris = ["jar_file_uris"] self.builder.add_jar_file_uris(jar_file_uris) - self.assertEqual( - jar_file_uris, self.builder.job["job"][self.job_type]["jar_file_uris"] - ) + self.assertEqual(jar_file_uris, self.builder.job["job"][self.job_type]["jar_file_uris"]) def test_add_archive_uris(self): archive_uris = ["archive_uris"] self.builder.add_archive_uris(archive_uris) - self.assertEqual( - archive_uris, self.builder.job["job"][self.job_type]["archive_uris"] - ) + self.assertEqual(archive_uris, self.builder.job["job"][self.job_type]["archive_uris"]) def test_add_file_uris(self): file_uris = ["file_uris"] @@ -418,9 +368,7 @@ def test_add_file_uris(self): def test_add_python_file_uris(self): python_file_uris = ["python_file_uris"] self.builder.add_python_file_uris(python_file_uris) - self.assertEqual( - python_file_uris, self.builder.job["job"][self.job_type]["python_file_uris"] - ) + self.assertEqual(python_file_uris, self.builder.job["job"][self.job_type]["python_file_uris"]) def test_set_main_error(self): with self.assertRaises(Exception): @@ -434,16 +382,12 @@ def test_set_main_class(self): def test_set_main_jar(self): main = "main" self.builder.set_main(main_class=None, main_jar=main) - self.assertEqual( - main, self.builder.job["job"][self.job_type]["main_jar_file_uri"] - ) + self.assertEqual(main, self.builder.job["job"][self.job_type]["main_jar_file_uri"]) def test_set_python_main(self): main = "main" self.builder.set_python_main(main) - self.assertEqual( - main, self.builder.job["job"][self.job_type]["main_python_file_uri"] - ) + self.assertEqual(main, self.builder.job["job"][self.job_type]["main_python_file_uri"]) @mock.patch(DATAPROC_STRING.format("uuid.uuid4")) def test_set_job_name(self, mock_uuid): diff --git a/tests/providers/google/cloud/hooks/test_datastore.py b/tests/providers/google/cloud/hooks/test_datastore.py index 51591a848533c..fa4a52017f4f3 100644 --- a/tests/providers/google/cloud/hooks/test_datastore.py +++ b/tests/providers/google/cloud/hooks/test_datastore.py @@ -29,10 +29,7 @@ def mock_init( - self, - gcp_conn_id, - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id, delegate_to=None, impersonation_chain=None, ): # pylint: disable=unused-argument pass @@ -49,8 +46,9 @@ def setUp(self): def test_get_conn(self, mock_build, mock_authorize): conn = self.datastore_hook.get_conn() - mock_build.assert_called_once_with('datastore', 'v1', http=mock_authorize.return_value, - cache_discovery=False) + mock_build.assert_called_once_with( + 'datastore', 'v1', http=mock_authorize.return_value, cache_discovery=False + ) self.assertEqual(conn, mock_build.return_value) self.assertEqual(conn, self.datastore_hook.connection) @@ -64,14 +62,16 @@ def test_allocate_ids(self, mock_get_conn): projects = self.datastore_hook.connection.projects projects.assert_called_once_with() allocate_ids = projects.return_value.allocateIds - allocate_ids.assert_called_once_with(projectId=GCP_PROJECT_ID, - body={'keys': partial_keys}) + allocate_ids.assert_called_once_with(projectId=GCP_PROJECT_ID, body={'keys': partial_keys}) execute = allocate_ids.return_value.execute execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(keys, execute.return_value['keys']) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_allocate_ids_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.connection = mock_get_conn.return_value @@ -79,7 +79,8 @@ def test_allocate_ids_no_project_id(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException) as err: self.datastore_hook.allocate_ids( # pylint: disable=no-value-for-parameter - partial_keys=partial_keys) + partial_keys=partial_keys + ) self.assertIn("project_id", str(err.exception)) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') @@ -87,22 +88,22 @@ def test_begin_transaction(self, mock_get_conn): self.datastore_hook.connection = mock_get_conn.return_value transaction = self.datastore_hook.begin_transaction( - project_id=GCP_PROJECT_ID, - transaction_options={}, + project_id=GCP_PROJECT_ID, transaction_options={}, ) projects = self.datastore_hook.connection.projects projects.assert_called_once_with() begin_transaction = projects.return_value.beginTransaction - begin_transaction.assert_called_once_with( - projectId=GCP_PROJECT_ID, body={'transactionOptions': {}} - ) + begin_transaction.assert_called_once_with(projectId=GCP_PROJECT_ID, body={'transactionOptions': {}}) execute = begin_transaction.return_value.execute execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(transaction, execute.return_value['transaction']) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_begin_transaction_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.connection = mock_get_conn.return_value @@ -125,8 +126,11 @@ def test_commit(self, mock_get_conn): execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(resp, execute.return_value) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_commit_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.connection = mock_get_conn.return_value @@ -143,27 +147,26 @@ def test_lookup(self, mock_get_conn): read_consistency = 'ENUM' transaction = 'transaction' - resp = self.datastore_hook.lookup(keys=keys, - read_consistency=read_consistency, - transaction=transaction, - project_id=GCP_PROJECT_ID - ) + resp = self.datastore_hook.lookup( + keys=keys, read_consistency=read_consistency, transaction=transaction, project_id=GCP_PROJECT_ID + ) projects = self.datastore_hook.connection.projects projects.assert_called_once_with() lookup = projects.return_value.lookup - lookup.assert_called_once_with(projectId=GCP_PROJECT_ID, - body={ - 'keys': keys, - 'readConsistency': read_consistency, - 'transaction': transaction - }) + lookup.assert_called_once_with( + projectId=GCP_PROJECT_ID, + body={'keys': keys, 'readConsistency': read_consistency, 'transaction': transaction}, + ) execute = lookup.return_value.execute execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(resp, execute.return_value) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_lookup_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.connection = mock_get_conn.return_value @@ -172,10 +175,9 @@ def test_lookup_no_project_id(self, mock_get_conn, mock_project_id): transaction = 'transaction' with self.assertRaises(AirflowException) as err: - self.datastore_hook.lookup(keys=keys, # pylint: disable=no-value-for-parameter - read_consistency=read_consistency, - transaction=transaction, - ) + self.datastore_hook.lookup( # pylint: disable=no-value-for-parameter + keys=keys, read_consistency=read_consistency, transaction=transaction, + ) self.assertIn("project_id", str(err.exception)) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') @@ -188,13 +190,15 @@ def test_rollback(self, mock_get_conn): projects = self.datastore_hook.connection.projects projects.assert_called_once_with() rollback = projects.return_value.rollback - rollback.assert_called_once_with(projectId=GCP_PROJECT_ID, - body={'transaction': transaction}) + rollback.assert_called_once_with(projectId=GCP_PROJECT_ID, body={'transaction': transaction}) execute = rollback.return_value.execute execute.assert_called_once_with(num_retries=mock.ANY) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_rollback_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.connection = mock_get_conn.return_value @@ -219,8 +223,11 @@ def test_run_query(self, mock_get_conn): execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(resp, execute.return_value['batch']) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_run_query_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.connection = mock_get_conn.return_value @@ -265,11 +272,13 @@ def test_delete_operation(self, mock_get_conn): self.assertEqual(resp, execute.return_value) @patch('airflow.providers.google.cloud.hooks.datastore.time.sleep') - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_operation', - side_effect=[ - {'metadata': {'common': {'state': 'PROCESSING'}}}, - {'metadata': {'common': {'state': 'NOT PROCESSING'}}} - ]) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_operation', + side_effect=[ + {'metadata': {'common': {'state': 'PROCESSING'}}}, + {'metadata': {'common': {'state': 'NOT PROCESSING'}}}, + ], + ) def test_poll_operation_until_done(self, mock_get_operation, mock_time_sleep): name = 'name' polling_interval_in_seconds = 10 @@ -288,30 +297,34 @@ def test_export_to_storage_bucket(self, mock_get_conn): entity_filter = {} labels = {} - resp = self.datastore_hook.export_to_storage_bucket(bucket=bucket, - namespace=namespace, - entity_filter=entity_filter, - labels=labels, - project_id=GCP_PROJECT_ID - ) + resp = self.datastore_hook.export_to_storage_bucket( + bucket=bucket, + namespace=namespace, + entity_filter=entity_filter, + labels=labels, + project_id=GCP_PROJECT_ID, + ) projects = self.datastore_hook.admin_connection.projects projects.assert_called_once_with() export = projects.return_value.export - export.assert_called_once_with(projectId=GCP_PROJECT_ID, - body={ - 'outputUrlPrefix': 'gs://' + '/'.join( - filter(None, [bucket, namespace]) - ), - 'entityFilter': entity_filter, - 'labels': labels, - }) + export.assert_called_once_with( + projectId=GCP_PROJECT_ID, + body={ + 'outputUrlPrefix': 'gs://' + '/'.join(filter(None, [bucket, namespace])), + 'entityFilter': entity_filter, + 'labels': labels, + }, + ) execute = export.return_value.execute execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(resp, execute.return_value) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_export_to_storage_bucket_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.admin_connection = mock_get_conn.return_value @@ -322,10 +335,7 @@ def test_export_to_storage_bucket_no_project_id(self, mock_get_conn, mock_projec with self.assertRaises(AirflowException) as err: self.datastore_hook.export_to_storage_bucket( # pylint: disable=no-value-for-parameter - bucket=bucket, - namespace=namespace, - entity_filter=entity_filter, - labels=labels, + bucket=bucket, namespace=namespace, entity_filter=entity_filter, labels=labels, ) self.assertIn("project_id", str(err.exception)) @@ -338,31 +348,35 @@ def test_import_from_storage_bucket(self, mock_get_conn): entity_filter = {} labels = {} - resp = self.datastore_hook.import_from_storage_bucket(bucket=bucket, - file=file, - namespace=namespace, - entity_filter=entity_filter, - labels=labels, - project_id=GCP_PROJECT_ID - ) + resp = self.datastore_hook.import_from_storage_bucket( + bucket=bucket, + file=file, + namespace=namespace, + entity_filter=entity_filter, + labels=labels, + project_id=GCP_PROJECT_ID, + ) projects = self.datastore_hook.admin_connection.projects projects.assert_called_once_with() import_ = projects.return_value.import_ - import_.assert_called_once_with(projectId=GCP_PROJECT_ID, - body={ - 'inputUrl': 'gs://' + '/'.join( - filter(None, [bucket, namespace, file]) - ), - 'entityFilter': entity_filter, - 'labels': labels, - }) + import_.assert_called_once_with( + projectId=GCP_PROJECT_ID, + body={ + 'inputUrl': 'gs://' + '/'.join(filter(None, [bucket, namespace, file])), + 'entityFilter': entity_filter, + 'labels': labels, + }, + ) execute = import_.return_value.execute execute.assert_called_once_with(num_retries=mock.ANY) self.assertEqual(resp, execute.return_value) - @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', - new_callable=mock.PropertyMock, return_value=None) + @patch( + 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id', + new_callable=mock.PropertyMock, + return_value=None, + ) @patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn') def test_import_from_storage_bucket_no_project_id(self, mock_get_conn, mock_project_id): self.datastore_hook.admin_connection = mock_get_conn.return_value @@ -374,10 +388,6 @@ def test_import_from_storage_bucket_no_project_id(self, mock_get_conn, mock_proj with self.assertRaises(AirflowException) as err: self.datastore_hook.import_from_storage_bucket( # pylint: disable=no-value-for-parameter - bucket=bucket, - file=file, - namespace=namespace, - entity_filter=entity_filter, - labels=labels, + bucket=bucket, file=file, namespace=namespace, entity_filter=entity_filter, labels=labels, ) self.assertIn("project_id", str(err.exception)) diff --git a/tests/providers/google/cloud/hooks/test_dlp.py b/tests/providers/google/cloud/hooks/test_dlp.py index 380e849018aef..f29be16401d9a 100644 --- a/tests/providers/google/cloud/hooks/test_dlp.py +++ b/tests/providers/google/cloud/hooks/test_dlp.py @@ -52,15 +52,9 @@ STORED_INFO_TYPE_ORGANIZATION_PATH = "organizations/{}/storedInfoTypes/{}".format( ORGANIZATION_ID, STORED_INFO_TYPE_ID ) -DEIDENTIFY_TEMPLATE_PROJECT_PATH = "projects/{}/deidentifyTemplates/{}".format( - PROJECT_ID, TEMPLATE_ID -) -INSPECT_TEMPLATE_PROJECT_PATH = "projects/{}/inspectTemplates/{}".format( - PROJECT_ID, TEMPLATE_ID -) -STORED_INFO_TYPE_PROJECT_PATH = "projects/{}/storedInfoTypes/{}".format( - PROJECT_ID, STORED_INFO_TYPE_ID -) +DEIDENTIFY_TEMPLATE_PROJECT_PATH = "projects/{}/deidentifyTemplates/{}".format(PROJECT_ID, TEMPLATE_ID) +INSPECT_TEMPLATE_PROJECT_PATH = "projects/{}/inspectTemplates/{}".format(PROJECT_ID, TEMPLATE_ID) +STORED_INFO_TYPE_PROJECT_PATH = "projects/{}/storedInfoTypes/{}".format(PROJECT_ID, STORED_INFO_TYPE_ID) JOB_TRIGGER_PATH = "projects/{}/jobTriggers/{}".format(PROJECT_ID, TRIGGER_ID) @@ -73,16 +67,14 @@ def setUp(self): self.hook = CloudDLPHook(gcp_conn_id="test") @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.client_info", - new_callable=mock.PropertyMock + "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.client_info", new_callable=mock.PropertyMock ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.dlp.DlpServiceClient") def test_dlp_service_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.hook._client, result) @@ -103,7 +95,7 @@ def test_cancel_dlp_job_without_dlp_job_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_cancel_dlp_job_without_parent(self, _, mock_project_id): @@ -113,7 +105,7 @@ def test_cancel_dlp_job_without_parent(self, _, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_deidentify_template_with_org_id(self, get_conn, mock_project_id): @@ -148,7 +140,7 @@ def test_create_deidentify_template_with_project_id(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_deidentify_template_without_parent(self, _, mock_project_id): @@ -158,9 +150,7 @@ def test_create_deidentify_template_without_parent(self, _, mock_project_id): @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_dlp_job(self, get_conn): get_conn.return_value.create_dlp_job.return_value = API_RESPONSE - result = self.hook.create_dlp_job( - project_id=PROJECT_ID, wait_until_finished=False - ) + result = self.hook.create_dlp_job(project_id=PROJECT_ID, wait_until_finished=False) self.assertIs(result, API_RESPONSE) get_conn.return_value.create_dlp_job.assert_called_once_with( @@ -176,7 +166,7 @@ def test_create_dlp_job(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_dlp_job_without_project_id(self, mock_get_conn, mock_project_id): @@ -199,7 +189,7 @@ def test_create_dlp_job_with_wait_until_finished(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_inspect_template_with_org_id(self, get_conn, mock_project_id): @@ -234,7 +224,7 @@ def test_create_inspect_template_with_project_id(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_inspect_template_without_parent(self, _, mock_project_id): @@ -248,22 +238,15 @@ def test_create_job_trigger(self, get_conn): self.assertIs(result, API_RESPONSE) get_conn.return_value.create_job_trigger.assert_called_once_with( - parent=PROJECT_PATH, - job_trigger=None, - trigger_id=None, - retry=None, - timeout=None, - metadata=None, + parent=PROJECT_PATH, job_trigger=None, trigger_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_job_trigger_without_parent(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.create_job_trigger() # pylint: disable=no-value-for-parameter @@ -271,7 +254,7 @@ def test_create_job_trigger_without_parent(self, mock_get_conn, mock_project_id) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_stored_info_type_with_org_id(self, get_conn, mock_project_id): @@ -306,11 +289,9 @@ def test_create_stored_info_type_with_project_id(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_create_stored_info_type_without_parent(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.create_stored_info_type() @@ -336,11 +317,9 @@ def test_deidentify_content(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_deidentify_content_without_parent(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.deidentify_content() # pylint: disable=no-value-for-parameter @@ -348,34 +327,22 @@ def test_deidentify_content_without_parent(self, mock_get_conn, mock_project_id) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_deidentify_template_with_org_id(self, get_conn, mock_project_id): - self.hook.delete_deidentify_template( - template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID - ) + self.hook.delete_deidentify_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID) get_conn.return_value.delete_deidentify_template.assert_called_once_with( - name=DEIDENTIFY_TEMPLATE_ORGANIZATION_PATH, - retry=None, - timeout=None, - metadata=None, + name=DEIDENTIFY_TEMPLATE_ORGANIZATION_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_deidentify_template_with_project_id(self, get_conn): - self.hook.delete_deidentify_template( - template_id=TEMPLATE_ID, project_id=PROJECT_ID - ) + self.hook.delete_deidentify_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID) get_conn.return_value.delete_deidentify_template.assert_called_once_with( - name=DEIDENTIFY_TEMPLATE_PROJECT_PATH, - retry=None, - timeout=None, - metadata=None, + name=DEIDENTIFY_TEMPLATE_PROJECT_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -386,7 +353,7 @@ def test_delete_deidentify_template_without_template_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_deidentify_template_without_parent(self, mock_get_conn, mock_project_id): @@ -409,7 +376,7 @@ def test_delete_dlp_job_without_dlp_job_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_dlp_job_without_parent(self, mock_get_conn, mock_project_id): @@ -419,32 +386,22 @@ def test_delete_dlp_job_without_parent(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_inspect_template_with_org_id(self, get_conn, mock_project_id): - self.hook.delete_inspect_template( - template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID - ) + self.hook.delete_inspect_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID) get_conn.return_value.delete_inspect_template.assert_called_once_with( - name=INSPECT_TEMPLATE_ORGANIZATION_PATH, - retry=None, - timeout=None, - metadata=None, + name=INSPECT_TEMPLATE_ORGANIZATION_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_inspect_template_with_project_id(self, get_conn): - self.hook.delete_inspect_template( - template_id=TEMPLATE_ID, project_id=PROJECT_ID - ) + self.hook.delete_inspect_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID) get_conn.return_value.delete_inspect_template.assert_called_once_with( - name=INSPECT_TEMPLATE_PROJECT_PATH, - retry=None, - timeout=None, - metadata=None, + name=INSPECT_TEMPLATE_PROJECT_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -455,7 +412,7 @@ def test_delete_inspect_template_without_template_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_inspect_template_without_parent(self, mock_get_conn, mock_project_id): @@ -478,7 +435,7 @@ def test_delete_job_trigger_without_trigger_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_job_trigger_without_parent(self, mock_get_conn, mock_project_id): @@ -488,7 +445,7 @@ def test_delete_job_trigger_without_parent(self, mock_get_conn, mock_project_id) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_stored_info_type_with_org_id(self, get_conn, mock_project_id): @@ -497,23 +454,15 @@ def test_delete_stored_info_type_with_org_id(self, get_conn, mock_project_id): ) get_conn.return_value.delete_stored_info_type.assert_called_once_with( - name=STORED_INFO_TYPE_ORGANIZATION_PATH, - retry=None, - timeout=None, - metadata=None, + name=STORED_INFO_TYPE_ORGANIZATION_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_stored_info_type_with_project_id(self, get_conn): - self.hook.delete_stored_info_type( - stored_info_type_id=STORED_INFO_TYPE_ID, project_id=PROJECT_ID - ) + self.hook.delete_stored_info_type(stored_info_type_id=STORED_INFO_TYPE_ID, project_id=PROJECT_ID) get_conn.return_value.delete_stored_info_type.assert_called_once_with( - name=STORED_INFO_TYPE_PROJECT_PATH, - retry=None, - timeout=None, - metadata=None, + name=STORED_INFO_TYPE_PROJECT_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -524,7 +473,7 @@ def test_delete_stored_info_type_without_stored_info_type_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_delete_stored_info_type_without_parent(self, mock_get_conn, mock_project_id): @@ -534,36 +483,26 @@ def test_delete_stored_info_type_without_parent(self, mock_get_conn, mock_projec @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_deidentify_template_with_org_id(self, get_conn, mock_project_id): get_conn.return_value.get_deidentify_template.return_value = API_RESPONSE - result = self.hook.get_deidentify_template( - template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID - ) + result = self.hook.get_deidentify_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.get_deidentify_template.assert_called_once_with( - name=DEIDENTIFY_TEMPLATE_ORGANIZATION_PATH, - retry=None, - timeout=None, - metadata=None, + name=DEIDENTIFY_TEMPLATE_ORGANIZATION_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_deidentify_template_with_project_id(self, get_conn): get_conn.return_value.get_deidentify_template.return_value = API_RESPONSE - result = self.hook.get_deidentify_template( - template_id=TEMPLATE_ID, project_id=PROJECT_ID - ) + result = self.hook.get_deidentify_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.get_deidentify_template.assert_called_once_with( - name=DEIDENTIFY_TEMPLATE_PROJECT_PATH, - retry=None, - timeout=None, - metadata=None, + name=DEIDENTIFY_TEMPLATE_PROJECT_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -574,7 +513,7 @@ def test_get_deidentify_template_without_template_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_deidentify_template_without_parent(self, mock_get_conn, mock_project_id): @@ -599,7 +538,7 @@ def test_get_dlp_job_without_dlp_job_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_dlp_job_without_parent(self, mock_get_conn, mock_project_id): @@ -609,36 +548,26 @@ def test_get_dlp_job_without_parent(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_inspect_template_with_org_id(self, get_conn, mock_project_id): get_conn.return_value.get_inspect_template.return_value = API_RESPONSE - result = self.hook.get_inspect_template( - template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID - ) + result = self.hook.get_inspect_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.get_inspect_template.assert_called_once_with( - name=INSPECT_TEMPLATE_ORGANIZATION_PATH, - retry=None, - timeout=None, - metadata=None, + name=INSPECT_TEMPLATE_ORGANIZATION_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_inspect_template_with_project_id(self, get_conn): get_conn.return_value.get_inspect_template.return_value = API_RESPONSE - result = self.hook.get_inspect_template( - template_id=TEMPLATE_ID, project_id=PROJECT_ID - ) + result = self.hook.get_inspect_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.get_inspect_template.assert_called_once_with( - name=INSPECT_TEMPLATE_PROJECT_PATH, - retry=None, - timeout=None, - metadata=None, + name=INSPECT_TEMPLATE_PROJECT_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -649,7 +578,7 @@ def test_get_inspect_template_without_template_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_inspect_template_without_parent(self, mock_get_conn, mock_project_id): @@ -659,9 +588,7 @@ def test_get_inspect_template_without_parent(self, mock_get_conn, mock_project_i @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_job_trigger(self, get_conn): get_conn.return_value.get_job_trigger.return_value = API_RESPONSE - result = self.hook.get_job_trigger( - job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID - ) + result = self.hook.get_job_trigger(job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.get_job_trigger.assert_called_once_with( @@ -676,7 +603,7 @@ def test_get_job_trigger_without_trigger_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_job_trigger_without_parent(self, mock_get_conn, mock_project_id): @@ -686,7 +613,7 @@ def test_get_job_trigger_without_parent(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_stored_info_type_with_org_id(self, get_conn, mock_project_id): @@ -697,10 +624,7 @@ def test_get_stored_info_type_with_org_id(self, get_conn, mock_project_id): self.assertIs(result, API_RESPONSE) get_conn.return_value.get_stored_info_type.assert_called_once_with( - name=STORED_INFO_TYPE_ORGANIZATION_PATH, - retry=None, - timeout=None, - metadata=None, + name=STORED_INFO_TYPE_ORGANIZATION_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -712,10 +636,7 @@ def test_get_stored_info_type_with_project_id(self, get_conn): self.assertIs(result, API_RESPONSE) get_conn.return_value.get_stored_info_type.assert_called_once_with( - name=STORED_INFO_TYPE_PROJECT_PATH, - retry=None, - timeout=None, - metadata=None, + name=STORED_INFO_TYPE_PROJECT_PATH, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -726,7 +647,7 @@ def test_get_stored_info_type_without_stored_info_type_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_get_stored_info_type_without_parent(self, mock_get_get_conn, mock_project_id): @@ -752,7 +673,7 @@ def test_inspect_content(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_inspect_content_without_parent(self, mock_get_conn, mock_project_id): @@ -762,7 +683,7 @@ def test_inspect_content_without_parent(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_deidentify_templates_with_org_id(self, get_conn, mock_project_id): @@ -770,12 +691,7 @@ def test_list_deidentify_templates_with_org_id(self, get_conn, mock_project_id): self.assertIsInstance(result, list) get_conn.return_value.list_deidentify_templates.assert_called_once_with( - parent=ORGANIZATION_PATH, - page_size=None, - order_by=None, - retry=None, - timeout=None, - metadata=None, + parent=ORGANIZATION_PATH, page_size=None, order_by=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -784,18 +700,13 @@ def test_list_deidentify_templates_with_project_id(self, get_conn): self.assertIsInstance(result, list) get_conn.return_value.list_deidentify_templates.assert_called_once_with( - parent=PROJECT_PATH, - page_size=None, - order_by=None, - retry=None, - timeout=None, - metadata=None, + parent=PROJECT_PATH, page_size=None, order_by=None, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_deidentify_templates_without_parent(self, mock_get_conn, mock_project_id): @@ -821,7 +732,7 @@ def test_list_dlp_jobs(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_dlp_jobs_without_parent(self, mock_get_conn, mock_project_id): @@ -841,7 +752,7 @@ def test_list_info_types(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_inspect_templates_with_org_id(self, get_conn, mock_project_id): @@ -849,12 +760,7 @@ def test_list_inspect_templates_with_org_id(self, get_conn, mock_project_id): self.assertIsInstance(result, list) get_conn.return_value.list_inspect_templates.assert_called_once_with( - parent=ORGANIZATION_PATH, - page_size=None, - order_by=None, - retry=None, - timeout=None, - metadata=None, + parent=ORGANIZATION_PATH, page_size=None, order_by=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -863,18 +769,13 @@ def test_list_inspect_templates_with_project_id(self, get_conn): self.assertIsInstance(result, list) get_conn.return_value.list_inspect_templates.assert_called_once_with( - parent=PROJECT_PATH, - page_size=None, - order_by=None, - retry=None, - timeout=None, - metadata=None, + parent=PROJECT_PATH, page_size=None, order_by=None, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_inspect_templates_without_parent(self, mock_get_conn, mock_project_id): @@ -899,7 +800,7 @@ def test_list_job_triggers(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_job_triggers_without_parent(self, mock_get_conn, mock_project_id): @@ -909,7 +810,7 @@ def test_list_job_triggers_without_parent(self, mock_get_conn, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_stored_info_types_with_org_id(self, get_conn, mock_project_id): @@ -917,12 +818,7 @@ def test_list_stored_info_types_with_org_id(self, get_conn, mock_project_id): self.assertIsInstance(result, list) get_conn.return_value.list_stored_info_types.assert_called_once_with( - parent=ORGANIZATION_PATH, - page_size=None, - order_by=None, - retry=None, - timeout=None, - metadata=None, + parent=ORGANIZATION_PATH, page_size=None, order_by=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") @@ -931,18 +827,13 @@ def test_list_stored_info_types_with_project_id(self, get_conn): self.assertIsInstance(result, list) get_conn.return_value.list_stored_info_types.assert_called_once_with( - parent=PROJECT_PATH, - page_size=None, - order_by=None, - retry=None, - timeout=None, - metadata=None, + parent=PROJECT_PATH, page_size=None, order_by=None, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_list_stored_info_types_without_parent(self, mock_get_conn, mock_project_id): @@ -969,7 +860,7 @@ def test_redact_image(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_redact_image_without_parent(self, mock_get_conn, mock_project_id): @@ -997,11 +888,9 @@ def test_reidentify_content(self, get_conn): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_reidentify_content_without_parent(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.reidentify_content() # pylint: disable=no-value-for-parameter @@ -1009,7 +898,7 @@ def test_reidentify_content_without_parent(self, mock_get_conn, mock_project_id) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_deidentify_template_with_org_id(self, get_conn, mock_project_id): @@ -1031,9 +920,7 @@ def test_update_deidentify_template_with_org_id(self, get_conn, mock_project_id) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_deidentify_template_with_project_id(self, get_conn): get_conn.return_value.update_deidentify_template.return_value = API_RESPONSE - result = self.hook.update_deidentify_template( - template_id=TEMPLATE_ID, project_id=PROJECT_ID - ) + result = self.hook.update_deidentify_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.update_deidentify_template.assert_called_once_with( @@ -1048,14 +935,12 @@ def test_update_deidentify_template_with_project_id(self, get_conn): @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_deidentify_template_without_template_id(self, _): with self.assertRaises(AirflowException): - self.hook.update_deidentify_template( - template_id=None, organization_id=ORGANIZATION_ID - ) + self.hook.update_deidentify_template(template_id=None, organization_id=ORGANIZATION_ID) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_deidentify_template_without_parent(self, mock_get_conn, mock_project_id): @@ -1065,14 +950,12 @@ def test_update_deidentify_template_without_parent(self, mock_get_conn, mock_pro @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_inspect_template_with_org_id(self, get_conn, mock_project_id): get_conn.return_value.update_inspect_template.return_value = API_RESPONSE - result = self.hook.update_inspect_template( - template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID - ) + result = self.hook.update_inspect_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.update_inspect_template.assert_called_once_with( @@ -1087,9 +970,7 @@ def test_update_inspect_template_with_org_id(self, get_conn, mock_project_id): @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_inspect_template_with_project_id(self, get_conn): get_conn.return_value.update_inspect_template.return_value = API_RESPONSE - result = self.hook.update_inspect_template( - template_id=TEMPLATE_ID, project_id=PROJECT_ID - ) + result = self.hook.update_inspect_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.update_inspect_template.assert_called_once_with( @@ -1104,14 +985,12 @@ def test_update_inspect_template_with_project_id(self, get_conn): @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_inspect_template_without_template_id(self, _): with self.assertRaises(AirflowException): - self.hook.update_inspect_template( - template_id=None, organization_id=ORGANIZATION_ID - ) + self.hook.update_inspect_template(template_id=None, organization_id=ORGANIZATION_ID) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_inspect_template_without_parent(self, mock_get_conn, mock_project_id): @@ -1121,9 +1000,7 @@ def test_update_inspect_template_without_parent(self, mock_get_conn, mock_projec @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_job_trigger(self, get_conn): get_conn.return_value.update_job_trigger.return_value = API_RESPONSE - result = self.hook.update_job_trigger( - job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID - ) + result = self.hook.update_job_trigger(job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) get_conn.return_value.update_job_trigger.assert_called_once_with( @@ -1143,7 +1020,7 @@ def test_update_job_trigger_without_job_trigger_id(self, _): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_job_trigger_without_parent(self, mock_get_conn, mock_project_id): @@ -1153,7 +1030,7 @@ def test_update_job_trigger_without_parent(self, mock_get_conn, mock_project_id) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_stored_info_type_with_org_id(self, get_conn, mock_project_id): @@ -1189,23 +1066,17 @@ def test_update_stored_info_type_with_project_id(self, get_conn): metadata=None, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" - ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_stored_info_type_without_stored_info_type_id(self, _): with self.assertRaises(AirflowException): - self.hook.update_stored_info_type( - stored_info_type_id=None, organization_id=ORGANIZATION_ID - ) + self.hook.update_stored_info_type(stored_info_type_id=None, organization_id=ORGANIZATION_ID) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None - ) - @mock.patch( - "airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn" + return_value=None, ) + @mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn") def test_update_stored_info_type_without_parent(self, mock_get_conn, mock_project_id): with self.assertRaises(AirflowException): self.hook.update_stored_info_type(stored_info_type_id=STORED_INFO_TYPE_ID) diff --git a/tests/providers/google/cloud/hooks/test_functions.py b/tests/providers/google/cloud/hooks/test_functions.py index e656aaeee9e0c..5b0fcbb054a5a 100644 --- a/tests/providers/google/cloud/hooks/test_functions.py +++ b/tests/providers/google/cloud/hooks/test_functions.py @@ -24,7 +24,9 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, get_open_mock, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + get_open_mock, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -33,10 +35,11 @@ class TestFunctionHookNoDefaultProjectId(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_no_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_no_default_project_id, + ): self.gcf_function_hook_no_project_id = CloudFunctionsHook(gcp_conn_id='test', api_version='v1') @mock.patch("airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._authorize") @@ -54,19 +57,17 @@ def test_gcf_client_creation(self, mock_build, mock_authorize): 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._wait_for_operation_to_complete' ) def test_create_new_function_overridden_project_id(self, wait_for_operation_to_complete, get_conn): - create_method = get_conn.return_value.projects.return_value.locations. \ - return_value.functions.return_value.create + create_method = ( + get_conn.return_value.projects.return_value.locations.return_value.functions.return_value.create + ) execute_method = create_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gcf_function_hook_no_project_id.create_new_function( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - location=GCF_LOCATION, - body={} + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, location=GCF_LOCATION, body={} ) self.assertIsNone(res) - create_method.assert_called_once_with(body={}, - location='projects/example-project/locations/location') + create_method.assert_called_once_with(body={}, location='projects/example-project/locations/location') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id') @@ -75,32 +76,34 @@ def test_create_new_function_overridden_project_id(self, wait_for_operation_to_c def test_upload_function_zip_overridden_project_id(self, get_conn, requests_put): mck, open_module = get_open_mock() with mock.patch('{}.open'.format(open_module), mck): + # fmt: off generate_upload_url_method = get_conn.return_value.projects.return_value.locations. \ return_value.functions.return_value.generateUploadUrl + # fmt: on execute_method = generate_upload_url_method.return_value.execute execute_method.return_value = {"uploadUrl": "http://uploadHere"} requests_put.return_value = None res = self.gcf_function_hook_no_project_id.upload_function_zip( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - location=GCF_LOCATION, - zip_path="/tmp/path.zip" + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, location=GCF_LOCATION, zip_path="/tmp/path.zip" ) self.assertEqual("http://uploadHere", res) generate_upload_url_method.assert_called_once_with( - parent='projects/example-project/locations/location') + parent='projects/example-project/locations/location' + ) execute_method.assert_called_once_with(num_retries=5) requests_put.assert_called_once_with( data=mock.ANY, - headers={'Content-type': 'application/zip', - 'x-goog-content-length-range': '0,104857600'}, - url='http://uploadHere' + headers={'Content-type': 'application/zip', 'x-goog-content-length-range': '0,104857600'}, + url='http://uploadHere', ) class TestFunctionHookDefaultProjectId(unittest.TestCase): def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_default_project_id, + ): self.gcf_function_hook = CloudFunctionsHook(gcp_conn_id='test', api_version='v1') @mock.patch("airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._authorize") @@ -116,26 +119,24 @@ def test_gcf_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn') @mock.patch( 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._wait_for_operation_to_complete' ) def test_create_new_function(self, wait_for_operation_to_complete, get_conn, mock_project_id): - create_method = get_conn.return_value.projects.return_value.locations.\ - return_value.functions.return_value.create + create_method = ( + get_conn.return_value.projects.return_value.locations.return_value.functions.return_value.create + ) execute_method = create_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gcf_function_hook.create_new_function( - location=GCF_LOCATION, - body={}, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + location=GCF_LOCATION, body={}, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertIsNone(res) - create_method.assert_called_once_with(body={}, - location='projects/example-project/locations/location') + create_method.assert_called_once_with(body={}, location='projects/example-project/locations/location') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id') @@ -144,31 +145,28 @@ def test_create_new_function(self, wait_for_operation_to_complete, get_conn, moc 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._wait_for_operation_to_complete' ) def test_create_new_function_override_project_id(self, wait_for_operation_to_complete, get_conn): - create_method = get_conn.return_value.projects.return_value.locations. \ - return_value.functions.return_value.create + create_method = ( + get_conn.return_value.projects.return_value.locations.return_value.functions.return_value.create + ) execute_method = create_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gcf_function_hook.create_new_function( - project_id='new-project', - location=GCF_LOCATION, - body={} + project_id='new-project', location=GCF_LOCATION, body={} ) self.assertIsNone(res) - create_method.assert_called_once_with(body={}, - location='projects/new-project/locations/location') + create_method.assert_called_once_with(body={}, location='projects/new-project/locations/location') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id') @mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn') def test_get_function(self, get_conn): - get_method = get_conn.return_value.projects.return_value.locations. \ - return_value.functions.return_value.get + get_method = ( + get_conn.return_value.projects.return_value.locations.return_value.functions.return_value.get + ) execute_method = get_method.return_value.execute execute_method.return_value = {"name": "function"} - res = self.gcf_function_hook.get_function( - name=GCF_FUNCTION - ) + res = self.gcf_function_hook.get_function(name=GCF_FUNCTION) self.assertIsNotNone(res) self.assertEqual('function', res['name']) get_method.assert_called_once_with(name='function') @@ -179,8 +177,9 @@ def test_get_function(self, get_conn): 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._wait_for_operation_to_complete' ) def test_delete_function(self, wait_for_operation_to_complete, get_conn): - delete_method = get_conn.return_value.projects.return_value.locations. \ - return_value.functions.return_value.delete + delete_method = ( + get_conn.return_value.projects.return_value.locations.return_value.functions.return_value.delete + ) execute_method = delete_method.return_value.execute wait_for_operation_to_complete.return_value = None execute_method.return_value = {"name": "operation_id"} @@ -196,54 +195,49 @@ def test_delete_function(self, wait_for_operation_to_complete, get_conn): 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook._wait_for_operation_to_complete' ) def test_update_function(self, wait_for_operation_to_complete, get_conn): - patch_method = get_conn.return_value.projects.return_value.locations. \ - return_value.functions.return_value.patch + patch_method = ( + get_conn.return_value.projects.return_value.locations.return_value.functions.return_value.patch + ) execute_method = patch_method.return_value.execute execute_method.return_value = {"name": "operation_id"} wait_for_operation_to_complete.return_value = None res = self.gcf_function_hook.update_function( # pylint: disable=assignment-from-no-return - update_mask=['a', 'b', 'c'], - name=GCF_FUNCTION, - body={} + update_mask=['a', 'b', 'c'], name=GCF_FUNCTION, body={} ) self.assertIsNone(res) - patch_method.assert_called_once_with( - body={}, - name='function', - updateMask='a,b,c' - ) + patch_method.assert_called_once_with(body={}, name='function', updateMask='a,b,c') execute_method.assert_called_once_with(num_retries=5) wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id') @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('requests.put') @mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn') def test_upload_function_zip(self, get_conn, requests_put, mock_project_id): mck, open_module = get_open_mock() with mock.patch('{}.open'.format(open_module), mck): + # fmt: off generate_upload_url_method = get_conn.return_value.projects.return_value.locations. \ return_value.functions.return_value.generateUploadUrl + # fmt: on execute_method = generate_upload_url_method.return_value.execute execute_method.return_value = {"uploadUrl": "http://uploadHere"} requests_put.return_value = None res = self.gcf_function_hook.upload_function_zip( - location=GCF_LOCATION, - zip_path="/tmp/path.zip", - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + location=GCF_LOCATION, zip_path="/tmp/path.zip", project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertEqual("http://uploadHere", res) generate_upload_url_method.assert_called_once_with( - parent='projects/example-project/locations/location') + parent='projects/example-project/locations/location' + ) execute_method.assert_called_once_with(num_retries=5) requests_put.assert_called_once_with( data=mock.ANY, - headers={'Content-type': 'application/zip', - 'x-goog-content-length-range': '0,104857600'}, - url='http://uploadHere' + headers={'Content-type': 'application/zip', 'x-goog-content-length-range': '0,104857600'}, + url='http://uploadHere', ) @mock.patch('requests.put') @@ -251,47 +245,47 @@ def test_upload_function_zip(self, get_conn, requests_put, mock_project_id): def test_upload_function_zip_overridden_project_id(self, get_conn, requests_put): mck, open_module = get_open_mock() with mock.patch('{}.open'.format(open_module), mck): + # fmt: off generate_upload_url_method = get_conn.return_value.projects.return_value.locations. \ return_value.functions.return_value.generateUploadUrl + # fmt: on execute_method = generate_upload_url_method.return_value.execute execute_method.return_value = {"uploadUrl": "http://uploadHere"} requests_put.return_value = None res = self.gcf_function_hook.upload_function_zip( - project_id='new-project', - location=GCF_LOCATION, - zip_path="/tmp/path.zip" + project_id='new-project', location=GCF_LOCATION, zip_path="/tmp/path.zip" ) self.assertEqual("http://uploadHere", res) generate_upload_url_method.assert_called_once_with( - parent='projects/new-project/locations/location') + parent='projects/new-project/locations/location' + ) execute_method.assert_called_once_with(num_retries=5) requests_put.assert_called_once_with( data=mock.ANY, - headers={'Content-type': 'application/zip', - 'x-goog-content-length-range': '0,104857600'}, - url='http://uploadHere' + headers={'Content-type': 'application/zip', 'x-goog-content-length-range': '0,104857600'}, + url='http://uploadHere', ) @mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn') def test_call_function(self, mock_get_conn): payload = {'executionId': 'wh41ppcyoa6l', 'result': 'Hello World!'} - call = mock_get_conn.return_value.projects.return_value.\ + # fmt: off + call = mock_get_conn.return_value.projects.return_value. \ locations.return_value.functions.return_value.call + # fmt: on call.return_value.execute.return_value = payload function_id = "function1234" input_data = {'key': 'value'} name = "projects/{project_id}/locations/{location}/functions/{function_id}".format( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - location=GCF_LOCATION, - function_id=function_id + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, location=GCF_LOCATION, function_id=function_id ) result = self.gcf_function_hook.call_function( function_id=function_id, location=GCF_LOCATION, input_data=input_data, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) call.assert_called_once_with(body=input_data, name=name) @@ -300,8 +294,10 @@ def test_call_function(self, mock_get_conn): @mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn') def test_call_function_error(self, mock_get_conn): payload = {'error': 'Something very bad'} + # fmt: off call = mock_get_conn.return_value.projects.return_value. \ locations.return_value.functions.return_value.call + # fmt: on call.return_value.execute.return_value = payload function_id = "function1234" @@ -311,5 +307,5 @@ def test_call_function_error(self, mock_get_conn): function_id=function_id, location=GCF_LOCATION, input_data=input_data, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 8a844839520cc..bcf02f229860c 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -48,24 +48,17 @@ def test_parse_gcs_url(self): Test GCS url parsing """ - self.assertEqual( - gcs._parse_gcs_url('gs://bucket/path/to/blob'), - ('bucket', 'path/to/blob')) + self.assertEqual(gcs._parse_gcs_url('gs://bucket/path/to/blob'), ('bucket', 'path/to/blob')) # invalid URI - self.assertRaises(AirflowException, gcs._parse_gcs_url, - 'gs:/bucket/path/to/blob') - self.assertRaises(AirflowException, gcs._parse_gcs_url, - 'http://google.com/aaa') + self.assertRaises(AirflowException, gcs._parse_gcs_url, 'gs:/bucket/path/to/blob') + self.assertRaises(AirflowException, gcs._parse_gcs_url, 'http://google.com/aaa') # trailing slash - self.assertEqual( - gcs._parse_gcs_url('gs://bucket/path/to/blob/'), - ('bucket', 'path/to/blob/')) + self.assertEqual(gcs._parse_gcs_url('gs://bucket/path/to/blob/'), ('bucket', 'path/to/blob/')) # bucket only - self.assertEqual( - gcs._parse_gcs_url('gs://bucket/'), ('bucket', '')) + self.assertEqual(gcs._parse_gcs_url('gs://bucket/'), ('bucket', '')) class TestFallbackObjectUrlToObjectNameAndBucketName(unittest.TestCase): @@ -73,16 +66,12 @@ def setUp(self) -> None: self.assertion_on_body = mock.MagicMock() @_fallback_object_url_to_object_name_and_bucket_name() - def test_method( - _, - bucket_name=None, - object_name=None, - object_url=None - ): + def test_method(_, bucket_name=None, object_name=None, object_url=None): assert object_name == "OBJECT_NAME" assert bucket_name == "BUCKET_NAME" assert object_url is None self.assertion_on_body() + self.test_method = test_method def test_should_url(self): @@ -95,23 +84,21 @@ def test_should_support_bucket_and_object(self): def test_should_raise_exception_on_missing(self): with self.assertRaisesRegex( - TypeError, - re.escape( - "test_method() missing 2 required positional arguments: 'bucket_name' and 'object_name'" - )): + TypeError, + re.escape( + "test_method() missing 2 required positional arguments: 'bucket_name' and 'object_name'" + ), + ): self.test_method(None) self.assertion_on_body.assert_not_called() def test_should_raise_exception_on_mutually_exclusive(self): - with self.assertRaisesRegex( - AirflowException, - re.escape("The mutually exclusive parameters.") - ): + with self.assertRaisesRegex(AirflowException, re.escape("The mutually exclusive parameters.")): self.test_method( None, bucket_name="BUCKET_NAME", object_name="OBJECT_NAME", - object_url="gs://BUCKET_NAME/OBJECT_NAME" + object_url="gs://BUCKET_NAME/OBJECT_NAME", ) self.assertion_on_body.assert_not_called() @@ -119,35 +106,30 @@ def test_should_raise_exception_on_mutually_exclusive(self): class TestGCSHook(unittest.TestCase): def setUp(self): with mock.patch( - GCS_STRING.format('GoogleBaseHook.__init__'), - new=mock_base_gcp_hook_default_project_id, + GCS_STRING.format('GoogleBaseHook.__init__'), new=mock_base_gcp_hook_default_project_id, ): - self.gcs_hook = gcs.GCSHook( - gcp_conn_id='test') + self.gcs_hook = gcs.GCSHook(gcp_conn_id='test') @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.client_info', new_callable=mock.PropertyMock, - return_value="CLIENT_INFO" + return_value="CLIENT_INFO", ) @mock.patch( BASE_STRING.format("GoogleBaseHook._get_credentials_and_project_id"), - return_value=("CREDENTIALS", "PROJECT_ID") + return_value=("CREDENTIALS", "PROJECT_ID"), ) @mock.patch(GCS_STRING.format('GoogleBaseHook.get_connection')) @mock.patch('google.cloud.storage.Client') - def test_storage_client_creation(self, - mock_client, - mock_get_connetion, - mock_get_creds_and_project_id, - mock_client_info): + def test_storage_client_creation( + self, mock_client, mock_get_connetion, mock_get_creds_and_project_id, mock_client_info + ): hook = gcs.GCSHook() result = hook.get_conn() # test that Storage Client is called with required arguments mock_client.assert_called_once_with( - client_info="CLIENT_INFO", - credentials="CREDENTIALS", - project="PROJECT_ID") + client_info="CLIENT_INFO", credentials="CREDENTIALS", project="PROJECT_ID" + ) self.assertEqual(mock_client.return_value, result) @mock.patch(GCS_STRING.format('GCSHook.get_conn')) @@ -193,13 +175,13 @@ def test_is_updated_after(self, mock_service): test_object = 'test_object' # Given - mock_service.return_value.bucket.return_value.get_blob\ - .return_value.updated = datetime(2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc()) + mock_service.return_value.bucket.return_value.get_blob.return_value.updated = datetime( + 2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc() + ) # When response = self.gcs_hook.is_updated_after( - bucket_name=test_bucket, object_name=test_object, - ts=datetime(2018, 1, 1, 1, 1, 1) + bucket_name=test_bucket, object_name=test_object, ts=datetime(2018, 1, 1, 1, 1, 1) ) # Then @@ -211,13 +193,13 @@ def test_is_updated_before(self, mock_service): test_object = 'test_object' # Given - mock_service.return_value.bucket.return_value.get_blob \ - .return_value.updated = datetime(2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc()) + mock_service.return_value.bucket.return_value.get_blob.return_value.updated = datetime( + 2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc() + ) # When response = self.gcs_hook.is_updated_before( - bucket_name=test_bucket, object_name=test_object, - ts=datetime(2020, 1, 1, 1, 1, 1) + bucket_name=test_bucket, object_name=test_object, ts=datetime(2020, 1, 1, 1, 1, 1) ) # Then @@ -229,14 +211,16 @@ def test_is_updated_between(self, mock_service): test_object = 'test_object' # Given - mock_service.return_value.bucket.return_value.get_blob \ - .return_value.updated = datetime(2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc()) + mock_service.return_value.bucket.return_value.get_blob.return_value.updated = datetime( + 2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc() + ) # When response = self.gcs_hook.is_updated_between( - bucket_name=test_bucket, object_name=test_object, + bucket_name=test_bucket, + object_name=test_object, min_ts=datetime(2018, 1, 1, 1, 1, 1), - max_ts=datetime(2020, 1, 1, 1, 1, 1) + max_ts=datetime(2020, 1, 1, 1, 1, 1), ) # Then @@ -248,13 +232,13 @@ def test_is_older_than_with_true_cond(self, mock_service): test_object = 'test_object' # Given - mock_service.return_value.bucket.return_value.get_blob \ - .return_value.updated = datetime(2020, 1, 28, 14, 7, 20, 700000, dateutil.tz.tzutc()) + mock_service.return_value.bucket.return_value.get_blob.return_value.updated = datetime( + 2020, 1, 28, 14, 7, 20, 700000, dateutil.tz.tzutc() + ) # When response = self.gcs_hook.is_older_than( - bucket_name=test_bucket, object_name=test_object, - seconds=86400 # 24hr + bucket_name=test_bucket, object_name=test_object, seconds=86400 # 24hr ) # Then @@ -266,12 +250,13 @@ def test_is_older_than_with_false_cond(self, mock_service): test_object = 'test_object' # Given + # fmt: off mock_service.return_value.bucket.return_value.get_blob \ .return_value.updated = timezone.utcnow() + timedelta(days=2) + # fmt: on # When response = self.gcs_hook.is_older_than( - bucket_name=test_bucket, object_name=test_object, - seconds=86400 # 24hr + bucket_name=test_bucket, object_name=test_object, seconds=86400 # 24hr ) # Then self.assertFalse(response) @@ -286,9 +271,7 @@ def test_copy(self, mock_service, mock_bucket): destination_bucket_instance = mock_bucket source_blob = mock_bucket.blob(source_object) - destination_blob = storage.Blob( - bucket=destination_bucket_instance, - name=destination_object) + destination_blob = storage.Blob(bucket=destination_bucket_instance, name=destination_object) # Given bucket_mock = mock_service.return_value.bucket @@ -301,15 +284,13 @@ def test_copy(self, mock_service, mock_bucket): source_bucket=source_bucket, source_object=source_object, destination_bucket=destination_bucket, - destination_object=destination_object + destination_object=destination_object, ) # Then self.assertEqual(response, None) copy_method.assert_called_once_with( - blob=source_blob, - destination_bucket=destination_bucket_instance, - new_name=destination_object + blob=source_blob, destination_bucket=destination_bucket_instance, new_name=destination_object ) def test_copy_fail_same_source_and_destination(self): @@ -319,16 +300,17 @@ def test_copy_fail_same_source_and_destination(self): destination_object = 'test-source-object' with self.assertRaises(ValueError) as e: - self.gcs_hook.copy(source_bucket=source_bucket, - source_object=source_object, - destination_bucket=destination_bucket, - destination_object=destination_object) + self.gcs_hook.copy( + source_bucket=source_bucket, + source_object=source_object, + destination_bucket=destination_bucket, + destination_object=destination_object, + ) self.assertEqual( str(e.exception), 'Either source/destination bucket or source/destination object ' - 'must be different, not both the same: bucket=%s, object=%s' % - (source_bucket, source_object) + 'must be different, not both the same: bucket=%s, object=%s' % (source_bucket, source_object), ) def test_copy_empty_source_bucket(self): @@ -338,15 +320,14 @@ def test_copy_empty_source_bucket(self): destination_object = 'test-dest-object' with self.assertRaises(ValueError) as e: - self.gcs_hook.copy(source_bucket=source_bucket, - source_object=source_object, - destination_bucket=destination_bucket, - destination_object=destination_object) + self.gcs_hook.copy( + source_bucket=source_bucket, + source_object=source_object, + destination_bucket=destination_bucket, + destination_object=destination_object, + ) - self.assertEqual( - str(e.exception), - 'source_bucket and source_object cannot be empty.' - ) + self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.') def test_copy_empty_source_object(self): source_bucket = 'test-source-object' @@ -355,15 +336,14 @@ def test_copy_empty_source_object(self): destination_object = 'test-dest-object' with self.assertRaises(ValueError) as e: - self.gcs_hook.copy(source_bucket=source_bucket, - source_object=source_object, - destination_bucket=destination_bucket, - destination_object=destination_object) + self.gcs_hook.copy( + source_bucket=source_bucket, + source_object=source_object, + destination_bucket=destination_bucket, + destination_object=destination_object, + ) - self.assertEqual( - str(e.exception), - 'source_bucket and source_object cannot be empty.' - ) + self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.') @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GCSHook.get_conn')) @@ -387,12 +367,12 @@ def test_rewrite(self, mock_service, mock_bucket): source_bucket=source_bucket, source_object=source_object, destination_bucket=destination_bucket, - destination_object=destination_object) + destination_object=destination_object, + ) # Then self.assertEqual(response, None) - rewrite_method.assert_called_once_with( - source=source_blob) + rewrite_method.assert_called_once_with(source=source_blob) def test_rewrite_empty_source_bucket(self): source_bucket = None @@ -401,15 +381,14 @@ def test_rewrite_empty_source_bucket(self): destination_object = 'test-dest-object' with self.assertRaises(ValueError) as e: - self.gcs_hook.rewrite(source_bucket=source_bucket, - source_object=source_object, - destination_bucket=destination_bucket, - destination_object=destination_object) + self.gcs_hook.rewrite( + source_bucket=source_bucket, + source_object=source_object, + destination_bucket=destination_bucket, + destination_object=destination_object, + ) - self.assertEqual( - str(e.exception), - 'source_bucket and source_object cannot be empty.' - ) + self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.') def test_rewrite_empty_source_object(self): source_bucket = 'test-source-object' @@ -418,15 +397,14 @@ def test_rewrite_empty_source_object(self): destination_object = 'test-dest-object' with self.assertRaises(ValueError) as e: - self.gcs_hook.rewrite(source_bucket=source_bucket, - source_object=source_object, - destination_bucket=destination_bucket, - destination_object=destination_object) + self.gcs_hook.rewrite( + source_bucket=source_bucket, + source_object=source_object, + destination_bucket=destination_bucket, + destination_object=destination_object, + ) - self.assertEqual( - str(e.exception), - 'source_bucket and source_object cannot be empty.' - ) + self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.') @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GCSHook.get_conn')) @@ -441,8 +419,8 @@ def test_delete(self, mock_service, mock_bucket): delete_method.return_value = blob_to_be_deleted response = self.gcs_hook.delete( # pylint: disable=assignment-from-no-return - bucket_name=test_bucket, - object_name=test_object) + bucket_name=test_bucket, object_name=test_object + ) self.assertIsNone(response) @mock.patch(GCS_STRING.format('GCSHook.get_conn')) @@ -467,9 +445,10 @@ def test_delete_bucket(self, mock_service): mock_service.return_value.bucket.assert_called_once_with(test_bucket) mock_service.return_value.bucket.return_value.delete.assert_called_once() - @mock.patch(GCS_STRING.format('GCSHook.get_conn'), **{ - 'return_value.bucket.return_value.delete.side_effect': exceptions.NotFound(message="Not Found") - }) + @mock.patch( + GCS_STRING.format('GCSHook.get_conn'), + **{'return_value.bucket.return_value.delete.side_effect': exceptions.NotFound(message="Not Found")}, + ) def test_delete_nonexisting_bucket(self, mock_service): test_bucket = "test bucket" @@ -488,8 +467,7 @@ def test_object_get_size(self, mock_service): get_blob_method = bucket_method.return_value.get_blob get_blob_method.return_value.size = returned_file_size - response = self.gcs_hook.get_size(bucket_name=test_bucket, - object_name=test_object) + response = self.gcs_hook.get_size(bucket_name=test_bucket, object_name=test_object) self.assertEqual(response, returned_file_size) @@ -503,8 +481,7 @@ def test_object_get_crc32c(self, mock_service): get_blob_method = bucket_method.return_value.get_blob get_blob_method.return_value.crc32c = returned_file_crc32c - response = self.gcs_hook.get_crc32c(bucket_name=test_bucket, - object_name=test_object) + response = self.gcs_hook.get_crc32c(bucket_name=test_bucket, object_name=test_object) self.assertEqual(response, returned_file_crc32c) @@ -518,8 +495,7 @@ def test_object_get_md5hash(self, mock_service): get_blob_method = bucket_method.return_value.get_blob get_blob_method.return_value.md5_hash = returned_file_md5hash - response = self.gcs_hook.get_md5hash(bucket_name=test_bucket, - object_name=test_object) + response = self.gcs_hook.get_md5hash(bucket_name=test_bucket, object_name=test_object) self.assertEqual(response, returned_file_md5hash) @@ -546,7 +522,7 @@ def test_create_bucket(self, mock_service, mock_bucket): storage_class=test_storage_class, location=test_location, labels=test_labels, - project_id=test_project + project_id=test_project, ) self.assertEqual(response, sample_bucket.id) @@ -583,7 +559,7 @@ def test_create_bucket_with_resource(self, mock_service, mock_bucket): storage_class=test_storage_class, location=test_location, labels=test_labels, - project_id=test_project + project_id=test_project, ) self.assertEqual(response, sample_bucket.id) @@ -602,21 +578,18 @@ def test_compose(self, mock_service, mock_blob): test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3'] test_destination_object = 'test_object_composed' - mock_service.return_value.bucket.return_value\ - .blob.return_value = mock_blob(blob_name=mock.ANY) - method = mock_service.return_value.bucket.return_value.blob\ - .return_value.compose + mock_service.return_value.bucket.return_value.blob.return_value = mock_blob(blob_name=mock.ANY) + method = mock_service.return_value.bucket.return_value.blob.return_value.compose self.gcs_hook.compose( bucket_name=test_bucket, source_objects=test_source_objects, - destination_object=test_destination_object + destination_object=test_destination_object, ) method.assert_called_once_with( - sources=[ - mock_blob(blob_name=source_object) for source_object in test_source_objects - ]) + sources=[mock_blob(blob_name=source_object) for source_object in test_source_objects] + ) @mock.patch(GCS_STRING.format('GCSHook.get_conn')) def test_compose_with_empty_source_objects(self, mock_service): # pylint: disable=unused-argument @@ -628,13 +601,10 @@ def test_compose_with_empty_source_objects(self, mock_service): # pylint: disab self.gcs_hook.compose( bucket_name=test_bucket, source_objects=test_source_objects, - destination_object=test_destination_object + destination_object=test_destination_object, ) - self.assertEqual( - str(e.exception), - 'source_objects cannot be empty.' - ) + self.assertEqual(str(e.exception), 'source_objects cannot be empty.') @mock.patch(GCS_STRING.format('GCSHook.get_conn')) def test_compose_without_bucket(self, mock_service): # pylint: disable=unused-argument @@ -646,13 +616,10 @@ def test_compose_without_bucket(self, mock_service): # pylint: disable=unused-a self.gcs_hook.compose( bucket_name=test_bucket, source_objects=test_source_objects, - destination_object=test_destination_object + destination_object=test_destination_object, ) - self.assertEqual( - str(e.exception), - 'bucket_name and destination_object cannot be empty.' - ) + self.assertEqual(str(e.exception), 'bucket_name and destination_object cannot be empty.') @mock.patch(GCS_STRING.format('GCSHook.get_conn')) def test_compose_without_destination_object(self, mock_service): # pylint: disable=unused-argument @@ -664,13 +631,10 @@ def test_compose_without_destination_object(self, mock_service): # pylint: disa self.gcs_hook.compose( bucket_name=test_bucket, source_objects=test_source_objects, - destination_object=test_destination_object + destination_object=test_destination_object, ) - self.assertEqual( - str(e.exception), - 'bucket_name and destination_object cannot be empty.' - ) + self.assertEqual(str(e.exception), 'bucket_name and destination_object cannot be empty.') @mock.patch(GCS_STRING.format('GCSHook.get_conn')) def test_download_as_string(self, mock_service): @@ -678,13 +642,10 @@ def test_download_as_string(self, mock_service): test_object = 'test_object' test_object_bytes = io.BytesIO(b"input") - download_method = mock_service.return_value.bucket.return_value \ - .blob.return_value.download_as_string + download_method = mock_service.return_value.bucket.return_value.blob.return_value.download_as_string download_method.return_value = test_object_bytes - response = self.gcs_hook.download(bucket_name=test_bucket, - object_name=test_object, - filename=None) + response = self.gcs_hook.download(bucket_name=test_bucket, object_name=test_object, filename=None) self.assertEqual(response, test_object_bytes) download_method.assert_called_once_with() @@ -696,16 +657,18 @@ def test_download_to_file(self, mock_service): test_object_bytes = io.BytesIO(b"input") test_file = 'test_file' - download_filename_method = mock_service.return_value.bucket.return_value \ - .blob.return_value.download_to_filename + download_filename_method = ( + mock_service.return_value.bucket.return_value.blob.return_value.download_to_filename + ) download_filename_method.return_value = None - download_as_a_string_method = mock_service.return_value.bucket.return_value \ - .blob.return_value.download_as_string + download_as_a_string_method = ( + mock_service.return_value.bucket.return_value.blob.return_value.download_as_string + ) download_as_a_string_method.return_value = test_object_bytes - response = self.gcs_hook.download(bucket_name=test_bucket, - object_name=test_object, - filename=test_file) + response = self.gcs_hook.download( + bucket_name=test_bucket, object_name=test_object, filename=test_file + ) self.assertEqual(response, test_file) download_filename_method.assert_called_once_with(test_file) @@ -718,36 +681,36 @@ def test_provide_file(self, mock_service, mock_temp_file): test_object_bytes = io.BytesIO(b"input") test_file = 'test_file' - download_filename_method = mock_service.return_value.bucket.return_value \ - .blob.return_value.download_to_filename + download_filename_method = ( + mock_service.return_value.bucket.return_value.blob.return_value.download_to_filename + ) download_filename_method.return_value = None - download_as_a_string_method = mock_service.return_value.bucket.return_value \ - .blob.return_value.download_as_string + download_as_a_string_method = ( + mock_service.return_value.bucket.return_value.blob.return_value.download_as_string + ) download_as_a_string_method.return_value = test_object_bytes mock_temp_file.return_value.__enter__.return_value = mock.MagicMock() mock_temp_file.return_value.__enter__.return_value.name = test_file - with self.gcs_hook.provide_file( - bucket_name=test_bucket, - object_name=test_object) as response: + with self.gcs_hook.provide_file(bucket_name=test_bucket, object_name=test_object) as response: self.assertEqual(test_file, response.name) download_filename_method.assert_called_once_with(test_file) - mock_temp_file.assert_has_calls([ - mock.call(suffix='test_object'), - mock.call().__enter__(), - mock.call().__enter__().flush(), - mock.call().__exit__(None, None, None) - ]) + mock_temp_file.assert_has_calls( + [ + mock.call(suffix='test_object'), + mock.call().__enter__(), + mock.call().__enter__().flush(), + mock.call().__exit__(None, None, None), + ] + ) class TestGCSHookUpload(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__')): - self.gcs_hook = gcs.GCSHook( - gcp_conn_id='test' - ) + self.gcs_hook = gcs.GCSHook(gcp_conn_id='test') # generate a 384KiB test file (larger than the minimum 256KiB multipart chunk size) self.testfile = tempfile.NamedTemporaryFile(delete=False) @@ -764,16 +727,12 @@ def test_upload_file(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - upload_method = mock_service.return_value.bucket.return_value\ - .blob.return_value.upload_from_filename + upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_filename - self.gcs_hook.upload(test_bucket, - test_object, - filename=self.testfile.name) + self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name) upload_method.assert_called_once_with( - filename=self.testfile.name, - content_type='application/octet-stream' + filename=self.testfile.name, content_type='application/octet-stream' ) @mock.patch(GCS_STRING.format('GCSHook.get_conn')) @@ -781,10 +740,7 @@ def test_upload_file_gzip(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - self.gcs_hook.upload(test_bucket, - test_object, - filename=self.testfile.name, - gzip=True) + self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name, gzip=True) self.assertFalse(os.path.exists(self.testfile.name + '.gz')) @mock.patch(GCS_STRING.format('GCSHook.get_conn')) @@ -792,34 +748,22 @@ def test_upload_data_str(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - upload_method = mock_service.return_value.bucket.return_value\ - .blob.return_value.upload_from_string + upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_string - self.gcs_hook.upload(test_bucket, - test_object, - data=self.testdata_str) + self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_str) - upload_method.assert_called_once_with( - self.testdata_str, - content_type='text/plain' - ) + upload_method.assert_called_once_with(self.testdata_str, content_type='text/plain') @mock.patch(GCS_STRING.format('GCSHook.get_conn')) def test_upload_data_bytes(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - upload_method = mock_service.return_value.bucket.return_value\ - .blob.return_value.upload_from_string + upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_string - self.gcs_hook.upload(test_bucket, - test_object, - data=self.testdata_bytes) + self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_bytes) - upload_method.assert_called_once_with( - self.testdata_bytes, - content_type='text/plain' - ) + upload_method.assert_called_once_with(self.testdata_bytes, content_type='text/plain') @mock.patch(GCS_STRING.format('BytesIO')) @mock.patch(GCS_STRING.format('gz.GzipFile')) @@ -831,13 +775,9 @@ def test_upload_data_str_gzip(self, mock_service, mock_gzip, mock_bytes_io): gzip_ctx = mock_gzip.return_value.__enter__.return_value data = mock_bytes_io.return_value.getvalue.return_value - upload_method = mock_service.return_value.bucket.return_value\ - .blob.return_value.upload_from_string + upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_string - self.gcs_hook.upload(test_bucket, - test_object, - data=self.testdata_str, - gzip=True) + self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_str, gzip=True) byte_str = bytes(self.testdata_str, encoding) mock_gzip.assert_called_once_with(fileobj=mock_bytes_io.return_value, mode="w") @@ -853,13 +793,9 @@ def test_upload_data_bytes_gzip(self, mock_service, mock_gzip, mock_bytes_io): gzip_ctx = mock_gzip.return_value.__enter__.return_value data = mock_bytes_io.return_value.getvalue.return_value - upload_method = mock_service.return_value.bucket.return_value \ - .blob.return_value.upload_from_string + upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_string - self.gcs_hook.upload(test_bucket, - test_object, - data=self.testdata_bytes, - gzip=True) + self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_bytes, gzip=True) mock_gzip.assert_called_once_with(fileobj=mock_bytes_io.return_value, mode="w") gzip_ctx.write.assert_called_once_with(self.testdata_bytes) @@ -869,19 +805,21 @@ def test_upload_data_bytes_gzip(self, mock_service, mock_gzip, mock_bytes_io): def test_upload_exceptions(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - both_params_excep = "'filename' and 'data' parameter provided. Please " \ - "specify a single parameter, either 'filename' for " \ - "local file uploads or 'data' for file content uploads." - no_params_excep = "'filename' and 'data' parameter missing. " \ - "One is required to upload to gcs." + both_params_excep = ( + "'filename' and 'data' parameter provided. Please " + "specify a single parameter, either 'filename' for " + "local file uploads or 'data' for file content uploads." + ) + no_params_excep = "'filename' and 'data' parameter missing. One is required to upload to gcs." with self.assertRaises(ValueError) as cm: self.gcs_hook.upload(test_bucket, test_object) self.assertEqual(no_params_excep, str(cm.exception)) with self.assertRaises(ValueError) as cm: - self.gcs_hook.upload(test_bucket, test_object, - filename=self.testfile.name, data=self.testdata_str) + self.gcs_hook.upload( + test_bucket, test_object, filename=self.testfile.name, data=self.testdata_str + ) self.assertEqual(both_params_excep, str(cm.exception)) diff --git a/tests/providers/google/cloud/hooks/test_gdm.py b/tests/providers/google/cloud/hooks/test_gdm.py index 67e00e2167c8d..452780664cb0d 100644 --- a/tests/providers/google/cloud/hooks/test_gdm.py +++ b/tests/providers/google/cloud/hooks/test_gdm.py @@ -24,10 +24,7 @@ def mock_init( - self, - gcp_conn_id, - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id, delegate_to=None, impersonation_chain=None, ): # pylint: disable=unused-argument pass @@ -37,11 +34,9 @@ def mock_init( class TestDeploymentManagerHook(unittest.TestCase): - def setUp(self): with mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", - new=mock_init, + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", new=mock_init, ): self.gdm_hook = GoogleDeploymentManagerHook(gcp_conn_id="test") @@ -60,31 +55,29 @@ def test_list_deployments(self, mock_get_conn): None, ] - deployments = self.gdm_hook.list_deployments(project_id=TEST_PROJECT, - deployment_filter='filter', - order_by='name') + deployments = self.gdm_hook.list_deployments( + project_id=TEST_PROJECT, deployment_filter='filter', order_by='name' + ) mock_get_conn.assert_called_once_with() mock_get_conn.return_value.deployments.return_value.list.assert_called_once_with( - project=TEST_PROJECT, - filter='filter', - orderBy='name', + project=TEST_PROJECT, filter='filter', orderBy='name', ) self.assertEqual(mock_get_conn.return_value.deployments.return_value.list_next.call_count, 2) - self.assertEqual(deployments, [{'id': 'deployment1', 'name': 'test-deploy1'}, - {'id': 'deployment2', 'name': 'test-deploy2'}]) + self.assertEqual( + deployments, + [{'id': 'deployment1', 'name': 'test-deploy1'}, {'id': 'deployment2', 'name': 'test-deploy2'}], + ) @mock.patch("airflow.providers.google.cloud.hooks.gdm.GoogleDeploymentManagerHook.get_conn") def test_delete_deployment(self, mock_get_conn): self.gdm_hook.delete_deployment(project_id=TEST_PROJECT, deployment=TEST_DEPLOYMENT) mock_get_conn.assert_called_once_with() mock_get_conn.return_value.deployments().delete.assert_called_once_with( - project=TEST_PROJECT, - deployment=TEST_DEPLOYMENT, - deletePolicy=None + project=TEST_PROJECT, deployment=TEST_DEPLOYMENT, deletePolicy=None ) @mock.patch("airflow.providers.google.cloud.hooks.gdm.GoogleDeploymentManagerHook.get_conn") @@ -99,7 +92,5 @@ def test_delete_deployment_delete_fails(self, mock_get_conn): mock_get_conn.assert_called_once_with() mock_get_conn.return_value.deployments().delete.assert_called_once_with( - project=TEST_PROJECT, - deployment=TEST_DEPLOYMENT, - deletePolicy=None + project=TEST_PROJECT, deployment=TEST_DEPLOYMENT, deletePolicy=None ) diff --git a/tests/providers/google/cloud/hooks/test_kms.py b/tests/providers/google/cloud/hooks/test_kms.py index a1b1db29eb283..1088f60f26797 100644 --- a/tests/providers/google/cloud/hooks/test_kms.py +++ b/tests/providers/google/cloud/hooks/test_kms.py @@ -46,10 +46,7 @@ def mock_init( - self, - gcp_conn_id, - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id, delegate_to=None, impersonation_chain=None, ): # pylint: disable=unused-argument pass @@ -57,22 +54,19 @@ def mock_init( class TestCloudKMSHook(unittest.TestCase): def setUp(self): with mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", - new=mock_init, + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", new=mock_init, ): self.kms_hook = CloudKMSHook(gcp_conn_id="test") @mock.patch( - "airflow.providers.google.cloud.hooks.kms.CloudKMSHook.client_info", - new_callable=mock.PropertyMock, + "airflow.providers.google.cloud.hooks.kms.CloudKMSHook.client_info", new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.kms.CloudKMSHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.kms.KeyManagementServiceClient") def test_kms_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.kms_hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value, + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value, ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.kms_hook._conn, result) diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index e84c41f0648dd..6a655658960e2 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -37,7 +37,7 @@ def setUp(self): @mock.patch( "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.container_v1.ClusterManagerClient") @@ -45,8 +45,7 @@ def test_gke_cluster_client_creation(self, mock_client, mock_get_creds, mock_cli result = self.gke_hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.gke_hook._client, result) @@ -60,38 +59,37 @@ def setUp(self): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_delete_cluster(self, wait_mock, convert_mock, mock_project_id): retry_mock, timeout_mock = mock.Mock(), mock.Mock() client_delete = self.gke_hook._client.delete_cluster = mock.Mock() - self.gke_hook.delete_cluster(name=CLUSTER_NAME, project_id=TEST_GCP_PROJECT_ID, - retry=retry_mock, - timeout=timeout_mock) + self.gke_hook.delete_cluster( + name=CLUSTER_NAME, project_id=TEST_GCP_PROJECT_ID, retry=retry_mock, timeout=timeout_mock + ) - client_delete.assert_called_once_with(project_id=TEST_GCP_PROJECT_ID, - zone=GKE_ZONE, - cluster_id=CLUSTER_NAME, - retry=retry_mock, - timeout=timeout_mock) + client_delete.assert_called_once_with( + project_id=TEST_GCP_PROJECT_ID, + zone=GKE_ZONE, + cluster_id=CLUSTER_NAME, + retry=retry_mock, + timeout=timeout_mock, + ) wait_mock.assert_called_once_with(client_delete.return_value) convert_mock.assert_not_called() @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.log") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.log") @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_delete_cluster_not_found(self, wait_mock, convert_mock, log_mock, mock_project_id): from google.api_core.exceptions import NotFound @@ -107,11 +105,10 @@ def test_delete_cluster_not_found(self, wait_mock, convert_mock, log_mock, mock_ @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_delete_cluster_error(self, wait_mock, convert_mock, mock_project_id): # To force an error self.gke_hook._client.delete_cluster.side_effect = AirflowException('400') @@ -130,11 +127,10 @@ def setUp(self): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_create_cluster_proto(self, wait_mock, convert_mock, mock_project_id): mock_cluster_proto = Cluster() mock_cluster_proto.name = CLUSTER_NAME @@ -143,26 +139,27 @@ def test_create_cluster_proto(self, wait_mock, convert_mock, mock_project_id): client_create = self.gke_hook._client.create_cluster = mock.Mock() - self.gke_hook.create_cluster(cluster=mock_cluster_proto, - project_id=TEST_GCP_PROJECT_ID, - retry=retry_mock, - timeout=timeout_mock) + self.gke_hook.create_cluster( + cluster=mock_cluster_proto, project_id=TEST_GCP_PROJECT_ID, retry=retry_mock, timeout=timeout_mock + ) - client_create.assert_called_once_with(project_id=TEST_GCP_PROJECT_ID, - zone=GKE_ZONE, - cluster=mock_cluster_proto, - retry=retry_mock, timeout=timeout_mock) + client_create.assert_called_once_with( + project_id=TEST_GCP_PROJECT_ID, + zone=GKE_ZONE, + cluster=mock_cluster_proto, + retry=retry_mock, + timeout=timeout_mock, + ) wait_mock.assert_called_once_with(client_create.return_value) convert_mock.assert_not_called() @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_create_cluster_dict(self, wait_mock, convert_mock, mock_project_id): mock_cluster_dict = {'name': CLUSTER_NAME} retry_mock, timeout_mock = mock.Mock(), mock.Mock() @@ -170,24 +167,22 @@ def test_create_cluster_dict(self, wait_mock, convert_mock, mock_project_id): client_create = self.gke_hook._client.create_cluster = mock.Mock() proto_mock = convert_mock.return_value = mock.Mock() - self.gke_hook.create_cluster(cluster=mock_cluster_dict, - project_id=TEST_GCP_PROJECT_ID, - retry=retry_mock, - timeout=timeout_mock) + self.gke_hook.create_cluster( + cluster=mock_cluster_dict, project_id=TEST_GCP_PROJECT_ID, retry=retry_mock, timeout=timeout_mock + ) - client_create.assert_called_once_with(project_id=TEST_GCP_PROJECT_ID, - zone=GKE_ZONE, - cluster=proto_mock, - retry=retry_mock, timeout=timeout_mock) - wait_mock.assert_called_once_with(client_create.return_value) - convert_mock.assert_called_once_with( - {'name': 'test-cluster'}, - Cluster() + client_create.assert_called_once_with( + project_id=TEST_GCP_PROJECT_ID, + zone=GKE_ZONE, + cluster=proto_mock, + retry=retry_mock, + timeout=timeout_mock, ) + wait_mock.assert_called_once_with(client_create.return_value) + convert_mock.assert_called_once_with({'name': 'test-cluster'}, Cluster()) @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_create_cluster_error(self, wait_mock, convert_mock): # to force an error mock_cluster_proto = None @@ -199,13 +194,11 @@ def test_create_cluster_error(self, wait_mock, convert_mock): @mock.patch( 'airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook._get_credentials_and_project_id', - return_value=(mock.MagicMock(), TEST_GCP_PROJECT_ID) + return_value=(mock.MagicMock(), TEST_GCP_PROJECT_ID), ) - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.log") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.log") @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.ParseDict") - @mock.patch( - "airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") + @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.wait_for_operation") def test_create_cluster_already_exists(self, wait_mock, convert_mock, log_mock, mock_get_credentials): from google.api_core.exceptions import AlreadyExists @@ -229,25 +222,25 @@ def test_get_cluster(self): client_get = self.gke_hook._client.get_cluster = mock.Mock() - self.gke_hook.get_cluster(name=CLUSTER_NAME, - project_id=TEST_GCP_PROJECT_ID, - retry=retry_mock, - timeout=timeout_mock) + self.gke_hook.get_cluster( + name=CLUSTER_NAME, project_id=TEST_GCP_PROJECT_ID, retry=retry_mock, timeout=timeout_mock + ) - client_get.assert_called_once_with(project_id=TEST_GCP_PROJECT_ID, - zone=GKE_ZONE, - cluster_id=CLUSTER_NAME, - retry=retry_mock, timeout=timeout_mock) + client_get.assert_called_once_with( + project_id=TEST_GCP_PROJECT_ID, + zone=GKE_ZONE, + cluster_id=CLUSTER_NAME, + retry=retry_mock, + timeout=timeout_mock, + ) class TestGKEHook(unittest.TestCase): - def setUp(self): self.gke_hook = GKEHook(location=GKE_ZONE) self.gke_hook._client = mock.Mock() - @mock.patch('airflow.providers.google.cloud.hooks.kubernetes_engine.container_v1.' - 'ClusterManagerClient') + @mock.patch('airflow.providers.google.cloud.hooks.kubernetes_engine.container_v1.' 'ClusterManagerClient') @mock.patch('airflow.providers.google.common.hooks.base_google.ClientInfo') @mock.patch('airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook._get_credentials') def test_get_client(self, mock_get_credentials, mock_client_info, mock_client): @@ -255,14 +248,15 @@ def test_get_client(self, mock_get_credentials, mock_client_info, mock_client): self.gke_hook.get_conn() assert mock_get_credentials.called mock_client.assert_called_once_with( - credentials=mock_get_credentials.return_value, - client_info=mock_client_info.return_value) + credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value + ) def test_get_operation(self): self.gke_hook._client.get_operation = mock.Mock() self.gke_hook.get_operation('TEST_OP', project_id=TEST_GCP_PROJECT_ID) self.gke_hook._client.get_operation.assert_called_once_with( - project_id=TEST_GCP_PROJECT_ID, zone=GKE_ZONE, operation_id='TEST_OP') + project_id=TEST_GCP_PROJECT_ID, zone=GKE_ZONE, operation_id='TEST_OP' + ) def test_append_label(self): key = 'test-key' @@ -281,6 +275,7 @@ def test_append_label_replace(self): @mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.time.sleep") def test_wait_for_response_done(self, time_mock): from google.cloud.container_v1.gapic.enums import Operation + mock_op = mock.Mock() mock_op.status = Operation.Status.DONE self.gke_hook.wait_for_operation(mock_op) diff --git a/tests/providers/google/cloud/hooks/test_life_sciences.py b/tests/providers/google/cloud/hooks/test_life_sciences.py index 40b7905c97fae..fb2c8fd1f6efe 100644 --- a/tests/providers/google/cloud/hooks/test_life_sciences.py +++ b/tests/providers/google/cloud/hooks/test_life_sciences.py @@ -26,12 +26,17 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.life_sciences import LifeSciencesHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) -TEST_OPERATION = {"name": 'operation-name', "metadata": {"@type": 'anytype'}, - "done": True, "response": "response"} +TEST_OPERATION = { + "name": 'operation-name', + "metadata": {"@type": 'anytype'}, + "done": True, + "response": "response", +} TEST_WAITING_OPERATION = {"done": False, "response": "response"} TEST_DONE_OPERATION = {"done": True, "response": "response"} @@ -41,7 +46,6 @@ class TestLifeSciencesHookWithPassedProjectId(unittest.TestCase): - def setUp(self): with mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", @@ -51,8 +55,7 @@ def setUp(self): def test_location_path(self): path = 'projects/life-science-project-id/locations/test-location' - path2 = self.hook._location_path(project_id=TEST_PROJECT_ID, - location=TEST_LOCATION) + path2 = self.hook._location_path(project_id=TEST_PROJECT_ID, location=TEST_LOCATION) self.assertEqual(path, path2) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook._authorize") @@ -68,12 +71,12 @@ def test_life_science_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") def test_run_pipeline_immediately_complete(self, get_conn_mock, mock_project_id): service_mock = get_conn_mock.return_value - + # fmt: off service_mock.projects.return_value \ .locations.return_value \ .pipelines.return_value \ @@ -86,27 +89,27 @@ def test_run_pipeline_immediately_complete(self, get_conn_mock, mock_project_id) .get.return_value \ .execute.return_value = TEST_DONE_OPERATION - result = self.hook.run_pipeline(body={}, - location=TEST_LOCATION, - project_id=TEST_PROJECT_ID) - parent = self.hook._location_path(project_id=TEST_PROJECT_ID, - location=TEST_LOCATION) - service_mock.projects.return_value.locations.return_value\ + result = self.hook.run_pipeline(body={}, # pylint: disable=no-value-for-parameter + location=TEST_LOCATION) + parent = self.hook. \ + _location_path(location=TEST_LOCATION) # pylint: disable=no-value-for-parameter + service_mock.projects.return_value.locations.return_value \ .pipelines.return_value.run \ .assert_called_once_with(body={}, parent=parent) + # fmt: on self.assertEqual(result, TEST_OPERATION) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.time.sleep") def test_waiting_operation(self, _, get_conn_mock, mock_project_id): service_mock = get_conn_mock.return_value - + # fmt: off service_mock.projects.return_value \ .locations.return_value \ .pipelines.return_value \ @@ -122,21 +125,20 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id): .get.return_value \ .execute = execute_mock - result = self.hook.run_pipeline(body={}, - location=TEST_LOCATION, - project_id=TEST_PROJECT_ID) + # fmt: on + result = self.hook.run_pipeline(body={}, location=TEST_LOCATION, project_id=TEST_PROJECT_ID) self.assertEqual(result, TEST_OPERATION) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.time.sleep") def test_error_operation(self, _, get_conn_mock, mock_project_id): service_mock = get_conn_mock.return_value - + # fmt: off service_mock.projects.return_value \ .locations.return_value \ .pipelines.return_value \ @@ -149,16 +151,12 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id): .operations.return_value \ .get.return_value \ .execute = execute_mock - + # fmt: on with self.assertRaisesRegex(AirflowException, "error"): - self.hook.run_pipeline(body={}, - location=TEST_LOCATION, - project_id=TEST_PROJECT_ID - ) + self.hook.run_pipeline(body={}, location=TEST_LOCATION, project_id=TEST_PROJECT_ID) class TestLifeSciencesHookWithDefaultProjectIdFromConnection(unittest.TestCase): - def setUp(self): with mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", @@ -179,12 +177,13 @@ def test_life_science_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") def test_run_pipeline_immediately_complete(self, get_conn_mock, mock_project_id): service_mock = get_conn_mock.return_value + # fmt: off service_mock.projects.return_value \ .locations.return_value \ .pipelines.return_value \ @@ -197,55 +196,55 @@ def test_run_pipeline_immediately_complete(self, get_conn_mock, mock_project_id) .get.return_value \ .execute.return_value = TEST_DONE_OPERATION - result = self.hook.run_pipeline(body={}, # pylint: disable=no-value-for-parameter - location=TEST_LOCATION) - parent = self.hook.\ - _location_path(location=TEST_LOCATION) # pylint: disable=no-value-for-parameter - service_mock.projects.return_value.locations.return_value\ + result = self.hook.run_pipeline(body={}, location=TEST_LOCATION, project_id=TEST_PROJECT_ID) + parent = self.hook._location_path(project_id=TEST_PROJECT_ID, location=TEST_LOCATION) + service_mock.projects.return_value.locations.return_value \ .pipelines.return_value.run \ .assert_called_once_with(body={}, parent=parent) + # fmt: on self.assertEqual(result, TEST_OPERATION) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.time.sleep") def test_waiting_operation(self, _, get_conn_mock, mock_project_id): service_mock = get_conn_mock.return_value + # fmt: off service_mock.projects.return_value \ .locations.return_value \ .pipelines.return_value \ .run.return_value \ .execute.return_value = TEST_OPERATION - execute_mock = mock.Mock( - **{"side_effect": [TEST_WAITING_OPERATION, TEST_DONE_OPERATION]} - ) + execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_DONE_OPERATION]}) service_mock.projects.return_value \ .locations.return_value \ .operations.return_value \ .get.return_value \ .execute = execute_mock + # fmt: on - result = self.hook.run_pipeline(body={}, # pylint: disable=no-value-for-parameter - location=TEST_LOCATION) + # pylint: disable=no-value-for-parameter + result = self.hook.run_pipeline(body={}, location=TEST_LOCATION) self.assertEqual(result, TEST_OPERATION) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.time.sleep") def test_error_operation(self, _, get_conn_mock, mock_project_id): service_mock = get_conn_mock.return_value + # fmt: off service_mock.projects.return_value \ .locations.return_value \ .pipelines.return_value \ @@ -258,14 +257,13 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id): .operations.return_value \ .get.return_value \ .execute = execute_mock + # fmt: on with self.assertRaisesRegex(AirflowException, "error"): - self.hook.run_pipeline(body={}, # pylint: disable=no-value-for-parameter - location=TEST_LOCATION) + self.hook.run_pipeline(body={}, location=TEST_LOCATION) # pylint: disable=no-value-for-parameter class TestLifeSciencesHookWithoutProjectId(unittest.TestCase): - def setUp(self): with mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", @@ -286,13 +284,12 @@ def test_life_science_client_creation(self, mock_build, mock_authorize): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn") def test_run_pipeline(self, get_conn_mock, mock_project_id): # pylint: disable=unused-argument with self.assertRaises(AirflowException) as e: - self.hook.run_pipeline(body={}, # pylint: disable=no-value-for-parameter - location=TEST_LOCATION) + self.hook.run_pipeline(body={}, location=TEST_LOCATION) # pylint: disable=no-value-for-parameter self.assertEqual( "The project id must be passed either as keyword project_id parameter or as project_id extra in " diff --git a/tests/providers/google/cloud/hooks/test_mlengine.py b/tests/providers/google/cloud/hooks/test_mlengine.py index e32e5f636c173..5599f29e9f8a2 100644 --- a/tests/providers/google/cloud/hooks/test_mlengine.py +++ b/tests/providers/google/cloud/hooks/test_mlengine.py @@ -25,7 +25,8 @@ from airflow.providers.google.cloud.hooks import mlengine as hook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, ) @@ -40,30 +41,22 @@ def test_mle_engine_client_creation(self, mock_build, mock_authorize): result = self.hook.get_conn() self.assertEqual(mock_build.return_value, result) - mock_build.assert_called_with( - 'ml', 'v1', http=mock_authorize.return_value, cache_discovery=False - ) + mock_build.assert_called_with('ml', 'v1', http=mock_authorize.return_value, cache_discovery=False) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_version(self, mock_get_conn): project_id = 'test-project' model_name = 'test-model' version_name = 'test-version' - version = { - 'name': version_name, - 'labels': {'other-label': 'test-value'} - } + version = {'name': version_name, 'labels': {'other-label': 'test-value'}} version_with_airflow_version = { 'name': 'test-version', - 'labels': { - 'other-label': 'test-value', - 'airflow-version': hook._AIRFLOW_VERSION - } + 'labels': {'other-label': 'test-value', 'airflow-version': hook._AIRFLOW_VERSION}, } operation_path = 'projects/{}/operations/test-operation'.format(project_id) model_path = 'projects/{}/models/{}'.format(project_id, model_name) operation_done = {'name': operation_path, 'done': True} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -79,23 +72,25 @@ def test_create_version(self, mock_get_conn): get.return_value. execute.return_value ) = {'name': operation_path, 'done': True} - + # fmt: on create_version_response = self.hook.create_version( - project_id=project_id, - model_name=model_name, - version_spec=deepcopy(version) + project_id=project_id, model_name=model_name, version_spec=deepcopy(version) ) self.assertEqual(create_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().create( - body=version_with_airflow_version, - parent=model_path - ), - mock.call().projects().models().versions().create().execute(), - mock.call().projects().operations().get(name=version_name), - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call() + .projects() + .models() + .versions() + .create(body=version_with_airflow_version, parent=model_path), + mock.call().projects().models().versions().create().execute(), + mock.call().projects().operations().get(name=version_name), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_version_with_labels(self, mock_get_conn): @@ -105,12 +100,12 @@ def test_create_version_with_labels(self, mock_get_conn): version = {'name': version_name} version_with_airflow_version = { 'name': 'test-version', - 'labels': {'airflow-version': hook._AIRFLOW_VERSION} + 'labels': {'airflow-version': hook._AIRFLOW_VERSION}, } operation_path = 'projects/{}/operations/test-operation'.format(project_id) model_path = 'projects/{}/models/{}'.format(project_id, model_name) operation_done = {'name': operation_path, 'done': True} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -132,17 +127,22 @@ def test_create_version_with_labels(self, mock_get_conn): model_name=model_name, version_spec=deepcopy(version) ) + # fmt: on self.assertEqual(create_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().create( - body=version_with_airflow_version, - parent=model_path - ), - mock.call().projects().models().versions().create().execute(), - mock.call().projects().operations().get(name=version_name), - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call() + .projects() + .models() + .versions() + .create(body=version_with_airflow_version, parent=model_path), + mock.call().projects().models().versions().create().execute(), + mock.call().projects().operations().get(name=version_name), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_set_default_version(self, mock_get_conn): @@ -152,7 +152,7 @@ def test_set_default_version(self, mock_get_conn): operation_path = 'projects/{}/operations/test-operation'.format(project_id) version_path = 'projects/{}/models/{}/versions/{}'.format(project_id, model_name, version_name) operation_done = {'name': operation_path, 'done': True} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -161,19 +161,20 @@ def test_set_default_version(self, mock_get_conn): setDefault.return_value. execute.return_value ) = operation_done - + # fmt: on set_default_version_response = self.hook.set_default_version( - project_id=project_id, - model_name=model_name, - version_name=version_name + project_id=project_id, model_name=model_name, version_name=version_name ) self.assertEqual(set_default_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().setDefault(body={}, name=version_path), - mock.call().projects().models().versions().setDefault().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().setDefault(body={}, name=version_path), + mock.call().projects().models().versions().setDefault().execute(), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.time.sleep") @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") @@ -183,18 +184,16 @@ def test_list_versions(self, mock_get_conn, mock_sleep): model_path = 'projects/{}/models/{}'.format(project_id, model_name) version_names = ['ver_{}'.format(ix) for ix in range(3)] response_bodies = [ - { - 'nextPageToken': "TOKEN-{}".format(ix), - 'versions': [ver] - } for ix, ver in enumerate(version_names)] + {'nextPageToken': "TOKEN-{}".format(ix), 'versions': [ver]} + for ix, ver in enumerate(version_names) + ] response_bodies[-1].pop('nextPageToken') - pages_requests = [ - mock.Mock(**{'execute.return_value': body}) for body in response_bodies - ] + pages_requests = [mock.Mock(**{'execute.return_value': body}) for body in response_bodies] versions_mock = mock.Mock( **{'list.return_value': pages_requests[0], 'list_next.side_effect': pages_requests[1:] + [None]} ) + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -204,16 +203,23 @@ def test_list_versions(self, mock_get_conn, mock_sleep): list_versions_response = self.hook.list_versions( project_id=project_id, model_name=model_name) - + # fmt: on self.assertEqual(list_versions_response, version_names) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().list(pageSize=100, parent=model_path), - mock.call().projects().models().versions().list().execute(), - ] + [ - mock.call().projects().models().versions().list_next( - previous_request=pages_requests[i], previous_response=response_bodies[i] - ) for i in range(3) - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().list(pageSize=100, parent=model_path), + mock.call().projects().models().versions().list().execute(), + ] + + [ + mock.call() + .projects() + .models() + .versions() + .list_next(previous_request=pages_requests[i], previous_response=response_bodies[i]) + for i in range(3) + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_delete_version(self, mock_get_conn): @@ -225,7 +231,7 @@ def test_delete_version(self, mock_get_conn): version = {'name': operation_path} operation_not_done = {'name': operation_path, 'done': False} operation_done = {'name': operation_path, 'done': True} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -242,18 +248,21 @@ def test_delete_version(self, mock_get_conn): delete.return_value. execute.return_value ) = version - + # fmt: on delete_version_response = self.hook.delete_version( - project_id=project_id, model_name=model_name, - version_name=version_name) + project_id=project_id, model_name=model_name, version_name=version_name + ) self.assertEqual(delete_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().delete(name=version_path), - mock.call().projects().models().versions().delete().execute(), - mock.call().projects().operations().get(name=operation_path), - mock.call().projects().operations().get().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().delete(name=version_path), + mock.call().projects().models().versions().delete().execute(), + mock.call().projects().operations().get(name=operation_path), + mock.call().projects().operations().get().execute(), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_model(self, mock_get_conn): @@ -264,10 +273,10 @@ def test_create_model(self, mock_get_conn): } model_with_airflow_version = { 'name': model_name, - 'labels': {'airflow-version': hook._AIRFLOW_VERSION} + 'labels': {'airflow-version': hook._AIRFLOW_VERSION}, } project_path = 'projects/{}'.format(project_id) - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -275,16 +284,16 @@ def test_create_model(self, mock_get_conn): create.return_value. execute.return_value ) = model - - create_model_response = self.hook.create_model( - project_id=project_id, model=deepcopy(model) - ) + # fmt: on + create_model_response = self.hook.create_model(project_id=project_id, model=deepcopy(model)) self.assertEqual(create_model_response, model) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path), - mock.call().projects().models().create().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path), + mock.call().projects().models().create().execute(), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_model_idempotency(self, mock_get_conn): @@ -295,10 +304,10 @@ def test_create_model_idempotency(self, mock_get_conn): } model_with_airflow_version = { 'name': model_name, - 'labels': {'airflow-version': hook._AIRFLOW_VERSION} + 'labels': {'airflow-version': hook._AIRFLOW_VERSION}, } project_path = 'projects/{}'.format(project_id) - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -338,38 +347,34 @@ def test_create_model_idempotency(self, mock_get_conn): get.return_value. execute.return_value ) = deepcopy(model) - - create_model_response = self.hook.create_model( - project_id=project_id, model=deepcopy(model) - ) + # fmt: on + create_model_response = self.hook.create_model(project_id=project_id, model=deepcopy(model)) self.assertEqual(create_model_response, model) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path), - mock.call().projects().models().create().execute(), - ]) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().get(name='projects/test-project/models/test-model'), - mock.call().projects().models().get().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path), + mock.call().projects().models().create().execute(), + ] + ) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().get(name='projects/test-project/models/test-model'), + mock.call().projects().models().get().execute(), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_model_with_labels(self, mock_get_conn): project_id = 'test-project' model_name = 'test-model' - model = { - 'name': model_name, - 'labels': {'other-label': 'test-value'} - } + model = {'name': model_name, 'labels': {'other-label': 'test-value'}} model_with_airflow_version = { 'name': model_name, - 'labels': { - 'other-label': 'test-value', - 'airflow-version': hook._AIRFLOW_VERSION - } + 'labels': {'other-label': 'test-value', 'airflow-version': hook._AIRFLOW_VERSION}, } project_path = 'projects/{}'.format(project_id) - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -381,12 +386,14 @@ def test_create_model_with_labels(self, mock_get_conn): create_model_response = self.hook.create_model( project_id=project_id, model=deepcopy(model) ) - + # fmt: on self.assertEqual(create_model_response, model) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path), - mock.call().projects().models().create().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path), + mock.call().projects().models().create().execute(), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_get_model(self, mock_get_conn): @@ -394,7 +401,7 @@ def test_get_model(self, mock_get_conn): model_name = 'test-model' model = {'model': model_name} model_path = 'projects/{}/models/{}'.format(project_id, model_name) - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -402,16 +409,16 @@ def test_get_model(self, mock_get_conn): get.return_value. execute.return_value ) = model - - get_model_response = self.hook.get_model( - project_id=project_id, model_name=model_name - ) + # fmt: on + get_model_response = self.hook.get_model(project_id=project_id, model_name=model_name) self.assertEqual(get_model_response, model) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().get(name=model_path), - mock.call().projects().models().get().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().get(name=model_path), + mock.call().projects().models().get().execute(), + ] + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_delete_model(self, mock_get_conn): @@ -419,6 +426,7 @@ def test_delete_model(self, mock_get_conn): model_name = 'test-model' model = {'model': model_name} model_path = 'projects/{}/models/{}'.format(project_id, model_name) + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -426,16 +434,16 @@ def test_delete_model(self, mock_get_conn): delete.return_value. execute.return_value ) = model + # fmt: on + self.hook.delete_model(project_id=project_id, model_name=model_name) - self.hook.delete_model( - project_id=project_id, model_name=model_name + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().delete(name=model_path), + mock.call().projects().models().delete().execute(), + ] ) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().delete(name=model_path), - mock.call().projects().models().delete().execute() - ]) - @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.log") @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_delete_model_when_not_exists(self, mock_get_conn, mock_log): @@ -443,9 +451,9 @@ def test_delete_model_when_not_exists(self, mock_get_conn, mock_log): model_name = 'test-model' model_path = 'projects/{}/models/{}'.format(project_id, model_name) http_error = HttpError( - resp=mock.MagicMock(status=404, reason="Model not found."), - content=b'Model not found.' + resp=mock.MagicMock(status=404, reason="Model not found."), content=b'Model not found.' ) + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -453,15 +461,15 @@ def test_delete_model_when_not_exists(self, mock_get_conn, mock_log): delete.return_value. execute.side_effect ) = [http_error] + # fmt: on + self.hook.delete_model(project_id=project_id, model_name=model_name) - self.hook.delete_model( - project_id=project_id, model_name=model_name + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().delete(name=model_path), + mock.call().projects().models().delete().execute(), + ] ) - - mock_get_conn.assert_has_calls([ - mock.call().projects().models().delete(name=model_path), - mock.call().projects().models().delete().execute() - ]) mock_log.error.assert_called_once_with('Model was not found: %s', http_error) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.time.sleep") @@ -473,11 +481,14 @@ def test_delete_model_with_contents(self, mock_get_conn, mock_sleep): operation_path = 'projects/{}/operations/test-operation'.format(project_id) operation_done = {'name': operation_path, 'done': True} version_names = ["AAA", "BBB", "CCC"] - versions = [{ - 'name': 'projects/{}/models/{}/versions/{}'.format(project_id, model_name, version_name), - "isDefault": i == 0 - } for i, version_name in enumerate(version_names)] - + versions = [ + { + 'name': 'projects/{}/models/{}/versions/{}'.format(project_id, model_name, version_name), + "isDefault": i == 0, + } + for i, version_name in enumerate(version_names) + ] + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -500,21 +511,25 @@ def test_delete_model_with_contents(self, mock_get_conn, mock_sleep): versions.return_value. list_next.return_value ) = None - - self.hook.delete_model( - project_id=project_id, model_name=model_name, delete_contents=True - ) + # fmt: on + self.hook.delete_model(project_id=project_id, model_name=model_name, delete_contents=True) mock_get_conn.assert_has_calls( [ mock.call().projects().models().delete(name=model_path), - mock.call().projects().models().delete().execute() - ] + [ - mock.call().projects().models().versions().delete( + mock.call().projects().models().delete().execute(), + ] + + [ + mock.call() + .projects() + .models() + .versions() + .delete( name='projects/{}/models/{}/versions/{}'.format(project_id, model_name, version_name), - ) for version_name in version_names + ) + for version_name in version_names ], - any_order=True + any_order=True, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.time.sleep") @@ -531,7 +546,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep): new_job_with_airflow_version = { 'jobId': job_id, 'foo': 4815162342, - 'labels': {'airflow-version': hook._AIRFLOW_VERSION} + 'labels': {'airflow-version': hook._AIRFLOW_VERSION}, } job_succeeded = { @@ -542,7 +557,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep): 'jobId': job_id, 'state': 'QUEUED', } - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -557,17 +572,18 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep): get.return_value. execute.side_effect ) = [job_queued, job_succeeded] - - create_job_response = self.hook.create_job( - project_id=project_id, job=deepcopy(new_job) - ) + # fmt: on + create_job_response = self.hook.create_job(project_id=project_id, job=deepcopy(new_job)) self.assertEqual(create_job_response, job_succeeded) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path), - mock.call().projects().jobs().get(name=job_path), - mock.call().projects().jobs().get().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path), + mock.call().projects().jobs().get(name=job_path), + mock.call().projects().jobs().get().execute(), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.time.sleep") @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") @@ -576,18 +592,11 @@ def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep): job_id = 'test-job-id' project_path = 'projects/{}'.format(project_id) job_path = 'projects/{}/jobs/{}'.format(project_id, job_id) - new_job = { - 'jobId': job_id, - 'foo': 4815162342, - 'labels': {'other-label': 'test-value'} - } + new_job = {'jobId': job_id, 'foo': 4815162342, 'labels': {'other-label': 'test-value'}} new_job_with_airflow_version = { 'jobId': job_id, 'foo': 4815162342, - 'labels': { - 'other-label': 'test-value', - 'airflow-version': hook._AIRFLOW_VERSION - } + 'labels': {'other-label': 'test-value', 'airflow-version': hook._AIRFLOW_VERSION}, } job_succeeded = { @@ -598,7 +607,7 @@ def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep): 'jobId': job_id, 'state': 'QUEUED', } - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -617,13 +626,16 @@ def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep): create_job_response = self.hook.create_job( project_id=project_id, job=deepcopy(new_job) ) - + # fmt: on self.assertEqual(create_job_response, job_succeeded) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path), - mock.call().projects().jobs().get(name=job_path), - mock.call().projects().jobs().get().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path), + mock.call().projects().jobs().get(name=job_path), + mock.call().projects().jobs().get().execute(), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_mlengine_job_reuse_existing_job_by_default(self, mock_get_conn): @@ -637,7 +649,7 @@ def test_create_mlengine_job_reuse_existing_job_by_default(self, mock_get_conn): 'state': 'SUCCEEDED', } error_job_exists = HttpError(resp=mock.MagicMock(status=409), content=b'Job already exists') - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -652,17 +664,19 @@ def test_create_mlengine_job_reuse_existing_job_by_default(self, mock_get_conn): get.return_value. execute.return_value ) = job_succeeded - - create_job_response = self.hook.create_job( - project_id=project_id, job=job_succeeded) + # fmt: on + create_job_response = self.hook.create_job(project_id=project_id, job=job_succeeded) self.assertEqual(create_job_response, job_succeeded) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().create(body=job_succeeded, parent=project_path), - mock.call().projects().jobs().create().execute(), - mock.call().projects().jobs().get(name=job_path), - mock.call().projects().jobs().get().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().jobs().create(body=job_succeeded, parent=project_path), + mock.call().projects().jobs().create().execute(), + mock.call().projects().jobs().get(name=job_path), + mock.call().projects().jobs().get().execute(), + ], + any_order=True, + ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_mlengine_job_check_existing_job_failed(self, mock_get_conn): @@ -672,20 +686,17 @@ def test_create_mlengine_job_check_existing_job_failed(self, mock_get_conn): 'jobId': job_id, 'foo': 4815162342, 'state': 'SUCCEEDED', - 'someInput': { - 'input': 'someInput' - } + 'someInput': {'input': 'someInput'}, } different_job = { 'jobId': job_id, 'foo': 4815162342, 'state': 'SUCCEEDED', - 'someInput': { - 'input': 'someDifferentInput' - } + 'someInput': {'input': 'someDifferentInput'}, } error_job_exists = HttpError(resp=mock.MagicMock(status=409), content=b'Job already exists') + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -701,14 +712,12 @@ def test_create_mlengine_job_check_existing_job_failed(self, mock_get_conn): execute.return_value ) = different_job + # fmt: on def check_input(existing_job): - return existing_job.get('someInput', None) == \ - my_job['someInput'] + return existing_job.get('someInput', None) == my_job['someInput'] with self.assertRaises(HttpError): - self.hook.create_job( - project_id=project_id, job=my_job, - use_existing_job_fn=check_input) + self.hook.create_job(project_id=project_id, job=my_job, use_existing_job_fn=check_input) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_mlengine_job_check_existing_job_success(self, mock_get_conn): @@ -718,12 +727,10 @@ def test_create_mlengine_job_check_existing_job_success(self, mock_get_conn): 'jobId': job_id, 'foo': 4815162342, 'state': 'SUCCEEDED', - 'someInput': { - 'input': 'someInput' - } + 'someInput': {'input': 'someInput'}, } error_job_exists = HttpError(resp=mock.MagicMock(status=409), content=b'Job already exists') - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -739,12 +746,13 @@ def test_create_mlengine_job_check_existing_job_success(self, mock_get_conn): execute.return_value ) = my_job + # fmt: on def check_input(existing_job): return existing_job.get('someInput', None) == my_job['someInput'] create_job_response = self.hook.create_job( - project_id=project_id, job=my_job, - use_existing_job_fn=check_input) + project_id=project_id, job=my_job, use_existing_job_fn=check_input + ) self.assertEqual(create_job_response, my_job) @@ -755,7 +763,7 @@ def test_cancel_mlengine_job(self, mock_get_conn): job_path = 'projects/{}/jobs/{}'.format(project_id, job_id) job_cancelled = {} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -763,13 +771,11 @@ def test_cancel_mlengine_job(self, mock_get_conn): cancel.return_value. execute.return_value ) = job_cancelled - + # fmt: on cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=project_id) self.assertEqual(cancel_job_response, job_cancelled) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().cancel(name=job_path), - ], any_order=True) + mock_get_conn.assert_has_calls([mock.call().projects().jobs().cancel(name=job_path),], any_order=True) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_cancel_mlengine_job_nonexistent_job(self, mock_get_conn): @@ -778,7 +784,7 @@ def test_cancel_mlengine_job_nonexistent_job(self, mock_get_conn): job_cancelled = {} error_job_does_not_exist = HttpError(resp=mock.MagicMock(status=404), content=b'Job does not exist') - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -793,7 +799,7 @@ def test_cancel_mlengine_job_nonexistent_job(self, mock_get_conn): cancel.return_value. execute.return_value ) = job_cancelled - + # fmt: on with self.assertRaises(HttpError): self.hook.cancel_job(job_id=job_id, project_id=project_id) @@ -805,9 +811,9 @@ def test_cancel_mlengine_job_completed_job(self, mock_get_conn): job_cancelled = {} error_job_already_completed = HttpError( - resp=mock.MagicMock(status=400), - content=b'Job already completed') - + resp=mock.MagicMock(status=400), content=b'Job already completed' + ) + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -822,13 +828,11 @@ def test_cancel_mlengine_job_completed_job(self, mock_get_conn): cancel.return_value. execute.return_value ) = job_cancelled - + # fmt: on cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=project_id) self.assertEqual(cancel_job_response, job_cancelled) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().cancel(name=job_path), - ], any_order=True) + mock_get_conn.assert_has_calls([mock.call().projects().jobs().cancel(name=job_path),], any_order=True) class TestMLEngineHookWithDefaultProjectId(unittest.TestCase): @@ -843,7 +847,7 @@ def setUp(self) -> None: @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_version(self, mock_get_conn, mock_project_id): @@ -853,7 +857,7 @@ def test_create_version(self, mock_get_conn, mock_project_id): operation_path = 'projects/{}/operations/test-operation'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST) model_path = 'projects/{}/models/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST, model_name) operation_done = {'name': operation_path, 'done': True} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -869,24 +873,25 @@ def test_create_version(self, mock_get_conn, mock_project_id): get.return_value. execute.return_value ) = {'name': operation_path, 'done': True} - + # fmt: on create_version_response = self.hook.create_version( - model_name=model_name, - version_spec=version, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + model_name=model_name, version_spec=version, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST ) self.assertEqual(create_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().create(body=version, parent=model_path), - mock.call().projects().models().versions().create().execute(), - mock.call().projects().operations().get(name=version_name), - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().create(body=version, parent=model_path), + mock.call().projects().models().versions().create().execute(), + mock.call().projects().operations().get(name=version_name), + ], + any_order=True, + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_set_default_version(self, mock_get_conn, mock_project_id): @@ -897,7 +902,7 @@ def test_set_default_version(self, mock_get_conn, mock_project_id): GCP_PROJECT_ID_HOOK_UNIT_TEST, model_name, version_name ) operation_done = {'name': operation_path, 'done': True} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -906,23 +911,24 @@ def test_set_default_version(self, mock_get_conn, mock_project_id): setDefault.return_value. execute.return_value ) = operation_done - + # fmt: on set_default_version_response = self.hook.set_default_version( - model_name=model_name, - version_name=version_name, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + model_name=model_name, version_name=version_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) self.assertEqual(set_default_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().setDefault(body={}, name=version_path), - mock.call().projects().models().versions().setDefault().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().setDefault(body={}, name=version_path), + mock.call().projects().models().versions().setDefault().execute(), + ], + any_order=True, + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.time.sleep") @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") @@ -931,47 +937,54 @@ def test_list_versions(self, mock_get_conn, mock_sleep, mock_project_id): model_path = 'projects/{}/models/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST, model_name) version_names = ['ver_{}'.format(ix) for ix in range(3)] response_bodies = [ - { - 'nextPageToken': "TOKEN-{}".format(ix), - 'versions': [ver] - } for ix, ver in enumerate(version_names)] + {'nextPageToken': "TOKEN-{}".format(ix), 'versions': [ver]} + for ix, ver in enumerate(version_names) + ] response_bodies[-1].pop('nextPageToken') - pages_requests = [ - mock.Mock(**{'execute.return_value': body}) for body in response_bodies - ] + pages_requests = [mock.Mock(**{'execute.return_value': body}) for body in response_bodies] versions_mock = mock.Mock( **{'list.return_value': pages_requests[0], 'list_next.side_effect': pages_requests[1:] + [None]} ) + # fmt: off ( mock_get_conn.return_value. projects.return_value. models.return_value. versions.return_value ) = versions_mock - + # fmt: on list_versions_response = self.hook.list_versions( - model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) self.assertEqual(list_versions_response, version_names) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().list(pageSize=100, parent=model_path), - mock.call().projects().models().versions().list().execute(), - ] + [ - mock.call().projects().models().versions().list_next( - previous_request=pages_requests[i], previous_response=response_bodies[i] - ) for i in range(3) - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().list(pageSize=100, parent=model_path), + mock.call().projects().models().versions().list().execute(), + ] + + [ + mock.call() + .projects() + .models() + .versions() + .list_next(previous_request=pages_requests[i], previous_response=response_bodies[i]) + for i in range(3) + ], + any_order=True, + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_delete_version(self, mock_get_conn, mock_project_id): model_name = 'test-model' version_name = 'test-version' + # fmt: off operation_path = 'projects/{}/operations/test-operation'.format( GCP_PROJECT_ID_HOOK_UNIT_TEST ) @@ -998,22 +1011,26 @@ def test_delete_version(self, mock_get_conn, mock_project_id): delete.return_value. execute.return_value ) = version - + # fmt: on delete_version_response = self.hook.delete_version( - model_name=model_name, version_name=version_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + model_name=model_name, version_name=version_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) self.assertEqual(delete_version_response, operation_done) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().versions().delete(name=version_path), - mock.call().projects().models().versions().delete().execute(), - mock.call().projects().operations().get(name=operation_path), - mock.call().projects().operations().get().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().versions().delete(name=version_path), + mock.call().projects().models().versions().delete().execute(), + mock.call().projects().operations().get(name=operation_path), + mock.call().projects().operations().get().execute(), + ], + any_order=True, + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_create_model(self, mock_get_conn, mock_project_id): @@ -1022,7 +1039,7 @@ def test_create_model(self, mock_get_conn, mock_project_id): 'name': model_name, } project_path = 'projects/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST) - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -1030,26 +1047,28 @@ def test_create_model(self, mock_get_conn, mock_project_id): create.return_value. execute.return_value ) = model - + # fmt: on create_model_response = self.hook.create_model(model=model, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) self.assertEqual(create_model_response, model) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().create(body=model, parent=project_path), - mock.call().projects().models().create().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().create(body=model, parent=project_path), + mock.call().projects().models().create().execute(), + ] + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_get_model(self, mock_get_conn, mock_project_id): model_name = 'test-model' model = {'model': model_name} model_path = 'projects/{}/models/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST, model_name) - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -1057,26 +1076,30 @@ def test_get_model(self, mock_get_conn, mock_project_id): get.return_value. execute.return_value ) = model - + # fmt: on get_model_response = self.hook.get_model( - model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) self.assertEqual(get_model_response, model) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().get(name=model_path), - mock.call().projects().models().get().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().get(name=model_path), + mock.call().projects().models().get().execute(), + ] + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_delete_model(self, mock_get_conn, mock_project_id): model_name = 'test-model' model = {'model': model_name} model_path = 'projects/{}/models/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST, model_name) + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -1084,18 +1107,20 @@ def test_delete_model(self, mock_get_conn, mock_project_id): delete.return_value. execute.return_value ) = model - + # fmt: on self.hook.delete_model(model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) - mock_get_conn.assert_has_calls([ - mock.call().projects().models().delete(name=model_path), - mock.call().projects().models().delete().execute() - ]) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().models().delete(name=model_path), + mock.call().projects().models().delete().execute(), + ] + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.time.sleep") @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") @@ -1115,7 +1140,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep, mock_project_id): 'jobId': job_id, 'state': 'QUEUED', } - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -1130,20 +1155,23 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep, mock_project_id): get.return_value. execute.side_effect ) = [job_queued, job_succeeded] - + # fmt: on create_job_response = self.hook.create_job(job=new_job, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) self.assertEqual(create_job_response, job_succeeded) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().create(body=new_job, parent=project_path), - mock.call().projects().jobs().get(name=job_path), - mock.call().projects().jobs().get().execute() - ], any_order=True) + mock_get_conn.assert_has_calls( + [ + mock.call().projects().jobs().create(body=new_job, parent=project_path), + mock.call().projects().jobs().get(name=job_path), + mock.call().projects().jobs().get().execute(), + ], + any_order=True, + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn") def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id): @@ -1151,7 +1179,7 @@ def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id): job_path = 'projects/{}/jobs/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST, job_id) job_cancelled = {} - + # fmt: off ( mock_get_conn.return_value. projects.return_value. @@ -1159,10 +1187,8 @@ def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id): cancel.return_value. execute.return_value ) = job_cancelled - + # fmt: on cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) self.assertEqual(cancel_job_response, job_cancelled) - mock_get_conn.assert_has_calls([ - mock.call().projects().jobs().cancel(name=job_path), - ], any_order=True) + mock_get_conn.assert_has_calls([mock.call().projects().jobs().cancel(name=job_path),], any_order=True) diff --git a/tests/providers/google/cloud/hooks/test_natural_language.py b/tests/providers/google/cloud/hooks/test_natural_language.py index d4f7239f3da88..203a1fc679bdf 100644 --- a/tests/providers/google/cloud/hooks/test_natural_language.py +++ b/tests/providers/google/cloud/hooks/test_natural_language.py @@ -42,23 +42,21 @@ def setUp(self): @mock.patch( "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, + ) + @mock.patch( + "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook." "_get_credentials" ) - @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook." - "_get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.natural_language.LanguageServiceClient") def test_language_service_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.hook._conn, result) - @mock.patch( - "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn", - ) + @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",) def test_analyze_entities(self, get_conn): get_conn.return_value.analyze_entities.return_value = API_RESPONSE result = self.hook.analyze_entities(document=DOCUMENT, encoding_type=ENCODING_TYPE) @@ -69,9 +67,7 @@ def test_analyze_entities(self, get_conn): document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None ) - @mock.patch( - "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn", - ) + @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",) def test_analyze_entity_sentiment(self, get_conn): get_conn.return_value.analyze_entity_sentiment.return_value = API_RESPONSE result = self.hook.analyze_entity_sentiment(document=DOCUMENT, encoding_type=ENCODING_TYPE) @@ -82,9 +78,7 @@ def test_analyze_entity_sentiment(self, get_conn): document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None ) - @mock.patch( - "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn", - ) + @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",) def test_analyze_sentiment(self, get_conn): get_conn.return_value.analyze_sentiment.return_value = API_RESPONSE result = self.hook.analyze_sentiment(document=DOCUMENT, encoding_type=ENCODING_TYPE) @@ -95,9 +89,7 @@ def test_analyze_sentiment(self, get_conn): document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None ) - @mock.patch( - "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn", - ) + @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",) def test_analyze_syntax(self, get_conn): get_conn.return_value.analyze_syntax.return_value = API_RESPONSE result = self.hook.analyze_syntax(document=DOCUMENT, encoding_type=ENCODING_TYPE) @@ -108,9 +100,7 @@ def test_analyze_syntax(self, get_conn): document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None ) - @mock.patch( - "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn", - ) + @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",) def test_annotate_text(self, get_conn): get_conn.return_value.annotate_text.return_value = API_RESPONSE result = self.hook.annotate_text(document=DOCUMENT, encoding_type=ENCODING_TYPE, features=None) @@ -126,9 +116,7 @@ def test_annotate_text(self, get_conn): metadata=None, ) - @mock.patch( - "airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn", - ) + @mock.patch("airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",) def test_classify_text(self, get_conn): get_conn.return_value.classify_text.return_value = API_RESPONSE result = self.hook.classify_text(document=DOCUMENT) diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py index dc637fb40459c..1b1ddf21bbe08 100644 --- a/tests/providers/google/cloud/hooks/test_pubsub.py +++ b/tests/providers/google/cloud/hooks/test_pubsub.py @@ -39,12 +39,10 @@ TEST_SUBSCRIPTION = 'test-subscription' TEST_UUID = 'abc123-xzy789' TEST_MESSAGES = [ - { - 'data': b'Hello, World!', - 'attributes': {'type': 'greeting'} - }, + {'data': b'Hello, World!', 'attributes': {'type': 'greeting'}}, {'data': b'Knock, knock'}, - {'attributes': {'foo': ''}}] + {'attributes': {'foo': ''}}, +] EXPANDED_TOPIC = 'projects/{}/topics/{}'.format(TEST_PROJECT, TEST_TOPIC) EXPANDED_SUBSCRIPTION = 'projects/{}/subscriptions/{}'.format(TEST_PROJECT, TEST_SUBSCRIPTION) @@ -52,18 +50,14 @@ def mock_init( - self, - gcp_conn_id, - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id, delegate_to=None, impersonation_chain=None, ): # pylint: disable=unused-argument pass class TestPubSubHook(unittest.TestCase): def setUp(self): - with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), - new=mock_init): + with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init): self.pubsub_hook = PubSubHook(gcp_conn_id='test') def _generate_messages(self, count) -> List[ReceivedMessage]: @@ -81,30 +75,30 @@ def _generate_messages(self, count) -> List[ReceivedMessage]: for i in range(1, count + 1) ] - @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook.client_info", - new_callable=mock.PropertyMock) + @mock.patch( + "airflow.providers.google.cloud.hooks.pubsub.PubSubHook.client_info", new_callable=mock.PropertyMock + ) @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PublisherClient") def test_publisher_client_creation(self, mock_client, mock_get_creds, mock_client_info): self.assertIsNone(self.pubsub_hook._client) result = self.pubsub_hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.pubsub_hook._client, result) - @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook.client_info", - new_callable=mock.PropertyMock) + @mock.patch( + "airflow.providers.google.cloud.hooks.pubsub.PubSubHook.client_info", new_callable=mock.PropertyMock + ) @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.pubsub.SubscriberClient") def test_subscriber_client_creation(self, mock_client, mock_get_creds, mock_client_info): self.assertIsNone(self.pubsub_hook._client) result = self.pubsub_hook.subscriber_client mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) @@ -119,19 +113,14 @@ def test_create_nonexistent_topic(self, mock_service): kms_key_name=None, retry=None, timeout=None, - metadata=None + metadata=None, ) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_delete_topic(self, mock_service): delete_method = mock_service.return_value.delete_topic self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC) - delete_method.assert_called_once_with( - topic=EXPANDED_TOPIC, - retry=None, - timeout=None, - metadata=None - ) + delete_method.assert_called_once_with(topic=EXPANDED_TOPIC, retry=None, timeout=None, metadata=None) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_delete_nonexisting_topic_failifnotexists(self, mock_service): @@ -208,7 +197,7 @@ def test_create_subscription_different_project_topic(self, mock_service): project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, - subscription_project_id='a-different-project' + subscription_project_id='a-different-project', ) expected_subscription = 'projects/{}/subscriptions/{}'.format( 'a-different-project', TEST_SUBSCRIPTION @@ -238,10 +227,7 @@ def test_delete_subscription(self, mock_service): self.pubsub_hook.delete_subscription(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION) delete_method = mock_service.delete_subscription delete_method.assert_called_once_with( - subscription=EXPANDED_SUBSCRIPTION, - retry=None, - timeout=None, - metadata=None + subscription=EXPANDED_SUBSCRIPTION, retry=None, timeout=None, metadata=None ) @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) @@ -267,8 +253,9 @@ def test_delete_subscription_api_call_error(self, mock_service): @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) @mock.patch(PUBSUB_STRING.format('uuid4'), new_callable=mock.Mock(return_value=lambda: TEST_UUID)) - def test_create_subscription_without_subscription_name(self, mock_uuid, - mock_service): # noqa # pylint: disable=unused-argument,line-too-long + def test_create_subscription_without_subscription_name( + self, mock_uuid, mock_service + ): # noqa # pylint: disable=unused-argument,line-too-long create_method = mock_service.create_subscription expected_name = EXPANDED_SUBSCRIPTION.replace(TEST_SUBSCRIPTION, 'sub-%s' % TEST_UUID) @@ -326,7 +313,7 @@ def test_create_subscription_with_filter(self, mock_service): project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, - filter_='attributes.domain="com"' + filter_='attributes.domain="com"', ) create_method.assert_called_once_with( name=EXPANDED_SUBSCRIPTION, @@ -438,12 +425,15 @@ def test_pull_no_messages(self, mock_service): ) self.assertListEqual([], response) - @parameterized.expand([ - (exception,) for exception in [ - HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), - GoogleAPICallError("API Call Error") + @parameterized.expand( + [ + (exception,) + for exception in [ + HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), + GoogleAPICallError("API Call Error"), + ] ] - ]) + ) @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_pull_fails_on_exception(self, exception, mock_service): pull_method = mock_service.pull @@ -465,16 +455,14 @@ def test_acknowledge_by_ack_ids(self, mock_service): ack_method = mock_service.acknowledge self.pubsub_hook.acknowledge( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - ack_ids=['1', '2', '3'] + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3'] ) ack_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, ack_ids=['1', '2', '3'], retry=None, timeout=None, - metadata=None + metadata=None, ) @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) @@ -482,9 +470,7 @@ def test_acknowledge_by_message_objects(self, mock_service): ack_method = mock_service.acknowledge self.pubsub_hook.acknowledge( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - messages=self._generate_messages(3), + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, messages=self._generate_messages(3), ) ack_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, @@ -494,12 +480,15 @@ def test_acknowledge_by_message_objects(self, mock_service): metadata=None, ) - @parameterized.expand([ - (exception,) for exception in [ - HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), - GoogleAPICallError("API Call Error") + @parameterized.expand( + [ + (exception,) + for exception in [ + HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), + GoogleAPICallError("API Call Error"), + ] ] - ]) + ) @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_acknowledge_fails_on_exception(self, exception, mock_service): ack_method = mock_service.acknowledge @@ -507,44 +496,47 @@ def test_acknowledge_fails_on_exception(self, exception, mock_service): with self.assertRaises(PubSubException): self.pubsub_hook.acknowledge( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - ack_ids=['1', '2', '3'] + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3'] ) ack_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, ack_ids=['1', '2', '3'], retry=None, timeout=None, - metadata=None + metadata=None, ) - @parameterized.expand([ - (messages,) for messages in [ - [{"data": b'test'}], - [{"data": b''}], - [{"data": b'test', "attributes": {"weight": "100kg"}}], - [{"data": b'', "attributes": {"weight": "100kg"}}], - [{"attributes": {"weight": "100kg"}}], + @parameterized.expand( + [ + (messages,) + for messages in [ + [{"data": b'test'}], + [{"data": b''}], + [{"data": b'test', "attributes": {"weight": "100kg"}}], + [{"data": b'', "attributes": {"weight": "100kg"}}], + [{"attributes": {"weight": "100kg"}}], + ] ] - ]) + ) def test_messages_validation_positive(self, messages): PubSubHook._validate_messages(messages) - @parameterized.expand([ - ([("wrong type",)], "Wrong message type. Must be a dictionary."), - ([{"wrong_key": b'test'}], "Wrong message. Dictionary must contain 'data' or 'attributes'."), - ([{"data": 'wrong string'}], "Wrong message. 'data' must be send as a bytestring"), - ([{"data": None}], "Wrong message. 'data' must be send as a bytestring"), - ( - [{"attributes": None}], - "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary." - ), - ( - [{"attributes": "wrong string"}], - "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary." - ) - ]) + @parameterized.expand( + [ + ([("wrong type",)], "Wrong message type. Must be a dictionary."), + ([{"wrong_key": b'test'}], "Wrong message. Dictionary must contain 'data' or 'attributes'."), + ([{"data": 'wrong string'}], "Wrong message. 'data' must be send as a bytestring"), + ([{"data": None}], "Wrong message. 'data' must be send as a bytestring"), + ( + [{"attributes": None}], + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary.", + ), + ( + [{"attributes": "wrong string"}], + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary.", + ), + ] + ) def test_messages_validation_negative(self, messages, error_message): with self.assertRaises(PubSubException) as e: PubSubHook._validate_messages(messages) diff --git a/tests/providers/google/cloud/hooks/test_secret_manager.py b/tests/providers/google/cloud/hooks/test_secret_manager.py index 0510028d8c7ef..5c0e9aa2922cd 100644 --- a/tests/providers/google/cloud/hooks/test_secret_manager.py +++ b/tests/providers/google/cloud/hooks/test_secret_manager.py @@ -24,7 +24,8 @@ from airflow.providers.google.cloud.hooks.secret_manager import SecretsManagerHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, ) BASE_PACKAGE = 'airflow.providers.google.common.hooks.base_google.' @@ -33,10 +34,11 @@ class TestSecretsManagerHook(unittest.TestCase): - @patch(INTERNAL_CLIENT_PACKAGE + "._SecretManagerClient.client", return_value=MagicMock()) - @patch(SECRETS_HOOK_PACKAGE + 'SecretsManagerHook._get_credentials_and_project_id', - return_value=(MagicMock(), GCP_PROJECT_ID_HOOK_UNIT_TEST)) + @patch( + SECRETS_HOOK_PACKAGE + 'SecretsManagerHook._get_credentials_and_project_id', + return_value=(MagicMock(), GCP_PROJECT_ID_HOOK_UNIT_TEST), + ) @patch(BASE_PACKAGE + 'GoogleBaseHook.__init__', new=mock_base_gcp_hook_default_project_id) def test_get_missing_key(self, mock_get_credentials, mock_client): mock_client.secret_version_path.return_value = "full-path" @@ -49,8 +51,10 @@ def test_get_missing_key(self, mock_get_credentials, mock_client): self.assertIsNone(secret) @patch(INTERNAL_CLIENT_PACKAGE + "._SecretManagerClient.client", return_value=MagicMock()) - @patch(SECRETS_HOOK_PACKAGE + 'SecretsManagerHook._get_credentials_and_project_id', - return_value=(MagicMock(), GCP_PROJECT_ID_HOOK_UNIT_TEST)) + @patch( + SECRETS_HOOK_PACKAGE + 'SecretsManagerHook._get_credentials_and_project_id', + return_value=(MagicMock(), GCP_PROJECT_ID_HOOK_UNIT_TEST), + ) @patch(BASE_PACKAGE + 'GoogleBaseHook.__init__', new=mock_base_gcp_hook_default_project_id) def test_get_existing_key(self, mock_get_credentials, mock_client): mock_client.secret_version_path.return_value = "full-path" diff --git a/tests/providers/google/cloud/hooks/test_spanner.py b/tests/providers/google/cloud/hooks/test_spanner.py index 71d8b9bce7d15..937104ed0c37b 100644 --- a/tests/providers/google/cloud/hooks/test_spanner.py +++ b/tests/providers/google/cloud/hooks/test_spanner.py @@ -23,7 +23,8 @@ from airflow.providers.google.cloud.hooks.spanner import SpannerHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -33,15 +34,15 @@ class TestGcpSpannerHookDefaultProjectId(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_default_project_id, + ): self.spanner_hook_default_project_id = SpannerHook(gcp_conn_id='test') @mock.patch( - "airflow.providers.google.cloud.hooks.spanner.SpannerHook.client_info", - new_callable=mock.PropertyMock + "airflow.providers.google.cloud.hooks.spanner.SpannerHook.client_info", new_callable=mock.PropertyMock ) @mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.spanner.Client") @@ -50,7 +51,7 @@ def test_spanner_client_creation(self, mock_client, mock_get_creds, mock_client_ mock_client.assert_called_once_with( project=GCP_PROJECT_ID_HOOK_UNIT_TEST, credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + client_info=mock_client_info.return_value, ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.spanner_hook_default_project_id._client, result) @@ -60,8 +61,9 @@ def test_get_existing_instance(self, get_client): instance_method = get_client.return_value.instance instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True - res = self.spanner_hook_default_project_id.get_instance(instance_id=SPANNER_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + res = self.spanner_hook_default_project_id.get_instance( + instance_id=SPANNER_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') self.assertIsNotNone(res) @@ -71,8 +73,9 @@ def test_get_existing_instance_overridden_project_id(self, get_client): instance_method = get_client.return_value.instance instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True - res = self.spanner_hook_default_project_id.get_instance(instance_id=SPANNER_INSTANCE, - project_id='new-project') + res = self.spanner_hook_default_project_id.get_instance( + instance_id=SPANNER_INSTANCE, project_id='new-project' + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with(instance_id='instance') self.assertIsNotNone(res) @@ -80,7 +83,7 @@ def test_get_existing_instance_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_create_instance(self, get_client, mock_project_id): @@ -99,7 +102,8 @@ def test_create_instance(self, get_client, mock_project_id): instance_id='instance', configuration_name='configuration', display_name='database-name', - node_count=1) + node_count=1, + ) self.assertIsNone(res) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') @@ -112,19 +116,21 @@ def test_create_instance_overridden_project_id(self, get_client): instance_id=SPANNER_INSTANCE, configuration_name=SPANNER_CONFIGURATION, node_count=1, - display_name=SPANNER_DATABASE) + display_name=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with( instance_id='instance', configuration_name='configuration', display_name='database-name', - node_count=1) + node_count=1, + ) self.assertIsNone(res) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_update_instance(self, get_client, mock_project_id): @@ -142,8 +148,11 @@ def test_update_instance(self, get_client, mock_project_id): ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with( - instance_id='instance', configuration_name='configuration', display_name='database-name', - node_count=2) + instance_id='instance', + configuration_name='configuration', + display_name='database-name', + node_count=2, + ) update_method.assert_called_once_with() self.assertIsNone(res) @@ -159,18 +168,22 @@ def test_update_instance_overridden_project_id(self, get_client): instance_id=SPANNER_INSTANCE, configuration_name=SPANNER_CONFIGURATION, node_count=2, - display_name=SPANNER_DATABASE) + display_name=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with( - instance_id='instance', configuration_name='configuration', display_name='database-name', - node_count=2) + instance_id='instance', + configuration_name='configuration', + display_name='database-name', + node_count=2, + ) update_method.assert_called_once_with() self.assertIsNone(res) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_delete_instance(self, get_client, mock_project_id): @@ -180,12 +193,10 @@ def test_delete_instance(self, get_client, mock_project_id): delete_method = instance_method.return_value.delete delete_method.return_value = False res = self.spanner_hook_default_project_id.delete_instance( - instance_id=SPANNER_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + instance_id=SPANNER_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) get_client.assert_called_once_with(project_id='example-project') - instance_method.assert_called_once_with( - 'instance') + instance_method.assert_called_once_with('instance') delete_method.assert_called_once_with() self.assertIsNone(res) @@ -197,18 +208,17 @@ def test_delete_instance_overridden_project_id(self, get_client): delete_method = instance_method.return_value.delete delete_method.return_value = False res = self.spanner_hook_default_project_id.delete_instance( - project_id='new-project', - instance_id=SPANNER_INSTANCE) + project_id='new-project', instance_id=SPANNER_INSTANCE + ) get_client.assert_called_once_with(project_id='new-project') - instance_method.assert_called_once_with( - 'instance') + instance_method.assert_called_once_with('instance') delete_method.assert_called_once_with() self.assertIsNone(res) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_get_database(self, get_client, mock_project_id): @@ -238,9 +248,8 @@ def test_get_database_overridden_project_id(self, get_client): database_exists_method = instance_method.return_value.exists database_exists_method.return_value = True res = self.spanner_hook_default_project_id.get_database( - project_id='new-project', - instance_id=SPANNER_INSTANCE, - database_id=SPANNER_DATABASE) + project_id='new-project', instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -250,7 +259,7 @@ def test_get_database_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_create_database(self, get_client, mock_project_id): @@ -282,7 +291,8 @@ def test_create_database_overridden_project_id(self, get_client): project_id='new-project', instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - ddl_statements=[]) + ddl_statements=[], + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name', ddl_statements=[]) @@ -292,7 +302,7 @@ def test_create_database_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_update_database(self, get_client, mock_project_id): @@ -324,7 +334,8 @@ def test_update_database_overridden_project_id(self, get_client): project_id='new-project', instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - ddl_statements=[]) + ddl_statements=[], + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -334,7 +345,7 @@ def test_update_database_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_delete_database(self, get_client, mock_project_id): @@ -367,9 +378,8 @@ def test_delete_database_overridden_project_id(self, get_client): database_exists_method = database_method.return_value.exists database_exists_method.return_value = True res = self.spanner_hook_default_project_id.delete_database( - project_id='new-project', - instance_id=SPANNER_INSTANCE, - database_id=SPANNER_DATABASE) + project_id='new-project', instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -380,7 +390,7 @@ def test_delete_database_overridden_project_id(self, get_client): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') def test_execute_dml(self, get_client, mock_project_id): @@ -409,10 +419,8 @@ def test_execute_dml_overridden_project_id(self, get_client): database_method = instance_method.return_value.database run_in_transaction_method = database_method.return_value.run_in_transaction res = self.spanner_hook_default_project_id.execute_dml( - project_id='new-project', - instance_id=SPANNER_INSTANCE, - database_id=SPANNER_DATABASE, - queries='') + project_id='new-project', instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, queries='' + ) get_client.assert_called_once_with(project_id='new-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -421,19 +429,19 @@ def test_execute_dml_overridden_project_id(self, get_client): class TestGcpSpannerHookNoDefaultProjectID(unittest.TestCase): - def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_no_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_no_default_project_id, + ): self.spanner_hook_no_default_project_id = SpannerHook(gcp_conn_id='test') @mock.patch( - "airflow.providers.google.cloud.hooks.spanner.SpannerHook.client_info", - new_callable=mock.PropertyMock + "airflow.providers.google.cloud.hooks.spanner.SpannerHook.client_info", new_callable=mock.PropertyMock ) @mock.patch( "airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_credentials", - return_value="CREDENTIALS" + return_value="CREDENTIALS", ) @mock.patch("airflow.providers.google.cloud.hooks.spanner.Client") def test_spanner_client_creation(self, mock_client, mock_get_creds, mock_client_info): @@ -441,7 +449,7 @@ def test_spanner_client_creation(self, mock_client, mock_get_creds, mock_client_ mock_client.assert_called_once_with( project=GCP_PROJECT_ID_HOOK_UNIT_TEST, credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + client_info=mock_client_info.return_value, ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.spanner_hook_no_default_project_id._client, result) @@ -451,8 +459,9 @@ def test_get_existing_instance_overridden_project_id(self, get_client): instance_method = get_client.return_value.instance instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = True - res = self.spanner_hook_no_default_project_id.get_instance(instance_id=SPANNER_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + res = self.spanner_hook_no_default_project_id.get_instance( + instance_id=SPANNER_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') self.assertIsNotNone(res) @@ -462,8 +471,9 @@ def test_get_non_existing_instance(self, get_client): instance_method = get_client.return_value.instance instance_exists_method = instance_method.return_value.exists instance_exists_method.return_value = False - res = self.spanner_hook_no_default_project_id.get_instance(instance_id=SPANNER_INSTANCE, - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST) + res = self.spanner_hook_no_default_project_id.get_instance( + instance_id=SPANNER_INSTANCE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') self.assertIsNone(res) @@ -478,13 +488,15 @@ def test_create_instance_overridden_project_id(self, get_client): instance_id=SPANNER_INSTANCE, configuration_name=SPANNER_CONFIGURATION, node_count=1, - display_name=SPANNER_DATABASE) + display_name=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with( instance_id='instance', configuration_name='configuration', display_name='database-name', - node_count=1) + node_count=1, + ) self.assertIsNone(res) @mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client') @@ -499,11 +511,15 @@ def test_update_instance_overridden_project_id(self, get_client): instance_id=SPANNER_INSTANCE, configuration_name=SPANNER_CONFIGURATION, node_count=2, - display_name=SPANNER_DATABASE) + display_name=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with( - instance_id='instance', configuration_name='configuration', display_name='database-name', - node_count=2) + instance_id='instance', + configuration_name='configuration', + display_name='database-name', + node_count=2, + ) update_method.assert_called_once_with() self.assertIsNone(res) @@ -515,11 +531,10 @@ def test_delete_instance_overridden_project_id(self, get_client): delete_method = instance_method.return_value.delete delete_method.return_value = False res = self.spanner_hook_no_default_project_id.delete_instance( - project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, - instance_id=SPANNER_INSTANCE) + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE + ) get_client.assert_called_once_with(project_id='example-project') - instance_method.assert_called_once_with( - 'instance') + instance_method.assert_called_once_with('instance') delete_method.assert_called_once_with() self.assertIsNone(res) @@ -534,7 +549,8 @@ def test_get_database_overridden_project_id(self, get_client): res = self.spanner_hook_no_default_project_id.get_database( project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, - database_id=SPANNER_DATABASE) + database_id=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -552,7 +568,8 @@ def test_create_database_overridden_project_id(self, get_client): project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - ddl_statements=[]) + ddl_statements=[], + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name', ddl_statements=[]) @@ -570,7 +587,8 @@ def test_update_database_overridden_project_id(self, get_client): project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - ddl_statements=[]) + ddl_statements=[], + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -589,7 +607,8 @@ def test_update_database_overridden_project_id_and_operation(self, get_client): instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, operation_id="operation", - ddl_statements=[]) + ddl_statements=[], + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -608,7 +627,8 @@ def test_delete_database_overridden_project_id(self, get_client): res = self.spanner_hook_no_default_project_id.delete_database( project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, - database_id=SPANNER_DATABASE) + database_id=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -628,7 +648,8 @@ def test_delete_database_missing_database(self, get_client): self.spanner_hook_no_default_project_id.delete_database( project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, - database_id=SPANNER_DATABASE) + database_id=SPANNER_DATABASE, + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') @@ -646,7 +667,8 @@ def test_execute_dml_overridden_project_id(self, get_client): project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - queries='') + queries='', + ) get_client.assert_called_once_with(project_id='example-project') instance_method.assert_called_once_with(instance_id='instance') database_method.assert_called_once_with(database_id='database-name') diff --git a/tests/providers/google/cloud/hooks/test_speech_to_text.py b/tests/providers/google/cloud/hooks/test_speech_to_text.py index aa5737c99794f..e64860fc9153a 100644 --- a/tests/providers/google/cloud/hooks/test_speech_to_text.py +++ b/tests/providers/google/cloud/hooks/test_speech_to_text.py @@ -39,15 +39,14 @@ def setUp(self): @patch( "airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook.client_info", - new_callable=PropertyMock + new_callable=PropertyMock, ) @patch("airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook._get_credentials") @patch("airflow.providers.google.cloud.hooks.speech_to_text.SpeechClient") def test_speech_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.gcp_speech_to_text_hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.gcp_speech_to_text_hook._client, result) diff --git a/tests/providers/google/cloud/hooks/test_stackdriver.py b/tests/providers/google/cloud/hooks/test_stackdriver.py index 7340b243bd25a..dd363e74c73c7 100644 --- a/tests/providers/google/cloud/hooks/test_stackdriver.py +++ b/tests/providers/google/cloud/hooks/test_stackdriver.py @@ -32,90 +32,67 @@ TEST_ALERT_POLICY_1 = { "combiner": "OR", "name": "projects/sd-project/alertPolicies/12345", - "creationRecord": { - "mutatedBy": "user123", - "mutateTime": "2020-01-01T00:00:00.000000Z" - }, + "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"}, "enabled": True, "displayName": "test display", "conditions": [ { "conditionThreshold": { "comparison": "COMPARISON_GT", - "aggregations": [ - { - "alignmentPeriod": "60s", - "perSeriesAligner": "ALIGN_RATE" - } - ] + "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}], }, "displayName": "Condition display", - "name": "projects/sd-project/alertPolicies/123/conditions/456" + "name": "projects/sd-project/alertPolicies/123/conditions/456", } - ] + ], } TEST_ALERT_POLICY_2 = { "combiner": "OR", "name": "projects/sd-project/alertPolicies/6789", - "creationRecord": { - "mutatedBy": "user123", - "mutateTime": "2020-01-01T00:00:00.000000Z" - }, + "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"}, "enabled": False, "displayName": "test display", "conditions": [ { "conditionThreshold": { "comparison": "COMPARISON_GT", - "aggregations": [ - { - "alignmentPeriod": "60s", - "perSeriesAligner": "ALIGN_RATE" - } - ] + "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}], }, "displayName": "Condition display", - "name": "projects/sd-project/alertPolicies/456/conditions/789" + "name": "projects/sd-project/alertPolicies/456/conditions/789", } - ] + ], } TEST_NOTIFICATION_CHANNEL_1 = { "displayName": "sd", "enabled": True, - "labels": { - "auth_token": "top-secret", - "channel_name": "#channel" - }, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, "name": "projects/sd-project/notificationChannels/12345", - "type": "slack" + "type": "slack", } TEST_NOTIFICATION_CHANNEL_2 = { "displayName": "sd", "enabled": False, - "labels": { - "auth_token": "top-secret", - "channel_name": "#channel" - }, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, "name": "projects/sd-project/notificationChannels/6789", - "type": "slack" + "type": "slack", } class TestStackdriverHookMethods(unittest.TestCase): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_policy_client') def test_stackdriver_list_alert_policies(self, mock_policy_client, mock_get_creds_and_project_id): method = mock_policy_client.return_value.list_alert_policies hook = stackdriver.StackdriverHook() hook.list_alert_policies( - filter_=TEST_FILTER, - project_id=PROJECT_ID, + filter_=TEST_FILTER, project_id=PROJECT_ID, ) method.assert_called_once_with( name='projects/{project}'.format(project=PROJECT_ID), @@ -124,12 +101,12 @@ def test_stackdriver_list_alert_policies(self, mock_policy_client, mock_get_cred timeout=DEFAULT, order_by=None, page_size=None, - metadata=None + metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_policy_client') def test_stackdriver_enable_alert_policy(self, mock_policy_client, mock_get_creds_and_project_id): @@ -142,8 +119,7 @@ def test_stackdriver_enable_alert_policy(self, mock_policy_client, mock_get_cred mock_policy_client.return_value.list_alert_policies.return_value = alert_policies hook.enable_alert_policies( - filter_=TEST_FILTER, - project_id=PROJECT_ID, + filter_=TEST_FILTER, project_id=PROJECT_ID, ) mock_policy_client.return_value.list_alert_policies.assert_called_once_with( name='projects/{project}'.format(project=PROJECT_ID), @@ -177,11 +153,10 @@ def test_stackdriver_disable_alert_policy(self, mock_policy_client, mock_get_cre mock_policy_client.return_value.list_alert_policies.return_value = [ alert_policy_enabled, - alert_policy_disabled + alert_policy_disabled, ] hook.disable_alert_policies( - filter_=TEST_FILTER, - project_id=PROJECT_ID, + filter_=TEST_FILTER, project_id=PROJECT_ID, ) mock_policy_client.return_value.list_alert_policies.assert_called_once_with( name='projects/{project}'.format(project=PROJECT_ID), @@ -209,8 +184,9 @@ def test_stackdriver_disable_alert_policy(self, mock_policy_client, mock_get_cre ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_policy_client') @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client') - def test_stackdriver_upsert_alert_policy(self, mock_channel_client, mock_policy_client, - mock_get_creds_and_project_id): + def test_stackdriver_upsert_alert_policy( + self, mock_channel_client, mock_policy_client, mock_get_creds_and_project_id + ): hook = stackdriver.StackdriverHook() existing_alert_policy = ParseDict(TEST_ALERT_POLICY_1, monitoring_v3.types.alert_pb2.AlertPolicy()) alert_policy_to_create = ParseDict(TEST_ALERT_POLICY_2, monitoring_v3.types.alert_pb2.AlertPolicy()) @@ -254,10 +230,7 @@ def test_stackdriver_upsert_alert_policy(self, mock_channel_client, mock_policy_ existing_alert_policy.ClearField('creation_record') existing_alert_policy.ClearField('mutation_record') mock_policy_client.return_value.update_alert_policy.assert_called_once_with( - alert_policy=existing_alert_policy, - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + alert_policy=existing_alert_policy, retry=DEFAULT, timeout=DEFAULT, metadata=None ) @mock.patch( @@ -267,14 +240,9 @@ def test_stackdriver_upsert_alert_policy(self, mock_channel_client, mock_policy_ @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_policy_client') def test_stackdriver_delete_alert_policy(self, mock_policy_client, mock_get_creds_and_project_id): hook = stackdriver.StackdriverHook() - hook.delete_alert_policy( - name='test-alert', - ) + hook.delete_alert_policy(name='test-alert',) mock_policy_client.return_value.delete_alert_policy.assert_called_once_with( - name='test-alert', - retry=DEFAULT, - timeout=DEFAULT, - metadata=None, + name='test-alert', retry=DEFAULT, timeout=DEFAULT, metadata=None, ) @mock.patch( @@ -285,8 +253,7 @@ def test_stackdriver_delete_alert_policy(self, mock_policy_client, mock_get_cred def test_stackdriver_list_notification_channel(self, mock_channel_client, mock_get_creds_and_project_id): hook = stackdriver.StackdriverHook() hook.list_notification_channels( - filter_=TEST_FILTER, - project_id=PROJECT_ID, + filter_=TEST_FILTER, project_id=PROJECT_ID, ) mock_channel_client.return_value.list_notification_channels.assert_called_once_with( name='projects/{project}'.format(project=PROJECT_ID), @@ -295,29 +262,31 @@ def test_stackdriver_list_notification_channel(self, mock_channel_client, mock_g page_size=None, retry=DEFAULT, timeout=DEFAULT, - metadata=None + metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client') - def test_stackdriver_enable_notification_channel(self, mock_channel_client, - mock_get_creds_and_project_id): + def test_stackdriver_enable_notification_channel( + self, mock_channel_client, mock_get_creds_and_project_id + ): hook = stackdriver.StackdriverHook() - notification_channel_enabled = ParseDict(TEST_NOTIFICATION_CHANNEL_1, - monitoring_v3.types.notification_pb2.NotificationChannel()) - notification_channel_disabled = ParseDict(TEST_NOTIFICATION_CHANNEL_2, - monitoring_v3.types.notification_pb2.NotificationChannel()) + notification_channel_enabled = ParseDict( + TEST_NOTIFICATION_CHANNEL_1, monitoring_v3.types.notification_pb2.NotificationChannel() + ) + notification_channel_disabled = ParseDict( + TEST_NOTIFICATION_CHANNEL_2, monitoring_v3.types.notification_pb2.NotificationChannel() + ) mock_channel_client.return_value.list_notification_channels.return_value = [ notification_channel_enabled, - notification_channel_disabled + notification_channel_disabled, ] hook.enable_notification_channels( - filter_=TEST_FILTER, - project_id=PROJECT_ID, + filter_=TEST_FILTER, project_id=PROJECT_ID, ) notification_channel_disabled.enabled.value = True # pylint: disable=no-member @@ -333,24 +302,26 @@ def test_stackdriver_enable_notification_channel(self, mock_channel_client, @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client') - def test_stackdriver_disable_notification_channel(self, mock_channel_client, - mock_get_creds_and_project_id): + def test_stackdriver_disable_notification_channel( + self, mock_channel_client, mock_get_creds_and_project_id + ): hook = stackdriver.StackdriverHook() - notification_channel_enabled = ParseDict(TEST_NOTIFICATION_CHANNEL_1, - monitoring_v3.types.notification_pb2.NotificationChannel()) - notification_channel_disabled = ParseDict(TEST_NOTIFICATION_CHANNEL_2, - monitoring_v3.types.notification_pb2.NotificationChannel()) + notification_channel_enabled = ParseDict( + TEST_NOTIFICATION_CHANNEL_1, monitoring_v3.types.notification_pb2.NotificationChannel() + ) + notification_channel_disabled = ParseDict( + TEST_NOTIFICATION_CHANNEL_2, monitoring_v3.types.notification_pb2.NotificationChannel() + ) mock_channel_client.return_value.list_notification_channels.return_value = [ notification_channel_enabled, - notification_channel_disabled + notification_channel_disabled, ] hook.disable_notification_channels( - filter_=TEST_FILTER, - project_id=PROJECT_ID, + filter_=TEST_FILTER, project_id=PROJECT_ID, ) notification_channel_enabled.enabled.value = False # pylint: disable=no-member @@ -366,16 +337,16 @@ def test_stackdriver_disable_notification_channel(self, mock_channel_client, @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client') def test_stackdriver_upsert_channel(self, mock_channel_client, mock_get_creds_and_project_id): hook = stackdriver.StackdriverHook() - existing_notification_channel = ParseDict(TEST_NOTIFICATION_CHANNEL_1, - monitoring_v3.types.notification_pb2.NotificationChannel()) + existing_notification_channel = ParseDict( + TEST_NOTIFICATION_CHANNEL_1, monitoring_v3.types.notification_pb2.NotificationChannel() + ) notification_channel_to_be_created = ParseDict( - TEST_NOTIFICATION_CHANNEL_2, - monitoring_v3.types.notification_pb2.NotificationChannel() + TEST_NOTIFICATION_CHANNEL_2, monitoring_v3.types.notification_pb2.NotificationChannel() ) mock_channel_client.return_value.list_notification_channels.return_value = [ existing_notification_channel @@ -394,10 +365,7 @@ def test_stackdriver_upsert_channel(self, mock_channel_client, mock_get_creds_an metadata=None, ) mock_channel_client.return_value.update_notification_channel.assert_called_once_with( - notification_channel=existing_notification_channel, - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + notification_channel=existing_notification_channel, retry=DEFAULT, timeout=DEFAULT, metadata=None ) notification_channel_to_be_created.ClearField('name') mock_channel_client.return_value.create_notification_channel.assert_called_once_with( @@ -405,23 +373,19 @@ def test_stackdriver_upsert_channel(self, mock_channel_client, mock_get_creds_an notification_channel=notification_channel_to_be_created, retry=DEFAULT, timeout=DEFAULT, - metadata=None + metadata=None, ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id', - return_value=(CREDENTIALS, PROJECT_ID) + return_value=(CREDENTIALS, PROJECT_ID), ) @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client') - def test_stackdriver_delete_notification_channel(self, mock_channel_client, - mock_get_creds_and_project_id): + def test_stackdriver_delete_notification_channel( + self, mock_channel_client, mock_get_creds_and_project_id + ): hook = stackdriver.StackdriverHook() - hook.delete_notification_channel( - name='test-channel', - ) + hook.delete_notification_channel(name='test-channel',) mock_channel_client.return_value.delete_notification_channel.assert_called_once_with( - name='test-channel', - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + name='test-channel', retry=DEFAULT, timeout=DEFAULT, metadata=None ) diff --git a/tests/providers/google/cloud/hooks/test_tasks.py b/tests/providers/google/cloud/hooks/test_tasks.py index 77eef26f3e3f0..1eb532e2a121a 100644 --- a/tests/providers/google/cloud/hooks/test_tasks.py +++ b/tests/providers/google/cloud/hooks/test_tasks.py @@ -32,9 +32,7 @@ QUEUE_ID = "test-queue" FULL_QUEUE_PATH = "projects/test-project/locations/asia-east2/queues/test-queue" TASK_NAME = "test-task" -FULL_TASK_PATH = ( - "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task" -) +FULL_TASK_PATH = "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task" class TestCloudTasksHook(unittest.TestCase): @@ -47,15 +45,14 @@ def setUp(self): @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.tasks.CloudTasksHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.tasks.CloudTasksClient") def test_cloud_tasks_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.hook._client, result) @@ -66,10 +63,7 @@ def test_cloud_tasks_client_creation(self, mock_client, mock_get_creds, mock_cli ) def test_create_queue(self, get_conn): result = self.hook.create_queue( - location=LOCATION, - task_queue=Queue(), - queue_name=QUEUE_ID, - project_id=PROJECT_ID, + location=LOCATION, task_queue=Queue(), queue_name=QUEUE_ID, project_id=PROJECT_ID, ) self.assertIs(result, API_RESPONSE) @@ -88,10 +82,7 @@ def test_create_queue(self, get_conn): ) def test_update_queue(self, get_conn): result = self.hook.update_queue( - task_queue=Queue(state=3), - location=LOCATION, - queue_name=QUEUE_ID, - project_id=PROJECT_ID, + task_queue=Queue(state=3), location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID, ) self.assertIs(result, API_RESPONSE) @@ -109,9 +100,7 @@ def test_update_queue(self, get_conn): **{"return_value.get_queue.return_value": API_RESPONSE}, # type: ignore ) def test_get_queue(self, get_conn): - result = self.hook.get_queue( - location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID - ) + result = self.hook.get_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) self.assertIs(result, API_RESPONSE) @@ -129,12 +118,7 @@ def test_list_queues(self, get_conn): self.assertEqual(result, list(API_RESPONSE)) get_conn.return_value.list_queues.assert_called_once_with( - parent=FULL_LOCATION_PATH, - filter_=None, - page_size=None, - retry=None, - timeout=None, - metadata=None, + parent=FULL_LOCATION_PATH, filter_=None, page_size=None, retry=None, timeout=None, metadata=None, ) @mock.patch( @@ -142,9 +126,7 @@ def test_list_queues(self, get_conn): **{"return_value.delete_queue.return_value": API_RESPONSE}, # type: ignore ) def test_delete_queue(self, get_conn): - result = self.hook.delete_queue( - location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID - ) + result = self.hook.delete_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) self.assertEqual(result, None) @@ -157,9 +139,7 @@ def test_delete_queue(self, get_conn): **{"return_value.purge_queue.return_value": API_RESPONSE}, # type: ignore ) def test_purge_queue(self, get_conn): - result = self.hook.purge_queue( - location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID - ) + result = self.hook.purge_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) self.assertEqual(result, API_RESPONSE) @@ -172,9 +152,7 @@ def test_purge_queue(self, get_conn): **{"return_value.pause_queue.return_value": API_RESPONSE}, # type: ignore ) def test_pause_queue(self, get_conn): - result = self.hook.pause_queue( - location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID - ) + result = self.hook.pause_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) self.assertEqual(result, API_RESPONSE) @@ -187,9 +165,7 @@ def test_pause_queue(self, get_conn): **{"return_value.resume_queue.return_value": API_RESPONSE}, # type: ignore ) def test_resume_queue(self, get_conn): - result = self.hook.resume_queue( - location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID - ) + result = self.hook.resume_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) self.assertEqual(result, API_RESPONSE) @@ -203,11 +179,7 @@ def test_resume_queue(self, get_conn): ) def test_create_task(self, get_conn): result = self.hook.create_task( - location=LOCATION, - queue_name=QUEUE_ID, - task=Task(), - project_id=PROJECT_ID, - task_name=TASK_NAME, + location=LOCATION, queue_name=QUEUE_ID, task=Task(), project_id=PROJECT_ID, task_name=TASK_NAME, ) self.assertEqual(result, API_RESPONSE) @@ -227,20 +199,13 @@ def test_create_task(self, get_conn): ) def test_get_task(self, get_conn): result = self.hook.get_task( - location=LOCATION, - queue_name=QUEUE_ID, - task_name=TASK_NAME, - project_id=PROJECT_ID, + location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, project_id=PROJECT_ID, ) self.assertEqual(result, API_RESPONSE) get_conn.return_value.get_task.assert_called_once_with( - name=FULL_TASK_PATH, - response_view=None, - retry=None, - timeout=None, - metadata=None, + name=FULL_TASK_PATH, response_view=None, retry=None, timeout=None, metadata=None, ) @mock.patch( @@ -248,9 +213,7 @@ def test_get_task(self, get_conn): **{"return_value.list_tasks.return_value": API_RESPONSE}, # type: ignore ) def test_list_tasks(self, get_conn): - result = self.hook.list_tasks( - location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID - ) + result = self.hook.list_tasks(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) self.assertEqual(result, list(API_RESPONSE)) @@ -269,10 +232,7 @@ def test_list_tasks(self, get_conn): ) def test_delete_task(self, get_conn): result = self.hook.delete_task( - location=LOCATION, - queue_name=QUEUE_ID, - task_name=TASK_NAME, - project_id=PROJECT_ID, + location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, project_id=PROJECT_ID, ) self.assertEqual(result, None) @@ -287,18 +247,11 @@ def test_delete_task(self, get_conn): ) def test_run_task(self, get_conn): result = self.hook.run_task( - location=LOCATION, - queue_name=QUEUE_ID, - task_name=TASK_NAME, - project_id=PROJECT_ID, + location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, project_id=PROJECT_ID, ) self.assertEqual(result, API_RESPONSE) get_conn.return_value.run_task.assert_called_once_with( - name=FULL_TASK_PATH, - response_view=None, - retry=None, - timeout=None, - metadata=None, + name=FULL_TASK_PATH, response_view=None, retry=None, timeout=None, metadata=None, ) diff --git a/tests/providers/google/cloud/hooks/test_text_to_speech.py b/tests/providers/google/cloud/hooks/test_text_to_speech.py index b741e44915769..8ce7b6102f18e 100644 --- a/tests/providers/google/cloud/hooks/test_text_to_speech.py +++ b/tests/providers/google/cloud/hooks/test_text_to_speech.py @@ -39,15 +39,14 @@ def setUp(self): @patch( "airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook.client_info", - new_callable=PropertyMock + new_callable=PropertyMock, ) @patch("airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook._get_credentials") @patch("airflow.providers.google.cloud.hooks.text_to_speech.TextToSpeechClient") def test_text_to_speech_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.gcp_text_to_speech_hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.gcp_text_to_speech_hook._client, result) diff --git a/tests/providers/google/cloud/hooks/test_translate.py b/tests/providers/google/cloud/hooks/test_translate.py index cffb99361c009..8cd7fca06b36a 100644 --- a/tests/providers/google/cloud/hooks/test_translate.py +++ b/tests/providers/google/cloud/hooks/test_translate.py @@ -36,15 +36,14 @@ def setUp(self): @mock.patch( "airflow.providers.google.cloud.hooks.translate.CloudTranslateHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.translate.CloudTranslateHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.translate.Client") def test_translate_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.hook._client, result) diff --git a/tests/providers/google/cloud/hooks/test_video_intelligence.py b/tests/providers/google/cloud/hooks/test_video_intelligence.py index 26cf267167d72..82d30e3a8f86f 100644 --- a/tests/providers/google/cloud/hooks/test_video_intelligence.py +++ b/tests/providers/google/cloud/hooks/test_video_intelligence.py @@ -42,7 +42,7 @@ def setUp(self): @mock.patch( "airflow.providers.google.cloud.hooks.video_intelligence.CloudVideoIntelligenceHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch( "airflow.providers.google.cloud.hooks.video_intelligence.CloudVideoIntelligenceHook._get_credentials" @@ -51,8 +51,7 @@ def setUp(self): def test_video_intelligence_service_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.hook._conn, result) diff --git a/tests/providers/google/cloud/hooks/test_vision.py b/tests/providers/google/cloud/hooks/test_vision.py index 7c3b95d74e195..bcbcebaac98c2 100644 --- a/tests/providers/google/cloud/hooks/test_vision.py +++ b/tests/providers/google/cloud/hooks/test_vision.py @@ -21,7 +21,9 @@ from google.cloud.vision import enums from google.cloud.vision_v1 import ProductSearchClient from google.cloud.vision_v1.proto.image_annotator_pb2 import ( - AnnotateImageResponse, EntityAnnotation, SafeSearchAnnotation, + AnnotateImageResponse, + EntityAnnotation, + SafeSearchAnnotation, ) from google.cloud.vision_v1.proto.product_search_service_pb2 import Product, ProductSet, ReferenceImage from google.protobuf.json_format import MessageToDict @@ -60,7 +62,7 @@ { 'image': {'source': {'image_uri': "gs://bucket-name/object-name"}}, 'features': [{'type': enums.Feature.Type.LOGO_DETECTION}], - } + }, ] REFERENCE_IMAGE_NAME_TEST = "projects/{}/locations/{}/products/{}/referenceImages/{}".format( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST, REFERENCE_IMAGE_ID_TEST @@ -81,15 +83,14 @@ def setUp(self): @mock.patch( "airflow.providers.google.cloud.hooks.vision.CloudVisionHook.client_info", - new_callable=mock.PropertyMock + new_callable=mock.PropertyMock, ) @mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook._get_credentials") @mock.patch("airflow.providers.google.cloud.hooks.vision.ProductSearchClient") def test_product_search_client_creation(self, mock_client, mock_get_creds, mock_client_info): result = self.hook.get_conn() mock_client.assert_called_once_with( - credentials=mock_get_creds.return_value, - client_info=mock_client_info.return_value + credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value ) self.assertEqual(mock_client.return_value, result) self.assertEqual(self.hook._client, result) @@ -253,10 +254,7 @@ def test_update_productset_no_explicit_name_and_missing_params_for_constructed_n ) err = cm.exception self.assertTrue(err) - self.assertIn( - ERR_UNABLE_TO_CREATE.format(label='ProductSet', id_label='productset_id'), - str(err) - ) + self.assertIn(ERR_UNABLE_TO_CREATE.format(label='ProductSet', id_label='productset_id'), str(err)) update_product_set_method.assert_not_called() @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)]) @@ -323,11 +321,13 @@ def test_update_productset_explicit_name_different_from_constructed(self, get_co # self.assertIn("The required parameter 'project_id' is missing", str(err)) self.assertTrue(err) self.assertIn( - ERR_DIFF_NAMES.format(explicit_name=explicit_ps_name, - constructed_name=template_ps_name, - label="ProductSet", id_label="productset_id" - ), - str(err) + ERR_DIFF_NAMES.format( + explicit_name=explicit_ps_name, + constructed_name=template_ps_name, + label="ProductSet", + id_label="productset_id", + ), + str(err), ) update_product_set_method.assert_not_called() @@ -347,7 +347,7 @@ def test_delete_productset(self, get_conn): @mock.patch( 'airflow.providers.google.cloud.hooks.vision.CloudVisionHook.get_conn', - **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST} + **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST}, ) def test_create_reference_image_explicit_id(self, get_conn): # Given @@ -375,7 +375,7 @@ def test_create_reference_image_explicit_id(self, get_conn): @mock.patch( 'airflow.providers.google.cloud.hooks.vision.CloudVisionHook.get_conn', - **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST} + **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST}, ) def test_create_reference_image_autogenerated_id(self, get_conn): # Given @@ -601,8 +601,7 @@ def test_update_product_no_explicit_name_and_missing_params_for_constructed_name err = cm.exception self.assertTrue(err) self.assertIn( - ERR_UNABLE_TO_CREATE.format(label='Product', id_label='product_id'), - str(err), + ERR_UNABLE_TO_CREATE.format(label='Product', id_label='product_id'), str(err), ) update_product_method.assert_not_called() @@ -663,10 +662,12 @@ def test_update_product_explicit_name_different_from_constructed(self, get_conn) err = cm.exception self.assertTrue(err) self.assertIn( - ERR_DIFF_NAMES.format(explicit_name=explicit_p_name, - constructed_name=template_p_name, - label="Product", id_label="product_id" - ), + ERR_DIFF_NAMES.format( + explicit_name=explicit_p_name, + constructed_name=template_p_name, + label="Product", + id_label="product_id", + ), str(err), ) update_product_method.assert_not_called() diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py index 0ac5a542f0388..d17cb500d0996 100644 --- a/tests/providers/google/cloud/log/test_gcs_task_handler.py +++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py @@ -164,7 +164,7 @@ def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_cred 'INFO:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Previous ' 'log discarded: sequence item 0: expected str instance, bytes found', 'ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could ' - 'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect' + 'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect', ], ) mock_blob.assert_has_calls( diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler_system.py b/tests/providers/google/cloud/log/test_gcs_task_handler_system.py index 7bedca6dc1a21..aecced21d9031 100644 --- a/tests/providers/google/cloud/log/test_gcs_task_handler_system.py +++ b/tests/providers/google/cloud/log/test_gcs_task_handler_system.py @@ -31,14 +31,15 @@ from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_connections, clear_db_runs from tests.test_utils.gcp_system_helpers import ( - GoogleSystemTest, provide_gcp_context, resolve_full_gcp_key_path, + GoogleSystemTest, + provide_gcp_context, + resolve_full_gcp_key_path, ) @pytest.mark.system("google") @pytest.mark.credential_file(GCP_GCS_KEY) class TestGCSTaskHandlerSystemTest(GoogleSystemTest): - @classmethod def setUpClass(cls) -> None: unique_suffix = ''.join(random.sample(string.ascii_lowercase, 16)) @@ -55,6 +56,7 @@ def setUp(self) -> None: def tearDown(self) -> None: from airflow.config_templates import airflow_local_settings + importlib.reload(airflow_local_settings) settings.configure_logging() clear_db_runs() @@ -68,14 +70,10 @@ def test_should_read_logs(self, session): AIRFLOW__LOGGING__REMOTE_LOG_CONN_ID="google_cloud_default", AIRFLOW__CORE__LOAD_EXAMPLES="false", AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__, - GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_GCS_KEY) + GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_GCS_KEY), ): - self.assertEqual(0, subprocess.Popen( - ["airflow", "dags", "trigger", "example_complex"] - ).wait()) - self.assertEqual(0, subprocess.Popen( - ["airflow", "scheduler", "--num-runs", "1"] - ).wait()) + self.assertEqual(0, subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait()) + self.assertEqual(0, subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()) ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first() dag = DagBag(dag_folder=example_complex.__file__).dags['example_complex'] @@ -84,12 +82,15 @@ def test_should_read_logs(self, session): self.assert_remote_logs("INFO - Task exited with return code 0", ti) def assert_remote_logs(self, expected_message, ti): - with provide_gcp_context(GCP_GCS_KEY), conf_vars({ - ('logging', 'remote_logging'): 'True', - ('logging', 'remote_base_log_folder'): f"gs://{self.bucket_name}/path/to/logs", - ('logging', 'remote_log_conn_id'): "google_cloud_default", - }): + with provide_gcp_context(GCP_GCS_KEY), conf_vars( + { + ('logging', 'remote_logging'): 'True', + ('logging', 'remote_base_log_folder'): f"gs://{self.bucket_name}/path/to/logs", + ('logging', 'remote_log_conn_id'): "google_cloud_default", + } + ): from airflow.config_templates import airflow_local_settings + importlib.reload(airflow_local_settings) settings.configure_logging() diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py index 565f2f50e2904..0ed1c252e2d3b 100644 --- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py +++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py @@ -36,17 +36,13 @@ def _create_list_response(messages, token): class TestStackdriverLoggingHandlerStandalone(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client') def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') transport_type = mock.MagicMock() - stackdriver_task_handler = StackdriverTaskHandler( - transport=transport_type, - labels={"key": 'value'} - ) + stackdriver_task_handler = StackdriverTaskHandler(transport=transport_type, labels={"key": 'value'}) logger = logging.getLogger("logger") logger.addHandler(stackdriver_task_handler) @@ -57,19 +53,13 @@ def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_pro transport_type.return_value.send.assert_called_once_with( mock.ANY, 'test-message', labels={"key": 'value'}, resource=Resource(type='global', labels={}) ) - mock_client.assert_called_once_with( - credentials='creds', - client_info=mock.ANY, - project="project_id" - ) + mock_client.assert_called_once_with(credentials='creds', client_info=mock.ANY, project="project_id") class TestStackdriverLoggingHandlerTask(unittest.TestCase): def setUp(self) -> None: self.transport_mock = mock.MagicMock() - self.stackdriver_task_handler = StackdriverTaskHandler( - transport=self.transport_mock - ) + self.stackdriver_task_handler = StackdriverTaskHandler(transport=self.transport_mock) self.logger = logging.getLogger("logger") date = datetime(2016, 1, 1) @@ -95,7 +85,7 @@ def test_should_set_labels(self, mock_client, mock_get_creds_and_project_id): 'task_id': 'task_for_testing_file_log_handler', 'dag_id': 'dag_for_testing_file_task_handler', 'execution_date': '2016-01-01T00:00:00+00:00', - 'try_number': '1' + 'try_number': '1', } resource = Resource(type='global', labels={}) self.transport_mock.return_value.send.assert_called_once_with( @@ -107,8 +97,7 @@ def test_should_set_labels(self, mock_client, mock_get_creds_and_project_id): def test_should_append_labels(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') self.stackdriver_task_handler = StackdriverTaskHandler( - transport=self.transport_mock, - labels={"product.googleapis.com/task_id": "test-value"} + transport=self.transport_mock, labels={"product.googleapis.com/task_id": "test-value"} ) self.stackdriver_task_handler.set_context(self.ti) self.logger.addHandler(self.stackdriver_task_handler) @@ -121,7 +110,7 @@ def test_should_append_labels(self, mock_client, mock_get_creds_and_project_id): 'dag_id': 'dag_for_testing_file_task_handler', 'execution_date': '2016-01-01T00:00:00+00:00', 'try_number': '1', - 'product.googleapis.com/task_id': 'test-value' + 'product.googleapis.com/task_id': 'test-value', } resource = Resource(type='global', labels={}) self.transport_mock.return_value.send.assert_called_once_with( @@ -131,7 +120,7 @@ def test_should_append_labels(self, mock_client, mock_get_creds_and_project_id): @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @mock.patch( 'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client', - **{'return_value.project': 'asf-project'} # type: ignore + **{'return_value.project': 'asf-project'}, # type: ignore ) def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_project_id): mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None) @@ -140,11 +129,11 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj logs, metadata = self.stackdriver_task_handler.read(self.ti) mock_client.return_value.list_entries.assert_called_once_with( filter_='resource.type="global"\n' - 'logName="projects/asf-project/logs/airflow"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' - 'labels.execution_date="2016-01-01T00:00:00+00:00"', - page_token=None + 'logName="projects/asf-project/logs/airflow"\n' + 'labels.task_id="task_for_testing_file_log_handler"\n' + 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.execution_date="2016-01-01T00:00:00+00:00"', + page_token=None, ) self.assertEqual(['MSG1\nMSG2'], logs) self.assertEqual([{'end_of_log': True}], metadata) @@ -152,7 +141,7 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @mock.patch( 'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client', - **{'return_value.project': 'asf-project'} # type: ignore + **{'return_value.project': 'asf-project'}, # type: ignore ) def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_and_project_id): mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None) @@ -161,11 +150,11 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_ logs, metadata = self.stackdriver_task_handler.read(self.ti) mock_client.return_value.list_entries.assert_called_once_with( filter_='resource.type="global"\n' - 'logName="projects/asf-project/logs/airflow"\n' - 'labels.task_id="K\\"OT"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' - 'labels.execution_date="2016-01-01T00:00:00+00:00"', - page_token=None + 'logName="projects/asf-project/logs/airflow"\n' + 'labels.task_id="K\\"OT"\n' + 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.execution_date="2016-01-01T00:00:00+00:00"', + page_token=None, ) self.assertEqual(['MSG1\nMSG2'], logs) self.assertEqual([{'end_of_log': True}], metadata) @@ -173,7 +162,7 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_ @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @mock.patch( 'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client', - **{'return_value.project': 'asf-project'} # type: ignore + **{'return_value.project': 'asf-project'}, # type: ignore ) def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_project_id): mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None) @@ -182,12 +171,12 @@ def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_p logs, metadata = self.stackdriver_task_handler.read(self.ti, 3) mock_client.return_value.list_entries.assert_called_once_with( filter_='resource.type="global"\n' - 'logName="projects/asf-project/logs/airflow"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' - 'labels.execution_date="2016-01-01T00:00:00+00:00"\n' - 'labels.try_number="3"', - page_token=None + 'logName="projects/asf-project/logs/airflow"\n' + 'labels.task_id="task_for_testing_file_log_handler"\n' + 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.execution_date="2016-01-01T00:00:00+00:00"\n' + 'labels.try_number="3"', + page_token=None, ) self.assertEqual(['MSG1\nMSG2'], logs) self.assertEqual([{'end_of_log': True}], metadata) @@ -201,17 +190,13 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_ ] mock_get_creds_and_project_id.return_value = ('creds', 'project_id') logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3) - mock_client.return_value.list_entries.assert_called_once_with( - filter_=mock.ANY, page_token=None - ) + mock_client.return_value.list_entries.assert_called_once_with(filter_=mock.ANY, page_token=None) self.assertEqual(['MSG1\nMSG2'], logs) self.assertEqual([{'end_of_log': False, 'next_page_token': 'TOKEN1'}], metadata1) mock_client.return_value.list_entries.return_value.next_page_token = None logs, metadata2 = self.stackdriver_task_handler.read(self.ti, 3, metadata1[0]) - mock_client.return_value.list_entries.assert_called_with( - filter_=mock.ANY, page_token="TOKEN1" - ) + mock_client.return_value.list_entries.assert_called_with(filter_=mock.ANY, page_token="TOKEN1") self.assertEqual(['MSG3\nMSG4'], logs) self.assertEqual([{'end_of_log': True}], metadata2) @@ -232,7 +217,7 @@ def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_pr @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @mock.patch( 'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client', - **{'return_value.project': 'asf-project'} # type: ignore + **{'return_value.project': 'asf-project'}, # type: ignore ) def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') @@ -245,8 +230,7 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred }, ) self.stackdriver_task_handler = StackdriverTaskHandler( - transport=self.transport_mock, - resource=resource + transport=self.transport_mock, resource=resource ) entry = mock.MagicMock(payload={"message": "TEXT"}) @@ -257,14 +241,14 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred logs, metadata = self.stackdriver_task_handler.read(self.ti) mock_client.return_value.list_entries.assert_called_once_with( filter_='resource.type="cloud_composer_environment"\n' - 'logName="projects/asf-project/logs/airflow"\n' - 'resource.labels."environment.name"="test-instancce"\n' - 'resource.labels.location="europpe-west-3"\n' - 'resource.labels.project_id="asf-project"\n' - 'labels.task_id="task_for_testing_file_log_handler"\n' - 'labels.dag_id="dag_for_testing_file_task_handler"\n' - 'labels.execution_date="2016-01-01T00:00:00+00:00"', - page_token=None + 'logName="projects/asf-project/logs/airflow"\n' + 'resource.labels."environment.name"="test-instancce"\n' + 'resource.labels.location="europpe-west-3"\n' + 'resource.labels.project_id="asf-project"\n' + 'labels.task_id="task_for_testing_file_log_handler"\n' + 'labels.dag_id="dag_for_testing_file_task_handler"\n' + 'labels.execution_date="2016-01-01T00:00:00+00:00"', + page_token=None, ) self.assertEqual(['TEXT\nTEXT'], logs) self.assertEqual([{'end_of_log': True}], metadata) @@ -274,9 +258,7 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ('creds', 'project_id') - stackdriver_task_handler = StackdriverTaskHandler( - gcp_key_path="KEY_PATH", - ) + stackdriver_task_handler = StackdriverTaskHandler(gcp_key_path="KEY_PATH",) client = stackdriver_task_handler._client @@ -286,15 +268,11 @@ def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id scopes=frozenset( { 'https://www.googleapis.com/auth/logging.write', - 'https://www.googleapis.com/auth/logging.read' + 'https://www.googleapis.com/auth/logging.read', } - ) - ) - mock_client.assert_called_once_with( - credentials='creds', - client_info=mock.ANY, - project="project_id" + ), ) + mock_client.assert_called_once_with(credentials='creds', client_info=mock.ANY, project="project_id") self.assertEqual(mock_client.return_value, client) @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id') @@ -303,9 +281,7 @@ def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_ mock_get_creds_and_project_id.return_value = ('creds', 'project_id') mock_client.return_value.project = 'project_id' - stackdriver_task_handler = StackdriverTaskHandler( - gcp_key_path="KEY_PATH", - ) + stackdriver_task_handler = StackdriverTaskHandler(gcp_key_path="KEY_PATH",) url = stackdriver_task_handler.get_external_log_url(self.ti, self.ti.try_number) @@ -318,10 +294,12 @@ def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_ self.assertIn('global', parsed_qs['resource']) filter_params = parsed_qs['advancedFilter'][0].split('\n') - expected_filter = ['resource.type="global"', - 'logName="projects/project_id/logs/airflow"', - f'labels.task_id="{self.ti.task_id}"', - f'labels.dag_id="{self.dag.dag_id}"', - f'labels.execution_date="{self.ti.execution_date.isoformat()}"', - f'labels.try_number="{self.ti.try_number}"'] + expected_filter = [ + 'resource.type="global"', + 'logName="projects/project_id/logs/airflow"', + f'labels.task_id="{self.ti.task_id}"', + f'labels.dag_id="{self.dag.dag_id}"', + f'labels.execution_date="{self.ti.execution_date.isoformat()}"', + f'labels.try_number="{self.ti.try_number}"', + ] self.assertCountEqual(expected_filter, filter_params) diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py index d35ffd9b7ddc8..3b643acba2020 100644 --- a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py +++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py @@ -37,13 +37,13 @@ @pytest.mark.system("google") @pytest.mark.credential_file(GCP_STACKDDRIVER) class TestStackdriverLoggingHandlerSystemTest(unittest.TestCase): - def setUp(self) -> None: clear_db_runs() self.log_name = 'stackdriver-tests-'.join(random.sample(string.ascii_lowercase, 16)) def tearDown(self) -> None: from airflow.config_templates import airflow_local_settings + importlib.reload(airflow_local_settings) settings.configure_logging() clear_db_runs() @@ -56,14 +56,10 @@ def test_should_support_key_auth(self, session): AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER=f"stackdriver://{self.log_name}", AIRFLOW__LOGGING__GOOGLE_KEY_PATH=resolve_full_gcp_key_path(GCP_STACKDDRIVER), AIRFLOW__CORE__LOAD_EXAMPLES="false", - AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__ + AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__, ): - self.assertEqual(0, subprocess.Popen( - ["airflow", "dags", "trigger", "example_complex"] - ).wait()) - self.assertEqual(0, subprocess.Popen( - ["airflow", "scheduler", "--num-runs", "1"] - ).wait()) + self.assertEqual(0, subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait()) + self.assertEqual(0, subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()) ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first() self.assert_remote_logs("INFO - Task exited with return code 0", ti) @@ -76,24 +72,23 @@ def test_should_support_adc(self, session): AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER=f"stackdriver://{self.log_name}", AIRFLOW__CORE__LOAD_EXAMPLES="false", AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__, - GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_STACKDDRIVER) + GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_STACKDDRIVER), ): - self.assertEqual(0, subprocess.Popen( - ["airflow", "dags", "trigger", "example_complex"] - ).wait()) - self.assertEqual(0, subprocess.Popen( - ["airflow", "scheduler", "--num-runs", "1"] - ).wait()) + self.assertEqual(0, subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait()) + self.assertEqual(0, subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()) ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first() self.assert_remote_logs("INFO - Task exited with return code 0", ti) def assert_remote_logs(self, expected_message, ti): - with provide_gcp_context(GCP_STACKDDRIVER), conf_vars({ - ('logging', 'remote_logging'): 'True', - ('logging', 'remote_base_log_folder'): f"stackdriver://{self.log_name}", - }): + with provide_gcp_context(GCP_STACKDDRIVER), conf_vars( + { + ('logging', 'remote_logging'): 'True', + ('logging', 'remote_base_log_folder'): f"stackdriver://{self.log_name}", + } + ): from airflow.config_templates import airflow_local_settings + importlib.reload(airflow_local_settings) settings.configure_logging() diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index 73b081b5f9983..262a903e28bf2 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -23,10 +23,19 @@ from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient from airflow.providers.google.cloud.operators.automl import ( - AutoMLBatchPredictOperator, AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, - AutoMLDeleteModelOperator, AutoMLDeployModelOperator, AutoMLGetModelOperator, AutoMLImportDataOperator, - AutoMLListDatasetOperator, AutoMLPredictOperator, AutoMLTablesListColumnSpecsOperator, - AutoMLTablesListTableSpecsOperator, AutoMLTablesUpdateDatasetOperator, AutoMLTrainModelOperator, + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLDeployModelOperator, + AutoMLGetModelOperator, + AutoMLImportDataOperator, + AutoMLListDatasetOperator, + AutoMLPredictOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, + AutoMLTrainModelOperator, ) CREDENTIALS = "test-creds" @@ -58,10 +67,7 @@ class TestAutoMLTrainModelOperator(unittest.TestCase): def test_execute(self, mock_hook, mock_xcom): mock_hook.return_value.extract_object_id.return_value = MODEL_ID op = AutoMLTrainModelOperator( - model=MODEL, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, + model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) op.execute(context=None) mock_hook.return_value.create_model.assert_called_once_with( @@ -129,10 +135,7 @@ class TestAutoMLCreateImportOperator(unittest.TestCase): def test_execute(self, mock_hook, mock_xcom): mock_hook.return_value.extract_object_id.return_value = DATASET_ID op = AutoMLCreateDatasetOperator( - dataset=DATASET, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, + dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) op.execute(context=None) mock_hook.return_value.create_dataset.assert_called_once_with( @@ -185,18 +188,11 @@ def test_execute(self, mock_hook): dataset["name"] = DATASET_ID op = AutoMLTablesUpdateDatasetOperator( - dataset=dataset, - update_mask=MASK, - location=GCP_LOCATION, - task_id=TASK_ID, + dataset=dataset, update_mask=MASK, location=GCP_LOCATION, task_id=TASK_ID, ) op.execute(context=None) mock_hook.return_value.update_dataset.assert_called_once_with( - dataset=dataset, - metadata=None, - retry=None, - timeout=None, - update_mask=MASK, + dataset=dataset, metadata=None, retry=None, timeout=None, update_mask=MASK, ) @@ -204,10 +200,7 @@ class TestAutoMLGetModelOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): op = AutoMLGetModelOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, + model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) op.execute(context=None) mock_hook.return_value.get_model.assert_called_once_with( @@ -224,10 +217,7 @@ class TestAutoMLDeleteModelOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): op = AutoMLDeleteModelOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, + model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) op.execute(context=None) mock_hook.return_value.delete_model.assert_called_once_with( @@ -316,16 +306,10 @@ class TestAutoMLDatasetListOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook, mock_xcom): - op = AutoMLListDatasetOperator( - location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID - ) + op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID) op.execute(context=None) mock_hook.return_value.list_datasets.assert_called_once_with( - location=GCP_LOCATION, - metadata=None, - project_id=GCP_PROJECT_ID, - retry=None, - timeout=None, + location=GCP_LOCATION, metadata=None, project_id=GCP_PROJECT_ID, retry=None, timeout=None, ) mock_xcom.assert_called_once_with(None, key="dataset_id_list", value=[]) @@ -334,10 +318,7 @@ class TestAutoMLDatasetDeleteOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): op = AutoMLDeleteDatasetOperator( - dataset_id=DATASET_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, + dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) op.execute(context=None) mock_hook.return_value.delete_dataset.assert_called_once_with( diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 246c8ec282a6a..0f393e25418d5 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -29,12 +29,24 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, TaskFail, TaskInstance, XCom from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCheckOperator, BigQueryConsoleIndexableLink, BigQueryConsoleLink, - BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator, - BigQueryDeleteDatasetOperator, BigQueryDeleteTableOperator, BigQueryExecuteQueryOperator, - BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, - BigQueryInsertJobOperator, BigQueryIntervalCheckOperator, BigQueryPatchDatasetOperator, - BigQueryUpdateDatasetOperator, BigQueryUpsertTableOperator, BigQueryValueCheckOperator, + BigQueryCheckOperator, + BigQueryConsoleIndexableLink, + BigQueryConsoleLink, + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryDeleteTableOperator, + BigQueryExecuteQueryOperator, + BigQueryGetDataOperator, + BigQueryGetDatasetOperator, + BigQueryGetDatasetTablesOperator, + BigQueryInsertJobOperator, + BigQueryIntervalCheckOperator, + BigQueryPatchDatasetOperator, + BigQueryUpdateDatasetOperator, + BigQueryUpsertTableOperator, + BigQueryValueCheckOperator, ) from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import Session @@ -51,44 +63,34 @@ TEST_SOURCE_FORMAT = 'CSV' DEFAULT_DATE = datetime(2015, 1, 1) TEST_DAG_ID = 'test-bigquery-operators' -TEST_TABLE_RESOURCES = { - "tableReference": { - "tableId": TEST_TABLE_ID - }, - "expirationTime": 1234567 -} +TEST_TABLE_RESOURCES = {"tableReference": {"tableId": TEST_TABLE_ID}, "expirationTime": 1234567} VIEW_DEFINITION = { "query": "SELECT * FROM `{}.{}`".format(TEST_DATASET, TEST_TABLE_ID), - "useLegacySql": False + "useLegacySql": False, } class TestBigQueryCreateEmptyTableOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook): operator = BigQueryCreateEmptyTableOperator( - task_id=TASK_ID, - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID + task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, table_id=TEST_TABLE_ID ) operator.execute(None) - mock_hook.return_value.create_empty_table \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=None, - time_partitioning={}, - cluster_fields=None, - labels=None, - view=None, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=None, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_create_view(self, mock_hook): @@ -97,49 +99,33 @@ def test_create_view(self, mock_hook): dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, table_id=TEST_TABLE_ID, - view=VIEW_DEFINITION + view=VIEW_DEFINITION, ) operator.execute(None) - mock_hook.return_value.create_empty_table \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=None, - time_partitioning={}, - cluster_fields=None, - labels=None, - view=VIEW_DEFINITION, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=None, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=VIEW_DEFINITION, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_create_clustered_empty_table(self, mock_hook): schema_fields = [ - { - "name": "emp_name", - "type": "STRING", - "mode": "REQUIRED" - }, - { - "name": "date_hired", - "type": "DATE", - "mode": "REQUIRED" - }, - { - "name": "date_birth", - "type": "DATE", - "mode": "NULLABLE" - } + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "date_hired", "type": "DATE", "mode": "REQUIRED"}, + {"name": "date_birth", "type": "DATE", "mode": "NULLABLE"}, ] - time_partitioning = { - "type": "DAY", - "field": "date_hired" - } + time_partitioning = {"type": "DAY", "field": "date_hired"} cluster_fields = ["date_birth"] operator = BigQueryCreateEmptyTableOperator( task_id=TASK_ID, @@ -148,63 +134,56 @@ def test_create_clustered_empty_table(self, mock_hook): table_id=TEST_TABLE_ID, schema_fields=schema_fields, time_partitioning=time_partitioning, - cluster_fields=cluster_fields + cluster_fields=cluster_fields, ) operator.execute(None) - mock_hook.return_value.create_empty_table \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=schema_fields, - time_partitioning=time_partitioning, - cluster_fields=cluster_fields, - labels=None, - view=None, - encryption_configuration=None, - table_resource=None, - exists_ok=False, - ) + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=schema_fields, + time_partitioning=time_partitioning, + cluster_fields=cluster_fields, + labels=None, + view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) class TestBigQueryCreateExternalTableOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook): operator = BigQueryCreateExternalTableOperator( task_id=TASK_ID, - destination_project_dataset_table='{}.{}'.format( - TEST_DATASET, TEST_TABLE_ID - ), + destination_project_dataset_table='{}.{}'.format(TEST_DATASET, TEST_TABLE_ID), schema_fields=[], bucket=TEST_GCS_BUCKET, source_objects=TEST_GCS_DATA, - source_format=TEST_SOURCE_FORMAT + source_format=TEST_SOURCE_FORMAT, ) operator.execute(None) - mock_hook.return_value \ - .create_external_table \ - .assert_called_once_with( - external_project_dataset_table='{}.{}'.format( - TEST_DATASET, TEST_TABLE_ID - ), - schema_fields=[], - source_uris=['gs://{}/{}'.format(TEST_GCS_BUCKET, source_object) - for source_object in TEST_GCS_DATA], - source_format=TEST_SOURCE_FORMAT, - compression='NONE', - skip_leading_rows=0, - field_delimiter=',', - max_bad_records=0, - quote_character=None, - allow_quoted_newlines=False, - allow_jagged_rows=False, - src_fmt_configs={}, - labels=None, - encryption_configuration=None - ) + mock_hook.return_value.create_external_table.assert_called_once_with( + external_project_dataset_table='{}.{}'.format(TEST_DATASET, TEST_TABLE_ID), + schema_fields=[], + source_uris=[ + 'gs://{}/{}'.format(TEST_GCS_BUCKET, source_object) for source_object in TEST_GCS_DATA + ], + source_format=TEST_SOURCE_FORMAT, + compression='NONE', + skip_leading_rows=0, + field_delimiter=',', + max_bad_records=0, + quote_character=None, + allow_quoted_newlines=False, + allow_jagged_rows=False, + src_fmt_configs={}, + labels=None, + encryption_configuration=None, + ) class TestBigQueryDeleteDatasetOperator(unittest.TestCase): @@ -214,17 +193,13 @@ def test_execute(self, mock_hook): task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, - delete_contents=TEST_DELETE_CONTENTS + delete_contents=TEST_DELETE_CONTENTS, ) operator.execute(None) - mock_hook.return_value \ - .delete_dataset \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - delete_contents=TEST_DELETE_CONTENTS - ) + mock_hook.return_value.delete_dataset.assert_called_once_with( + dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, delete_contents=TEST_DELETE_CONTENTS + ) class TestBigQueryCreateEmptyDatasetOperator(unittest.TestCase): @@ -234,7 +209,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, - location=TEST_DATASET_LOCATION + location=TEST_DATASET_LOCATION, ) operator.execute(None) @@ -251,18 +226,13 @@ class TestBigQueryGetDatasetOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook): operator = BigQueryGetDatasetOperator( - task_id=TASK_ID, - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID + task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID ) operator.execute(None) - mock_hook.return_value \ - .get_dataset \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID - ) + mock_hook.return_value.get_dataset.assert_called_once_with( + dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID + ) class TestBigQueryPatchDatasetOperator(unittest.TestCase): @@ -273,17 +243,13 @@ def test_execute(self, mock_hook): dataset_resource=dataset_resource, task_id=TASK_ID, dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID + project_id=TEST_GCP_PROJECT_ID, ) operator.execute(None) - mock_hook.return_value \ - .patch_dataset \ - .assert_called_once_with( - dataset_resource=dataset_resource, - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID - ) + mock_hook.return_value.patch_dataset.assert_called_once_with( + dataset_resource=dataset_resource, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID + ) class TestBigQueryUpdateDatasetOperator(unittest.TestCase): @@ -294,33 +260,28 @@ def test_execute(self, mock_hook): dataset_resource=dataset_resource, task_id=TASK_ID, dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID + project_id=TEST_GCP_PROJECT_ID, ) operator.execute(None) - mock_hook.return_value \ - .update_dataset \ - .assert_called_once_with( - dataset_resource=dataset_resource, - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - fields=list(dataset_resource.keys()) - ) + mock_hook.return_value.update_dataset.assert_called_once_with( + dataset_resource=dataset_resource, + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + fields=list(dataset_resource.keys()), + ) class TestBigQueryOperator(unittest.TestCase): def setUp(self): - self.dagbag = models.DagBag( - dag_folder='/dev/null', include_examples=True) + self.dagbag = models.DagBag(dag_folder='/dev/null', include_examples=True) self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) def tearDown(self): session = Session() - session.query(models.TaskInstance).filter_by( - dag_id=TEST_DAG_ID).delete() - session.query(TaskFail).filter_by( - dag_id=TEST_DAG_ID).delete() + session.query(models.TaskInstance).filter_by(dag_id=TEST_DAG_ID).delete() + session.query(TaskFail).filter_by(dag_id=TEST_DAG_ID).delete() session.commit() session.close() @@ -348,40 +309,35 @@ def test_execute(self, mock_hook): time_partitioning=None, api_resource_configs=None, cluster_fields=None, - encryption_configuration=encryption_configuration + encryption_configuration=encryption_configuration, ) operator.execute(MagicMock()) - mock_hook.return_value \ - .run_query \ - .assert_called_once_with( - sql='Select * from test_table', - destination_dataset_table=None, - write_disposition='WRITE_EMPTY', - allow_large_results=False, - flatten_results=None, - udf_config=None, - maximum_billing_tier=None, - maximum_bytes_billed=None, - create_disposition='CREATE_IF_NEEDED', - schema_update_options=(), - query_params=None, - labels=None, - priority='INTERACTIVE', - time_partitioning=None, - api_resource_configs=None, - cluster_fields=None, - encryption_configuration=encryption_configuration - ) + mock_hook.return_value.run_query.assert_called_once_with( + sql='Select * from test_table', + destination_dataset_table=None, + write_disposition='WRITE_EMPTY', + allow_large_results=False, + flatten_results=None, + udf_config=None, + maximum_billing_tier=None, + maximum_bytes_billed=None, + create_disposition='CREATE_IF_NEEDED', + schema_update_options=(), + query_params=None, + labels=None, + priority='INTERACTIVE', + time_partitioning=None, + api_resource_configs=None, + cluster_fields=None, + encryption_configuration=encryption_configuration, + ) @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute_list(self, mock_hook): operator = BigQueryExecuteQueryOperator( task_id=TASK_ID, - sql=[ - 'Select * from test_table', - 'Select * from other_test_table', - ], + sql=['Select * from test_table', 'Select * from other_test_table',], destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, @@ -403,9 +359,8 @@ def test_execute_list(self, mock_hook): ) operator.execute(MagicMock()) - mock_hook.return_value \ - .run_query \ - .assert_has_calls([ + mock_hook.return_value.run_query.assert_has_calls( + [ mock.call( sql='Select * from test_table', destination_dataset_table=None, @@ -444,7 +399,8 @@ def test_execute_list(self, mock_hook): cluster_fields=None, encryption_configuration=None, ), - ]) + ] + ) @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute_bad_type(self, mock_hook): @@ -480,31 +436,29 @@ def test_bigquery_operator_defaults(self, mock_hook): sql='Select * from test_table', dag=self.dag, default_args=self.args, - schema_update_options=None + schema_update_options=None, ) operator.execute(MagicMock()) - mock_hook.return_value \ - .run_query \ - .assert_called_once_with( - sql='Select * from test_table', - destination_dataset_table=None, - write_disposition='WRITE_EMPTY', - allow_large_results=False, - flatten_results=None, - udf_config=None, - maximum_billing_tier=None, - maximum_bytes_billed=None, - create_disposition='CREATE_IF_NEEDED', - schema_update_options=None, - query_params=None, - labels=None, - priority='INTERACTIVE', - time_partitioning=None, - api_resource_configs=None, - cluster_fields=None, - encryption_configuration=None - ) + mock_hook.return_value.run_query.assert_called_once_with( + sql='Select * from test_table', + destination_dataset_table=None, + write_disposition='WRITE_EMPTY', + allow_large_results=False, + flatten_results=None, + udf_config=None, + maximum_billing_tier=None, + maximum_bytes_billed=None, + create_disposition='CREATE_IF_NEEDED', + schema_update_options=None, + query_params=None, + labels=None, + priority='INTERACTIVE', + time_partitioning=None, + api_resource_configs=None, + cluster_fields=None, + encryption_configuration=None, + ) self.assertTrue(isinstance(operator.sql, str)) ti = TaskInstance(task=operator, execution_date=DEFAULT_DATE) ti.render_templates() @@ -513,8 +467,7 @@ def test_bigquery_operator_defaults(self, mock_hook): def test_bigquery_operator_extra_serialized_field_when_single_query(self): with self.dag: BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql='SELECT * FROM test_table', + task_id=TASK_ID, sql='SELECT * FROM test_table', ) serialized_dag = SerializedDAG.to_dict(self.dag) self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) @@ -530,7 +483,7 @@ def test_bigquery_operator_extra_serialized_field_when_single_query(self): # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}] + [{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}], ) # Check DeSerialized version of operator link @@ -550,16 +503,16 @@ def test_bigquery_operator_extra_serialized_field_when_single_query(self): def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): with self.dag: BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], + task_id=TASK_ID, sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], ) serialized_dag = SerializedDAG.to_dict(self.dag) self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] - self.assertEqual(getattr(simple_task, "sql"), - ['SELECT * FROM test_table', 'SELECT * FROM test_table2']) + self.assertEqual( + getattr(simple_task, "sql"), ['SELECT * FROM test_table', 'SELECT * FROM test_table2'] + ) ######################################################### # Verify Operator Links work with Serialized Operator @@ -578,8 +531,8 @@ def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': { 'index': 1 } - } - ] + }, + ], ) # Check DeSerialized version of operator link @@ -590,8 +543,7 @@ def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): ti.xcom_push(key='job_id', value=job_id) self.assertEqual( - {'BigQuery Console #1', 'BigQuery Console #2'}, - simple_task.operator_extra_link_dict.keys() + {'BigQuery Console #1', 'BigQuery Console #2'}, simple_task.operator_extra_link_dict.keys() ) self.assertEqual( @@ -608,33 +560,25 @@ def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_bigquery_operator_extra_link_when_missing_job_id(self, mock_hook, session): bigquery_task = BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql='SELECT * FROM test_table', - dag=self.dag, + task_id=TASK_ID, sql='SELECT * FROM test_table', dag=self.dag, ) self.dag.clear() session.query(XCom).delete() self.assertEqual( - '', - bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name), + '', bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name), ) @provide_session @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_bigquery_operator_extra_link_when_single_query(self, mock_hook, session): bigquery_task = BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql='SELECT * FROM test_table', - dag=self.dag, + task_id=TASK_ID, sql='SELECT * FROM test_table', dag=self.dag, ) self.dag.clear() session.query(XCom).delete() - ti = TaskInstance( - task=bigquery_task, - execution_date=DEFAULT_DATE, - ) + ti = TaskInstance(task=bigquery_task, execution_date=DEFAULT_DATE,) job_id = '12345' ti.xcom_push(key='job_id', value=job_id) @@ -645,32 +589,25 @@ def test_bigquery_operator_extra_link_when_single_query(self, mock_hook, session ) self.assertEqual( - '', - bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name), + '', bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name), ) @provide_session @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_bigquery_operator_extra_link_when_multiple_query(self, mock_hook, session): bigquery_task = BigQueryExecuteQueryOperator( - task_id=TASK_ID, - sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], - dag=self.dag, + task_id=TASK_ID, sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], dag=self.dag, ) self.dag.clear() session.query(XCom).delete() - ti = TaskInstance( - task=bigquery_task, - execution_date=DEFAULT_DATE, - ) + ti = TaskInstance(task=bigquery_task, execution_date=DEFAULT_DATE,) job_id = ['123', '45'] ti.xcom_push(key='job_id', value=job_id) self.assertEqual( - {'BigQuery Console #1', 'BigQuery Console #2'}, - bigquery_task.operator_extra_link_dict.keys() + {'BigQuery Console #1', 'BigQuery Console #2'}, bigquery_task.operator_extra_link_dict.keys() ) self.assertEqual( @@ -685,29 +622,27 @@ def test_bigquery_operator_extra_link_when_multiple_query(self, mock_hook, sessi class TestBigQueryGetDataOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook): max_results = 100 selected_fields = 'DATE' - operator = BigQueryGetDataOperator(task_id=TASK_ID, - dataset_id=TEST_DATASET, - table_id=TEST_TABLE_ID, - max_results=max_results, - selected_fields=selected_fields, - location=TEST_DATASET_LOCATION, - ) + operator = BigQueryGetDataOperator( + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + max_results=max_results, + selected_fields=selected_fields, + location=TEST_DATASET_LOCATION, + ) operator.execute(None) - mock_hook.return_value \ - .list_rows \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - table_id=TEST_TABLE_ID, - max_results=max_results, - selected_fields=selected_fields, - location=TEST_DATASET_LOCATION, - ) + mock_hook.return_value.list_rows.assert_called_once_with( + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + max_results=max_results, + selected_fields=selected_fields, + location=TEST_DATASET_LOCATION, + ) class TestBigQueryTableDeleteOperator(unittest.TestCase): @@ -719,13 +654,12 @@ def test_execute(self, mock_hook): operator = BigQueryDeleteTableOperator( task_id=TASK_ID, deletion_dataset_table=deletion_dataset_table, - ignore_if_missing=ignore_if_missing + ignore_if_missing=ignore_if_missing, ) operator.execute(None) mock_hook.return_value.delete_table.assert_called_once_with( - table_id=deletion_dataset_table, - not_found_ok=ignore_if_missing + table_id=deletion_dataset_table, not_found_ok=ignore_if_missing ) @@ -733,42 +667,41 @@ class TestBigQueryGetDatasetTablesOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook): operator = BigQueryGetDatasetTablesOperator( - task_id=TASK_ID, - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - max_results=2 + task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, max_results=2 ) operator.execute(None) mock_hook.return_value.get_dataset_tables.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - max_results=2, + dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, max_results=2, ) class TestBigQueryConnIdDeprecationWarning(unittest.TestCase): - @parameterized.expand([ - (BigQueryCheckOperator, dict(sql='Select * from test_table', task_id=TASK_ID)), - (BigQueryValueCheckOperator, dict(sql='Select * from test_table', pass_value=95, task_id=TASK_ID)), - (BigQueryIntervalCheckOperator, - dict(table=TEST_TABLE_ID, metrics_thresholds={'COUNT(*)': 1.5}, task_id=TASK_ID)), - (BigQueryGetDataOperator, dict(dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, task_id=TASK_ID)), - (BigQueryExecuteQueryOperator, dict(sql='Select * from test_table', task_id=TASK_ID)), - (BigQueryDeleteDatasetOperator, dict(dataset_id=TEST_DATASET, task_id=TASK_ID)), - (BigQueryCreateEmptyDatasetOperator, dict(dataset_id=TEST_DATASET, task_id=TASK_ID)), - (BigQueryDeleteTableOperator, dict(deletion_dataset_table=TEST_DATASET, task_id=TASK_ID)) - ]) + @parameterized.expand( + [ + (BigQueryCheckOperator, dict(sql='Select * from test_table', task_id=TASK_ID)), + ( + BigQueryValueCheckOperator, + dict(sql='Select * from test_table', pass_value=95, task_id=TASK_ID), + ), + ( + BigQueryIntervalCheckOperator, + dict(table=TEST_TABLE_ID, metrics_thresholds={'COUNT(*)': 1.5}, task_id=TASK_ID), + ), + (BigQueryGetDataOperator, dict(dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, task_id=TASK_ID)), + (BigQueryExecuteQueryOperator, dict(sql='Select * from test_table', task_id=TASK_ID)), + (BigQueryDeleteDatasetOperator, dict(dataset_id=TEST_DATASET, task_id=TASK_ID)), + (BigQueryCreateEmptyDatasetOperator, dict(dataset_id=TEST_DATASET, task_id=TASK_ID)), + (BigQueryDeleteTableOperator, dict(deletion_dataset_table=TEST_DATASET, task_id=TASK_ID)), + ] + ) def test_bigquery_conn_id_deprecation_warning(self, operator_class, kwargs): bigquery_conn_id = 'google_cloud_default' with self.assertWarnsRegex( DeprecationWarning, - "The bigquery_conn_id parameter has been deprecated. You should pass the gcp_conn_id parameter." + "The bigquery_conn_id parameter has been deprecated. You should pass the gcp_conn_id parameter.", ): - operator = operator_class( - bigquery_conn_id=bigquery_conn_id, - **kwargs - ) + operator = operator_class(bigquery_conn_id=bigquery_conn_id, **kwargs) self.assertEqual(bigquery_conn_id, operator.gcp_conn_id) @@ -783,13 +716,9 @@ def test_execute(self, mock_hook): ) operator.execute(None) - mock_hook.return_value \ - .run_table_upsert \ - .assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_resource=TEST_TABLE_RESOURCES - ) + mock_hook.return_value.run_table_upsert.assert_called_once_with( + dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, table_resource=TEST_TABLE_RESOURCES + ) class TestBigQueryInsertJobOperator: @@ -801,12 +730,7 @@ def test_execute_success(self, mock_hook, mock_md5): real_job_id = f"{job_id}_{hash_}" mock_md5.return_value.hexdigest.return_value = hash_ - configuration = { - "query": { - "query": "SELECT * FROM any", - "useLegacySql": False, - } - } + configuration = {"query": {"query": "SELECT * FROM any", "useLegacySql": False,}} mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) op = BigQueryInsertJobOperator( @@ -814,7 +738,7 @@ def test_execute_success(self, mock_hook, mock_md5): configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, - project_id=TEST_GCP_PROJECT_ID + project_id=TEST_GCP_PROJECT_ID, ) result = op.execute({}) @@ -835,12 +759,7 @@ def test_execute_failure(self, mock_hook, mock_md5): real_job_id = f"{job_id}_{hash_}" mock_md5.return_value.hexdigest.return_value = hash_ - configuration = { - "query": { - "query": "SELECT * FROM any", - "useLegacySql": False, - } - } + configuration = {"query": {"query": "SELECT * FROM any", "useLegacySql": False,}} mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=True) op = BigQueryInsertJobOperator( @@ -848,7 +767,7 @@ def test_execute_failure(self, mock_hook, mock_md5): configuration=configuration, location=TEST_DATASET_LOCATION, job_id=job_id, - project_id=TEST_GCP_PROJECT_ID + project_id=TEST_GCP_PROJECT_ID, ) with pytest.raises(AirflowException): op.execute({}) @@ -861,17 +780,10 @@ def test_execute_reattach(self, mock_hook, mock_md5): real_job_id = f"{job_id}_{hash_}" mock_md5.return_value.hexdigest.return_value = hash_ - configuration = { - "query": { - "query": "SELECT * FROM any", - "useLegacySql": False, - } - } + configuration = {"query": {"query": "SELECT * FROM any", "useLegacySql": False,}} mock_hook.return_value.insert_job.return_value.result.side_effect = Conflict("any") - job = MagicMock( - job_id=real_job_id, error_result=False, state="PENDING", done=lambda: False, - ) + job = MagicMock(job_id=real_job_id, error_result=False, state="PENDING", done=lambda: False,) mock_hook.return_value.get_job.return_value = job op = BigQueryInsertJobOperator( @@ -880,14 +792,12 @@ def test_execute_reattach(self, mock_hook, mock_md5): location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, - reattach_states={"PENDING"} + reattach_states={"PENDING"}, ) result = op.execute({}) mock_hook.return_value.get_job.assert_called_once_with( - location=TEST_DATASET_LOCATION, - job_id=real_job_id, - project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, job_id=real_job_id, project_id=TEST_GCP_PROJECT_ID, ) job.result.assert_called_once_with() @@ -903,16 +813,9 @@ def test_execute_force_rerun(self, mock_hook, mock_uuid, mock_md5): real_job_id = f"{job_id}_{hash_}" mock_md5.return_value.hexdigest.return_value = hash_ - configuration = { - "query": { - "query": "SELECT * FROM any", - "useLegacySql": False, - } - } + configuration = {"query": {"query": "SELECT * FROM any", "useLegacySql": False,}} - job = MagicMock( - job_id=real_job_id, error_result=False, - ) + job = MagicMock(job_id=real_job_id, error_result=False,) mock_hook.return_value.insert_job.return_value = job op = BigQueryInsertJobOperator( @@ -942,17 +845,10 @@ def test_execute_no_force_rerun(self, mock_hook, mock_md5): real_job_id = f"{job_id}_{hash_}" mock_md5.return_value.hexdigest.return_value = hash_ - configuration = { - "query": { - "query": "SELECT * FROM any", - "useLegacySql": False, - } - } + configuration = {"query": {"query": "SELECT * FROM any", "useLegacySql": False,}} mock_hook.return_value.insert_job.return_value.result.side_effect = Conflict("any") - job = MagicMock( - job_id=real_job_id, error_result=False, state="DONE", done=lambda: True, - ) + job = MagicMock(job_id=real_job_id, error_result=False, state="DONE", done=lambda: True,) mock_hook.return_value.get_job.return_value = job op = BigQueryInsertJobOperator( @@ -961,7 +857,7 @@ def test_execute_no_force_rerun(self, mock_hook, mock_md5): location=TEST_DATASET_LOCATION, job_id=job_id, project_id=TEST_GCP_PROJECT_ID, - reattach_states={"PENDING"} + reattach_states={"PENDING"}, ) # No force rerun with pytest.raises(AirflowException): diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py index 0827ffdab85de..d06dc85ae8db6 100644 --- a/tests/providers/google/cloud/operators/test_bigquery_dts.py +++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py @@ -20,7 +20,8 @@ import mock from airflow.providers.google.cloud.operators.bigquery_dts import ( - BigQueryCreateDataTransferOperator, BigQueryDataTransferServiceStartTransferRunsOperator, + BigQueryCreateDataTransferOperator, + BigQueryDataTransferServiceStartTransferRunsOperator, BigQueryDeleteDataTransferConfigOperator, ) diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts_system.py b/tests/providers/google/cloud/operators/test_bigquery_dts_system.py index 0ce5677e49c78..93e57394d2180 100644 --- a/tests/providers/google/cloud/operators/test_bigquery_dts_system.py +++ b/tests/providers/google/cloud/operators/test_bigquery_dts_system.py @@ -18,7 +18,10 @@ import pytest from airflow.providers.google.cloud.example_dags.example_bigquery_dts import ( - BUCKET_URI, GCP_DTS_BQ_DATASET, GCP_DTS_BQ_TABLE, GCP_PROJECT_ID, + BUCKET_URI, + GCP_DTS_BQ_DATASET, + GCP_DTS_BQ_TABLE, + GCP_PROJECT_ID, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGQUERY_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -32,8 +35,7 @@ def create_dataset(self, project_id: str, dataset: str, table: str): table_name = f"{dataset_name}.{table}" self.execute_with_ctx( - ["bq", "--location", "us", "mk", "--dataset", dataset_name], - key=GCP_BIGQUERY_KEY + ["bq", "--location", "us", "mk", "--dataset", dataset_name], key=GCP_BIGQUERY_KEY ) self.execute_with_ctx(["bq", "mk", "--table", table_name, ""], key=GCP_BIGQUERY_KEY) @@ -50,7 +52,8 @@ def upload_data(self, dataset: str, table: str, gcs_file: str): "CSV", table_name, gcs_file, - ], key=GCP_BIGQUERY_KEY + ], + key=GCP_BIGQUERY_KEY, ) def delete_dataset(self, project_id: str, dataset: str): @@ -61,17 +64,13 @@ def delete_dataset(self, project_id: str, dataset: str): def setUp(self): super().setUp() self.create_dataset( - project_id=GCP_PROJECT_ID, - dataset=GCP_DTS_BQ_DATASET, - table=GCP_DTS_BQ_TABLE, + project_id=GCP_PROJECT_ID, dataset=GCP_DTS_BQ_DATASET, table=GCP_DTS_BQ_TABLE, ) self.upload_data(dataset=GCP_DTS_BQ_DATASET, table=GCP_DTS_BQ_TABLE, gcs_file=BUCKET_URI) @provide_gcp_context(GCP_BIGQUERY_KEY) def tearDown(self): - self.delete_dataset( - project_id=GCP_PROJECT_ID, dataset=GCP_DTS_BQ_DATASET - ) + self.delete_dataset(project_id=GCP_PROJECT_ID, dataset=GCP_DTS_BQ_DATASET) super().tearDown() @provide_gcp_context(GCP_BIGQUERY_KEY) diff --git a/tests/providers/google/cloud/operators/test_bigtable.py b/tests/providers/google/cloud/operators/test_bigtable.py index ee467eb5b2561..aafef5cb5297b 100644 --- a/tests/providers/google/cloud/operators/test_bigtable.py +++ b/tests/providers/google/cloud/operators/test_bigtable.py @@ -28,8 +28,12 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.bigtable import ( - BigtableCreateInstanceOperator, BigtableCreateTableOperator, BigtableDeleteInstanceOperator, - BigtableDeleteTableOperator, BigtableUpdateClusterOperator, BigtableUpdateInstanceOperator, + BigtableCreateInstanceOperator, + BigtableCreateTableOperator, + BigtableDeleteInstanceOperator, + BigtableDeleteTableOperator, + BigtableUpdateClusterOperator, + BigtableUpdateInstanceOperator, ) PROJECT_ID = 'test_project_id' @@ -53,15 +57,18 @@ class TestBigtableInstanceCreate(unittest.TestCase): - @parameterized.expand([ - ('instance_id', PROJECT_ID, '', CLUSTER_ID, CLUSTER_ZONE), - ('main_cluster_id', PROJECT_ID, INSTANCE_ID, '', CLUSTER_ZONE), - ('main_cluster_zone', PROJECT_ID, INSTANCE_ID, CLUSTER_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) - @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_empty_attribute(self, missing_attribute, project_id, instance_id, - main_cluster_id, - main_cluster_zone, mock_hook): + @parameterized.expand( + [ + ('instance_id', PROJECT_ID, '', CLUSTER_ID, CLUSTER_ZONE), + ('main_cluster_id', PROJECT_ID, INSTANCE_ID, '', CLUSTER_ZONE), + ('main_cluster_zone', PROJECT_ID, INSTANCE_ID, CLUSTER_ID, ''), + ], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) + @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') + def test_empty_attribute( + self, missing_attribute, project_id, instance_id, main_cluster_id, main_cluster_zone, mock_hook + ): with self.assertRaises(AirflowException) as e: BigtableCreateInstanceOperator( project_id=project_id, @@ -69,7 +76,7 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, main_cluster_id=main_cluster_id, main_cluster_zone=main_cluster_zone, task_id="id", - gcp_conn_id=GCP_CONN_ID + gcp_conn_id=GCP_CONN_ID, ) err = e.exception self.assertEqual(str(err), 'Empty parameter: {}'.format(missing_attribute)) @@ -91,8 +98,7 @@ def test_create_instance_that_exists(self, mock_hook): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_instance.assert_not_called() @@ -111,8 +117,7 @@ def test_create_instance_that_exists_empty_project_id(self, mock_hook): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_instance.assert_not_called() @@ -130,14 +135,14 @@ def test_different_error_reraised(self, mock_hook): ) mock_hook.return_value.create_instance.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.GoogleAPICallError('error')) + side_effect=google.api_core.exceptions.GoogleAPICallError('error') + ) with self.assertRaises(google.api_core.exceptions.GoogleAPICallError): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_instance.assert_called_once_with( cluster_nodes=None, @@ -152,7 +157,7 @@ def test_different_error_reraised(self, mock_hook): replica_clusters=None, replica_cluster_id=None, replica_cluster_zone=None, - timeout=None + timeout=None, ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') @@ -169,8 +174,7 @@ def test_create_instance_that_doesnt_exists(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_instance.assert_called_once_with( cluster_nodes=None, @@ -185,7 +189,7 @@ def test_create_instance_that_doesnt_exists(self, mock_hook): replica_clusters=None, replica_cluster_id=None, replica_cluster_zone=None, - timeout=None + timeout=None, ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') @@ -203,8 +207,7 @@ def test_create_instance_with_replicas_that_doesnt_exists(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_instance.assert_called_once_with( cluster_nodes=None, @@ -219,7 +222,7 @@ def test_create_instance_with_replicas_that_doesnt_exists(self, mock_hook): replica_clusters=REPLICATE_CLUSTERS, replica_cluster_id=None, replica_cluster_zone=None, - timeout=None + timeout=None, ) @@ -238,8 +241,7 @@ def test_delete_execute(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_called_once_with( project_id=PROJECT_ID, @@ -247,7 +249,7 @@ def test_delete_execute(self, mock_hook): instance_display_name=INSTANCE_DISPLAY_NAME, instance_type=INSTANCE_TYPE, instance_labels=INSTANCE_LABELS, - timeout=None + timeout=None, ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') @@ -263,8 +265,7 @@ def test_update_execute_empty_project_id(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_called_once_with( project_id=None, @@ -272,12 +273,13 @@ def test_update_execute_empty_project_id(self, mock_hook): instance_display_name=INSTANCE_DISPLAY_NAME, instance_type=INSTANCE_TYPE, instance_labels=INSTANCE_LABELS, - timeout=None + timeout=None, ) - @parameterized.expand([ - ('instance_id', PROJECT_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) + @parameterized.expand( + [('instance_id', PROJECT_ID, ''),], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_empty_attribute(self, missing_attribute, project_id, instance_id, mock_hook): with self.assertRaises(AirflowException) as e: @@ -287,7 +289,7 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, mock_ instance_display_name=INSTANCE_DISPLAY_NAME, instance_type=INSTANCE_TYPE, instance_labels=INSTANCE_LABELS, - task_id="id" + task_id="id", ) err = e.exception self.assertEqual(str(err), 'Empty parameter: {}'.format(missing_attribute)) @@ -311,12 +313,10 @@ def test_update_instance_that_doesnt_exists(self, mock_hook): op.execute(None) err = e.exception - self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format( - INSTANCE_ID)) + self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format(INSTANCE_ID)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_not_called() @@ -337,12 +337,10 @@ def test_update_instance_that_doesnt_exists_empty_project_id(self, mock_hook): op.execute(None) err = e.exception - self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format( - INSTANCE_ID)) + self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format(INSTANCE_ID)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_not_called() @@ -359,14 +357,14 @@ def test_different_error_reraised(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.GoogleAPICallError('error')) + side_effect=google.api_core.exceptions.GoogleAPICallError('error') + ) with self.assertRaises(google.api_core.exceptions.GoogleAPICallError): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_called_once_with( project_id=PROJECT_ID, @@ -374,19 +372,21 @@ def test_different_error_reraised(self, mock_hook): instance_display_name=INSTANCE_DISPLAY_NAME, instance_type=INSTANCE_TYPE, instance_labels=INSTANCE_LABELS, - timeout=None + timeout=None, ) class TestBigtableClusterUpdate(unittest.TestCase): - @parameterized.expand([ - ('instance_id', PROJECT_ID, '', CLUSTER_ID, NODES), - ('cluster_id', PROJECT_ID, INSTANCE_ID, '', NODES), - ('nodes', PROJECT_ID, INSTANCE_ID, CLUSTER_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) - @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_empty_attribute(self, missing_attribute, project_id, instance_id, - cluster_id, nodes, mock_hook): + @parameterized.expand( + [ + ('instance_id', PROJECT_ID, '', CLUSTER_ID, NODES), + ('cluster_id', PROJECT_ID, INSTANCE_ID, '', NODES), + ('nodes', PROJECT_ID, INSTANCE_ID, CLUSTER_ID, ''), + ], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) + @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') + def test_empty_attribute(self, missing_attribute, project_id, instance_id, cluster_id, nodes, mock_hook): with self.assertRaises(AirflowException) as e: BigtableUpdateClusterOperator( project_id=project_id, @@ -394,7 +394,7 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, cluster_id=cluster_id, nodes=nodes, task_id="id", - gcp_conn_id=GCP_CONN_ID + gcp_conn_id=GCP_CONN_ID, ) err = e.exception self.assertEqual(str(err), 'Empty parameter: {}'.format(missing_attribute)) @@ -417,17 +417,14 @@ def test_updating_cluster_but_instance_does_not_exists(self, mock_hook): op.execute(None) err = e.exception - self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format( - INSTANCE_ID)) + self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format(INSTANCE_ID)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_not_called() @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_updating_cluster_but_instance_does_not_exists_empty_project_id(self, - mock_hook): + def test_updating_cluster_but_instance_does_not_exists_empty_project_id(self, mock_hook): mock_hook.return_value.get_instance.return_value = None with self.assertRaises(AirflowException) as e: @@ -442,11 +439,9 @@ def test_updating_cluster_but_instance_does_not_exists_empty_project_id(self, op.execute(None) err = e.exception - self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format( - INSTANCE_ID)) + self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format(INSTANCE_ID)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_not_called() @@ -454,7 +449,8 @@ def test_updating_cluster_but_instance_does_not_exists_empty_project_id(self, def test_updating_cluster_that_does_not_exists(self, mock_hook): instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) mock_hook.return_value.update_cluster.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Cluster not found.")) + side_effect=google.api_core.exceptions.NotFound("Cluster not found.") + ) with self.assertRaises(AirflowException) as e: op = BigtableUpdateClusterOperator( @@ -471,21 +467,21 @@ def test_updating_cluster_that_does_not_exists(self, mock_hook): err = e.exception self.assertEqual( str(err), - "Dependency: cluster '{}' does not exist for instance '{}'.".format( - CLUSTER_ID, INSTANCE_ID) + "Dependency: cluster '{}' does not exist for instance '{}'.".format(CLUSTER_ID, INSTANCE_ID), ) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with( - instance=instance, cluster_id=CLUSTER_ID, nodes=NODES) + instance=instance, cluster_id=CLUSTER_ID, nodes=NODES + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_updating_cluster_that_does_not_exists_empty_project_id(self, mock_hook): instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) mock_hook.return_value.update_cluster.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Cluster not found.")) + side_effect=google.api_core.exceptions.NotFound("Cluster not found.") + ) with self.assertRaises(AirflowException) as e: op = BigtableUpdateClusterOperator( @@ -501,15 +497,14 @@ def test_updating_cluster_that_does_not_exists_empty_project_id(self, mock_hook) err = e.exception self.assertEqual( str(err), - "Dependency: cluster '{}' does not exist for instance '{}'.".format( - CLUSTER_ID, INSTANCE_ID) + "Dependency: cluster '{}' does not exist for instance '{}'.".format(CLUSTER_ID, INSTANCE_ID), ) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with( - instance=instance, cluster_id=CLUSTER_ID, nodes=NODES) + instance=instance, cluster_id=CLUSTER_ID, nodes=NODES + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_different_error_reraised(self, mock_hook): @@ -524,17 +519,18 @@ def test_different_error_reraised(self, mock_hook): ) instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) mock_hook.return_value.update_cluster.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.GoogleAPICallError('error')) + side_effect=google.api_core.exceptions.GoogleAPICallError('error') + ) with self.assertRaises(google.api_core.exceptions.GoogleAPICallError): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with( - instance=instance, cluster_id=CLUSTER_ID, nodes=NODES) + instance=instance, cluster_id=CLUSTER_ID, nodes=NODES + ) class TestBigtableInstanceDelete(unittest.TestCase): @@ -549,12 +545,11 @@ def test_delete_execute(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID) + project_id=PROJECT_ID, instance_id=INSTANCE_ID + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_delete_execute_empty_project_id(self, mock_hook): @@ -566,24 +561,20 @@ def test_delete_execute_empty_project_id(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=None, - instance_id=INSTANCE_ID) + project_id=None, instance_id=INSTANCE_ID + ) - @parameterized.expand([ - ('instance_id', PROJECT_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) + @parameterized.expand( + [('instance_id', PROJECT_ID, ''),], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_empty_attribute(self, missing_attribute, project_id, instance_id, mock_hook): with self.assertRaises(AirflowException) as e: - BigtableDeleteInstanceOperator( - project_id=project_id, - instance_id=instance_id, - task_id="id" - ) + BigtableDeleteInstanceOperator(project_id=project_id, instance_id=instance_id, task_id="id") err = e.exception self.assertEqual(str(err), 'Empty parameter: {}'.format(missing_attribute)) mock_hook.assert_not_called() @@ -598,15 +589,15 @@ def test_deleting_instance_that_doesnt_exists(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Instance not found.")) + side_effect=google.api_core.exceptions.NotFound("Instance not found.") + ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID) + project_id=PROJECT_ID, instance_id=INSTANCE_ID + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_deleting_instance_that_doesnt_exists_empty_project_id(self, mock_hook): @@ -617,15 +608,15 @@ def test_deleting_instance_that_doesnt_exists_empty_project_id(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Instance not found.")) + side_effect=google.api_core.exceptions.NotFound("Instance not found.") + ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=None, - instance_id=INSTANCE_ID) + project_id=None, instance_id=INSTANCE_ID + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_different_error_reraised(self, mock_hook): @@ -637,18 +628,18 @@ def test_different_error_reraised(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.GoogleAPICallError('error')) + side_effect=google.api_core.exceptions.GoogleAPICallError('error') + ) with self.assertRaises(google.api_core.exceptions.GoogleAPICallError): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID) + project_id=PROJECT_ID, instance_id=INSTANCE_ID + ) class TestBigtableTableDelete(unittest.TestCase): @@ -664,28 +655,25 @@ def test_delete_execute(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_table.assert_called_once_with( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - table_id=TABLE_ID) + project_id=PROJECT_ID, instance_id=INSTANCE_ID, table_id=TABLE_ID + ) - @parameterized.expand([ - ('instance_id', PROJECT_ID, '', TABLE_ID), - ('table_id', PROJECT_ID, INSTANCE_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) + @parameterized.expand( + [('instance_id', PROJECT_ID, '', TABLE_ID), ('table_id', PROJECT_ID, INSTANCE_ID, ''),], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, - mock_hook): + def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook): with self.assertRaises(AirflowException) as e: BigtableDeleteTableOperator( project_id=project_id, instance_id=instance_id, table_id=table_id, task_id="id", - gcp_conn_id=GCP_CONN_ID + gcp_conn_id=GCP_CONN_ID, ) err = e.exception self.assertEqual(str(err), 'Empty parameter: {}'.format(missing_attribute)) @@ -703,16 +691,15 @@ def test_deleting_table_that_doesnt_exists(self, mock_hook): ) mock_hook.return_value.delete_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Table not found.")) + side_effect=google.api_core.exceptions.NotFound("Table not found.") + ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_table.assert_called_once_with( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - table_id=TABLE_ID) + project_id=PROJECT_ID, instance_id=INSTANCE_ID, table_id=TABLE_ID + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_deleting_table_that_doesnt_exists_empty_project_id(self, mock_hook): @@ -725,16 +712,15 @@ def test_deleting_table_that_doesnt_exists_empty_project_id(self, mock_hook): ) mock_hook.return_value.delete_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Table not found.")) + side_effect=google.api_core.exceptions.NotFound("Table not found.") + ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_table.assert_called_once_with( - project_id=None, - instance_id=INSTANCE_ID, - table_id=TABLE_ID) + project_id=None, instance_id=INSTANCE_ID, table_id=TABLE_ID + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_deleting_table_when_instance_doesnt_exists(self, mock_hook): @@ -751,11 +737,9 @@ def test_deleting_table_when_instance_doesnt_exists(self, mock_hook): with self.assertRaises(AirflowException) as e: op.execute(None) err = e.exception - self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format( - INSTANCE_ID)) + self.assertEqual(str(err), "Dependency: instance '{}' does not exist.".format(INSTANCE_ID)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_table.assert_not_called() @@ -770,19 +754,18 @@ def test_different_error_reraised(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.GoogleAPICallError('error')) + side_effect=google.api_core.exceptions.GoogleAPICallError('error') + ) with self.assertRaises(google.api_core.exceptions.GoogleAPICallError): op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_table.assert_called_once_with( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - table_id=TABLE_ID) + project_id=PROJECT_ID, instance_id=INSTANCE_ID, table_id=TABLE_ID + ) class TestBigtableTableCreate(unittest.TestCase): @@ -801,29 +784,28 @@ def test_create_execute(self, mock_hook): instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_table.assert_called_once_with( instance=instance, table_id=TABLE_ID, initial_split_keys=INITIAL_SPLIT_KEYS, - column_families=EMPTY_COLUMN_FAMILIES) + column_families=EMPTY_COLUMN_FAMILIES, + ) - @parameterized.expand([ - ('instance_id', PROJECT_ID, '', TABLE_ID), - ('table_id', PROJECT_ID, INSTANCE_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) + @parameterized.expand( + [('instance_id', PROJECT_ID, '', TABLE_ID), ('table_id', PROJECT_ID, INSTANCE_ID, ''),], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, - mock_hook): + def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook): with self.assertRaises(AirflowException) as e: BigtableCreateTableOperator( project_id=project_id, instance_id=instance_id, table_id=table_id, task_id="id", - gcp_conn_id=GCP_CONN_ID + gcp_conn_id=GCP_CONN_ID, ) err = e.exception self.assertEqual(str(err), 'Empty parameter: {}'.format(missing_attribute)) @@ -847,12 +829,10 @@ def test_instance_not_exists(self, mock_hook): err = e.exception self.assertEqual( str(err), - "Dependency: instance '{}' does not exist in project '{}'.".format( - INSTANCE_ID, PROJECT_ID) + "Dependency: instance '{}' does not exist in project '{}'.".format(INSTANCE_ID, PROJECT_ID), ) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') @@ -868,22 +848,22 @@ def test_creating_table_that_exists(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.get_column_families_for_table.return_value = \ - EMPTY_COLUMN_FAMILIES + mock_hook.return_value.get_column_families_for_table.return_value = EMPTY_COLUMN_FAMILIES instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) mock_hook.return_value.create_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.")) + side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.") + ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_table.assert_called_once_with( instance=instance, table_id=TABLE_ID, initial_split_keys=INITIAL_SPLIT_KEYS, - column_families=EMPTY_COLUMN_FAMILIES) + column_families=EMPTY_COLUMN_FAMILIES, + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') def test_creating_table_that_exists_empty_project_id(self, mock_hook): @@ -897,26 +877,25 @@ def test_creating_table_that_exists_empty_project_id(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.get_column_families_for_table.return_value = \ - EMPTY_COLUMN_FAMILIES + mock_hook.return_value.get_column_families_for_table.return_value = EMPTY_COLUMN_FAMILIES instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) mock_hook.return_value.create_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.")) + side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.") + ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_table.assert_called_once_with( instance=instance, table_id=TABLE_ID, initial_split_keys=INITIAL_SPLIT_KEYS, - column_families=EMPTY_COLUMN_FAMILIES) + column_families=EMPTY_COLUMN_FAMILIES, + ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_creating_table_that_exists_with_different_column_families_ids_in_the_table( - self, mock_hook): + def test_creating_table_that_exists_with_different_column_families_ids_in_the_table(self, mock_hook): op = BigtableCreateTableOperator( project_id=PROJECT_ID, instance_id=INSTANCE_ID, @@ -928,26 +907,23 @@ def test_creating_table_that_exists_with_different_column_families_ids_in_the_ta impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.get_column_families_for_table.return_value = { - "existing_family": None} + mock_hook.return_value.get_column_families_for_table.return_value = {"existing_family": None} mock_hook.return_value.create_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.")) + side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.") + ) with self.assertRaises(AirflowException) as e: op.execute(None) err = e.exception self.assertEqual( - str(err), - "Table '{}' already exists with different Column Families.".format(TABLE_ID) + str(err), "Table '{}' already exists with different Column Families.".format(TABLE_ID) ) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook') - def test_creating_table_that_exists_with_different_column_families_gc_rule_in__table( - self, mock_hook): + def test_creating_table_that_exists_with_different_column_families_gc_rule_in__table(self, mock_hook): op = BigtableCreateTableOperator( project_id=PROJECT_ID, instance_id=INSTANCE_ID, @@ -962,20 +938,17 @@ def test_creating_table_that_exists_with_different_column_families_gc_rule_in__t cf_mock = mock.Mock() cf_mock.gc_rule = mock.Mock(return_value=MaxVersionsGCRule(2)) - mock_hook.return_value.get_column_families_for_table.return_value = { - "cf-id": cf_mock - } + mock_hook.return_value.get_column_families_for_table.return_value = {"cf-id": cf_mock} mock_hook.return_value.create_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.")) + side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.") + ) with self.assertRaises(AirflowException) as e: op.execute(None) err = e.exception self.assertEqual( - str(err), - "Table '{}' already exists with different Column Families.".format(TABLE_ID) + str(err), "Table '{}' already exists with different Column Families.".format(TABLE_ID) ) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) diff --git a/tests/providers/google/cloud/operators/test_bigtable_system.py b/tests/providers/google/cloud/operators/test_bigtable_system.py index 9196dc252d7fe..b987731e25393 100644 --- a/tests/providers/google/cloud/operators/test_bigtable_system.py +++ b/tests/providers/google/cloud/operators/test_bigtable_system.py @@ -29,16 +29,24 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_BIGTABLE_KEY) class BigTableExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_BIGTABLE_KEY) def test_run_example_dag_gcs_bigtable(self): self.run_dag('example_gcp_bigtable_operators', CLOUD_DAG_FOLDER) @provide_gcp_context(GCP_BIGTABLE_KEY) def tearDown(self): - self.execute_with_ctx([ - 'gcloud', 'bigtable', '--project', GCP_PROJECT_ID, - '--quiet', '--verbosity=none', - 'instances', 'delete', CBT_INSTANCE - ], key=GCP_BIGTABLE_KEY) + self.execute_with_ctx( + [ + 'gcloud', + 'bigtable', + '--project', + GCP_PROJECT_ID, + '--quiet', + '--verbosity=none', + 'instances', + 'delete', + CBT_INSTANCE, + ], + key=GCP_BIGTABLE_KEY, + ) super().tearDown() diff --git a/tests/providers/google/cloud/operators/test_cloud_build.py b/tests/providers/google/cloud/operators/test_cloud_build.py index b00cf243be853..4dbd285144f80 100644 --- a/tests/providers/google/cloud/operators/test_cloud_build.py +++ b/tests/providers/google/cloud/operators/test_cloud_build.py @@ -164,9 +164,7 @@ def test_storage_source_replace(self, hook_mock): } hook_mock.create_build(body=expected_result, project_id=TEST_PROJECT_ID) - @mock.patch( - "airflow.providers.google.cloud.operators.cloud_build.CloudBuildHook", - ) + @mock.patch("airflow.providers.google.cloud.operators.cloud_build.CloudBuildHook",) def test_repo_source_replace(self, hook_mock): hook_mock.return_value.create_build.return_value = TEST_CREATE_BODY current_body = { @@ -212,25 +210,20 @@ def test_repo_source_replace(self, hook_mock): def test_load_templated_yaml(self): dag = DAG(dag_id='example_cloudbuild_operator', start_date=TEST_DEFAULT_DATE) with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w+t') as build: - build.writelines(""" + build.writelines( + """ steps: - name: 'ubuntu' args: ['echo', 'Hello {{ params.name }}!'] - """) + """ + ) build.seek(0) body_path = build.name operator = CloudBuildCreateBuildOperator( - body=body_path, - task_id="task-id", dag=dag, - params={'name': 'airflow'} + body=body_path, task_id="task-id", dag=dag, params={'name': 'airflow'} ) operator.prepare_template() ti = TaskInstance(operator, TEST_DEFAULT_DATE) ti.render_templates() - expected_body = {'steps': [ - {'name': 'ubuntu', - 'args': ['echo', 'Hello airflow!'] - } - ] - } + expected_body = {'steps': [{'name': 'ubuntu', 'args': ['echo', 'Hello airflow!']}]} self.assertEqual(expected_body, operator.body) diff --git a/tests/providers/google/cloud/operators/test_cloud_build_system.py b/tests/providers/google/cloud/operators/test_cloud_build_system.py index 951ef967f90d2..55e9edabdb03d 100644 --- a/tests/providers/google/cloud/operators/test_cloud_build_system.py +++ b/tests/providers/google/cloud/operators/test_cloud_build_system.py @@ -31,6 +31,7 @@ class CloudBuildExampleDagsSystemTest(GoogleSystemTest): It use a real service. """ + helper = GCPCloudBuildTestHelper() @provide_gcp_context(GCP_CLOUD_BUILD_KEY, project_id=GoogleSystemTest._project_id()) diff --git a/tests/providers/google/cloud/operators/test_cloud_memorystore.py b/tests/providers/google/cloud/operators/test_cloud_memorystore.py index a6ea94ea8a0f4..4db32b68e75a6 100644 --- a/tests/providers/google/cloud/operators/test_cloud_memorystore.py +++ b/tests/providers/google/cloud/operators/test_cloud_memorystore.py @@ -23,11 +23,16 @@ from google.cloud.redis_v1.types import Instance from airflow.providers.google.cloud.operators.cloud_memorystore import ( - CloudMemorystoreCreateInstanceAndImportOperator, CloudMemorystoreCreateInstanceOperator, - CloudMemorystoreDeleteInstanceOperator, CloudMemorystoreExportInstanceOperator, - CloudMemorystoreFailoverInstanceOperator, CloudMemorystoreGetInstanceOperator, - CloudMemorystoreImportOperator, CloudMemorystoreListInstancesOperator, - CloudMemorystoreScaleInstanceOperator, CloudMemorystoreUpdateInstanceOperator, + CloudMemorystoreCreateInstanceAndImportOperator, + CloudMemorystoreCreateInstanceOperator, + CloudMemorystoreDeleteInstanceOperator, + CloudMemorystoreExportInstanceOperator, + CloudMemorystoreFailoverInstanceOperator, + CloudMemorystoreGetInstanceOperator, + CloudMemorystoreImportOperator, + CloudMemorystoreListInstancesOperator, + CloudMemorystoreScaleInstanceOperator, + CloudMemorystoreUpdateInstanceOperator, ) TEST_GCP_CONN_ID = "test-gcp-conn-id" @@ -68,8 +73,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_instance.assert_called_once_with( location=TEST_LOCATION, @@ -98,8 +102,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_instance.assert_called_once_with( location=TEST_LOCATION, @@ -128,8 +131,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.export_instance.assert_called_once_with( location=TEST_LOCATION, @@ -159,8 +161,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.failover_instance.assert_called_once_with( location=TEST_LOCATION, @@ -189,8 +190,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.get_instance.assert_called_once_with( location=TEST_LOCATION, @@ -219,8 +219,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.import_instance.assert_called_once_with( location=TEST_LOCATION, @@ -249,8 +248,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.list_instances.assert_called_once_with( location=TEST_LOCATION, @@ -280,8 +278,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_called_once_with( update_mask=TEST_UPDATE_MASK, @@ -312,8 +309,7 @@ def test_assert_valid_hook_call(self, mock_hook): ) task.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.update_instance.assert_called_once_with( update_mask={"paths": ["memory_size_gb"]}, @@ -346,10 +342,7 @@ def test_assert_valid_hook_call(self, mock_hook): task.execute(mock.MagicMock()) mock_hook.assert_has_calls( [ - mock.call( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, - ), + mock.call(gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN,), mock.call().create_instance( location=TEST_LOCATION, instance_id=TEST_INSTANCE_ID, diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py b/tests/providers/google/cloud/operators/test_cloud_sql.py index 4809877884030..3211585445b76 100644 --- a/tests/providers/google/cloud/operators/test_cloud_sql.py +++ b/tests/providers/google/cloud/operators/test_cloud_sql.py @@ -27,9 +27,14 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.google.cloud.operators.cloud_sql import ( - CloudSQLCreateInstanceDatabaseOperator, CloudSQLCreateInstanceOperator, - CloudSQLDeleteInstanceDatabaseOperator, CloudSQLDeleteInstanceOperator, CloudSQLExecuteQueryOperator, - CloudSQLExportInstanceOperator, CloudSQLImportInstanceOperator, CloudSQLInstancePatchOperator, + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceDatabaseOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExecuteQueryOperator, + CloudSQLExportInstanceOperator, + CloudSQLImportInstanceOperator, + CloudSQLInstancePatchOperator, CloudSQLPatchInstanceDatabaseOperator, ) @@ -45,7 +50,7 @@ "binaryLogEnabled": True, "enabled": True, "replicationLogArchivingEnabled": True, - "startTime": "05:00" + "startTime": "05:00", }, "activationPolicy": "ALWAYS", "authorizedGaeApplications": [], @@ -59,33 +64,25 @@ { "value": "192.168.100.0/24", "name": "network1", - "expirationTime": "2012-11-15T16:19:00.094Z" + "expirationTime": "2012-11-15T16:19:00.094Z", }, ], "privateNetwork": "/vpc/resource/link", - "requireSsl": True + "requireSsl": True, }, "locationPreference": { "zone": "europe-west4-a", - "followGaeApplication": "/app/engine/application/to/follow" - }, - "maintenanceWindow": { - "hour": 5, - "day": 7, - "updateTrack": "canary" + "followGaeApplication": "/app/engine/application/to/follow", }, + "maintenanceWindow": {"hour": 5, "day": 7, "updateTrack": "canary"}, "pricingPlan": "PER_USE", "replicationType": "ASYNCHRONOUS", "storageAutoResize": False, "storageAutoResizeLimit": 0, - "userLabels": { - "my-key": "my-value" - } + "userLabels": {"my-key": "my-value"}, }, "databaseVersion": "MYSQL_5_7", - "failoverReplica": { - "name": "replica-1" - }, + "failoverReplica": {"name": "replica-1"}, "masterInstanceName": "master-instance-1", "onPremisesConfiguration": {}, "region": "europe-west4", @@ -100,47 +97,32 @@ "password": "secret_pass", "sslCipher": "list-of-ciphers", "username": "user", - "verifyServerCertificate": True + "verifyServerCertificate": True, }, - } + }, } PATCH_BODY = { "name": INSTANCE_NAME, - "settings": { - "tier": "db-n1-standard-2", - "dataDiskType": "PD_HDD" - }, - "region": "europe-west4" + "settings": {"tier": "db-n1-standard-2", "dataDiskType": "PD_HDD"}, + "region": "europe-west4", } DATABASE_INSERT_BODY = { - "name": DB_NAME, # The name of the database in the Cloud SQL instance. - # This does not include the project ID or instance name. - - "project": PROJECT_ID, # The project ID of the project containing the Cloud SQL - # database. The Google apps domain is prefixed if - # applicable. - + "name": DB_NAME, # The name of the database in the Cloud SQL instance. + # This does not include the project ID or instance name. + "project": PROJECT_ID, # The project ID of the project containing the Cloud SQL + # database. The Google apps domain is prefixed if + # applicable. "instance": INSTANCE_NAME, # The name of the Cloud SQL instance. - # This does not include the project ID. -} -DATABASE_PATCH_BODY = { - "charset": "utf16", - "collation": "utf16_general_ci" + # This does not include the project ID. } +DATABASE_PATCH_BODY = {"charset": "utf16", "collation": "utf16_general_ci"} EXPORT_BODY = { "exportContext": { "fileType": "CSV", "uri": "gs://bucketName/fileName", "databases": [], - "sqlExportOptions": { - "tables": [ - "table1", "table2" - ], - "schemaOnly": False - }, - "csvExportOptions": { - "selectQuery": "SELECT * FROM TABLE" - } + "sqlExportOptions": {"tables": ["table1", "table2"], "schemaOnly": False}, + "csvExportOptions": {"selectQuery": "SELECT * FROM TABLE"}, } } IMPORT_BODY = { @@ -149,12 +131,7 @@ "uri": "gs://bucketName/fileName", "database": "db1", "importUser": "", - "csvImportOptions": { - "table": "my_table", - "columns": [ - "col1", "col2" - ] - } + "csvImportOptions": {"table": "my_table", "columns": ["col1", "col2"]}, } } @@ -169,64 +146,54 @@ def test_instance_create(self, mock_hook, _check_if_instance_exists): _check_if_instance_exists.return_value = False mock_hook.return_value.create_instance.return_value = True op = CloudSQLCreateInstanceOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=CREATE_BODY, - task_id="id" - ) - result = op.execute(context={ # pylint: disable=assignment-from-no-return - 'task_instance': mock.Mock() - }) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=CREATE_BODY, task_id="id" + ) + result = op.execute( + context={'task_instance': mock.Mock()} # pylint: disable=assignment-from-no-return + ) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.create_instance.assert_called_once_with( - project_id=PROJECT_ID, - body=CREATE_BODY + project_id=PROJECT_ID, body=CREATE_BODY ) self.assertIsNone(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLCreateInstanceOperator._check_if_instance_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLCreateInstanceOperator._check_if_instance_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_create_missing_project_id(self, mock_hook, _check_if_instance_exists): _check_if_instance_exists.return_value = False mock_hook.return_value.create_instance.return_value = True - op = CloudSQLCreateInstanceOperator( - instance=INSTANCE_NAME, - body=CREATE_BODY, - task_id="id" - ) - result = op.execute(context={ # pylint: disable=assignment-from-no-return - 'task_instance': mock.Mock() - }) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) - mock_hook.return_value.create_instance.assert_called_once_with( - project_id=None, - body=CREATE_BODY + op = CloudSQLCreateInstanceOperator(instance=INSTANCE_NAME, body=CREATE_BODY, task_id="id") + result = op.execute( + context={'task_instance': mock.Mock()} # pylint: disable=assignment-from-no-return + ) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, ) + mock_hook.return_value.create_instance.assert_called_once_with(project_id=None, body=CREATE_BODY) self.assertIsNone(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLCreateInstanceOperator._check_if_instance_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLCreateInstanceOperator._check_if_instance_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_create_idempotent(self, mock_hook, _check_if_instance_exists): _check_if_instance_exists.return_value = True mock_hook.return_value.create_instance.return_value = True op = CloudSQLCreateInstanceOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=CREATE_BODY, - task_id="id" - ) - result = op.execute(context={ # pylint: disable=assignment-from-no-return - 'task_instance': mock.Mock() - }) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=CREATE_BODY, task_id="id" + ) + result = op.execute( + context={'task_instance': mock.Mock()} # pylint: disable=assignment-from-no-return + ) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.create_instance.assert_not_called() self.assertIsNone(result) @@ -234,10 +201,7 @@ def test_instance_create_idempotent(self, mock_hook, _check_if_instance_exists): def test_create_should_throw_ex_when_empty_project_id(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = CloudSQLCreateInstanceOperator( - project_id="", - body=CREATE_BODY, - instance=INSTANCE_NAME, - task_id="id" + project_id="", body=CREATE_BODY, instance=INSTANCE_NAME, task_id="id" ) op.execute(None) err = cm.exception @@ -248,10 +212,7 @@ def test_create_should_throw_ex_when_empty_project_id(self, mock_hook): def test_create_should_throw_ex_when_empty_body(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = CloudSQLCreateInstanceOperator( - project_id=PROJECT_ID, - body={}, - instance=INSTANCE_NAME, - task_id="id" + project_id=PROJECT_ID, body={}, instance=INSTANCE_NAME, task_id="id" ) op.execute(None) err = cm.exception @@ -262,10 +223,7 @@ def test_create_should_throw_ex_when_empty_body(self, mock_hook): def test_create_should_throw_ex_when_empty_instance(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = CloudSQLCreateInstanceOperator( - project_id=PROJECT_ID, - body=CREATE_BODY, - instance="", - task_id="id" + project_id=PROJECT_ID, body=CREATE_BODY, instance="", task_id="id" ) op.execute(None) err = cm.exception @@ -281,23 +239,23 @@ def test_create_should_validate_list_type(self, mock_hook): "ipConfiguration": { "authorizedNetworks": {} # Should be a list, not a dict. # Testing if the validation catches this. - } - } + }, + }, } with self.assertRaises(AirflowException) as cm: op = CloudSQLCreateInstanceOperator( - project_id=PROJECT_ID, - body=wrong_list_type_body, - instance=INSTANCE_NAME, - task_id="id" + project_id=PROJECT_ID, body=wrong_list_type_body, instance=INSTANCE_NAME, task_id="id" ) op.execute(None) err = cm.exception - self.assertIn("The field 'settings.ipConfiguration.authorizedNetworks' " - "should be of list type according to the specification", str(err)) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + self.assertIn( + "The field 'settings.ipConfiguration.authorizedNetworks' " + "should be of list type according to the specification", + str(err), + ) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_create_should_validate_non_empty_fields(self, mock_hook): @@ -306,206 +264,176 @@ def test_create_should_validate_non_empty_fields(self, mock_hook): "settings": { "tier": "", # Field can't be empty (defined in CLOUD_SQL_VALIDATION). # Testing if the validation catches this. - } + }, } with self.assertRaises(AirflowException) as cm: op = CloudSQLCreateInstanceOperator( - project_id=PROJECT_ID, - body=empty_tier_body, - instance=INSTANCE_NAME, - task_id="id" + project_id=PROJECT_ID, body=empty_tier_body, instance=INSTANCE_NAME, task_id="id" ) op.execute(None) err = cm.exception - self.assertIn("The body field 'settings.tier' can't be empty. " - "Please provide a value.", str(err)) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + self.assertIn("The body field 'settings.tier' can't be empty. " "Please provide a value.", str(err)) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_patch(self, mock_hook): mock_hook.return_value.patch_instance.return_value = True op = CloudSQLInstancePatchOperator( - project_id=PROJECT_ID, - body=PATCH_BODY, - instance=INSTANCE_NAME, - task_id="id" + project_id=PROJECT_ID, body=PATCH_BODY, instance=INSTANCE_NAME, task_id="id" ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.patch_instance.assert_called_once_with( - project_id=PROJECT_ID, - body=PATCH_BODY, - instance=INSTANCE_NAME + project_id=PROJECT_ID, body=PATCH_BODY, instance=INSTANCE_NAME ) self.assertTrue(result) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_patch_missing_project_id(self, mock_hook): mock_hook.return_value.patch_instance.return_value = True - op = CloudSQLInstancePatchOperator( - body=PATCH_BODY, - instance=INSTANCE_NAME, - task_id="id" - ) + op = CloudSQLInstancePatchOperator(body=PATCH_BODY, instance=INSTANCE_NAME, task_id="id") result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.patch_instance.assert_called_once_with( - project_id=None, - body=PATCH_BODY, - instance=INSTANCE_NAME + project_id=None, body=PATCH_BODY, instance=INSTANCE_NAME ) self.assertTrue(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLInstancePatchOperator._check_if_instance_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLInstancePatchOperator._check_if_instance_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") - def test_instance_patch_should_bubble_up_ex_if_not_exists(self, mock_hook, - _check_if_instance_exists): + def test_instance_patch_should_bubble_up_ex_if_not_exists(self, mock_hook, _check_if_instance_exists): _check_if_instance_exists.return_value = False with self.assertRaises(AirflowException) as cm: op = CloudSQLInstancePatchOperator( - project_id=PROJECT_ID, - body=PATCH_BODY, - instance=INSTANCE_NAME, - task_id="id" + project_id=PROJECT_ID, body=PATCH_BODY, instance=INSTANCE_NAME, task_id="id" ) op.execute(None) err = cm.exception self.assertIn('specify another instance to patch', str(err)) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.patch_instance.assert_not_called() - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLDeleteInstanceOperator._check_if_instance_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLDeleteInstanceOperator._check_if_instance_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_delete(self, mock_hook, _check_if_instance_exists): _check_if_instance_exists.return_value = True - op = CloudSQLDeleteInstanceOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - task_id="id" - ) + op = CloudSQLDeleteInstanceOperator(project_id=PROJECT_ID, instance=INSTANCE_NAME, task_id="id") result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=PROJECT_ID, - instance=INSTANCE_NAME + project_id=PROJECT_ID, instance=INSTANCE_NAME ) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLDeleteInstanceOperator._check_if_instance_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLDeleteInstanceOperator._check_if_instance_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_delete_missing_project_id(self, mock_hook, _check_if_instance_exists): _check_if_instance_exists.return_value = True - op = CloudSQLDeleteInstanceOperator( - instance=INSTANCE_NAME, - task_id="id" - ) + op = CloudSQLDeleteInstanceOperator(instance=INSTANCE_NAME, task_id="id") result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=None, - instance=INSTANCE_NAME + project_id=None, instance=INSTANCE_NAME ) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLDeleteInstanceOperator._check_if_instance_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLDeleteInstanceOperator._check_if_instance_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_delete_should_abort_and_succeed_if_not_exists( - self, - mock_hook, - _check_if_instance_exists): + self, mock_hook, _check_if_instance_exists + ): _check_if_instance_exists.return_value = False - op = CloudSQLDeleteInstanceOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - task_id="id" - ) + op = CloudSQLDeleteInstanceOperator(project_id=PROJECT_ID, instance=INSTANCE_NAME, task_id="id") result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.delete_instance.assert_not_called() - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLCreateInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLCreateInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_db_create(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = False op = CloudSQLCreateInstanceDatabaseOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=DATABASE_INSERT_BODY, - task_id="id" + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY, task_id="id" ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.create_database.assert_called_once_with( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=DATABASE_INSERT_BODY + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY ) self.assertTrue(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLCreateInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLCreateInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_db_create_missing_project_id(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = False op = CloudSQLCreateInstanceDatabaseOperator( - instance=INSTANCE_NAME, - body=DATABASE_INSERT_BODY, - task_id="id" + instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY, task_id="id" ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.create_database.assert_called_once_with( - project_id=None, - instance=INSTANCE_NAME, - body=DATABASE_INSERT_BODY + project_id=None, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY ) self.assertTrue(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLCreateInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLCreateInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") - def test_instance_db_create_should_abort_and_succeed_if_exists( - self, mock_hook, _check_if_db_exists): + def test_instance_db_create_should_abort_and_succeed_if_exists(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = True op = CloudSQLCreateInstanceDatabaseOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=DATABASE_INSERT_BODY, - task_id="id" + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY, task_id="id" ) result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.create_database.assert_not_called() - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLPatchInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLPatchInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_db_patch(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = True @@ -514,48 +442,42 @@ def test_instance_db_patch(self, mock_hook, _check_if_db_exists): instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY, - task_id="id" + task_id="id", ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.patch_database.assert_called_once_with( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - database=DB_NAME, - body=DATABASE_PATCH_BODY + project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY ) self.assertTrue(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLPatchInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLPatchInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_db_patch_missing_project_id(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = True op = CloudSQLPatchInstanceDatabaseOperator( - instance=INSTANCE_NAME, - database=DB_NAME, - body=DATABASE_PATCH_BODY, - task_id="id" + instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY, task_id="id" ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.patch_database.assert_called_once_with( - project_id=None, - instance=INSTANCE_NAME, - database=DB_NAME, - body=DATABASE_PATCH_BODY + project_id=None, instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY ) self.assertTrue(result) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLPatchInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLPatchInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") - def test_instance_db_patch_should_throw_ex_if_not_exists( - self, mock_hook, _check_if_db_exists): + def test_instance_db_patch_should_throw_ex_if_not_exists(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = False with self.assertRaises(AirflowException) as cm: op = CloudSQLPatchInstanceDatabaseOperator( @@ -563,15 +485,15 @@ def test_instance_db_patch_should_throw_ex_if_not_exists( instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY, - task_id="id" + task_id="id", ) op.execute(None) err = cm.exception self.assertIn("Cloud SQL instance with ID", str(err)) self.assertIn("does not contain database", str(err)) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.patch_database.assert_not_called() @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") @@ -582,7 +504,7 @@ def test_instance_db_patch_should_throw_ex_when_empty_database(self, mock_hook): instance=INSTANCE_NAME, database="", body=DATABASE_INSERT_BODY, - task_id="id" + task_id="id", ) op.execute(None) err = cm.exception @@ -590,104 +512,84 @@ def test_instance_db_patch_should_throw_ex_when_empty_database(self, mock_hook): mock_hook.assert_not_called() mock_hook.return_value.patch_database.assert_not_called() - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLDeleteInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLDeleteInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_db_delete(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = True op = CloudSQLDeleteInstanceDatabaseOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - database=DB_NAME, - task_id="id" + project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, task_id="id" ) result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.delete_database.assert_called_once_with( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - database=DB_NAME + project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME ) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLDeleteInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLDeleteInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_db_delete_missing_project_id(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = True - op = CloudSQLDeleteInstanceDatabaseOperator( - instance=INSTANCE_NAME, - database=DB_NAME, - task_id="id" - ) + op = CloudSQLDeleteInstanceDatabaseOperator(instance=INSTANCE_NAME, database=DB_NAME, task_id="id") result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.delete_database.assert_called_once_with( - project_id=None, - instance=INSTANCE_NAME, - database=DB_NAME + project_id=None, instance=INSTANCE_NAME, database=DB_NAME ) - @mock.patch("airflow.providers.google.cloud.operators.cloud_sql" - ".CloudSQLDeleteInstanceDatabaseOperator._check_if_db_exists") + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_sql" + ".CloudSQLDeleteInstanceDatabaseOperator._check_if_db_exists" + ) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") - def test_instance_db_delete_should_abort_and_succeed_if_not_exists( - self, mock_hook, _check_if_db_exists): + def test_instance_db_delete_should_abort_and_succeed_if_not_exists(self, mock_hook, _check_if_db_exists): _check_if_db_exists.return_value = False op = CloudSQLDeleteInstanceDatabaseOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - database=DB_NAME, - task_id="id" + project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, task_id="id" ) result = op.execute(None) self.assertTrue(result) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.delete_database.assert_not_called() @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_export(self, mock_hook): mock_hook.return_value.export_instance.return_value = True op = CloudSQLExportInstanceOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=EXPORT_BODY, - task_id="id" + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=EXPORT_BODY, task_id="id" ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.export_instance.assert_called_once_with( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=EXPORT_BODY + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=EXPORT_BODY ) self.assertTrue(result) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_export_missing_project_id(self, mock_hook): mock_hook.return_value.export_instance.return_value = True - op = CloudSQLExportInstanceOperator( - instance=INSTANCE_NAME, - body=EXPORT_BODY, - task_id="id" - ) + op = CloudSQLExportInstanceOperator(instance=INSTANCE_NAME, body=EXPORT_BODY, task_id="id") result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.export_instance.assert_called_once_with( - project_id=None, - instance=INSTANCE_NAME, - body=EXPORT_BODY + project_id=None, instance=INSTANCE_NAME, body=EXPORT_BODY ) self.assertTrue(result) @@ -695,44 +597,32 @@ def test_instance_export_missing_project_id(self, mock_hook): def test_instance_import(self, mock_hook): mock_hook.return_value.export_instance.return_value = True op = CloudSQLImportInstanceOperator( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=IMPORT_BODY, - task_id="id" + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=IMPORT_BODY, task_id="id" ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.import_instance.assert_called_once_with( - project_id=PROJECT_ID, - instance=INSTANCE_NAME, - body=IMPORT_BODY + project_id=PROJECT_ID, instance=INSTANCE_NAME, body=IMPORT_BODY ) self.assertTrue(result) @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_import_missing_project_id(self, mock_hook): mock_hook.return_value.export_instance.return_value = True - op = CloudSQLImportInstanceOperator( - instance=INSTANCE_NAME, - body=IMPORT_BODY, - task_id="id" - ) + op = CloudSQLImportInstanceOperator(instance=INSTANCE_NAME, body=IMPORT_BODY, task_id="id") result = op.execute(None) - mock_hook.assert_called_once_with(api_version="v1beta4", - gcp_conn_id="google_cloud_default", - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version="v1beta4", gcp_conn_id="google_cloud_default", impersonation_chain=None, + ) mock_hook.return_value.import_instance.assert_called_once_with( - project_id=None, - instance=INSTANCE_NAME, - body=IMPORT_BODY + project_id=None, instance=INSTANCE_NAME, body=IMPORT_BODY ) self.assertTrue(result) class TestCloudSqlQueryValidation(unittest.TestCase): - @staticmethod def _setup_connections(get_connections, uri): gcp_connection = mock.MagicMock() @@ -740,72 +630,108 @@ def _setup_connections(get_connections, uri): gcp_connection.extra_dejson.get.return_value = 'empty_project' cloudsql_connection = Connection(uri=uri) cloudsql_connection2 = Connection(uri=uri) - get_connections.side_effect = [[gcp_connection], [cloudsql_connection], - [cloudsql_connection2]] - - @parameterized.expand([ - ('project_id', '', 'instance_name', 'mysql', False, False, - 'SELECT * FROM TEST', - "The required extra 'location' is empty"), - ('project_id', 'location', '', 'postgres', False, False, - 'SELECT * FROM TEST', - "The required extra 'instance' is empty"), - ('project_id', 'location', 'instance_name', 'wrong', False, False, - 'SELECT * FROM TEST', - "Invalid database type 'wrong'. Must be one of ['postgres', 'mysql']"), - ('project_id', 'location', 'instance_name', 'postgres', True, True, - 'SELECT * FROM TEST', - "Cloud SQL Proxy does not support SSL connections. SSL is not needed as" - " Cloud SQL Proxy provides encryption on its own"), - ('project_id', 'location', 'instance_name', 'postgres', False, True, - 'SELECT * FROM TEST', - "SSL connections requires sslcert to be set"), - ]) + get_connections.side_effect = [[gcp_connection], [cloudsql_connection], [cloudsql_connection2]] + + @parameterized.expand( + [ + ( + 'project_id', + '', + 'instance_name', + 'mysql', + False, + False, + 'SELECT * FROM TEST', + "The required extra 'location' is empty", + ), + ( + 'project_id', + 'location', + '', + 'postgres', + False, + False, + 'SELECT * FROM TEST', + "The required extra 'instance' is empty", + ), + ( + 'project_id', + 'location', + 'instance_name', + 'wrong', + False, + False, + 'SELECT * FROM TEST', + "Invalid database type 'wrong'. Must be one of ['postgres', 'mysql']", + ), + ( + 'project_id', + 'location', + 'instance_name', + 'postgres', + True, + True, + 'SELECT * FROM TEST', + "Cloud SQL Proxy does not support SSL connections. SSL is not needed as" + " Cloud SQL Proxy provides encryption on its own", + ), + ( + 'project_id', + 'location', + 'instance_name', + 'postgres', + False, + True, + 'SELECT * FROM TEST', + "SSL connections requires sslcert to be set", + ), + ] + ) @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") - def test_create_operator_with_wrong_parameters(self, - project_id, - location, - instance_name, - database_type, - use_proxy, - use_ssl, - sql, - message, - get_connections): - uri = \ - "gcpcloudsql://user:password@127.0.0.1:3200/testdb?" \ - "database_type={database_type}&" \ - "project_id={project_id}&location={location}&instance={instance_name}&" \ + def test_create_operator_with_wrong_parameters( + self, + project_id, + location, + instance_name, + database_type, + use_proxy, + use_ssl, + sql, + message, + get_connections, + ): + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?" + "database_type={database_type}&" + "project_id={project_id}&location={location}&instance={instance_name}&" "use_proxy={use_proxy}&use_ssl={use_ssl}".format( database_type=database_type, project_id=project_id, location=location, instance_name=instance_name, use_proxy=use_proxy, - use_ssl=use_ssl) + use_ssl=use_ssl, + ) + ) self._setup_connections(get_connections, uri) with self.assertRaises(AirflowException) as cm: - op = CloudSQLExecuteQueryOperator( - sql=sql, - task_id='task_id' - ) + op = CloudSQLExecuteQueryOperator(sql=sql, task_id='task_id') op.execute(None) err = cm.exception self.assertIn(message, str(err)) @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections") def test_create_operator_with_too_long_unix_socket_path(self, get_connections): - uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \ - "project_id=example-project&location=europe-west1&" \ - "instance=" \ - "test_db_with_long_name_a_bit_above" \ - "_the_limit_of_UNIX_socket_asdadadasadasd&" \ - "use_proxy=True&sql_proxy_use_tcp=False" - self._setup_connections(get_connections, uri) - operator = CloudSQLExecuteQueryOperator( - sql=['SELECT * FROM TABLE'], - task_id='task_id' + uri = ( + "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" + "project_id=example-project&location=europe-west1&" + "instance=" + "test_db_with_long_name_a_bit_above" + "_the_limit_of_UNIX_socket_asdadadasadasd&" + "use_proxy=True&sql_proxy_use_tcp=False" ) + self._setup_connections(get_connections, uri) + operator = CloudSQLExecuteQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') with self.assertRaises(AirflowException) as cm: operator.execute(None) err = cm.exception diff --git a/tests/providers/google/cloud/operators/test_cloud_sql_system.py b/tests/providers/google/cloud/operators/test_cloud_sql_system.py index bc44c5ef87ed1..dbf5feed9360d 100644 --- a/tests/providers/google/cloud/operators/test_cloud_sql_system.py +++ b/tests/providers/google/cloud/operators/test_cloud_sql_system.py @@ -25,7 +25,10 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_sql import CloudSqlProxyRunner from tests.providers.google.cloud.operators.test_cloud_sql_system_helper import ( - QUERY_SUFFIX, TEARDOWN_LOCK_FILE, TEARDOWN_LOCK_FILE_QUERY, CloudSqlQueryTestHelper, + QUERY_SUFFIX, + TEARDOWN_LOCK_FILE, + TEARDOWN_LOCK_FILE_QUERY, + CloudSqlQueryTestHelper, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_CLOUDSQL_KEY, GcpAuthenticator from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -62,21 +65,23 @@ def test_run_example_dag_cloudsql(self): "can remove 'random.txt' file from /files/airflow-breeze-config/ folder and restart " "breeze environment. This will generate random name of the database for next run " "(the problem is that Cloud SQL keeps names of deleted instances in " - "short-term cache).") + "short-term cache)." + ) raise e @pytest.mark.system("google.cloud") @pytest.mark.credential_file(GCP_CLOUDSQL_KEY) class CloudSqlProxySystemTest(GoogleSystemTest): - @classmethod @provide_gcp_context(GCP_CLOUDSQL_KEY) def setUpClass(cls): SQL_QUERY_TEST_HELPER.set_ip_addresses_in_env() if os.path.exists(TEARDOWN_LOCK_FILE_QUERY): - print("Skip creating and setting up instances as they were created manually " - "(helps to iterate on tests)") + print( + "Skip creating and setting up instances as they were created manually " + "(helps to iterate on tests)" + ) else: helper = CloudSqlQueryTestHelper() gcp_authenticator = GcpAuthenticator(gcp_key=GCP_CLOUDSQL_KEY) @@ -101,13 +106,14 @@ def tearDownClass(cls): @staticmethod def generate_unique_path(): - return ''.join( - random.choice(string.ascii_letters + string.digits) for _ in range(8)) + return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) def test_start_proxy_fail_no_parameters(self): - runner = CloudSqlProxyRunner(path_prefix='/tmp/' + self.generate_unique_path(), - project_id=GCP_PROJECT_ID, - instance_specification='a') + runner = CloudSqlProxyRunner( + path_prefix='/tmp/' + self.generate_unique_path(), + project_id=GCP_PROJECT_ID, + instance_specification='a', + ) with self.assertRaises(AirflowException) as cm: runner.start_proxy() err = cm.exception @@ -119,9 +125,11 @@ def test_start_proxy_fail_no_parameters(self): self.assertIsNone(runner.sql_proxy_process) def test_start_proxy_with_all_instances(self): - runner = CloudSqlProxyRunner(path_prefix='/tmp/' + self.generate_unique_path(), - project_id=GCP_PROJECT_ID, - instance_specification='') + runner = CloudSqlProxyRunner( + path_prefix='/tmp/' + self.generate_unique_path(), + project_id=GCP_PROJECT_ID, + instance_specification='', + ) try: runner.start_proxy() time.sleep(1) @@ -131,9 +139,11 @@ def test_start_proxy_with_all_instances(self): @provide_gcp_context(GCP_CLOUDSQL_KEY) def test_start_proxy_with_all_instances_generated_credential_file(self): - runner = CloudSqlProxyRunner(path_prefix='/tmp/' + self.generate_unique_path(), - project_id=GCP_PROJECT_ID, - instance_specification='') + runner = CloudSqlProxyRunner( + path_prefix='/tmp/' + self.generate_unique_path(), + project_id=GCP_PROJECT_ID, + instance_specification='', + ) try: runner.start_proxy() time.sleep(1) @@ -142,10 +152,12 @@ def test_start_proxy_with_all_instances_generated_credential_file(self): self.assertIsNone(runner.sql_proxy_process) def test_start_proxy_with_all_instances_specific_version(self): - runner = CloudSqlProxyRunner(path_prefix='/tmp/' + self.generate_unique_path(), - project_id=GCP_PROJECT_ID, - instance_specification='', - sql_proxy_version='v1.13') + runner = CloudSqlProxyRunner( + path_prefix='/tmp/' + self.generate_unique_path(), + project_id=GCP_PROJECT_ID, + instance_specification='', + sql_proxy_version='v1.13', + ) try: runner.start_proxy() time.sleep(1) diff --git a/tests/providers/google/cloud/operators/test_cloud_sql_system_helper.py b/tests/providers/google/cloud/operators/test_cloud_sql_system_helper.py index 6249e7abd77f4..cc22de9c80392 100755 --- a/tests/providers/google/cloud/operators/test_cloud_sql_system_helper.py +++ b/tests/providers/google/cloud/operators/test_cloud_sql_system_helper.py @@ -32,30 +32,24 @@ GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') GCP_LOCATION = os.environ.get('GCP_LOCATION', 'europe-west1') -GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get('GCSQL_POSTGRES_SERVER_CA_FILE', - ".key/postgres-server-ca.pem") -GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get('GCSQL_POSTGRES_CLIENT_CERT_FILE', - ".key/postgres-client-cert.pem") -GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get('GCSQL_POSTGRES_CLIENT_KEY_FILE', - ".key/postgres-client-key.pem") -GCSQL_POSTGRES_PUBLIC_IP_FILE = os.environ.get('GCSQL_POSTGRES_PUBLIC_IP_FILE', - ".key/postgres-ip.env") +GCSQL_POSTGRES_SERVER_CA_FILE = os.environ.get('GCSQL_POSTGRES_SERVER_CA_FILE', ".key/postgres-server-ca.pem") +GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get( + 'GCSQL_POSTGRES_CLIENT_CERT_FILE', ".key/postgres-client-cert.pem" +) +GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get( + 'GCSQL_POSTGRES_CLIENT_KEY_FILE', ".key/postgres-client-key.pem" +) +GCSQL_POSTGRES_PUBLIC_IP_FILE = os.environ.get('GCSQL_POSTGRES_PUBLIC_IP_FILE', ".key/postgres-ip.env") GCSQL_POSTGRES_USER = os.environ.get('GCSQL_POSTGRES_USER', 'postgres_user') -GCSQL_POSTGRES_DATABASE_NAME = os.environ.get('GCSQL_POSTGRES_DATABASE_NAME', - 'postgresdb') -GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_CERT_FILE', - ".key/mysql-client-cert.pem") -GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_KEY_FILE', - ".key/mysql-client-key.pem") -GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get('GCSQL_MYSQL_SERVER_CA_FILE', - ".key/mysql-server-ca.pem") -GCSQL_MYSQL_PUBLIC_IP_FILE = os.environ.get('GCSQL_MYSQL_PUBLIC_IP_FILE', - ".key/mysql-ip.env") +GCSQL_POSTGRES_DATABASE_NAME = os.environ.get('GCSQL_POSTGRES_DATABASE_NAME', 'postgresdb') +GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_CERT_FILE', ".key/mysql-client-cert.pem") +GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get('GCSQL_MYSQL_CLIENT_KEY_FILE', ".key/mysql-client-key.pem") +GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get('GCSQL_MYSQL_SERVER_CA_FILE', ".key/mysql-server-ca.pem") +GCSQL_MYSQL_PUBLIC_IP_FILE = os.environ.get('GCSQL_MYSQL_PUBLIC_IP_FILE', ".key/mysql-ip.env") GCSQL_MYSQL_USER = os.environ.get('GCSQL_MYSQL_USER', 'mysql_user') GCSQL_MYSQL_DATABASE_NAME = os.environ.get('GCSQL_MYSQL_DATABASE_NAME', 'mysqldb') -GCSQL_MYSQL_EXPORT_URI = os.environ.get('GCSQL_MYSQL_EXPORT_URI', - 'gs://bucketName/fileName') +GCSQL_MYSQL_EXPORT_URI = os.environ.get('GCSQL_MYSQL_EXPORT_URI', 'gs://bucketName/fileName') DB_VERSION_MYSQL = 'MYSQL_5_7' DV_VERSION_POSTGRES = 'POSTGRES_9_6' @@ -96,35 +90,43 @@ def get_mysql_instance_name(instance_suffix=''): class CloudSqlQueryTestHelper(LoggingCommandExecutor): - - def create_instances(self, instance_suffix='', - failover_instance_suffix=None, - master_instance_suffix=None): - thread_mysql = Thread(target=lambda: self.__create_instance( - get_mysql_instance_name(instance_suffix), DB_VERSION_MYSQL, - master_instance_name=get_mysql_instance_name(master_instance_suffix), - failover_replica_name=get_mysql_instance_name(failover_instance_suffix) - )) - thread_postgres = Thread(target=lambda: self.__create_instance( - get_postgres_instance_name(instance_suffix), DV_VERSION_POSTGRES, - master_instance_name=get_postgres_instance_name(master_instance_suffix), - failover_replica_name=get_postgres_instance_name(failover_instance_suffix) - )) + def create_instances( + self, instance_suffix='', failover_instance_suffix=None, master_instance_suffix=None + ): + thread_mysql = Thread( + target=lambda: self.__create_instance( + get_mysql_instance_name(instance_suffix), + DB_VERSION_MYSQL, + master_instance_name=get_mysql_instance_name(master_instance_suffix), + failover_replica_name=get_mysql_instance_name(failover_instance_suffix), + ) + ) + thread_postgres = Thread( + target=lambda: self.__create_instance( + get_postgres_instance_name(instance_suffix), + DV_VERSION_POSTGRES, + master_instance_name=get_postgres_instance_name(master_instance_suffix), + failover_replica_name=get_postgres_instance_name(failover_instance_suffix), + ) + ) thread_mysql.start() thread_postgres.start() thread_mysql.join() thread_postgres.join() - def delete_instances(self, instance_suffix='', - master_instance_suffix=None): - thread_mysql = Thread(target=lambda: self.__delete_instance( - get_mysql_instance_name(instance_suffix), - master_instance_name=get_mysql_instance_name(master_instance_suffix) - )) - thread_postgres = Thread(target=lambda: self.__delete_instance( - get_postgres_instance_name(instance_suffix), - master_instance_name=get_mysql_instance_name(master_instance_suffix) - )) + def delete_instances(self, instance_suffix='', master_instance_suffix=None): + thread_mysql = Thread( + target=lambda: self.__delete_instance( + get_mysql_instance_name(instance_suffix), + master_instance_name=get_mysql_instance_name(master_instance_suffix), + ) + ) + thread_postgres = Thread( + target=lambda: self.__delete_instance( + get_postgres_instance_name(instance_suffix), + master_instance_name=get_mysql_instance_name(master_instance_suffix), + ) + ) thread_mysql.start() thread_postgres.start() thread_mysql.join() @@ -132,66 +134,109 @@ def delete_instances(self, instance_suffix='', def get_ip_addresses(self, instance_suffix): with open(GCSQL_MYSQL_PUBLIC_IP_FILE, "w") as file: - ip_address = self.__get_ip_address(get_mysql_instance_name(instance_suffix), - 'GCSQL_MYSQL_PUBLIC_IP') + ip_address = self.__get_ip_address( + get_mysql_instance_name(instance_suffix), 'GCSQL_MYSQL_PUBLIC_IP' + ) file.write(ip_address) with open(GCSQL_POSTGRES_PUBLIC_IP_FILE, "w") as file: - ip_address = self.__get_ip_address(get_postgres_instance_name(instance_suffix), - 'GCSQL_POSTGRES_PUBLIC_IP') + ip_address = self.__get_ip_address( + get_postgres_instance_name(instance_suffix), 'GCSQL_POSTGRES_PUBLIC_IP' + ) file.write(ip_address) def raise_database_exception(self, database): - raise Exception("The {database} instance does not exist. Make sure to run " - "`python {f} --action=before-tests` before running the test" - " (and remember to run `python {f} --action=after-tests` " - "after you are done." - .format(f=__file__, database=database)) + raise Exception( + "The {database} instance does not exist. Make sure to run " + "`python {f} --action=before-tests` before running the test" + " (and remember to run `python {f} --action=after-tests` " + "after you are done.".format(f=__file__, database=database) + ) def check_if_instances_are_up(self, instance_suffix=''): res_postgres = self.execute_cmd( - ['gcloud', 'sql', 'instances', 'describe', - get_postgres_instance_name(instance_suffix), - "--project={}".format(GCP_PROJECT_ID)]) + [ + 'gcloud', + 'sql', + 'instances', + 'describe', + get_postgres_instance_name(instance_suffix), + "--project={}".format(GCP_PROJECT_ID), + ] + ) if res_postgres != 0: self.raise_database_exception('postgres') res_postgres = self.execute_cmd( - ['gcloud', 'sql', 'instances', 'describe', - get_postgres_instance_name(instance_suffix), - "--project={}".format(GCP_PROJECT_ID)]) + [ + 'gcloud', + 'sql', + 'instances', + 'describe', + get_postgres_instance_name(instance_suffix), + "--project={}".format(GCP_PROJECT_ID), + ] + ) if res_postgres != 0: self.raise_database_exception('mysql') def authorize_address(self, instance_suffix=''): ip_address = self.__get_my_public_ip() self.log.info('Authorizing access from IP: %s', ip_address) - postgres_thread = Thread(target=lambda: self.execute_cmd( - ['gcloud', 'sql', 'instances', 'patch', - get_postgres_instance_name(instance_suffix), '--quiet', - "--authorized-networks={}".format(ip_address), - "--project={}".format(GCP_PROJECT_ID)])) - mysql_thread = Thread(target=lambda: self.execute_cmd( - ['gcloud', 'sql', 'instances', 'patch', - get_mysql_instance_name(instance_suffix), '--quiet', - "--authorized-networks={}".format(ip_address), - "--project={}".format(GCP_PROJECT_ID)])) + postgres_thread = Thread( + target=lambda: self.execute_cmd( + [ + 'gcloud', + 'sql', + 'instances', + 'patch', + get_postgres_instance_name(instance_suffix), + '--quiet', + "--authorized-networks={}".format(ip_address), + "--project={}".format(GCP_PROJECT_ID), + ] + ) + ) + mysql_thread = Thread( + target=lambda: self.execute_cmd( + [ + 'gcloud', + 'sql', + 'instances', + 'patch', + get_mysql_instance_name(instance_suffix), + '--quiet', + "--authorized-networks={}".format(ip_address), + "--project={}".format(GCP_PROJECT_ID), + ] + ) + ) postgres_thread.start() mysql_thread.start() postgres_thread.join() mysql_thread.join() def setup_instances(self, instance_suffix=''): - mysql_thread = Thread(target=lambda: self.__setup_instance_and_certs( - get_mysql_instance_name(instance_suffix), DB_VERSION_MYSQL, - server_ca_file_mysql, - client_key_file_mysql, client_cert_file_mysql, GCSQL_MYSQL_DATABASE_NAME, - GCSQL_MYSQL_USER - )) - postgres_thread = Thread(target=lambda: self.__setup_instance_and_certs( - get_postgres_instance_name(instance_suffix), DV_VERSION_POSTGRES, - server_ca_file_postgres, - client_key_file_postgres, client_cert_file_postgres, - GCSQL_POSTGRES_DATABASE_NAME, GCSQL_POSTGRES_USER - )) + mysql_thread = Thread( + target=lambda: self.__setup_instance_and_certs( + get_mysql_instance_name(instance_suffix), + DB_VERSION_MYSQL, + server_ca_file_mysql, + client_key_file_mysql, + client_cert_file_mysql, + GCSQL_MYSQL_DATABASE_NAME, + GCSQL_MYSQL_USER, + ) + ) + postgres_thread = Thread( + target=lambda: self.__setup_instance_and_certs( + get_postgres_instance_name(instance_suffix), + DV_VERSION_POSTGRES, + server_ca_file_postgres, + client_key_file_postgres, + client_cert_file_postgres, + GCSQL_POSTGRES_DATABASE_NAME, + GCSQL_POSTGRES_USER, + ) + ) mysql_thread.start() postgres_thread.start() mysql_thread.join() @@ -202,37 +247,47 @@ def setup_instances(self, instance_suffix=''): def delete_service_account_acls(self): self.__delete_service_accounts_acls() - def __create_instance(self, instance_name, db_version, - failover_replica_name=None, - master_instance_name=None): - self.log.info('Creating a test %s instance "%s"... Failover(%s), Master(%s)', - db_version, instance_name, failover_replica_name, master_instance_name) + def __create_instance( + self, instance_name, db_version, failover_replica_name=None, master_instance_name=None + ): + self.log.info( + 'Creating a test %s instance "%s"... Failover(%s), Master(%s)', + db_version, + instance_name, + failover_replica_name, + master_instance_name, + ) try: create_instance_opcode = self.__create_sql_instance( - instance_name, db_version, + instance_name, + db_version, failover_replica_name=failover_replica_name, - master_instance_name=master_instance_name) + master_instance_name=master_instance_name, + ) if create_instance_opcode: # return code 1, some error occurred operation_name = self.__get_operation_name(instance_name) self.log.info('Waiting for operation: %s ...', operation_name) self.__wait_for_create(operation_name) self.log.info('... Done.') - self.log.info('... Done creating a test %s instance "%s"!\n', - db_version, instance_name) + self.log.info('... Done creating a test %s instance "%s"!\n', db_version, instance_name) except Exception as ex: - self.log.error('Exception occurred. ' - 'Aborting creating a test instance.\n\n%s', ex) + self.log.error('Exception occurred. ' 'Aborting creating a test instance.\n\n%s', ex) raise ex def __delete_service_accounts_acls(self): export_bucket_split = urlsplit(GCSQL_MYSQL_EXPORT_URI) export_bucket_name = export_bucket_split[1] # netloc (bucket) - self.log.info('Deleting temporary service accounts from bucket "%s"...', - export_bucket_name) - all_permissions = self.check_output(['gsutil', 'iam', 'get', - "gs://{}".format(export_bucket_name), - "--project={}".format(GCP_PROJECT_ID)]) + self.log.info('Deleting temporary service accounts from bucket "%s"...', export_bucket_name) + all_permissions = self.check_output( + [ + 'gsutil', + 'iam', + 'get', + "gs://{}".format(export_bucket_name), + "--project={}".format(GCP_PROJECT_ID), + ] + ) all_permissions_dejson = json.loads(all_permissions.decode("utf-8")) for binding in all_permissions_dejson['bindings']: if binding['role'] != 'roles/storage.legacyBucketWriter': @@ -248,11 +303,11 @@ def __delete_service_accounts_acls(self): if member_type != 'serviceAccount': self.log.warning( - "Skip removing member %s as the type %s is not service account", - member, member_type + "Skip removing member %s as the type %s is not service account", member, member_type ) - self.execute_cmd(['gsutil', 'acl', 'ch', '-d', member_email, - "gs://{}".format(export_bucket_name)]) + self.execute_cmd( + ['gsutil', 'acl', 'ch', '-d', member_email, "gs://{}".format(export_bucket_name)] + ) @staticmethod def set_ip_addresses_in_env(): @@ -266,13 +321,19 @@ def __set_ip_address_in_env(file_name): env, ip_address = file.read().split("=") os.environ[env] = ip_address - def __setup_instance_and_certs(self, instance_name, db_version, server_ca_file, - client_key_file, client_cert_file, db_name, - db_username): + def __setup_instance_and_certs( + self, + instance_name, + db_version, + server_ca_file, + client_key_file, + client_cert_file, + db_name, + db_username, + ): self.log.info('Setting up a test %s instance "%s"...', db_version, instance_name) try: - self.__remove_keys_and_certs([server_ca_file, client_key_file, - client_cert_file]) + self.__remove_keys_and_certs([server_ca_file, client_key_file, client_cert_file]) self.__wait_for_operations(instance_name) self.__write_to_file(server_ca_file, self.__get_server_ca_cert(instance_name)) @@ -282,43 +343,51 @@ def __setup_instance_and_certs(self, instance_name, db_version, server_ca_file, self.__wait_for_operations(instance_name) self.__create_client_cert(instance_name, client_key_file, client_cert_name) self.__wait_for_operations(instance_name) - self.__write_to_file(client_cert_file, - self.__get_client_cert(instance_name, client_cert_name)) + self.__write_to_file(client_cert_file, self.__get_client_cert(instance_name, client_cert_name)) self.__wait_for_operations(instance_name) self.__wait_for_operations(instance_name) self.__create_user(instance_name, db_username) self.__wait_for_operations(instance_name) self.__delete_db(instance_name, db_name) self.__create_db(instance_name, db_name) - self.log.info('... Done setting up a test %s instance "%s"!\n', - db_version, instance_name) + self.log.info('... Done setting up a test %s instance "%s"!\n', db_version, instance_name) except Exception as ex: - self.log.error('Exception occurred. ' - 'Aborting setting up test instance and certs.\n\n%s', ex) + self.log.error('Exception occurred. ' 'Aborting setting up test instance and certs.\n\n%s', ex) raise ex - def __delete_instance(self, - instance_name: str, - master_instance_name: Optional[str]) -> None: + def __delete_instance(self, instance_name: str, master_instance_name: Optional[str]) -> None: if master_instance_name is not None: self.__wait_for_operations(master_instance_name) self.__wait_for_operations(instance_name) self.log.info('Deleting Cloud SQL instance "%s"...', instance_name) - self.execute_cmd(['gcloud', 'sql', 'instances', 'delete', - instance_name, '--quiet']) + self.execute_cmd(['gcloud', 'sql', 'instances', 'delete', instance_name, '--quiet']) self.log.info('... Done.') def __get_my_public_ip(self): - return self.check_output( - ['curl', 'https://ipinfo.io/ip']).decode('utf-8').strip() - - def __create_sql_instance(self, instance_name: str, db_version: str, master_instance_name: Optional[str], - failover_replica_name: Optional[str]) -> int: - cmd = ['gcloud', 'sql', 'instances', 'create', instance_name, - '--region', GCP_LOCATION, - '--project', GCP_PROJECT_ID, - '--database-version', db_version, - '--tier', 'db-f1-micro'] + return self.check_output(['curl', 'https://ipinfo.io/ip']).decode('utf-8').strip() + + def __create_sql_instance( + self, + instance_name: str, + db_version: str, + master_instance_name: Optional[str], + failover_replica_name: Optional[str], + ) -> int: + cmd = [ + 'gcloud', + 'sql', + 'instances', + 'create', + instance_name, + '--region', + GCP_LOCATION, + '--project', + GCP_PROJECT_ID, + '--database-version', + db_version, + '--tier', + 'db-f1-micro', + ] if master_instance_name: cmd.extend(['--master-instance-name', master_instance_name]) self.__wait_for_operations(master_instance_name) @@ -332,39 +401,57 @@ def __create_sql_instance(self, instance_name: str, db_version: str, master_inst def __get_server_ca_cert(self, instance_name: str) -> bytes: self.log.info('Getting server CA cert for "%s"...', instance_name) output = self.check_output( - ['gcloud', 'sql', 'instances', 'describe', instance_name, - '--format=value(serverCaCert.cert)']) + ['gcloud', 'sql', 'instances', 'describe', instance_name, '--format=value(serverCaCert.cert)'] + ) self.log.info('... Done.') return output def __get_client_cert(self, instance_name: str, client_cert_name: str) -> bytes: self.log.info('Getting client cert for "%s"...', instance_name) output = self.check_output( - ['gcloud', 'sql', 'ssl', 'client-certs', 'describe', client_cert_name, '-i', - instance_name, '--format=get(cert)']) + [ + 'gcloud', + 'sql', + 'ssl', + 'client-certs', + 'describe', + client_cert_name, + '-i', + instance_name, + '--format=get(cert)', + ] + ) self.log.info('... Done.') return output def __create_user(self, instance_name: str, username: str) -> None: - self.log.info('Creating user "%s" in Cloud SQL instance "%s"...', username, - instance_name) - self.execute_cmd(['gcloud', 'sql', 'users', 'create', username, '-i', - instance_name, '--host', '%', '--password', 'JoxHlwrPzwch0gz9', - '--quiet']) + self.log.info('Creating user "%s" in Cloud SQL instance "%s"...', username, instance_name) + self.execute_cmd( + [ + 'gcloud', + 'sql', + 'users', + 'create', + username, + '-i', + instance_name, + '--host', + '%', + '--password', + 'JoxHlwrPzwch0gz9', + '--quiet', + ] + ) self.log.info('... Done.') def __delete_db(self, instance_name: str, db_name: str) -> None: - self.log.info('Deleting database "%s" in Cloud SQL instance "%s"...', db_name, - instance_name) - self.execute_cmd(['gcloud', 'sql', 'databases', 'delete', db_name, '-i', - instance_name, '--quiet']) + self.log.info('Deleting database "%s" in Cloud SQL instance "%s"...', db_name, instance_name) + self.execute_cmd(['gcloud', 'sql', 'databases', 'delete', db_name, '-i', instance_name, '--quiet']) self.log.info('... Done.') def __create_db(self, instance_name: str, db_name: str) -> None: - self.log.info('Creating database "%s" in Cloud SQL instance "%s"...', db_name, - instance_name) - self.execute_cmd(['gcloud', 'sql', 'databases', 'create', db_name, '-i', - instance_name, '--quiet']) + self.log.info('Creating database "%s" in Cloud SQL instance "%s"...', db_name, instance_name) + self.execute_cmd(['gcloud', 'sql', 'databases', 'create', db_name, '-i', instance_name, '--quiet']) self.log.info('... Done.') def __write_to_file(self, filepath: str, content: bytes) -> None: @@ -396,8 +483,19 @@ def __remove_keys_and_certs(self, filepaths): def __delete_client_cert(self, instance_name, common_name): self.log.info('Deleting client key and cert for "%s"...', instance_name) - self.execute_cmd(['gcloud', 'sql', 'ssl', 'client-certs', 'delete', common_name, - '--instance', instance_name, '--quiet']) + self.execute_cmd( + [ + 'gcloud', + 'sql', + 'ssl', + 'client-certs', + 'delete', + common_name, + '--instance', + instance_name, + '--quiet', + ] + ) self.log.info('... Done.') def __create_client_cert(self, instance_name, client_key_file, common_name): @@ -406,13 +504,25 @@ def __create_client_cert(self, instance_name, client_key_file, common_name): os.remove(client_key_file) except OSError: pass - self.execute_cmd(['gcloud', 'sql', 'ssl', 'client-certs', 'create', common_name, - client_key_file, '-i', instance_name]) + self.execute_cmd( + [ + 'gcloud', + 'sql', + 'ssl', + 'client-certs', + 'create', + common_name, + client_key_file, + '-i', + instance_name, + ] + ) self.log.info('... Done.') def __get_operation_name(self, instance_name: str) -> str: op_name_bytes = self.check_output( - ['gcloud', 'sql', 'operations', 'list', '--instance', instance_name, '--format=get(name)']) + ['gcloud', 'sql', 'operations', 'list', '--instance', instance_name, '--format=get(name)'] + ) return op_name_bytes.decode('utf-8').strip() def __print_operations(self, operations): @@ -431,31 +541,49 @@ def __wait_for_operations(self, instance_name: str) -> None: break def __get_ip_address(self, instance_name: str, env_var: str) -> str: - ip_address = self.check_output( - ['gcloud', 'sql', 'instances', 'describe', instance_name, - '--format=get(ipAddresses[0].ipAddress)'] - ).decode('utf-8').strip() + ip_address = ( + self.check_output( + [ + 'gcloud', + 'sql', + 'instances', + 'describe', + instance_name, + '--format=get(ipAddresses[0].ipAddress)', + ] + ) + .decode('utf-8') + .strip() + ) os.environ[env_var] = ip_address return "{}={}".format(env_var, ip_address) def __get_operations(self, instance_name: str) -> str: op_name_bytes = self.check_output( - ['gcloud', 'sql', 'operations', 'list', '-i', - instance_name, '--format=get(NAME,TYPE,STATUS)']) + ['gcloud', 'sql', 'operations', 'list', '-i', instance_name, '--format=get(NAME,TYPE,STATUS)'] + ) return op_name_bytes.decode('utf-8').strip() def __wait_for_create(self, operation_name: str) -> None: - self.execute_cmd(['gcloud', 'beta', 'sql', 'operations', 'wait', - '--project', GCP_PROJECT_ID, operation_name]) + self.execute_cmd( + ['gcloud', 'beta', 'sql', 'operations', 'wait', '--project', GCP_PROJECT_ID, operation_name] + ) if __name__ == '__main__': - parser = argparse.ArgumentParser( - description='Create or delete Cloud SQL instances for system tests.') - parser.add_argument('--action', required=True, - choices=('create', 'delete', 'setup-instances', - 'create-query', 'delete-query', - 'delete-service-accounts-acls')) + parser = argparse.ArgumentParser(description='Create or delete Cloud SQL instances for system tests.') + parser.add_argument( + '--action', + required=True, + choices=( + 'create', + 'delete', + 'setup-instances', + 'create-query', + 'delete-query', + 'delete-service-accounts-acls', + ), + ) action = parser.parse_args().action helper = CloudSqlQueryTestHelper() @@ -467,8 +595,7 @@ def __wait_for_create(self, operation_name: str) -> None: gcp_authenticator.gcp_authenticate() if action == 'create': helper.create_instances(failover_instance_suffix='-failover-replica') - helper.create_instances(instance_suffix="-read-replica", - master_instance_suffix='') + helper.create_instances(instance_suffix="-read-replica", master_instance_suffix='') helper.create_instances(instance_suffix="2") helper.create_instances(instance_suffix=QUERY_SUFFIX) helper.setup_instances(instance_suffix=QUERY_SUFFIX) diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py index 3495546135f24..4ea070f4b2104 100644 --- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py @@ -29,17 +29,37 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, TaskInstance from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( - ACCESS_KEY_ID, AWS_ACCESS_KEY, AWS_S3_DATA_SOURCE, BUCKET_NAME, FILTER_JOB_NAMES, GCS_DATA_SINK, - GCS_DATA_SOURCE, HTTP_DATA_SOURCE, LIST_URL, NAME, SCHEDULE, SCHEDULE_END_DATE, SCHEDULE_START_DATE, - SECRET_ACCESS_KEY, START_TIME_OF_DAY, STATUS, TRANSFER_SPEC, + ACCESS_KEY_ID, + AWS_ACCESS_KEY, + AWS_S3_DATA_SOURCE, + BUCKET_NAME, + FILTER_JOB_NAMES, + GCS_DATA_SINK, + GCS_DATA_SOURCE, + HTTP_DATA_SOURCE, + LIST_URL, + NAME, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + SECRET_ACCESS_KEY, + START_TIME_OF_DAY, + STATUS, + TRANSFER_SPEC, ) from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( - CloudDataTransferServiceCancelOperationOperator, CloudDataTransferServiceCreateJobOperator, - CloudDataTransferServiceDeleteJobOperator, CloudDataTransferServiceGCSToGCSOperator, - CloudDataTransferServiceGetOperationOperator, CloudDataTransferServiceListOperationsOperator, - CloudDataTransferServicePauseOperationOperator, CloudDataTransferServiceResumeOperationOperator, - CloudDataTransferServiceS3ToGCSOperator, CloudDataTransferServiceUpdateJobOperator, - TransferJobPreprocessor, TransferJobValidator, + CloudDataTransferServiceCancelOperationOperator, + CloudDataTransferServiceCreateJobOperator, + CloudDataTransferServiceDeleteJobOperator, + CloudDataTransferServiceGCSToGCSOperator, + CloudDataTransferServiceGetOperationOperator, + CloudDataTransferServiceListOperationsOperator, + CloudDataTransferServicePauseOperationOperator, + CloudDataTransferServiceResumeOperationOperator, + CloudDataTransferServiceS3ToGCSOperator, + CloudDataTransferServiceUpdateJobOperator, + TransferJobPreprocessor, + TransferJobValidator, ) from airflow.utils import timezone @@ -169,12 +189,15 @@ def test_should_not_change_time_for_dict(self): def test_should_set_default_schedule(self): body = {} TransferJobPreprocessor(body=body, default_schedule=True).process_body() - self.assertEqual(body, { - SCHEDULE: { - SCHEDULE_END_DATE: {'day': 15, 'month': 10, 'year': 2018}, - SCHEDULE_START_DATE: {'day': 15, 'month': 10, 'year': 2018} - } - }) + self.assertEqual( + body, + { + SCHEDULE: { + SCHEDULE_END_DATE: {'day': 15, 'month': 10, 'year': 2018}, + SCHEDULE_START_DATE: {'day': 15, 'month': 10, 'year': 2018}, + } + }, + ) class TestTransferJobValidator(unittest.TestCase): @@ -236,16 +259,12 @@ def test_job_create_gcs(self, mock_hook): body = deepcopy(VALID_TRANSFER_JOB_GCS) del body['name'] op = CloudDataTransferServiceCreateJobOperator( - body=body, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + body=body, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_transfer_job.assert_called_once_with(body=VALID_TRANSFER_JOB_GCS_RAW) @@ -264,17 +283,13 @@ def test_job_create_aws(self, aws_hook, mock_hook): body = deepcopy(VALID_TRANSFER_JOB_AWS) del body['name'] op = CloudDataTransferServiceCreateJobOperator( - body=body, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + body=body, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_transfer_job.assert_called_once_with(body=VALID_TRANSFER_JOB_AWS_RAW) @@ -333,17 +348,12 @@ def test_job_update(self, mock_hook): body = {'transferJob': {'description': 'example-name'}, 'updateTransferJobFieldMask': DESCRIPTION} op = CloudDataTransferServiceUpdateJobOperator( - job_name=JOB_NAME, - body=body, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + job_name=JOB_NAME, body=body, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_transfer_job.assert_called_once_with(job_name=JOB_NAME, body=body) self.assertEqual(result, VALID_TRANSFER_JOB_GCS) @@ -383,9 +393,7 @@ def test_job_delete(self, mock_hook): ) op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_transfer_job.assert_called_once_with( job_name=JOB_NAME, project_id=GCP_PROJECT_ID @@ -433,15 +441,11 @@ class TestGpcStorageTransferOperationsGetOperator(unittest.TestCase): def test_operation_get(self, mock_hook): mock_hook.return_value.get_transfer_operation.return_value = VALID_OPERATION op = CloudDataTransferServiceGetOperationOperator( - operation_name=OPERATION_NAME, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + operation_name=OPERATION_NAME, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.get_transfer_operation.assert_called_once_with(operation_name=OPERATION_NAME) self.assertEqual(result, VALID_OPERATION) @@ -482,15 +486,11 @@ class TestGcpStorageTransferOperationListOperator(unittest.TestCase): def test_operation_list(self, mock_hook): mock_hook.return_value.list_transfer_operations.return_value = [VALID_TRANSFER_JOB_GCS] op = CloudDataTransferServiceListOperationsOperator( - request_filter=TEST_FILTER, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + request_filter=TEST_FILTER, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.list_transfer_operations.assert_called_once_with(request_filter=TEST_FILTER) self.assertEqual(result, [VALID_TRANSFER_JOB_GCS]) @@ -527,15 +527,11 @@ class TestGcpStorageTransferOperationsPauseOperator(unittest.TestCase): ) def test_operation_pause(self, mock_hook): op = CloudDataTransferServicePauseOperationOperator( - operation_name=OPERATION_NAME, - task_id='task-id', - google_impersonation_chain=IMPERSONATION_CHAIN, + operation_name=OPERATION_NAME, task_id='task-id', google_impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(None) mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.pause_transfer_operation.assert_called_once_with(operation_name=OPERATION_NAME) @@ -580,15 +576,11 @@ class TestGcpStorageTransferOperationsResumeOperator(unittest.TestCase): ) def test_operation_resume(self, mock_hook): op = CloudDataTransferServiceResumeOperationOperator( - operation_name=OPERATION_NAME, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + operation_name=OPERATION_NAME, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.resume_transfer_operation.assert_called_once_with( operation_name=OPERATION_NAME @@ -636,15 +628,11 @@ class TestGcpStorageTransferOperationsCancelOperator(unittest.TestCase): ) def test_operation_cancel(self, mock_hook): op = CloudDataTransferServiceCancelOperationOperator( - operation_name=OPERATION_NAME, - task_id=TASK_ID, - google_impersonation_chain=IMPERSONATION_CHAIN, + operation_name=OPERATION_NAME, task_id=TASK_ID, google_impersonation_chain=IMPERSONATION_CHAIN, ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=IMPERSONATION_CHAIN, + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.cancel_transfer_operation.assert_called_once_with( operation_name=OPERATION_NAME diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service_system.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service_system.py index a7027e25d7c4a..5e4a895954395 100644 --- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service_system.py +++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service_system.py @@ -18,7 +18,9 @@ import pytest from airflow.providers.google.cloud.example_dags.example_cloud_storage_transfer_service_gcp import ( - GCP_PROJECT_ID, GCP_TRANSFER_FIRST_TARGET_BUCKET, GCP_TRANSFER_SECOND_TARGET_BUCKET, + GCP_PROJECT_ID, + GCP_TRANSFER_FIRST_TARGET_BUCKET, + GCP_TRANSFER_SECOND_TARGET_BUCKET, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_TRANSFER_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context diff --git a/tests/providers/google/cloud/operators/test_compute.py b/tests/providers/google/cloud/operators/test_compute.py index b89173aa604e1..d200d9fbd9f15 100644 --- a/tests/providers/google/cloud/operators/test_compute.py +++ b/tests/providers/google/cloud/operators/test_compute.py @@ -29,8 +29,10 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, TaskInstance from airflow.providers.google.cloud.operators.compute import ( - ComputeEngineCopyInstanceTemplateOperator, ComputeEngineInstanceGroupUpdateManagerTemplateOperator, - ComputeEngineSetMachineTypeOperator, ComputeEngineStartInstanceOperator, + ComputeEngineCopyInstanceTemplateOperator, + ComputeEngineInstanceGroupUpdateManagerTemplateOperator, + ComputeEngineSetMachineTypeOperator, + ComputeEngineStartInstanceOperator, ComputeEngineStopInstanceOperator, ) from airflow.utils import timezone @@ -53,15 +55,12 @@ class TestGceInstanceStart(unittest.TestCase): def test_instance_start(self, mock_hook): mock_hook.return_value.start_instance.return_value = True op = ComputeEngineStartInstanceOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - task_id='id' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id' ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.start_instance.assert_called_once_with( zone=GCE_ZONE, resource_id=RESOURCE_ID, project_id=GCP_PROJECT_ID ) @@ -72,9 +71,7 @@ def test_instance_start(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_instance_start_with_templates(self, _): dag_id = 'test_dag_id' - args = { - 'start_date': DEFAULT_DATE - } + args = {'start_date': DEFAULT_DATE} self.dag = DAG(dag_id, default_args=args) # pylint: disable=attribute-defined-outside-init op = ComputeEngineStartInstanceOperator( project_id='{{ dag.dag_id }}', @@ -83,7 +80,7 @@ def test_instance_start_with_templates(self, _): gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', - dag=self.dag + dag=self.dag, ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() @@ -97,10 +94,7 @@ def test_instance_start_with_templates(self, _): def test_start_should_throw_ex_when_missing_project_id(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineStartInstanceOperator( - project_id="", - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - task_id='id' + project_id="", zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id' ) op.execute(None) err = cm.exception @@ -109,21 +103,14 @@ def test_start_should_throw_ex_when_missing_project_id(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_start_should_not_throw_ex_when_project_id_none(self, _): - op = ComputeEngineStartInstanceOperator( - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - task_id='id' - ) + op = ComputeEngineStartInstanceOperator(zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id') op.execute(None) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_start_should_throw_ex_when_missing_zone(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineStartInstanceOperator( - project_id=GCP_PROJECT_ID, - zone="", - resource_id=RESOURCE_ID, - task_id='id' + project_id=GCP_PROJECT_ID, zone="", resource_id=RESOURCE_ID, task_id='id' ) op.execute(None) err = cm.exception @@ -134,10 +121,7 @@ def test_start_should_throw_ex_when_missing_zone(self, mock_hook): def test_start_should_throw_ex_when_missing_resource_id(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineStartInstanceOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id="", - task_id='id' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id="", task_id='id' ) op.execute(None) err = cm.exception @@ -149,15 +133,12 @@ class TestGceInstanceStop(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_instance_stop(self, mock_hook): op = ComputeEngineStopInstanceOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - task_id='id' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id' ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.stop_instance.assert_called_once_with( zone=GCE_ZONE, resource_id=RESOURCE_ID, project_id=GCP_PROJECT_ID ) @@ -167,9 +148,7 @@ def test_instance_stop(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_instance_stop_with_templates(self, _): dag_id = 'test_dag_id' - args = { - 'start_date': DEFAULT_DATE - } + args = {'start_date': DEFAULT_DATE} self.dag = DAG(dag_id, default_args=args) # pylint: disable=attribute-defined-outside-init op = ComputeEngineStopInstanceOperator( project_id='{{ dag.dag_id }}', @@ -178,7 +157,7 @@ def test_instance_stop_with_templates(self, _): gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', - dag=self.dag + dag=self.dag, ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() @@ -192,10 +171,7 @@ def test_instance_stop_with_templates(self, _): def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineStopInstanceOperator( - project_id="", - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - task_id='id' + project_id="", zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id' ) op.execute(None) err = cm.exception @@ -204,15 +180,11 @@ def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_stop_should_not_throw_ex_when_project_id_none(self, mock_hook): - op = ComputeEngineStopInstanceOperator( - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - task_id='id' - ) + op = ComputeEngineStopInstanceOperator(zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id') op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.stop_instance.assert_called_once_with( zone=GCE_ZONE, resource_id=RESOURCE_ID, project_id=None ) @@ -221,10 +193,7 @@ def test_stop_should_not_throw_ex_when_project_id_none(self, mock_hook): def test_stop_should_throw_ex_when_missing_zone(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineStopInstanceOperator( - project_id=GCP_PROJECT_ID, - zone="", - resource_id=RESOURCE_ID, - task_id='id' + project_id=GCP_PROJECT_ID, zone="", resource_id=RESOURCE_ID, task_id='id' ) op.execute(None) err = cm.exception @@ -235,10 +204,7 @@ def test_stop_should_throw_ex_when_missing_zone(self, mock_hook): def test_stop_should_throw_ex_when_missing_resource_id(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineStopInstanceOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id="", - task_id='id' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id="", task_id='id' ) op.execute(None) err = cm.exception @@ -255,17 +221,14 @@ def test_set_machine_type(self, mock_hook): zone=GCE_ZONE, resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, - task_id='id' + task_id='id', ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.set_machine_type.assert_called_once_with( - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - body=SET_MACHINE_TYPE_BODY, - project_id=GCP_PROJECT_ID + zone=GCE_ZONE, resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, project_id=GCP_PROJECT_ID ) # Setting all of the operator's input parameters as templated dag_ids @@ -273,9 +236,7 @@ def test_set_machine_type(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_set_machine_type_with_templates(self, _): dag_id = 'test_dag_id' - args = { - 'start_date': DEFAULT_DATE - } + args = {'start_date': DEFAULT_DATE} self.dag = DAG(dag_id, default_args=args) # pylint: disable=attribute-defined-outside-init op = ComputeEngineSetMachineTypeOperator( project_id='{{ dag.dag_id }}', @@ -285,7 +246,7 @@ def test_set_machine_type_with_templates(self, _): gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', - dag=self.dag + dag=self.dag, ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() @@ -303,7 +264,7 @@ def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hoo zone=GCE_ZONE, resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, - task_id='id' + task_id='id', ) op.execute(None) err = cm.exception @@ -313,20 +274,14 @@ def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hoo @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_set_machine_type_should_not_throw_ex_when_project_id_none(self, mock_hook): op = ComputeEngineSetMachineTypeOperator( - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - body=SET_MACHINE_TYPE_BODY, - task_id='id' + zone=GCE_ZONE, resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, task_id='id' ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.set_machine_type.assert_called_once_with( - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - body=SET_MACHINE_TYPE_BODY, - project_id=None + zone=GCE_ZONE, resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, project_id=None ) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') @@ -337,7 +292,7 @@ def test_set_machine_type_should_throw_ex_when_missing_zone(self, mock_hook): zone="", resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, - task_id='id' + task_id='id', ) op.execute(None) err = cm.exception @@ -352,7 +307,7 @@ def test_set_machine_type_should_throw_ex_when_missing_resource_id(self, mock_ho zone=GCE_ZONE, resource_id="", body=SET_MACHINE_TYPE_BODY, - task_id='id' + task_id='id', ) op.execute(None) err = cm.exception @@ -363,66 +318,67 @@ def test_set_machine_type_should_throw_ex_when_missing_resource_id(self, mock_ho def test_set_machine_type_should_throw_ex_when_missing_machine_type(self, mock_hook): with self.assertRaises(AirflowException) as cm: op = ComputeEngineSetMachineTypeOperator( - project_id=GCP_PROJECT_ID, - zone=GCE_ZONE, - resource_id=RESOURCE_ID, - body={}, - task_id='id' + project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=RESOURCE_ID, body={}, task_id='id' ) op.execute(None) err = cm.exception - self.assertIn( - "The required body field 'machineType' is missing. Please add it.", str(err)) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) - - MOCK_OP_RESPONSE = "{'kind': 'compute#operation', 'id': '8529919847974922736', " \ - "'name': " \ - "'operation-1538578207537-577542784f769-7999ab71-94f9ec1d', " \ - "'zone': 'https://www.googleapis.com/compute/v1/projects/example" \ - "-project/zones/europe-west3-b', 'operationType': " \ - "'setMachineType', 'targetLink': " \ - "'https://www.googleapis.com/compute/v1/projects/example-project" \ - "/zones/europe-west3-b/instances/pa-1', 'targetId': " \ - "'2480086944131075860', 'status': 'DONE', 'user': " \ - "'service-account@example-project.iam.gserviceaccount.com', " \ - "'progress': 100, 'insertTime': '2018-10-03T07:50:07.951-07:00', "\ - "'startTime': '2018-10-03T07:50:08.324-07:00', 'endTime': " \ - "'2018-10-03T07:50:08.484-07:00', 'error': {'errors': [{'code': " \ - "'UNSUPPORTED_OPERATION', 'message': \"Machine type with name " \ - "'machine-type-1' does not exist in zone 'europe-west3-b'.\"}]}, "\ - "'httpErrorStatusCode': 400, 'httpErrorMessage': 'BAD REQUEST', " \ - "'selfLink': " \ - "'https://www.googleapis.com/compute/v1/projects/example-project" \ - "/zones/europe-west3-b/operations/operation-1538578207537" \ - "-577542784f769-7999ab71-94f9ec1d'} " - - @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook' - '._check_zone_operation_status') - @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook' - '._execute_set_machine_type') + self.assertIn("The required body field 'machineType' is missing. Please add it.", str(err)) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) + + MOCK_OP_RESPONSE = ( + "{'kind': 'compute#operation', 'id': '8529919847974922736', " + "'name': " + "'operation-1538578207537-577542784f769-7999ab71-94f9ec1d', " + "'zone': 'https://www.googleapis.com/compute/v1/projects/example" + "-project/zones/europe-west3-b', 'operationType': " + "'setMachineType', 'targetLink': " + "'https://www.googleapis.com/compute/v1/projects/example-project" + "/zones/europe-west3-b/instances/pa-1', 'targetId': " + "'2480086944131075860', 'status': 'DONE', 'user': " + "'service-account@example-project.iam.gserviceaccount.com', " + "'progress': 100, 'insertTime': '2018-10-03T07:50:07.951-07:00', " + "'startTime': '2018-10-03T07:50:08.324-07:00', 'endTime': " + "'2018-10-03T07:50:08.484-07:00', 'error': {'errors': [{'code': " + "'UNSUPPORTED_OPERATION', 'message': \"Machine type with name " + "'machine-type-1' does not exist in zone 'europe-west3-b'.\"}]}, " + "'httpErrorStatusCode': 400, 'httpErrorMessage': 'BAD REQUEST', " + "'selfLink': " + "'https://www.googleapis.com/compute/v1/projects/example-project" + "/zones/europe-west3-b/operations/operation-1538578207537" + "-577542784f769-7999ab71-94f9ec1d'} " + ) + + @mock.patch( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineHook' '._check_zone_operation_status' + ) + @mock.patch( + 'airflow.providers.google.cloud.operators.compute.ComputeEngineHook' '._execute_set_machine_type' + ) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook.get_conn') def test_set_machine_type_should_handle_and_trim_gce_error( - self, get_conn, _execute_set_machine_type, _check_zone_operation_status): + self, get_conn, _execute_set_machine_type, _check_zone_operation_status + ): get_conn.return_value = {} _execute_set_machine_type.return_value = {"name": "test-operation"} - _check_zone_operation_status.return_value = ast.literal_eval( - self.MOCK_OP_RESPONSE) + _check_zone_operation_status.return_value = ast.literal_eval(self.MOCK_OP_RESPONSE) with self.assertRaises(AirflowException) as cm: op = ComputeEngineSetMachineTypeOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=RESOURCE_ID, body=SET_MACHINE_TYPE_BODY, - task_id='id' + task_id='id', ) op.execute(None) err = cm.exception _check_zone_operation_status.assert_called_once_with( - {}, "test-operation", GCP_PROJECT_ID, GCE_ZONE, mock.ANY) + {}, "test-operation", GCP_PROJECT_ID, GCE_ZONE, mock.ANY + ) _execute_set_machine_type.assert_called_once_with( - GCE_ZONE, RESOURCE_ID, SET_MACHINE_TYPE_BODY, GCP_PROJECT_ID) + GCE_ZONE, RESOURCE_ID, SET_MACHINE_TYPE_BODY, GCP_PROJECT_ID + ) # Checking the full message was sometimes failing due to different order # of keys in the serialized JSON self.assertIn("400 BAD REQUEST: {", str(err)) # checking the square bracket trim @@ -445,41 +401,20 @@ def test_set_machine_type_should_handle_and_trim_gce_error( { "kind": "compute#networkInterface", "network": "https://www.googleapis.com/compute/v1/" - "projects/project/global/networks/default", - "accessConfigs": [ - { - "kind": "compute#accessConfig", - "type": "ONE_TO_ONE_NAT", - } - ] + "projects/project/global/networks/default", + "accessConfigs": [{"kind": "compute#accessConfig", "type": "ONE_TO_ONE_NAT",}], }, { "network": "https://www.googleapis.com/compute/v1/" - "projects/project/global/networks/default", - "accessConfigs": [ - { - "kind": "compute#accessConfig", - "networkTier": "PREMIUM" - } - ] - } - ], - "disks": [ - { - "kind": "compute#attachedDisk", - "type": "PERSISTENT", - "licenses": [ - "A String", - ] - } + "projects/project/global/networks/default", + "accessConfigs": [{"kind": "compute#accessConfig", "networkTier": "PREMIUM"}], + }, ], - "metadata": { - "kind": "compute#metadata", - "fingerprint": "GDPUYxlwHe4=" - }, + "disks": [{"kind": "compute#attachedDisk", "type": "PERSISTENT", "licenses": ["A String",]}], + "metadata": {"kind": "compute#metadata", "fingerprint": "GDPUYxlwHe4="}, }, "selfLink": "https://www.googleapis.com/compute/v1/projects/project" - "/global/instanceTemplates/instance-template-test" + "/global/instanceTemplates/instance-template-test", } GCE_INSTANCE_TEMPLATE_BODY_INSERT = { @@ -490,31 +425,17 @@ def test_set_machine_type_should_handle_and_trim_gce_error( "networkInterfaces": [ { "network": "https://www.googleapis.com/compute/v1/" - "projects/project/global/networks/default", - "accessConfigs": [ - { - "type": "ONE_TO_ONE_NAT", - } - ] + "projects/project/global/networks/default", + "accessConfigs": [{"type": "ONE_TO_ONE_NAT",}], }, { "network": "https://www.googleapis.com/compute/v1/" - "projects/project/global/networks/default", - "accessConfigs": [ - { - "networkTier": "PREMIUM" - } - ] - } - ], - "disks": [ - { - "type": "PERSISTENT", - } + "projects/project/global/networks/default", + "accessConfigs": [{"networkTier": "PREMIUM"}], + }, ], - "metadata": { - "fingerprint": "GDPUYxlwHe4=" - }, + "disks": [{"type": "PERSISTENT",}], + "metadata": {"fingerprint": "GDPUYxlwHe4="}, }, } @@ -528,22 +449,20 @@ def test_successful_copy_template(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, resource_id=GCE_INSTANCE_TEMPLATE_NAME, task_id='id', - body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME} + body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME}, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=GCP_PROJECT_ID, - body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, - request_id=None + project_id=GCP_PROJECT_ID, body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, request_id=None ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -552,39 +471,35 @@ def test_successful_copy_template_missing_project_id(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( resource_id=GCE_INSTANCE_TEMPLATE_NAME, task_id='id', - body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME} + body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME}, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=None, - body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, - request_id=None + project_id=None, body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, request_id=None ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_idempotent_copy_template_when_already_copied(self, mock_hook): - mock_hook.return_value.get_instance_template.side_effect = [ - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW - ] + mock_hook.return_value.get_instance_template.side_effect = [GCE_INSTANCE_TEMPLATE_BODY_GET_NEW] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, resource_id=GCE_INSTANCE_TEMPLATE_NAME, task_id='id', - body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME} + body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME}, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.insert_instance_template.assert_not_called() self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -593,19 +508,19 @@ def test_successful_copy_template_with_request_id(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, resource_id=GCE_INSTANCE_TEMPLATE_NAME, request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID, task_id='id', - body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME} + body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME}, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.insert_instance_template.assert_called_once_with( project_id=GCP_PROJECT_ID, body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, @@ -618,27 +533,24 @@ def test_successful_copy_template_with_description_fields(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, resource_id=GCE_INSTANCE_TEMPLATE_NAME, request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID, task_id='id', - body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME, - "description": "New description"} + body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME, "description": "New description"}, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) body_insert = deepcopy(GCE_INSTANCE_TEMPLATE_BODY_INSERT) body_insert["description"] = "New description" mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=GCP_PROJECT_ID, - body=body_insert, - request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID, + project_id=GCP_PROJECT_ID, body=body_insert, request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID, ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -647,29 +559,27 @@ def test_copy_with_some_validation_warnings(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, resource_id=GCE_INSTANCE_TEMPLATE_NAME, task_id='id', - body_patch={"name": GCE_INSTANCE_TEMPLATE_NEW_NAME, - "some_wrong_field": "test", - "properties": { - "some_other_wrong_field": "test" - }} + body_patch={ + "name": GCE_INSTANCE_TEMPLATE_NEW_NAME, + "some_wrong_field": "test", + "properties": {"some_other_wrong_field": "test"}, + }, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) body_insert = deepcopy(GCE_INSTANCE_TEMPLATE_BODY_INSERT) body_insert["some_wrong_field"] = "test" body_insert["properties"]["some_other_wrong_field"] = "test" mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=GCP_PROJECT_ID, - body=body_insert, - request_id=None, + project_id=GCP_PROJECT_ID, body=body_insert, request_id=None, ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -678,7 +588,7 @@ def test_successful_copy_template_with_updated_nested_fields(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, @@ -686,21 +596,17 @@ def test_successful_copy_template_with_updated_nested_fields(self, mock_hook): task_id='id', body_patch={ "name": GCE_INSTANCE_TEMPLATE_NEW_NAME, - "properties": { - "machineType": "n1-standard-2", - } - } + "properties": {"machineType": "n1-standard-2",}, + }, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) body_insert = deepcopy(GCE_INSTANCE_TEMPLATE_BODY_INSERT) body_insert["properties"]["machineType"] = "n1-standard-2" mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=GCP_PROJECT_ID, - body=body_insert, - request_id=None + project_id=GCP_PROJECT_ID, body=body_insert, request_id=None ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -709,7 +615,7 @@ def test_successful_copy_template_with_smaller_array_fields(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, @@ -722,39 +628,27 @@ def test_successful_copy_template_with_smaller_array_fields(self, mock_hook): "networkInterfaces": [ { "network": "https://www.googleapis.com/compute/v1/" - "projects/project/global/networks/default", - "accessConfigs": [ - { - "type": "ONE_TO_ONE_NAT", - "natIP": "8.8.8.8" - } - ] + "projects/project/global/networks/default", + "accessConfigs": [{"type": "ONE_TO_ONE_NAT", "natIP": "8.8.8.8"}], } - ] - } - } + ], + }, + }, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) body_insert = deepcopy(GCE_INSTANCE_TEMPLATE_BODY_INSERT) body_insert["properties"]["networkInterfaces"] = [ { "network": "https://www.googleapis.com/compute/v1/" - "projects/project/global/networks/default", - "accessConfigs": [ - { - "type": "ONE_TO_ONE_NAT", - "natIP": "8.8.8.8" - } - ] + "projects/project/global/networks/default", + "accessConfigs": [{"type": "ONE_TO_ONE_NAT", "natIP": "8.8.8.8"}], } ] mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=GCP_PROJECT_ID, - body=body_insert, - request_id=None + project_id=GCP_PROJECT_ID, body=body_insert, request_id=None ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -763,7 +657,7 @@ def test_successful_copy_template_with_bigger_array_fields(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] op = ComputeEngineCopyInstanceTemplateOperator( project_id=GCP_PROJECT_ID, @@ -773,50 +667,28 @@ def test_successful_copy_template_with_bigger_array_fields(self, mock_hook): "name": GCE_INSTANCE_TEMPLATE_NEW_NAME, "properties": { "disks": [ - { - "kind": "compute#attachedDisk", - "type": "SCRATCH", - "licenses": [ - "Updated String", - ] - }, + {"kind": "compute#attachedDisk", "type": "SCRATCH", "licenses": ["Updated String",]}, { "kind": "compute#attachedDisk", "type": "PERSISTENT", - "licenses": [ - "Another String", - ] - } + "licenses": ["Another String",], + }, ], - } - } + }, + }, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) body_insert = deepcopy(GCE_INSTANCE_TEMPLATE_BODY_INSERT) body_insert["properties"]["disks"] = [ - { - "kind": "compute#attachedDisk", - "type": "SCRATCH", - "licenses": [ - "Updated String", - ] - }, - { - "kind": "compute#attachedDisk", - "type": "PERSISTENT", - "licenses": [ - "Another String", - ] - } + {"kind": "compute#attachedDisk", "type": "SCRATCH", "licenses": ["Updated String",]}, + {"kind": "compute#attachedDisk", "type": "PERSISTENT", "licenses": ["Another String",]}, ] mock_hook.return_value.insert_instance_template.assert_called_once_with( - project_id=GCP_PROJECT_ID, - body=body_insert, - request_id=None, + project_id=GCP_PROJECT_ID, body=body_insert, request_id=None, ) self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result) @@ -825,7 +697,7 @@ def test_missing_name(self, mock_hook): mock_hook.return_value.get_instance_template.side_effect = [ HttpError(resp=httplib2.Response({'status': 404}), content=EMPTY_CONTENT), GCE_INSTANCE_TEMPLATE_BODY_GET, - GCE_INSTANCE_TEMPLATE_BODY_GET_NEW + GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, ] with self.assertRaises(AirflowException) as cm: op = ComputeEngineCopyInstanceTemplateOperator( @@ -833,31 +705,34 @@ def test_missing_name(self, mock_hook): resource_id=GCE_INSTANCE_TEMPLATE_NAME, request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID, task_id='id', - body_patch={"description": "New description"} + body_patch={"description": "New description"}, ) op.execute(None) err = cm.exception - self.assertIn("should contain at least name for the new operator " - "in the 'name' field", str(err)) + self.assertIn("should contain at least name for the new operator " "in the 'name' field", str(err)) mock_hook.assert_not_called() GCE_INSTANCE_GROUP_MANAGER_NAME = "instance-group-test" -GCE_INSTANCE_TEMPLATE_SOURCE_URL = \ - "https://www.googleapis.com/compute/beta/projects/project" \ +GCE_INSTANCE_TEMPLATE_SOURCE_URL = ( + "https://www.googleapis.com/compute/beta/projects/project" "/global/instanceTemplates/instance-template-test" +) -GCE_INSTANCE_TEMPLATE_OTHER_URL = \ - "https://www.googleapis.com/compute/beta/projects/project" \ +GCE_INSTANCE_TEMPLATE_OTHER_URL = ( + "https://www.googleapis.com/compute/beta/projects/project" "/global/instanceTemplates/instance-template-other" +) -GCE_INSTANCE_TEMPLATE_NON_EXISTING_URL = \ - "https://www.googleapis.com/compute/beta/projects/project" \ +GCE_INSTANCE_TEMPLATE_NON_EXISTING_URL = ( + "https://www.googleapis.com/compute/beta/projects/project" "/global/instanceTemplates/instance-template-non-existing" +) -GCE_INSTANCE_TEMPLATE_DESTINATION_URL = \ - "https://www.googleapis.com/compute/beta/projects/project" \ +GCE_INSTANCE_TEMPLATE_DESTINATION_URL = ( + "https://www.googleapis.com/compute/beta/projects/project" "/global/instanceTemplates/instance-template-new" +) GCE_INSTANCE_GROUP_MANAGER_GET = { "kind": "compute#instanceGroupManager", @@ -867,17 +742,8 @@ def test_missing_name(self, mock_hook): "zone": "https://www.googleapis.com/compute/beta/projects/project/zones/zone", "instanceTemplate": GCE_INSTANCE_TEMPLATE_SOURCE_URL, "versions": [ - { - "name": "v1", - "instanceTemplate": GCE_INSTANCE_TEMPLATE_SOURCE_URL, - "targetSize": { - "calculated": 1 - } - }, - { - "name": "v2", - "instanceTemplate": GCE_INSTANCE_TEMPLATE_OTHER_URL, - } + {"name": "v1", "instanceTemplate": GCE_INSTANCE_TEMPLATE_SOURCE_URL, "targetSize": {"calculated": 1}}, + {"name": "v2", "instanceTemplate": GCE_INSTANCE_TEMPLATE_OTHER_URL,}, ], "instanceGroup": GCE_INSTANCE_TEMPLATE_SOURCE_URL, "baseInstanceName": GCE_INSTANCE_GROUP_MANAGER_NAME, @@ -891,23 +757,14 @@ def test_missing_name(self, mock_hook): "deleting": 0, "abandoning": 0, "restarting": 0, - "refreshing": 0 - }, - "pendingActions": { - "creating": 0, - "deleting": 0, - "recreating": 0, - "restarting": 0 + "refreshing": 0, }, + "pendingActions": {"creating": 0, "deleting": 0, "recreating": 0, "restarting": 0}, "targetSize": 1, "selfLink": "https://www.googleapis.com/compute/beta/projects/project/zones/" - "zone/instanceGroupManagers/" + GCE_INSTANCE_GROUP_MANAGER_NAME, - "autoHealingPolicies": [ - { - "initialDelaySec": 300 - } - ], - "serviceAccount": "198907790164@cloudservices.gserviceaccount.com" + "zone/instanceGroupManagers/" + GCE_INSTANCE_GROUP_MANAGER_NAME, + "autoHealingPolicies": [{"initialDelaySec": 300}], + "serviceAccount": "198907790164@cloudservices.gserviceaccount.com", } GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH = { @@ -916,14 +773,9 @@ def test_missing_name(self, mock_hook): { "name": "v1", "instanceTemplate": GCE_INSTANCE_TEMPLATE_DESTINATION_URL, - "targetSize": { - "calculated": 1 - } + "targetSize": {"calculated": 1}, }, - { - "name": "v2", - "instanceTemplate": GCE_INSTANCE_TEMPLATE_OTHER_URL, - } + {"name": "v2", "instanceTemplate": GCE_INSTANCE_TEMPLATE_OTHER_URL,}, ], } @@ -932,63 +784,61 @@ def test_missing_name(self, mock_hook): GCE_INSTANCE_GROUP_MANAGER_UPDATE_POLICY = { "type": "OPPORTUNISTIC", "minimalAction": "RESTART", - "maxSurge": { - "fixed": 1 - }, - "maxUnavailable": { - "percent": 10 - }, - "minReadySec": 1800 + "maxSurge": {"fixed": 1}, + "maxUnavailable": {"percent": 10}, + "minReadySec": 1800, } class TestGceInstanceGroupManagerUpdate(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_successful_instance_group_update(self, mock_hook): - mock_hook.return_value.get_instance_group_manager.return_value = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) + mock_hook.return_value.get_instance_group_manager.return_value = deepcopy( + GCE_INSTANCE_GROUP_MANAGER_GET + ) op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, task_id='id', source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.patch_instance_group_manager.assert_called_once_with( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, body=GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH, - request_id=None + request_id=None, ) self.assertTrue(result) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_successful_instance_group_update_missing_project_id(self, mock_hook): - mock_hook.return_value.get_instance_group_manager.return_value = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) + mock_hook.return_value.get_instance_group_manager.return_value = deepcopy( + GCE_INSTANCE_GROUP_MANAGER_GET + ) op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, task_id='id', source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.patch_instance_group_manager.assert_called_once_with( project_id=None, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, body=GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH, - request_id=None + request_id=None, ) self.assertTrue(result) @@ -996,29 +846,27 @@ def test_successful_instance_group_update_missing_project_id(self, mock_hook): def test_successful_instance_group_update_no_instance_template_field(self, mock_hook): instance_group_manager_no_template = deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) del instance_group_manager_no_template['instanceTemplate'] - mock_hook.return_value.get_instance_group_manager.return_value = \ - instance_group_manager_no_template + mock_hook.return_value.get_instance_group_manager.return_value = instance_group_manager_no_template op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, task_id='id', source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) - expected_patch_no_instance_template = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) + expected_patch_no_instance_template = deepcopy(GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH) del expected_patch_no_instance_template['instanceTemplate'] mock_hook.return_value.patch_instance_group_manager.assert_called_once_with( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, body=expected_patch_no_instance_template, - request_id=None + request_id=None, ) self.assertTrue(result) @@ -1026,36 +874,35 @@ def test_successful_instance_group_update_no_instance_template_field(self, mock_ def test_successful_instance_group_update_no_versions_field(self, mock_hook): instance_group_manager_no_versions = deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) del instance_group_manager_no_versions['versions'] - mock_hook.return_value.get_instance_group_manager.return_value = \ - instance_group_manager_no_versions + mock_hook.return_value.get_instance_group_manager.return_value = instance_group_manager_no_versions op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, task_id='id', source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) - expected_patch_no_versions = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) + expected_patch_no_versions = deepcopy(GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH) del expected_patch_no_versions['versions'] mock_hook.return_value.patch_instance_group_manager.assert_called_once_with( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, body=expected_patch_no_versions, - request_id=None + request_id=None, ) self.assertTrue(result) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_successful_instance_group_update_with_update_policy(self, mock_hook): - mock_hook.return_value.get_instance_group_manager.return_value = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) + mock_hook.return_value.get_instance_group_manager.return_value = deepcopy( + GCE_INSTANCE_GROUP_MANAGER_GET + ) op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, @@ -1063,28 +910,28 @@ def test_successful_instance_group_update_with_update_policy(self, mock_hook): task_id='id', update_policy=GCE_INSTANCE_GROUP_MANAGER_UPDATE_POLICY, source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) - expected_patch_with_update_policy = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) + expected_patch_with_update_policy = deepcopy(GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH) expected_patch_with_update_policy['updatePolicy'] = GCE_INSTANCE_GROUP_MANAGER_UPDATE_POLICY mock_hook.return_value.patch_instance_group_manager.assert_called_once_with( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, body=expected_patch_with_update_policy, - request_id=None + request_id=None, ) self.assertTrue(result) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_successful_instance_group_update_with_request_id(self, mock_hook): - mock_hook.return_value.get_instance_group_manager.return_value = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) + mock_hook.return_value.get_instance_group_manager.return_value = deepcopy( + GCE_INSTANCE_GROUP_MANAGER_GET + ) op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, @@ -1095,15 +942,15 @@ def test_successful_instance_group_update_with_request_id(self, mock_hook): destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.patch_instance_group_manager.assert_called_once_with( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, body=GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH, - request_id=GCE_INSTANCE_GROUP_MANAGER_REQUEST_ID + request_id=GCE_INSTANCE_GROUP_MANAGER_REQUEST_ID, ) self.assertTrue(result) @@ -1117,26 +964,27 @@ def test_try_to_use_api_v1(self, _): task_id='id', api_version='v1', source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) err = cm.exception self.assertIn("Use beta api version or above", str(err)) @mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook') def test_try_to_use_non_existing_template(self, mock_hook): - mock_hook.return_value.get_instance_group_manager.return_value = \ - deepcopy(GCE_INSTANCE_GROUP_MANAGER_GET) + mock_hook.return_value.get_instance_group_manager.return_value = deepcopy( + GCE_INSTANCE_GROUP_MANAGER_GET + ) op = ComputeEngineInstanceGroupUpdateManagerTemplateOperator( project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER_NAME, task_id='id', source_template=GCE_INSTANCE_TEMPLATE_NON_EXISTING_URL, - destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL + destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL, ) result = op.execute(None) - mock_hook.assert_called_once_with(api_version='beta', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='beta', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.patch_instance_group_manager.assert_not_called() self.assertTrue(result) diff --git a/tests/providers/google/cloud/operators/test_compute_system_helper.py b/tests/providers/google/cloud/operators/test_compute_system_helper.py index 99b1956df1e81..2663ac87f0826 100755 --- a/tests/providers/google/cloud/operators/test_compute_system_helper.py +++ b/tests/providers/google/cloud/operators/test_compute_system_helper.py @@ -24,76 +24,166 @@ GCE_INSTANCE = os.environ.get('GCE_INSTANCE', 'testinstance') GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') -GCE_INSTANCE_GROUP_MANAGER_NAME = os.environ.get('GCE_INSTANCE_GROUP_MANAGER_NAME', - 'instance-group-test') +GCE_INSTANCE_GROUP_MANAGER_NAME = os.environ.get('GCE_INSTANCE_GROUP_MANAGER_NAME', 'instance-group-test') GCE_ZONE = os.environ.get('GCE_ZONE', 'europe-west1-b') -GCE_TEMPLATE_NAME = os.environ.get('GCE_TEMPLATE_NAME', - 'instance-template-test') -GCE_NEW_TEMPLATE_NAME = os.environ.get('GCE_NEW_TEMPLATE_NAME', - 'instance-template-test-new') +GCE_TEMPLATE_NAME = os.environ.get('GCE_TEMPLATE_NAME', 'instance-template-test') +GCE_NEW_TEMPLATE_NAME = os.environ.get('GCE_NEW_TEMPLATE_NAME', 'instance-template-test-new') class GCPComputeTestHelper(LoggingCommandExecutor): - def delete_instance(self): - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, - '--quiet', '--verbosity=none', - 'instances', 'delete', GCE_INSTANCE, '--zone', GCE_ZONE, - ]) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + '--verbosity=none', + 'instances', + 'delete', + GCE_INSTANCE, + '--zone', + GCE_ZONE, + ] + ) def create_instance(self): - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, '--quiet', - 'instances', 'create', GCE_INSTANCE, - '--zone', GCE_ZONE - ]) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + 'instances', + 'create', + GCE_INSTANCE, + '--zone', + GCE_ZONE, + ] + ) def delete_instance_group_and_template(self, silent=False): - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, - '--quiet', '--verbosity=none', - 'instance-groups', 'managed', 'delete', GCE_INSTANCE_GROUP_MANAGER_NAME, - '--zone', GCE_ZONE - ], silent=silent) - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, - '--quiet', '--verbosity=none', - 'instance-templates', 'delete', GCE_NEW_TEMPLATE_NAME - ], silent=silent) - self.execute_cmd([ - 'gcloud', 'beta', 'compute', - '--project', GCP_PROJECT_ID, - '--quiet', '--verbosity=none', - 'instance-templates', 'delete', GCE_TEMPLATE_NAME - ], silent=silent) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + '--verbosity=none', + 'instance-groups', + 'managed', + 'delete', + GCE_INSTANCE_GROUP_MANAGER_NAME, + '--zone', + GCE_ZONE, + ], + silent=silent, + ) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + '--verbosity=none', + 'instance-templates', + 'delete', + GCE_NEW_TEMPLATE_NAME, + ], + silent=silent, + ) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + '--verbosity=none', + 'instance-templates', + 'delete', + GCE_TEMPLATE_NAME, + ], + silent=silent, + ) def create_instance_group_and_template(self): - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, '--quiet', - 'instance-templates', 'create', GCE_TEMPLATE_NAME - ]) - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, '--quiet', - 'instance-groups', 'managed', 'create', GCE_INSTANCE_GROUP_MANAGER_NAME, - '--template', GCE_TEMPLATE_NAME, - '--zone', GCE_ZONE, '--size=1' - ]) - self.execute_cmd([ - 'gcloud', 'beta', 'compute', '--project', GCP_PROJECT_ID, '--quiet', - 'instance-groups', 'managed', 'wait-until-stable', - GCE_INSTANCE_GROUP_MANAGER_NAME, - '--zone', GCE_ZONE - ]) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + 'instance-templates', + 'create', + GCE_TEMPLATE_NAME, + ] + ) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + 'instance-groups', + 'managed', + 'create', + GCE_INSTANCE_GROUP_MANAGER_NAME, + '--template', + GCE_TEMPLATE_NAME, + '--zone', + GCE_ZONE, + '--size=1', + ] + ) + self.execute_cmd( + [ + 'gcloud', + 'beta', + 'compute', + '--project', + GCP_PROJECT_ID, + '--quiet', + 'instance-groups', + 'managed', + 'wait-until-stable', + GCE_INSTANCE_GROUP_MANAGER_NAME, + '--zone', + GCE_ZONE, + ] + ) if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Create or delete GCE instances/instance groups for system tests.') - parser.add_argument('--action', dest='action', required=True, - choices=('create-instance', 'delete-instance', - 'create-instance-group', 'delete-instance-group', - 'before-tests', 'after-tests')) + description='Create or delete GCE instances/instance groups for system tests.' + ) + parser.add_argument( + '--action', + dest='action', + required=True, + choices=( + 'create-instance', + 'delete-instance', + 'create-instance-group', + 'delete-instance-group', + 'before-tests', + 'after-tests', + ), + ) action = parser.parse_args().action helper = GCPComputeTestHelper() diff --git a/tests/providers/google/cloud/operators/test_datacatalog.py b/tests/providers/google/cloud/operators/test_datacatalog.py index e303b1e63717c..44c60742aadc0 100644 --- a/tests/providers/google/cloud/operators/test_datacatalog.py +++ b/tests/providers/google/cloud/operators/test_datacatalog.py @@ -23,16 +23,26 @@ from google.cloud.datacatalog_v1beta1.types import Entry, EntryGroup, Tag, TagTemplate, TagTemplateField from airflow.providers.google.cloud.operators.datacatalog import ( - CloudDataCatalogCreateEntryGroupOperator, CloudDataCatalogCreateEntryOperator, - CloudDataCatalogCreateTagOperator, CloudDataCatalogCreateTagTemplateFieldOperator, - CloudDataCatalogCreateTagTemplateOperator, CloudDataCatalogDeleteEntryGroupOperator, - CloudDataCatalogDeleteEntryOperator, CloudDataCatalogDeleteTagOperator, - CloudDataCatalogDeleteTagTemplateFieldOperator, CloudDataCatalogDeleteTagTemplateOperator, - CloudDataCatalogGetEntryGroupOperator, CloudDataCatalogGetEntryOperator, - CloudDataCatalogGetTagTemplateOperator, CloudDataCatalogListTagsOperator, - CloudDataCatalogLookupEntryOperator, CloudDataCatalogRenameTagTemplateFieldOperator, - CloudDataCatalogSearchCatalogOperator, CloudDataCatalogUpdateEntryOperator, - CloudDataCatalogUpdateTagOperator, CloudDataCatalogUpdateTagTemplateFieldOperator, + CloudDataCatalogCreateEntryGroupOperator, + CloudDataCatalogCreateEntryOperator, + CloudDataCatalogCreateTagOperator, + CloudDataCatalogCreateTagTemplateFieldOperator, + CloudDataCatalogCreateTagTemplateOperator, + CloudDataCatalogDeleteEntryGroupOperator, + CloudDataCatalogDeleteEntryOperator, + CloudDataCatalogDeleteTagOperator, + CloudDataCatalogDeleteTagTemplateFieldOperator, + CloudDataCatalogDeleteTagTemplateOperator, + CloudDataCatalogGetEntryGroupOperator, + CloudDataCatalogGetEntryOperator, + CloudDataCatalogGetTagTemplateOperator, + CloudDataCatalogListTagsOperator, + CloudDataCatalogLookupEntryOperator, + CloudDataCatalogRenameTagTemplateFieldOperator, + CloudDataCatalogSearchCatalogOperator, + CloudDataCatalogUpdateEntryOperator, + CloudDataCatalogUpdateTagOperator, + CloudDataCatalogUpdateTagTemplateFieldOperator, CloudDataCatalogUpdateTagTemplateOperator, ) @@ -110,8 +120,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ti = mock.MagicMock() result = task.execute(context={"task_instance": ti}) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_entry.assert_called_once_with( location=TEST_LOCATION, @@ -150,8 +159,7 @@ def test_assert_valid_hook_call_when_exists(self, mock_hook) -> None: ti = mock.MagicMock() result = task.execute(context={"task_instance": ti}) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_entry.assert_called_once_with( location=TEST_LOCATION, @@ -197,8 +205,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ti = mock.MagicMock() result = task.execute(context={"task_instance": ti}) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_entry_group.assert_called_once_with( location=TEST_LOCATION, @@ -236,8 +243,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ti = mock.MagicMock() result = task.execute(context={"task_instance": ti}) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_tag.assert_called_once_with( location=TEST_LOCATION, @@ -275,8 +281,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ti = mock.MagicMock() result = task.execute(context={"task_instance": ti}) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_tag_template.assert_called_once_with( location=TEST_LOCATION, @@ -313,8 +318,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ti = mock.MagicMock() result = task.execute(context={"task_instance": ti}) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_tag_template_field.assert_called_once_with( location=TEST_LOCATION, @@ -347,8 +351,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_entry.assert_called_once_with( location=TEST_LOCATION, @@ -377,8 +380,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_entry_group.assert_called_once_with( location=TEST_LOCATION, @@ -408,8 +410,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_tag.assert_called_once_with( location=TEST_LOCATION, @@ -440,8 +441,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_tag_template.assert_called_once_with( location=TEST_LOCATION, @@ -472,8 +472,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_tag_template_field.assert_called_once_with( location=TEST_LOCATION, @@ -504,8 +503,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.get_entry.assert_called_once_with( location=TEST_LOCATION, @@ -535,8 +533,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.get_entry_group.assert_called_once_with( location=TEST_LOCATION, @@ -565,8 +562,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.get_tag_template.assert_called_once_with( location=TEST_LOCATION, @@ -596,8 +592,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.list_tags.assert_called_once_with( location=TEST_LOCATION, @@ -626,8 +621,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.lookup_entry.assert_called_once_with( linked_resource=TEST_LINKED_RESOURCE, @@ -656,8 +650,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.rename_tag_template_field.assert_called_once_with( location=TEST_LOCATION, @@ -688,8 +681,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.search_catalog.assert_called_once_with( scope=TEST_SCOPE, @@ -721,8 +713,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.update_entry.assert_called_once_with( entry=TEST_ENTRY, @@ -757,8 +748,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.update_tag.assert_called_once_with( tag=TEST_TAG_ID, @@ -792,8 +782,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.update_tag_template.assert_called_once_with( tag_template=TEST_TAG_TEMPLATE_ID, @@ -827,8 +816,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: ) task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.update_tag_template_field.assert_called_once_with( tag_template_field=TEST_TAG_TEMPLATE_FIELD, diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 402fd04efb0df..cb5378bd8a39a 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -22,7 +22,9 @@ import mock from airflow.providers.google.cloud.operators.dataflow import ( - CheckJobRunning, DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, + CheckJobRunning, + DataflowCreateJavaJobOperator, + DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.version import version @@ -32,7 +34,7 @@ TEMPLATE = 'gs://dataflow-templates/wordcount/template_file' PARAMETERS = { 'inputFile': 'gs://dataflow-samples/shakespeare/kinglear.txt', - 'output': 'gs://test/output/my_output' + 'output': 'gs://test/output/my_output', } PY_FILE = 'gs://my-bucket/my-object.py' PY_INTERPRETER = 'python3' @@ -47,16 +49,13 @@ 'project': 'test', 'stagingLocation': 'gs://test/staging', 'tempLocation': 'gs://test/temp', - 'zone': 'us-central1-f' -} -ADDITIONAL_OPTIONS = { - 'output': 'gs://test/output', - 'labels': {'foo': 'bar'} + 'zone': 'us-central1-f', } +ADDITIONAL_OPTIONS = {'output': 'gs://test/output', 'labels': {'foo': 'bar'}} TEST_VERSION = 'v{}'.format(version.replace('.', '-').replace('+', '-')) EXPECTED_ADDITIONAL_OPTIONS = { 'output': 'gs://test/output', - 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION} + 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, } POLL_SLEEP = 30 GCS_HOOK_STRING = 'airflow.providers.google.cloud.operators.dataflow.{}' @@ -64,7 +63,6 @@ class TestDataflowPythonOperator(unittest.TestCase): - def setUp(self): self.dataflow = DataflowCreatePythonJobOperator( task_id=TASK_ID, @@ -74,7 +72,7 @@ def setUp(self): dataflow_default_options=DEFAULT_OPTIONS_PYTHON, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP, - location=TEST_LOCATION + location=TEST_LOCATION, ) def test_init(self): @@ -85,10 +83,8 @@ def test_init(self): self.assertEqual(self.dataflow.py_options, PY_OPTIONS) self.assertEqual(self.dataflow.py_interpreter, PY_INTERPRETER) self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP) - self.assertEqual(self.dataflow.dataflow_default_options, - DEFAULT_OPTIONS_PYTHON) - self.assertEqual(self.dataflow.options, - EXPECTED_ADDITIONAL_OPTIONS) + self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS_PYTHON) + self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') @@ -105,7 +101,7 @@ def test_exec(self, gcs_hook, dataflow_mock): 'project': 'test', 'staging_location': 'gs://test/staging', 'output': 'gs://test/output', - 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION} + 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION}, } gcs_provide_file.assert_called_once_with(object_url=PY_FILE) start_python_hook.assert_called_once_with( @@ -118,13 +114,12 @@ def test_exec(self, gcs_hook, dataflow_mock): py_system_site_packages=False, on_new_job_id_callback=mock.ANY, project_id=None, - location=TEST_LOCATION + location=TEST_LOCATION, ) self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) class TestDataflowJavaOperator(unittest.TestCase): - def setUp(self): self.dataflow = DataflowCreateJavaJobOperator( task_id=TASK_ID, @@ -134,7 +129,7 @@ def setUp(self): dataflow_default_options=DEFAULT_OPTIONS_JAVA, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP, - location=TEST_LOCATION + location=TEST_LOCATION, ) def test_init(self): @@ -142,12 +137,10 @@ def test_init(self): self.assertEqual(self.dataflow.task_id, TASK_ID) self.assertEqual(self.dataflow.job_name, JOB_NAME) self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP) - self.assertEqual(self.dataflow.dataflow_default_options, - DEFAULT_OPTIONS_JAVA) + self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS_JAVA) self.assertEqual(self.dataflow.job_class, JOB_CLASS) self.assertEqual(self.dataflow.jar, JAR_FILE) - self.assertEqual(self.dataflow.options, - EXPECTED_ADDITIONAL_OPTIONS) + self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS) self.assertEqual(self.dataflow.check_if_running, CheckJobRunning.WaitForRun) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @@ -172,7 +165,7 @@ def test_exec(self, gcs_hook, dataflow_mock): multiple_jobs=None, on_new_job_id_callback=mock.ANY, project_id=None, - location=TEST_LOCATION + location=TEST_LOCATION, ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @@ -192,7 +185,8 @@ def test_check_job_running_exec(self, gcs_hook, dataflow_mock): gcs_provide_file.assert_not_called() start_java_hook.assert_not_called() dataflow_running.assert_called_once_with( - name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION) + name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION + ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') @@ -218,10 +212,11 @@ def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock): multiple_jobs=None, on_new_job_id_callback=mock.ANY, project_id=None, - location=TEST_LOCATION + location=TEST_LOCATION, ) dataflow_running.assert_called_once_with( - name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION) + name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION + ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') @@ -248,7 +243,7 @@ def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock): multiple_jobs=True, on_new_job_id_callback=mock.ANY, project_id=None, - location=TEST_LOCATION + location=TEST_LOCATION, ) dataflow_running.assert_called_once_with( name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION @@ -256,7 +251,6 @@ def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock): class TestDataflowTemplateOperator(unittest.TestCase): - def setUp(self): self.dataflow = DataflowTemplatedJobStartOperator( task_id=TASK_ID, @@ -266,7 +260,7 @@ def setUp(self): options=DEFAULT_OPTIONS_TEMPLATE, dataflow_default_options={"EXTRA_OPTION": "TEST_A"}, poll_sleep=POLL_SLEEP, - location=TEST_LOCATION + location=TEST_LOCATION, ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') @@ -283,7 +277,7 @@ def test_exec(self, dataflow_mock): 'stagingLocation': 'gs://test/staging', 'tempLocation': 'gs://test/temp', 'zone': 'us-central1-f', - 'EXTRA_OPTION': "TEST_A" + 'EXTRA_OPTION': "TEST_A", } start_template_hook.assert_called_once_with( job_name=JOB_NAME, @@ -292,5 +286,5 @@ def test_exec(self, dataflow_mock): dataflow_template=TEMPLATE, on_new_job_id_callback=mock.ANY, project_id=None, - location=TEST_LOCATION + location=TEST_LOCATION, ) diff --git a/tests/providers/google/cloud/operators/test_datafusion.py b/tests/providers/google/cloud/operators/test_datafusion.py index 1c296c519ea59..74db0b0e30901 100644 --- a/tests/providers/google/cloud/operators/test_datafusion.py +++ b/tests/providers/google/cloud/operators/test_datafusion.py @@ -19,11 +19,16 @@ from airflow import DAG from airflow.providers.google.cloud.operators.datafusion import ( - CloudDataFusionCreateInstanceOperator, CloudDataFusionCreatePipelineOperator, - CloudDataFusionDeleteInstanceOperator, CloudDataFusionDeletePipelineOperator, - CloudDataFusionGetInstanceOperator, CloudDataFusionListPipelinesOperator, - CloudDataFusionRestartInstanceOperator, CloudDataFusionStartPipelineOperator, - CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator, + CloudDataFusionCreateInstanceOperator, + CloudDataFusionCreatePipelineOperator, + CloudDataFusionDeleteInstanceOperator, + CloudDataFusionDeletePipelineOperator, + CloudDataFusionGetInstanceOperator, + CloudDataFusionListPipelinesOperator, + CloudDataFusionRestartInstanceOperator, + CloudDataFusionStartPipelineOperator, + CloudDataFusionStopPipelineOperator, + CloudDataFusionUpdateInstanceOperator, ) HOOK_STR = "airflow.providers.google.cloud.operators.datafusion.DataFusionHook" @@ -67,10 +72,7 @@ class TestCloudDataFusionRestartInstanceOperator: @mock.patch(HOOK_STR) def test_execute(self, mock_hook): op = CloudDataFusionRestartInstanceOperator( - task_id="test_taks", - instance_name=INSTANCE_NAME, - location=LOCATION, - project_id=PROJECT_ID, + task_id="test_taks", instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID, ) op.execute({}) @@ -93,10 +95,7 @@ def test_execute(self, mock_hook): op.execute({}) mock_hook.return_value.create_instance.assert_called_once_with( - instance_name=INSTANCE_NAME, - instance=INSTANCE, - location=LOCATION, - project_id=PROJECT_ID, + instance_name=INSTANCE_NAME, instance=INSTANCE, location=LOCATION, project_id=PROJECT_ID, ) assert mock_hook.return_value.wait_for_operation.call_count == 1 @@ -105,10 +104,7 @@ class TestCloudDataFusionDeleteInstanceOperator: @mock.patch(HOOK_STR) def test_execute(self, mock_hook): op = CloudDataFusionDeleteInstanceOperator( - task_id="test_taks", - instance_name=INSTANCE_NAME, - location=LOCATION, - project_id=PROJECT_ID, + task_id="test_taks", instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID, ) op.execute({}) @@ -122,10 +118,7 @@ class TestCloudDataFusionGetInstanceOperator: @mock.patch(HOOK_STR) def test_execute(self, mock_hook): op = CloudDataFusionGetInstanceOperator( - task_id="test_taks", - instance_name=INSTANCE_NAME, - location=LOCATION, - project_id=PROJECT_ID, + task_id="test_taks", instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID, ) op.execute({}) @@ -153,10 +146,7 @@ def test_execute(self, mock_hook): ) mock_hook.return_value.create_pipeline.assert_called_once_with( - instance_url=INSTANCE_URL, - pipeline_name=PIPELINE_NAME, - pipeline=PIPELINE, - namespace=NAMESPACE, + instance_url=INSTANCE_URL, pipeline_name=PIPELINE_NAME, pipeline=PIPELINE, namespace=NAMESPACE, ) @@ -179,10 +169,7 @@ def test_execute(self, mock_hook): ) mock_hook.return_value.delete_pipeline.assert_called_once_with( - instance_url=INSTANCE_URL, - pipeline_name=PIPELINE_NAME, - namespace=NAMESPACE, - version_id="1.12", + instance_url=INSTANCE_URL, pipeline_name=PIPELINE_NAME, namespace=NAMESPACE, version_id="1.12", ) @@ -198,7 +185,7 @@ def test_execute(self, mock_hook): namespace=NAMESPACE, location=LOCATION, project_id=PROJECT_ID, - runtime_args=RUNTIME_ARGS + runtime_args=RUNTIME_ARGS, ) op.dag = mock.MagicMock(spec=DAG, task_dict={}, dag_id="test") diff --git a/tests/providers/google/cloud/operators/test_dataprep.py b/tests/providers/google/cloud/operators/test_dataprep.py index 73ccee62c92ec..ce01d5b3b30c1 100644 --- a/tests/providers/google/cloud/operators/test_dataprep.py +++ b/tests/providers/google/cloud/operators/test_dataprep.py @@ -24,9 +24,7 @@ class TestDataprepGetJobsForJobGroupOperator(TestCase): - @mock.patch( - "airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook" - ) + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute(self, hook_mock): op = DataprepGetJobsForJobGroupOperator(job_id=JOB_ID, task_id=TASK_ID) op.execute(context={}) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index f091b2739bea0..ac705c1ff7376 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -26,11 +26,20 @@ from airflow import AirflowException from airflow.providers.google.cloud.operators.dataproc import ( - ClusterGenerator, DataprocCreateClusterOperator, DataprocDeleteClusterOperator, - DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator, - DataprocScaleClusterOperator, DataprocSubmitHadoopJobOperator, DataprocSubmitHiveJobOperator, - DataprocSubmitJobOperator, DataprocSubmitPigJobOperator, DataprocSubmitPySparkJobOperator, - DataprocSubmitSparkJobOperator, DataprocSubmitSparkSqlJobOperator, DataprocUpdateClusterOperator, + ClusterGenerator, + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocInstantiateInlineWorkflowTemplateOperator, + DataprocInstantiateWorkflowTemplateOperator, + DataprocScaleClusterOperator, + DataprocSubmitHadoopJobOperator, + DataprocSubmitHiveJobOperator, + DataprocSubmitJobOperator, + DataprocSubmitPigJobOperator, + DataprocSubmitPySparkJobOperator, + DataprocSubmitSparkJobOperator, + DataprocSubmitSparkSqlJobOperator, + DataprocUpdateClusterOperator, ) from airflow.version import version as airflow_version @@ -52,8 +61,7 @@ "cluster_name": CLUSTER_NAME, "config": { "gce_cluster_config": { - "zone_uri": "https://www.googleapis.com/compute/v1/projects/" - "project_id/zones/zone", + "zone_uri": "https://www.googleapis.com/compute/v1/projects/" "project_id/zones/zone", "metadata": {"metadata": "data"}, "network_uri": "network_uri", "subnetwork_uri": "subnetwork_uri", @@ -65,58 +73,41 @@ "master_config": { "num_instances": 2, "machine_type_uri": "https://www.googleapis.com/compute/v1/projects/" - "project_id/zones/zone/machineTypes/master_machine_type", - "disk_config": { - "boot_disk_type": "master_disk_type", - "boot_disk_size_gb": 128, - }, + "project_id/zones/zone/machineTypes/master_machine_type", + "disk_config": {"boot_disk_type": "master_disk_type", "boot_disk_size_gb": 128,}, "image_uri": "https://www.googleapis.com/compute/beta/projects/" - "custom_image_project_id/global/images/custom_image", + "custom_image_project_id/global/images/custom_image", }, "worker_config": { "num_instances": 2, "machine_type_uri": "https://www.googleapis.com/compute/v1/projects/" - "project_id/zones/zone/machineTypes/worker_machine_type", - "disk_config": { - "boot_disk_type": "worker_disk_type", - "boot_disk_size_gb": 256, - }, + "project_id/zones/zone/machineTypes/worker_machine_type", + "disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256,}, "image_uri": "https://www.googleapis.com/compute/beta/projects/" - "custom_image_project_id/global/images/custom_image", + "custom_image_project_id/global/images/custom_image", }, "secondary_worker_config": { "num_instances": 4, "machine_type_uri": "https://www.googleapis.com/compute/v1/projects/" - "project_id/zones/zone/machineTypes/worker_machine_type", - "disk_config": { - "boot_disk_type": "worker_disk_type", - "boot_disk_size_gb": 256, - }, + "project_id/zones/zone/machineTypes/worker_machine_type", + "disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256,}, "is_preemptible": True, }, "software_config": { "properties": {"properties": "data"}, "optional_components": ["optional_components"], }, - "lifecycle_config": { - "idle_delete_ttl": "60s", - "auto_delete_time": "2019-09-12T00:00:00.000000Z", - }, + "lifecycle_config": {"idle_delete_ttl": "60s", "auto_delete_time": "2019-09-12T00:00:00.000000Z",}, "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"}, "autoscaling_config": {"policy_uri": "autoscaling_policy"}, "config_bucket": "storage_bucket", - "initialization_actions": [ - {"executable_file": "init_actions_uris", "execution_timeout": "600s"} - ], + "initialization_actions": [{"executable_file": "init_actions_uris", "execution_timeout": "600s"}], }, "labels": {"labels": "data", "airflow-version": AIRFLOW_VERSION}, } UPDATE_MASK = { - "paths": [ - "config.worker_config.num_instances", - "config.secondary_worker_config.num_instances", - ] + "paths": ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances",] } TIMEOUT = 120 @@ -215,8 +206,7 @@ def test_execute(self, mock_hook): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, @@ -246,8 +236,7 @@ def test_execute_if_cluster_exists(self, mock_hook): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, @@ -281,7 +270,7 @@ def test_execute_if_cluster_exists_do_not_use(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, request_id=REQUEST_ID, - use_if_exists=False + use_if_exists=False, ) with self.assertRaises(AlreadyExists): op.execute(context={}) @@ -309,14 +298,10 @@ def test_execute_if_cluster_exists_in_error_state(self, mock_hook): op.execute(context={}) mock_hook.return_value.diagnose_cluster.assert_called_once_with( - region=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_name=CLUSTER_NAME, + region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME, ) mock_hook.return_value.delete_cluster.assert_called_once_with( - region=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_name=CLUSTER_NAME, + region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME, ) @mock.patch(DATAPROC_PATH.format("exponential_sleep_generator")) @@ -353,18 +338,14 @@ def test_execute_if_cluster_exists_in_deleting_state( mock_get_cluster.assert_has_calls(calls) mock_create_cluster.assert_has_calls(calls) mock_hook.return_value.diagnose_cluster.assert_called_once_with( - region=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_name=CLUSTER_NAME, + region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME, ) class TestDataprocClusterScaleOperator(unittest.TestCase): def test_deprecation_warning(self): with self.assertWarns(DeprecationWarning) as warning: - DataprocScaleClusterOperator( - task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT - ) + DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT) assert_warning("DataprocUpdateClusterOperator", warning) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -390,8 +371,7 @@ def test_execute(self, mock_hook): op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with( project_id=GCP_PROJECT, @@ -420,8 +400,7 @@ def test_execute(self, mock_hook): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_cluster.assert_called_once_with( region=GCP_LOCATION, @@ -458,8 +437,7 @@ def test_execute(self, mock_hook): op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -495,8 +473,7 @@ def test_execute(self, mock_hook): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with( location=GCP_LOCATION, @@ -535,8 +512,7 @@ def test_execute(self, mock_hook): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.instantiate_workflow_template.assert_called_once_with( template_name=template_id, @@ -570,8 +546,7 @@ def test_execute(self, mock_hook): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( template=template, @@ -589,10 +564,7 @@ class TestDataProcHiveOperator(unittest.TestCase): variables = {"key": "value"} job_id = "uuid_id" job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "hive_job": {"query_list": {"queries": [query]}, "script_variables": variables}, @@ -602,9 +574,7 @@ class TestDataProcHiveOperator(unittest.TestCase): def test_deprecation_warning(self, mock_hook): with self.assertWarns(DeprecationWarning) as warning: DataprocSubmitHiveJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - query="query", + task_id=TASK_ID, region=GCP_LOCATION, query="query", ) assert_warning("DataprocSubmitJobOperator", warning) @@ -626,8 +596,7 @@ def test_execute(self, mock_hook, mock_uuid): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -658,10 +627,7 @@ class TestDataProcPigOperator(unittest.TestCase): variables = {"key": "value"} job_id = "uuid_id" job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "pig_job": {"query_list": {"queries": [query]}, "script_variables": variables}, @@ -671,9 +637,7 @@ class TestDataProcPigOperator(unittest.TestCase): def test_deprecation_warning(self, mock_hook): with self.assertWarns(DeprecationWarning) as warning: DataprocSubmitPigJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - query="query", + task_id=TASK_ID, region=GCP_LOCATION, query="query", ) assert_warning("DataprocSubmitJobOperator", warning) @@ -695,8 +659,7 @@ def test_execute(self, mock_hook, mock_uuid): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -727,25 +690,17 @@ class TestDataProcSparkSqlOperator(unittest.TestCase): variables = {"key": "value"} job_id = "uuid_id" job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, - "spark_sql_job": { - "query_list": {"queries": [query]}, - "script_variables": variables, - }, + "spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables,}, } @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_deprecation_warning(self, mock_hook): with self.assertWarns(DeprecationWarning) as warning: DataprocSubmitSparkSqlJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - query="query", + task_id=TASK_ID, region=GCP_LOCATION, query="query", ) assert_warning("DataprocSubmitJobOperator", warning) @@ -767,8 +722,7 @@ def test_execute(self, mock_hook, mock_uuid): ) op.execute(context={}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -799,10 +753,7 @@ class TestDataProcSparkOperator(unittest.TestCase): jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] job_id = "uuid_id" job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "spark_job": {"jar_file_uris": jars, "main_class": main_class}, @@ -812,10 +763,7 @@ class TestDataProcSparkOperator(unittest.TestCase): def test_deprecation_warning(self, mock_hook): with self.assertWarns(DeprecationWarning) as warning: DataprocSubmitSparkJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - main_class=self.main_class, - dataproc_jars=self.jars, + task_id=TASK_ID, region=GCP_LOCATION, main_class=self.main_class, dataproc_jars=self.jars, ) assert_warning("DataprocSubmitJobOperator", warning) @@ -842,10 +790,7 @@ class TestDataProcHadoopOperator(unittest.TestCase): jar = "file:///usr/lib/spark/examples/jars/spark-examples.jar" job_id = "uuid_id" job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "hadoop_job": {"main_jar_file_uri": jar, "args": args}, @@ -855,10 +800,7 @@ class TestDataProcHadoopOperator(unittest.TestCase): def test_deprecation_warning(self, mock_hook): with self.assertWarns(DeprecationWarning) as warning: DataprocSubmitHadoopJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - main_jar=self.jar, - arguments=self.args, + task_id=TASK_ID, region=GCP_LOCATION, main_jar=self.jar, arguments=self.args, ) assert_warning("DataprocSubmitJobOperator", warning) @@ -884,10 +826,7 @@ class TestDataProcPySparkOperator(unittest.TestCase): uri = "gs://{}/{}" job_id = "uuid_id" job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "pyspark_job": {"main_python_file_uri": uri}, @@ -897,9 +836,7 @@ class TestDataProcPySparkOperator(unittest.TestCase): def test_deprecation_warning(self, mock_hook): with self.assertWarns(DeprecationWarning) as warning: DataprocSubmitPySparkJobOperator( - task_id=TASK_ID, - region=GCP_LOCATION, - main=self.uri, + task_id=TASK_ID, region=GCP_LOCATION, main=self.uri, ) assert_warning("DataprocSubmitJobOperator", warning) diff --git a/tests/providers/google/cloud/operators/test_dataproc_system.py b/tests/providers/google/cloud/operators/test_dataproc_system.py index 863ad5b5ce094..816ee4817570a 100644 --- a/tests/providers/google/cloud/operators/test_dataproc_system.py +++ b/tests/providers/google/cloud/operators/test_dataproc_system.py @@ -57,12 +57,8 @@ class DataprocExampleDagsTest(GoogleSystemTest): def setUp(self): super().setUp() self.create_gcs_bucket(BUCKET) - self.upload_content_to_gcs( - lines=pyspark_file, bucket=PYSPARK_URI, filename=PYSPARK_MAIN - ) - self.upload_content_to_gcs( - lines=sparkr_file, bucket=SPARKR_URI, filename=SPARKR_MAIN - ) + self.upload_content_to_gcs(lines=pyspark_file, bucket=PYSPARK_URI, filename=PYSPARK_MAIN) + self.upload_content_to_gcs(lines=sparkr_file, bucket=SPARKR_URI, filename=SPARKR_MAIN) @provide_gcp_context(GCP_DATAPROC_KEY) def tearDown(self): diff --git a/tests/providers/google/cloud/operators/test_datastore.py b/tests/providers/google/cloud/operators/test_datastore.py index 89a2a589f7f8a..e49dded7d832e 100644 --- a/tests/providers/google/cloud/operators/test_datastore.py +++ b/tests/providers/google/cloud/operators/test_datastore.py @@ -18,9 +18,14 @@ from unittest import mock from airflow.providers.google.cloud.operators.datastore import ( - CloudDatastoreAllocateIdsOperator, CloudDatastoreBeginTransactionOperator, CloudDatastoreCommitOperator, - CloudDatastoreDeleteOperationOperator, CloudDatastoreExportEntitiesOperator, - CloudDatastoreGetOperationOperator, CloudDatastoreImportEntitiesOperator, CloudDatastoreRollbackOperator, + CloudDatastoreAllocateIdsOperator, + CloudDatastoreBeginTransactionOperator, + CloudDatastoreCommitOperator, + CloudDatastoreDeleteOperationOperator, + CloudDatastoreExportEntitiesOperator, + CloudDatastoreGetOperationOperator, + CloudDatastoreImportEntitiesOperator, + CloudDatastoreRollbackOperator, CloudDatastoreRunQueryOperator, ) @@ -37,67 +42,43 @@ class TestCloudDatastoreExportEntitiesOperator: @mock.patch(HOOK_PATH) def test_execute(self, mock_hook): - mock_hook.return_value.export_to_storage_bucket.return_value = { - "name": OPERATION_ID - } + mock_hook.return_value.export_to_storage_bucket.return_value = {"name": OPERATION_ID} mock_hook.return_value.poll_operation_until_done.return_value = { "metadata": {"common": {"state": "SUCCESSFUL"}} } op = CloudDatastoreExportEntitiesOperator( - task_id="test_task", - datastore_conn_id=CONN_ID, - project_id=PROJECT_ID, - bucket=BUCKET, + task_id="test_task", datastore_conn_id=CONN_ID, project_id=PROJECT_ID, bucket=BUCKET, ) op.execute({}) mock_hook.assert_called_once_with(CONN_ID, None, impersonation_chain=None) mock_hook.return_value.export_to_storage_bucket.assert_called_once_with( - project_id=PROJECT_ID, - bucket=BUCKET, - entity_filter=None, - labels=None, - namespace=None, + project_id=PROJECT_ID, bucket=BUCKET, entity_filter=None, labels=None, namespace=None, ) - mock_hook.return_value.poll_operation_until_done.assert_called_once_with( - OPERATION_ID, 10 - ) + mock_hook.return_value.poll_operation_until_done.assert_called_once_with(OPERATION_ID, 10) class TestCloudDatastoreImportEntitiesOperator: @mock.patch(HOOK_PATH) def test_execute(self, mock_hook): - mock_hook.return_value.import_from_storage_bucket.return_value = { - "name": OPERATION_ID - } + mock_hook.return_value.import_from_storage_bucket.return_value = {"name": OPERATION_ID} mock_hook.return_value.poll_operation_until_done.return_value = { "metadata": {"common": {"state": "SUCCESSFUL"}} } op = CloudDatastoreImportEntitiesOperator( - task_id="test_task", - datastore_conn_id=CONN_ID, - project_id=PROJECT_ID, - bucket=BUCKET, - file=FILE, + task_id="test_task", datastore_conn_id=CONN_ID, project_id=PROJECT_ID, bucket=BUCKET, file=FILE, ) op.execute({}) mock_hook.assert_called_once_with(CONN_ID, None, impersonation_chain=None) mock_hook.return_value.import_from_storage_bucket.assert_called_once_with( - project_id=PROJECT_ID, - bucket=BUCKET, - file=FILE, - entity_filter=None, - labels=None, - namespace=None, + project_id=PROJECT_ID, bucket=BUCKET, file=FILE, entity_filter=None, labels=None, namespace=None, ) - mock_hook.return_value.export_to_storage_bucketassert_called_once_with( - OPERATION_ID, 10 - ) + mock_hook.return_value.export_to_storage_bucketassert_called_once_with(OPERATION_ID, 10) class TestCloudDatastoreAllocateIds: @@ -105,10 +86,7 @@ class TestCloudDatastoreAllocateIds: def test_execute(self, mock_hook): partial_keys = [1, 2, 3] op = CloudDatastoreAllocateIdsOperator( - task_id="test_task", - gcp_conn_id=CONN_ID, - project_id=PROJECT_ID, - partial_keys=partial_keys, + task_id="test_task", gcp_conn_id=CONN_ID, project_id=PROJECT_ID, partial_keys=partial_keys, ) op.execute({}) @@ -122,10 +100,7 @@ class TestCloudDatastoreBeginTransaction: @mock.patch(HOOK_PATH) def test_execute(self, mock_hook): op = CloudDatastoreBeginTransactionOperator( - task_id="test_task", - gcp_conn_id=CONN_ID, - project_id=PROJECT_ID, - transaction_options=BODY, + task_id="test_task", gcp_conn_id=CONN_ID, project_id=PROJECT_ID, transaction_options=BODY, ) op.execute({}) @@ -144,31 +119,23 @@ def test_execute(self, mock_hook): op.execute({}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) - mock_hook.return_value.commit.assert_called_once_with( - project_id=PROJECT_ID, body=BODY - ) + mock_hook.return_value.commit.assert_called_once_with(project_id=PROJECT_ID, body=BODY) class TestCloudDatastoreDeleteOperation: @mock.patch(HOOK_PATH) def test_execute(self, mock_hook): - op = CloudDatastoreDeleteOperationOperator( - task_id="test_task", gcp_conn_id=CONN_ID, name=TRANSACTION - ) + op = CloudDatastoreDeleteOperationOperator(task_id="test_task", gcp_conn_id=CONN_ID, name=TRANSACTION) op.execute({}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) - mock_hook.return_value.delete_operation.assert_called_once_with( - name=TRANSACTION - ) + mock_hook.return_value.delete_operation.assert_called_once_with(name=TRANSACTION) class TestCloudDatastoreGetOperation: @mock.patch(HOOK_PATH) def test_execute(self, mock_hook): - op = CloudDatastoreGetOperationOperator( - task_id="test_task", gcp_conn_id=CONN_ID, name=TRANSACTION - ) + op = CloudDatastoreGetOperationOperator(task_id="test_task", gcp_conn_id=CONN_ID, name=TRANSACTION) op.execute({}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) @@ -179,10 +146,7 @@ class TestCloudDatastoreRollback: @mock.patch(HOOK_PATH) def test_execute(self, mock_hook): op = CloudDatastoreRollbackOperator( - task_id="test_task", - gcp_conn_id=CONN_ID, - project_id=PROJECT_ID, - transaction=TRANSACTION, + task_id="test_task", gcp_conn_id=CONN_ID, project_id=PROJECT_ID, transaction=TRANSACTION, ) op.execute({}) @@ -201,6 +165,4 @@ def test_execute(self, mock_hook): op.execute({}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) - mock_hook.return_value.run_query.assert_called_once_with( - project_id=PROJECT_ID, body=BODY - ) + mock_hook.return_value.run_query.assert_called_once_with(project_id=PROJECT_ID, body=BODY) diff --git a/tests/providers/google/cloud/operators/test_datastore_system.py b/tests/providers/google/cloud/operators/test_datastore_system.py index 7cbeaf1c0f367..549379211c857 100644 --- a/tests/providers/google/cloud/operators/test_datastore_system.py +++ b/tests/providers/google/cloud/operators/test_datastore_system.py @@ -28,7 +28,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_DATASTORE_KEY) class GcpDatastoreSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_DATASTORE_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/operators/test_dlp.py b/tests/providers/google/cloud/operators/test_dlp.py index f4aecce1a2f0d..8e6f15c62a8dd 100644 --- a/tests/providers/google/cloud/operators/test_dlp.py +++ b/tests/providers/google/cloud/operators/test_dlp.py @@ -26,18 +26,36 @@ import mock from airflow.providers.google.cloud.operators.dlp import ( - CloudDLPCancelDLPJobOperator, CloudDLPCreateDeidentifyTemplateOperator, CloudDLPCreateDLPJobOperator, - CloudDLPCreateInspectTemplateOperator, CloudDLPCreateJobTriggerOperator, - CloudDLPCreateStoredInfoTypeOperator, CloudDLPDeidentifyContentOperator, - CloudDLPDeleteDeidentifyTemplateOperator, CloudDLPDeleteDLPJobOperator, - CloudDLPDeleteInspectTemplateOperator, CloudDLPDeleteJobTriggerOperator, - CloudDLPDeleteStoredInfoTypeOperator, CloudDLPGetDeidentifyTemplateOperator, CloudDLPGetDLPJobOperator, - CloudDLPGetDLPJobTriggerOperator, CloudDLPGetInspectTemplateOperator, CloudDLPGetStoredInfoTypeOperator, - CloudDLPInspectContentOperator, CloudDLPListDeidentifyTemplatesOperator, CloudDLPListDLPJobsOperator, - CloudDLPListInfoTypesOperator, CloudDLPListInspectTemplatesOperator, CloudDLPListJobTriggersOperator, - CloudDLPListStoredInfoTypesOperator, CloudDLPRedactImageOperator, CloudDLPReidentifyContentOperator, - CloudDLPUpdateDeidentifyTemplateOperator, CloudDLPUpdateInspectTemplateOperator, - CloudDLPUpdateJobTriggerOperator, CloudDLPUpdateStoredInfoTypeOperator, + CloudDLPCancelDLPJobOperator, + CloudDLPCreateDeidentifyTemplateOperator, + CloudDLPCreateDLPJobOperator, + CloudDLPCreateInspectTemplateOperator, + CloudDLPCreateJobTriggerOperator, + CloudDLPCreateStoredInfoTypeOperator, + CloudDLPDeidentifyContentOperator, + CloudDLPDeleteDeidentifyTemplateOperator, + CloudDLPDeleteDLPJobOperator, + CloudDLPDeleteInspectTemplateOperator, + CloudDLPDeleteJobTriggerOperator, + CloudDLPDeleteStoredInfoTypeOperator, + CloudDLPGetDeidentifyTemplateOperator, + CloudDLPGetDLPJobOperator, + CloudDLPGetDLPJobTriggerOperator, + CloudDLPGetInspectTemplateOperator, + CloudDLPGetStoredInfoTypeOperator, + CloudDLPInspectContentOperator, + CloudDLPListDeidentifyTemplatesOperator, + CloudDLPListDLPJobsOperator, + CloudDLPListInfoTypesOperator, + CloudDLPListInspectTemplatesOperator, + CloudDLPListJobTriggersOperator, + CloudDLPListStoredInfoTypesOperator, + CloudDLPRedactImageOperator, + CloudDLPReidentifyContentOperator, + CloudDLPUpdateDeidentifyTemplateOperator, + CloudDLPUpdateInspectTemplateOperator, + CloudDLPUpdateJobTriggerOperator, + CloudDLPUpdateStoredInfoTypeOperator, ) GCP_CONN_ID = "google_cloud_default" @@ -56,15 +74,10 @@ def test_cancel_dlp_job(self, mock_hook): operator = CloudDLPCancelDLPJobOperator(dlp_job_id=DLP_JOB_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.cancel_dlp_job.assert_called_once_with( - dlp_job_id=DLP_JOB_ID, - project_id=None, - retry=None, - timeout=None, - metadata=None, + dlp_job_id=DLP_JOB_ID, project_id=None, retry=None, timeout=None, metadata=None, ) @@ -72,13 +85,10 @@ class TestCloudDLPCreateDeidentifyTemplateOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_create_deidentify_template(self, mock_hook): mock_hook.return_value.create_deidentify_template.return_value = mock.MagicMock() - operator = CloudDLPCreateDeidentifyTemplateOperator( - organization_id=ORGANIZATION_ID, task_id="id" - ) + operator = CloudDLPCreateDeidentifyTemplateOperator(organization_id=ORGANIZATION_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_deidentify_template.assert_called_once_with( organization_id=ORGANIZATION_ID, @@ -98,8 +108,7 @@ def test_create_dlp_job(self, mock_hook): operator = CloudDLPCreateDLPJobOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_dlp_job.assert_called_once_with( project_id=PROJECT_ID, @@ -117,13 +126,10 @@ class TestCloudDLPCreateInspectTemplateOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_create_inspect_template(self, mock_hook): mock_hook.return_value.create_inspect_template.return_value = mock.MagicMock() - operator = CloudDLPCreateInspectTemplateOperator( - organization_id=ORGANIZATION_ID, task_id="id" - ) + operator = CloudDLPCreateInspectTemplateOperator(organization_id=ORGANIZATION_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_inspect_template.assert_called_once_with( organization_id=ORGANIZATION_ID, @@ -143,16 +149,10 @@ def test_create_job_trigger(self, mock_hook): operator = CloudDLPCreateJobTriggerOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_job_trigger.assert_called_once_with( - project_id=PROJECT_ID, - job_trigger=None, - trigger_id=None, - retry=None, - timeout=None, - metadata=None, + project_id=PROJECT_ID, job_trigger=None, trigger_id=None, retry=None, timeout=None, metadata=None, ) @@ -160,13 +160,10 @@ class TestCloudDLPCreateStoredInfoTypeOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_create_stored_info_type(self, mock_hook): mock_hook.return_value.create_stored_info_type.return_value = mock.MagicMock() - operator = CloudDLPCreateStoredInfoTypeOperator( - organization_id=ORGANIZATION_ID, task_id="id" - ) + operator = CloudDLPCreateStoredInfoTypeOperator(organization_id=ORGANIZATION_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_stored_info_type.assert_called_once_with( organization_id=ORGANIZATION_ID, @@ -183,13 +180,10 @@ class TestCloudDLPDeidentifyContentOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_deidentify_content(self, mock_hook): mock_hook.return_value.deidentify_content.return_value = mock.MagicMock() - operator = CloudDLPDeidentifyContentOperator( - project_id=PROJECT_ID, task_id="id" - ) + operator = CloudDLPDeidentifyContentOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.deidentify_content.assert_called_once_with( project_id=PROJECT_ID, @@ -213,8 +207,7 @@ def test_delete_deidentify_template(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_deidentify_template.assert_called_once_with( template_id=TEMPLATE_ID, @@ -230,20 +223,13 @@ class TestCloudDLPDeleteDlpJobOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_delete_dlp_job(self, mock_hook): mock_hook.return_value.delete_dlp_job.return_value = mock.MagicMock() - operator = CloudDLPDeleteDLPJobOperator( - dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID, task_id="id" - ) + operator = CloudDLPDeleteDLPJobOperator(dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_dlp_job.assert_called_once_with( - dlp_job_id=DLP_JOB_ID, - project_id=PROJECT_ID, - retry=None, - timeout=None, - metadata=None, + dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID, retry=None, timeout=None, metadata=None, ) @@ -256,8 +242,7 @@ def test_delete_inspect_template(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_inspect_template.assert_called_once_with( template_id=TEMPLATE_ID, @@ -278,15 +263,10 @@ def test_delete_job_trigger(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_job_trigger.assert_called_once_with( - job_trigger_id=TRIGGER_ID, - project_id=PROJECT_ID, - retry=None, - timeout=None, - metadata=None, + job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID, retry=None, timeout=None, metadata=None, ) @@ -295,14 +275,11 @@ class TestCloudDLPDeleteStoredInfoTypeOperator(unittest.TestCase): def test_delete_stored_info_type(self, mock_hook): mock_hook.return_value.delete_stored_info_type.return_value = mock.MagicMock() operator = CloudDLPDeleteStoredInfoTypeOperator( - stored_info_type_id=STORED_INFO_TYPE_ID, - organization_id=ORGANIZATION_ID, - task_id="id", + stored_info_type_id=STORED_INFO_TYPE_ID, organization_id=ORGANIZATION_ID, task_id="id", ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_stored_info_type.assert_called_once_with( stored_info_type_id=STORED_INFO_TYPE_ID, @@ -323,8 +300,7 @@ def test_get_deidentify_template(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_deidentify_template.assert_called_once_with( template_id=TEMPLATE_ID, @@ -340,20 +316,13 @@ class TestCloudDLPGetDlpJobOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_get_dlp_job(self, mock_hook): mock_hook.return_value.get_dlp_job.return_value = mock.MagicMock() - operator = CloudDLPGetDLPJobOperator( - dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID, task_id="id" - ) + operator = CloudDLPGetDLPJobOperator(dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_dlp_job.assert_called_once_with( - dlp_job_id=DLP_JOB_ID, - project_id=PROJECT_ID, - retry=None, - timeout=None, - metadata=None, + dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID, retry=None, timeout=None, metadata=None, ) @@ -366,8 +335,7 @@ def test_get_inspect_template(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_inspect_template.assert_called_once_with( template_id=TEMPLATE_ID, @@ -388,15 +356,10 @@ def test_get_job_trigger(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_job_trigger.assert_called_once_with( - job_trigger_id=TRIGGER_ID, - project_id=PROJECT_ID, - retry=None, - timeout=None, - metadata=None, + job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID, retry=None, timeout=None, metadata=None, ) @@ -405,14 +368,11 @@ class TestCloudDLPGetStoredInfoTypeOperator(unittest.TestCase): def test_get_stored_info_type(self, mock_hook): mock_hook.return_value.get_stored_info_type.return_value = mock.MagicMock() operator = CloudDLPGetStoredInfoTypeOperator( - stored_info_type_id=STORED_INFO_TYPE_ID, - organization_id=ORGANIZATION_ID, - task_id="id", + stored_info_type_id=STORED_INFO_TYPE_ID, organization_id=ORGANIZATION_ID, task_id="id", ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_stored_info_type.assert_called_once_with( stored_info_type_id=STORED_INFO_TYPE_ID, @@ -431,8 +391,7 @@ def test_inspect_content(self, mock_hook): operator = CloudDLPInspectContentOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.inspect_content.assert_called_once_with( project_id=PROJECT_ID, @@ -449,13 +408,10 @@ class TestCloudDLPListDeidentifyTemplatesOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_list_deidentify_templates(self, mock_hook): mock_hook.return_value.list_deidentify_templates.return_value = mock.MagicMock() - operator = CloudDLPListDeidentifyTemplatesOperator( - organization_id=ORGANIZATION_ID, task_id="id" - ) + operator = CloudDLPListDeidentifyTemplatesOperator(organization_id=ORGANIZATION_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_deidentify_templates.assert_called_once_with( organization_id=ORGANIZATION_ID, @@ -475,8 +431,7 @@ def test_list_dlp_jobs(self, mock_hook): operator = CloudDLPListDLPJobsOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_dlp_jobs.assert_called_once_with( project_id=PROJECT_ID, @@ -497,15 +452,10 @@ def test_list_info_types(self, mock_hook): operator = CloudDLPListInfoTypesOperator(task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_info_types.assert_called_once_with( - language_code=None, - results_filter=None, - retry=None, - timeout=None, - metadata=None, + language_code=None, results_filter=None, retry=None, timeout=None, metadata=None, ) @@ -513,13 +463,10 @@ class TestCloudDLPListInspectTemplatesOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_list_inspect_templates(self, mock_hook): mock_hook.return_value.list_inspect_templates.return_value = mock.MagicMock() - operator = CloudDLPListInspectTemplatesOperator( - organization_id=ORGANIZATION_ID, task_id="id" - ) + operator = CloudDLPListInspectTemplatesOperator(organization_id=ORGANIZATION_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_inspect_templates.assert_called_once_with( organization_id=ORGANIZATION_ID, @@ -539,8 +486,7 @@ def test_list_job_triggers(self, mock_hook): operator = CloudDLPListJobTriggersOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_job_triggers.assert_called_once_with( project_id=PROJECT_ID, @@ -557,13 +503,10 @@ class TestCloudDLPListStoredInfoTypesOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_list_stored_info_types(self, mock_hook): mock_hook.return_value.list_stored_info_types.return_value = mock.MagicMock() - operator = CloudDLPListStoredInfoTypesOperator( - organization_id=ORGANIZATION_ID, task_id="id" - ) + operator = CloudDLPListStoredInfoTypesOperator(organization_id=ORGANIZATION_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_stored_info_types.assert_called_once_with( organization_id=ORGANIZATION_ID, @@ -583,8 +526,7 @@ def test_redact_image(self, mock_hook): operator = CloudDLPRedactImageOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.redact_image.assert_called_once_with( project_id=PROJECT_ID, @@ -602,13 +544,10 @@ class TestCloudDLPReidentifyContentOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_reidentify_content(self, mock_hook): mock_hook.return_value.reidentify_content.return_value = mock.MagicMock() - operator = CloudDLPReidentifyContentOperator( - project_id=PROJECT_ID, task_id="id" - ) + operator = CloudDLPReidentifyContentOperator(project_id=PROJECT_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.reidentify_content.assert_called_once_with( project_id=PROJECT_ID, @@ -632,8 +571,7 @@ def test_update_deidentify_template(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_deidentify_template.assert_called_once_with( template_id=TEMPLATE_ID, @@ -656,8 +594,7 @@ def test_update_inspect_template(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_inspect_template.assert_called_once_with( template_id=TEMPLATE_ID, @@ -675,13 +612,10 @@ class TestCloudDLPUpdateJobTriggerOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.dlp.CloudDLPHook") def test_update_job_trigger(self, mock_hook): mock_hook.return_value.update_job_trigger.return_value = mock.MagicMock() - operator = CloudDLPUpdateJobTriggerOperator( - job_trigger_id=TRIGGER_ID, task_id="id" - ) + operator = CloudDLPUpdateJobTriggerOperator(job_trigger_id=TRIGGER_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_job_trigger.assert_called_once_with( job_trigger_id=TRIGGER_ID, @@ -699,14 +633,11 @@ class TestCloudDLPUpdateStoredInfoTypeOperator(unittest.TestCase): def test_update_stored_info_type(self, mock_hook): mock_hook.return_value.update_stored_info_type.return_value = mock.MagicMock() operator = CloudDLPUpdateStoredInfoTypeOperator( - stored_info_type_id=STORED_INFO_TYPE_ID, - organization_id=ORGANIZATION_ID, - task_id="id", + stored_info_type_id=STORED_INFO_TYPE_ID, organization_id=ORGANIZATION_ID, task_id="id", ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_stored_info_type.assert_called_once_with( stored_info_type_id=STORED_INFO_TYPE_ID, diff --git a/tests/providers/google/cloud/operators/test_functions.py b/tests/providers/google/cloud/operators/test_functions.py index 5470f3e9759a2..9a4b5ab884899 100644 --- a/tests/providers/google/cloud/operators/test_functions.py +++ b/tests/providers/google/cloud/operators/test_functions.py @@ -25,7 +25,9 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.functions import ( - FUNCTION_NAME_PATTERN, CloudFunctionDeleteFunctionOperator, CloudFunctionDeployFunctionOperator, + FUNCTION_NAME_PATTERN, + CloudFunctionDeleteFunctionOperator, + CloudFunctionDeployFunctionOperator, CloudFunctionInvokeFunctionOperator, ) from airflow.version import version @@ -37,9 +39,7 @@ GCP_LOCATION = 'test_region' GCF_SOURCE_ARCHIVE_URL = 'gs://folder/file.zip' GCF_ENTRYPOINT = 'helloWorld' -FUNCTION_NAME = 'projects/{}/locations/{}/functions/{}'.format(GCP_PROJECT_ID, - GCP_LOCATION, - GCF_ENTRYPOINT) +FUNCTION_NAME = 'projects/{}/locations/{}/functions/{}'.format(GCP_PROJECT_ID, GCP_LOCATION, GCF_ENTRYPOINT) GCF_RUNTIME = 'nodejs6' VALID_RUNTIMES = ['nodejs6', 'nodejs8', 'python37'] VALID_BODY = { @@ -47,7 +47,7 @@ "entryPoint": GCF_ENTRYPOINT, "runtime": GCF_RUNTIME, "httpsTrigger": {}, - "sourceArchiveUrl": GCF_SOURCE_ARCHIVE_URL + "sourceArchiveUrl": GCF_SOURCE_ARCHIVE_URL, } @@ -61,8 +61,7 @@ def _prepare_test_bodies(): body_values = [ ({}, "The required parameter 'body' is missing"), (body_no_name, "The required body field 'name' is missing"), - (body_empty_entry_point, - "The body field 'entryPoint' of value '' does not match"), + (body_empty_entry_point, "The body field 'entryPoint' of value '' does not match"), (body_empty_runtime, "The body field 'runtime' of value '' does not match"), ] return body_values @@ -75,10 +74,7 @@ def test_body_empty_or_missing_fields(self, body, message, mock_hook): mock_hook.return_value.upload_function_zip.return_value = 'https://uploadUrl' with self.assertRaises(AirflowException) as cm: op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) err = cm.exception @@ -87,29 +83,23 @@ def test_body_empty_or_missing_fields(self, body, message, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_deploy_execute(self, mock_hook): mock_hook.return_value.get_function.side_effect = mock.Mock( - side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found')) + side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found') + ) mock_hook.return_value.create_new_function.return_value = True op = CloudFunctionDeployFunctionOperator( - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - body=deepcopy(VALID_BODY), - task_id="id" + project_id=GCP_PROJECT_ID, location=GCP_LOCATION, body=deepcopy(VALID_BODY), task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.get_function.assert_called_once_with( 'projects/test_project_id/locations/test_region/functions/helloWorld' ) expected_body = deepcopy(VALID_BODY) - expected_body['labels'] = { - 'airflow-version': 'v' + version.replace('.', '-').replace('+', '-') - } + expected_body['labels'] = {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} mock_hook.return_value.create_new_function.assert_called_once_with( - project_id='test_project_id', - location='test_region', - body=expected_body + project_id='test_project_id', location='test_region', body=expected_body ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') @@ -117,56 +107,45 @@ def test_update_function_if_exists(self, mock_hook): mock_hook.return_value.get_function.return_value = True mock_hook.return_value.update_function.return_value = True op = CloudFunctionDeployFunctionOperator( - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - body=deepcopy(VALID_BODY), - task_id="id" + project_id=GCP_PROJECT_ID, location=GCP_LOCATION, body=deepcopy(VALID_BODY), task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.get_function.assert_called_once_with( 'projects/test_project_id/locations/test_region/functions/helloWorld' ) expected_body = deepcopy(VALID_BODY) - expected_body['labels'] = { - 'airflow-version': 'v' + version.replace('.', '-').replace('+', '-') - } + expected_body['labels'] = {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} mock_hook.return_value.update_function.assert_called_once_with( 'projects/test_project_id/locations/test_region/functions/helloWorld', - expected_body, expected_body.keys()) + expected_body, + expected_body.keys(), + ) mock_hook.return_value.create_new_function.assert_not_called() @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_empty_project_id_is_ok(self, mock_hook): - mock_hook.return_value.get_function.side_effect = \ - HttpError(resp=MOCK_RESP_404, content=b'not found') + mock_hook.return_value.get_function.side_effect = HttpError(resp=MOCK_RESP_404, content=b'not found') operator = CloudFunctionDeployFunctionOperator( - location="test_region", - body=deepcopy(VALID_BODY), - task_id="id" + location="test_region", body=deepcopy(VALID_BODY), task_id="id" ) operator.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) new_body = deepcopy(VALID_BODY) - new_body['labels'] = { - 'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} + new_body['labels'] = {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} mock_hook.return_value.create_new_function.assert_called_once_with( - project_id=None, - location="test_region", - body=new_body) + project_id=None, location="test_region", body=new_body + ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_empty_location(self, mock_hook): with self.assertRaises(AirflowException) as cm: CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="", - body=None, - task_id="id" + project_id="test_project_id", location="", body=None, task_id="id" ) err = cm.exception self.assertIn("The required parameter 'location' is missing", str(err)) @@ -175,101 +154,79 @@ def test_empty_location(self, mock_hook): def test_empty_body(self, mock_hook): with self.assertRaises(AirflowException) as cm: CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=None, - task_id="id" + project_id="test_project_id", location="test_region", body=None, task_id="id" ) err = cm.exception self.assertIn("The required parameter 'body' is missing", str(err)) - @parameterized.expand([ - (runtime,) for runtime in VALID_RUNTIMES - ]) + @parameterized.expand([(runtime,) for runtime in VALID_RUNTIMES]) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_correct_runtime_field(self, runtime, mock_hook): mock_hook.return_value.create_new_function.return_value = True body = deepcopy(VALID_BODY) body['runtime'] = runtime op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() - @parameterized.expand([ - (network,) for network in [ - "network-01", - "n-0-2-3-4", - "projects/PROJECT/global/networks/network-01" - "projects/PRÓJECT/global/networks/netwórk-01" + @parameterized.expand( + [ + (network,) + for network in [ + "network-01", + "n-0-2-3-4", + "projects/PROJECT/global/networks/network-01", + "projects/PRÓJECT/global/networks/netwórk-01", + ] ] - ]) + ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_valid_network_field(self, network, mock_hook): mock_hook.return_value.create_new_function.return_value = True body = deepcopy(VALID_BODY) body['network'] = network op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() - @parameterized.expand([ - (labels,) for labels in [ - {}, - {"label": 'value-01'}, - {"label_324234_a_b_c": 'value-01_93'}, - ] - ]) + @parameterized.expand( + [(labels,) for labels in [{}, {"label": 'value-01'}, {"label_324234_a_b_c": 'value-01_93'},]] + ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_valid_labels_field(self, labels, mock_hook): mock_hook.return_value.create_new_function.return_value = True body = deepcopy(VALID_BODY) body['labels'] = labels op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_validation_disabled(self, mock_hook): mock_hook.return_value.create_new_function.return_value = True - body = { - "name": "function_name", - "some_invalid_body_field": "some_invalid_body_field_value" - } + body = {"name": "function_name", "some_invalid_body_field": "some_invalid_body_field_value"} op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - validate_body=False, - task_id="id" + project_id="test_project_id", location="test_region", body=body, validate_body=False, task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') @@ -279,36 +236,30 @@ def test_body_validation_simple(self, mock_hook): body['name'] = '' with self.assertRaises(AirflowException) as cm: op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) err = cm.exception - self.assertIn("The body field 'name' of value '' does not match", - str(err)) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + self.assertIn("The body field 'name' of value '' does not match", str(err)) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() - @parameterized.expand([ - ('name', '', - "The body field 'name' of value '' does not match"), - ('description', '', "The body field 'description' of value '' does not match"), - ('entryPoint', '', "The body field 'entryPoint' of value '' does not match"), - ('availableMemoryMb', '0', - "The available memory has to be greater than 0"), - ('availableMemoryMb', '-1', - "The available memory has to be greater than 0"), - ('availableMemoryMb', 'ss', - "invalid literal for int() with base 10: 'ss'"), - ('network', '', "The body field 'network' of value '' does not match"), - ('maxInstances', '0', "The max instances parameter has to be greater than 0"), - ('maxInstances', '-1', "The max instances parameter has to be greater than 0"), - ('maxInstances', 'ss', "invalid literal for int() with base 10: 'ss'"), - ]) + @parameterized.expand( + [ + ('name', '', "The body field 'name' of value '' does not match"), + ('description', '', "The body field 'description' of value '' does not match"), + ('entryPoint', '', "The body field 'entryPoint' of value '' does not match"), + ('availableMemoryMb', '0', "The available memory has to be greater than 0"), + ('availableMemoryMb', '-1', "The available memory has to be greater than 0"), + ('availableMemoryMb', 'ss', "invalid literal for int() with base 10: 'ss'"), + ('network', '', "The body field 'network' of value '' does not match"), + ('maxInstances', '0', "The max instances parameter has to be greater than 0"), + ('maxInstances', '-1', "The max instances parameter has to be greater than 0"), + ('maxInstances', 'ss', "invalid literal for int() with base 10: 'ss'"), + ] + ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_invalid_field_values(self, key, value, message, mock_hook): mock_hook.return_value.create_new_function.return_value = True @@ -316,52 +267,76 @@ def test_invalid_field_values(self, key, value, message, mock_hook): body[key] = value with self.assertRaises(AirflowException) as cm: op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) err = cm.exception self.assertIn(message, str(err)) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() - @parameterized.expand([ - ({'sourceArchiveUrl': ''}, - "The body field 'source_code.sourceArchiveUrl' of value '' does not match"), - ({'sourceArchiveUrl': '', 'zip_path': '/path/to/file'}, - "Only one of 'sourceArchiveUrl' in body or 'zip_path' argument allowed."), - ({'sourceArchiveUrl': 'gs://url', 'zip_path': '/path/to/file'}, - "Only one of 'sourceArchiveUrl' in body or 'zip_path' argument allowed."), - ({'sourceArchiveUrl': '', 'sourceUploadUrl': ''}, - "Parameter 'sourceUploadUrl' is empty in the body and argument " - "'zip_path' is missing or empty."), - ({'sourceArchiveUrl': 'gs://adasda', 'sourceRepository': ''}, - "The field 'source_code.sourceRepository' should be of dictionary type"), - ({'sourceUploadUrl': '', 'sourceRepository': ''}, - "Parameter 'sourceUploadUrl' is empty in the body and argument 'zip_path' " - "is missing or empty."), - ({'sourceArchiveUrl': '', 'sourceUploadUrl': '', 'sourceRepository': ''}, - "Parameter 'sourceUploadUrl' is empty in the body and argument 'zip_path' " - "is missing or empty."), - ({'sourceArchiveUrl': 'gs://url', 'sourceUploadUrl': 'https://url'}, - "The mutually exclusive fields 'sourceUploadUrl' and 'sourceArchiveUrl' " - "belonging to the union 'source_code' are both present. Please remove one"), - ({'sourceUploadUrl': 'https://url', 'zip_path': '/path/to/file'}, - "Only one of 'sourceUploadUrl' in body " - "or 'zip_path' argument allowed. Found both."), - ({'sourceUploadUrl': ''}, "Parameter 'sourceUploadUrl' is empty in the body " - "and argument 'zip_path' is missing or empty."), - ({'sourceRepository': ''}, "The field 'source_code.sourceRepository' " - "should be of dictionary type"), - ({'sourceRepository': {}}, "The required body field " - "'source_code.sourceRepository.url' is missing"), - ({'sourceRepository': {'url': ''}}, - "The body field 'source_code.sourceRepository.url' of value '' does not match"), - ] + @parameterized.expand( + [ + ( + {'sourceArchiveUrl': ''}, + "The body field 'source_code.sourceArchiveUrl' of value '' does not match", + ), + ( + {'sourceArchiveUrl': '', 'zip_path': '/path/to/file'}, + "Only one of 'sourceArchiveUrl' in body or 'zip_path' argument allowed.", + ), + ( + {'sourceArchiveUrl': 'gs://url', 'zip_path': '/path/to/file'}, + "Only one of 'sourceArchiveUrl' in body or 'zip_path' argument allowed.", + ), + ( + {'sourceArchiveUrl': '', 'sourceUploadUrl': ''}, + "Parameter 'sourceUploadUrl' is empty in the body and argument " + "'zip_path' is missing or empty.", + ), + ( + {'sourceArchiveUrl': 'gs://adasda', 'sourceRepository': ''}, + "The field 'source_code.sourceRepository' should be of dictionary type", + ), + ( + {'sourceUploadUrl': '', 'sourceRepository': ''}, + "Parameter 'sourceUploadUrl' is empty in the body and argument 'zip_path' " + "is missing or empty.", + ), + ( + {'sourceArchiveUrl': '', 'sourceUploadUrl': '', 'sourceRepository': ''}, + "Parameter 'sourceUploadUrl' is empty in the body and argument 'zip_path' " + "is missing or empty.", + ), + ( + {'sourceArchiveUrl': 'gs://url', 'sourceUploadUrl': 'https://url'}, + "The mutually exclusive fields 'sourceUploadUrl' and 'sourceArchiveUrl' " + "belonging to the union 'source_code' are both present. Please remove one", + ), + ( + {'sourceUploadUrl': 'https://url', 'zip_path': '/path/to/file'}, + "Only one of 'sourceUploadUrl' in body or 'zip_path' argument allowed. Found both.", + ), + ( + {'sourceUploadUrl': ''}, + "Parameter 'sourceUploadUrl' is empty in the body " + "and argument 'zip_path' is missing or empty.", + ), + ( + {'sourceRepository': ''}, + "The field 'source_code.sourceRepository' should be of dictionary type", + ), + ( + {'sourceRepository': {}}, + "The required body field 'source_code.sourceRepository.url' is missing", + ), + ( + {'sourceRepository': {'url': ''}}, + "The body field 'source_code.sourceRepository.url' of value '' does not match", + ), + ] ) def test_invalid_source_code_union_field(self, source_code, message): body = deepcopy(VALID_BODY) @@ -375,29 +350,30 @@ def test_invalid_source_code_union_field(self, source_code, message): location="test_region", body=body, task_id="id", - zip_path=zip_path + zip_path=zip_path, ) op.execute(None) err = cm.exception self.assertIn(message, str(err)) + # fmt: off @parameterized.expand([ ({'sourceArchiveUrl': 'gs://url'}, 'test_project_id'), ({'zip_path': '/path/to/file', 'sourceUploadUrl': None}, 'test_project_id'), ({'zip_path': '/path/to/file', 'sourceUploadUrl': None}, None), - ({'sourceUploadUrl': - 'https://source.developers.google.com/projects/a/repos/b/revisions/c/paths/d'}, - 'test_project_id'), - ({'sourceRepository': - {'url': 'https://source.developers.google.com/projects/a/' - 'repos/b/revisions/c/paths/d'}}, + ({'sourceUploadUrl': 'https://source.developers.google.com/projects/a/repos/b/revisions/c/paths/d'}, 'test_project_id'), + ({'sourceRepository': { + 'url': + 'https://source.developers.google.com/projects/a/repos/b/revisions/c/paths/d' + }}, 'test_project_id'), ]) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_valid_source_code_union_field(self, source_code, project_id, mock_hook): mock_hook.return_value.upload_function_zip.return_value = 'https://uploadUrl' mock_hook.return_value.get_function.side_effect = mock.Mock( - side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found')) + side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found') + ) mock_hook.return_value.create_new_function.return_value = True body = deepcopy(VALID_BODY) body.pop('sourceUploadUrl', None) @@ -412,53 +388,63 @@ def test_valid_source_code_union_field(self, source_code, project_id, mock_hook) location="test_region", body=body, task_id="id", - zip_path=zip_path + zip_path=zip_path, ) else: op = CloudFunctionDeployFunctionOperator( - location="test_region", - body=body, - task_id="id", - zip_path=zip_path + location="test_region", body=body, task_id="id", zip_path=zip_path ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) if zip_path: mock_hook.return_value.upload_function_zip.assert_called_once_with( - project_id=project_id, - location='test_region', - zip_path='/path/to/file' + project_id=project_id, location='test_region', zip_path='/path/to/file' ) mock_hook.return_value.get_function.assert_called_once_with( 'projects/test_project_id/locations/test_region/functions/helloWorld' ) mock_hook.return_value.create_new_function.assert_called_once_with( - project_id=project_id, - location='test_region', - body=body + project_id=project_id, location='test_region', body=body ) mock_hook.reset_mock() - @parameterized.expand([ - ({'eventTrigger': {}}, - "The required body field 'trigger.eventTrigger.eventType' is missing"), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b'}}, - "The required body field 'trigger.eventTrigger.resource' is missing"), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', 'resource': ''}}, - "The body field 'trigger.eventTrigger.resource' of value '' does not match"), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', - 'resource': 'res', - 'service': ''}}, - "The body field 'trigger.eventTrigger.service' of value '' does not match"), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', - 'resource': 'res', - 'service': 'service_name', - 'failurePolicy': {'retry': ''}}}, - "The field 'trigger.eventTrigger.failurePolicy.retry' " - "should be of dictionary type") - ] + # fmt: on + + @parameterized.expand( + [ + ({'eventTrigger': {}}, "The required body field 'trigger.eventTrigger.eventType' is missing"), + ( + {'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b'}}, + "The required body field 'trigger.eventTrigger.resource' is missing", + ), + ( + {'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', 'resource': ''}}, + "The body field 'trigger.eventTrigger.resource' of value '' does not match", + ), + ( + { + 'eventTrigger': { + 'eventType': 'providers/test/eventTypes/a.b', + 'resource': 'res', + 'service': '', + } + }, + "The body field 'trigger.eventTrigger.service' of value '' does not match", + ), + ( + { + 'eventTrigger': { + 'eventType': 'providers/test/eventTypes/a.b', + 'resource': 'res', + 'service': 'service_name', + 'failurePolicy': {'retry': ''}, + } + }, + "The field 'trigger.eventTrigger.failurePolicy.retry' should be of dictionary type", + ), + ] ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_invalid_trigger_union_field(self, trigger, message, mock_hook): @@ -469,61 +455,73 @@ def test_invalid_trigger_union_field(self, trigger, message, mock_hook): body.update(trigger) with self.assertRaises(AirflowException) as cm: op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id", + project_id="test_project_id", location="test_region", body=body, task_id="id", ) op.execute(None) err = cm.exception self.assertIn(message, str(err)) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() - @parameterized.expand([ - ({'httpsTrigger': {}},), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', - 'resource': 'res'}},), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', - 'resource': 'res', - 'service': 'service_name'}},), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/ą.b', - 'resource': 'reś', - 'service': 'service_namę'}},), - ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', - 'resource': 'res', - 'service': 'service_name', - 'failurePolicy': {'retry': {}}}},) - ]) + @parameterized.expand( + [ + ({'httpsTrigger': {}},), + ({'eventTrigger': {'eventType': 'providers/test/eventTypes/a.b', 'resource': 'res'}},), + ( + { + 'eventTrigger': { + 'eventType': 'providers/test/eventTypes/a.b', + 'resource': 'res', + 'service': 'service_name', + } + }, + ), + ( + { + 'eventTrigger': { + 'eventType': 'providers/test/eventTypes/ą.b', + 'resource': 'reś', + 'service': 'service_namę', + } + }, + ), + ( + { + 'eventTrigger': { + 'eventType': 'providers/test/eventTypes/a.b', + 'resource': 'res', + 'service': 'service_name', + 'failurePolicy': {'retry': {}}, + } + }, + ), + ] + ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_valid_trigger_union_field(self, trigger, mock_hook): mock_hook.return_value.upload_function_zip.return_value = 'https://uploadUrl' mock_hook.return_value.get_function.side_effect = mock.Mock( - side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found')) + side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found') + ) mock_hook.return_value.create_new_function.return_value = True body = deepcopy(VALID_BODY) body.pop('httpsTrigger', None) body.pop('eventTrigger', None) body.update(trigger) op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id", + project_id="test_project_id", location="test_region", body=body, task_id="id", ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.get_function.assert_called_once_with( 'projects/test_project_id/locations/test_region/functions/helloWorld' ) mock_hook.return_value.create_new_function.assert_called_once_with( - project_id='test_project_id', - location='test_region', - body=body + project_id='test_project_id', location='test_region', body=body ) mock_hook.reset_mock() @@ -533,46 +531,40 @@ def test_extra_parameter(self, mock_hook): body = deepcopy(VALID_BODY) body['extra_parameter'] = 'extra' op = CloudFunctionDeployFunctionOperator( - project_id="test_project_id", - location="test_region", - body=body, - task_id="id" + project_id="test_project_id", location="test_region", body=body, task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.reset_mock() class TestGcfFunctionDelete(unittest.TestCase): - _FUNCTION_NAME = 'projects/project_name/locations/project_location/functions' \ - '/function_name' + _FUNCTION_NAME = 'projects/project_name/locations/project_location/functions/function_name' _DELETE_FUNCTION_EXPECTED = { '@type': 'type.googleapis.com/google.cloud.functions.v1.CloudFunction', 'name': _FUNCTION_NAME, 'sourceArchiveUrl': 'gs://functions/hello.zip', - 'httpsTrigger': { - 'url': 'https://project_location-project_name.cloudfunctions.net' - '/function_name'}, - 'status': 'ACTIVE', 'entryPoint': 'entry_point', 'timeout': '60s', + 'httpsTrigger': {'url': 'https://project_location-project_name.cloudfunctions.net/function_name'}, + 'status': 'ACTIVE', + 'entryPoint': 'entry_point', + 'timeout': '60s', 'availableMemoryMb': 256, 'serviceAccountEmail': 'project_name@appspot.gserviceaccount.com', 'updateTime': '2018-08-23T00:00:00Z', - 'versionId': '1', 'runtime': 'nodejs6'} + 'versionId': '1', + 'runtime': 'nodejs6', + } @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_delete_execute(self, mock_hook): - mock_hook.return_value.delete_function.return_value = \ - self._DELETE_FUNCTION_EXPECTED - op = CloudFunctionDeleteFunctionOperator( - name=self._FUNCTION_NAME, - task_id="id" - ) + mock_hook.return_value.delete_function.return_value = self._DELETE_FUNCTION_EXPECTED + op = CloudFunctionDeleteFunctionOperator(name=self._FUNCTION_NAME, task_id="id") result = op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.delete_function.assert_called_once_with( 'projects/project_name/locations/project_location/functions/function_name' ) @@ -581,73 +573,59 @@ def test_delete_execute(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_correct_name(self, mock_hook): op = CloudFunctionDeleteFunctionOperator( - name="projects/project_name/locations/project_location/functions" - "/function_name", - task_id="id" + name="projects/project_name/locations/project_location/functions" "/function_name", task_id="id" ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_invalid_name(self, mock_hook): with self.assertRaises(AttributeError) as cm: - op = CloudFunctionDeleteFunctionOperator( - name="invalid_name", - task_id="id" - ) + op = CloudFunctionDeleteFunctionOperator(name="invalid_name", task_id="id") op.execute(None) err = cm.exception - self.assertEqual(str(err), 'Parameter name must match pattern: {}'.format( - FUNCTION_NAME_PATTERN)) + self.assertEqual(str(err), 'Parameter name must match pattern: {}'.format(FUNCTION_NAME_PATTERN)) mock_hook.assert_not_called() @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_empty_name(self, mock_hook): - mock_hook.return_value.delete_function.return_value = \ - self._DELETE_FUNCTION_EXPECTED + mock_hook.return_value.delete_function.return_value = self._DELETE_FUNCTION_EXPECTED with self.assertRaises(AttributeError) as cm: - CloudFunctionDeleteFunctionOperator( - name="", - task_id="id" - ) + CloudFunctionDeleteFunctionOperator(name="", task_id="id") err = cm.exception self.assertEqual(str(err), 'Empty parameter: name') mock_hook.assert_not_called() @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_gcf_error_silenced_when_function_doesnt_exist(self, mock_hook): - op = CloudFunctionDeleteFunctionOperator( - name=self._FUNCTION_NAME, - task_id="id" - ) + op = CloudFunctionDeleteFunctionOperator(name=self._FUNCTION_NAME, task_id="id") mock_hook.return_value.delete_function.side_effect = mock.Mock( - side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found')) + side_effect=HttpError(resp=MOCK_RESP_404, content=b'not found') + ) op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.delete_function.assert_called_once_with( 'projects/project_name/locations/project_location/functions/function_name' ) @mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook') def test_non_404_gcf_error_bubbled_up(self, mock_hook): - op = CloudFunctionDeleteFunctionOperator( - name=self._FUNCTION_NAME, - task_id="id" - ) + op = CloudFunctionDeleteFunctionOperator(name=self._FUNCTION_NAME, task_id="id") resp = type('', (object,), {"status": 500})() mock_hook.return_value.delete_function.side_effect = mock.Mock( - side_effect=HttpError(resp=resp, content=b'error')) + side_effect=HttpError(resp=resp, content=b'error') + ) with self.assertRaises(HttpError): op.execute(None) - mock_hook.assert_called_once_with(api_version='v1', - gcp_conn_id='google_cloud_default', - impersonation_chain=None,) + mock_hook.assert_called_once_with( + api_version='v1', gcp_conn_id='google_cloud_default', impersonation_chain=None, + ) mock_hook.return_value.delete_function.assert_called_once_with( 'projects/project_name/locations/project_location/functions/function_name' ) @@ -678,20 +656,11 @@ def test_execute(self, mock_gcf_hook, mock_xcom): ) op.execute(None) mock_gcf_hook.assert_called_once_with( - api_version=api_version, - gcp_conn_id=gcp_conn_id, - impersonation_chain=impersonation_chain, + api_version=api_version, gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) mock_gcf_hook.return_value.call_function.assert_called_once_with( - function_id=function_id, - input_data=payload, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID + function_id=function_id, input_data=payload, location=GCP_LOCATION, project_id=GCP_PROJECT_ID ) - mock_xcom.assert_called_once_with( - context=None, - key='execution_id', - value=exec_id - ) + mock_xcom.assert_called_once_with(context=None, key='execution_id', value=exec_id) diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index c04888497daa7..dbafab7b3d534 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -21,9 +21,14 @@ import mock from airflow.providers.google.cloud.operators.gcs import ( - GCSBucketCreateAclEntryOperator, GCSCreateBucketOperator, GCSDeleteBucketOperator, - GCSDeleteObjectsOperator, GCSFileTransformOperator, GCSListObjectsOperator, - GCSObjectCreateAclEntryOperator, GCSSynchronizeBucketsOperator, + GCSBucketCreateAclEntryOperator, + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSDeleteObjectsOperator, + GCSFileTransformOperator, + GCSListObjectsOperator, + GCSObjectCreateAclEntryOperator, + GCSSynchronizeBucketsOperator, ) TASK_ID = "test-gcs-operator" @@ -43,11 +48,7 @@ def test_execute(self, mock_hook): operator = GCSCreateBucketOperator( task_id=TASK_ID, bucket_name=TEST_BUCKET, - resource={ - "lifecycle": { - "rule": [{"action": {"type": "Delete"}, "condition": {"age": 7}}] - } - }, + resource={"lifecycle": {"rule": [{"action": {"type": "Delete"}, "condition": {"age": 7}}]}}, storage_class="MULTI_REGIONAL", location="EU", labels={"env": "prod"}, @@ -61,11 +62,7 @@ def test_execute(self, mock_hook): location="EU", labels={"env": "prod"}, project_id=TEST_PROJECT, - resource={ - "lifecycle": { - "rule": [{"action": {"type": "Delete"}, "condition": {"age": 7}}] - } - }, + resource={"lifecycle": {"rule": [{"action": {"type": "Delete"}, "condition": {"age": 7}}]}}, ) @@ -112,9 +109,7 @@ def test_object_create_acl(self, mock_hook): class TestGoogleCloudStorageDeleteOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_delete_objects(self, mock_hook): - operator = GCSDeleteObjectsOperator( - task_id=TASK_ID, bucket_name=TEST_BUCKET, objects=MOCK_FILES[0:2] - ) + operator = GCSDeleteObjectsOperator(task_id=TASK_ID, bucket_name=TEST_BUCKET, objects=MOCK_FILES[0:2]) operator.execute(None) mock_hook.return_value.list.assert_not_called() @@ -129,14 +124,10 @@ def test_delete_objects(self, mock_hook): @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_delete_prefix(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES[1:3] - operator = GCSDeleteObjectsOperator( - task_id=TASK_ID, bucket_name=TEST_BUCKET, prefix=PREFIX - ) + operator = GCSDeleteObjectsOperator(task_id=TASK_ID, bucket_name=TEST_BUCKET, prefix=PREFIX) operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - bucket_name=TEST_BUCKET, prefix=PREFIX - ) + mock_hook.return_value.list.assert_called_once_with(bucket_name=TEST_BUCKET, prefix=PREFIX) mock_hook.return_value.delete.assert_has_calls( calls=[ mock.call(bucket_name=TEST_BUCKET, object_name=MOCK_FILES[1]), @@ -205,31 +196,24 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile): ) mock_subprocess.Popen.assert_called_once_with( - args=[transform_script, source, destination], - stdout="pipe", - stderr="stdout", - close_fds=True, + args=[transform_script, source, destination], stdout="pipe", stderr="stdout", close_fds=True, ) mock_hook.return_value.upload.assert_called_with( - bucket_name=destination_bucket, - object_name=destination_object, - filename=destination, + bucket_name=destination_bucket, object_name=destination_object, filename=destination, ) class TestGCSDeleteBucketOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_delete_bucket(self, mock_hook): - operator = GCSDeleteBucketOperator( - task_id=TASK_ID, bucket_name=TEST_BUCKET) + operator = GCSDeleteBucketOperator(task_id=TASK_ID, bucket_name=TEST_BUCKET) operator.execute(None) mock_hook.return_value.delete_bucket.assert_called_once_with(bucket_name=TEST_BUCKET, force=True) class TestGoogleCloudStorageSync(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') def test_execute(self, mock_hook): task = GCSSynchronizeBucketsOperator( diff --git a/tests/providers/google/cloud/operators/test_gcs_system_helper.py b/tests/providers/google/cloud/operators/test_gcs_system_helper.py index f1d00f2006472..5cf66974787b3 100644 --- a/tests/providers/google/cloud/operators/test_gcs_system_helper.py +++ b/tests/providers/google/cloud/operators/test_gcs_system_helper.py @@ -37,7 +37,8 @@ def create_test_file(): # Create script for transform operator with open(PATH_TO_TRANSFORM_SCRIPT, "w+") as file: - file.write("""import sys + file.write( + """import sys source = sys.argv[1] destination = sys.argv[2] @@ -46,7 +47,8 @@ def create_test_file(): lines = [l.upper() for l in src.readlines()] print(lines) dest.writelines(lines) - """) + """ + ) @staticmethod def remove_test_files(): diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index b0cdb48eec282..70a31605399a2 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -27,7 +27,9 @@ from airflow.models import Connection from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.google.cloud.operators.kubernetes_engine import ( - GKECreateClusterOperator, GKEDeleteClusterOperator, GKEStartPodOperator, + GKECreateClusterOperator, + GKEDeleteClusterOperator, + GKEStartPodOperator, ) TEST_GCP_PROJECT_ID = 'test-id' @@ -37,9 +39,7 @@ PROJECT_BODY = {'name': 'test-name'} PROJECT_BODY_CREATE_DICT = {'name': 'test-name', 'initial_node_count': 1} -PROJECT_BODY_CREATE_CLUSTER = type( - "Cluster", (object,), {"name": "test-name", "initial_node_count": 1} -)() +PROJECT_BODY_CREATE_CLUSTER = type("Cluster", (object,), {"name": "test-name", "initial_node_count": 1})() TASK_NAME = 'test-task-name' NAMESPACE = ('default',) @@ -51,23 +51,21 @@ class TestGoogleCloudPlatformContainerOperator(unittest.TestCase): - - @parameterized.expand( - (body,) for body in [PROJECT_BODY_CREATE_DICT, PROJECT_BODY_CREATE_CLUSTER] - ) + @parameterized.expand((body,) for body in [PROJECT_BODY_CREATE_DICT, PROJECT_BODY_CREATE_CLUSTER]) @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_create_execute(self, body, mock_hook): - operator = GKECreateClusterOperator(project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - body=body, - task_id=PROJECT_TASK_ID) + operator = GKECreateClusterOperator( + project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID + ) operator.execute(None) mock_hook.return_value.create_cluster.assert_called_once_with( - cluster=body, project_id=TEST_GCP_PROJECT_ID) + cluster=body, project_id=TEST_GCP_PROJECT_ID + ) @parameterized.expand( - (body,) for body in [ + (body,) + for body in [ None, {'missing_name': 'test-name', 'initial_node_count': 1}, {'name': 'test-name', 'missing_initial_node_count': 1}, @@ -78,100 +76,98 @@ def test_create_execute(self, body, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_create_execute_error_body(self, body, mock_hook): with self.assertRaises(AirflowException): - GKECreateClusterOperator(project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - body=body, - task_id=PROJECT_TASK_ID) + GKECreateClusterOperator( + project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID + ) # pylint: disable=missing-kwoa @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_create_execute_error_project_id(self, mock_hook): with self.assertRaises(AirflowException): - GKECreateClusterOperator(location=PROJECT_LOCATION, - body=PROJECT_BODY, - task_id=PROJECT_TASK_ID) + GKECreateClusterOperator(location=PROJECT_LOCATION, body=PROJECT_BODY, task_id=PROJECT_TASK_ID) # pylint: disable=no-value-for-parameter @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_create_execute_error_location(self, mock_hook): with self.assertRaises(AirflowException): - GKECreateClusterOperator(project_id=TEST_GCP_PROJECT_ID, - body=PROJECT_BODY, - task_id=PROJECT_TASK_ID) + GKECreateClusterOperator( + project_id=TEST_GCP_PROJECT_ID, body=PROJECT_BODY, task_id=PROJECT_TASK_ID + ) @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_delete_execute(self, mock_hook): - operator = GKEDeleteClusterOperator(project_id=TEST_GCP_PROJECT_ID, - name=CLUSTER_NAME, - location=PROJECT_LOCATION, - task_id=PROJECT_TASK_ID) + operator = GKEDeleteClusterOperator( + project_id=TEST_GCP_PROJECT_ID, + name=CLUSTER_NAME, + location=PROJECT_LOCATION, + task_id=PROJECT_TASK_ID, + ) operator.execute(None) mock_hook.return_value.delete_cluster.assert_called_once_with( - name=CLUSTER_NAME, project_id=TEST_GCP_PROJECT_ID) + name=CLUSTER_NAME, project_id=TEST_GCP_PROJECT_ID + ) # pylint: disable=no-value-for-parameter @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_delete_execute_error_project_id(self, mock_hook): with self.assertRaises(AirflowException): - GKEDeleteClusterOperator(location=PROJECT_LOCATION, - name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID) + GKEDeleteClusterOperator(location=PROJECT_LOCATION, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID) # pylint: disable=missing-kwoa @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_delete_execute_error_cluster_name(self, mock_hook): with self.assertRaises(AirflowException): - GKEDeleteClusterOperator(project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - task_id=PROJECT_TASK_ID) + GKEDeleteClusterOperator( + project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, task_id=PROJECT_TASK_ID + ) # pylint: disable=missing-kwoa @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook') def test_delete_execute_error_location(self, mock_hook): with self.assertRaises(AirflowException): - GKEDeleteClusterOperator(project_id=TEST_GCP_PROJECT_ID, - name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID) + GKEDeleteClusterOperator( + project_id=TEST_GCP_PROJECT_ID, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID + ) class TestGKEPodOperator(unittest.TestCase): def setUp(self): - self.gke_op = GKEStartPodOperator(project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE) + self.gke_op = GKEStartPodOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + image=IMAGE, + ) def test_template_fields(self): - self.assertTrue(set(KubernetesPodOperator.template_fields).issubset( - GKEStartPodOperator.template_fields)) + self.assertTrue( + set(KubernetesPodOperator.template_fields).issubset(GKEStartPodOperator.template_fields) + ) # pylint: disable=unused-argument @mock.patch.dict(os.environ, {}) @mock.patch( "airflow.hooks.base_hook.BaseHook.get_connections", - return_value=[Connection( - extra=json.dumps({ - "extra__google_cloud_platform__keyfile_dict": '{"private_key": "r4nd0m_k3y"}' - }) - )] + return_value=[ + Connection( + extra=json.dumps( + {"extra__google_cloud_platform__keyfile_dict": '{"private_key": "r4nd0m_k3y"}'} + ) + ) + ], ) - @mock.patch( - 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute') - @mock.patch( - 'airflow.providers.google.cloud.operators.kubernetes_engine.GoogleBaseHook') - @mock.patch( - 'airflow.providers.google.cloud.operators.kubernetes_engine.execute_in_subprocess') + @mock.patch('airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute') + @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GoogleBaseHook') + @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.execute_in_subprocess') @mock.patch('tempfile.NamedTemporaryFile') - def test_execute( - self, file_mock, mock_execute_in_subprocess, mock_gcp_hook, exec_mock, get_con_mock - ): - type(file_mock.return_value.__enter__.return_value).name = PropertyMock(side_effect=[ - FILE_NAME, '/path/to/new-file' - ]) + def test_execute(self, file_mock, mock_execute_in_subprocess, mock_gcp_hook, exec_mock, get_con_mock): + type(file_mock.return_value.__enter__.return_value).name = PropertyMock( + side_effect=[FILE_NAME, '/path/to/new-file'] + ) self.gke_op.execute(None) @@ -179,10 +175,15 @@ def test_execute( mock_execute_in_subprocess.assert_called_once_with( [ - 'gcloud', 'container', 'clusters', 'get-credentials', + 'gcloud', + 'container', + 'clusters', + 'get-credentials', CLUSTER_NAME, - '--zone', PROJECT_LOCATION, - '--project', TEST_GCP_PROJECT_ID, + '--zone', + PROJECT_LOCATION, + '--project', + TEST_GCP_PROJECT_ID, ] ) @@ -192,26 +193,25 @@ def test_execute( @mock.patch.dict(os.environ, {}) @mock.patch( "airflow.hooks.base_hook.BaseHook.get_connections", - return_value=[Connection( - extra=json.dumps({ - "extra__google_cloud_platform__keyfile_dict": '{"private_key": "r4nd0m_k3y"}' - }) - )] + return_value=[ + Connection( + extra=json.dumps( + {"extra__google_cloud_platform__keyfile_dict": '{"private_key": "r4nd0m_k3y"}'} + ) + ) + ], ) - @mock.patch( - 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute') - @mock.patch( - 'airflow.providers.google.cloud.operators.kubernetes_engine.GoogleBaseHook') - @mock.patch( - 'airflow.providers.google.cloud.operators.kubernetes_engine.execute_in_subprocess') + @mock.patch('airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute') + @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GoogleBaseHook') + @mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.execute_in_subprocess') @mock.patch('tempfile.NamedTemporaryFile') def test_execute_with_internal_ip( self, file_mock, mock_execute_in_subprocess, mock_gcp_hook, exec_mock, get_con_mock ): self.gke_op.use_internal_ip = True - type(file_mock.return_value.__enter__.return_value).name = PropertyMock(side_effect=[ - FILE_NAME, '/path/to/new-file' - ]) + type(file_mock.return_value.__enter__.return_value).name = PropertyMock( + side_effect=[FILE_NAME, '/path/to/new-file'] + ) self.gke_op.execute(None) @@ -219,11 +219,16 @@ def test_execute_with_internal_ip( mock_execute_in_subprocess.assert_called_once_with( [ - 'gcloud', 'container', 'clusters', 'get-credentials', + 'gcloud', + 'container', + 'clusters', + 'get-credentials', CLUSTER_NAME, - '--zone', PROJECT_LOCATION, - '--project', TEST_GCP_PROJECT_ID, - '--internal-ip' + '--zone', + PROJECT_LOCATION, + '--project', + TEST_GCP_PROJECT_ID, + '--internal-ip', ] ) diff --git a/tests/providers/google/cloud/operators/test_life_sciences.py b/tests/providers/google/cloud/operators/test_life_sciences.py index beb2c92594370..09f790e0040fd 100644 --- a/tests/providers/google/cloud/operators/test_life_sciences.py +++ b/tests/providers/google/cloud/operators/test_life_sciences.py @@ -23,32 +23,25 @@ from airflow.providers.google.cloud.operators.life_sciences import LifeSciencesRunPipelineOperator -TEST_BODY = { - "pipeline": { - "actions": [{}], - "resources": {}, - "environment": {}, - "timeout": '3.5s' - } -} +TEST_BODY = {"pipeline": {"actions": [{}], "resources": {}, "environment": {}, "timeout": '3.5s'}} -TEST_OPERATION = {"name": 'operation-name', "metadata": {"@type": 'anytype'}, - "done": True, "response": "response"} +TEST_OPERATION = { + "name": 'operation-name', + "metadata": {"@type": 'anytype'}, + "done": True, + "response": "response", +} TEST_PROJECT_ID = "life-science-project-id" TEST_LOCATION = 'test-location' class TestLifeSciencesRunPipelineOperator(unittest.TestCase): - @mock.patch("airflow.providers.google.cloud.operators.life_sciences.LifeSciencesHook") def test_executes(self, mock_hook): mock_instance = mock_hook.return_value mock_instance.run_pipeline.return_value = TEST_OPERATION operator = LifeSciencesRunPipelineOperator( - task_id='task-id', - body=TEST_BODY, - location=TEST_LOCATION, - project_id=TEST_PROJECT_ID + task_id='task-id', body=TEST_BODY, location=TEST_LOCATION, project_id=TEST_PROJECT_ID ) result = operator.execute(None) self.assertEqual(result, TEST_OPERATION) @@ -57,10 +50,6 @@ def test_executes(self, mock_hook): def test_executes_without_project_id(self, mock_hook): mock_instance = mock_hook.return_value mock_instance.run_pipeline.return_value = TEST_OPERATION - operator = LifeSciencesRunPipelineOperator( - task_id='task-id', - body=TEST_BODY, - location=TEST_LOCATION, - ) + operator = LifeSciencesRunPipelineOperator(task_id='task-id', body=TEST_BODY, location=TEST_LOCATION,) result = operator.execute(None) self.assertEqual(result, TEST_OPERATION) diff --git a/tests/providers/google/cloud/operators/test_life_sciences_system.py b/tests/providers/google/cloud/operators/test_life_sciences_system.py index cc0b269edb125..bf01794bbe516 100644 --- a/tests/providers/google/cloud/operators/test_life_sciences_system.py +++ b/tests/providers/google/cloud/operators/test_life_sciences_system.py @@ -27,16 +27,11 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_LIFE_SCIENCES_KEY) class CloudLifeSciencesExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_LIFE_SCIENCES_KEY) def setUp(self): super().setUp() self.create_gcs_bucket(BUCKET, LOCATION) - self.upload_content_to_gcs( - lines=f"{os.urandom(1 * 1024 * 1024)}", - bucket=BUCKET, - filename=FILENAME - ) + self.upload_content_to_gcs(lines=f"{os.urandom(1 * 1024 * 1024)}", bucket=BUCKET, filename=FILENAME) @provide_gcp_context(GCP_LIFE_SCIENCES_KEY) def test_run_example_dag_function(self): diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py index 77f65eea61718..30ad775390fc6 100644 --- a/tests/providers/google/cloud/operators/test_mlengine.py +++ b/tests/providers/google/cloud/operators/test_mlengine.py @@ -27,11 +27,19 @@ from airflow.models import TaskInstance from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.mlengine import ( - AIPlatformConsoleLink, MLEngineCreateModelOperator, MLEngineCreateVersionOperator, - MLEngineDeleteModelOperator, MLEngineDeleteVersionOperator, MLEngineGetModelOperator, - MLEngineListVersionsOperator, MLEngineManageModelOperator, MLEngineManageVersionOperator, - MLEngineSetDefaultVersionOperator, MLEngineStartBatchPredictionJobOperator, - MLEngineStartTrainingJobOperator, MLEngineTrainingCancelJobOperator, + AIPlatformConsoleLink, + MLEngineCreateModelOperator, + MLEngineCreateVersionOperator, + MLEngineDeleteModelOperator, + MLEngineDeleteVersionOperator, + MLEngineGetModelOperator, + MLEngineListVersionsOperator, + MLEngineManageModelOperator, + MLEngineManageVersionOperator, + MLEngineSetDefaultVersionOperator, + MLEngineStartBatchPredictionJobOperator, + MLEngineStartTrainingJobOperator, + MLEngineTrainingCancelJobOperator, ) from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.dates import days_ago @@ -51,7 +59,7 @@ TEST_VERSION = { 'name': 'v1', 'deploymentUri': 'gs://some-bucket/jobs/test_training/model.pb', - 'runtimeVersion': '1.6' + 'runtimeVersion': '1.6', } @@ -69,9 +77,9 @@ class TestMLEngineBatchPredictionOperator(unittest.TestCase): 'outputPath': 'gs://fake-output-path', 'predictionCount': 5000, 'errorCount': 0, - 'nodeHours': 2.78 + 'nodeHours': 2.78, }, - 'state': 'SUCCEEDED' + 'state': 'SUCCEEDED', } BATCH_PREDICTION_DEFAULT_ARGS = { 'project_id': 'test-project', @@ -80,35 +88,29 @@ class TestMLEngineBatchPredictionOperator(unittest.TestCase): 'region': 'us-east1', 'data_format': 'TEXT', 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'], - 'output_path': - 'gs://12_legal_bucket_underscore_number/legal-output-path', - 'task_id': 'test-prediction' + 'output_path': 'gs://12_legal_bucket_underscore_number/legal-output-path', + 'task_id': 'test-prediction', } def setUp(self): super().setUp() self.dag = DAG( 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'end_date': DEFAULT_DATE, - }, - schedule_interval='@daily') + default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE, 'end_date': DEFAULT_DATE,}, + schedule_interval='@daily', + ) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_with_model(self, mock_hook): input_with_model = self.INPUT_MISSING_ORIGIN.copy() - input_with_model['modelName'] = \ - 'projects/test-project/models/test_model' + input_with_model['modelName'] = 'projects/test-project/models/test_model' success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') + resp=httplib2.Response({'status': 404}), content=b'some bytes' + ) hook_instance.create_job.return_value = success_message prediction_task = MLEngineStartBatchPredictionJobOperator( @@ -121,35 +123,35 @@ def test_success_with_model(self, mock_hook): model_name=input_with_model['modelName'].split('/')[-1], labels={'some': 'labels'}, dag=self.dag, - task_id='test-prediction') + task_id='test-prediction', + ) prediction_output = prediction_task.execute(None) - mock_hook.assert_called_once_with('google_cloud_default', None, - impersonation_chain=None,) + mock_hook.assert_called_once_with( + 'google_cloud_default', None, impersonation_chain=None, + ) hook_instance.create_job.assert_called_once_with( project_id='test-project', job={ 'jobId': 'test_prediction', 'labels': {'some': 'labels'}, - 'predictionInput': input_with_model + 'predictionInput': input_with_model, }, - use_existing_job_fn=ANY + use_existing_job_fn=ANY, ) self.assertEqual(success_message['predictionOutput'], prediction_output) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_with_version(self, mock_hook): input_with_version = self.INPUT_MISSING_ORIGIN.copy() - input_with_version['versionName'] = \ - 'projects/test-project/models/test_model/versions/test_version' + input_with_version['versionName'] = 'projects/test-project/models/test_model/versions/test_version' success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_version hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') + resp=httplib2.Response({'status': 404}), content=b'some bytes' + ) hook_instance.create_job.return_value = success_message prediction_task = MLEngineStartBatchPredictionJobOperator( @@ -162,18 +164,17 @@ def test_success_with_version(self, mock_hook): model_name=input_with_version['versionName'].split('/')[-3], version_name=input_with_version['versionName'].split('/')[-1], dag=self.dag, - task_id='test-prediction') + task_id='test-prediction', + ) prediction_output = prediction_task.execute(None) - mock_hook.assert_called_once_with('google_cloud_default', None, - impersonation_chain=None,) + mock_hook.assert_called_once_with( + 'google_cloud_default', None, impersonation_chain=None, + ) hook_instance.create_job.assert_called_once_with( project_id='test-project', - job={ - 'jobId': 'test_prediction', - 'predictionInput': input_with_version - }, - use_existing_job_fn=ANY + job={'jobId': 'test_prediction', 'predictionInput': input_with_version}, + use_existing_job_fn=ANY, ) self.assertEqual(success_message['predictionOutput'], prediction_output) @@ -186,9 +187,8 @@ def test_success_with_uri(self, mock_hook): hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') + resp=httplib2.Response({'status': 404}), content=b'some bytes' + ) hook_instance.create_job.return_value = success_message prediction_task = MLEngineStartBatchPredictionJobOperator( @@ -200,18 +200,17 @@ def test_success_with_uri(self, mock_hook): output_path=input_with_uri['outputPath'], uri=input_with_uri['uri'], dag=self.dag, - task_id='test-prediction') + task_id='test-prediction', + ) prediction_output = prediction_task.execute(None) - mock_hook.assert_called_once_with('google_cloud_default', None, - impersonation_chain=None,) + mock_hook.assert_called_once_with( + 'google_cloud_default', None, impersonation_chain=None, + ) hook_instance.create_job.assert_called_once_with( project_id='test-project', - job={ - 'jobId': 'test_prediction', - 'predictionInput': input_with_uri - }, - use_existing_job_fn=ANY + job={'jobId': 'test_prediction', 'predictionInput': input_with_uri}, + use_existing_job_fn=ANY, ) self.assertEqual(success_message['predictionOutput'], prediction_output) @@ -222,9 +221,9 @@ def test_invalid_model_origin(self): task_args['model_name'] = 'fake_model' with self.assertRaises(AirflowException) as context: MLEngineStartBatchPredictionJobOperator(**task_args).execute(None) - self.assertEqual('Ambiguous model origin: Both uri and ' - 'model/version name are provided.', - str(context.exception)) + self.assertEqual( + 'Ambiguous model origin: Both uri and ' 'model/version name are provided.', str(context.exception) + ) # Test that both uri and model/version is given task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() @@ -233,18 +232,19 @@ def test_invalid_model_origin(self): task_args['version_name'] = 'fake_version' with self.assertRaises(AirflowException) as context: MLEngineStartBatchPredictionJobOperator(**task_args).execute(None) - self.assertEqual('Ambiguous model origin: Both uri and ' - 'model/version name are provided.', - str(context.exception)) + self.assertEqual( + 'Ambiguous model origin: Both uri and ' 'model/version name are provided.', str(context.exception) + ) # Test that a version is given without a model task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['version_name'] = 'bare_version' with self.assertRaises(AirflowException) as context: MLEngineStartBatchPredictionJobOperator(**task_args).execute(None) - self.assertEqual('Missing model: Batch prediction expects a model ' - 'name when a version name is provided.', - str(context.exception)) + self.assertEqual( + 'Missing model: Batch prediction expects a model ' 'name when a version name is provided.', + str(context.exception), + ) # Test that none of uri, model, model/version is given task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() @@ -253,21 +253,19 @@ def test_invalid_model_origin(self): self.assertEqual( 'Missing model origin: Batch prediction expects a ' 'model, a model & version combination, or a URI to a savedModel.', - str(context.exception)) + str(context.exception), + ) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_http_error(self, mock_hook): http_error_code = 403 input_with_model = self.INPUT_MISSING_ORIGIN.copy() - input_with_model['modelName'] = \ - 'projects/experimental/models/test_model' + input_with_model['modelName'] = 'projects/experimental/models/test_model' hook_instance = mock_hook.return_value hook_instance.create_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), - content=b'Forbidden') + resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden' + ) with self.assertRaises(HttpError) as context: prediction_task = MLEngineStartBatchPredictionJobOperator( @@ -279,26 +277,23 @@ def test_http_error(self, mock_hook): output_path=input_with_model['outputPath'], model_name=input_with_model['modelName'].split('/')[-1], dag=self.dag, - task_id='test-prediction') + task_id='test-prediction', + ) prediction_task.execute(None) - mock_hook.assert_called_once_with('google_cloud_default', None, - impersonation_chain=None,) + mock_hook.assert_called_once_with( + 'google_cloud_default', None, impersonation_chain=None, + ) hook_instance.create_job.assert_called_once_with( - 'test-project', { - 'jobId': 'test_prediction', - 'predictionInput': input_with_model - }, ANY) + 'test-project', {'jobId': 'test_prediction', 'predictionInput': input_with_model}, ANY + ) self.assertEqual(http_error_code, context.exception.resp.status) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_failed_job_error(self, mock_hook): hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = { - 'state': 'FAILED', - 'errorMessage': 'A failure message' - } + hook_instance.create_job.return_value = {'state': 'FAILED', 'errorMessage': 'A failure message'} task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() task_args['uri'] = 'a uri' @@ -319,7 +314,7 @@ class TestMLEngineTrainingOperator(unittest.TestCase): 'scale_tier': 'STANDARD_1', 'labels': {'some': 'labels'}, 'task_id': 'test-training', - 'start_date': days_ago(1) + 'start_date': days_ago(1), } TRAINING_INPUT = { 'jobId': 'test_training', @@ -329,8 +324,8 @@ class TestMLEngineTrainingOperator(unittest.TestCase): 'packageUris': ['gs://some-bucket/package1'], 'pythonModule': 'trainer', 'args': '--some_arg=\'aaa\'', - 'region': 'us-east1' - } + 'region': 'us-east1', + }, } def setUp(self): @@ -343,19 +338,17 @@ def test_success_create_training_job(self, mock_hook): hook_instance = mock_hook.return_value hook_instance.create_job.return_value = success_response - training_op = MLEngineStartTrainingJobOperator( - **self.TRAINING_DEFAULT_ARGS) + training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) training_op.execute(MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'create_job' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) + project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY + ) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_create_training_job_with_optional_args(self, mock_hook): @@ -373,43 +366,39 @@ def test_success_create_training_job_with_optional_args(self, mock_hook): runtime_version='1.6', python_version='3.5', job_dir='gs://some-bucket/jobs/test_training', - **self.TRAINING_DEFAULT_ARGS) + **self.TRAINING_DEFAULT_ARGS, + ) training_op.execute(MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'create_job' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=training_input, use_existing_job_fn=ANY) + project_id='test-project', job=training_input, use_existing_job_fn=ANY + ) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_http_error(self, mock_hook): http_error_code = 403 hook_instance = mock_hook.return_value hook_instance.create_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), - content=b'Forbidden') + resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden' + ) with self.assertRaises(HttpError) as context: - training_op = MLEngineStartTrainingJobOperator( - **self.TRAINING_DEFAULT_ARGS) + training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) training_op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'create_job' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) + project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY + ) self.assertEqual(http_error_code, context.exception.resp.status) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') @@ -421,30 +410,24 @@ def test_failed_job_error(self, mock_hook): hook_instance.create_job.return_value = failure_response with self.assertRaises(RuntimeError) as context: - training_op = MLEngineStartTrainingJobOperator( - **self.TRAINING_DEFAULT_ARGS) + training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) training_op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'create_job' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) + project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY + ) self.assertEqual('A failure message', str(context.exception)) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_console_extra_link(self, mock_hook): - training_op = MLEngineStartTrainingJobOperator( - **self.TRAINING_DEFAULT_ARGS) + training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) - ti = TaskInstance( - task=training_op, - execution_date=DEFAULT_DATE, - ) + ti = TaskInstance(task=training_op, execution_date=DEFAULT_DATE,) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] project_id = self.TRAINING_DEFAULT_ARGS['project_id'] @@ -460,8 +443,7 @@ def test_console_extra_link(self, mock_hook): ) self.assertEqual( - '', - training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), + '', training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), ) def test_console_extra_link_serialized_field(self): @@ -474,7 +456,7 @@ def test_console_extra_link_serialized_field(self): # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], - [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}] + [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}], ) # Check DeSerialized version of operator link @@ -487,10 +469,7 @@ def test_console_extra_link_serialized_field(self): "project_id": project_id, } - ti = TaskInstance( - task=training_op, - execution_date=DEFAULT_DATE, - ) + ti = TaskInstance(task=training_op, execution_date=DEFAULT_DATE,) ti.xcom_push(key='gcp_metadata', value=gcp_metadata) self.assertEqual( @@ -499,8 +478,7 @@ def test_console_extra_link_serialized_field(self): ) self.assertEqual( - '', - simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), + '', simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), ) @@ -509,7 +487,7 @@ class TestMLEngineTrainingCancelJobOperator(unittest.TestCase): TRAINING_DEFAULT_ARGS = { 'project_id': 'test-project', 'job_id': 'test_training', - 'task_id': 'test-training' + 'task_id': 'test-training', } @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') @@ -518,49 +496,42 @@ def test_success_cancel_training_job(self, mock_hook): hook_instance = mock_hook.return_value hook_instance.cancel_job.return_value = success_response - cancel_training_op = MLEngineTrainingCancelJobOperator( - **self.TRAINING_DEFAULT_ARGS) + cancel_training_op = MLEngineTrainingCancelJobOperator(**self.TRAINING_DEFAULT_ARGS) cancel_training_op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'cancel_job' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.cancel_job.assert_called_once_with( - project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id']) + project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id'] + ) @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_http_error(self, mock_hook): http_error_code = 403 hook_instance = mock_hook.return_value hook_instance.cancel_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), - content=b'Forbidden') + resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden' + ) with self.assertRaises(HttpError) as context: - cancel_training_op = MLEngineTrainingCancelJobOperator( - **self.TRAINING_DEFAULT_ARGS) + cancel_training_op = MLEngineTrainingCancelJobOperator(**self.TRAINING_DEFAULT_ARGS) cancel_training_op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'create_job' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.cancel_job.assert_called_once_with( - project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id']) + project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id'] + ) self.assertEqual(http_error_code, context.exception.resp.status) class TestMLEngineModelOperator(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_create_model(self, mock_hook): task = MLEngineManageModelOperator( @@ -616,14 +587,13 @@ def test_fail(self, mock_hook): model=TEST_MODEL, operation="invalid", gcp_conn_id=TEST_GCP_CONN_ID, - delegate_to=TEST_DELEGATE_TO + delegate_to=TEST_DELEGATE_TO, ) with self.assertRaises(ValueError): task.execute(None) class TestMLEngineCreateModelOperator(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_create_model(self, mock_hook): task = MLEngineCreateModelOperator( @@ -648,7 +618,6 @@ def test_success_create_model(self, mock_hook): class TestMLEngineGetModelOperator(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_get_model(self, mock_hook): task = MLEngineGetModelOperator( @@ -674,7 +643,6 @@ def test_success_get_model(self, mock_hook): class TestMLEngineDeleteModelOperator(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success_delete_model(self, mock_hook): task = MLEngineDeleteModelOperator( @@ -684,7 +652,7 @@ def test_success_delete_model(self, mock_hook): gcp_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, impersonation_chain=TEST_IMPERSONATION_CHAIN, - delete_contents=True + delete_contents=True, ) task.execute(None) @@ -703,7 +671,7 @@ class TestMLEngineVersionOperator(unittest.TestCase): VERSION_DEFAULT_ARGS = { 'project_id': 'test-project', 'model_name': 'test-model', - 'task_id': 'test-version' + 'task_id': 'test-version', } @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') @@ -712,24 +680,20 @@ def test_success_create_version(self, mock_hook): hook_instance = mock_hook.return_value hook_instance.create_version.return_value = success_response - training_op = MLEngineManageVersionOperator( - version=TEST_VERSION, - **self.VERSION_DEFAULT_ARGS) + training_op = MLEngineManageVersionOperator(version=TEST_VERSION, **self.VERSION_DEFAULT_ARGS) training_op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ) # Make sure only 'create_version' is invoked on hook instance self.assertEqual(len(hook_instance.mock_calls), 1) hook_instance.create_version.assert_called_once_with( - project_id='test-project', model_name='test-model', version_spec=TEST_VERSION) + project_id='test-project', model_name='test-model', version_spec=TEST_VERSION + ) class TestMLEngineCreateVersion(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success(self, mock_hook): task = MLEngineCreateVersionOperator( @@ -750,9 +714,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.create_version.assert_called_once_with( - project_id=TEST_PROJECT_ID, - model_name=TEST_MODEL_NAME, - version_spec=TEST_VERSION + project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_spec=TEST_VERSION ) def test_missing_model_name(self): @@ -779,7 +741,6 @@ def test_missing_version(self): class TestMLEngineSetDefaultVersion(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success(self, mock_hook): task = MLEngineSetDefaultVersionOperator( @@ -800,9 +761,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.set_default_version.assert_called_once_with( - project_id=TEST_PROJECT_ID, - model_name=TEST_MODEL_NAME, - version_name=TEST_VERSION_NAME + project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=TEST_VERSION_NAME ) def test_missing_model_name(self): @@ -829,7 +788,6 @@ def test_missing_version_name(self): class TestMLEngineListVersions(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success(self, mock_hook): task = MLEngineListVersionsOperator( @@ -849,8 +807,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.list_versions.assert_called_once_with( - project_id=TEST_PROJECT_ID, - model_name=TEST_MODEL_NAME, + project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, ) def test_missing_model_name(self): @@ -865,7 +822,6 @@ def test_missing_model_name(self): class TestMLEngineDeleteVersion(unittest.TestCase): - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') def test_success(self, mock_hook): task = MLEngineDeleteVersionOperator( @@ -886,9 +842,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.delete_version.assert_called_once_with( - project_id=TEST_PROJECT_ID, - model_name=TEST_MODEL_NAME, - version_name=TEST_VERSION_NAME + project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=TEST_VERSION_NAME ) def test_missing_version_name(self): diff --git a/tests/providers/google/cloud/operators/test_mlengine_system.py b/tests/providers/google/cloud/operators/test_mlengine_system.py index e42b3ffae3752..c7ab84fa76b0a 100644 --- a/tests/providers/google/cloud/operators/test_mlengine_system.py +++ b/tests/providers/google/cloud/operators/test_mlengine_system.py @@ -20,7 +20,11 @@ import pytest from airflow.providers.google.cloud.example_dags.example_mlengine import ( - JOB_DIR, PREDICTION_OUTPUT, SAVED_MODEL_PATH, SUMMARY_STAGING, SUMMARY_TMP, + JOB_DIR, + PREDICTION_OUTPUT, + SAVED_MODEL_PATH, + SUMMARY_STAGING, + SUMMARY_TMP, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_AI_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -33,7 +37,6 @@ @pytest.mark.credential_file(GCP_AI_KEY) class MlEngineExampleDagTest(GoogleSystemTest): - @provide_gcp_context(GCP_AI_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py index b6d2286b4e122..87dfc9591b781 100644 --- a/tests/providers/google/cloud/operators/test_mlengine_utils.py +++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py @@ -44,9 +44,9 @@ class TestCreateEvaluateOps(unittest.TestCase): 'outputPath': 'gs://fake-output-path', 'predictionCount': 5000, 'errorCount': 0, - 'nodeHours': 2.78 + 'nodeHours': 2.78, }, - 'state': 'SUCCEEDED' + 'state': 'SUCCEEDED', } def setUp(self): @@ -62,10 +62,12 @@ def setUp(self): 'model_name': 'test_model', 'version_name': 'test_version', }, - schedule_interval='@daily') + schedule_interval='@daily', + ) self.metric_fn = lambda x: (0.1,) self.metric_fn_encoded = mlengine_operator_utils.base64.b64encode( - mlengine_operator_utils.dill.dumps(self.metric_fn, recurse=True)).decode() + mlengine_operator_utils.dill.dumps(self.metric_fn, recurse=True) + ).decode() def test_successful_run(self): input_with_model = self.INPUT_MISSING_ORIGIN.copy() @@ -88,15 +90,14 @@ def test_successful_run(self): hook_instance = mock_mlengine_hook.return_value hook_instance.create_job.return_value = success_message result = pred.execute(None) - mock_mlengine_hook.assert_called_once_with('google_cloud_default', None, - impersonation_chain=None,) + mock_mlengine_hook.assert_called_once_with( + 'google_cloud_default', None, impersonation_chain=None, + ) hook_instance.create_job.assert_called_once_with( project_id='test-project', - job={ - 'jobId': 'eval_test_prediction', - 'predictionInput': input_with_model, - }, - use_existing_job_fn=ANY) + job={'jobId': 'eval_test_prediction', 'predictionInput': input_with_model,}, + use_existing_job_fn=ANY, + ) self.assertEqual(success_message['predictionOutput'], result) with patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') as mock_dataflow_hook: @@ -104,7 +105,8 @@ def test_successful_run(self): hook_instance.start_python_dataflow.return_value = None summary.execute(None) mock_dataflow_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10) + gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10 + ) hook_instance.start_python_dataflow.assert_called_once_with( job_name='{{task.task_id}}', variables={ @@ -128,7 +130,8 @@ def test_successful_run(self): hook_instance.download.return_value = '{"err": 0.9, "count": 9}' result = validate.execute({}) hook_instance.download.assert_called_once_with( - 'legal-bucket', 'fake-output-path/prediction.summary.json') + 'legal-bucket', 'fake-output-path/prediction.summary.json' + ) self.assertEqual('err=0.9', result) def test_failures(self): @@ -142,7 +145,8 @@ def create_test_dag(dag_id): 'project_id': 'test-project', 'region': 'us-east1', }, - schedule_interval='@daily') + schedule_interval='@daily', + ) return dag input_with_model = self.INPUT_MISSING_ORIGIN.copy() @@ -158,26 +162,35 @@ def create_test_dag(dag_id): with self.assertRaisesRegex(AirflowException, 'Missing model origin'): mlengine_operator_utils.create_evaluate_ops( - dag=create_test_dag('test_dag_1'), **other_params_but_models) + dag=create_test_dag('test_dag_1'), **other_params_but_models + ) with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops( - dag=create_test_dag('test_dag_2'), model_uri='abc', model_name='cde', - **other_params_but_models) + dag=create_test_dag('test_dag_2'), + model_uri='abc', + model_name='cde', + **other_params_but_models, + ) with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops( - dag=create_test_dag('test_dag_3'), model_uri='abc', version_name='vvv', - **other_params_but_models) + dag=create_test_dag('test_dag_3'), + model_uri='abc', + version_name='vvv', + **other_params_but_models, + ) with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'): params = other_params_but_models.copy() params['metric_fn_and_keys'] = (None, ['abc']) mlengine_operator_utils.create_evaluate_ops( - dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params) + dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params + ) with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'): params = other_params_but_models.copy() params['validate_fn'] = None mlengine_operator_utils.create_evaluate_ops( - dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params) + dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params + ) diff --git a/tests/providers/google/cloud/operators/test_natural_language.py b/tests/providers/google/cloud/operators/test_natural_language.py index b1a580fa4e3e3..d948c5837a59c 100644 --- a/tests/providers/google/cloud/operators/test_natural_language.py +++ b/tests/providers/google/cloud/operators/test_natural_language.py @@ -19,14 +19,19 @@ import unittest from google.cloud.language_v1.proto.language_service_pb2 import ( - AnalyzeEntitiesResponse, AnalyzeEntitySentimentResponse, AnalyzeSentimentResponse, ClassifyTextResponse, + AnalyzeEntitiesResponse, + AnalyzeEntitySentimentResponse, + AnalyzeSentimentResponse, + ClassifyTextResponse, Document, ) from mock import patch from airflow.providers.google.cloud.operators.natural_language import ( - CloudNaturalLanguageAnalyzeEntitiesOperator, CloudNaturalLanguageAnalyzeEntitySentimentOperator, - CloudNaturalLanguageAnalyzeSentimentOperator, CloudNaturalLanguageClassifyTextOperator, + CloudNaturalLanguageAnalyzeEntitiesOperator, + CloudNaturalLanguageAnalyzeEntitySentimentOperator, + CloudNaturalLanguageAnalyzeSentimentOperator, + CloudNaturalLanguageClassifyTextOperator, ) DOCUMENT = Document( diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py index f3d2eaa776653..93b509f1dff76 100644 --- a/tests/providers/google/cloud/operators/test_pubsub.py +++ b/tests/providers/google/cloud/operators/test_pubsub.py @@ -24,8 +24,12 @@ from google.protobuf.json_format import MessageToDict, ParseDict from airflow.providers.google.cloud.operators.pubsub import ( - PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, PubSubDeleteSubscriptionOperator, - PubSubDeleteTopicOperator, PubSubPublishMessageOperator, PubSubPullOperator, + PubSubCreateSubscriptionOperator, + PubSubCreateTopicOperator, + PubSubDeleteSubscriptionOperator, + PubSubDeleteTopicOperator, + PubSubPublishMessageOperator, + PubSubPullOperator, ) TASK_ID = 'test-task-id' @@ -33,22 +37,17 @@ TEST_TOPIC = 'test-topic' TEST_SUBSCRIPTION = 'test-subscription' TEST_MESSAGES = [ - { - 'data': b'Hello, World!', - 'attributes': {'type': 'greeting'} - }, + {'data': b'Hello, World!', 'attributes': {'type': 'greeting'}}, {'data': b'Knock, knock'}, - {'attributes': {'foo': ''}}] + {'attributes': {'foo': ''}}, +] class TestPubSubTopicCreateOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_failifexists(self, mock_hook): operator = PubSubCreateTopicOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC, - fail_if_exists=True + task_id=TASK_ID, project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_exists=True ) operator.execute(None) @@ -67,10 +66,7 @@ def test_failifexists(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_succeedifexists(self, mock_hook): operator = PubSubCreateTopicOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC, - fail_if_exists=False + task_id=TASK_ID, project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_exists=False ) operator.execute(None) @@ -83,18 +79,14 @@ def test_succeedifexists(self, mock_hook): kms_key_name=None, retry=None, timeout=None, - metadata=None + metadata=None, ) class TestPubSubTopicDeleteOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): - operator = PubSubDeleteTopicOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC - ) + operator = PubSubDeleteTopicOperator(task_id=TASK_ID, project_id=TEST_PROJECT, topic=TEST_TOPIC) operator.execute(None) mock_hook.return_value.delete_topic.assert_called_once_with( @@ -103,7 +95,7 @@ def test_execute(self, mock_hook): fail_if_not_exists=False, retry=None, timeout=None, - metadata=None + metadata=None, ) @@ -111,10 +103,7 @@ class TestPubSubSubscriptionCreateOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubCreateSubscriptionOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC, - subscription=TEST_SUBSCRIPTION + task_id=TASK_ID, project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION ) mock_hook.return_value.create_subscription.return_value = TEST_SUBSCRIPTION response = operator.execute(None) @@ -148,7 +137,7 @@ def test_execute_different_project_ids(self, mock_hook): topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, subscription_project_id=another_project, - task_id=TASK_ID + task_id=TASK_ID, ) mock_hook.return_value.create_subscription.return_value = TEST_SUBSCRIPTION response = operator.execute(None) @@ -170,16 +159,14 @@ def test_execute_different_project_ids(self, mock_hook): retry_policy=None, retry=None, timeout=None, - metadata=None + metadata=None, ) self.assertEqual(response, TEST_SUBSCRIPTION) @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute_no_subscription(self, mock_hook): operator = PubSubCreateSubscriptionOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC + task_id=TASK_ID, project_id=TEST_PROJECT, topic=TEST_TOPIC ) mock_hook.return_value.create_subscription.return_value = TEST_SUBSCRIPTION response = operator.execute(None) @@ -210,9 +197,7 @@ class TestPubSubSubscriptionDeleteOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubDeleteSubscriptionOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION + task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION ) operator.execute(None) @@ -230,10 +215,7 @@ class TestPubSubPublishOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_publish(self, mock_hook): operator = PubSubPublishMessageOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - topic=TEST_TOPIC, - messages=TEST_MESSAGES, + task_id=TASK_ID, project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES, ) operator.execute(None) @@ -259,17 +241,12 @@ def _generate_messages(self, count): ] def _generate_dicts(self, count): - return [ - MessageToDict(m) - for m in self._generate_messages(count) - ] + return [MessageToDict(m) for m in self._generate_messages(count)] @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute_no_messages(self, mock_hook): operator = PubSubPullOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, + task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ) mock_hook.return_value.pull.return_value = [] @@ -278,10 +255,7 @@ def test_execute_no_messages(self, mock_hook): @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') def test_execute_with_ack_messages(self, mock_hook): operator = PubSubPullOperator( - task_id=TASK_ID, - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - ack_messages=True, + task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_messages=True, ) generated_messages = self._generate_messages(5) @@ -290,9 +264,7 @@ def test_execute_with_ack_messages(self, mock_hook): self.assertEqual(generated_dicts, operator.execute({})) mock_hook.return_value.acknowledge.assert_called_once_with( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - messages=generated_messages, + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, messages=generated_messages, ) @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook') @@ -301,8 +273,7 @@ def test_execute_with_messages_callback(self, mock_hook): messages_callback_return_value = 'asdfg' def messages_callback( - pulled_messages: List[ReceivedMessage], - context: Dict[str, Any], + pulled_messages: List[ReceivedMessage], context: Dict[str, Any], ): assert pulled_messages == generated_messages @@ -325,10 +296,7 @@ def messages_callback( response = operator.execute({}) mock_hook.return_value.pull.assert_called_once_with( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - max_messages=5, - return_immediately=True + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, return_immediately=True ) messages_callback.assert_called_once() diff --git a/tests/providers/google/cloud/operators/test_spanner.py b/tests/providers/google/cloud/operators/test_spanner.py index 7f0c977690c17..1e3c2b9f1becf 100644 --- a/tests/providers/google/cloud/operators/test_spanner.py +++ b/tests/providers/google/cloud/operators/test_spanner.py @@ -22,9 +22,12 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.spanner import ( - SpannerDeleteDatabaseInstanceOperator, SpannerDeleteInstanceOperator, - SpannerDeployDatabaseInstanceOperator, SpannerDeployInstanceOperator, - SpannerQueryDatabaseInstanceOperator, SpannerUpdateDatabaseInstanceOperator, + SpannerDeleteDatabaseInstanceOperator, + SpannerDeleteInstanceOperator, + SpannerDeployDatabaseInstanceOperator, + SpannerDeployInstanceOperator, + SpannerQueryDatabaseInstanceOperator, + SpannerUpdateDatabaseInstanceOperator, ) PROJECT_ID = 'project-id' @@ -50,19 +53,18 @@ def test_instance_create(self, mock_hook): configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), display_name=DISPLAY_NAME, - task_id="id" + task_id="id", ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.create_instance.assert_called_once_with( project_id=PROJECT_ID, instance_id=INSTANCE_ID, configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), - display_name=DISPLAY_NAME + display_name=DISPLAY_NAME, ) mock_hook.return_value.update_instance.assert_not_called() self.assertIsNone(result) @@ -75,19 +77,18 @@ def test_instance_create_missing_project_id(self, mock_hook): configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), display_name=DISPLAY_NAME, - task_id="id" + task_id="id", ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.create_instance.assert_called_once_with( project_id=None, instance_id=INSTANCE_ID, configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), - display_name=DISPLAY_NAME + display_name=DISPLAY_NAME, ) mock_hook.return_value.update_instance.assert_not_called() self.assertIsNone(result) @@ -101,19 +102,18 @@ def test_instance_update(self, mock_hook): configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), display_name=DISPLAY_NAME, - task_id="id" + task_id="id", ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.update_instance.assert_called_once_with( project_id=PROJECT_ID, instance_id=INSTANCE_ID, configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), - display_name=DISPLAY_NAME + display_name=DISPLAY_NAME, ) mock_hook.return_value.create_instance.assert_not_called() self.assertIsNone(result) @@ -126,19 +126,18 @@ def test_instance_update_missing_project_id(self, mock_hook): configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), display_name=DISPLAY_NAME, - task_id="id" + task_id="id", ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.update_instance.assert_called_once_with( project_id=None, instance_id=INSTANCE_ID, configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), - display_name=DISPLAY_NAME + display_name=DISPLAY_NAME, ) mock_hook.return_value.create_instance.assert_not_called() self.assertIsNone(result) @@ -152,23 +151,20 @@ def test_instance_create_aborts_and_succeeds_if_instance_exists(self, mock_hook) configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), display_name=DISPLAY_NAME, - task_id="id" + task_id="id", ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.create_instance.assert_not_called() self.assertIsNone(result) - @parameterized.expand([ - ("", INSTANCE_ID, "project_id"), - (PROJECT_ID, "", "instance_id"), - ]) + @parameterized.expand( + [("", INSTANCE_ID, "project_id"), (PROJECT_ID, "", "instance_id"),] + ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_instance_create_ex_if_param_missing(self, project_id, instance_id, - exp_msg, mock_hook): + def test_instance_create_ex_if_param_missing(self, project_id, instance_id, exp_msg, mock_hook): with self.assertRaises(AirflowException) as cm: SpannerDeployInstanceOperator( project_id=project_id, @@ -176,7 +172,7 @@ def test_instance_create_ex_if_param_missing(self, project_id, instance_id, configuration_name=CONFIG_NAME, node_count=int(NODE_COUNT), display_name=DISPLAY_NAME, - task_id="id" + task_id="id", ) err = cm.exception self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) @@ -185,15 +181,10 @@ def test_instance_create_ex_if_param_missing(self, project_id, instance_id, @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_instance_delete(self, mock_hook): mock_hook.return_value.get_instance.return_value = {"name": INSTANCE_ID} - op = SpannerDeleteInstanceOperator( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - task_id="id" - ) + op = SpannerDeleteInstanceOperator(project_id=PROJECT_ID, instance_id=INSTANCE_ID, task_id="id") result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.delete_instance.assert_called_once_with( project_id=PROJECT_ID, instance_id=INSTANCE_ID @@ -203,51 +194,34 @@ def test_instance_delete(self, mock_hook): @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_instance_delete_missing_project_id(self, mock_hook): mock_hook.return_value.get_instance.return_value = {"name": INSTANCE_ID} - op = SpannerDeleteInstanceOperator( - instance_id=INSTANCE_ID, - task_id="id" - ) + op = SpannerDeleteInstanceOperator(instance_id=INSTANCE_ID, task_id="id") result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.delete_instance.assert_called_once_with( - project_id=None, - instance_id=INSTANCE_ID + project_id=None, instance_id=INSTANCE_ID ) self.assertTrue(result) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_instance_delete_aborts_and_succeeds_if_instance_does_not_exist(self, - mock_hook): + def test_instance_delete_aborts_and_succeeds_if_instance_does_not_exist(self, mock_hook): mock_hook.return_value.get_instance.return_value = None - op = SpannerDeleteInstanceOperator( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - task_id="id" - ) + op = SpannerDeleteInstanceOperator(project_id=PROJECT_ID, instance_id=INSTANCE_ID, task_id="id") result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.delete_instance.assert_not_called() self.assertTrue(result) - @parameterized.expand([ - ("", INSTANCE_ID, "project_id"), - (PROJECT_ID, "", "instance_id"), - ]) + @parameterized.expand( + [("", INSTANCE_ID, "project_id"), (PROJECT_ID, "", "instance_id"),] + ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_instance_delete_ex_if_param_missing(self, project_id, instance_id, exp_msg, - mock_hook): + def test_instance_delete_ex_if_param_missing(self, project_id, instance_id, exp_msg, mock_hook): with self.assertRaises(AirflowException) as cm: - SpannerDeleteInstanceOperator( - project_id=project_id, - instance_id=instance_id, - task_id="id" - ) + SpannerDeleteInstanceOperator(project_id=project_id, instance_id=instance_id, task_id="id") err = cm.exception self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) mock_hook.assert_not_called() @@ -260,17 +234,14 @@ def test_instance_query(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, query=INSERT_QUERY, - task_id="id" + task_id="id", ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.execute_dml.assert_called_once_with( - project_id=PROJECT_ID, instance_id=INSTANCE_ID, - database_id=DB_ID, - queries=[INSERT_QUERY] + project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY] ) self.assertIsNone(result) @@ -278,38 +249,36 @@ def test_instance_query(self, mock_hook): def test_instance_query_missing_project_id(self, mock_hook): mock_hook.return_value.execute_sql.return_value = None op = SpannerQueryDatabaseInstanceOperator( - instance_id=INSTANCE_ID, - database_id=DB_ID, - query=INSERT_QUERY, - task_id="id" + instance_id=INSTANCE_ID, database_id=DB_ID, query=INSERT_QUERY, task_id="id" ) result = op.execute(None) # pylint: disable=assignment-from-no-return mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.execute_dml.assert_called_once_with( - project_id=None, instance_id=INSTANCE_ID, - database_id=DB_ID, queries=[INSERT_QUERY] + project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY] ) self.assertIsNone(result) - @parameterized.expand([ - ("", INSTANCE_ID, DB_ID, INSERT_QUERY, "project_id"), - (PROJECT_ID, "", DB_ID, INSERT_QUERY, "instance_id"), - (PROJECT_ID, INSTANCE_ID, "", INSERT_QUERY, "database_id"), - (PROJECT_ID, INSTANCE_ID, DB_ID, "", "query"), - ]) + @parameterized.expand( + [ + ("", INSTANCE_ID, DB_ID, INSERT_QUERY, "project_id"), + (PROJECT_ID, "", DB_ID, INSERT_QUERY, "instance_id"), + (PROJECT_ID, INSTANCE_ID, "", INSERT_QUERY, "database_id"), + (PROJECT_ID, INSTANCE_ID, DB_ID, "", "query"), + ] + ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_instance_query_ex_if_param_missing(self, project_id, instance_id, - database_id, query, exp_msg, mock_hook): + def test_instance_query_ex_if_param_missing( + self, project_id, instance_id, database_id, query, exp_msg, mock_hook + ): with self.assertRaises(AirflowException) as cm: SpannerQueryDatabaseInstanceOperator( project_id=project_id, instance_id=instance_id, database_id=database_id, query=query, - task_id="id" + task_id="id", ) err = cm.exception self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) @@ -323,12 +292,11 @@ def test_instance_query_dml(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, query=INSERT_QUERY, - task_id="id" + task_id="id", ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.execute_dml.assert_called_once_with( project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY] @@ -342,16 +310,17 @@ def test_instance_query_dml_list(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, query=[INSERT_QUERY, INSERT_QUERY_2], - task_id="id" + task_id="id", ) op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.execute_dml.assert_called_once_with( - project_id=PROJECT_ID, instance_id=INSTANCE_ID, - database_id=DB_ID, queries=[INSERT_QUERY, INSERT_QUERY_2] + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id=DB_ID, + queries=[INSERT_QUERY, INSERT_QUERY_2], ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") @@ -362,16 +331,14 @@ def test_database_create(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS, - task_id="id" + task_id="id", ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.create_database.assert_called_once_with( - project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, - ddl_statements=DDL_STATEMENTS + project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS ) mock_hook.return_value.update_database.assert_not_called() self.assertTrue(result) @@ -380,19 +347,14 @@ def test_database_create(self, mock_hook): def test_database_create_missing_project_id(self, mock_hook): mock_hook.return_value.get_database.return_value = None op = SpannerDeployDatabaseInstanceOperator( - instance_id=INSTANCE_ID, - database_id=DB_ID, - ddl_statements=DDL_STATEMENTS, - task_id="id" + instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS, task_id="id" ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.create_database.assert_called_once_with( - project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, - ddl_statements=DDL_STATEMENTS + project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS ) mock_hook.return_value.update_database.assert_not_called() self.assertTrue(result) @@ -405,34 +367,34 @@ def test_database_create_with_pre_existing_db(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS, - task_id="id" + task_id="id", ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.create_database.assert_not_called() mock_hook.return_value.update_database.assert_not_called() self.assertTrue(result) - @parameterized.expand([ - ("", INSTANCE_ID, DB_ID, DDL_STATEMENTS, 'project_id'), - (PROJECT_ID, "", DB_ID, DDL_STATEMENTS, 'instance_id'), - (PROJECT_ID, INSTANCE_ID, "", DDL_STATEMENTS, 'database_id'), - ]) + @parameterized.expand( + [ + ("", INSTANCE_ID, DB_ID, DDL_STATEMENTS, 'project_id'), + (PROJECT_ID, "", DB_ID, DDL_STATEMENTS, 'instance_id'), + (PROJECT_ID, INSTANCE_ID, "", DDL_STATEMENTS, 'database_id'), + ] + ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_database_create_ex_if_param_missing(self, - project_id, instance_id, - database_id, ddl_statements, - exp_msg, mock_hook): + def test_database_create_ex_if_param_missing( + self, project_id, instance_id, database_id, ddl_statements, exp_msg, mock_hook + ): with self.assertRaises(AirflowException) as cm: SpannerDeployDatabaseInstanceOperator( project_id=project_id, instance_id=instance_id, database_id=database_id, ddl_statements=ddl_statements, - task_id="id" + task_id="id", ) err = cm.exception self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) @@ -446,16 +408,18 @@ def test_database_update(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS, - task_id="id" + task_id="id", ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.update_database.assert_called_once_with( - project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, - ddl_statements=DDL_STATEMENTS, operation_id=None + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id=DB_ID, + ddl_statements=DDL_STATEMENTS, + operation_id=None, ) self.assertTrue(result) @@ -463,38 +427,39 @@ def test_database_update(self, mock_hook): def test_database_update_missing_project_id(self, mock_hook): mock_hook.return_value.get_database.return_value = {"name": DB_ID} op = SpannerUpdateDatabaseInstanceOperator( - instance_id=INSTANCE_ID, - database_id=DB_ID, - ddl_statements=DDL_STATEMENTS, - task_id="id" + instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS, task_id="id" ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.update_database.assert_called_once_with( - project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, - ddl_statements=DDL_STATEMENTS, operation_id=None + project_id=None, + instance_id=INSTANCE_ID, + database_id=DB_ID, + ddl_statements=DDL_STATEMENTS, + operation_id=None, ) self.assertTrue(result) - @parameterized.expand([ - ("", INSTANCE_ID, DB_ID, DDL_STATEMENTS, 'project_id'), - (PROJECT_ID, "", DB_ID, DDL_STATEMENTS, 'instance_id'), - (PROJECT_ID, INSTANCE_ID, "", DDL_STATEMENTS, 'database_id'), - ]) + @parameterized.expand( + [ + ("", INSTANCE_ID, DB_ID, DDL_STATEMENTS, 'project_id'), + (PROJECT_ID, "", DB_ID, DDL_STATEMENTS, 'instance_id'), + (PROJECT_ID, INSTANCE_ID, "", DDL_STATEMENTS, 'database_id'), + ] + ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_database_update_ex_if_param_missing(self, project_id, instance_id, - database_id, ddl_statements, - exp_msg, mock_hook): + def test_database_update_ex_if_param_missing( + self, project_id, instance_id, database_id, ddl_statements, exp_msg, mock_hook + ): with self.assertRaises(AirflowException) as cm: SpannerUpdateDatabaseInstanceOperator( project_id=project_id, instance_id=instance_id, database_id=database_id, ddl_statements=ddl_statements, - task_id="id" + task_id="id", ) err = cm.exception self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) @@ -509,30 +474,28 @@ def test_database_update_ex_if_database_not_exist(self, mock_hook): instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS, - task_id="id" + task_id="id", ) op.execute(None) err = cm.exception - self.assertIn("The Cloud Spanner database 'db1' in project 'project-id' and " - "instance 'instance-id' is missing", str(err)) + self.assertIn( + "The Cloud Spanner database 'db1' in project 'project-id' and " + "instance 'instance-id' is missing", + str(err), + ) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_database_delete(self, mock_hook): mock_hook.return_value.get_database.return_value = {"name": DB_ID} op = SpannerDeleteDatabaseInstanceOperator( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - database_id=DB_ID, - task_id="id" + project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, task_id="id" ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.delete_database.assert_called_once_with( project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID @@ -542,15 +505,10 @@ def test_database_delete(self, mock_hook): @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_database_delete_missing_project_id(self, mock_hook): mock_hook.return_value.get_database.return_value = {"name": DB_ID} - op = SpannerDeleteDatabaseInstanceOperator( - instance_id=INSTANCE_ID, - database_id=DB_ID, - task_id="id" - ) + op = SpannerDeleteDatabaseInstanceOperator(instance_id=INSTANCE_ID, database_id=DB_ID, task_id="id") result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.delete_database.assert_called_once_with( project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID @@ -558,39 +516,36 @@ def test_database_delete_missing_project_id(self, mock_hook): self.assertTrue(result) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_database_delete_exits_and_succeeds_if_database_does_not_exist(self, - mock_hook): + def test_database_delete_exits_and_succeeds_if_database_does_not_exist(self, mock_hook): mock_hook.return_value.get_database.return_value = None op = SpannerDeleteDatabaseInstanceOperator( - project_id=PROJECT_ID, - instance_id=INSTANCE_ID, - database_id=DB_ID, - task_id="id" + project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, task_id="id" ) result = op.execute(None) mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - impersonation_chain=None, + gcp_conn_id="google_cloud_default", impersonation_chain=None, ) mock_hook.return_value.delete_database.assert_not_called() self.assertTrue(result) - @parameterized.expand([ - ("", INSTANCE_ID, DB_ID, DDL_STATEMENTS, 'project_id'), - (PROJECT_ID, "", DB_ID, DDL_STATEMENTS, 'instance_id'), - (PROJECT_ID, INSTANCE_ID, "", DDL_STATEMENTS, 'database_id'), - ]) + @parameterized.expand( + [ + ("", INSTANCE_ID, DB_ID, DDL_STATEMENTS, 'project_id'), + (PROJECT_ID, "", DB_ID, DDL_STATEMENTS, 'instance_id'), + (PROJECT_ID, INSTANCE_ID, "", DDL_STATEMENTS, 'database_id'), + ] + ) @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") - def test_database_delete_ex_if_param_missing(self, project_id, instance_id, - database_id, ddl_statements, - exp_msg, mock_hook): + def test_database_delete_ex_if_param_missing( + self, project_id, instance_id, database_id, ddl_statements, exp_msg, mock_hook + ): with self.assertRaises(AirflowException) as cm: SpannerDeleteDatabaseInstanceOperator( project_id=project_id, instance_id=instance_id, database_id=database_id, ddl_statements=ddl_statements, - task_id="id" + task_id="id", ) err = cm.exception self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) diff --git a/tests/providers/google/cloud/operators/test_spanner_system.py b/tests/providers/google/cloud/operators/test_spanner_system.py index 94ea6009cbb0f..45afac99d954e 100644 --- a/tests/providers/google/cloud/operators/test_spanner_system.py +++ b/tests/providers/google/cloud/operators/test_spanner_system.py @@ -29,14 +29,22 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_SPANNER_KEY) class CloudSpannerExampleDagsTest(GoogleSystemTest): - @provide_gcp_context(GCP_SPANNER_KEY) def tearDown(self): - self.execute_with_ctx([ - 'gcloud', 'spanner', '--project', GCP_PROJECT_ID, - '--quiet', '--verbosity=none', - 'instances', 'delete', GCP_SPANNER_INSTANCE_ID - ], key=GCP_SPANNER_KEY) + self.execute_with_ctx( + [ + 'gcloud', + 'spanner', + '--project', + GCP_PROJECT_ID, + '--quiet', + '--verbosity=none', + 'instances', + 'delete', + GCP_SPANNER_INSTANCE_ID, + ], + key=GCP_SPANNER_KEY, + ) super().tearDown() @provide_gcp_context(GCP_SPANNER_KEY) diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py index 8af4b014a7995..2315bf242d981 100644 --- a/tests/providers/google/cloud/operators/test_speech_to_text.py +++ b/tests/providers/google/cloud/operators/test_speech_to_text.py @@ -45,8 +45,7 @@ def test_recognize_speech_green_path(self, mock_hook): ).execute(context={"task_instance": Mock()}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.recognize_speech.assert_called_once_with( config=CONFIG, audio=AUDIO, retry=None, timeout=None diff --git a/tests/providers/google/cloud/operators/test_speech_to_text_system.py b/tests/providers/google/cloud/operators/test_speech_to_text_system.py index 0396593748f4e..b7bdf23b888ec 100644 --- a/tests/providers/google/cloud/operators/test_speech_to_text_system.py +++ b/tests/providers/google/cloud/operators/test_speech_to_text_system.py @@ -26,7 +26,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class GCPTextToSpeechExampleDagSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/operators/test_stackdriver.py b/tests/providers/google/cloud/operators/test_stackdriver.py index e5594dc113551..387f857428967 100644 --- a/tests/providers/google/cloud/operators/test_stackdriver.py +++ b/tests/providers/google/cloud/operators/test_stackdriver.py @@ -23,11 +23,16 @@ from google.api_core.gapic_v1.method import DEFAULT from airflow.providers.google.cloud.operators.stackdriver import ( - StackdriverDeleteAlertOperator, StackdriverDeleteNotificationChannelOperator, - StackdriverDisableAlertPoliciesOperator, StackdriverDisableNotificationChannelsOperator, - StackdriverEnableAlertPoliciesOperator, StackdriverEnableNotificationChannelsOperator, - StackdriverListAlertPoliciesOperator, StackdriverListNotificationChannelsOperator, - StackdriverUpsertAlertOperator, StackdriverUpsertNotificationChannelOperator, + StackdriverDeleteAlertOperator, + StackdriverDeleteNotificationChannelOperator, + StackdriverDisableAlertPoliciesOperator, + StackdriverDisableNotificationChannelsOperator, + StackdriverEnableAlertPoliciesOperator, + StackdriverEnableNotificationChannelsOperator, + StackdriverListAlertPoliciesOperator, + StackdriverListNotificationChannelsOperator, + StackdriverUpsertAlertOperator, + StackdriverUpsertNotificationChannelOperator, ) TEST_TASK_ID = 'test-stackdriver-operator' @@ -35,85 +40,60 @@ TEST_ALERT_POLICY_1 = { "combiner": "OR", "name": "projects/sd-project/alertPolicies/12345", - "creationRecord": { - "mutatedBy": "user123", - "mutateTime": "2020-01-01T00:00:00.000000Z" - }, + "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"}, "enabled": True, "displayName": "test display", "conditions": [ { "conditionThreshold": { "comparison": "COMPARISON_GT", - "aggregations": [ - { - "alignmentPeriod": "60s", - "perSeriesAligner": "ALIGN_RATE" - } - ] + "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}], }, "displayName": "Condition display", - "name": "projects/sd-project/alertPolicies/123/conditions/456" + "name": "projects/sd-project/alertPolicies/123/conditions/456", } - ] + ], } TEST_ALERT_POLICY_2 = { "combiner": "OR", "name": "projects/sd-project/alertPolicies/6789", - "creationRecord": { - "mutatedBy": "user123", - "mutateTime": "2020-01-01T00:00:00.000000Z" - }, + "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"}, "enabled": False, "displayName": "test display", "conditions": [ { "conditionThreshold": { "comparison": "COMPARISON_GT", - "aggregations": [ - { - "alignmentPeriod": "60s", - "perSeriesAligner": "ALIGN_RATE" - } - ] + "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}], }, "displayName": "Condition display", - "name": "projects/sd-project/alertPolicies/456/conditions/789" + "name": "projects/sd-project/alertPolicies/456/conditions/789", } - ] + ], } TEST_NOTIFICATION_CHANNEL_1 = { "displayName": "sd", "enabled": True, - "labels": { - "auth_token": "top-secret", - "channel_name": "#channel" - }, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, "name": "projects/sd-project/notificationChannels/12345", - "type": "slack" + "type": "slack", } TEST_NOTIFICATION_CHANNEL_2 = { "displayName": "sd", "enabled": False, - "labels": { - "auth_token": "top-secret", - "channel_name": "#channel" - }, + "labels": {"auth_token": "top-secret", "channel_name": "#channel"}, "name": "projects/sd-project/notificationChannels/6789", - "type": "slack" + "type": "slack", } class TestStackdriverListAlertPoliciesOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverListAlertPoliciesOperator( - task_id=TEST_TASK_ID, - filter_=TEST_FILTER - ) + operator = StackdriverListAlertPoliciesOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER) operator.execute(None) mock_hook.return_value.list_alert_policies.assert_called_once_with( project_id=None, @@ -130,34 +110,20 @@ def test_execute(self, mock_hook): class TestStackdriverEnableAlertPoliciesOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverEnableAlertPoliciesOperator( - task_id=TEST_TASK_ID, - filter_=TEST_FILTER - ) + operator = StackdriverEnableAlertPoliciesOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER) operator.execute(None) mock_hook.return_value.enable_alert_policies.assert_called_once_with( - project_id=None, - filter_=TEST_FILTER, - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + project_id=None, filter_=TEST_FILTER, retry=DEFAULT, timeout=DEFAULT, metadata=None ) class TestStackdriverDisableAlertPoliciesOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverDisableAlertPoliciesOperator( - task_id=TEST_TASK_ID, - filter_=TEST_FILTER - ) + operator = StackdriverDisableAlertPoliciesOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER) operator.execute(None) mock_hook.return_value.disable_alert_policies.assert_called_once_with( - project_id=None, - filter_=TEST_FILTER, - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + project_id=None, filter_=TEST_FILTER, retry=DEFAULT, timeout=DEFAULT, metadata=None ) @@ -165,8 +131,7 @@ class TestStackdriverUpsertAlertsOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): operator = StackdriverUpsertAlertOperator( - task_id=TEST_TASK_ID, - alerts=json.dumps({"policies": [TEST_ALERT_POLICY_1, TEST_ALERT_POLICY_2]}) + task_id=TEST_TASK_ID, alerts=json.dumps({"policies": [TEST_ALERT_POLICY_1, TEST_ALERT_POLICY_2]}) ) operator.execute(None) mock_hook.return_value.upsert_alert.assert_called_once_with( @@ -174,33 +139,24 @@ def test_execute(self, mock_hook): project_id=None, retry=DEFAULT, timeout=DEFAULT, - metadata=None + metadata=None, ) class TestStackdriverDeleteAlertOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverDeleteAlertOperator( - task_id=TEST_TASK_ID, - name='test-alert', - ) + operator = StackdriverDeleteAlertOperator(task_id=TEST_TASK_ID, name='test-alert',) operator.execute(None) mock_hook.return_value.delete_alert_policy.assert_called_once_with( - name='test-alert', - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + name='test-alert', retry=DEFAULT, timeout=DEFAULT, metadata=None ) class TestStackdriverListNotificationChannelsOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverListNotificationChannelsOperator( - task_id=TEST_TASK_ID, - filter_=TEST_FILTER - ) + operator = StackdriverListNotificationChannelsOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER) operator.execute(None) mock_hook.return_value.list_notification_channels.assert_called_once_with( project_id=None, @@ -217,34 +173,20 @@ def test_execute(self, mock_hook): class TestStackdriverEnableNotificationChannelsOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverEnableNotificationChannelsOperator( - task_id=TEST_TASK_ID, - filter_=TEST_FILTER - ) + operator = StackdriverEnableNotificationChannelsOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER) operator.execute(None) mock_hook.return_value.enable_notification_channels.assert_called_once_with( - project_id=None, - filter_=TEST_FILTER, - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + project_id=None, filter_=TEST_FILTER, retry=DEFAULT, timeout=DEFAULT, metadata=None ) class TestStackdriverDisableNotificationChannelsOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverDisableNotificationChannelsOperator( - task_id=TEST_TASK_ID, - filter_=TEST_FILTER - ) + operator = StackdriverDisableNotificationChannelsOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER) operator.execute(None) mock_hook.return_value.disable_notification_channels.assert_called_once_with( - project_id=None, - filter_=TEST_FILTER, - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + project_id=None, filter_=TEST_FILTER, retry=DEFAULT, timeout=DEFAULT, metadata=None ) @@ -253,7 +195,7 @@ class TestStackdriverUpsertChannelOperator(unittest.TestCase): def test_execute(self, mock_hook): operator = StackdriverUpsertNotificationChannelOperator( task_id=TEST_TASK_ID, - channels=json.dumps({"channels": [TEST_NOTIFICATION_CHANNEL_1, TEST_NOTIFICATION_CHANNEL_2]}) + channels=json.dumps({"channels": [TEST_NOTIFICATION_CHANNEL_1, TEST_NOTIFICATION_CHANNEL_2]}), ) operator.execute(None) mock_hook.return_value.upsert_channel.assert_called_once_with( @@ -261,21 +203,15 @@ def test_execute(self, mock_hook): project_id=None, retry=DEFAULT, timeout=DEFAULT, - metadata=None + metadata=None, ) class TestStackdriverDeleteNotificationChannelOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook') def test_execute(self, mock_hook): - operator = StackdriverDeleteNotificationChannelOperator( - task_id=TEST_TASK_ID, - name='test-channel', - ) + operator = StackdriverDeleteNotificationChannelOperator(task_id=TEST_TASK_ID, name='test-channel',) operator.execute(None) mock_hook.return_value.delete_notification_channel.assert_called_once_with( - name='test-channel', - retry=DEFAULT, - timeout=DEFAULT, - metadata=None + name='test-channel', retry=DEFAULT, timeout=DEFAULT, metadata=None ) diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py index e6d87f5eddfb3..d6b26bd6a5457 100644 --- a/tests/providers/google/cloud/operators/test_tasks.py +++ b/tests/providers/google/cloud/operators/test_tasks.py @@ -22,10 +22,18 @@ from google.cloud.tasks_v2.types import Queue, Task from airflow.providers.google.cloud.operators.tasks import ( - CloudTasksQueueCreateOperator, CloudTasksQueueDeleteOperator, CloudTasksQueueGetOperator, - CloudTasksQueuePauseOperator, CloudTasksQueuePurgeOperator, CloudTasksQueueResumeOperator, - CloudTasksQueuesListOperator, CloudTasksQueueUpdateOperator, CloudTasksTaskCreateOperator, - CloudTasksTaskDeleteOperator, CloudTasksTaskGetOperator, CloudTasksTaskRunOperator, + CloudTasksQueueCreateOperator, + CloudTasksQueueDeleteOperator, + CloudTasksQueueGetOperator, + CloudTasksQueuePauseOperator, + CloudTasksQueuePurgeOperator, + CloudTasksQueueResumeOperator, + CloudTasksQueuesListOperator, + CloudTasksQueueUpdateOperator, + CloudTasksTaskCreateOperator, + CloudTasksTaskDeleteOperator, + CloudTasksTaskGetOperator, + CloudTasksTaskRunOperator, CloudTasksTasksListOperator, ) @@ -36,22 +44,17 @@ QUEUE_ID = "test-queue" FULL_QUEUE_PATH = "projects/test-project/locations/asia-east2/queues/test-queue" TASK_NAME = "test-task" -FULL_TASK_PATH = ( - "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task" -) +FULL_TASK_PATH = "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task" class TestCloudTasksQueueCreate(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_create_queue(self, mock_hook): mock_hook.return_value.create_queue.return_value = mock.MagicMock() - operator = CloudTasksQueueCreateOperator( - location=LOCATION, task_queue=Queue(), task_id="id" - ) + operator = CloudTasksQueueCreateOperator(location=LOCATION, task_queue=Queue(), task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_queue.assert_called_once_with( location=LOCATION, @@ -68,13 +71,10 @@ class TestCloudTasksQueueUpdate(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_update_queue(self, mock_hook): mock_hook.return_value.update_queue.return_value = mock.MagicMock() - operator = CloudTasksQueueUpdateOperator( - task_queue=Queue(name=FULL_QUEUE_PATH), task_id="id" - ) + operator = CloudTasksQueueUpdateOperator(task_queue=Queue(name=FULL_QUEUE_PATH), task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_queue.assert_called_once_with( task_queue=Queue(name=FULL_QUEUE_PATH), @@ -92,21 +92,13 @@ class TestCloudTasksQueueGet(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_get_queue(self, mock_hook): mock_hook.return_value.get_queue.return_value = mock.MagicMock() - operator = CloudTasksQueueGetOperator( - location=LOCATION, queue_name=QUEUE_ID, task_id="id" - ) + operator = CloudTasksQueueGetOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_queue.assert_called_once_with( - location=LOCATION, - queue_name=QUEUE_ID, - project_id=None, - retry=None, - timeout=None, - metadata=None, + location=LOCATION, queue_name=QUEUE_ID, project_id=None, retry=None, timeout=None, metadata=None, ) @@ -117,8 +109,7 @@ def test_list_queues(self, mock_hook): operator = CloudTasksQueuesListOperator(location=LOCATION, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_queues.assert_called_once_with( location=LOCATION, @@ -135,21 +126,13 @@ class TestCloudTasksQueueDelete(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_delete_queue(self, mock_hook): mock_hook.return_value.delete_queue.return_value = mock.MagicMock() - operator = CloudTasksQueueDeleteOperator( - location=LOCATION, queue_name=QUEUE_ID, task_id="id" - ) + operator = CloudTasksQueueDeleteOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_queue.assert_called_once_with( - location=LOCATION, - queue_name=QUEUE_ID, - project_id=None, - retry=None, - timeout=None, - metadata=None, + location=LOCATION, queue_name=QUEUE_ID, project_id=None, retry=None, timeout=None, metadata=None, ) @@ -157,21 +140,13 @@ class TestCloudTasksQueuePurge(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_delete_queue(self, mock_hook): mock_hook.return_value.purge_queue.return_value = mock.MagicMock() - operator = CloudTasksQueuePurgeOperator( - location=LOCATION, queue_name=QUEUE_ID, task_id="id" - ) + operator = CloudTasksQueuePurgeOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.purge_queue.assert_called_once_with( - location=LOCATION, - queue_name=QUEUE_ID, - project_id=None, - retry=None, - timeout=None, - metadata=None, + location=LOCATION, queue_name=QUEUE_ID, project_id=None, retry=None, timeout=None, metadata=None, ) @@ -179,21 +154,13 @@ class TestCloudTasksQueuePause(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_pause_queue(self, mock_hook): mock_hook.return_value.pause_queue.return_value = mock.MagicMock() - operator = CloudTasksQueuePauseOperator( - location=LOCATION, queue_name=QUEUE_ID, task_id="id" - ) + operator = CloudTasksQueuePauseOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.pause_queue.assert_called_once_with( - location=LOCATION, - queue_name=QUEUE_ID, - project_id=None, - retry=None, - timeout=None, - metadata=None, + location=LOCATION, queue_name=QUEUE_ID, project_id=None, retry=None, timeout=None, metadata=None, ) @@ -201,21 +168,13 @@ class TestCloudTasksQueueResume(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_resume_queue(self, mock_hook): mock_hook.return_value.resume_queue.return_value = mock.MagicMock() - operator = CloudTasksQueueResumeOperator( - location=LOCATION, queue_name=QUEUE_ID, task_id="id" - ) + operator = CloudTasksQueueResumeOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.resume_queue.assert_called_once_with( - location=LOCATION, - queue_name=QUEUE_ID, - project_id=None, - retry=None, - timeout=None, - metadata=None, + location=LOCATION, queue_name=QUEUE_ID, project_id=None, retry=None, timeout=None, metadata=None, ) @@ -228,8 +187,7 @@ def test_create_task(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_task.assert_called_once_with( location=LOCATION, @@ -253,8 +211,7 @@ def test_get_task(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_task.assert_called_once_with( location=LOCATION, @@ -272,13 +229,10 @@ class TestCloudTasksTasksList(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook") def test_list_tasks(self, mock_hook): mock_hook.return_value.list_tasks.return_value = mock.MagicMock() - operator = CloudTasksTasksListOperator( - location=LOCATION, queue_name=QUEUE_ID, task_id="id" - ) + operator = CloudTasksTasksListOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id") operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.list_tasks.assert_called_once_with( location=LOCATION, @@ -301,8 +255,7 @@ def test_delete_task(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_task.assert_called_once_with( location=LOCATION, @@ -324,8 +277,7 @@ def test_run_task(self, mock_hook): ) operator.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.run_task.assert_called_once_with( location=LOCATION, diff --git a/tests/providers/google/cloud/operators/test_text_to_speech.py b/tests/providers/google/cloud/operators/test_text_to_speech.py index bc547e99052de..5f722b5120380 100644 --- a/tests/providers/google/cloud/operators/test_text_to_speech.py +++ b/tests/providers/google/cloud/operators/test_text_to_speech.py @@ -57,12 +57,10 @@ def test_synthesize_text_green_path(self, mock_text_to_speech_hook, mock_gcp_hoo ).execute(context={"task_instance": Mock()}) mock_text_to_speech_hook.assert_called_once_with( - gcp_conn_id="gcp-conn-id", - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id="gcp-conn-id", impersonation_chain=IMPERSONATION_CHAIN, ) mock_gcp_hook.assert_called_once_with( - google_cloud_storage_conn_id="gcp-conn-id", - impersonation_chain=IMPERSONATION_CHAIN, + google_cloud_storage_conn_id="gcp-conn-id", impersonation_chain=IMPERSONATION_CHAIN, ) mock_text_to_speech_hook.return_value.synthesize_speech.assert_called_once_with( input_data=INPUT, voice=VOICE, audio_config=AUDIO_CONFIG, retry=None, timeout=None diff --git a/tests/providers/google/cloud/operators/test_text_to_speech_system.py b/tests/providers/google/cloud/operators/test_text_to_speech_system.py index 85b823a941d17..7104a12884d26 100644 --- a/tests/providers/google/cloud/operators/test_text_to_speech_system.py +++ b/tests/providers/google/cloud/operators/test_text_to_speech_system.py @@ -26,7 +26,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class GCPTextToSpeechExampleDagSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/operators/test_translate.py b/tests/providers/google/cloud/operators/test_translate.py index 74f19554a628f..cb69608a34837 100644 --- a/tests/providers/google/cloud/operators/test_translate.py +++ b/tests/providers/google/cloud/operators/test_translate.py @@ -49,8 +49,7 @@ def test_minimal_green_path(self, mock_hook): ) return_value = op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.translate.assert_called_once_with( values=['zażółć gęślą jaźń'], diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py index be9b2b1e30e4b..d71d536b30e5f 100644 --- a/tests/providers/google/cloud/operators/test_translate_speech.py +++ b/tests/providers/google/cloud/operators/test_translate_speech.py @@ -20,7 +20,9 @@ import mock from google.cloud.speech_v1.proto.cloud_speech_pb2 import ( - RecognizeResponse, SpeechRecognitionAlternative, SpeechRecognitionResult, + RecognizeResponse, + SpeechRecognitionAlternative, + SpeechRecognitionResult, ) from airflow.exceptions import AirflowException @@ -35,11 +37,11 @@ class TestCloudTranslateSpeech(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.translate_speech.CloudTranslateHook') def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook): mock_speech_hook.return_value.recognize_speech.return_value = RecognizeResponse( - results=[SpeechRecognitionResult( - alternatives=[SpeechRecognitionAlternative( - transcript='test speech recognition result' - )] - )] + results=[ + SpeechRecognitionResult( + alternatives=[SpeechRecognitionAlternative(transcript='test speech recognition result')] + ) + ] ) mock_translate_hook.return_value.translate.return_value = [ { @@ -64,17 +66,14 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook): return_value = op.execute(context=None) mock_speech_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_translate_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_speech_hook.return_value.recognize_speech.assert_called_once_with( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio={"uri": "gs://bucket/object"}, config={"encoding": "LINEAR16"}, ) mock_translate_hook.return_value.translate.assert_called_once_with( @@ -118,17 +117,14 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook): self.assertIn("it should contain 'alternatives' field", str(err)) mock_speech_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_translate_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_speech_hook.return_value.recognize_speech.assert_called_once_with( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio={"uri": "gs://bucket/object"}, config={"encoding": "LINEAR16"}, ) mock_translate_hook.return_value.translate.assert_not_called() diff --git a/tests/providers/google/cloud/operators/test_translate_speech_system.py b/tests/providers/google/cloud/operators/test_translate_speech_system.py index 86856680adf6d..ad9ed2ad12d85 100644 --- a/tests/providers/google/cloud/operators/test_translate_speech_system.py +++ b/tests/providers/google/cloud/operators/test_translate_speech_system.py @@ -26,7 +26,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class GCPTextToSpeechExampleDagSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/operators/test_video_intelligence.py b/tests/providers/google/cloud/operators/test_video_intelligence.py index 84aa453bf24f3..f3a7578b77ad9 100644 --- a/tests/providers/google/cloud/operators/test_video_intelligence.py +++ b/tests/providers/google/cloud/operators/test_video_intelligence.py @@ -23,7 +23,8 @@ from google.cloud.videointelligence_v1.proto.video_intelligence_pb2 import AnnotateVideoResponse from airflow.providers.google.cloud.operators.video_intelligence import ( - CloudVideoIntelligenceDetectVideoExplicitContentOperator, CloudVideoIntelligenceDetectVideoLabelsOperator, + CloudVideoIntelligenceDetectVideoExplicitContentOperator, + CloudVideoIntelligenceDetectVideoLabelsOperator, CloudVideoIntelligenceDetectVideoShotsOperator, ) @@ -52,8 +53,7 @@ def test_detect_video_labels_green_path(self, mock_hook): ).execute(context={"task_instance": mock.Mock()}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.annotate_video.assert_called_once_with( input_uri=INPUT_URI, @@ -62,7 +62,7 @@ def test_detect_video_labels_green_path(self, mock_hook): video_context=None, location=None, retry=None, - timeout=None + timeout=None, ) @mock.patch("airflow.providers.google.cloud.operators.video_intelligence.CloudVideoIntelligenceHook") @@ -79,8 +79,7 @@ def test_detect_video_explicit_content_green_path(self, mock_hook): ).execute(context={"task_instance": mock.Mock()}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.annotate_video.assert_called_once_with( input_uri=INPUT_URI, @@ -89,7 +88,7 @@ def test_detect_video_explicit_content_green_path(self, mock_hook): video_context=None, location=None, retry=None, - timeout=None + timeout=None, ) @mock.patch("airflow.providers.google.cloud.operators.video_intelligence.CloudVideoIntelligenceHook") @@ -106,8 +105,7 @@ def test_detect_video_shots_green_path(self, mock_hook): ).execute(context={"task_instance": mock.Mock()}) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.annotate_video.assert_called_once_with( input_uri=INPUT_URI, @@ -116,5 +114,5 @@ def test_detect_video_shots_green_path(self, mock_hook): video_context=None, location=None, retry=None, - timeout=None + timeout=None, ) diff --git a/tests/providers/google/cloud/operators/test_video_intelligence_system.py b/tests/providers/google/cloud/operators/test_video_intelligence_system.py index 795267c3f831a..f7296753d9000 100644 --- a/tests/providers/google/cloud/operators/test_video_intelligence_system.py +++ b/tests/providers/google/cloud/operators/test_video_intelligence_system.py @@ -30,16 +30,12 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_AI_KEY) class CloudVideoIntelligenceExampleDagsTest(GoogleSystemTest): - @provide_gcp_context(GCP_AI_KEY) def setUp(self): self.create_gcs_bucket(GCP_BUCKET_NAME, location="europe-north1") self.execute_with_ctx( - cmd=[ - "bash", - "-c", - f"curl {GCP_VIDEO_SOURCE_URL} | gsutil cp - gs://{GCP_BUCKET_NAME}/video.mp4" - ], key=GCP_GCS_KEY + cmd=["bash", "-c", f"curl {GCP_VIDEO_SOURCE_URL} | gsutil cp - gs://{GCP_BUCKET_NAME}/video.mp4"], + key=GCP_GCS_KEY, ) super().setUp() diff --git a/tests/providers/google/cloud/operators/test_vision.py b/tests/providers/google/cloud/operators/test_vision.py index f6380c2b99a78..bd19fb6674785 100644 --- a/tests/providers/google/cloud/operators/test_vision.py +++ b/tests/providers/google/cloud/operators/test_vision.py @@ -23,14 +23,23 @@ from google.cloud.vision_v1.types import Product, ProductSet, ReferenceImage from airflow.providers.google.cloud.operators.vision import ( - CloudVisionAddProductToProductSetOperator, CloudVisionCreateProductOperator, - CloudVisionCreateProductSetOperator, CloudVisionCreateReferenceImageOperator, - CloudVisionDeleteProductOperator, CloudVisionDeleteProductSetOperator, - CloudVisionDeleteReferenceImageOperator, CloudVisionDetectImageLabelsOperator, - CloudVisionDetectImageSafeSearchOperator, CloudVisionDetectTextOperator, CloudVisionGetProductOperator, - CloudVisionGetProductSetOperator, CloudVisionImageAnnotateOperator, - CloudVisionRemoveProductFromProductSetOperator, CloudVisionTextDetectOperator, - CloudVisionUpdateProductOperator, CloudVisionUpdateProductSetOperator, + CloudVisionAddProductToProductSetOperator, + CloudVisionCreateProductOperator, + CloudVisionCreateProductSetOperator, + CloudVisionCreateReferenceImageOperator, + CloudVisionDeleteProductOperator, + CloudVisionDeleteProductSetOperator, + CloudVisionDeleteReferenceImageOperator, + CloudVisionDetectImageLabelsOperator, + CloudVisionDetectImageSafeSearchOperator, + CloudVisionDetectTextOperator, + CloudVisionGetProductOperator, + CloudVisionGetProductSetOperator, + CloudVisionImageAnnotateOperator, + CloudVisionRemoveProductFromProductSetOperator, + CloudVisionTextDetectOperator, + CloudVisionUpdateProductOperator, + CloudVisionUpdateProductSetOperator, ) PRODUCTSET_TEST = ProductSet(display_name='Test Product Set') @@ -42,7 +51,7 @@ ANNOTATE_REQUEST_TEST = {'image': {'source': {'image_uri': 'https://foo.com/image.jpg'}}} ANNOTATE_REQUEST_BATCH_TEST = [ {'image': {'source': {'image_uri': 'https://foo.com/image1.jpg'}}}, - {'image': {'source': {'image_uri': 'https://foo.com/image2.jpg'}}} + {'image': {'source': {'image_uri': 'https://foo.com/image2.jpg'}}}, ] LOCATION_TEST = 'europe-west1' GCP_CONN_ID = 'google_cloud_default' @@ -58,8 +67,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_product_set.assert_called_once_with( location=LOCATION_TEST, @@ -95,8 +103,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_product_set.assert_called_once_with( location=LOCATION_TEST, @@ -119,8 +126,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_product_set.assert_called_once_with( location=LOCATION_TEST, @@ -141,8 +147,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_product_set.assert_called_once_with( location=LOCATION_TEST, @@ -161,8 +166,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionCreateProductOperator(location=LOCATION_TEST, product=PRODUCT_TEST, task_id='id') op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_product.assert_called_once_with( location=LOCATION_TEST, @@ -196,8 +200,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionGetProductOperator(location=LOCATION_TEST, product_id=PRODUCT_ID_TEST, task_id='id') op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.get_product.assert_called_once_with( location=LOCATION_TEST, @@ -216,8 +219,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionUpdateProductOperator(location=LOCATION_TEST, product=PRODUCT_TEST, task_id='id') op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.update_product.assert_called_once_with( location=LOCATION_TEST, @@ -240,8 +242,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_product.assert_called_once_with( location=LOCATION_TEST, @@ -254,9 +255,7 @@ def test_minimal_green_path(self, mock_hook): class TestCloudVisionReferenceImageCreate(unittest.TestCase): - @mock.patch( - 'airflow.providers.google.cloud.operators.vision.CloudVisionHook', - ) + @mock.patch('airflow.providers.google.cloud.operators.vision.CloudVisionHook',) def test_minimal_green_path(self, mock_hook): mock_hook.return_value.create_reference_image.return_value = {} op = CloudVisionCreateReferenceImageOperator( @@ -267,8 +266,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_reference_image.assert_called_once_with( location=LOCATION_TEST, @@ -283,7 +281,7 @@ def test_minimal_green_path(self, mock_hook): @mock.patch( 'airflow.providers.google.cloud.operators.vision.CloudVisionHook', - **{'return_value.create_reference_image.side_effect': AlreadyExists("MESSAGe")} + **{'return_value.create_reference_image.side_effect': AlreadyExists("MESSAGe")}, ) def test_already_exists(self, mock_hook): # Exception AlreadyExists not raised, caught in the operator's execute() - idempotence @@ -295,8 +293,7 @@ def test_already_exists(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.create_reference_image.assert_called_once_with( location=LOCATION_TEST, @@ -311,9 +308,7 @@ def test_already_exists(self, mock_hook): class TestCloudVisionReferenceImageDelete(unittest.TestCase): - @mock.patch( - 'airflow.providers.google.cloud.operators.vision.CloudVisionHook', - ) + @mock.patch('airflow.providers.google.cloud.operators.vision.CloudVisionHook',) def test_minimal_green_path(self, mock_hook): mock_hook.return_value.delete_reference_image.return_value = {} op = CloudVisionDeleteReferenceImageOperator( @@ -324,8 +319,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.delete_reference_image.assert_called_once_with( location=LOCATION_TEST, @@ -349,8 +343,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.add_product_to_product_set.assert_called_once_with( product_set_id=PRODUCTSET_ID_TEST, @@ -374,8 +367,7 @@ def test_minimal_green_path(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.remove_product_from_product_set.assert_called_once_with( product_set_id=PRODUCTSET_ID_TEST, @@ -394,8 +386,7 @@ def test_minimal_green_path_for_one_image(self, mock_hook): op = CloudVisionImageAnnotateOperator(request=ANNOTATE_REQUEST_TEST, task_id='id') op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.annotate_image.assert_called_once_with( request=ANNOTATE_REQUEST_TEST, retry=None, timeout=None @@ -406,8 +397,7 @@ def test_minimal_green_path_for_batch(self, mock_hook): op = CloudVisionImageAnnotateOperator(request=ANNOTATE_REQUEST_BATCH_TEST, task_id='id') op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.batch_annotate_images.assert_called_once_with( requests=ANNOTATE_REQUEST_BATCH_TEST, retry=None, timeout=None @@ -420,8 +410,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionDetectTextOperator(image=DETECT_TEST_IMAGE, task_id="id") op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.text_detection.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, additional_properties=None @@ -435,16 +424,13 @@ def test_additional_params(self, mock_hook): language_hints="pl", web_detection_params={'param': 'test'}, additional_properties={ - 'image_context': { - 'additional_property_1': 'add_1' - }, - 'additional_property_2': 'add_2' - } + 'image_context': {'additional_property_1': 'add_1'}, + 'additional_property_2': 'add_2', + }, ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.text_detection.assert_called_once_with( image=DETECT_TEST_IMAGE, @@ -456,11 +442,9 @@ def test_additional_params(self, mock_hook): 'image_context': { 'language_hints': 'pl', 'additional_property_1': 'add_1', - 'web_detection_params': { - 'param': 'test' - } - } - } + 'web_detection_params': {'param': 'test'}, + }, + }, ) @@ -470,8 +454,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionTextDetectOperator(image=DETECT_TEST_IMAGE, task_id="id") op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.document_text_detection.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, additional_properties=None @@ -484,8 +467,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionDetectImageLabelsOperator(image=DETECT_TEST_IMAGE, task_id="id") op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.label_detection.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, additional_properties=None @@ -498,8 +480,7 @@ def test_minimal_green_path(self, mock_hook): op = CloudVisionDetectImageSafeSearchOperator(image=DETECT_TEST_IMAGE, task_id="id") op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, ) mock_hook.return_value.safe_search_detection.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, additional_properties=None diff --git a/tests/providers/google/cloud/secrets/test_secret_manager.py b/tests/providers/google/cloud/secrets/test_secret_manager.py index 377316518508d..329b1339f521a 100644 --- a/tests/providers/google/cloud/secrets/test_secret_manager.py +++ b/tests/providers/google/cloud/secrets/test_secret_manager.py @@ -52,24 +52,28 @@ def test_default_valid_and_sep(self, mock_client_callable, mock_get_creds): backend = CloudSecretManagerBackend() self.assertTrue(backend._is_valid_prefix_and_sep()) - @parameterized.expand([ - ("colon:", "not:valid", ":"), - ("slash/", "not/valid", "/"), - ("space_with_char", "a b", ""), - ("space_only", "", " ") - ]) + @parameterized.expand( + [ + ("colon:", "not:valid", ":"), + ("slash/", "not/valid", "/"), + ("space_with_char", "a b", ""), + ("space_only", "", " "), + ] + ) def test_raise_exception_with_invalid_prefix_sep(self, _, prefix, sep): with self.assertRaises(AirflowException): CloudSecretManagerBackend(connections_prefix=prefix, sep=sep) - @parameterized.expand([ - ("dash-", "valid1", "-", True), - ("underscore_", "isValid", "_", True), - ("empty_string", "", "", True), - ("space_prefix", " ", "", False), - ("space_sep", "", " ", False), - ("colon:", "not:valid", ":", False) - ]) + @parameterized.expand( + [ + ("dash-", "valid1", "-", True), + ("underscore_", "isValid", "_", True), + ("empty_string", "", "", True), + ("space_prefix", " ", "", False), + ("space_sep", "", " ", False), + ("colon:", "not:valid", ":", False), + ] + ) @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @mock.patch(CLIENT_MODULE_NAME + ".SecretManagerServiceClient") def test_is_valid_prefix_and_sep(self, _, prefix, sep, is_valid, mock_client_callable, mock_get_creds): @@ -81,11 +85,7 @@ def test_is_valid_prefix_and_sep(self, _, prefix, sep, is_valid, mock_client_cal backend.sep = sep self.assertEqual(backend._is_valid_prefix_and_sep(), is_valid) - @parameterized.expand([ - "airflow-connections", - "connections", - "airflow" - ]) + @parameterized.expand(["airflow-connections", "connections", "airflow"]) @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @mock.patch(CLIENT_MODULE_NAME + ".SecretManagerServiceClient") def test_get_conn_uri(self, connections_prefix, mock_client_callable, mock_get_creds): @@ -101,9 +101,7 @@ def test_get_conn_uri(self, connections_prefix, mock_client_callable, mock_get_c secret_id = secrets_manager_backend.build_path(connections_prefix, CONN_ID, SEP) returned_uri = secrets_manager_backend.get_conn_uri(conn_id=CONN_ID) self.assertEqual(CONN_URI, returned_uri) - mock_client.secret_version_path.assert_called_once_with( - PROJECT_ID, secret_id, "latest" - ) + mock_client.secret_version_path.assert_called_once_with(PROJECT_ID, secret_id, "latest") @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @mock.patch(MODULE_NAME + ".CloudSecretManagerBackend.get_conn_uri") @@ -129,15 +127,10 @@ def test_get_conn_uri_non_existent_key(self, mock_client_callable, mock_get_cred self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=CONN_ID)) self.assertEqual([], secrets_manager_backend.get_connections(conn_id=CONN_ID)) self.assertRegex( - log_output.output[0], - f"GCP API Call Error \\(NotFound\\): Secret ID {secret_id} not found" + log_output.output[0], f"GCP API Call Error \\(NotFound\\): Secret ID {secret_id} not found" ) - @parameterized.expand([ - "airflow-variables", - "variables", - "airflow" - ]) + @parameterized.expand(["airflow-variables", "variables", "airflow"]) @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @mock.patch(CLIENT_MODULE_NAME + ".SecretManagerServiceClient") def test_get_variable(self, variables_prefix, mock_client_callable, mock_get_creds): @@ -153,15 +146,9 @@ def test_get_variable(self, variables_prefix, mock_client_callable, mock_get_cre secret_id = secrets_manager_backend.build_path(variables_prefix, VAR_KEY, SEP) returned_uri = secrets_manager_backend.get_variable(VAR_KEY) self.assertEqual(VAR_VALUE, returned_uri) - mock_client.secret_version_path.assert_called_once_with( - PROJECT_ID, secret_id, "latest" - ) + mock_client.secret_version_path.assert_called_once_with(PROJECT_ID, secret_id, "latest") - @parameterized.expand([ - "airflow-variables", - "variables", - "airflow" - ]) + @parameterized.expand(["airflow-variables", "variables", "airflow"]) @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @mock.patch(CLIENT_MODULE_NAME + ".SecretManagerServiceClient") def test_get_variable_override_project_id(self, variables_prefix, mock_client_callable, mock_get_creds): @@ -173,14 +160,13 @@ def test_get_variable_override_project_id(self, variables_prefix, mock_client_ca test_response.payload.data = VAR_VALUE.encode("UTF-8") mock_client.access_secret_version.return_value = test_response - secrets_manager_backend = CloudSecretManagerBackend(variables_prefix=variables_prefix, - project_id=OVERRIDDEN_PROJECT_ID) + secrets_manager_backend = CloudSecretManagerBackend( + variables_prefix=variables_prefix, project_id=OVERRIDDEN_PROJECT_ID + ) secret_id = secrets_manager_backend.build_path(variables_prefix, VAR_KEY, SEP) returned_uri = secrets_manager_backend.get_variable(VAR_KEY) self.assertEqual(VAR_VALUE, returned_uri) - mock_client.secret_version_path.assert_called_once_with( - OVERRIDDEN_PROJECT_ID, secret_id, "latest" - ) + mock_client.secret_version_path.assert_called_once_with(OVERRIDDEN_PROJECT_ID, secret_id, "latest") @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @mock.patch(CLIENT_MODULE_NAME + ".SecretManagerServiceClient") @@ -196,6 +182,5 @@ def test_get_variable_non_existent_key(self, mock_client_callable, mock_get_cred with self.assertLogs(secrets_manager_backend.client.log, level="ERROR") as log_output: self.assertIsNone(secrets_manager_backend.get_variable(VAR_KEY)) self.assertRegex( - log_output.output[0], - f"GCP API Call Error \\(NotFound\\): Secret ID {secret_id} not found" + log_output.output[0], f"GCP API Call Error \\(NotFound\\): Secret ID {secret_id} not found" ) diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py index 856b2a75ed2b5..52af97295390a 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery.py +++ b/tests/providers/google/cloud/sensors/test_bigquery.py @@ -18,7 +18,8 @@ from unittest import TestCase, mock from airflow.providers.google.cloud.sensors.bigquery import ( - BigQueryTableExistenceSensor, BigQueryTablePartitionExistenceSensor, + BigQueryTableExistenceSensor, + BigQueryTablePartitionExistenceSensor, ) TEST_PROJECT_ID = "test_project" @@ -53,9 +54,7 @@ def test_passing_arguments_to_hook(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) mock_hook.return_value.table_exists.assert_called_once_with( - project_id=TEST_PROJECT_ID, - dataset_id=TEST_DATASET_ID, - table_id=TEST_TABLE_ID + project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID ) @@ -86,5 +85,5 @@ def test_passing_arguments_to_hook(self, mock_hook): project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, - partition_id=TEST_PARTITION_ID + partition_id=TEST_PARTITION_ID, ) diff --git a/tests/providers/google/cloud/sensors/test_bigtable.py b/tests/providers/google/cloud/sensors/test_bigtable.py index f12071be91d23..bf6015122ad5f 100644 --- a/tests/providers/google/cloud/sensors/test_bigtable.py +++ b/tests/providers/google/cloud/sensors/test_bigtable.py @@ -35,13 +35,12 @@ class BigtableWaitForTableReplicationTest(unittest.TestCase): - @parameterized.expand([ - ('instance_id', PROJECT_ID, '', TABLE_ID), - ('table_id', PROJECT_ID, INSTANCE_ID, ''), - ], testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0]) + @parameterized.expand( + [('instance_id', PROJECT_ID, '', TABLE_ID), ('table_id', PROJECT_ID, INSTANCE_ID, ''),], + testcase_func_name=lambda f, n, p: 'test_empty_attribute.empty_' + p.args[0], + ) @mock.patch('airflow.providers.google.cloud.sensors.bigtable.BigtableHook') - def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, - mock_hook): + def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook): with self.assertRaises(AirflowException) as e: BigtableTableReplicationCompletedSensor( project_id=project_id, @@ -69,15 +68,15 @@ def test_wait_no_instance(self, mock_hook): ) self.assertFalse(op.poke(None)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @mock.patch('airflow.providers.google.cloud.sensors.bigtable.BigtableHook') def test_wait_no_table(self, mock_hook): mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) mock_hook.return_value.get_cluster_states_for_table.side_effect = mock.Mock( - side_effect=google.api_core.exceptions.NotFound("Table not found.")) + side_effect=google.api_core.exceptions.NotFound("Table not found.") + ) op = BigtableTableReplicationCompletedSensor( project_id=PROJECT_ID, @@ -89,16 +88,13 @@ def test_wait_no_table(self, mock_hook): ) self.assertFalse(op.poke(None)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @mock.patch('airflow.providers.google.cloud.sensors.bigtable.BigtableHook') def test_wait_not_ready(self, mock_hook): mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) - mock_hook.return_value.get_cluster_states_for_table.return_value = { - "cl-id": ClusterState(0) - } + mock_hook.return_value.get_cluster_states_for_table.return_value = {"cl-id": ClusterState(0)} op = BigtableTableReplicationCompletedSensor( project_id=PROJECT_ID, instance_id=INSTANCE_ID, @@ -109,16 +105,13 @@ def test_wait_not_ready(self, mock_hook): ) self.assertFalse(op.poke(None)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @mock.patch('airflow.providers.google.cloud.sensors.bigtable.BigtableHook') def test_wait_ready(self, mock_hook): mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) - mock_hook.return_value.get_cluster_states_for_table.return_value = { - "cl-id": ClusterState(4) - } + mock_hook.return_value.get_cluster_states_for_table.return_value = {"cl-id": ClusterState(4)} op = BigtableTableReplicationCompletedSensor( project_id=PROJECT_ID, instance_id=INSTANCE_ID, @@ -129,6 +122,5 @@ def test_wait_ready(self, mock_hook): ) self.assertTrue(op.poke(None)) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py index 5eef6541fb2cf..2bd9f60c0cd44 100644 --- a/tests/providers/google/cloud/sensors/test_gcs.py +++ b/tests/providers/google/cloud/sensors/test_gcs.py @@ -23,8 +23,11 @@ from airflow.exceptions import AirflowSensorTimeout from airflow.models.dag import DAG, AirflowException from airflow.providers.google.cloud.sensors.gcs import ( - GCSObjectExistenceSensor, GCSObjectsWtihPrefixExistenceSensor, GCSObjectUpdateSensor, - GCSUploadSessionCompleteSensor, ts_function, + GCSObjectExistenceSensor, + GCSObjectsWtihPrefixExistenceSensor, + GCSObjectUpdateSensor, + GCSUploadSessionCompleteSensor, + ts_function, ) TEST_BUCKET = "TEST_BUCKET" @@ -43,8 +46,7 @@ DEFAULT_DATE = datetime(2015, 1, 1) -MOCK_DATE_ARRAY = [datetime(2019, 2, 24, 12, 0, 0) - i * timedelta(seconds=10) - for i in range(25)] +MOCK_DATE_ARRAY = [datetime(2019, 2, 24, 12, 0, 0) - i * timedelta(seconds=10) for i in range(25)] def next_time_side_effect(): @@ -86,17 +88,13 @@ class TestTsFunction(TestCase): def test_should_support_datetime(self): context = { 'dag': DAG(dag_id=TEST_DAG_ID, schedule_interval=timedelta(days=5)), - 'execution_date': datetime(2019, 2, 14, 0, 0) + 'execution_date': datetime(2019, 2, 14, 0, 0), } result = ts_function(context) self.assertEqual(datetime(2019, 2, 19, 0, 0, tzinfo=timezone.utc), result) def test_should_support_cron(self): - dag = DAG( - dag_id=TEST_DAG_ID, - start_date=datetime(2019, 2, 19, 0, 0), - schedule_interval='@weekly' - ) + dag = DAG(dag_id=TEST_DAG_ID, start_date=datetime(2019, 2, 19, 0, 0), schedule_interval='@weekly') context = { 'dag': dag, @@ -174,7 +172,8 @@ def test_execute(self, mock_hook): google_cloud_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, impersonation_chain=TEST_IMPERSONATION_CHAIN, - poke_interval=0) + poke_interval=0, + ) generated_messages = ['test-prefix/obj%s' % i for i in range(5)] mock_hook.return_value.list.return_value = generated_messages @@ -191,20 +190,15 @@ def test_execute(self, mock_hook): @mock.patch('airflow.providers.google.cloud.sensors.gcs.GCSHook') def test_execute_timeout(self, mock_hook): task = GCSObjectsWtihPrefixExistenceSensor( - task_id="task-id", - bucket=TEST_BUCKET, - prefix=TEST_PREFIX, - poke_interval=0, - timeout=1) + task_id="task-id", bucket=TEST_BUCKET, prefix=TEST_PREFIX, poke_interval=0, timeout=1 + ) mock_hook.return_value.list.return_value = [] with self.assertRaises(AirflowSensorTimeout): task.execute(mock.MagicMock) - mock_hook.return_value.list.assert_called_once_with( - TEST_BUCKET, prefix=TEST_PREFIX) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX) class TestGCSUploadSessionCompleteSensor(TestCase): - def setUp(self): args = { 'owner': 'airflow', @@ -225,7 +219,7 @@ def setUp(self): google_cloud_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, impersonation_chain=TEST_IMPERSONATION_CHAIN, - dag=self.dag + dag=self.dag, ) self.last_mocked_date = datetime(2019, 4, 24, 0, 0, 0) @@ -256,7 +250,7 @@ def test_files_deleted_between_pokes_allow_delete(self): poke_interval=10, min_objects=1, allow_delete=True, - dag=self.dag + dag=self.dag, ) self.sensor.is_bucket_updated({'a', 'b'}) self.assertEqual(self.sensor.inactivity_seconds, 0) diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index c2305a0e8216f..496196f2faef3 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -48,18 +48,11 @@ def _generate_messages(self, count): ] def _generate_dicts(self, count): - return [ - MessageToDict(m) - for m in self._generate_messages(count) - ] + return [MessageToDict(m) for m in self._generate_messages(count)] @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') def test_poke_no_messages(self, mock_hook): - operator = PubSubPullSensor( - task_id=TASK_ID, - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - ) + operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION,) mock_hook.return_value.pull.return_value = [] self.assertEqual(False, operator.poke({})) @@ -67,10 +60,7 @@ def test_poke_no_messages(self, mock_hook): @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') def test_poke_with_ack_messages(self, mock_hook): operator = PubSubPullSensor( - task_id=TASK_ID, - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - ack_messages=True, + task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_messages=True, ) generated_messages = self._generate_messages(5) @@ -79,18 +69,13 @@ def test_poke_with_ack_messages(self, mock_hook): self.assertEqual(True, operator.poke({})) mock_hook.return_value.acknowledge.assert_called_once_with( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - messages=generated_messages, + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, messages=generated_messages, ) @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubPullSensor( - task_id=TASK_ID, - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - poke_interval=0, + task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, poke_interval=0, ) generated_messages = self._generate_messages(5) @@ -99,10 +84,7 @@ def test_execute(self, mock_hook): response = operator.execute({}) mock_hook.return_value.pull.assert_called_once_with( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - max_messages=5, - return_immediately=True + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, return_immediately=True ) self.assertEqual(generated_dicts, response) @@ -124,7 +106,7 @@ def test_execute_timeout(self, mock_hook): project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, - return_immediately=False + return_immediately=False, ) @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook') @@ -133,8 +115,7 @@ def test_execute_with_messages_callback(self, mock_hook): messages_callback_return_value = 'asdfg' def messages_callback( - pulled_messages: List[ReceivedMessage], - context: Dict[str, Any], + pulled_messages: List[ReceivedMessage], context: Dict[str, Any], ): assert pulled_messages == generated_messages @@ -158,10 +139,7 @@ def messages_callback( response = operator.execute({}) mock_hook.return_value.pull.assert_called_once_with( - project_id=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - max_messages=5, - return_immediately=True + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, return_immediately=True ) messages_callback.assert_called_once() diff --git a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py index 0c14e7e33388b..76ca0b3630afe 100644 --- a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py @@ -25,8 +25,13 @@ TASK_ID = 'test-adls-gcs-operator' ADLS_PATH_1 = '*' GCS_PATH = 'gs://test/' -MOCK_FILES = ["test/TEST1.csv", "test/TEST2.csv", "test/path/TEST3.csv", - "test/path/PARQUET.parquet", "test/path/PIC.png"] +MOCK_FILES = [ + "test/TEST1.csv", + "test/TEST2.csv", + "test/path/TEST3.csv", + "test/path/PARQUET.parquet", + "test/path/PIC.png", +] AZURE_CONN_ID = 'azure_data_lake_default' GCS_CONN_ID = 'google_cloud_default' IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] @@ -42,7 +47,7 @@ def test_init(self): dest_gcs=GCS_PATH, replace=False, azure_data_lake_conn_id=AZURE_CONN_ID, - gcp_conn_id=GCS_CONN_ID + gcp_conn_id=GCS_CONN_ID, ) self.assertEqual(operator.task_id, TASK_ID) @@ -54,8 +59,7 @@ def test_init(self): @mock.patch('airflow.providers.google.cloud.transfers.adls_to_gcs.AzureDataLakeHook') @mock.patch('airflow.providers.microsoft.azure.operators.adls_list.AzureDataLakeHook') - @mock.patch( - 'airflow.providers.google.cloud.transfers.adls_to_gcs.GCSHook') + @mock.patch('airflow.providers.google.cloud.transfers.adls_to_gcs.GCSHook') def test_execute(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook): """Test the execute function when the run is successful.""" @@ -77,21 +81,16 @@ def test_execute(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook): gcs_mock_hook.return_value.upload.assert_has_calls( [ mock.call( - bucket_name='test', - filename=mock.ANY, - object_name='test/path/PARQUET.parquet', - gzip=False + bucket_name='test', filename=mock.ANY, object_name='test/path/PARQUET.parquet', gzip=False ), mock.call( - bucket_name='test', - filename=mock.ANY, - object_name='test/path/TEST3.csv', - gzip=False + bucket_name='test', filename=mock.ANY, object_name='test/path/TEST3.csv', gzip=False ), mock.call(bucket_name='test', filename=mock.ANY, object_name='test/path/PIC.png', gzip=False), mock.call(bucket_name='test', filename=mock.ANY, object_name='test/TEST1.csv', gzip=False), - mock.call(bucket_name='test', filename=mock.ANY, object_name='test/TEST2.csv', gzip=False) - ], any_order=True + mock.call(bucket_name='test', filename=mock.ANY, object_name='test/TEST2.csv', gzip=False), + ], + any_order=True, ) adls_one_mock_hook.assert_called_once_with(azure_data_lake_conn_id=AZURE_CONN_ID) @@ -107,8 +106,7 @@ def test_execute(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook): @mock.patch('airflow.providers.google.cloud.transfers.adls_to_gcs.AzureDataLakeHook') @mock.patch('airflow.providers.microsoft.azure.operators.adls_list.AzureDataLakeHook') - @mock.patch( - 'airflow.providers.google.cloud.transfers.adls_to_gcs.GCSHook') + @mock.patch('airflow.providers.google.cloud.transfers.adls_to_gcs.GCSHook') def test_execute_with_gzip(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook): """Test the execute function when the run is successful.""" @@ -119,7 +117,7 @@ def test_execute_with_gzip(self, gcs_mock_hook, adls_one_mock_hook, adls_two_moc replace=False, azure_data_lake_conn_id=AZURE_CONN_ID, google_cloud_storage_conn_id=GCS_CONN_ID, - gzip=True + gzip=True, ) adls_one_mock_hook.return_value.list.return_value = MOCK_FILES @@ -130,21 +128,16 @@ def test_execute_with_gzip(self, gcs_mock_hook, adls_one_mock_hook, adls_two_moc gcs_mock_hook.return_value.upload.assert_has_calls( [ mock.call( - bucket_name='test', - filename=mock.ANY, - object_name='test/path/PARQUET.parquet', - gzip=True + bucket_name='test', filename=mock.ANY, object_name='test/path/PARQUET.parquet', gzip=True ), mock.call( - bucket_name='test', - filename=mock.ANY, - object_name='test/path/TEST3.csv', - gzip=True + bucket_name='test', filename=mock.ANY, object_name='test/path/TEST3.csv', gzip=True ), mock.call(bucket_name='test', filename=mock.ANY, object_name='test/path/PIC.png', gzip=True), mock.call(bucket_name='test', filename=mock.ANY, object_name='test/TEST1.csv', gzip=True), - mock.call(bucket_name='test', filename=mock.ANY, object_name='test/TEST2.csv', gzip=True) - ], any_order=True + mock.call(bucket_name='test', filename=mock.ANY, object_name='test/TEST2.csv', gzip=True), + ], + any_order=True, ) # we expect MOCK_FILES to be uploaded diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py index 049befdbfcd47..c0d8f8df39b35 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -30,10 +30,8 @@ class TestBigQueryToBigQueryOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryHook') def test_execute(self, mock_hook): - source_project_dataset_tables = '{}.{}'.format( - TEST_DATASET, TEST_TABLE_ID) - destination_project_dataset_table = '{}.{}'.format( - TEST_DATASET + '_new', TEST_TABLE_ID) + source_project_dataset_tables = '{}.{}'.format(TEST_DATASET, TEST_TABLE_ID) + destination_project_dataset_table = '{}.{}'.format(TEST_DATASET + '_new', TEST_TABLE_ID) write_disposition = 'WRITE_EMPTY' create_disposition = 'CREATE_IF_NEEDED' labels = {'k1': 'v1'} @@ -46,19 +44,15 @@ def test_execute(self, mock_hook): write_disposition=write_disposition, create_disposition=create_disposition, labels=labels, - encryption_configuration=encryption_configuration + encryption_configuration=encryption_configuration, ) operator.execute(None) - mock_hook.return_value \ - .get_conn.return_value \ - .cursor.return_value \ - .run_copy \ - .assert_called_once_with( - source_project_dataset_tables=source_project_dataset_tables, - destination_project_dataset_table=destination_project_dataset_table, - write_disposition=write_disposition, - create_disposition=create_disposition, - labels=labels, - encryption_configuration=encryption_configuration - ) + mock_hook.return_value.get_conn.return_value.cursor.return_value.run_copy.assert_called_once_with( + source_project_dataset_tables=source_project_dataset_tables, + destination_project_dataset_table=destination_project_dataset_table, + write_disposition=write_disposition, + create_disposition=create_disposition, + labels=labels, + encryption_configuration=encryption_configuration, + ) diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index e95d62365b1f9..8bd53e0355c83 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -30,8 +30,7 @@ class TestBigQueryToCloudStorageOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook') def test_execute(self, mock_hook): - source_project_dataset_table = '{}.{}'.format( - TEST_DATASET, TEST_TABLE_ID) + source_project_dataset_table = '{}.{}'.format(TEST_DATASET, TEST_TABLE_ID) destination_cloud_storage_uris = ['gs://some-bucket/some-file.txt'] compression = 'NONE' export_format = 'CSV' @@ -47,20 +46,16 @@ def test_execute(self, mock_hook): export_format=export_format, field_delimiter=field_delimiter, print_header=print_header, - labels=labels + labels=labels, ) operator.execute(None) - mock_hook.return_value \ - .get_conn.return_value \ - .cursor.return_value \ - .run_extract \ - .assert_called_once_with( - source_project_dataset_table=source_project_dataset_table, - destination_cloud_storage_uris=destination_cloud_storage_uris, - compression=compression, - export_format=export_format, - field_delimiter=field_delimiter, - print_header=print_header, - labels=labels - ) + mock_hook.return_value.get_conn.return_value.cursor.return_value.run_extract.assert_called_once_with( + source_project_dataset_table=source_project_dataset_table, + destination_cloud_storage_uris=destination_cloud_storage_uris, + compression=compression, + export_format=export_format, + field_delimiter=field_delimiter, + print_header=print_header, + labels=labels, + ) diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs_system.py index fbb78ced3c390..9595386af4292 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs_system.py @@ -27,7 +27,6 @@ @pytest.mark.system("google.cloud") @pytest.mark.credential_file(GCP_BIGQUERY_KEY) class BigQueryExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_BIGQUERY_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py b/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py index f21f249a29274..015a2870b1426 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mysql.py @@ -39,14 +39,13 @@ def test_execute_good_request_to_bq(self, mock_hook): ) operator.execute(None) - mock_hook.return_value \ - .get_conn.return_value \ - .cursor.return_value \ - .get_tabledata \ + # fmt: off + mock_hook.return_value.get_conn.return_value.cursor.return_value.get_tabledata\ .assert_called_once_with( dataset_id=TEST_DATASET, table_id=TEST_TABLE_ID, max_results=1000, selected_fields=None, - start_index=0 + start_index=0, ) + # fmt: on diff --git a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py index 7eca09d5a186e..557cc6d05f3e5 100644 --- a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py @@ -28,9 +28,7 @@ class TestCassandraToGCS(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.NamedTemporaryFile") - @mock.patch( - "airflow.providers.google.cloud.transfers.cassandra_to_gcs.GCSHook.upload" - ) + @mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.GCSHook.upload") @mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraHook") def test_execute(self, mock_hook, mock_upload, mock_tempfile): test_bucket = "test-bucket" @@ -50,10 +48,20 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile): operator.execute(None) mock_hook.return_value.get_conn.assert_called_once_with() - call_schema = call(bucket_name=test_bucket, object_name=schema, - filename=TMP_FILE_NAME, mime_type="application/json", gzip=gzip) - call_data = call(bucket_name=test_bucket, object_name=filename, - filename=TMP_FILE_NAME, mime_type="application/json", gzip=gzip) + call_schema = call( + bucket_name=test_bucket, + object_name=schema, + filename=TMP_FILE_NAME, + mime_type="application/json", + gzip=gzip, + ) + call_data = call( + bucket_name=test_bucket, + object_name=filename, + filename=TMP_FILE_NAME, + mime_type="application/json", + gzip=gzip, + ) mock_upload.assert_has_calls([call_schema, call_data], any_order=True) def test_convert_value(self): diff --git a/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs.py b/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs.py index d3a4df0933253..ac392c21a89a1 100644 --- a/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs.py @@ -31,44 +31,32 @@ "clicks", "impressions", ] -PARAMS = { - "level": "ad", - "date_preset": "yesterday" -} +PARAMS = {"level": "ad", "date_preset": "yesterday"} FACEBOOK_RETURN_VALUE = [ - { - "campaign_name": "abcd", - "campaign_id": "abcd", - "ad_id": "abcd", - "clicks": "2", - "impressions": "2", - } + {"campaign_name": "abcd", "campaign_id": "abcd", "ad_id": "abcd", "clicks": "2", "impressions": "2",} ] class TestFacebookAdsReportToGcsOperator: - @mock.patch("airflow.providers.google.cloud.transfers.facebook_ads_to_gcs.FacebookAdsReportingHook") @mock.patch("airflow.providers.google.cloud.transfers.facebook_ads_to_gcs.GCSHook") def test_execute(self, mock_gcs_hook, mock_ads_hook): mock_ads_hook.return_value.bulk_facebook_report.return_value = FACEBOOK_RETURN_VALUE - op = FacebookAdsReportToGcsOperator(facebook_conn_id=FACEBOOK_ADS_CONN_ID, - fields=FIELDS, - params=PARAMS, - object_name=GCS_OBJ_PATH, - bucket_name=GCS_BUCKET, - task_id="run_operator", - impersonation_chain=IMPERSONATION_CHAIN,) + op = FacebookAdsReportToGcsOperator( + facebook_conn_id=FACEBOOK_ADS_CONN_ID, + fields=FIELDS, + params=PARAMS, + object_name=GCS_OBJ_PATH, + bucket_name=GCS_BUCKET, + task_id="run_operator", + impersonation_chain=IMPERSONATION_CHAIN, + ) op.execute({}) - mock_ads_hook.assert_called_once_with(facebook_conn_id=FACEBOOK_ADS_CONN_ID, - api_version=API_VERSION) - mock_ads_hook.return_value.bulk_facebook_report.assert_called_once_with(params=PARAMS, - fields=FIELDS) + mock_ads_hook.assert_called_once_with(facebook_conn_id=FACEBOOK_ADS_CONN_ID, api_version=API_VERSION) + mock_ads_hook.return_value.bulk_facebook_report.assert_called_once_with(params=PARAMS, fields=FIELDS) mock_gcs_hook.assert_called_once_with( - gcp_conn_id=GCS_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCS_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_gcs_hook.return_value.upload.assert_called_once_with( + bucket_name=GCS_BUCKET, object_name=GCS_OBJ_PATH, filename=mock.ANY, gzip=False ) - mock_gcs_hook.return_value.upload.assert_called_once_with(bucket_name=GCS_BUCKET, - object_name=GCS_OBJ_PATH, - filename=mock.ANY, - gzip=False) diff --git a/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs_system.py index 143768c03f921..120f520093368 100644 --- a/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_facebook_ads_to_gcs_system.py @@ -32,16 +32,11 @@ FACEBOOK_CREDENTIALS_PATH = os.path.join(CREDENTIALS_DIR, FACEBOOK_KEY) CONNECTION_TYPE = os.environ.get('CONNECTION_TYPE', 'facebook_social') FACEBOOK_CONNECTION_ID = os.environ.get('FACEBOOK_CONNECTION_ID', 'facebook_default') -CONFIG_REQUIRED_FIELDS = ["app_id", - "app_secret", - "access_token", - "account_id"] +CONFIG_REQUIRED_FIELDS = ["app_id", "app_secret", "access_token", "account_id"] @contextmanager -def provide_facebook_connection( - key_file_path: str -): +def provide_facebook_connection(key_file_path: str): """ Context manager that provides a temporary value of AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT connection. It build a new connection that includes path to provided service json, @@ -51,20 +46,14 @@ def provide_facebook_connection( :type key_file_path: str """ if not key_file_path.endswith(".json"): - raise AirflowException( - "Use a JSON key file." - ) + raise AirflowException("Use a JSON key file.") with open(key_file_path, 'r') as credentials: creds = json.load(credentials) missing_keys = CONFIG_REQUIRED_FIELDS - creds.keys() if missing_keys: message = "{missing_keys} fields are missing".format(missing_keys=missing_keys) raise AirflowException(message) - conn = Connection( - conn_id=FACEBOOK_CONNECTION_ID, - conn_type=CONNECTION_TYPE, - extra=json.dumps(creds) - ) + conn = Connection(conn_id=FACEBOOK_CONNECTION_ID, conn_type=CONNECTION_TYPE, extra=json.dumps(creds)) with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}): yield @@ -73,7 +62,6 @@ def provide_facebook_connection( @pytest.mark.credential_file(GCP_BIGQUERY_KEY) @pytest.mark.system("google.cloud") class FacebookAdsToGcsExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_BIGQUERY_KEY) @provide_facebook_connection(FACEBOOK_CREDENTIALS_PATH) def test_dag_example(self): diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index b327cecdc4e05..a61889e339dab 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -30,41 +30,40 @@ class TestGoogleCloudStorageToBigQueryOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') def test_execute_explicit_project_legacy(self, bq_hook): - operator = GCSToBigQueryOperator(task_id=TASK_ID, - bucket=TEST_BUCKET, - source_objects=TEST_SOURCE_OBJECTS, - destination_project_dataset_table=TEST_EXPLICIT_DEST, - max_id_key=MAX_ID_KEY) + operator = GCSToBigQueryOperator( + task_id=TASK_ID, + bucket=TEST_BUCKET, + source_objects=TEST_SOURCE_OBJECTS, + destination_project_dataset_table=TEST_EXPLICIT_DEST, + max_id_key=MAX_ID_KEY, + ) # using legacy SQL bq_hook.return_value.get_conn.return_value.cursor.return_value.use_legacy_sql = True operator.execute(None) - bq_hook.return_value \ - .get_conn.return_value \ - .cursor.return_value \ - .execute \ - .assert_called_once_with("SELECT MAX(id) FROM [test-project.dataset.table]") + bq_hook.return_value.get_conn.return_value.cursor.return_value.execute.assert_called_once_with( + "SELECT MAX(id) FROM [test-project.dataset.table]" + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') def test_execute_explicit_project(self, bq_hook): - operator = GCSToBigQueryOperator(task_id=TASK_ID, - bucket=TEST_BUCKET, - source_objects=TEST_SOURCE_OBJECTS, - destination_project_dataset_table=TEST_EXPLICIT_DEST, - max_id_key=MAX_ID_KEY) + operator = GCSToBigQueryOperator( + task_id=TASK_ID, + bucket=TEST_BUCKET, + source_objects=TEST_SOURCE_OBJECTS, + destination_project_dataset_table=TEST_EXPLICIT_DEST, + max_id_key=MAX_ID_KEY, + ) # using non-legacy SQL bq_hook.return_value.get_conn.return_value.cursor.return_value.use_legacy_sql = False operator.execute(None) - bq_hook.return_value \ - .get_conn.return_value \ - .cursor.return_value \ - .execute \ - .assert_called_once_with("SELECT MAX(id) FROM `test-project.dataset.table`") + bq_hook.return_value.get_conn.return_value.cursor.return_value.execute.assert_called_once_with( + "SELECT MAX(id) FROM `test-project.dataset.table`" + ) diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery_system.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery_system.py index ab91cdde0aaed..13a75bffb08d9 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery_system.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery_system.py @@ -25,7 +25,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_BIGQUERY_KEY) class TestGoogleCloudStorageToBigQueryExample(GoogleSystemTest): - @provide_gcp_context(GCP_BIGQUERY_KEY) def test_run_example_dag_gcs_to_bigquery_operator(self): self.run_dag('example_gcs_to_bigquery_operator', CLOUD_DAG_FOLDER) diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py index a51b101147151..ce9ef0e43bff9 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py @@ -65,34 +65,36 @@ class TestGoogleCloudStorageToCloudStorageOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_no_prefix(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_PREFIX, - destination_bucket=DESTINATION_BUCKET) + destination_bucket=DESTINATION_BUCKET, + ) operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - TEST_BUCKET, prefix="", delimiter="test_object" - ) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix="", delimiter="test_object") @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_no_suffix(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_SUFFIX, - destination_bucket=DESTINATION_BUCKET) + destination_bucket=DESTINATION_BUCKET, + ) operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - TEST_BUCKET, prefix="test_object", delimiter="" - ) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix="test_object", delimiter="") @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_wildcard_with_replace_flag_false(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_SUFFIX, destination_bucket=DESTINATION_BUCKET, - replace=False) + replace=False, + ) operator.execute(None) mock_calls = [ @@ -104,31 +106,31 @@ def test_execute_wildcard_with_replace_flag_false(self, mock_hook): @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_prefix_and_suffix(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_MIDDLE, - destination_bucket=DESTINATION_BUCKET) + destination_bucket=DESTINATION_BUCKET, + ) operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - TEST_BUCKET, prefix="test", delimiter="object" - ) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix="test", delimiter="object") # copy with wildcard @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_wildcard_with_destination_object(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - destination_object=DESTINATION_OBJECT_PREFIX) + destination_object=DESTINATION_OBJECT_PREFIX, + ) operator.execute(None) mock_calls = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'foo/bar/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'foo/bar/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'foo/bar/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'foo/bar/file2.txt'), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls) @@ -136,19 +138,21 @@ def test_execute_wildcard_with_destination_object(self, mock_hook): def test_execute_wildcard_with_destination_object_retained_prefix(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - destination_object='{}/{}'.format(DESTINATION_OBJECT_PREFIX, - SOURCE_OBJECT_WILDCARD_SUFFIX[:-1]) + destination_object='{}/{}'.format(DESTINATION_OBJECT_PREFIX, SOURCE_OBJECT_WILDCARD_SUFFIX[:-1]), ) operator.execute(None) mock_calls_retained = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'foo/bar/test_object/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'foo/bar/test_object/file2.txt'), + mock.call( + TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'foo/bar/test_object/file1.txt' + ), + mock.call( + TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'foo/bar/test_object/file2.txt' + ), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_retained) @@ -156,16 +160,16 @@ def test_execute_wildcard_with_destination_object_retained_prefix(self, mock_hoo def test_execute_wildcard_without_destination_object(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, - destination_bucket=DESTINATION_BUCKET) + destination_bucket=DESTINATION_BUCKET, + ) operator.execute(None) mock_calls_none = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'test_object/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'test_object/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'test_object/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'test_object/file2.txt'), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_none) @@ -173,17 +177,17 @@ def test_execute_wildcard_without_destination_object(self, mock_hook): def test_execute_wildcard_empty_destination_object(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - destination_object='') + destination_object='', + ) operator.execute(None) mock_calls_empty = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, '/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, '/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, '/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, '/file2.txt'), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_empty) @@ -191,17 +195,17 @@ def test_execute_wildcard_empty_destination_object(self, mock_hook): def test_execute_last_modified_time(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - last_modified_time=None) + last_modified_time=None, + ) operator.execute(None) mock_calls_none = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'test_object/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'test_object/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'test_object/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'test_object/file2.txt'), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_none) @@ -210,17 +214,17 @@ def test_wc_with_last_modified_time_with_all_true_cond(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST mock_hook.return_value.is_updated_after.side_effect = [True, True, True] operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - last_modified_time=MOD_TIME_1) + last_modified_time=MOD_TIME_1, + ) operator.execute(None) mock_calls_none = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'test_object/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'test_object/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'test_object/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'test_object/file2.txt'), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_none) @@ -229,31 +233,33 @@ def test_wc_with_last_modified_time_with_one_true_cond(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST mock_hook.return_value.is_updated_after.side_effect = [True, False, False] operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - last_modified_time=MOD_TIME_1) + last_modified_time=MOD_TIME_1, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_called_once_with( - TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'test_object/file1.txt') + TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'test_object/file1.txt' + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_wc_with_no_last_modified_time(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_WILDCARD_FILENAME, destination_bucket=DESTINATION_BUCKET, - last_modified_time=None) + last_modified_time=None, + ) operator.execute(None) mock_calls_none = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'test_object/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'test_object/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'test_object/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'test_object/file2.txt'), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_none) @@ -261,67 +267,81 @@ def test_wc_with_no_last_modified_time(self, mock_hook): def test_no_prefix_with_last_modified_time_with_true_cond(self, mock_hook): mock_hook.return_value.is_updated_after.return_value = True operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=DESTINATION_BUCKET, destination_object=SOURCE_OBJECT_NO_WILDCARD, - last_modified_time=MOD_TIME_1) + last_modified_time=MOD_TIME_1, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_called_once_with( - TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt') + TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt' + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_no_prefix_with_maximum_modified_time_with_true_cond(self, mock_hook): mock_hook.return_value.is_updated_before.return_value = True operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=DESTINATION_BUCKET, destination_object=SOURCE_OBJECT_NO_WILDCARD, - maximum_modified_time=MOD_TIME_1) + maximum_modified_time=MOD_TIME_1, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_called_once_with( - TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt') + TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt' + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_exe_last_modified_time_and_maximum_modified_time_with_true_cond(self, mock_hook): mock_hook.return_value.is_updated_between.return_value = True operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=DESTINATION_BUCKET, destination_object=SOURCE_OBJECT_NO_WILDCARD, last_modified_time=MOD_TIME_1, - maximum_modified_time=MOD_TIME_2) + maximum_modified_time=MOD_TIME_2, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_called_once_with( - TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt') + TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt' + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_no_prefix_with_no_last_modified_time(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=DESTINATION_BUCKET, destination_object=SOURCE_OBJECT_NO_WILDCARD, - last_modified_time=None) + last_modified_time=None, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_called_once_with( - TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt') + TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt' + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_no_prefix_with_last_modified_time_with_false_cond(self, mock_hook): mock_hook.return_value.is_updated_after.return_value = False operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=DESTINATION_BUCKET, destination_object=SOURCE_OBJECT_NO_WILDCARD, - last_modified_time=MOD_TIME_1) + last_modified_time=MOD_TIME_1, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_not_called() @@ -330,31 +350,37 @@ def test_no_prefix_with_last_modified_time_with_false_cond(self, mock_hook): def test_executes_with_is_older_than_with_true_cond(self, mock_hook): mock_hook.return_value.is_older_than.return_value = True operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=DESTINATION_BUCKET, destination_object=SOURCE_OBJECT_NO_WILDCARD, last_modified_time=MOD_TIME_1, maximum_modified_time=MOD_TIME_2, - is_older_than=3600) + is_older_than=3600, + ) operator.execute(None) mock_hook.return_value.rewrite.assert_called_once_with( - TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt') + TEST_BUCKET, 'test_object.txt', DESTINATION_BUCKET, 'test_object.txt' + ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_execute_more_than_1_wildcard(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_MULTIPLE_WILDCARDS, destination_bucket=DESTINATION_BUCKET, - destination_object=DESTINATION_OBJECT_PREFIX) + destination_object=DESTINATION_OBJECT_PREFIX, + ) total_wildcards = operator.source_object.count(WILDCARD) - error_msg = "Only one wildcard '[*]' is allowed in source_object parameter. " \ - "Found {}".format(total_wildcards) + error_msg = "Only one wildcard '[*]' is allowed in source_object parameter. " "Found {}".format( + total_wildcards + ) with self.assertRaisesRegex(AirflowException, error_msg): operator.execute(None) @@ -363,16 +389,17 @@ def test_execute_more_than_1_wildcard(self, mock_hook): def test_execute_with_empty_destination_bucket(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_FILES_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_object=SOURCE_OBJECT_NO_WILDCARD, destination_bucket=None, - destination_object=DESTINATION_OBJECT_PREFIX) + destination_object=DESTINATION_OBJECT_PREFIX, + ) with mock.patch.object(operator.log, 'warning') as mock_warn: operator.execute(None) mock_warn.assert_called_once_with( - 'destination_bucket is None. Defaulting it to source_bucket (%s)', - TEST_BUCKET + 'destination_bucket is None. Defaulting it to source_bucket (%s)', TEST_BUCKET ) self.assertEqual(operator.destination_bucket, operator.source_bucket) @@ -380,30 +407,29 @@ def test_execute_with_empty_destination_bucket(self, mock_hook): @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_executes_with_empty_source_objects(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, - source_objects=SOURCE_OBJECTS_NO_FILE) + task_id=TASK_ID, source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_NO_FILE + ) operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - TEST_BUCKET, prefix='', delimiter=None - ) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix='', delimiter=None) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_raises_exception_with_two_empty_list_inside_source_objects(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_OBJECTS_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, - source_objects=SOURCE_OBJECTS_TWO_EMPTY_STRING) + task_id=TASK_ID, source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_TWO_EMPTY_STRING + ) - with self.assertRaisesRegex(AirflowException, - "You can't have two empty strings inside source_object"): + with self.assertRaisesRegex( + AirflowException, "You can't have two empty strings inside source_object" + ): operator.execute(None) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_executes_with_single_item_in_source_objects(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, - source_objects=SOURCE_OBJECTS_SINGLE_FILE) + task_id=TASK_ID, source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_SINGLE_FILE + ) operator.execute(None) mock_hook.return_value.list.assert_called_once_with( TEST_BUCKET, prefix=SOURCE_OBJECTS_SINGLE_FILE[0], delimiter=None @@ -412,42 +438,44 @@ def test_executes_with_single_item_in_source_objects(self, mock_hook): @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_executes_with_multiple_items_in_source_objects(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, - source_objects=SOURCE_OBJECTS_MULTIPLE_FILES) + task_id=TASK_ID, source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_MULTIPLE_FILES + ) operator.execute(None) mock_hook.return_value.list.assert_has_calls( [ mock.call(TEST_BUCKET, prefix='test_object/file1.txt', delimiter=None), - mock.call(TEST_BUCKET, prefix='test_object/file2.txt', delimiter=None) + mock.call(TEST_BUCKET, prefix='test_object/file2.txt', delimiter=None), ], - any_order=True + any_order=True, ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_executes_with_a_delimiter(self, mock_hook): operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, - source_objects=SOURCE_OBJECTS_NO_FILE, delimiter=DELIMITER) - operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - TEST_BUCKET, prefix='', delimiter=DELIMITER + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_objects=SOURCE_OBJECTS_NO_FILE, + delimiter=DELIMITER, ) + operator.execute(None) + mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix='', delimiter=DELIMITER) # COPY @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') def test_executes_with_delimiter_and_destination_object(self, mock_hook): mock_hook.return_value.list.return_value = ['test_object/file3.json'] operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_LIST, destination_bucket=DESTINATION_BUCKET, destination_object=DESTINATION_OBJECT, - delimiter=DELIMITER) + delimiter=DELIMITER, + ) operator.execute(None) mock_calls = [ - mock.call(TEST_BUCKET, 'test_object/file3.json', - DESTINATION_BUCKET, DESTINATION_OBJECT), + mock.call(TEST_BUCKET, 'test_object/file3.json', DESTINATION_BUCKET, DESTINATION_OBJECT), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls) @@ -455,19 +483,18 @@ def test_executes_with_delimiter_and_destination_object(self, mock_hook): def test_executes_with_different_delimiter_and_destination_object(self, mock_hook): mock_hook.return_value.list.return_value = ['test_object/file1.txt', 'test_object/file2.txt'] operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_LIST, destination_bucket=DESTINATION_BUCKET, destination_object=DESTINATION_OBJECT, - delimiter='.txt') + delimiter='.txt', + ) operator.execute(None) mock_calls = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, DESTINATION_OBJECT), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, DESTINATION_OBJECT), - + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, DESTINATION_OBJECT), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, DESTINATION_OBJECT), ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls) @@ -475,16 +502,14 @@ def test_executes_with_different_delimiter_and_destination_object(self, mock_hoo def test_executes_with_no_destination_bucket_and_no_destination_object(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_OBJECTS_LIST operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, - source_objects=SOURCE_OBJECTS_LIST) + task_id=TASK_ID, source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_LIST + ) operator.execute(None) mock_calls = [ - mock.call(TEST_BUCKET, 'test_object/file1.txt', - TEST_BUCKET, 'test_object/file1.txt'), - mock.call(TEST_BUCKET, 'test_object/file2.txt', - TEST_BUCKET, 'test_object/file2.txt'), - mock.call(TEST_BUCKET, 'test_object/file3.json', - TEST_BUCKET, 'test_object/file3.json'), ] + mock.call(TEST_BUCKET, 'test_object/file1.txt', TEST_BUCKET, 'test_object/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', TEST_BUCKET, 'test_object/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file3.json', TEST_BUCKET, 'test_object/file3.json'), + ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook') @@ -492,23 +517,17 @@ def test_wc_with_last_modified_time_with_all_true_cond_no_file(self, mock_hook): mock_hook.return_value.list.return_value = SOURCE_OBJECTS_LIST mock_hook.return_value.is_updated_after.side_effect = [True, True, True] operator = GCSToGCSOperator( - task_id=TASK_ID, source_bucket=TEST_BUCKET, + task_id=TASK_ID, + source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_NO_FILE, destination_bucket=DESTINATION_BUCKET, - last_modified_time=MOD_TIME_1) + last_modified_time=MOD_TIME_1, + ) operator.execute(None) mock_calls_none = [ - mock.call( - TEST_BUCKET, 'test_object/file1.txt', - DESTINATION_BUCKET, 'test_object/file1.txt' - ), - mock.call( - TEST_BUCKET, 'test_object/file2.txt', - DESTINATION_BUCKET, 'test_object/file2.txt' - ), - mock.call( - TEST_BUCKET, 'test_object/file3.json', - DESTINATION_BUCKET, 'test_object/file3.json' - ), ] + mock.call(TEST_BUCKET, 'test_object/file1.txt', DESTINATION_BUCKET, 'test_object/file1.txt'), + mock.call(TEST_BUCKET, 'test_object/file2.txt', DESTINATION_BUCKET, 'test_object/file2.txt'), + mock.call(TEST_BUCKET, 'test_object/file3.json', DESTINATION_BUCKET, 'test_object/file3.json'), + ] mock_hook.return_value.rewrite.assert_has_calls(mock_calls_none) diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs_system.py index 6d765df8a90f2..13b3a7afeb372 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs_system.py @@ -19,7 +19,12 @@ import pytest from airflow.providers.google.cloud.example_dags.example_gcs_to_gcs import ( - BUCKET_1_DST, BUCKET_1_SRC, BUCKET_2_DST, BUCKET_2_SRC, BUCKET_3_DST, BUCKET_3_SRC, + BUCKET_1_DST, + BUCKET_1_SRC, + BUCKET_2_DST, + BUCKET_2_SRC, + BUCKET_3_DST, + BUCKET_3_SRC, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -28,7 +33,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class GcsToGcsExampleDagsSystemTest(GoogleSystemTest): - def create_buckets(self): """Create a buckets in Google Cloud Storage service with sample content.""" @@ -45,7 +49,8 @@ def create_buckets(self): "bash", "-c", "cat /dev/urandom | head -c $((1 * 1024 * 1024)) | gsutil cp - {}".format(first_parent), - ], key=GCP_GCS_KEY + ], + key=GCP_GCS_KEY, ) self.execute_with_ctx( @@ -53,7 +58,8 @@ def create_buckets(self): "bash", "-c", "cat /dev/urandom | head -c $((1 * 1024 * 1024)) | gsutil cp - {}".format(second_parent), - ], key=GCP_GCS_KEY + ], + key=GCP_GCS_KEY, ) self.upload_to_gcs(first_parent, f"gs://{BUCKET_1_SRC}/file.bin") diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_local.py b/tests/providers/google/cloud/transfers/test_gcs_to_local.py index 4763cae0c7603..1a926321fa215 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_local.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_local.py @@ -35,10 +35,7 @@ class TestGoogleCloudStorageDownloadOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_local.GCSHook") def test_execute(self, mock_hook): operator = GCSToLocalFilesystemOperator( - task_id=TASK_ID, - bucket=TEST_BUCKET, - object_name=TEST_OBJECT, - filename=LOCAL_FILE_PATH, + task_id=TASK_ID, bucket=TEST_BUCKET, object_name=TEST_OBJECT, filename=LOCAL_FILE_PATH, ) operator.execute(None) diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py index 0a42919c8dbaa..8b9b022b86204 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py @@ -63,9 +63,7 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook): ) task.execute({}) gcs_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) sftp_hook.assert_called_once_with(SFTP_CONN_ID) @@ -74,9 +72,7 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook): self.assertEqual(kwargs["object_name"], SOURCE_OBJECT_NO_WILDCARD) args, kwargs = sftp_hook.return_value.store_file.call_args - self.assertEqual( - args[0], os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD) - ) + self.assertEqual(args[0], os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD)) gcs_hook.return_value.delete.assert_not_called() @@ -96,9 +92,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook): ) task.execute(None) gcs_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) sftp_hook.assert_called_once_with(SFTP_CONN_ID) @@ -107,13 +101,9 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook): self.assertEqual(kwargs["object_name"], SOURCE_OBJECT_NO_WILDCARD) args, kwargs = sftp_hook.return_value.store_file.call_args - self.assertEqual( - args[0], os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD) - ) + self.assertEqual(args[0], os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD)) - gcs_hook.return_value.delete.assert_called_once_with( - TEST_BUCKET, SOURCE_OBJECT_NO_WILDCARD - ) + gcs_hook.return_value.delete.assert_called_once_with(TEST_BUCKET, SOURCE_OBJECT_NO_WILDCARD) @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.GCSHook") @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.SFTPHook") @@ -131,9 +121,7 @@ def test_execute_copy_with_wildcard(self, sftp_hook, gcs_hook): ) operator.execute(None) - gcs_hook.return_value.list.assert_called_with( - TEST_BUCKET, delimiter=".txt", prefix="test_object" - ) + gcs_hook.return_value.list.assert_called_with(TEST_BUCKET, delimiter=".txt", prefix="test_object") call_one, call_two = gcs_hook.return_value.download.call_args_list self.assertEqual(call_one[1]["bucket_name"], TEST_BUCKET) @@ -158,9 +146,7 @@ def test_execute_move_with_wildcard(self, sftp_hook, gcs_hook): ) operator.execute(None) - gcs_hook.return_value.list.assert_called_with( - TEST_BUCKET, delimiter=".txt", prefix="test_object" - ) + gcs_hook.return_value.list.assert_called_with(TEST_BUCKET, delimiter=".txt", prefix="test_object") call_one, call_two = gcs_hook.return_value.delete.call_args_list self.assertEqual(call_one[0], (TEST_BUCKET, "test_object/file1.txt")) diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_sftp_system.py b/tests/providers/google/cloud/transfers/test_gcs_to_sftp_system.py index a0ab3c947b70e..e0345a8afcd98 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_sftp_system.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_sftp_system.py @@ -22,7 +22,9 @@ import pytest from airflow.providers.google.cloud.example_dags.example_gcs_to_sftp import ( - BUCKET_SRC, OBJECT_SRC_1, OBJECT_SRC_2, + BUCKET_SRC, + OBJECT_SRC_1, + OBJECT_SRC_2, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context @@ -30,7 +32,6 @@ @pytest.mark.credential_file(GCP_GCS_KEY) class GcsToSftpExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() @@ -40,18 +41,12 @@ def setUp(self): # 2. Prepare files for bucket_src, object_source in product( - ( - BUCKET_SRC, - "{}/subdir-1".format(BUCKET_SRC), - "{}/subdir-2".format(BUCKET_SRC), - ), + (BUCKET_SRC, "{}/subdir-1".format(BUCKET_SRC), "{}/subdir-2".format(BUCKET_SRC),), (OBJECT_SRC_1, OBJECT_SRC_2), ): source_path = "gs://{}/{}".format(bucket_src, object_source) self.upload_content_to_gcs( - lines=f"{os.urandom(1 * 1024 * 1024)}", - bucket=source_path, - filename=object_source + lines=f"{os.urandom(1 * 1024 * 1024)}", bucket=source_path, filename=object_source ) @provide_gcp_context(GCP_GCS_KEY) diff --git a/tests/providers/google/cloud/transfers/test_local_to_gcs.py b/tests/providers/google/cloud/transfers/test_local_to_gcs.py index e8e00b1b5ae9d..a43274e02d8e7 100644 --- a/tests/providers/google/cloud/transfers/test_local_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_local_to_gcs.py @@ -31,17 +31,10 @@ class TestFileToGcsOperator(unittest.TestCase): - _config = { - 'bucket': 'dummy', - 'mime_type': 'application/octet-stream', - 'gzip': False - } + _config = {'bucket': 'dummy', 'mime_type': 'application/octet-stream', 'gzip': False} def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) self.testfile1 = '/tmp/fake1.csv' with open(self.testfile1, 'wb') as f: @@ -61,7 +54,7 @@ def test_init(self): dag=self.dag, src=self.testfile1, dst='test/test1.csv', - **self._config + **self._config, ) self.assertEqual(operator.src, self.testfile1) self.assertEqual(operator.dst, 'test/test1.csv') @@ -69,8 +62,7 @@ def test_init(self): self.assertEqual(operator.mime_type, self._config['mime_type']) self.assertEqual(operator.gzip, self._config['gzip']) - @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', - autospec=True) + @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', autospec=True) def test_execute(self, mock_hook): mock_instance = mock_hook.return_value operator = LocalFilesystemToGCSOperator( @@ -78,7 +70,7 @@ def test_execute(self, mock_hook): dag=self.dag, src=self.testfile1, dst='test/test1.csv', - **self._config + **self._config, ) operator.execute(None) mock_instance.upload.assert_called_once_with( @@ -86,45 +78,36 @@ def test_execute(self, mock_hook): filename=self.testfile1, gzip=self._config['gzip'], mime_type=self._config['mime_type'], - object_name='test/test1.csv' + object_name='test/test1.csv', ) - @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', - autospec=True) + @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', autospec=True) def test_execute_multiple(self, mock_hook): mock_instance = mock_hook.return_value operator = LocalFilesystemToGCSOperator( - task_id='gcs_to_file_sensor', - dag=self.dag, - src=self.testfiles, - dst='test/', - **self._config + task_id='gcs_to_file_sensor', dag=self.dag, src=self.testfiles, dst='test/', **self._config ) operator.execute(None) - files_objects = zip(self.testfiles, ['test/' + os.path.basename(testfile) - for testfile in self.testfiles]) + files_objects = zip( + self.testfiles, ['test/' + os.path.basename(testfile) for testfile in self.testfiles] + ) calls = [ mock.call( bucket_name=self._config['bucket'], filename=filepath, gzip=self._config['gzip'], mime_type=self._config['mime_type'], - object_name=object_name + object_name=object_name, ) for filepath, object_name in files_objects ] mock_instance.upload.assert_has_calls(calls) - @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', - autospec=True) + @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', autospec=True) def test_execute_wildcard(self, mock_hook): mock_instance = mock_hook.return_value operator = LocalFilesystemToGCSOperator( - task_id='gcs_to_file_sensor', - dag=self.dag, - src='/tmp/fake*.csv', - dst='test/', - **self._config + task_id='gcs_to_file_sensor', dag=self.dag, src='/tmp/fake*.csv', dst='test/', **self._config ) operator.execute(None) object_names = ['test/' + os.path.basename(fp) for fp in glob('/tmp/fake*.csv')] @@ -135,14 +118,13 @@ def test_execute_wildcard(self, mock_hook): filename=filepath, gzip=self._config['gzip'], mime_type=self._config['mime_type'], - object_name=object_name + object_name=object_name, ) for filepath, object_name in files_objects ] mock_instance.upload.assert_has_calls(calls) - @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', - autospec=True) + @mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', autospec=True) def test_execute_negative(self, mock_hook): mock_instance = mock_hook.return_value operator = LocalFilesystemToGCSOperator( @@ -150,7 +132,7 @@ def test_execute_negative(self, mock_hook): dag=self.dag, src='/tmp/fake*.csv', dst='test/test1.csv', - **self._config + **self._config, ) print(glob('/tmp/fake*.csv')) with pytest.raises(ValueError): diff --git a/tests/providers/google/cloud/transfers/test_local_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_local_to_gcs_system.py index e0ed840a494fc..edabb76908afd 100644 --- a/tests/providers/google/cloud/transfers/test_local_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_local_to_gcs_system.py @@ -27,7 +27,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class LocalFilesystemToGCSOperatorExampleDagsTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py index 06e971a08cf4e..6221e989de282 100644 --- a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py @@ -32,34 +32,28 @@ JSON_FILENAME = 'test_{}.ndjson' GZIP = False -ROWS = [ - ('mock_row_content_1', 42), - ('mock_row_content_2', 43), - ('mock_row_content_3', 44) -] +ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)] CURSOR_DESCRIPTION = ( ('some_str', 0, None, None, None, None, None), - ('some_num', 3, None, None, None, None, None) + ('some_num', 3, None, None, None, None, None), ) NDJSON_LINES = [ b'{"some_num": 42, "some_str": "mock_row_content_1"}\n', b'{"some_num": 43, "some_str": "mock_row_content_2"}\n', - b'{"some_num": 44, "some_str": "mock_row_content_3"}\n' + b'{"some_num": 44, "some_str": "mock_row_content_3"}\n', ] SCHEMA_FILENAME = 'schema_test.json' SCHEMA_JSON = [ b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ', - b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]' + b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]', ] @unittest.skipIf(PY38, "Mssql package not available when Python >= 3.8.") class TestMsSqlToGoogleCloudStorageOperator(unittest.TestCase): - def test_init(self): """Test MySqlToGoogleCloudStorageOperator instance is properly initialized.""" - op = MSSQLToGCSOperator( - task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME) + op = MSSQLToGCSOperator(task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME) self.assertEqual(op.task_id, TASK_ID) self.assertEqual(op.sql, SQL) self.assertEqual(op.bucket, BUCKET) @@ -70,11 +64,8 @@ def test_init(self): def test_exec_success_json(self, gcs_hook_mock_class, mssql_hook_mock_class): """Test successful run of execute function for JSON""" op = MSSQLToGCSOperator( - task_id=TASK_ID, - mssql_conn_id=MSSQL_CONN_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME) + task_id=TASK_ID, mssql_conn_id=MSSQL_CONN_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME + ) mssql_hook_mock = mssql_hook_mock_class.return_value mssql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) @@ -125,7 +116,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, - approx_max_file_size_bytes=len(expected_upload[JSON_FILENAME.format(0)])) + approx_max_file_size_bytes=len(expected_upload[JSON_FILENAME.format(0)]), + ) op.execute(None) @mock.patch('airflow.providers.google.cloud.transfers.mssql_to_gcs.MsSqlHook') @@ -146,11 +138,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab gcs_hook_mock.upload.side_effect = _assert_upload op = MSSQLToGCSOperator( - task_id=TASK_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME, - schema_filename=SCHEMA_FILENAME) + task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME + ) op.execute(None) # once for the file and once for the schema diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py index abd8a18d32f43..19a969a579fe1 100644 --- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py @@ -35,53 +35,48 @@ CSV_FILENAME = 'test_{}.csv' SCHEMA = [ {'mode': 'REQUIRED', 'name': 'some_str', 'type': 'FLOAT'}, - {'mode': 'REQUIRED', 'name': 'some_num', 'type': 'TIMESTAMP'} + {'mode': 'REQUIRED', 'name': 'some_num', 'type': 'TIMESTAMP'}, ] -ROWS = [ - ('mock_row_content_1', 42), - ('mock_row_content_2', 43), - ('mock_row_content_3', 44) -] -CURSOR_DESCRIPTION = ( - ('some_str', 0, 0, 0, 0, 0, False), - ('some_num', 1005, 0, 0, 0, 0, False) -) +ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)] +CURSOR_DESCRIPTION = (('some_str', 0, 0, 0, 0, 0, False), ('some_num', 1005, 0, 0, 0, 0, False)) NDJSON_LINES = [ b'{"some_num": 42, "some_str": "mock_row_content_1"}\n', b'{"some_num": 43, "some_str": "mock_row_content_2"}\n', - b'{"some_num": 44, "some_str": "mock_row_content_3"}\n' + b'{"some_num": 44, "some_str": "mock_row_content_3"}\n', ] CSV_LINES = [ - b'some_str,some_num\r\n' - b'mock_row_content_1,42\r\n', + b'some_str,some_num\r\n' b'mock_row_content_1,42\r\n', b'mock_row_content_2,43\r\n', - b'mock_row_content_3,44\r\n' + b'mock_row_content_3,44\r\n', ] CSV_LINES_PIPE_DELIMITED = [ - b'some_str|some_num\r\n' - b'mock_row_content_1|42\r\n', + b'some_str|some_num\r\n' b'mock_row_content_1|42\r\n', b'mock_row_content_2|43\r\n', - b'mock_row_content_3|44\r\n' + b'mock_row_content_3|44\r\n', ] SCHEMA_FILENAME = 'schema_test.json' SCHEMA_JSON = [ b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ', - b'{"mode": "REQUIRED", "name": "some_num", "type": "STRING"}]' + b'{"mode": "REQUIRED", "name": "some_num", "type": "STRING"}]', ] CUSTOM_SCHEMA_JSON = [ b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ', - b'{"mode": "REQUIRED", "name": "some_num", "type": "TIMESTAMP"}]' + b'{"mode": "REQUIRED", "name": "some_num", "type": "TIMESTAMP"}]', ] class TestMySqlToGoogleCloudStorageOperator(unittest.TestCase): - def test_init(self): """Test MySqlToGoogleCloudStorageOperator instance is properly initialized.""" op = MySQLToGCSOperator( - task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, - export_format='CSV', field_delimiter='|') + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=JSON_FILENAME, + export_format='CSV', + field_delimiter='|', + ) self.assertEqual(op.task_id, TASK_ID) self.assertEqual(op.sql, SQL) self.assertEqual(op.bucket, BUCKET) @@ -89,37 +84,31 @@ def test_init(self): self.assertEqual(op.export_format, 'csv') self.assertEqual(op.field_delimiter, '|') - @parameterized.expand([ - ("string", None, "string"), - (datetime.date(1970, 1, 2), None, 86400), - (datetime.date(1970, 1, 2), "DATE", "1970-01-02"), - (datetime.datetime(1970, 1, 1, 1, 0), None, 3600), - (decimal.Decimal(5), None, 5), - (b"bytes", "BYTES", "Ynl0ZXM="), - (b"\x00\x01", "INTEGER", 1), - (None, "BYTES", None) - ]) + @parameterized.expand( + [ + ("string", None, "string"), + (datetime.date(1970, 1, 2), None, 86400), + (datetime.date(1970, 1, 2), "DATE", "1970-01-02"), + (datetime.datetime(1970, 1, 1, 1, 0), None, 3600), + (decimal.Decimal(5), None, 5), + (b"bytes", "BYTES", "Ynl0ZXM="), + (b"\x00\x01", "INTEGER", 1), + (None, "BYTES", None), + ] + ) def test_convert_type(self, value, schema_type, expected): op = MySQLToGCSOperator( - task_id=TASK_ID, - mysql_conn_id=MYSQL_CONN_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME) - self.assertEqual( - op.convert_type(value, schema_type), - expected) + task_id=TASK_ID, mysql_conn_id=MYSQL_CONN_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME + ) + self.assertEqual(op.convert_type(value, schema_type), expected) @mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook') @mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook') def test_exec_success_json(self, gcs_hook_mock_class, mysql_hook_mock_class): """Test successful run of execute function for JSON""" op = MySQLToGCSOperator( - task_id=TASK_ID, - mysql_conn_id=MYSQL_CONN_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME) + task_id=TASK_ID, mysql_conn_id=MYSQL_CONN_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME + ) mysql_hook_mock = mysql_hook_mock_class.return_value mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) @@ -152,7 +141,8 @@ def test_exec_success_csv(self, gcs_hook_mock_class, mysql_hook_mock_class): sql=SQL, export_format='CSV', bucket=BUCKET, - filename=CSV_FILENAME) + filename=CSV_FILENAME, + ) mysql_hook_mock = mysql_hook_mock_class.return_value mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) @@ -186,7 +176,8 @@ def test_exec_success_csv_ensure_utc(self, gcs_hook_mock_class, mysql_hook_mock_ export_format='CSV', bucket=BUCKET, filename=CSV_FILENAME, - ensure_utc=True) + ensure_utc=True, + ) mysql_hook_mock = mysql_hook_mock_class.return_value mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) @@ -220,7 +211,8 @@ def test_exec_success_csv_with_delimiter(self, gcs_hook_mock_class, mysql_hook_m export_format='csv', field_delimiter='|', bucket=BUCKET, - filename=CSV_FILENAME) + filename=CSV_FILENAME, + ) mysql_hook_mock = mysql_hook_mock_class.return_value mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) @@ -271,7 +263,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, - approx_max_file_size_bytes=len(expected_upload[JSON_FILENAME.format(0)])) + approx_max_file_size_bytes=len(expected_upload[JSON_FILENAME.format(0)]), + ) op.execute(None) @mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook') @@ -293,11 +286,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab gcs_hook_mock.upload.side_effect = _assert_upload op = MySQLToGCSOperator( - task_id=TASK_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME, - schema_filename=SCHEMA_FILENAME) + task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME + ) op.execute(None) # once for the file and once for the schema @@ -327,7 +317,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME, - schema=SCHEMA) + schema=SCHEMA, + ) op.execute(None) # once for the file and once for the schema @@ -336,27 +327,23 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab @mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook') @mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook') def test_query_with_error(self, mock_gcs_hook, mock_mysql_hook): - mock_mysql_hook.return_value.get_conn.\ - return_value.cursor.return_value.execute.side_effect = ProgrammingError + mock_mysql_hook.return_value.get_conn.return_value.cursor.return_value.execute.side_effect = ( + ProgrammingError + ) op = MySQLToGCSOperator( - task_id=TASK_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME, - schema_filename=SCHEMA_FILENAME) + task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME + ) with self.assertRaises(ProgrammingError): op.query() @mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook') @mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook') def test_execute_with_query_error(self, mock_gcs_hook, mock_mysql_hook): - mock_mysql_hook.return_value.get_conn.\ - return_value.cursor.return_value.execute.side_effect = ProgrammingError + mock_mysql_hook.return_value.get_conn.return_value.cursor.return_value.execute.side_effect = ( + ProgrammingError + ) op = MySQLToGCSOperator( - task_id=TASK_ID, - sql=SQL, - bucket=BUCKET, - filename=JSON_FILENAME, - schema_filename=SCHEMA_FILENAME) + task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME + ) with self.assertRaises(ProgrammingError): op.execute(None) diff --git a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py index 9fc0da5a05ffa..0fa73be542df0 100644 --- a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py @@ -35,11 +35,13 @@ NDJSON_LINES = [ b'{"some_num": 42, "some_str": "mock_row_content_1"}\n', b'{"some_num": 43, "some_str": "mock_row_content_2"}\n', - b'{"some_num": 44, "some_str": "mock_row_content_3"}\n' + b'{"some_num": 44, "some_str": "mock_row_content_3"}\n', ] SCHEMA_FILENAME = 'schema_test.json' -SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \ - b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]' +SCHEMA_JSON = ( + b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' + b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]' +) @pytest.mark.backend("postgres") @@ -51,20 +53,16 @@ def setUpClass(cls): with conn.cursor() as cur: for table in TABLES: cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table)) - cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);" - .format(table)) + cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);".format(table)) cur.execute( - "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", - ('mock_row_content_1', 42) + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_1', 42) ) cur.execute( - "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", - ('mock_row_content_2', 43) + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_2', 43) ) cur.execute( - "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", - ('mock_row_content_3', 44) + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_3', 44) ) @classmethod @@ -77,8 +75,7 @@ def tearDownClass(cls): def test_init(self): """Test PostgresToGoogleCloudStorageOperator instance is properly initialized.""" - op = PostgresToGCSOperator( - task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=FILENAME) + op = PostgresToGCSOperator(task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=FILENAME) self.assertEqual(op.task_id, TASK_ID) self.assertEqual(op.sql, SQL) self.assertEqual(op.bucket, BUCKET) @@ -88,11 +85,8 @@ def test_init(self): def test_exec_success(self, gcs_hook_mock_class): """Test the execute function in case where the run is successful.""" op = PostgresToGCSOperator( - task_id=TASK_ID, - postgres_conn_id=POSTGRES_CONN_ID, - sql=SQL, - bucket=BUCKET, - filename=FILENAME) + task_id=TASK_ID, postgres_conn_id=POSTGRES_CONN_ID, sql=SQL, bucket=BUCKET, filename=FILENAME + ) gcs_hook_mock = gcs_hook_mock_class.return_value @@ -132,7 +126,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): sql=SQL, bucket=BUCKET, filename=FILENAME, - approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)])) + approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]), + ) op.execute(None) @patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook') @@ -149,11 +144,8 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab gcs_hook_mock.upload.side_effect = _assert_upload op = PostgresToGCSOperator( - task_id=TASK_ID, - sql=SQL, - bucket=BUCKET, - filename=FILENAME, - schema_filename=SCHEMA_FILENAME) + task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=FILENAME, schema_filename=SCHEMA_FILENAME + ) op.execute(None) # once for the file and once for the schema diff --git a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py index 27a0e9933f49e..9c5188234eb9c 100644 --- a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py @@ -103,9 +103,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): mock_presto_hook.assert_called_once_with(presto_conn_id=PRESTO_CONN_ID) mock_gcs_hook.assert_called_once_with( - delegate_to=None, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + delegate_to=None, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) mock_gcs_hook.return_value.upload.assert_called() @@ -240,9 +238,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): mock_presto_hook.assert_called_once_with(presto_conn_id=PRESTO_CONN_ID) mock_gcs_hook.assert_called_once_with( - delegate_to=None, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, + delegate_to=None, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook") diff --git a/tests/providers/google/cloud/transfers/test_presto_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_presto_to_gcs_system.py index 29788399a4cb5..b630b2e8044e2 100644 --- a/tests/providers/google/cloud/transfers/test_presto_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_presto_to_gcs_system.py @@ -32,6 +32,7 @@ # fool the pre-commit check that looks for old imports... # TODO remove this once we don't need to test this on 1.10 import importlib + db_module = importlib.import_module("airflow.utils.db") create_session = getattr(db_module, "create_session") @@ -148,9 +149,10 @@ def setUp(self): with suppress(Exception): self.drop_db() self.init_db() - self.execute_with_ctx([ - "bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}" - ], key=GCP_BIGQUERY_KEY) + self.execute_with_ctx( + ["bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}"], + key=GCP_BIGQUERY_KEY, + ) @provide_gcp_context(GCP_BIGQUERY_KEY) def test_run_example_dag(self): @@ -160,7 +162,8 @@ def test_run_example_dag(self): def tearDown(self): self.delete_gcs_bucket(GCS_BUCKET) self.drop_db() - self.execute_with_ctx([ - "bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}" - ], key=GCP_BIGQUERY_KEY) + self.execute_with_ctx( + ["bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}"], + key=GCP_BIGQUERY_KEY, + ) super().tearDown() diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py index a18e10e97171b..9f4baf67e0c1a 100644 --- a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py @@ -57,8 +57,7 @@ def test_init(self): @mock.patch('airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook') @mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook') - @mock.patch( - 'airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook') + @mock.patch('airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook') def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): """Test the execute function when the run is successful.""" @@ -80,8 +79,9 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): [ mock.call('gcs-bucket', 'data/TEST1.csv', mock.ANY, gzip=False), mock.call('gcs-bucket', 'data/TEST3.csv', mock.ANY, gzip=False), - mock.call('gcs-bucket', 'data/TEST2.csv', mock.ANY, gzip=False) - ], any_order=True + mock.call('gcs-bucket', 'data/TEST2.csv', mock.ANY, gzip=False), + ], + any_order=True, ) s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) @@ -97,8 +97,7 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): @mock.patch('airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook') @mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook') - @mock.patch( - 'airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook') + @mock.patch('airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook') def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): """Test the execute function when the run is successful.""" @@ -109,7 +108,7 @@ def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_ho delimiter=S3_DELIMITER, dest_gcs_conn_id=GCS_CONN_ID, dest_gcs=GCS_PATH_PREFIX, - gzip=True + gzip=True, ) s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES @@ -117,14 +116,13 @@ def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_ho operator.execute(None) gcs_mock_hook.assert_called_once_with( - google_cloud_storage_conn_id=GCS_CONN_ID, - delegate_to=None, - impersonation_chain=None, + google_cloud_storage_conn_id=GCS_CONN_ID, delegate_to=None, impersonation_chain=None, ) gcs_mock_hook.return_value.upload.assert_has_calls( [ mock.call('gcs-bucket', 'data/TEST2.csv', mock.ANY, gzip=True), mock.call('gcs-bucket', 'data/TEST1.csv', mock.ANY, gzip=True), - mock.call('gcs-bucket', 'data/TEST3.csv', mock.ANY, gzip=True) - ], any_order=True + mock.call('gcs-bucket', 'data/TEST3.csv', mock.ANY, gzip=True), + ], + any_order=True, ) diff --git a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py index 02cc9fdbd28ab..0305c0c751cc9 100644 --- a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py @@ -69,9 +69,7 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook): ) task.execute(None) gcs_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) sftp_hook.assert_called_once_with(SFTP_CONN_ID) @@ -104,9 +102,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook): ) task.execute(None) gcs_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) sftp_hook.assert_called_once_with(SFTP_CONN_ID) @@ -121,9 +117,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook): mime_type=DEFAULT_MIME_TYPE, ) - sftp_hook.return_value.delete_file.assert_called_once_with( - SOURCE_OBJECT_NO_WILDCARD - ) + sftp_hook.return_value.delete_file.assert_called_once_with(SOURCE_OBJECT_NO_WILDCARD) @mock.patch("airflow.providers.google.cloud.transfers.sftp_to_gcs.GCSHook") @mock.patch("airflow.providers.google.cloud.transfers.sftp_to_gcs.SFTPHook") @@ -197,10 +191,7 @@ def test_execute_move_with_wildcard(self, sftp_hook, gcs_hook): task.execute(None) sftp_hook.return_value.delete_file.assert_has_calls( - [ - mock.call("main_dir/test_object3.json"), - mock.call("main_dir/sub_dir/test_object3.json"), - ] + [mock.call("main_dir/test_object3.json"), mock.call("main_dir/sub_dir/test_object3.json"),] ) @mock.patch("airflow.providers.google.cloud.transfers.sftp_to_gcs.GCSHook") diff --git a/tests/providers/google/cloud/transfers/test_sftp_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_sftp_to_gcs_system.py index 260b3748de106..d3bab2fe99198 100644 --- a/tests/providers/google/cloud/transfers/test_sftp_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_sftp_to_gcs_system.py @@ -21,7 +21,13 @@ import pytest from airflow.providers.google.cloud.example_dags.example_sftp_to_gcs import ( - BUCKET_SRC, DIR, OBJECT_SRC_1, OBJECT_SRC_2, OBJECT_SRC_3, SUBDIR, TMP_PATH, + BUCKET_SRC, + DIR, + OBJECT_SRC_1, + OBJECT_SRC_2, + OBJECT_SRC_3, + SUBDIR, + TMP_PATH, ) from tests.providers.google.cloud.utils.gcp_authenticator import GCP_GCS_KEY from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context diff --git a/tests/providers/google/cloud/transfers/test_sheets_to_gcs.py b/tests/providers/google/cloud/transfers/test_sheets_to_gcs.py index 59dbdac4526ad..705613393ab30 100644 --- a/tests/providers/google/cloud/transfers/test_sheets_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sheets_to_gcs.py @@ -41,9 +41,7 @@ def test_upload_data(self, mock_tempfile, mock_writer): mock_tempfile.return_value.__enter__.return_value.name = filename mock_sheet_hook = mock.MagicMock() - mock_sheet_hook.get_spreadsheet.return_value = { - "properties": {"title": SHEET_TITLE} - } + mock_sheet_hook.get_spreadsheet.return_value = {"properties": {"title": SHEET_TITLE}} expected_dest_file = f"{PATH}/{SHEET_TITLE}_{RANGE}.csv" mock_gcs_hook = mock.MagicMock() @@ -57,10 +55,7 @@ def test_upload_data(self, mock_tempfile, mock_writer): ) result = op._upload_data( - gcs_hook=mock_gcs_hook, - hook=mock_sheet_hook, - sheet_range=RANGE, - sheet_values=VALUES, + gcs_hook=mock_gcs_hook, hook=mock_sheet_hook, sheet_range=RANGE, sheet_values=VALUES, ) # Test writing to file @@ -79,9 +74,7 @@ def test_upload_data(self, mock_tempfile, mock_writer): @mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GCSHook") @mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GSheetsHook") - @mock.patch( - "airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator.xcom_push" - ) + @mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator.xcom_push") @mock.patch( "airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator._upload_data" ) @@ -103,14 +96,10 @@ def test_execute(self, mock_upload_data, mock_xcom, mock_sheet_hook, mock_gcs_ho op.execute(context) mock_sheet_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=IMPERSONATION_CHAIN, ) mock_gcs_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=IMPERSONATION_CHAIN, ) mock_sheet_hook.return_value.get_sheet_titles.assert_called_once_with( @@ -120,10 +109,7 @@ def test_execute(self, mock_upload_data, mock_xcom, mock_sheet_hook, mock_gcs_ho calls = [mock.call(spreadsheet_id=SPREADSHEET_ID, range_=r) for r in RANGES] mock_sheet_hook.return_value.get_values.has_calls(calls) - calls = [ - mock.call(mock_gcs_hook, mock_sheet_hook, r, v) - for r, v in zip(RANGES, data) - ] + calls = [mock.call(mock_gcs_hook, mock_sheet_hook, r, v) for r, v in zip(RANGES, data)] mock_upload_data.has_calls(calls) mock_xcom.assert_called_once_with(context, "destination_objects", [PATH, PATH]) diff --git a/tests/providers/google/cloud/transfers/test_sheets_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_sheets_to_gcs_system.py index e80209b20eaf6..529b26f078ca2 100644 --- a/tests/providers/google/cloud/transfers/test_sheets_to_gcs_system.py +++ b/tests/providers/google/cloud/transfers/test_sheets_to_gcs_system.py @@ -25,7 +25,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class GoogleSheetsToGCSExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py index b55b6f8ac6e75..c8e654492b857 100644 --- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py @@ -29,26 +29,31 @@ BUCKET = "TEST-BUCKET-1" FILENAME = "test_results.csv" TASK_ID = "TEST_TASK_ID" -SCHEMA = [{"name": "column_a", "type": "3"}, - {"name": "column_b", "type": "253"}, - {"name": "column_c", "type": "10"}] +SCHEMA = [ + {"name": "column_a", "type": "3"}, + {"name": "column_b", "type": "253"}, + {"name": "column_c", "type": "10"}, +] COLUMNS = ["column_a", "column_b", "column_c"] ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"] TMP_FILE_NAME = "temp-file" -INPUT_DATA = [["101", "school", "2015-01-01"], - ["102", "business", "2017-05-24"], - ["103", "non-profit", "2018-10-01"]] -OUTPUT_DATA = json.dumps({ - "column_a": "convert_type_return_value", - "column_b": "convert_type_return_value", - "column_c": "convert_type_return_value" -}).encode("utf-8") +INPUT_DATA = [ + ["101", "school", "2015-01-01"], + ["102", "business", "2017-05-24"], + ["103", "non-profit", "2018-10-01"], +] +OUTPUT_DATA = json.dumps( + { + "column_a": "convert_type_return_value", + "column_b": "convert_type_return_value", + "column_c": "convert_type_return_value", + } +).encode("utf-8") SCHEMA_FILE = "schema_file.json" APP_JSON = "application/json" class DummySQLToGCSOperator(BaseSQLToGCSOperator): - def field_to_bigquery(self, field): pass @@ -60,20 +65,15 @@ def query(self): class TestBaseSQLToGCSOperator(unittest.TestCase): - @mock.patch("airflow.providers.google.cloud.transfers.sql_to_gcs.NamedTemporaryFile") @mock.patch.object(csv.writer, "writerow") @mock.patch.object(GCSHook, "upload") @mock.patch.object(DummySQLToGCSOperator, "query") @mock.patch.object(DummySQLToGCSOperator, "field_to_bigquery") @mock.patch.object(DummySQLToGCSOperator, "convert_type") - def test_exec(self, - mock_convert_type, - mock_field_to_bigquery, - mock_query, - mock_upload, - mock_writerow, - mock_tempfile): + def test_exec( + self, mock_convert_type, mock_field_to_bigquery, mock_query, mock_upload, mock_writerow, mock_tempfile + ): cursor_mock = Mock() cursor_mock.description = [("column_a", "3"), ("column_b", "253"), ("column_c", "10")] cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA)) @@ -99,23 +99,32 @@ def test_exec(self, mock_tempfile.return_value = mock_file - operator = DummySQLToGCSOperator(sql=SQL, - bucket=BUCKET, - filename=FILENAME, - task_id=TASK_ID, - schema_filename=SCHEMA_FILE, - approx_max_file_size_bytes=1, - export_format="csv", - gzip=True, - schema=SCHEMA, - google_cloud_storage_conn_id='google_cloud_default') + operator = DummySQLToGCSOperator( + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + task_id=TASK_ID, + schema_filename=SCHEMA_FILE, + approx_max_file_size_bytes=1, + export_format="csv", + gzip=True, + schema=SCHEMA, + google_cloud_storage_conn_id='google_cloud_default', + ) operator.execute(context=dict()) mock_query.assert_called_once() - mock_writerow.assert_has_calls([mock.call(COLUMNS), mock.call(ROW), - mock.call(COLUMNS), mock.call(ROW), - mock.call(COLUMNS), mock.call(ROW), - mock.call(COLUMNS)]) + mock_writerow.assert_has_calls( + [ + mock.call(COLUMNS), + mock.call(ROW), + mock.call(COLUMNS), + mock.call(ROW), + mock.call(COLUMNS), + mock.call(ROW), + mock.call(COLUMNS), + ] + ) mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()]) csv_call = mock.call(BUCKET, FILENAME, TMP_FILE_NAME, mime_type='text/csv', gzip=True) json_call = mock.call(BUCKET, SCHEMA_FILE, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False) @@ -131,21 +140,22 @@ def test_exec(self, cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA)) - operator = DummySQLToGCSOperator(sql=SQL, - bucket=BUCKET, - filename=FILENAME, - task_id=TASK_ID, - export_format="json", - schema=SCHEMA) + operator = DummySQLToGCSOperator( + sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA + ) operator.execute(context=dict()) mock_query.assert_called_once() - mock_write.assert_has_calls([mock.call(OUTPUT_DATA), - mock.call(b"\n"), - mock.call(OUTPUT_DATA), - mock.call(b"\n"), - mock.call(OUTPUT_DATA), - mock.call(b"\n")]) + mock_write.assert_has_calls( + [ + mock.call(OUTPUT_DATA), + mock.call(b"\n"), + mock.call(OUTPUT_DATA), + mock.call(b"\n"), + mock.call(OUTPUT_DATA), + mock.call(b"\n"), + ] + ) mock_flush.assert_called_once() mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False) mock_close.assert_called_once() diff --git a/tests/providers/google/cloud/utils/base_gcp_mock.py b/tests/providers/google/cloud/utils/base_gcp_mock.py index d20696b2f362f..99b97c51db5b3 100644 --- a/tests/providers/google/cloud/utils/base_gcp_mock.py +++ b/tests/providers/google/cloud/utils/base_gcp_mock.py @@ -24,14 +24,9 @@ def mock_base_gcp_hook_default_project_id( - self, - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ): - self.extras = { - 'extra__google_cloud_platform__project': GCP_PROJECT_ID_HOOK_UNIT_TEST - } + self.extras = {'extra__google_cloud_platform__project': GCP_PROJECT_ID_HOOK_UNIT_TEST} self._conn = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain @@ -42,10 +37,7 @@ def mock_base_gcp_hook_default_project_id( def mock_base_gcp_hook_no_default_project_id( - self, - gcp_conn_id='google_cloud_default', - delegate_to=None, - impersonation_chain=None, + self, gcp_conn_id='google_cloud_default', delegate_to=None, impersonation_chain=None, ): self.extras = {} self._conn = gcp_conn_id @@ -58,14 +50,10 @@ def mock_base_gcp_hook_no_default_project_id( GCP_CONNECTION_WITH_PROJECT_ID = Connection( - extra=json.dumps({ - 'extra__google_cloud_platform__project': GCP_PROJECT_ID_HOOK_UNIT_TEST - }) + extra=json.dumps({'extra__google_cloud_platform__project': GCP_PROJECT_ID_HOOK_UNIT_TEST}) ) -GCP_CONNECTION_WITHOUT_PROJECT_ID = Connection( - extra=json.dumps({}) -) +GCP_CONNECTION_WITHOUT_PROJECT_ID = Connection(extra=json.dumps({})) def get_open_mock(): diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py index feb03b2923aa4..02a6baac8c780 100644 --- a/tests/providers/google/cloud/utils/gcp_authenticator.py +++ b/tests/providers/google/cloud/utils/gcp_authenticator.py @@ -23,6 +23,7 @@ from airflow import settings from airflow.exceptions import AirflowException from airflow.models import Connection + # Please keep these variables in alphabetical order. from tests.test_utils import AIRFLOW_MAIN_FOLDER from tests.utils.logging_command_executor import LoggingCommandExecutor diff --git a/tests/providers/google/cloud/utils/test_credentials_provider.py b/tests/providers/google/cloud/utils/test_credentials_provider.py index 100baaf8717a9..efbffb5f99ce3 100644 --- a/tests/providers/google/cloud/utils/test_credentials_provider.py +++ b/tests/providers/google/cloud/utils/test_credentials_provider.py @@ -28,9 +28,16 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.utils.credentials_provider import ( - _DEFAULT_SCOPES, AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT, _get_project_id_from_service_account_email, - _get_scopes, _get_target_principal_and_delegates, build_gcp_conn, get_credentials_and_project_id, - provide_gcp_conn_and_credentials, provide_gcp_connection, provide_gcp_credentials, + _DEFAULT_SCOPES, + AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT, + _get_project_id_from_service_account_email, + _get_scopes, + _get_target_principal_and_delegates, + build_gcp_conn, + get_credentials_and_project_id, + provide_gcp_conn_and_credentials, + provide_gcp_connection, + provide_gcp_credentials, ) ENV_VALUE = "test_env" @@ -47,24 +54,19 @@ class TestHelper(unittest.TestCase): def test_build_gcp_conn_path(self): value = "test" conn = build_gcp_conn(key_file_path=value) - self.assertEqual( - "google-cloud-platform://?extra__google_cloud_platform__key_path=test", conn - ) + self.assertEqual("google-cloud-platform://?extra__google_cloud_platform__key_path=test", conn) def test_build_gcp_conn_scopes(self): value = ["test", "test2"] conn = build_gcp_conn(scopes=value) self.assertEqual( - "google-cloud-platform://?extra__google_cloud_platform__scope=test%2Ctest2", - conn, + "google-cloud-platform://?extra__google_cloud_platform__scope=test%2Ctest2", conn, ) def test_build_gcp_conn_project(self): value = "test" conn = build_gcp_conn(project_id=value) - self.assertEqual( - "google-cloud-platform://?extra__google_cloud_platform__projects=test", conn - ) + self.assertEqual("google-cloud-platform://?extra__google_cloud_platform__projects=test", conn) class TestProvideGcpCredentials(unittest.TestCase): @@ -101,19 +103,14 @@ def test_provide_gcp_connection(self, mock_builder): scopes = ["scopes"] project_id = "project_id" with provide_gcp_connection(path, scopes, project_id): - mock_builder.assert_called_once_with( - key_file_path=path, scopes=scopes, project_id=project_id - ) - self.assertEqual( - os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], TEMP_VARIABLE - ) + mock_builder.assert_called_once_with(key_file_path=path, scopes=scopes, project_id=project_id) + self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], TEMP_VARIABLE) self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], ENV_VALUE) class TestProvideGcpConnAndCredentials(unittest.TestCase): @mock.patch.dict( - os.environ, - {AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: ENV_VALUE, CREDENTIALS: ENV_VALUE}, + os.environ, {AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: ENV_VALUE, CREDENTIALS: ENV_VALUE}, ) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.build_gcp_conn") def test_provide_gcp_conn_and_credentials(self, mock_builder): @@ -122,12 +119,8 @@ def test_provide_gcp_conn_and_credentials(self, mock_builder): scopes = ["scopes"] project_id = "project_id" with provide_gcp_conn_and_credentials(path, scopes, project_id): - mock_builder.assert_called_once_with( - key_file_path=path, scopes=scopes, project_id=project_id - ) - self.assertEqual( - os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], TEMP_VARIABLE - ) + mock_builder.assert_called_once_with(key_file_path=path, scopes=scopes, project_id=project_id) + self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], TEMP_VARIABLE) self.assertEqual(os.environ[CREDENTIALS], path) self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], ENV_VALUE) self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE) @@ -146,11 +139,14 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_auth_defaul result = get_credentials_and_project_id() mock_auth_default.assert_called_once_with(scopes=None) self.assertEqual(("CREDENTIALS", "PROJECT_ID"), result) - self.assertEqual([ - 'INFO:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting ' - 'connection using `google.auth.default()` since no key file is defined for ' - 'hook.' - ], cm.output) + self.assertEqual( + [ + 'INFO:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting ' + 'connection using `google.auth.default()` since no key file is defined for ' + 'hook.' + ], + cm.output, + ) @mock.patch('google.auth.default') def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, mock_auth_default): @@ -162,10 +158,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, moc mock_credentials.with_subject.assert_called_once_with("USER") self.assertEqual((mock_credentials.with_subject.return_value, self.test_project_id), result) - @parameterized.expand([ - (['scope1'], ), - (['scope1', 'scope2'], ) - ]) + @parameterized.expand([(['scope1'],), (['scope1', 'scope2'],)]) @mock.patch('google.auth.default') def test_get_credentials_and_project_id_with_default_auth_and_scopes(self, scopes, mock_auth_default): mock_credentials = mock.MagicMock() @@ -175,8 +168,9 @@ def test_get_credentials_and_project_id_with_default_auth_and_scopes(self, scope mock_auth_default.assert_called_once_with(scopes=scopes) self.assertEqual(mock_auth_default.return_value, result) - @mock.patch('airflow.providers.google.cloud.utils.credentials_provider.' - 'impersonated_credentials.Credentials') + @mock.patch( + 'airflow.providers.google.cloud.utils.credentials_provider.' 'impersonated_credentials.Credentials' + ) @mock.patch('google.auth.default') def test_get_credentials_and_project_id_with_default_auth_and_target_principal( self, mock_auth_default, mock_impersonated_credentials @@ -184,9 +178,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_target_principal( mock_credentials = mock.MagicMock() mock_auth_default.return_value = (mock_credentials, self.test_project_id) - result = get_credentials_and_project_id( - target_principal=ACCOUNT_3_ANOTHER_PROJECT, - ) + result = get_credentials_and_project_id(target_principal=ACCOUNT_3_ANOTHER_PROJECT,) mock_auth_default.assert_called_once_with(scopes=None) mock_impersonated_credentials.assert_called_once_with( source_credentials=mock_credentials, @@ -196,8 +188,9 @@ def test_get_credentials_and_project_id_with_default_auth_and_target_principal( ) self.assertEqual((mock_impersonated_credentials.return_value, ANOTHER_PROJECT_ID), result) - @mock.patch('airflow.providers.google.cloud.utils.credentials_provider.' - 'impersonated_credentials.Credentials') + @mock.patch( + 'airflow.providers.google.cloud.utils.credentials_provider.' 'impersonated_credentials.Credentials' + ) @mock.patch('google.auth.default') def test_get_credentials_and_project_id_with_default_auth_and_scopes_and_target_principal( self, mock_auth_default, mock_impersonated_credentials @@ -206,8 +199,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_scopes_and_target_ mock_auth_default.return_value = (mock_credentials, self.test_project_id) result = get_credentials_and_project_id( - scopes=['scope1', 'scope2'], - target_principal=ACCOUNT_1_SAME_PROJECT, + scopes=['scope1', 'scope2'], target_principal=ACCOUNT_1_SAME_PROJECT, ) mock_auth_default.assert_called_once_with(scopes=['scope1', 'scope2']) mock_impersonated_credentials.assert_called_once_with( @@ -218,8 +210,9 @@ def test_get_credentials_and_project_id_with_default_auth_and_scopes_and_target_ ) self.assertEqual((mock_impersonated_credentials.return_value, self.test_project_id), result) - @mock.patch('airflow.providers.google.cloud.utils.credentials_provider.' - 'impersonated_credentials.Credentials') + @mock.patch( + 'airflow.providers.google.cloud.utils.credentials_provider.' 'impersonated_credentials.Credentials' + ) @mock.patch('google.auth.default') def test_get_credentials_and_project_id_with_default_auth_and_target_principal_and_delegates( self, mock_auth_default, mock_impersonated_credentials @@ -240,60 +233,51 @@ def test_get_credentials_and_project_id_with_default_auth_and_target_principal_a ) self.assertEqual((mock_impersonated_credentials.return_value, ANOTHER_PROJECT_ID), result) - @mock.patch( - 'google.oauth2.service_account.Credentials.from_service_account_file', - ) + @mock.patch('google.oauth2.service_account.Credentials.from_service_account_file',) def test_get_credentials_and_project_id_with_service_account_file(self, mock_from_service_account_file): mock_from_service_account_file.return_value.project_id = self.test_project_id with self.assertLogs(level="DEBUG") as cm: result = get_credentials_and_project_id(key_path=self.test_key_file) mock_from_service_account_file.assert_called_once_with(self.test_key_file, scopes=None) self.assertEqual((mock_from_service_account_file.return_value, self.test_project_id), result) - self.assertEqual([ - 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting ' - 'connection using JSON key file KEY_PATH.json' - ], cm.output) - - @parameterized.expand([ - ("p12", "path/to/file.p12"), - ("unknown", "incorrect_file.ext") - ]) + self.assertEqual( + [ + 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting ' + 'connection using JSON key file KEY_PATH.json' + ], + cm.output, + ) + + @parameterized.expand([("p12", "path/to/file.p12"), ("unknown", "incorrect_file.ext")]) def test_get_credentials_and_project_id_with_service_account_file_and_non_valid_key(self, _, file): with self.assertRaises(AirflowException): get_credentials_and_project_id(key_path=file) - @mock.patch( - 'google.oauth2.service_account.Credentials.from_service_account_info', - ) + @mock.patch('google.oauth2.service_account.Credentials.from_service_account_info',) def test_get_credentials_and_project_id_with_service_account_info(self, mock_from_service_account_info): mock_from_service_account_info.return_value.project_id = self.test_project_id - service_account = { - 'private_key': "PRIVATE_KEY" - } + service_account = {'private_key': "PRIVATE_KEY"} with self.assertLogs(level="DEBUG") as cm: result = get_credentials_and_project_id(keyfile_dict=service_account) mock_from_service_account_info.assert_called_once_with(service_account, scopes=None) self.assertEqual((mock_from_service_account_info.return_value, self.test_project_id), result) - self.assertEqual([ - 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting ' - 'connection using JSON Dict' - ], cm.output) + self.assertEqual( + [ + 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting ' + 'connection using JSON Dict' + ], + cm.output, + ) - def test_get_credentials_and_project_id_with_mutually_exclusive_configuration( - self, - ): - with self.assertRaisesRegex(AirflowException, re.escape( - 'The `keyfile_dict` and `key_path` fields are mutually exclusive.' - )): + def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(self,): + with self.assertRaisesRegex( + AirflowException, re.escape('The `keyfile_dict` and `key_path` fields are mutually exclusive.') + ): get_credentials_and_project_id(key_path='KEY.json', keyfile_dict={'private_key': 'PRIVATE_KEY'}) @mock.patch("google.auth.default", return_value=("CREDENTIALS", "PROJECT_ID")) - @mock.patch( - 'google.oauth2.service_account.Credentials.from_service_account_info', - ) - @mock.patch( - 'google.oauth2.service_account.Credentials.from_service_account_file', - ) + @mock.patch('google.oauth2.service_account.Credentials.from_service_account_info',) + @mock.patch('google.oauth2.service_account.Credentials.from_service_account_file',) def test_disable_logging(self, mock_default, mock_info, mock_file): # assert not logs with self.assertRaises(AssertionError), self.assertLogs(level="DEBUG"): @@ -302,59 +286,55 @@ def test_disable_logging(self, mock_default, mock_info, mock_file): # assert not logs with self.assertRaises(AssertionError), self.assertLogs(level="DEBUG"): get_credentials_and_project_id( - keyfile_dict={'private_key': 'PRIVATE_KEY'}, - disable_logging=True, + keyfile_dict={'private_key': 'PRIVATE_KEY'}, disable_logging=True, ) # assert not logs with self.assertRaises(AssertionError), self.assertLogs(level="DEBUG"): get_credentials_and_project_id( - key_path='KEY.json', - disable_logging=True, + key_path='KEY.json', disable_logging=True, ) class TestGetScopes(unittest.TestCase): - def test_get_scopes_with_default(self): self.assertEqual(_get_scopes(), _DEFAULT_SCOPES) - @parameterized.expand([ - ('single_scope', 'scope1', ['scope1']), - ('multiple_scopes', 'scope1,scope2', ['scope1', 'scope2']), - ]) + @parameterized.expand( + [('single_scope', 'scope1', ['scope1']), ('multiple_scopes', 'scope1,scope2', ['scope1', 'scope2']),] + ) def test_get_scopes_with_input(self, _, scopes_str, scopes): self.assertEqual(_get_scopes(scopes_str), scopes) class TestGetTargetPrincipalAndDelegates(unittest.TestCase): - def test_get_target_principal_and_delegates_no_argument(self): self.assertEqual(_get_target_principal_and_delegates(), (None, None)) - @parameterized.expand([ - ('string', ACCOUNT_1_SAME_PROJECT, (ACCOUNT_1_SAME_PROJECT, None)), - ('empty_list', [], (None, None)), - ('single_element_list', [ACCOUNT_1_SAME_PROJECT], (ACCOUNT_1_SAME_PROJECT, [])), - ('multiple_elements_list', - [ACCOUNT_1_SAME_PROJECT, ACCOUNT_2_SAME_PROJECT, ACCOUNT_3_ANOTHER_PROJECT], - (ACCOUNT_3_ANOTHER_PROJECT, [ACCOUNT_1_SAME_PROJECT, ACCOUNT_2_SAME_PROJECT])), - ]) + @parameterized.expand( + [ + ('string', ACCOUNT_1_SAME_PROJECT, (ACCOUNT_1_SAME_PROJECT, None)), + ('empty_list', [], (None, None)), + ('single_element_list', [ACCOUNT_1_SAME_PROJECT], (ACCOUNT_1_SAME_PROJECT, [])), + ( + 'multiple_elements_list', + [ACCOUNT_1_SAME_PROJECT, ACCOUNT_2_SAME_PROJECT, ACCOUNT_3_ANOTHER_PROJECT], + (ACCOUNT_3_ANOTHER_PROJECT, [ACCOUNT_1_SAME_PROJECT, ACCOUNT_2_SAME_PROJECT]), + ), + ] + ) def test_get_target_principal_and_delegates_with_input( self, _, impersonation_chain, target_principal_and_delegates ): self.assertEqual( - _get_target_principal_and_delegates(impersonation_chain), - target_principal_and_delegates + _get_target_principal_and_delegates(impersonation_chain), target_principal_and_delegates ) class TestGetProjectIdFromServiceAccountEmail(unittest.TestCase): - def test_get_project_id_from_service_account_email(self,): self.assertEqual( - _get_project_id_from_service_account_email(ACCOUNT_3_ANOTHER_PROJECT), - ANOTHER_PROJECT_ID, + _get_project_id_from_service_account_email(ACCOUNT_3_ANOTHER_PROJECT), ANOTHER_PROJECT_ID, ) def test_get_project_id_from_service_account_email_wrong_input(self): diff --git a/tests/providers/google/cloud/utils/test_field_sanitizer.py b/tests/providers/google/cloud/utils/test_field_sanitizer.py index 91cac96355dc6..b3868c8df5a8c 100644 --- a/tests/providers/google/cloud/utils/test_field_sanitizer.py +++ b/tests/providers/google/cloud/utils/test_field_sanitizer.py @@ -91,21 +91,21 @@ def test_sanitize_should_remove_for_multiple_fields_from_root_level(self): self.assertEqual({}, body) def test_sanitize_should_remove_all_fields_in_a_list_value(self): - body = {"fields": [ - {"kind": "compute#instanceTemplate", "name": "instance"}, - {"kind": "compute#instanceTemplate1", "name": "instance1"}, - {"kind": "compute#instanceTemplate2", "name": "instance2"}, - ]} + body = { + "fields": [ + {"kind": "compute#instanceTemplate", "name": "instance"}, + {"kind": "compute#instanceTemplate1", "name": "instance1"}, + {"kind": "compute#instanceTemplate2", "name": "instance2"}, + ] + } fields_to_sanitize = ["fields.kind"] sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize) sanitizer.sanitize(body) - self.assertEqual({"fields": [ - {"name": "instance"}, - {"name": "instance1"}, - {"name": "instance2"}, - ]}, body) + self.assertEqual( + {"fields": [{"name": "instance"}, {"name": "instance1"}, {"name": "instance2"},]}, body + ) def test_sanitize_should_remove_all_fields_in_any_nested_body(self): fields_to_sanitize = [ @@ -130,37 +130,27 @@ def test_sanitize_should_remove_all_fields_in_any_nested_body(self): "kind": "compute#attachedDisk", "type": "PERSISTENT", "mode": "READ_WRITE", - } + }, ], - "metadata": { - "kind": "compute#metadata", - "fingerprint": "GDPUYxlwHe4=" - }, - } + "metadata": {"kind": "compute#metadata", "fingerprint": "GDPUYxlwHe4="}, + }, } sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize) sanitizer.sanitize(body) - self.assertEqual({ - "name": "instance", - "properties": { - "disks": [ - { - "name": "a", - "type": "PERSISTENT", - "mode": "READ_WRITE" - }, - { - "name": "b", - "type": "PERSISTENT", - "mode": "READ_WRITE" - } - ], - "metadata": { - "fingerprint": "GDPUYxlwHe4=" - } - } - }, body) + self.assertEqual( + { + "name": "instance", + "properties": { + "disks": [ + {"name": "a", "type": "PERSISTENT", "mode": "READ_WRITE"}, + {"name": "b", "type": "PERSISTENT", "mode": "READ_WRITE"}, + ], + "metadata": {"fingerprint": "GDPUYxlwHe4="}, + }, + }, + body, + ) def test_sanitize_should_not_fail_if_specification_has_none_value(self): fields_to_sanitize = [ @@ -169,23 +159,12 @@ def test_sanitize_should_not_fail_if_specification_has_none_value(self): "properties.metadata.kind", ] - body = { - "kind": "compute#instanceTemplate", - "name": "instance", - "properties": { - "disks": None - } - } + body = {"kind": "compute#instanceTemplate", "name": "instance", "properties": {"disks": None}} sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize) sanitizer.sanitize(body) - self.assertEqual({ - "name": "instance", - "properties": { - "disks": None - } - }, body) + self.assertEqual({"name": "instance", "properties": {"disks": None}}, body) def test_sanitize_should_not_fail_if_no_specification_matches(self): fields_to_sanitize = [ @@ -193,22 +172,12 @@ def test_sanitize_should_not_fail_if_no_specification_matches(self): "properties.metadata.kind2", ] - body = { - "name": "instance", - "properties": { - "disks": None - } - } + body = {"name": "instance", "properties": {"disks": None}} sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize) sanitizer.sanitize(body) - self.assertEqual({ - "name": "instance", - "properties": { - "disks": None - } - }, body) + self.assertEqual({"name": "instance", "properties": {"disks": None}}, body) def test_sanitize_should_not_fail_if_type_in_body_do_not_match_with_specification(self): fields_to_sanitize = [ @@ -216,19 +185,9 @@ def test_sanitize_should_not_fail_if_type_in_body_do_not_match_with_specificatio "properties.metadata.kind2", ] - body = { - "name": "instance", - "properties": { - "disks": 1 - } - } + body = {"name": "instance", "properties": {"disks": 1}} sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize) sanitizer.sanitize(body) - self.assertEqual({ - "name": "instance", - "properties": { - "disks": 1 - } - }, body) + self.assertEqual({"name": "instance", "properties": {"disks": 1}}, body) diff --git a/tests/providers/google/cloud/utils/test_field_validator.py b/tests/providers/google/cloud/utils/test_field_validator.py index c57ae7df43e98..8f502bd535315 100644 --- a/tests/providers/google/cloud/utils/test_field_validator.py +++ b/tests/providers/google/cloud/utils/test_field_validator.py @@ -18,7 +18,9 @@ import unittest from airflow.providers.google.cloud.utils.field_validator import ( - GcpBodyFieldValidator, GcpFieldValidationException, GcpValidationSpecificationException, + GcpBodyFieldValidator, + GcpFieldValidationException, + GcpValidationSpecificationException, ) @@ -179,9 +181,12 @@ def test_validate_should_allow_type_and_optional_in_a_spec(self): def test_validate_should_fail_if_union_field_is_not_found(self): specification = [ - dict(name="an_union", type="union", optional=False, fields=[ - dict(name="variant_1", regexp=r'^.+$', optional=False, allow_empty=False), - ]) + dict( + name="an_union", + type="union", + optional=False, + fields=[dict(name="variant_1", regexp=r'^.+$', optional=False, allow_empty=False),], + ) ] body = {} @@ -189,9 +194,7 @@ def test_validate_should_fail_if_union_field_is_not_found(self): self.assertIsNone(validator.validate(body)) def test_validate_should_fail_if_there_is_no_nested_field_for_union(self): - specification = [ - dict(name="an_union", type="union", optional=False, fields=[]) - ] + specification = [dict(name="an_union", type="union", optional=False, fields=[])] body = {} validator = GcpBodyFieldValidator(specification, 'v1') @@ -201,9 +204,7 @@ def test_validate_should_fail_if_there_is_no_nested_field_for_union(self): def test_validate_should_interpret_union_with_one_field(self): specification = [ - dict(name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'^.+$'), - ]) + dict(name="an_union", type="union", fields=[dict(name="variant_1", regexp=r'^.+$'),]) ] body = {"variant_1": "abc", "variant_2": "def"} @@ -212,10 +213,11 @@ def test_validate_should_interpret_union_with_one_field(self): def test_validate_should_fail_if_both_field_of_union_is_present(self): specification = [ - dict(name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'^.+$'), - dict(name="variant_2", regexp=r'^.+$'), - ]) + dict( + name="an_union", + type="union", + fields=[dict(name="variant_1", regexp=r'^.+$'), dict(name="variant_2", regexp=r'^.+$'),], + ) ] body = {"variant_1": "abc", "variant_2": "def"} @@ -225,9 +227,7 @@ def test_validate_should_fail_if_both_field_of_union_is_present(self): def test_validate_should_validate_when_value_matches_regex(self): specification = [ - dict(name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'[^a-z]'), - ]) + dict(name="an_union", type="union", fields=[dict(name="variant_1", regexp=r'[^a-z]'),]) ] body = {"variant_1": "12"} @@ -236,9 +236,7 @@ def test_validate_should_validate_when_value_matches_regex(self): def test_validate_should_fail_when_value_does_not_match_regex(self): specification = [ - dict(name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'[^a-z]'), - ]) + dict(name="an_union", type="union", fields=[dict(name="variant_1", regexp=r'[^a-z]'),]) ] body = {"variant_1": "abc"} @@ -251,9 +249,7 @@ def _int_equal_to_zero(value): if int(value) != 0: raise GcpFieldValidationException("The available memory has to be equal to 0") - specification = [ - dict(name="availableMemoryMb", custom_validation=_int_equal_to_zero) - ] + specification = [dict(name="availableMemoryMb", custom_validation=_int_equal_to_zero)] body = {"availableMemoryMb": 1} validator = GcpBodyFieldValidator(specification, 'v1') @@ -265,9 +261,7 @@ def _int_equal_to_zero(value): if int(value) != 0: raise GcpFieldValidationException("The available memory has to be equal to 0") - specification = [ - dict(name="availableMemoryMb", custom_validation=_int_equal_to_zero) - ] + specification = [dict(name="availableMemoryMb", custom_validation=_int_equal_to_zero)] body = {"availableMemoryMb": 0} validator = GcpBodyFieldValidator(specification, 'v1') @@ -278,14 +272,16 @@ def test_validate_should_validate_group_of_specs(self): dict(name="name", allow_empty=False), dict(name="description", allow_empty=False, optional=True), dict(name="labels", optional=True, type="dict"), - dict(name="an_union", type="union", fields=[ - dict(name="variant_1", regexp=r'^.+$'), - dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), - dict(name="variant_3", type="dict", fields=[ - dict(name="url", regexp=r'^.+$') - ]), - dict(name="variant_4") - ]), + dict( + name="an_union", + type="union", + fields=[ + dict(name="variant_1", regexp=r'^.+$'), + dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'), + dict(name="variant_3", type="dict", fields=[dict(name="url", regexp=r'^.+$')]), + dict(name="variant_4"), + ], + ), ] body = {"variant_1": "abc", "name": "bigquery"} diff --git a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py index 326c14ecdd190..18981e13151df 100644 --- a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py +++ b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py @@ -35,9 +35,11 @@ TASK_PREFIX_SUMMARY = TASK_PREFIX + "-summary" TASK_PREFIX_VALIDATION = TASK_PREFIX + "-validation" DATA_FORMAT = "TEXT" -INPUT_PATHS = ["gs://path/to/input/file.json", - "gs://path/to/input/file2.json", - "gs://path/to/input/file3.json"] +INPUT_PATHS = [ + "gs://path/to/input/file.json", + "gs://path/to/input/file2.json", + "gs://path/to/input/file3.json", +] PREDICTION_PATH = "gs://path/to/output/predictions.json" BATCH_PREDICTION_JOB_ID = "test-batch-prediction-job-id" PROJECT_ID = "test-project-id" @@ -45,7 +47,7 @@ DATAFLOW_OPTIONS = { "project": "my-gcp-project", "zone": "us-central1-f", - "stagingLocation": "gs://bucket/tmp/dataflow/staging/" + "stagingLocation": "gs://bucket/tmp/dataflow/staging/", } MODEL_URI = "gs://path/to/model/model" MODEL_NAME = "test-model-name" @@ -55,7 +57,7 @@ "region": REGION, "model_name": MODEL_NAME, "version_name": VERSION_NAME, - "dataflow_default_options": DATAFLOW_OPTIONS + "dataflow_default_options": DATAFLOW_OPTIONS, } TEST_DAG = DAG(dag_id="test-dag-id", start_date=datetime(2000, 1, 1), default_args=DAG_DEFAULT_ARGS) @@ -92,18 +94,19 @@ class TestMlengineOperatorUtils(unittest.TestCase): @mock.patch.object(PythonOperator, "set_upstream") @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") def test_create_evaluate_ops(self, mock_dataflow, mock_python): - result = create_evaluate_ops(task_prefix=TASK_PREFIX, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn=validate_err_and_count, - batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, - project_id=PROJECT_ID, - region=REGION, - dataflow_options=DATAFLOW_OPTIONS, - model_uri=MODEL_URI - ) + result = create_evaluate_ops( + task_prefix=TASK_PREFIX, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn=validate_err_and_count, + batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, + project_id=PROJECT_ID, + region=REGION, + dataflow_options=DATAFLOW_OPTIONS, + model_uri=MODEL_URI, + ) evaluate_prediction, evaluate_summary, evaluate_validation = result @@ -131,19 +134,20 @@ def test_create_evaluate_ops(self, mock_dataflow, mock_python): @mock.patch.object(PythonOperator, "set_upstream") @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_python): - result = create_evaluate_ops(task_prefix=TASK_PREFIX, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn=validate_err_and_count, - batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, - project_id=PROJECT_ID, - region=REGION, - dataflow_options=DATAFLOW_OPTIONS, - model_name=MODEL_NAME, - version_name=VERSION_NAME - ) + result = create_evaluate_ops( + task_prefix=TASK_PREFIX, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn=validate_err_and_count, + batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, + project_id=PROJECT_ID, + region=REGION, + dataflow_options=DATAFLOW_OPTIONS, + model_name=MODEL_NAME, + version_name=VERSION_NAME, + ) evaluate_prediction, evaluate_summary, evaluate_validation = result @@ -172,15 +176,16 @@ def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_py @mock.patch.object(PythonOperator, "set_upstream") @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python): - result = create_evaluate_ops(task_prefix=TASK_PREFIX, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn=validate_err_and_count, - batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, - dag=TEST_DAG - ) + result = create_evaluate_ops( + task_prefix=TASK_PREFIX, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn=validate_err_and_count, + batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, + dag=TEST_DAG, + ) evaluate_prediction, evaluate_summary, evaluate_validation = result @@ -210,32 +215,30 @@ def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python): @mock.patch.object(PythonOperator, "set_upstream") @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") def test_apply_validate_fn(self, mock_dataflow, mock_python, mock_download): - result = create_evaluate_ops(task_prefix=TASK_PREFIX, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn=validate_err_and_count, - batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, - project_id=PROJECT_ID, - region=REGION, - dataflow_options=DATAFLOW_OPTIONS, - model_uri=MODEL_URI - ) + result = create_evaluate_ops( + task_prefix=TASK_PREFIX, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn=validate_err_and_count, + batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, + project_id=PROJECT_ID, + region=REGION, + dataflow_options=DATAFLOW_OPTIONS, + model_uri=MODEL_URI, + ) _, _, evaluate_validation = result - mock_download.return_value = json.dumps({ - "err": 0.3, - "mse": 0.04, - "count": 1100 - }) + mock_download.return_value = json.dumps({"err": 0.3, "mse": 0.04, "count": 1100}) templates_dict = {"prediction_path": PREDICTION_PATH} with self.assertRaises(ValueError) as context: evaluate_validation.python_callable(templates_dict=templates_dict) - self.assertEqual("Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}", - str(context.exception)) + self.assertEqual( + "Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}", str(context.exception) + ) mock_download.assert_called_once_with("path", "to/output/predictions.json/prediction.summary.json") invalid_prediction_paths = ["://path/to/output/predictions.json", "gs://", ""] @@ -251,27 +254,33 @@ def test_invalid_task_prefix(self): for invalid_task_prefix_value in invalid_task_prefix_values: with self.assertRaises(AirflowException): - create_evaluate_ops(task_prefix=invalid_task_prefix_value, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn=validate_err_and_count) + create_evaluate_ops( + task_prefix=invalid_task_prefix_value, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn=validate_err_and_count, + ) def test_non_callable_metric_fn(self): with self.assertRaises(AirflowException): - create_evaluate_ops(task_prefix=TASK_PREFIX, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=("error_and_squared_error", ['err', 'mse']), - validate_fn=validate_err_and_count) + create_evaluate_ops( + task_prefix=TASK_PREFIX, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=("error_and_squared_error", ['err', 'mse']), + validate_fn=validate_err_and_count, + ) def test_non_callable_validate_fn(self): with self.assertRaises(AirflowException): - create_evaluate_ops(task_prefix=TASK_PREFIX, - data_format=DATA_FORMAT, - input_paths=INPUT_PATHS, - prediction_path=PREDICTION_PATH, - metric_fn_and_keys=get_metric_fn_and_keys(), - validate_fn="validate_err_and_count") + create_evaluate_ops( + task_prefix=TASK_PREFIX, + data_format=DATA_FORMAT, + input_paths=INPUT_PATHS, + prediction_path=PREDICTION_PATH, + metric_fn_and_keys=get_metric_fn_and_keys(), + validate_fn="validate_err_and_count", + ) diff --git a/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py b/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py index ae9e0e74fb7cf..78c6d5ea90598 100644 --- a/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py +++ b/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py @@ -46,34 +46,33 @@ def test_run_without_all_arguments_should_raise_exception(self): mlengine_prediction_summary.run() with self.assertRaises(SystemExit): - mlengine_prediction_summary.run([ - "--prediction_path=some/path", - ]) + mlengine_prediction_summary.run( + ["--prediction_path=some/path",] + ) with self.assertRaises(SystemExit): - mlengine_prediction_summary.run([ - "--prediction_path=some/path", - "--metric_fn_encoded=encoded_text", - ]) + mlengine_prediction_summary.run( + ["--prediction_path=some/path", "--metric_fn_encoded=encoded_text",] + ) def test_run_should_fail_for_invalid_encoded_fn(self): with self.assertRaises(binascii.Error): - mlengine_prediction_summary.run([ - "--prediction_path=some/path", - "--metric_fn_encoded=invalid_encoded_text", - "--metric_keys=a", - ]) + mlengine_prediction_summary.run( + [ + "--prediction_path=some/path", + "--metric_fn_encoded=invalid_encoded_text", + "--metric_keys=a", + ] + ) def test_run_should_fail_if_enc_fn_is_not_callable(self): non_callable_value = 1 fn_enc = base64.b64encode(dill.dumps(non_callable_value)).decode('utf-8') with self.assertRaises(ValueError): - mlengine_prediction_summary.run([ - "--prediction_path=some/path", - "--metric_fn_encoded=" + fn_enc, - "--metric_keys=a", - ]) + mlengine_prediction_summary.run( + ["--prediction_path=some/path", "--metric_fn_encoded=" + fn_enc, "--metric_keys=a",] + ) @mock.patch.object(mlengine_prediction_summary.beam.pipeline, "PipelineOptions") @mock.patch.object(mlengine_prediction_summary.beam, "Pipeline") @@ -84,11 +83,9 @@ def metric_function(): fn_enc = base64.b64encode(dill.dumps(metric_function)).decode('utf-8') - mlengine_prediction_summary.run([ - "--prediction_path=some/path", - "--metric_fn_encoded=" + fn_enc, - "--metric_keys=a", - ]) + mlengine_prediction_summary.run( + ["--prediction_path=some/path", "--metric_fn_encoded=" + fn_enc, "--metric_keys=a",] + ) pipeline_mock.assert_called_once_with([]) pipeline_obj_mock.assert_called_once() diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index 5ea0421de513d..a9540257e13e3 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -78,23 +78,15 @@ def test_do_nothing_on_non_error(self): def test_retry_on_exception(self): message = "POST https://translation.googleapis.com/language/translate/v2: User Rate Limit Exceeded" - errors = [ - mock.MagicMock(details=mock.PropertyMock(return_value='userRateLimitExceeded')) - ] - custom_fn = NoForbiddenAfterCount( - count=5, - message=message, - errors=errors - ) + errors = [mock.MagicMock(details=mock.PropertyMock(return_value='userRateLimitExceeded'))] + custom_fn = NoForbiddenAfterCount(count=5, message=message, errors=errors) _retryable_test_with_temporary_quota_retry(custom_fn) self.assertEqual(5, custom_fn.counter) def test_raise_exception_on_non_quota_exception(self): with self.assertRaisesRegex(Forbidden, "Daily Limit Exceeded"): message = "POST https://translation.googleapis.com/language/translate/v2: Daily Limit Exceeded" - errors = [ - mock.MagicMock(details=mock.PropertyMock(return_value='dailyLimitExceeded')) - ] + errors = [mock.MagicMock(details=mock.PropertyMock(return_value='dailyLimitExceeded'))] _retryable_test_with_temporary_quota_retry( NoForbiddenAfterCount(5, message=message, errors=errors) @@ -116,7 +108,6 @@ def project_id(self): class TestFallbackToDefaultProjectId(unittest.TestCase): - def test_no_arguments(self): gcp_hook = FallbackToDefaultProjectIdFixtureClass(321) @@ -156,8 +147,7 @@ def test_restrict_positional_arguments(self): class TestProvideGcpCredentialFile(unittest.TestCase): def setUp(self): with mock.patch( - MODULE_NAME + '.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id, + MODULE_NAME + '.GoogleBaseHook.__init__', new=mock_base_gcp_hook_default_project_id, ): self.instance = hook.GoogleBaseHook(gcp_conn_id="google-cloud-default") @@ -165,7 +155,7 @@ def test_provide_gcp_credential_file_decorator_key_path_and_keyfile_dict(self): key_path = '/test/key-path' self.instance.extras = { 'extra__google_cloud_platform__key_path': key_path, - 'extra__google_cloud_platform__keyfile_dict': '{"foo": "bar"}' + 'extra__google_cloud_platform__keyfile_dict': '{"foo": "bar"}', } @hook.GoogleBaseHook.provide_gcp_credential_file @@ -175,7 +165,7 @@ def assert_gcp_credential_file_in_env(_): with self.assertRaisesRegex( AirflowException, 'The `keyfile_dict` and `key_path` fields are mutually exclusive. ' - 'Please provide only one value.' + 'Please provide only one value.', ): assert_gcp_credential_file_in_env(self.instance) @@ -350,15 +340,10 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_get_creds_a self.assertEqual(('CREDENTIALS', 'PROJECT_ID'), result) @mock.patch(MODULE_NAME + '.get_credentials_and_project_id') - def test_get_credentials_and_project_id_with_service_account_file( - self, - mock_get_creds_and_proj_id - ): + def test_get_credentials_and_project_id_with_service_account_file(self, mock_get_creds_and_proj_id): mock_credentials = mock.MagicMock() mock_get_creds_and_proj_id.return_value = (mock_credentials, "PROJECT_ID") - self.instance.extras = { - 'extra__google_cloud_platform__key_path': "KEY_PATH.json" - } + self.instance.extras = {'extra__google_cloud_platform__key_path': "KEY_PATH.json"} result = self.instance._get_credentials_and_project_id() mock_get_creds_and_proj_id.assert_called_once_with( key_path='KEY_PATH.json', @@ -370,37 +355,22 @@ def test_get_credentials_and_project_id_with_service_account_file( ) self.assertEqual((mock_credentials, 'PROJECT_ID'), result) - def test_get_credentials_and_project_id_with_service_account_file_and_p12_key( - self - ): - self.instance.extras = { - 'extra__google_cloud_platform__key_path': "KEY_PATH.p12" - } + def test_get_credentials_and_project_id_with_service_account_file_and_p12_key(self): + self.instance.extras = {'extra__google_cloud_platform__key_path': "KEY_PATH.p12"} with self.assertRaises(AirflowException): self.instance._get_credentials_and_project_id() - def test_get_credentials_and_project_id_with_service_account_file_and_unknown_key( - self - ): - self.instance.extras = { - 'extra__google_cloud_platform__key_path': "KEY_PATH.unknown" - } + def test_get_credentials_and_project_id_with_service_account_file_and_unknown_key(self): + self.instance.extras = {'extra__google_cloud_platform__key_path': "KEY_PATH.unknown"} with self.assertRaises(AirflowException): self.instance._get_credentials_and_project_id() @mock.patch(MODULE_NAME + '.get_credentials_and_project_id') - def test_get_credentials_and_project_id_with_service_account_info( - self, - mock_get_creds_and_proj_id - ): + def test_get_credentials_and_project_id_with_service_account_info(self, mock_get_creds_and_proj_id): mock_credentials = mock.MagicMock() mock_get_creds_and_proj_id.return_value = (mock_credentials, "PROJECT_ID") - service_account = { - 'private_key': "PRIVATE_KEY" - } - self.instance.extras = { - 'extra__google_cloud_platform__keyfile_dict': json.dumps(service_account) - } + service_account = {'private_key': "PRIVATE_KEY"} + self.instance.extras = {'extra__google_cloud_platform__keyfile_dict': json.dumps(service_account)} result = self.instance._get_credentials_and_project_id() mock_get_creds_and_proj_id.assert_called_once_with( key_path=None, @@ -413,10 +383,7 @@ def test_get_credentials_and_project_id_with_service_account_info( self.assertEqual((mock_credentials, 'PROJECT_ID'), result) @mock.patch(MODULE_NAME + '.get_credentials_and_project_id') - def test_get_credentials_and_project_id_with_default_auth_and_delegate( - self, - mock_get_creds_and_proj_id - ): + def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, mock_get_creds_and_proj_id): mock_credentials = mock.MagicMock() mock_get_creds_and_proj_id.return_value = (mock_credentials, "PROJECT_ID") self.instance.extras = {} @@ -440,23 +407,20 @@ def test_get_credentials_and_project_id_with_default_auth_and_unsupported_delega mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") - with self.assertRaisesRegex(AirflowException, re.escape( - "The `delegate_to` parameter cannot be used here as the current authentication method does not " - "support account impersonate. Please use service-account for authorization." - )): + with self.assertRaisesRegex( + AirflowException, + re.escape( + "The `delegate_to` parameter cannot be used here as the current authentication method " + "does not support account impersonate. Please use service-account for authorization." + ), + ): self.instance._get_credentials_and_project_id() - @mock.patch( - MODULE_NAME + '.get_credentials_and_project_id', - return_value=("CREDENTIALS", "PROJECT_ID") - ) + @mock.patch(MODULE_NAME + '.get_credentials_and_project_id', return_value=("CREDENTIALS", "PROJECT_ID")) def test_get_credentials_and_project_id_with_default_auth_and_overridden_project_id( - self, - mock_get_creds_and_proj_id + self, mock_get_creds_and_proj_id ): - self.instance.extras = { - 'extra__google_cloud_platform__project': "SECOND_PROJECT_ID" - } + self.instance.extras = {'extra__google_cloud_platform__project': "SECOND_PROJECT_ID"} result = self.instance._get_credentials_and_project_id() mock_get_creds_and_proj_id.assert_called_once_with( key_path=None, @@ -468,28 +432,22 @@ def test_get_credentials_and_project_id_with_default_auth_and_overridden_project ) self.assertEqual(("CREDENTIALS", 'SECOND_PROJECT_ID'), result) - def test_get_credentials_and_project_id_with_mutually_exclusive_configuration( - self, - ): + def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(self,): self.instance.extras = { 'extra__google_cloud_platform__project': "PROJECT_ID", 'extra__google_cloud_platform__key_path': "KEY_PATH", 'extra__google_cloud_platform__keyfile_dict': '{"KEY": "VALUE"}', } - with self.assertRaisesRegex(AirflowException, re.escape( - 'The `keyfile_dict` and `key_path` fields are mutually exclusive.' - )): + with self.assertRaisesRegex( + AirflowException, re.escape('The `keyfile_dict` and `key_path` fields are mutually exclusive.') + ): self.instance._get_credentials_and_project_id() - def test_get_credentials_and_project_id_with_invalid_keyfile_dict( - self, - ): + def test_get_credentials_and_project_id_with_invalid_keyfile_dict(self,): self.instance.extras = { 'extra__google_cloud_platform__keyfile_dict': 'INVALID_DICT', } - with self.assertRaisesRegex(AirflowException, re.escape( - 'Invalid key JSON.' - )): + with self.assertRaisesRegex(AirflowException, re.escape('Invalid key JSON.')): self.instance._get_credentials_and_project_id() @unittest.skipIf(not default_creds_available, 'Default GCP credentials not available to run tests') @@ -515,16 +473,11 @@ def test_default_creds_with_scopes(self): scopes = credentials.scopes self.assertIn('https://www.googleapis.com/auth/bigquery', scopes) - self.assertIn( - 'https://www.googleapis.com/auth/devstorage.read_only', scopes) + self.assertIn('https://www.googleapis.com/auth/devstorage.read_only', scopes) - @unittest.skipIf( - not default_creds_available, - 'Default GCP credentials not available to run tests') + @unittest.skipIf(not default_creds_available, 'Default GCP credentials not available to run tests') def test_default_creds_no_scopes(self): - self.instance.extras = { - 'extra__google_cloud_platform__project': default_project - } + self.instance.extras = {'extra__google_cloud_platform__project': default_project} credentials = self.instance._get_credentials() @@ -542,8 +495,7 @@ def test_provide_gcp_credential_file_decorator_key_path(self): @hook.GoogleBaseHook.provide_gcp_credential_file def assert_gcp_credential_file_in_env(hook_instance): # pylint: disable=unused-argument - self.assertEqual(os.environ[CREDENTIALS], - key_path) + self.assertEqual(os.environ[CREDENTIALS], key_path) assert_gcp_credential_file_in_env(self.instance) @@ -552,17 +504,14 @@ def test_provide_gcp_credential_file_decorator_key_content(self, mock_file): string_file = StringIO() file_content = '{"foo": "bar"}' file_name = '/test/mock-file' - self.instance.extras = { - 'extra__google_cloud_platform__keyfile_dict': file_content - } + self.instance.extras = {'extra__google_cloud_platform__keyfile_dict': file_content} mock_file_handler = mock_file.return_value.__enter__.return_value mock_file_handler.name = file_name mock_file_handler.write = string_file.write @hook.GoogleBaseHook.provide_gcp_credential_file def assert_gcp_credential_file_in_env(hook_instance): # pylint: disable=unused-argument - self.assertEqual(os.environ[CREDENTIALS], - file_name) + self.assertEqual(os.environ[CREDENTIALS], file_name) self.assertEqual(file_content, string_file.getvalue()) assert_gcp_credential_file_in_env(self.instance) @@ -599,9 +548,7 @@ def test_num_retries_is_not_none_by_default(self, get_con_mock): Verify that if 'num_retries' in extras is not set, the default value should not be None """ - get_con_mock.return_value.extra_dejson = { - "extra__google_cloud_platform__num_retries": None - } + get_con_mock.return_value.extra_dejson = {"extra__google_cloud_platform__num_retries": None} self.assertEqual(self.instance.num_retries, 5) @mock.patch("airflow.providers.google.common.hooks.base_google.build_http") @@ -624,7 +571,7 @@ def test_authorize_assert_user_agent_is_sent(self, mock_get_credentials, mock_ht connection_type=None, headers={'user-agent': 'airflow/' + version.version}, method='GET', - redirections=5 + redirections=5, ) self.assertEqual(response, new_response) self.assertEqual(content, new_content) @@ -645,20 +592,21 @@ def test_authorize_assert_http_timeout_is_present(self, mock_get_credentials): http_authorized = self.instance._authorize().http self.assertNotEqual(http_authorized.timeout, None) - @parameterized.expand([ - ('string', "ACCOUNT_1", "ACCOUNT_1", None), - ('single_element_list', ["ACCOUNT_1"], "ACCOUNT_1", []), - ('multiple_elements_list', - ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"], "ACCOUNT_3", ["ACCOUNT_1", "ACCOUNT_2"]), - ]) + @parameterized.expand( + [ + ('string', "ACCOUNT_1", "ACCOUNT_1", None), + ('single_element_list', ["ACCOUNT_1"], "ACCOUNT_1", []), + ( + 'multiple_elements_list', + ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"], + "ACCOUNT_3", + ["ACCOUNT_1", "ACCOUNT_2"], + ), + ] + ) @mock.patch(MODULE_NAME + '.get_credentials_and_project_id') def test_get_credentials_and_project_id_with_impersonation_chain( - self, - _, - impersonation_chain, - target_principal, - delegates, - mock_get_creds_and_proj_id, + self, _, impersonation_chain, target_principal, delegates, mock_get_creds_and_proj_id, ): mock_credentials = mock.MagicMock() mock_get_creds_and_proj_id.return_value = (mock_credentials, PROJECT_ID) @@ -678,30 +626,27 @@ def test_get_credentials_and_project_id_with_impersonation_chain( class TestProvideAuthorizedGcloud(unittest.TestCase): def setUp(self): with mock.patch( - MODULE_NAME + '.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id, + MODULE_NAME + '.GoogleBaseHook.__init__', new=mock_base_gcp_hook_default_project_id, ): self.instance = hook.GoogleBaseHook(gcp_conn_id="google-cloud-default") @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=mock.PropertyMock, - return_value="PROJECT_ID" + return_value="PROJECT_ID", ) @mock.patch(MODULE_NAME + '.check_output') - def test_provide_authorized_gcloud_key_path_and_keyfile_dict( - self, mock_check_output, mock_default - ): + def test_provide_authorized_gcloud_key_path_and_keyfile_dict(self, mock_check_output, mock_default): key_path = '/test/key-path' self.instance.extras = { 'extra__google_cloud_platform__key_path': key_path, - 'extra__google_cloud_platform__keyfile_dict': '{"foo": "bar"}' + 'extra__google_cloud_platform__keyfile_dict': '{"foo": "bar"}', } with self.assertRaisesRegex( AirflowException, 'The `keyfile_dict` and `key_path` fields are mutually exclusive. ' - 'Please provide only one value.' + 'Please provide only one value.', ): with self.instance.provide_authorized_gcloud(): self.assertEqual(os.environ[CREDENTIALS], key_path) @@ -709,7 +654,7 @@ def test_provide_authorized_gcloud_key_path_and_keyfile_dict( @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=mock.PropertyMock, - return_value="PROJECT_ID" + return_value="PROJECT_ID", ) @mock.patch(MODULE_NAME + '.check_output') def test_provide_authorized_gcloud_key_path(self, mock_check_output, mock_project_id): @@ -721,13 +666,13 @@ def test_provide_authorized_gcloud_key_path(self, mock_check_output, mock_projec mock_check_output.has_calls( mock.call(['gcloud', 'config', 'set', 'core/project', 'PROJECT_ID']), - mock.call(['gcloud', 'auth', 'activate-service-account', '--key-file=/test/key-path']) + mock.call(['gcloud', 'auth', 'activate-service-account', '--key-file=/test/key-path']), ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=mock.PropertyMock, - return_value="PROJECT_ID" + return_value="PROJECT_ID", ) @mock.patch(MODULE_NAME + '.check_output') @mock.patch('tempfile.NamedTemporaryFile') @@ -743,15 +688,17 @@ def test_provide_authorized_gcloud_keyfile_dict(self, mock_file, mock_check_outp with self.instance.provide_authorized_gcloud(): self.assertEqual(os.environ[CREDENTIALS], file_name) - mock_check_output.has_calls([ - mock.call(['gcloud', 'config', 'set', 'core/project', 'PROJECT_ID']), - mock.call(['gcloud', 'auth', 'activate-service-account', '--key-file=/test/mock-file']) - ]) + mock_check_output.has_calls( + [ + mock.call(['gcloud', 'config', 'set', 'core/project', 'PROJECT_ID']), + mock.call(['gcloud', 'auth', 'activate-service-account', '--key-file=/test/mock-file']), + ] + ) @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=mock.PropertyMock, - return_value="PROJECT_ID" + return_value="PROJECT_ID", ) @mock.patch(MODULE_NAME + '._cloud_sdk') @mock.patch(MODULE_NAME + '.check_output') @@ -762,12 +709,14 @@ def test_provide_authorized_gcloud_via_gcloud_application_default( # This file always exists. mock_cloud_sdk.get_application_default_credentials_path.return_value = __file__ - file_content = json.dumps({ - "client_id": "CLIENT_ID", - "client_secret": "CLIENT_SECRET", - "refresh_token": "REFRESH_TOKEN", - "type": "authorized_user" - }) + file_content = json.dumps( + { + "client_id": "CLIENT_ID", + "client_secret": "CLIENT_SECRET", + "refresh_token": "REFRESH_TOKEN", + "type": "authorized_user", + } + ) with mock.patch(MODULE_NAME + '.open', mock.mock_open(read_data=file_content)): with self.instance.provide_authorized_gcloud(): # Do nothing @@ -778,14 +727,13 @@ def test_provide_authorized_gcloud_via_gcloud_application_default( mock.call(['gcloud', 'config', 'set', 'auth/client_id', 'CLIENT_ID']), mock.call(['gcloud', 'config', 'set', 'auth/client_secret', 'CLIENT_SECRET']), mock.call(['gcloud', 'config', 'set', 'core/project', 'PROJECT_ID']), - mock.call(['gcloud', 'auth', 'activate-refresh-token', 'CLIENT_ID', 'REFRESH_TOKEN']) + mock.call(['gcloud', 'auth', 'activate-refresh-token', 'CLIENT_ID', 'REFRESH_TOKEN']), ], - any_order=False + any_order=False, ) class TestNumRetry(unittest.TestCase): - def test_should_return_int_when_set_int_via_connection(self): instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") instance.extras = { @@ -799,7 +747,7 @@ def test_should_return_int_when_set_int_via_connection(self): 'os.environ', AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT=( 'google-cloud-platform://?extra__google_cloud_platform__num_retries=5' - ) + ), ) def test_should_return_int_when_set_via_env_var(self): instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") @@ -809,7 +757,7 @@ def test_should_return_int_when_set_via_env_var(self): 'os.environ', AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT=( 'google-cloud-platform://?extra__google_cloud_platform__num_retries=cat' - ) + ), ) def test_should_raise_when_invalid_value_via_env_var(self): instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") @@ -822,7 +770,7 @@ def test_should_raise_when_invalid_value_via_env_var(self): 'os.environ', AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT=( 'google-cloud-platform://?extra__google_cloud_platform__num_retries=' - ) + ), ) def test_should_fallback_when_empty_string_in_env_var(self): instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") diff --git a/tests/providers/google/common/hooks/test_discovery_api.py b/tests/providers/google/common/hooks/test_discovery_api.py index c9c425344a8d4..14ca31583daaa 100644 --- a/tests/providers/google/common/hooks/test_discovery_api.py +++ b/tests/providers/google/common/hooks/test_discovery_api.py @@ -26,7 +26,6 @@ class TestGoogleDiscoveryApiHook(unittest.TestCase): - def setUp(self): load_test_config() @@ -37,7 +36,7 @@ def setUp(self): host='google', schema='refresh_token', login='client_id', - password='client_secret' + password='client_secret', ) ) @@ -45,9 +44,7 @@ def setUp(self): @patch('airflow.providers.google.common.hooks.discovery_api.GoogleDiscoveryApiHook._authorize') def test_get_conn(self, mock_authorize, mock_build): google_discovery_api_hook = GoogleDiscoveryApiHook( - gcp_conn_id='google_test', - api_service_name='youtube', - api_version='v2' + gcp_conn_id='google_test', api_service_name='youtube', api_version='v2' ) google_discovery_api_hook.get_conn() @@ -56,26 +53,26 @@ def test_get_conn(self, mock_authorize, mock_build): serviceName=google_discovery_api_hook.api_service_name, version=google_discovery_api_hook.api_version, http=mock_authorize.return_value, - cache_discovery=False + cache_discovery=False, ) @patch('airflow.providers.google.common.hooks.discovery_api.getattr') @patch('airflow.providers.google.common.hooks.discovery_api.GoogleDiscoveryApiHook.get_conn') def test_query(self, mock_get_conn, mock_getattr): google_discovery_api_hook = GoogleDiscoveryApiHook( - gcp_conn_id='google_test', - api_service_name='analyticsreporting', - api_version='v4' + gcp_conn_id='google_test', api_service_name='analyticsreporting', api_version='v4' ) endpoint = 'analyticsreporting.reports.batchGet' data = { 'body': { - 'reportRequests': [{ - 'viewId': '180628393', - 'dateRanges': [{'startDate': '7daysAgo', 'endDate': 'today'}], - 'metrics': [{'expression': 'ga:sessions'}], - 'dimensions': [{'name': 'ga:country'}] - }] + 'reportRequests': [ + { + 'viewId': '180628393', + 'dateRanges': [{'startDate': '7daysAgo', 'endDate': 'today'}], + 'metrics': [{'expression': 'ga:sessions'}], + 'dimensions': [{'name': 'ga:country'}], + } + ] } } num_retries = 1 @@ -83,13 +80,15 @@ def test_query(self, mock_get_conn, mock_getattr): google_discovery_api_hook.query(endpoint, data, num_retries=num_retries) google_api_endpoint_name_parts = endpoint.split('.') - mock_getattr.assert_has_calls([ - call(mock_get_conn.return_value, google_api_endpoint_name_parts[1]), - call()(), - call(mock_getattr.return_value.return_value, google_api_endpoint_name_parts[2]), - call()(**data), - call()().execute(num_retries=num_retries) - ]) + mock_getattr.assert_has_calls( + [ + call(mock_get_conn.return_value, google_api_endpoint_name_parts[1]), + call()(), + call(mock_getattr.return_value.return_value, google_api_endpoint_name_parts[2]), + call()(**data), + call()().execute(num_retries=num_retries), + ] + ) @patch('airflow.providers.google.common.hooks.discovery_api.getattr') @patch('airflow.providers.google.common.hooks.discovery_api.GoogleDiscoveryApiHook.get_conn') @@ -101,22 +100,22 @@ def test_query_with_pagination(self, mock_get_conn, mock_getattr): google_api_conn_client_sub_call, google_api_conn_client_sub_call, google_api_conn_client_sub_call, - None + None, ] google_discovery_api_hook = GoogleDiscoveryApiHook( - gcp_conn_id='google_test', - api_service_name='analyticsreporting', - api_version='v4' + gcp_conn_id='google_test', api_service_name='analyticsreporting', api_version='v4' ) endpoint = 'analyticsreporting.reports.batchGet' data = { 'body': { - 'reportRequests': [{ - 'viewId': '180628393', - 'dateRanges': [{'startDate': '7daysAgo', 'endDate': 'today'}], - 'metrics': [{'expression': 'ga:sessions'}], - 'dimensions': [{'name': 'ga:country'}] - }] + 'reportRequests': [ + { + 'viewId': '180628393', + 'dateRanges': [{'startDate': '7daysAgo', 'endDate': 'today'}], + 'metrics': [{'expression': 'ga:sessions'}], + 'dimensions': [{'name': 'ga:country'}], + } + ] } } num_retries = 1 @@ -125,21 +124,23 @@ def test_query_with_pagination(self, mock_get_conn, mock_getattr): api_endpoint_name_parts = endpoint.split('.') google_api_conn_client = mock_get_conn.return_value - mock_getattr.assert_has_calls([ - call(google_api_conn_client, api_endpoint_name_parts[1]), - call()(), - call(google_api_conn_client_sub_call, api_endpoint_name_parts[2]), - call()(**data), - call()().__bool__(), - call()().execute(num_retries=num_retries), - call(google_api_conn_client, api_endpoint_name_parts[1]), - call()(), - call(google_api_conn_client_sub_call, api_endpoint_name_parts[2] + '_next'), - call()(google_api_conn_client_sub_call, google_api_conn_client_sub_call.execute.return_value), - call()().__bool__(), - call()().execute(num_retries=num_retries), - call(google_api_conn_client, api_endpoint_name_parts[1]), - call()(), - call(google_api_conn_client_sub_call, api_endpoint_name_parts[2] + '_next'), - call()(google_api_conn_client_sub_call, google_api_conn_client_sub_call.execute.return_value) - ]) + mock_getattr.assert_has_calls( + [ + call(google_api_conn_client, api_endpoint_name_parts[1]), + call()(), + call(google_api_conn_client_sub_call, api_endpoint_name_parts[2]), + call()(**data), + call()().__bool__(), + call()().execute(num_retries=num_retries), + call(google_api_conn_client, api_endpoint_name_parts[1]), + call()(), + call(google_api_conn_client_sub_call, api_endpoint_name_parts[2] + '_next'), + call()(google_api_conn_client_sub_call, google_api_conn_client_sub_call.execute.return_value), + call()().__bool__(), + call()().execute(num_retries=num_retries), + call(google_api_conn_client, api_endpoint_name_parts[1]), + call()(), + call(google_api_conn_client_sub_call, api_endpoint_name_parts[2] + '_next'), + call()(google_api_conn_client_sub_call, google_api_conn_client_sub_call.execute.return_value), + ] + ) diff --git a/tests/providers/google/common/utils/test_id_token_credentials.py b/tests/providers/google/common/utils/test_id_token_credentials.py index f1eb242b1a365..a0f66b96adbb5 100644 --- a/tests/providers/google/common/utils/test_id_token_credentials.py +++ b/tests/providers/google/common/utils/test_id_token_credentials.py @@ -25,7 +25,8 @@ from google.auth.environment_vars import CREDENTIALS from airflow.providers.google.common.utils.id_token_credentials import ( - IDTokenCredentialsAdapter, get_default_id_token_credentials, + IDTokenCredentialsAdapter, + get_default_id_token_credentials, ) diff --git a/tests/providers/google/firebase/hooks/test_firestore.py b/tests/providers/google/firebase/hooks/test_firestore.py index 18bfcae095497..a539bb485102d 100644 --- a/tests/providers/google/firebase/hooks/test_firestore.py +++ b/tests/providers/google/firebase/hooks/test_firestore.py @@ -27,7 +27,8 @@ from airflow.exceptions import AirflowException from airflow.providers.google.firebase.hooks.firestore import CloudFirestoreHook from tests.providers.google.cloud.utils.base_gcp_mock import ( - GCP_PROJECT_ID_HOOK_UNIT_TEST, mock_base_gcp_hook_default_project_id, + GCP_PROJECT_ID_HOOK_UNIT_TEST, + mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) @@ -36,7 +37,9 @@ "collectionIds": ["test-collection"], } -TEST_OPERATION = {"name": "operation-name", } +TEST_OPERATION = { + "name": "operation-name", +} TEST_WAITING_OPERATION = {"done": False, "response": "response"} TEST_DONE_OPERATION = {"done": True, "response": "response"} TEST_ERROR_OPERATION = {"done": True, "response": "response", "error": "error"} @@ -58,9 +61,7 @@ def setUp(self): @mock.patch("airflow.providers.google.firebase.hooks.firestore.build_from_document") def test_client_creation(self, mock_build_from_document, mock_build, mock_authorize): result = self.hook.get_conn() - mock_build.assert_called_once_with( - 'firestore', 'v1', cache_discovery=False - ) + mock_build.assert_called_once_with('firestore', 'v1', cache_discovery=False) mock_build_from_document.assert_called_once_with( mock_build.return_value._rootDesc, http=mock_authorize.return_value ) @@ -75,14 +76,9 @@ def test_mmediately_complete(self, get_conn_mock): mock_operation_get = ( service_mock.projects.return_value.databases.return_value.operations.return_value.get ) - ( - mock_export_documents.return_value - .execute.return_value - ) = TEST_OPERATION + (mock_export_documents.return_value.execute.return_value) = TEST_OPERATION - ( - mock_operation_get.return_value.execute.return_value - ) = TEST_DONE_OPERATION + (mock_operation_get.return_value.execute.return_value) = TEST_DONE_OPERATION self.hook.export_documents(body=EXPORT_DOCUMENT_BODY, project_id=TEST_PROJECT_ID) @@ -99,10 +95,7 @@ def test_waiting_operation(self, _, get_conn_mock): mock_operation_get = ( service_mock.projects.return_value.databases.return_value.operations.return_value.get ) - ( - mock_export_documents.return_value - .execute.return_value - ) = TEST_OPERATION + (mock_export_documents.return_value.execute.return_value) = TEST_OPERATION execute_mock = mock.Mock( **{"side_effect": [TEST_WAITING_OPERATION, TEST_DONE_OPERATION, TEST_DONE_OPERATION]} @@ -124,10 +117,7 @@ def test_error_operation(self, _, get_conn_mock): mock_operation_get = ( service_mock.projects.return_value.databases.return_value.operations.return_value.get ) - ( - mock_export_documents.return_value - .execute.return_value - ) = TEST_OPERATION + (mock_export_documents.return_value.execute.return_value) = TEST_OPERATION execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]}) mock_operation_get.return_value.execute = execute_mock @@ -150,9 +140,7 @@ def setUp(self): @mock.patch("airflow.providers.google.firebase.hooks.firestore.build_from_document") def test_client_creation(self, mock_build_from_document, mock_build, mock_authorize): result = self.hook.get_conn() - mock_build.assert_called_once_with( - 'firestore', 'v1', cache_discovery=False - ) + mock_build.assert_called_once_with('firestore', 'v1', cache_discovery=False) mock_build_from_document.assert_called_once_with( mock_build.return_value._rootDesc, http=mock_authorize.return_value ) @@ -162,7 +150,7 @@ def test_client_creation(self, mock_build_from_document, mock_build, mock_author @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.firebase.hooks.firestore.CloudFirestoreHook.get_conn") def test_immediately_complete(self, get_conn_mock, mock_project_id): @@ -172,10 +160,7 @@ def test_immediately_complete(self, get_conn_mock, mock_project_id): mock_operation_get = ( service_mock.projects.return_value.databases.return_value.operations.return_value.get ) - ( - mock_export_documents.return_value - .execute.return_value - ) = TEST_OPERATION + (mock_export_documents.return_value.execute.return_value) = TEST_OPERATION mock_operation_get.return_value.execute.return_value = TEST_DONE_OPERATION @@ -188,7 +173,7 @@ def test_immediately_complete(self, get_conn_mock, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.firebase.hooks.firestore.CloudFirestoreHook.get_conn") @mock.patch("airflow.providers.google.firebase.hooks.firestore.time.sleep") @@ -199,10 +184,7 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id): mock_operation_get = ( service_mock.projects.return_value.databases.return_value.operations.return_value.get ) - ( - mock_export_documents.return_value - .execute.return_value - ) = TEST_OPERATION + (mock_export_documents.return_value.execute.return_value) = TEST_OPERATION execute_mock = mock.Mock( **{"side_effect": [TEST_WAITING_OPERATION, TEST_DONE_OPERATION, TEST_DONE_OPERATION]} @@ -218,7 +200,7 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) @mock.patch("airflow.providers.google.firebase.hooks.firestore.CloudFirestoreHook.get_conn") @mock.patch("airflow.providers.google.firebase.hooks.firestore.time.sleep") @@ -229,10 +211,7 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id): mock_operation_get = ( service_mock.projects.return_value.databases.return_value.operations.return_value.get ) - ( - mock_export_documents.return_value - .execute.return_value - ) = TEST_OPERATION + (mock_export_documents.return_value.execute.return_value) = TEST_OPERATION execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]}) mock_operation_get.return_value.execute = execute_mock @@ -253,7 +232,7 @@ def setUp(self): @mock.patch( 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id', new_callable=PropertyMock, - return_value=None + return_value=None, ) @mock.patch("airflow.providers.google.firebase.hooks.firestore.CloudFirestoreHook.get_conn") def test_create_build(self, mock_get_conn, mock_project_id): diff --git a/tests/providers/google/firebase/operators/test_firestore_system.py b/tests/providers/google/firebase/operators/test_firestore_system.py index 9ec2bb71f8382..9a42a1d130378 100644 --- a/tests/providers/google/firebase/operators/test_firestore_system.py +++ b/tests/providers/google/firebase/operators/test_firestore_system.py @@ -18,7 +18,8 @@ import pytest from airflow.providers.google.firebase.example_dags.example_firestore import ( - DATASET_NAME, EXPORT_DESTINATION_URL, + DATASET_NAME, + EXPORT_DESTINATION_URL, ) from tests.providers.google.cloud.utils.gcp_authenticator import G_FIREBASE_KEY from tests.test_utils.gcp_system_helpers import FIREBASE_DAG_FOLDER, GoogleSystemTest, provide_gcp_context diff --git a/tests/providers/google/marketing_platform/hooks/test_analytics.py b/tests/providers/google/marketing_platform/hooks/test_analytics.py index c073e2916c881..b5e7368e6fc33 100644 --- a/tests/providers/google/marketing_platform/hooks/test_analytics.py +++ b/tests/providers/google/marketing_platform/hooks/test_analytics.py @@ -42,37 +42,25 @@ def setUp(self): @mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__") def test_init(self, mock_base_init): hook = GoogleAnalyticsHook( - API_VERSION, - GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + API_VERSION, GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) mock_base_init.assert_called_once_with( - GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) self.assertEqual(hook.api_version, API_VERSION) @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook._authorize" + "airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook._authorize" ) @mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.build") def test_gen_conn(self, mock_build, mock_authorize): result = self.hook.get_conn() mock_build.assert_called_once_with( - "analytics", - API_VERSION, - http=mock_authorize.return_value, - cache_discovery=False, + "analytics", API_VERSION, http=mock_authorize.return_value, cache_discovery=False, ) self.assertEqual(mock_build.return_value, result) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_list_accounts(self, get_conn_mock): mock_accounts = get_conn_mock.return_value.management.return_value.accounts mock_list = mock_accounts.return_value.list @@ -81,10 +69,7 @@ def test_list_accounts(self, get_conn_mock): list_accounts = self.hook.list_accounts() self.assertEqual(list_accounts, ["a", "b"]) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_list_accounts_for_multiple_pages(self, get_conn_mock): mock_accounts = get_conn_mock.return_value.management.return_value.accounts mock_list = mock_accounts.return_value.list @@ -96,10 +81,7 @@ def test_list_accounts_for_multiple_pages(self, get_conn_mock): list_accounts = self.hook.list_accounts() self.assertEqual(list_accounts, ["a", "b"]) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_get_ad_words_links_call(self, get_conn_mock): num_retries = 5 self.hook.get_ad_words_link( @@ -107,7 +89,7 @@ def test_get_ad_words_links_call(self, get_conn_mock): web_property_id=WEB_PROPERTY_ID, web_property_ad_words_link_id=WEB_PROPERTY_AD_WORDS_LINK_ID, ) - + # fmt: off get_conn_mock.return_value.management.return_value.webPropertyAdWordsLinks.\ return_value.get.return_value.execute.assert_called_once_with( num_retries=num_retries @@ -119,49 +101,31 @@ def test_get_ad_words_links_call(self, get_conn_mock): webPropertyId=WEB_PROPERTY_ID, webPropertyAdWordsLinkId=WEB_PROPERTY_AD_WORDS_LINK_ID, ) + # fmt: on - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_list_ad_words_links(self, get_conn_mock): - mock_ads_links = ( - get_conn_mock.return_value.management.return_value.webPropertyAdWordsLinks - ) + mock_ads_links = get_conn_mock.return_value.management.return_value.webPropertyAdWordsLinks mock_list = mock_ads_links.return_value.list mock_execute = mock_list.return_value.execute mock_execute.return_value = {"items": ["a", "b"], "totalResults": 2} - list_ads_links = self.hook.list_ad_words_links( - account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID - ) + list_ads_links = self.hook.list_ad_words_links(account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID) self.assertEqual(list_ads_links, ["a", "b"]) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_list_ad_words_links_for_multiple_pages(self, get_conn_mock): - mock_ads_links = ( - get_conn_mock.return_value.management.return_value.webPropertyAdWordsLinks - ) + mock_ads_links = get_conn_mock.return_value.management.return_value.webPropertyAdWordsLinks mock_list = mock_ads_links.return_value.list mock_execute = mock_list.return_value.execute mock_execute.side_effect = [ {"items": ["a"], "totalResults": 2}, {"items": ["b"], "totalResults": 2}, ] - list_ads_links = self.hook.list_ad_words_links( - account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID - ) + list_ads_links = self.hook.list_ad_words_links(account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID) self.assertEqual(list_ads_links, ["a", "b"]) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." "analytics.MediaFileUpload" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.MediaFileUpload") def test_upload_data(self, media_mock, get_conn_mock): temp_name = "temp/file" self.hook.upload_data( @@ -172,9 +136,8 @@ def test_upload_data(self, media_mock, get_conn_mock): resumable_upload=True, ) - media_mock.assert_called_once_with( - temp_name, mimetype="application/octet-stream", resumable=True - ) + media_mock.assert_called_once_with(temp_name, mimetype="application/octet-stream", resumable=True) + # fmt: off get_conn_mock.return_value.management.return_value.uploads.return_value.uploadData.\ assert_called_once_with( accountId=ACCOUNT_ID, @@ -182,11 +145,9 @@ def test_upload_data(self, media_mock, get_conn_mock): customDataSourceId=DATA_SOURCE, media_body=media_mock.return_value, ) + # fmt: on - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_delete_upload_data(self, get_conn_mock): body = {"key": "temp/file"} self.hook.delete_upload_data( @@ -195,7 +156,7 @@ def test_delete_upload_data(self, get_conn_mock): custom_data_source_id=DATA_SOURCE, delete_request_body=body, ) - + # fmt: off get_conn_mock.return_value.management.return_value.uploads.return_value.deleteUploadData.\ assert_called_once_with( accountId=ACCOUNT_ID, @@ -203,22 +164,16 @@ def test_delete_upload_data(self, get_conn_mock): customDataSourceId=DATA_SOURCE, body=body, ) + # fmt: on - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "analytics.GoogleAnalyticsHook.get_conn" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "analytics.GoogleAnalyticsHook.get_conn") def test_list_upload(self, get_conn_mock): - uploads = ( - get_conn_mock.return_value.management.return_value.uploads.return_value - ) + uploads = get_conn_mock.return_value.management.return_value.uploads.return_value uploads.list.return_value.execute.return_value = { "items": ["a", "b"], "totalResults": 2, } result = self.hook.list_uploads( - account_id=ACCOUNT_ID, - web_property_id=WEB_PROPERTY_ID, - custom_data_source_id=DATA_SOURCE, + account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID, custom_data_source_id=DATA_SOURCE, ) self.assertEqual(result, ["a", "b"]) diff --git a/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py b/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py index 0ab28638b6a06..d51bb95039423 100644 --- a/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py +++ b/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py @@ -36,24 +36,17 @@ def setUp(self): "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", new=mock_base_gcp_hook_default_project_id, ): - self.hook = GoogleCampaignManagerHook( - gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION - ) + self.hook = GoogleCampaignManagerHook(gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION) @mock.patch( "airflow.providers.google.marketing_platform.hooks." "campaign_manager.GoogleCampaignManagerHook._authorize" ) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks.campaign_manager.build" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks.campaign_manager.build") def test_gen_conn(self, mock_build, mock_authorize): result = self.hook.get_conn() mock_build.assert_called_once_with( - "dfareporting", - API_VERSION, - http=mock_authorize.return_value, - cache_discovery=False, + "dfareporting", API_VERSION, http=mock_authorize.return_value, cache_discovery=False, ) self.assertEqual(mock_build.return_value, result) @@ -83,12 +76,12 @@ def test_get_report(self, get_conn_mock): file_id = "FILE_ID" return_value = "TEST" - get_conn_mock.return_value.reports.return_value.files.return_value.\ + # fmt: off + get_conn_mock.return_value.reports.return_value.files.return_value. \ get.return_value.execute.return_value = return_value + # fmt: on - result = self.hook.get_report( - profile_id=PROFILE_ID, report_id=REPORT_ID, file_id=file_id - ) + result = self.hook.get_report(profile_id=PROFILE_ID, report_id=REPORT_ID, file_id=file_id) get_conn_mock.return_value.reports.return_value.files.return_value.get.assert_called_once_with( profileId=PROFILE_ID, reportId=REPORT_ID, fileId=file_id @@ -108,9 +101,7 @@ def test_get_report_file(self, get_conn_mock): return_value ) - result = self.hook.get_report_file( - profile_id=PROFILE_ID, report_id=REPORT_ID, file_id=file_id - ) + result = self.hook.get_report_file(profile_id=PROFILE_ID, report_id=REPORT_ID, file_id=file_id) get_conn_mock.return_value.reports.return_value.files.return_value.get_media.assert_called_once_with( profileId=PROFILE_ID, reportId=REPORT_ID, fileId=file_id @@ -150,9 +141,7 @@ def test_list_reports(self, get_conn_mock): items = ["item"] return_value = {"nextPageToken": None, "items": items} - get_conn_mock.return_value.reports.return_value.list.return_value.execute.return_value = ( - return_value - ) + get_conn_mock.return_value.reports.return_value.list.return_value.execute.return_value = return_value request_mock = mock.MagicMock() request_mock.execute.return_value = {"nextPageToken": None, "items": items} @@ -189,13 +178,9 @@ def test_patch_report(self, get_conn_mock): update_mask = {"test": "test"} return_value = "TEST" - get_conn_mock.return_value.reports.return_value.patch.return_value.execute.return_value = ( - return_value - ) + get_conn_mock.return_value.reports.return_value.patch.return_value.execute.return_value = return_value - result = self.hook.patch_report( - profile_id=PROFILE_ID, report_id=REPORT_ID, update_mask=update_mask - ) + result = self.hook.patch_report(profile_id=PROFILE_ID, report_id=REPORT_ID, update_mask=update_mask) get_conn_mock.return_value.reports.return_value.patch.assert_called_once_with( profileId=PROFILE_ID, reportId=REPORT_ID, body=update_mask @@ -211,13 +196,9 @@ def test_run_report(self, get_conn_mock): synchronous = True return_value = "TEST" - get_conn_mock.return_value.reports.return_value.run.return_value.execute.return_value = ( - return_value - ) + get_conn_mock.return_value.reports.return_value.run.return_value.execute.return_value = return_value - result = self.hook.run_report( - profile_id=PROFILE_ID, report_id=REPORT_ID, synchronous=synchronous - ) + result = self.hook.run_report(profile_id=PROFILE_ID, report_id=REPORT_ID, synchronous=synchronous) get_conn_mock.return_value.reports.return_value.run.assert_called_once_with( profileId=PROFILE_ID, reportId=REPORT_ID, synchronous=synchronous diff --git a/tests/providers/google/marketing_platform/hooks/test_display_video.py b/tests/providers/google/marketing_platform/hooks/test_display_video.py index f61f118fc035d..1193d280b530f 100644 --- a/tests/providers/google/marketing_platform/hooks/test_display_video.py +++ b/tests/providers/google/marketing_platform/hooks/test_display_video.py @@ -36,16 +36,11 @@ def setUp(self): "airflow.providers.google.marketing_platform.hooks." "display_video.GoogleDisplayVideo360Hook._authorize" ) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." "display_video.build" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "display_video.build") def test_gen_conn(self, mock_build, mock_authorize): result = self.hook.get_conn() mock_build.assert_called_once_with( - "doubleclickbidmanager", - API_VERSION, - http=mock_authorize.return_value, - cache_discovery=False, + "doubleclickbidmanager", API_VERSION, http=mock_authorize.return_value, cache_discovery=False, ) self.assertEqual(mock_build.return_value, result) @@ -53,16 +48,11 @@ def test_gen_conn(self, mock_build, mock_authorize): "airflow.providers.google.marketing_platform.hooks." "display_video.GoogleDisplayVideo360Hook._authorize" ) - @mock.patch( - "airflow.providers.google.marketing_platform.hooks." "display_video.build" - ) + @mock.patch("airflow.providers.google.marketing_platform.hooks." "display_video.build") def test_get_conn_to_display_video(self, mock_build, mock_authorize): result = self.hook.get_conn_to_display_video() mock_build.assert_called_once_with( - "displayvideo", - API_VERSION, - http=mock_authorize.return_value, - cache_discovery=False, + "displayvideo", API_VERSION, http=mock_authorize.return_value, cache_discovery=False, ) self.assertEqual(mock_build.return_value, result) @@ -80,9 +70,7 @@ def test_create_query(self, get_conn_mock): result = self.hook.create_query(query=body) - get_conn_mock.return_value.queries.return_value.createquery.assert_called_once_with( - body=body - ) + get_conn_mock.return_value.queries.return_value.createquery.assert_called_once_with(body=body) self.assertEqual(return_value, result) @@ -100,9 +88,7 @@ def test_delete_query(self, get_conn_mock): self.hook.delete_query(query_id=query_id) - get_conn_mock.return_value.queries.return_value.deletequery.assert_called_once_with( - queryId=query_id - ) + get_conn_mock.return_value.queries.return_value.deletequery.assert_called_once_with(queryId=query_id) @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -118,9 +104,7 @@ def test_get_query(self, get_conn_mock): result = self.hook.get_query(query_id=query_id) - get_conn_mock.return_value.queries.return_value.getquery.assert_called_once_with( - queryId=query_id - ) + get_conn_mock.return_value.queries.return_value.getquery.assert_called_once_with(queryId=query_id) self.assertEqual(return_value, result) @@ -164,12 +148,10 @@ def test_download_line_items_should_be_called_once(self, get_conn_mock): "filterType": "filter_type", "filterIds": [], "format": "format", - "fileSpec": "file_spec" + "fileSpec": "file_spec", } self.hook.download_line_items(request_body=request_body) - get_conn_mock.return_value\ - .lineitems.return_value\ - .downloadlineitems.assert_called_once() + get_conn_mock.return_value.lineitems.return_value.downloadlineitems.assert_called_once() @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -180,13 +162,13 @@ def test_download_line_items_should_be_called_with_params(self, get_conn_mock): "filterType": "filter_type", "filterIds": [], "format": "format", - "fileSpec": "file_spec" + "fileSpec": "file_spec", } self.hook.download_line_items(request_body=request_body) - get_conn_mock.return_value \ - .lineitems.return_value \ - .downloadlineitems.assert_called_once_with(body=request_body) + get_conn_mock.return_value.lineitems.return_value.downloadlineitems.assert_called_once_with( + body=request_body + ) @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -199,12 +181,13 @@ def test_download_line_items_should_return_equal_values(self, get_conn_mock): "filterType": "filter_type", "filterIds": [], "format": "format", - "fileSpec": "file_spec" + "fileSpec": "file_spec", } - get_conn_mock.return_value \ - .lineitems.return_value \ + # fmt: off + get_conn_mock.return_value.lineitems.return_value \ .downloadlineitems.return_value.execute.return_value = response + # fmt: on result = self.hook.download_line_items(request_body) self.assertEqual(line_item, result) @@ -217,9 +200,7 @@ def test_upload_line_items_should_be_called_once(self, get_conn_mock): line_items = ["this", "is", "super", "awesome", "test"] self.hook.upload_line_items(line_items) - get_conn_mock.return_value \ - .lineitems.return_value \ - .uploadlineitems.assert_called_once() + get_conn_mock.return_value.lineitems.return_value.uploadlineitems.assert_called_once() @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -235,25 +216,21 @@ def test_upload_line_items_should_be_called_with_params(self, get_conn_mock): self.hook.upload_line_items(line_items) - get_conn_mock.return_value \ - .lineitems.return_value \ - .uploadlineitems.assert_called_once_with(body=request_body) + get_conn_mock.return_value.lineitems.return_value.uploadlineitems.assert_called_once_with( + body=request_body + ) @mock.patch( "airflow.providers.google.marketing_platform.hooks." "display_video.GoogleDisplayVideo360Hook.get_conn" ) def test_upload_line_items_should_return_equal_values(self, get_conn_mock): - line_items = { - "lineItems": "string", - "format": "string", - "dryRun": False - } + line_items = {"lineItems": "string", "format": "string", "dryRun": False} return_value = "TEST" - get_conn_mock.return_value \ - .lineitems.return_value \ - .uploadlineitems.return_value \ - .execute.return_value = return_value + # fmt: off + get_conn_mock.return_value.lineitems.return_value \ + .uploadlineitems.return_value.execute.return_value = return_value + # fmt: on result = self.hook.upload_line_items(line_items) self.assertEqual(return_value, result) @@ -262,9 +239,7 @@ def test_upload_line_items_should_return_equal_values(self, get_conn_mock): "airflow.providers.google.marketing_platform.hooks." "display_video.GoogleDisplayVideo360Hook.get_conn_to_display_video" ) - def test_create_sdf_download_tasks_called_with_params( - self, get_conn_to_display_video - ): + def test_create_sdf_download_tasks_called_with_params(self, get_conn_to_display_video): body_request = { "version": "version", "partnerId": "partner_id", @@ -302,9 +277,7 @@ def test_create_sdf_download_tasks_called_once(self, get_conn_to_display_video): "airflow.providers.google.marketing_platform.hooks." "display_video.GoogleDisplayVideo360Hook.get_conn_to_display_video" ) - def test_create_sdf_download_tasks_return_equal_values( - self, get_conn_to_display_video - ): + def test_create_sdf_download_tasks_return_equal_values(self, get_conn_to_display_video): response = ["name"] body_request = { "version": "version", @@ -315,10 +288,12 @@ def test_create_sdf_download_tasks_return_equal_values( "inventorySourceFilter": "inventory_source_filter", } - get_conn_to_display_video.return_value.\ - sdfdownloadtasks.return_value.\ - create.return_value\ + # fmt: off + get_conn_to_display_video.return_value. \ + sdfdownloadtasks.return_value. \ + create.return_value \ .execute.return_value = response + # fmt: on result = self.hook.create_sdf_download_operation(body_request=body_request) self.assertEqual(response, result) @@ -330,10 +305,12 @@ def test_create_sdf_download_tasks_return_equal_values( def test_get_sdf_download_tasks_called_with_params(self, get_conn_to_display_video): operation_name = "operation_name" self.hook.get_sdf_download_operation(operation_name=operation_name) - get_conn_to_display_video.return_value.\ - sdfdownloadtasks.return_value.\ - operation.return_value.\ + # fmt: off + get_conn_to_display_video.return_value. \ + sdfdownloadtasks.return_value. \ + operation.return_value. \ get.assert_called_once_with(name=operation_name) + # fmt: on @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -342,10 +319,12 @@ def test_get_sdf_download_tasks_called_with_params(self, get_conn_to_display_vid def test_get_sdf_download_tasks_called_once(self, get_conn_to_display_video): operation_name = "name" self.hook.get_sdf_download_operation(operation_name=operation_name) - get_conn_to_display_video.return_value.\ - sdfdownloadtasks.return_value.\ - operation.return_value.\ + # fmt: off + get_conn_to_display_video.return_value. \ + sdfdownloadtasks.return_value. \ + operation.return_value. \ get.assert_called_once() + # fmt: on @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -355,9 +334,9 @@ def get_sdf_download_tasks_return_equal_values(self, get_conn_to_display_video): operation_name = "operation" response = "reposonse" - get_conn_to_display_video.return_value.\ - sdfdownloadtasks.return_value.\ - operation.return_value.get = response + get_conn_to_display_video.return_value.sdfdownloadtasks.return_value.operation.return_value.get = ( + response + ) result = self.hook.get_sdf_download_operation(operation_name=operation_name) @@ -371,9 +350,7 @@ def test_download_media_called_once(self, get_conn_to_display_video): resource_name = "resource_name" self.hook.download_media(resource_name=resource_name) - get_conn_to_display_video.return_value.\ - media.return_value.\ - download_media.assert_called_once() + get_conn_to_display_video.return_value.media.return_value.download_media.assert_called_once() @mock.patch( "airflow.providers.google.marketing_platform.hooks." @@ -383,6 +360,6 @@ def test_download_media_called_once_with_params(self, get_conn_to_display_video) resource_name = "resource_name" self.hook.download_media(resource_name=resource_name) - get_conn_to_display_video.return_value.\ - media.return_value.\ - download_media.assert_called_once_with(resource_name=resource_name) + get_conn_to_display_video.return_value.media.return_value.download_media.assert_called_once_with( + resource_name=resource_name + ) diff --git a/tests/providers/google/marketing_platform/hooks/test_search_ads.py b/tests/providers/google/marketing_platform/hooks/test_search_ads.py index babb5f19ef349..f75b5593ee85e 100644 --- a/tests/providers/google/marketing_platform/hooks/test_search_ads.py +++ b/tests/providers/google/marketing_platform/hooks/test_search_ads.py @@ -33,23 +33,18 @@ def setUp(self): self.hook = GoogleSearchAdsHook(gcp_conn_id=GCP_CONN_ID) @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "search_ads.GoogleSearchAdsHook._authorize" + "airflow.providers.google.marketing_platform.hooks." "search_ads.GoogleSearchAdsHook._authorize" ) @mock.patch("airflow.providers.google.marketing_platform.hooks.search_ads.build") def test_gen_conn(self, mock_build, mock_authorize): result = self.hook.get_conn() mock_build.assert_called_once_with( - "doubleclicksearch", - API_VERSION, - http=mock_authorize.return_value, - cache_discovery=False, + "doubleclicksearch", API_VERSION, http=mock_authorize.return_value, cache_discovery=False, ) self.assertEqual(mock_build.return_value, result) @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "search_ads.GoogleSearchAdsHook.get_conn" + "airflow.providers.google.marketing_platform.hooks." "search_ads.GoogleSearchAdsHook.get_conn" ) def test_insert(self, get_conn_mock): report = {"report": "test"} @@ -61,35 +56,27 @@ def test_insert(self, get_conn_mock): result = self.hook.insert_report(report=report) - get_conn_mock.return_value.reports.return_value.request.assert_called_once_with( - body=report - ) + get_conn_mock.return_value.reports.return_value.request.assert_called_once_with(body=report) self.assertEqual(return_value, result) @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "search_ads.GoogleSearchAdsHook.get_conn" + "airflow.providers.google.marketing_platform.hooks." "search_ads.GoogleSearchAdsHook.get_conn" ) def test_get(self, get_conn_mock): report_id = "REPORT_ID" return_value = "TEST" - get_conn_mock.return_value.reports.return_value.get.return_value.execute.return_value = ( - return_value - ) + get_conn_mock.return_value.reports.return_value.get.return_value.execute.return_value = return_value result = self.hook.get(report_id=report_id) - get_conn_mock.return_value.reports.return_value.get.assert_called_once_with( - reportId=report_id - ) + get_conn_mock.return_value.reports.return_value.get.assert_called_once_with(reportId=report_id) self.assertEqual(return_value, result) @mock.patch( - "airflow.providers.google.marketing_platform.hooks." - "search_ads.GoogleSearchAdsHook.get_conn" + "airflow.providers.google.marketing_platform.hooks." "search_ads.GoogleSearchAdsHook.get_conn" ) def test_get_file(self, get_conn_mock): report_fragment = 42 @@ -100,9 +87,7 @@ def test_get_file(self, get_conn_mock): return_value ) - result = self.hook.get_file( - report_fragment=report_fragment, report_id=report_id - ) + result = self.hook.get_file(report_fragment=report_fragment, report_id=report_id) get_conn_mock.return_value.reports.return_value.getFile.assert_called_once_with( reportFragment=report_fragment, reportId=report_id diff --git a/tests/providers/google/marketing_platform/operators/test_analytics.py b/tests/providers/google/marketing_platform/operators/test_analytics.py index 393fb3749a6e4..4cbd3cd292c28 100644 --- a/tests/providers/google/marketing_platform/operators/test_analytics.py +++ b/tests/providers/google/marketing_platform/operators/test_analytics.py @@ -20,9 +20,12 @@ from unittest import mock from airflow.providers.google.marketing_platform.operators.analytics import ( - GoogleAnalyticsDataImportUploadOperator, GoogleAnalyticsDeletePreviousDataUploadsOperator, - GoogleAnalyticsGetAdsLinkOperator, GoogleAnalyticsListAccountsOperator, - GoogleAnalyticsModifyFileHeadersDataImportOperator, GoogleAnalyticsRetrieveAdsLinksListOperator, + GoogleAnalyticsDataImportUploadOperator, + GoogleAnalyticsDeletePreviousDataUploadsOperator, + GoogleAnalyticsGetAdsLinkOperator, + GoogleAnalyticsListAccountsOperator, + GoogleAnalyticsModifyFileHeadersDataImportOperator, + GoogleAnalyticsRetrieveAdsLinksListOperator, ) WEB_PROPERTY_AD_WORDS_LINK_ID = "AAIIRRFFLLOOWW" @@ -37,10 +40,7 @@ class TestGoogleAnalyticsListAccountsOperator(unittest.TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "analytics.GoogleAnalyticsHook" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GoogleAnalyticsHook") def test_execute(self, hook_mock): op = GoogleAnalyticsListAccountsOperator( api_version=API_VERSION, @@ -54,10 +54,7 @@ def test_execute(self, hook_mock): class TestGoogleAnalyticsRetrieveAdsLinksListOperator(unittest.TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "analytics.GoogleAnalyticsHook" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GoogleAnalyticsHook") def test_execute(self, hook_mock): op = GoogleAnalyticsRetrieveAdsLinksListOperator( account_id=ACCOUNT_ID, @@ -71,9 +68,7 @@ def test_execute(self, hook_mock): hook_mock.assert_called_once() hook_mock.return_value.list_ad_words_links.assert_called_once() hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - api_version=API_VERSION, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, impersonation_chain=IMPERSONATION_CHAIN, ) hook_mock.return_value.list_ad_words_links.assert_called_once_with( account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID @@ -81,10 +76,7 @@ def test_execute(self, hook_mock): class TestGoogleAnalyticsGetAdsLinkOperator(unittest.TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "analytics.GoogleAnalyticsHook" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GoogleAnalyticsHook") def test_execute(self, hook_mock): op = GoogleAnalyticsGetAdsLinkOperator( account_id=ACCOUNT_ID, @@ -99,9 +91,7 @@ def test_execute(self, hook_mock): hook_mock.assert_called_once() hook_mock.return_value.get_ad_words_link.assert_called_once() hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - api_version=API_VERSION, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, api_version=API_VERSION, impersonation_chain=IMPERSONATION_CHAIN, ) hook_mock.return_value.get_ad_words_link.assert_called_once_with( account_id=ACCOUNT_ID, @@ -111,16 +101,9 @@ def test_execute(self, hook_mock): class TestGoogleAnalyticsDataImportUploadOperator(unittest.TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "analytics.GoogleAnalyticsHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." "analytics.GCSHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators.analytics.NamedTemporaryFile" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GoogleAnalyticsHook") + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GCSHook") + @mock.patch("airflow.providers.google.marketing_platform.operators.analytics.NamedTemporaryFile") def test_execute(self, mock_tempfile, gcs_hook_mock, ga_hook_mock): filename = "file/" mock_tempfile.return_value.__enter__.return_value.name = filename @@ -139,9 +122,7 @@ def test_execute(self, mock_tempfile, gcs_hook_mock, ga_hook_mock): op.execute(context=None) gcs_hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=IMPERSONATION_CHAIN, ) ga_hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -160,10 +141,7 @@ def test_execute(self, mock_tempfile, gcs_hook_mock, ga_hook_mock): class TestGoogleAnalyticsDeletePreviousDataUploadsOperator: - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "analytics.GoogleAnalyticsHook" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GoogleAnalyticsHook") def test_execute(self, mock_hook): mock_hook.return_value.list_uploads.return_value = [ {"id": 1}, @@ -181,23 +159,15 @@ def test_execute(self, mock_hook): ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) mock_hook.return_value.list_uploads.assert_called_once_with( - account_id=ACCOUNT_ID, - web_property_id=WEB_PROPERTY_ID, - custom_data_source_id=DATA_SOURCE, + account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID, custom_data_source_id=DATA_SOURCE, ) mock_hook.return_value.delete_upload_data.assert_called_once_with( - ACCOUNT_ID, - WEB_PROPERTY_ID, - DATA_SOURCE, - {"customDataImportUids": [1, 2, 3]}, + ACCOUNT_ID, WEB_PROPERTY_ID, DATA_SOURCE, {"customDataImportUids": [1, 2, 3]}, ) @@ -244,16 +214,12 @@ def test_modify_column_headers(self): with open(tmp.name) as f: assert expected_data == f.read() - @mock.patch( - "airflow.providers.google.marketing_platform.operators." "analytics.GCSHook" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "analytics.GCSHook") @mock.patch( "airflow.providers.google.marketing_platform.operators." "analytics.GoogleAnalyticsModifyFileHeadersDataImportOperator._modify_column_headers" ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators.analytics.NamedTemporaryFile" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators.analytics.NamedTemporaryFile") def test_execute(self, mock_tempfile, mock_modify, mock_hook): mapping = {"a": "b"} filename = "file/" @@ -270,9 +236,7 @@ def test_execute(self, mock_tempfile, mock_modify, mock_hook): op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.download.assert_called_once_with( @@ -283,6 +247,4 @@ def test_execute(self, mock_tempfile, mock_modify, mock_hook): tmp_file_location=filename, custom_dimension_header_mapping=mapping ) - mock_hook.return_value.upload( - bucket_name=BUCKET, object_name=BUCKET_OBJECT_NAME, filename=filename - ) + mock_hook.return_value.upload(bucket_name=BUCKET, object_name=BUCKET_OBJECT_NAME, filename=filename) diff --git a/tests/providers/google/marketing_platform/operators/test_campaign_manager.py b/tests/providers/google/marketing_platform/operators/test_campaign_manager.py index 04ad3048bcf92..e9e2145676a45 100644 --- a/tests/providers/google/marketing_platform/operators/test_campaign_manager.py +++ b/tests/providers/google/marketing_platform/operators/test_campaign_manager.py @@ -20,9 +20,12 @@ from unittest import TestCase, mock from airflow.providers.google.marketing_platform.operators.campaign_manager import ( - GoogleCampaignManagerBatchInsertConversionsOperator, GoogleCampaignManagerBatchUpdateConversionsOperator, - GoogleCampaignManagerDeleteReportOperator, GoogleCampaignManagerDownloadReportOperator, - GoogleCampaignManagerInsertReportOperator, GoogleCampaignManagerRunReportOperator, + GoogleCampaignManagerBatchInsertConversionsOperator, + GoogleCampaignManagerBatchUpdateConversionsOperator, + GoogleCampaignManagerDeleteReportOperator, + GoogleCampaignManagerDownloadReportOperator, + GoogleCampaignManagerInsertReportOperator, + GoogleCampaignManagerRunReportOperator, ) API_VERSION = "api_version" @@ -34,40 +37,24 @@ "floodlightConfigurationId": 1234, "gclid": "971nc2849184c1914019v1c34c14", "ordinal": "0", - "customVariables": [ - { - "kind": "dfareporting#customFloodlightVariable", - "type": "U10", - "value": "value", - } - ], + "customVariables": [{"kind": "dfareporting#customFloodlightVariable", "type": "U10", "value": "value",}], } class TestGoogleCampaignManagerDeleteReportOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.BaseOperator") def test_execute(self, mock_base_op, hook_mock): profile_id = "PROFILE_ID" report_id = "REPORT_ID" op = GoogleCampaignManagerDeleteReportOperator( - profile_id=profile_id, - report_id=report_id, - api_version=API_VERSION, - task_id="test_task", + profile_id=profile_id, report_id=report_id, api_version=API_VERSION, task_id="test_task", ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.delete_report.assert_called_once_with( profile_id=profile_id, report_id=report_id @@ -75,38 +62,19 @@ def test_execute(self, mock_base_op, hook_mock): class TestGoogleCampaignManagerGetReportOperator(TestCase): + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.http") + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.tempfile") @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.http" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.tempfile" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GCSHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.GCSHook") + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.BaseOperator") @mock.patch( "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerDownloadReportOperator.xcom_push" ) def test_execute( - self, - xcom_mock, - mock_base_op, - gcs_hook_mock, - hook_mock, - tempfile_mock, - http_mock, + self, xcom_mock, mock_base_op, gcs_hook_mock, hook_mock, tempfile_mock, http_mock, ): profile_id = "PROFILE_ID" report_id = "REPORT_ID" @@ -119,9 +87,7 @@ def test_execute( None, True, ) - tempfile_mock.NamedTemporaryFile.return_value.__enter__.return_value.name = ( - temp_file_name - ) + tempfile_mock.NamedTemporaryFile.return_value.__enter__.return_value.name = temp_file_name op = GoogleCampaignManagerDownloadReportOperator( profile_id=profile_id, report_id=report_id, @@ -133,18 +99,13 @@ def test_execute( ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.get_report_file.assert_called_once_with( profile_id=profile_id, report_id=report_id, file_id=file_id ) gcs_hook_mock.assert_called_once_with( - google_cloud_storage_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=None, + google_cloud_storage_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=None, ) gcs_hook_mock.return_value.upload.assert_called_once_with( bucket_name=bucket_name, @@ -153,20 +114,14 @@ def test_execute( filename=temp_file_name, mime_type="text/csv", ) - xcom_mock.assert_called_once_with( - None, key="report_name", value=report_name + ".gz" - ) + xcom_mock.assert_called_once_with(None, key="report_name", value=report_name + ".gz") class TestGoogleCampaignManagerInsertReportOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.BaseOperator") @mock.patch( "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerInsertReportOperator.xcom_push" @@ -179,21 +134,13 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): hook_mock.return_value.insert_report.return_value = {"id": report_id} op = GoogleCampaignManagerInsertReportOperator( - profile_id=profile_id, - report=report, - api_version=API_VERSION, - task_id="test_task", + profile_id=profile_id, report=report, api_version=API_VERSION, task_id="test_task", ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, - ) - hook_mock.return_value.insert_report.assert_called_once_with( - profile_id=profile_id, report=report + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) + hook_mock.return_value.insert_report.assert_called_once_with(profile_id=profile_id, report=report) xcom_mock.assert_called_once_with(None, key="report_id", value=report_id) def test_prepare_template(self): @@ -203,10 +150,7 @@ def test_prepare_template(self): f.write(json.dumps(report)) f.flush() op = GoogleCampaignManagerInsertReportOperator( - profile_id=profile_id, - report=f.name, - api_version=API_VERSION, - task_id="test_task", + profile_id=profile_id, report=f.name, api_version=API_VERSION, task_id="test_task", ) op.prepare_template() @@ -216,13 +160,9 @@ def test_prepare_template(self): class TestGoogleCampaignManagerRunReportOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.BaseOperator") @mock.patch( "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerRunReportOperator.xcom_push" @@ -244,10 +184,7 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.run_report.assert_called_once_with( profile_id=profile_id, report_id=report_id, synchronous=synchronous @@ -257,13 +194,9 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): class TestGoogleCampaignManagerBatchInsertConversionsOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.BaseOperator") def test_execute(self, mock_base_op, hook_mock): profile_id = "PROFILE_ID" op = GoogleCampaignManagerBatchInsertConversionsOperator( @@ -287,13 +220,9 @@ def test_execute(self, mock_base_op, hook_mock): class TestGoogleCampaignManagerBatchUpdateConversionOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "campaign_manager.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "campaign_manager.BaseOperator") def test_execute(self, mock_base_op, hook_mock): profile_id = "PROFILE_ID" op = GoogleCampaignManagerBatchUpdateConversionsOperator( diff --git a/tests/providers/google/marketing_platform/operators/test_display_video.py b/tests/providers/google/marketing_platform/operators/test_display_video.py index c048c27399b3d..096b01972765c 100644 --- a/tests/providers/google/marketing_platform/operators/test_display_video.py +++ b/tests/providers/google/marketing_platform/operators/test_display_video.py @@ -21,10 +21,14 @@ from unittest import TestCase, mock from airflow.providers.google.marketing_platform.operators.display_video import ( - GoogleDisplayVideo360CreateReportOperator, GoogleDisplayVideo360CreateSDFDownloadTaskOperator, - GoogleDisplayVideo360DeleteReportOperator, GoogleDisplayVideo360DownloadLineItemsOperator, - GoogleDisplayVideo360DownloadReportOperator, GoogleDisplayVideo360RunReportOperator, - GoogleDisplayVideo360SDFtoGCSOperator, GoogleDisplayVideo360UploadLineItemsOperator, + GoogleDisplayVideo360CreateReportOperator, + GoogleDisplayVideo360CreateSDFDownloadTaskOperator, + GoogleDisplayVideo360DeleteReportOperator, + GoogleDisplayVideo360DownloadLineItemsOperator, + GoogleDisplayVideo360DownloadReportOperator, + GoogleDisplayVideo360RunReportOperator, + GoogleDisplayVideo360SDFtoGCSOperator, + GoogleDisplayVideo360UploadLineItemsOperator, ) API_VERSION = "api_version" @@ -39,13 +43,9 @@ class TestGoogleDisplayVideo360CreateReportOperator(TestCase): "display_video.GoogleDisplayVideo360CreateReportOperator.xcom_push" ) @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.BaseOperator") def test_execute(self, mock_base_op, hook_mock, xcom_mock): body = {"body": "test"} query_id = "TEST" @@ -55,10 +55,7 @@ def test_execute(self, mock_base_op, hook_mock, xcom_mock): ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.create_query.assert_called_once_with(query=body) xcom_mock.assert_called_once_with(None, key="report_id", value=query_id) @@ -79,13 +76,9 @@ def test_prepare_template(self): class TestGoogleDisplayVideo360DeleteReportOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.BaseOperator") def test_execute(self, mock_base_op, hook_mock): query_id = "QUERY_ID" op = GoogleDisplayVideo360DeleteReportOperator( @@ -93,50 +86,26 @@ def test_execute(self, mock_base_op, hook_mock): ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.delete_query.assert_called_once_with(query_id=query_id) class TestGoogleDisplayVideo360GetReportOperator(TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform.operators." "display_video.shutil" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.urllib.request" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.tempfile" - ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.shutil") + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.urllib.request") + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.tempfile") @mock.patch( "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360DownloadReportOperator.xcom_push" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.GCSHook") @mock.patch( - "airflow.providers.google.marketing_platform.operators." "display_video.GCSHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.BaseOperator") def test_execute( - self, - mock_base_op, - mock_hook, - mock_gcs_hook, - mock_xcom, - mock_temp, - mock_reuqest, - mock_shutil, + self, mock_base_op, mock_hook, mock_gcs_hook, mock_xcom, mock_temp, mock_reuqest, mock_shutil, ): report_id = "REPORT_ID" bucket_name = "BUCKET" @@ -144,10 +113,7 @@ def test_execute( filename = "test" mock_temp.NamedTemporaryFile.return_value.__enter__.return_value.name = filename mock_hook.return_value.get_query.return_value = { - "metadata": { - "running": False, - "googleCloudStoragePathForLatestReport": "test", - } + "metadata": {"running": False, "googleCloudStoragePathForLatestReport": "test",} } op = GoogleDisplayVideo360DownloadReportOperator( report_id=report_id, @@ -158,17 +124,12 @@ def test_execute( ) op.execute(context=None) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) mock_hook.return_value.get_query.assert_called_once_with(query_id=report_id) mock_gcs_hook.assert_called_once_with( - google_cloud_storage_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=None, + google_cloud_storage_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=None, ) mock_gcs_hook.return_value.upload.assert_called_once_with( bucket_name=bucket_name, @@ -177,53 +138,33 @@ def test_execute( mime_type="text/csv", object_name=report_name + ".gz", ) - mock_xcom.assert_called_once_with( - None, key="report_name", value=report_name + ".gz" - ) + mock_xcom.assert_called_once_with(None, key="report_name", value=report_name + ".gz") class TestGoogleDisplayVideo360RunReportOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.BaseOperator" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.BaseOperator") def test_execute(self, mock_base_op, hook_mock): report_id = "QUERY_ID" params = {"param": "test"} op = GoogleDisplayVideo360RunReportOperator( - report_id=report_id, - params=params, - api_version=API_VERSION, - task_id="test_task", + report_id=report_id, params=params, api_version=API_VERSION, task_id="test_task", ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, - ) - hook_mock.return_value.run_query.assert_called_once_with( - query_id=report_id, params=params + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) + hook_mock.return_value.run_query.assert_called_once_with(query_id=report_id, params=params) class TestGoogleDisplayVideo360DownloadLineItemsOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." "display_video.GCSHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.tempfile" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.GCSHook") + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.tempfile") def test_execute(self, mock_temp, gcs_hook_mock, hook_mock): request_body = { "filterType": "filter_type", @@ -260,9 +201,7 @@ def test_execute(self, mock_temp, gcs_hook_mock, hook_mock): ) gcs_hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) hook_mock.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -270,32 +209,22 @@ def test_execute(self, mock_temp, gcs_hook_mock, hook_mock): delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) - hook_mock.return_value.download_line_items.assert_called_once_with( - request_body=request_body - ) + hook_mock.return_value.download_line_items.assert_called_once_with(request_body=request_body) class TestGoogleDisplayVideo360UploadLineItemsOperator(TestCase): + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.tempfile") @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.tempfile" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." "display_video.GCSHook" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.GCSHook") def test_execute(self, gcs_hook_mock, hook_mock, mock_tempfile): filename = "filename" object_name = "object_name" bucket_name = "bucket_name" line_items = "holy_hand_grenade" gcs_hook_mock.return_value.download.return_value = line_items - mock_tempfile.NamedTemporaryFile.return_value.__enter__.return_value.name = ( - filename - ) + mock_tempfile.NamedTemporaryFile.return_value.__enter__.return_value.name = filename op = GoogleDisplayVideo360UploadLineItemsOperator( bucket_name=bucket_name, @@ -313,32 +242,22 @@ def test_execute(self, gcs_hook_mock, hook_mock, mock_tempfile): ) gcs_hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=None, ) gcs_hook_mock.return_value.download.assert_called_once_with( bucket_name=bucket_name, object_name=object_name, filename=filename, ) hook_mock.return_value.upload_line_items.assert_called_once() - hook_mock.return_value.upload_line_items.assert_called_once_with( - line_items=line_items - ) + hook_mock.return_value.upload_line_items.assert_called_once_with(line_items=line_items) class TestGoogleDisplayVideo360SDFtoGCSOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." "display_video.GCSHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.tempfile" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.GCSHook") + @mock.patch("airflow.providers.google.marketing_platform.operators." "display_video.tempfile") def test_execute(self, mock_temp, gcs_mock_hook, mock_hook): operation_name = "operation_name" operation = {"key": "value"} @@ -389,24 +308,18 @@ def test_execute(self, mock_temp, gcs_mock_hook, mock_hook): gcs_mock_hook.assert_called_once() gcs_mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=DELEGATE_TO, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN, ) gcs_mock_hook.return_value.upload.assert_called_once() gcs_mock_hook.return_value.upload.assert_called_once_with( - bucket_name=bucket_name, - object_name=object_name, - filename=filename, - gzip=gzip, + bucket_name=bucket_name, object_name=object_name, filename=filename, gzip=gzip, ) class TestGoogleDisplayVideo360CreateSDFDownloadTaskOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.operators." - "display_video.GoogleDisplayVideo360Hook" + "airflow.providers.google.marketing_platform.operators." "display_video.GoogleDisplayVideo360Hook" ) def test_execute(self, mock_hook): body_request = { @@ -416,10 +329,7 @@ def test_execute(self, mock_hook): } op = GoogleDisplayVideo360CreateSDFDownloadTaskOperator( - body_request=body_request, - api_version=API_VERSION, - gcp_conn_id=GCP_CONN_ID, - task_id="test_task", + body_request=body_request, api_version=API_VERSION, gcp_conn_id=GCP_CONN_ID, task_id="test_task", ) op.execute(context=None) diff --git a/tests/providers/google/marketing_platform/operators/test_display_video_system.py b/tests/providers/google/marketing_platform/operators/test_display_video_system.py index 64ee8d8b15883..a29bbb3f9297b 100644 --- a/tests/providers/google/marketing_platform/operators/test_display_video_system.py +++ b/tests/providers/google/marketing_platform/operators/test_display_video_system.py @@ -25,14 +25,13 @@ SCOPES = [ "https://www.googleapis.com/auth/doubleclickbidmanager", "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/display-video" + "https://www.googleapis.com/auth/display-video", ] @pytest.mark.system("google.marketing_platform") @pytest.mark.credential_file(GMP_KEY) class DisplayVideoSystemTest(GoogleSystemTest): - def setUp(self): super().setUp() self.create_gcs_bucket(BUCKET) diff --git a/tests/providers/google/marketing_platform/operators/test_search_ads.py b/tests/providers/google/marketing_platform/operators/test_search_ads.py index 4af921c2bdee4..ae0435ebaf76b 100644 --- a/tests/providers/google/marketing_platform/operators/test_search_ads.py +++ b/tests/providers/google/marketing_platform/operators/test_search_ads.py @@ -20,7 +20,8 @@ from unittest import TestCase, mock from airflow.providers.google.marketing_platform.operators.search_ads import ( - GoogleSearchAdsDownloadReportOperator, GoogleSearchAdsInsertReportOperator, + GoogleSearchAdsDownloadReportOperator, + GoogleSearchAdsInsertReportOperator, ) API_VERSION = "api_version" @@ -28,14 +29,8 @@ class TestGoogleSearchAdsInsertReportOperator(TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform." - "operators.search_ads.GoogleSearchAdsHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform." - "operators.search_ads.BaseOperator" - ) + @mock.patch("airflow.providers.google.marketing_platform." "operators.search_ads.GoogleSearchAdsHook") + @mock.patch("airflow.providers.google.marketing_platform." "operators.search_ads.BaseOperator") @mock.patch( "airflow.providers.google.marketing_platform." "operators.search_ads.GoogleSearchAdsInsertReportOperator.xcom_push" @@ -44,15 +39,10 @@ def test_execute(self, xcom_mock, mock_base_op, hook_mock): report = {"report": "test"} report_id = "TEST" hook_mock.return_value.insert_report.return_value = {"id": report_id} - op = GoogleSearchAdsInsertReportOperator( - report=report, api_version=API_VERSION, task_id="test_task" - ) + op = GoogleSearchAdsInsertReportOperator(report=report, api_version=API_VERSION, task_id="test_task") op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.insert_report.assert_called_once_with(report=report) xcom_mock.assert_called_once_with(None, key="report_id", value=report_id) @@ -72,29 +62,15 @@ def test_prepare_template(self): class TestGoogleSearchAdsDownloadReportOperator(TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform." - "operators.search_ads.NamedTemporaryFile" - ) - @mock.patch( - "airflow.providers.google.marketing_platform." - "operators.search_ads.GCSHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform." - "operators.search_ads.GoogleSearchAdsHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform." - "operators.search_ads.BaseOperator" - ) + @mock.patch("airflow.providers.google.marketing_platform." "operators.search_ads.NamedTemporaryFile") + @mock.patch("airflow.providers.google.marketing_platform." "operators.search_ads.GCSHook") + @mock.patch("airflow.providers.google.marketing_platform." "operators.search_ads.GoogleSearchAdsHook") + @mock.patch("airflow.providers.google.marketing_platform." "operators.search_ads.BaseOperator") @mock.patch( "airflow.providers.google.marketing_platform." "operators.search_ads.GoogleSearchAdsDownloadReportOperator.xcom_push" ) - def test_execute( - self, xcom_mock, mock_base_op, hook_mock, gcs_hook_mock, tempfile_mock - ): + def test_execute(self, xcom_mock, mock_base_op, hook_mock, gcs_hook_mock, tempfile_mock): report_id = "REPORT_ID" file_name = "TEST" temp_file_name = "TEMP" @@ -114,23 +90,11 @@ def test_execute( ) op.execute(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, - ) - hook_mock.return_value.get_file.assert_called_once_with( - report_fragment=0, report_id=report_id - ) - tempfile_mock.return_value.__enter__.return_value.write.assert_called_once_with( - data + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) + hook_mock.return_value.get_file.assert_called_once_with(report_fragment=0, report_id=report_id) + tempfile_mock.return_value.__enter__.return_value.write.assert_called_once_with(data) gcs_hook_mock.return_value.upload.assert_called_once_with( - bucket_name=bucket_name, - gzip=True, - object_name=file_name + ".csv.gz", - filename=temp_file_name, - ) - xcom_mock.assert_called_once_with( - None, key="file_name", value=file_name + ".csv.gz" + bucket_name=bucket_name, gzip=True, object_name=file_name + ".csv.gz", filename=temp_file_name, ) + xcom_mock.assert_called_once_with(None, key="file_name", value=file_name + ".csv.gz") diff --git a/tests/providers/google/marketing_platform/operators/test_search_ads_system.py b/tests/providers/google/marketing_platform/operators/test_search_ads_system.py index ae3b974d2c5fc..9c6dfe99b134f 100644 --- a/tests/providers/google/marketing_platform/operators/test_search_ads_system.py +++ b/tests/providers/google/marketing_platform/operators/test_search_ads_system.py @@ -24,14 +24,13 @@ # Requires the following scope: SCOPES = [ "https://www.googleapis.com/auth/doubleclicksearch", - "https://www.googleapis.com/auth/cloud-platform" + "https://www.googleapis.com/auth/cloud-platform", ] @pytest.mark.system("google.marketing_platform") @pytest.mark.credential_file(GMP_KEY) class SearchAdsSystemTest(GoogleSystemTest): - def setUp(self): super().setUp() self.create_gcs_bucket(GCS_BUCKET) diff --git a/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py b/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py index 72fbaa24c3a43..5e968072ad515 100644 --- a/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py +++ b/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py @@ -27,13 +27,9 @@ class TestGoogleCampaignManagerDeleteReportOperator(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "campaign_manager.GoogleCampaignManagerHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "campaign_manager.BaseSensorOperator" + "airflow.providers.google.marketing_platform.sensors." "campaign_manager.GoogleCampaignManagerHook" ) + @mock.patch("airflow.providers.google.marketing_platform.sensors." "campaign_manager.BaseSensorOperator") def test_execute(self, mock_base_op, hook_mock): profile_id = "PROFILE_ID" report_id = "REPORT_ID" @@ -50,10 +46,7 @@ def test_execute(self, mock_base_op, hook_mock): ) result = op.poke(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.get_report.assert_called_once_with( profile_id=profile_id, report_id=report_id, file_id=file_id diff --git a/tests/providers/google/marketing_platform/sensors/test_display_video.py b/tests/providers/google/marketing_platform/sensors/test_display_video.py index a9bee5578976f..84ddf28ff3c3b 100644 --- a/tests/providers/google/marketing_platform/sensors/test_display_video.py +++ b/tests/providers/google/marketing_platform/sensors/test_display_video.py @@ -19,7 +19,8 @@ from unittest import TestCase, mock from airflow.providers.google.marketing_platform.sensors.display_video import ( - GoogleDisplayVideo360GetSDFDownloadOperationSensor, GoogleDisplayVideo360ReportSensor, + GoogleDisplayVideo360GetSDFDownloadOperationSensor, + GoogleDisplayVideo360ReportSensor, ) API_VERSION = "api_version" @@ -28,13 +29,9 @@ class TestGoogleDisplayVideo360ReportSensor(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "display_video.BaseSensorOperator" + "airflow.providers.google.marketing_platform.sensors." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.sensors." "display_video.BaseSensorOperator") def test_poke(self, mock_base_op, hook_mock): report_id = "REPORT_ID" op = GoogleDisplayVideo360ReportSensor( @@ -42,23 +39,16 @@ def test_poke(self, mock_base_op, hook_mock): ) op.poke(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.get_query.assert_called_once_with(query_id=report_id) class TestGoogleDisplayVideo360Sensor(TestCase): @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "display_video.GoogleDisplayVideo360Hook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "display_video.BaseSensorOperator" + "airflow.providers.google.marketing_platform.sensors." "display_video.GoogleDisplayVideo360Hook" ) + @mock.patch("airflow.providers.google.marketing_platform.sensors." "display_video.BaseSensorOperator") def test_poke(self, mock_base_op, hook_mock): operation_name = "operation_name" op = GoogleDisplayVideo360GetSDFDownloadOperationSensor( @@ -66,10 +56,7 @@ def test_poke(self, mock_base_op, hook_mock): ) op.poke(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.get_sdf_download_operation.assert_called_once_with( operation_name=operation_name diff --git a/tests/providers/google/marketing_platform/sensors/test_search_ads.py b/tests/providers/google/marketing_platform/sensors/test_search_ads.py index 38274348a5acf..bd6b056d43b4d 100644 --- a/tests/providers/google/marketing_platform/sensors/test_search_ads.py +++ b/tests/providers/google/marketing_platform/sensors/test_search_ads.py @@ -24,24 +24,13 @@ class TestSearchAdsReportSensor(TestCase): - @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "search_ads.GoogleSearchAdsHook" - ) - @mock.patch( - "airflow.providers.google.marketing_platform.sensors." - "search_ads.BaseSensorOperator" - ) + @mock.patch("airflow.providers.google.marketing_platform.sensors." "search_ads.GoogleSearchAdsHook") + @mock.patch("airflow.providers.google.marketing_platform.sensors." "search_ads.BaseSensorOperator") def test_poke(self, mock_base_op, hook_mock): report_id = "REPORT_ID" - op = GoogleSearchAdsReportSensor( - report_id=report_id, api_version=API_VERSION, task_id="test_task" - ) + op = GoogleSearchAdsReportSensor(report_id=report_id, api_version=API_VERSION, task_id="test_task") op.poke(context=None) hook_mock.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - api_version=API_VERSION, - impersonation_chain=None, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, api_version=API_VERSION, impersonation_chain=None, ) hook_mock.return_value.get.assert_called_once_with(report_id=report_id) diff --git a/tests/providers/google/suite/hooks/test_drive.py b/tests/providers/google/suite/hooks/test_drive.py index 6c85b837d581a..867cec26d01cb 100644 --- a/tests/providers/google/suite/hooks/test_drive.py +++ b/tests/providers/google/suite/hooks/test_drive.py @@ -37,7 +37,7 @@ def tearDown(self) -> None: @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook._authorize", - return_value="AUTHORIZE" + return_value="AUTHORIZE", ) @mock.patch("airflow.providers.google.suite.hooks.drive.build") def test_get_conn(self, mock_discovery_build, mock_authorize): @@ -190,7 +190,7 @@ def test_upload_file_to_root_directory( @mock.patch("airflow.providers.google.suite.hooks.drive.GoogleDriveHook.get_conn") @mock.patch( "airflow.providers.google.suite.hooks.drive.GoogleDriveHook._ensure_folders_exists", - return_value="PARENT_ID" + return_value="PARENT_ID", ) def test_upload_file_to_subdirectory( self, mock_ensure_folders_exists, mock_get_conn, mock_media_file_upload diff --git a/tests/providers/google/suite/hooks/test_sheets.py b/tests/providers/google/suite/hooks/test_sheets.py index 00615c22c95d0..3524af0e55882 100644 --- a/tests/providers/google/suite/hooks/test_sheets.py +++ b/tests/providers/google/suite/hooks/test_sheets.py @@ -46,8 +46,10 @@ class TestGSheetsHook(unittest.TestCase): def setUp(self): - with mock.patch('airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', - new=mock_base_gcp_hook_default_project_id): + with mock.patch( + 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__', + new=mock_base_gcp_hook_default_project_id, + ): self.hook = GSheetsHook(gcp_conn_id=GCP_CONN_ID) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook._authorize") @@ -69,7 +71,8 @@ def test_get_values(self, get_conn): range_=RANGE_, major_dimension=MAJOR_DIMENSION, value_render_option=VALUE_RENDER_OPTION, - date_time_render_option=DATE_TIME_RENDER_OPTION) + date_time_render_option=DATE_TIME_RENDER_OPTION, + ) self.assertIs(result, VALUES) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) get_method.assert_called_once_with( @@ -77,7 +80,7 @@ def test_get_values(self, get_conn): range=RANGE_, majorDimension=MAJOR_DIMENSION, valueRenderOption=VALUE_RENDER_OPTION, - dateTimeRenderOption=DATE_TIME_RENDER_OPTION + dateTimeRenderOption=DATE_TIME_RENDER_OPTION, ) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") @@ -90,7 +93,8 @@ def test_batch_get_values(self, get_conn): ranges=RANGES, major_dimension=MAJOR_DIMENSION, value_render_option=VALUE_RENDER_OPTION, - date_time_render_option=DATE_TIME_RENDER_OPTION) + date_time_render_option=DATE_TIME_RENDER_OPTION, + ) self.assertIs(result, API_RESPONSE) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) batch_get_method.assert_called_once_with( @@ -98,7 +102,7 @@ def test_batch_get_values(self, get_conn): ranges=RANGES, majorDimension=MAJOR_DIMENSION, valueRenderOption=VALUE_RENDER_OPTION, - dateTimeRenderOption=DATE_TIME_RENDER_OPTION + dateTimeRenderOption=DATE_TIME_RENDER_OPTION, ) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") @@ -114,12 +118,9 @@ def test_update_values(self, get_conn): value_input_option=VALUE_INPUT_OPTION, include_values_in_response=INCLUDE_VALUES_IN_RESPONSE, value_render_option=VALUE_RENDER_OPTION, - date_time_render_option=DATE_TIME_RENDER_OPTION) - body = { - "range": RANGE_, - "majorDimension": MAJOR_DIMENSION, - "values": VALUES - } + date_time_render_option=DATE_TIME_RENDER_OPTION, + ) + body = {"range": RANGE_, "majorDimension": MAJOR_DIMENSION, "values": VALUES} self.assertIs(result, API_RESPONSE) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) update_method.assert_called_once_with( @@ -129,7 +130,7 @@ def test_update_values(self, get_conn): includeValuesInResponse=INCLUDE_VALUES_IN_RESPONSE, responseValueRenderOption=VALUE_RENDER_OPTION, responseDateTimeRenderOption=DATE_TIME_RENDER_OPTION, - body=body + body=body, ) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") @@ -145,28 +146,22 @@ def test_batch_update_values(self, get_conn): value_input_option=VALUE_INPUT_OPTION, include_values_in_response=INCLUDE_VALUES_IN_RESPONSE, value_render_option=VALUE_RENDER_OPTION, - date_time_render_option=DATE_TIME_RENDER_OPTION) + date_time_render_option=DATE_TIME_RENDER_OPTION, + ) data = [] for idx, range_ in enumerate(RANGES): - value_range = { - "range": range_, - "majorDimension": MAJOR_DIMENSION, - "values": VALUES_BATCH[idx] - } + value_range = {"range": range_, "majorDimension": MAJOR_DIMENSION, "values": VALUES_BATCH[idx]} data.append(value_range) body = { "valueInputOption": VALUE_INPUT_OPTION, "data": data, "includeValuesInResponse": INCLUDE_VALUES_IN_RESPONSE, "responseValueRenderOption": VALUE_RENDER_OPTION, - "responseDateTimeRenderOption": DATE_TIME_RENDER_OPTION + "responseDateTimeRenderOption": DATE_TIME_RENDER_OPTION, } self.assertIs(result, API_RESPONSE) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) - batch_update_method.assert_called_once_with( - spreadsheetId=SPREADHSEET_ID, - body=body - ) + batch_update_method.assert_called_once_with(spreadsheetId=SPREADHSEET_ID, body=body) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") def test_batch_update_values_with_bad_data(self, get_conn): @@ -182,7 +177,8 @@ def test_batch_update_values_with_bad_data(self, get_conn): value_input_option=VALUE_INPUT_OPTION, include_values_in_response=INCLUDE_VALUES_IN_RESPONSE, value_render_option=VALUE_RENDER_OPTION, - date_time_render_option=DATE_TIME_RENDER_OPTION) + date_time_render_option=DATE_TIME_RENDER_OPTION, + ) batch_update_method.assert_not_called() execute_method.assert_not_called() err = cm.exception @@ -202,12 +198,9 @@ def test_append_values(self, get_conn): insert_data_option=INSERT_DATA_OPTION, include_values_in_response=INCLUDE_VALUES_IN_RESPONSE, value_render_option=VALUE_RENDER_OPTION, - date_time_render_option=DATE_TIME_RENDER_OPTION) - body = { - "range": RANGE_, - "majorDimension": MAJOR_DIMENSION, - "values": VALUES - } + date_time_render_option=DATE_TIME_RENDER_OPTION, + ) + body = {"range": RANGE_, "majorDimension": MAJOR_DIMENSION, "values": VALUES} self.assertIs(result, API_RESPONSE) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) append_method.assert_called_once_with( @@ -218,7 +211,7 @@ def test_append_values(self, get_conn): includeValuesInResponse=INCLUDE_VALUES_IN_RESPONSE, responseValueRenderOption=VALUE_RENDER_OPTION, responseDateTimeRenderOption=DATE_TIME_RENDER_OPTION, - body=body + body=body, ) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") @@ -230,10 +223,7 @@ def test_clear_values(self, get_conn): self.assertIs(result, API_RESPONSE) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) - clear_method.assert_called_once_with( - spreadsheetId=SPREADHSEET_ID, - range=RANGE_ - ) + clear_method.assert_called_once_with(spreadsheetId=SPREADHSEET_ID, range=RANGE_) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") def test_batch_clear_values(self, get_conn): @@ -244,10 +234,7 @@ def test_batch_clear_values(self, get_conn): body = {"ranges": RANGES} self.assertIs(result, API_RESPONSE) execute_method.assert_called_once_with(num_retries=NUM_RETRIES) - batch_clear_method.assert_called_once_with( - spreadsheetId=SPREADHSEET_ID, - body=body - ) + batch_clear_method.assert_called_once_with(spreadsheetId=SPREADHSEET_ID, body=body) @mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn") def test_get_spreadsheet(self, mock_get_conn): diff --git a/tests/providers/google/suite/operators/test_sheets.py b/tests/providers/google/suite/operators/test_sheets.py index eb558aa16320e..ed55308a1f420 100644 --- a/tests/providers/google/suite/operators/test_sheets.py +++ b/tests/providers/google/suite/operators/test_sheets.py @@ -41,9 +41,7 @@ def test_execute(self, mock_xcom, mock_hook): ) op.execute(context) - mock_hook.return_value.create_spreadsheet.assert_called_once_with( - spreadsheet=spreadsheet - ) + mock_hook.return_value.create_spreadsheet.assert_called_once_with(spreadsheet=spreadsheet) calls = [ mock.call(context, "spreadsheet_id", SPREADSHEET_ID), diff --git a/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py b/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py index 9344f30ffa2e8..65fe2689c08fc 100644 --- a/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py +++ b/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py @@ -57,11 +57,7 @@ def test_should_copy_single_file(self, mock_named_temporary_file, mock_gdrive, m mock_gdrive.assert_has_calls( [ - mock.call( - delegate_to=None, - gcp_conn_id="google_cloud_default", - impersonation_chain=None, - ), + mock.call(delegate_to=None, gcp_conn_id="google_cloud_default", impersonation_chain=None,), mock.call().upload_file( local_location="TMP1", remote_location="copied_sales/2017/january-backup.avro" ), diff --git a/tests/providers/google/suite/transfers/test_gcs_to_sheets.py b/tests/providers/google/suite/transfers/test_gcs_to_sheets.py index ff79152471f59..9d8cf8cbbbacb 100644 --- a/tests/providers/google/suite/transfers/test_gcs_to_sheets.py +++ b/tests/providers/google/suite/transfers/test_gcs_to_sheets.py @@ -50,14 +50,10 @@ def test_execute(self, mock_reader, mock_tempfile, mock_sheet_hook, mock_gcs_hoo op.execute(None) mock_sheet_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=IMPERSONATION_CHAIN, ) mock_gcs_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - delegate_to=None, - impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, delegate_to=None, impersonation_chain=IMPERSONATION_CHAIN, ) mock_gcs_hook.return_value.download.assert_called_once_with( @@ -67,7 +63,5 @@ def test_execute(self, mock_reader, mock_tempfile, mock_sheet_hook, mock_gcs_hoo mock_reader.assert_called_once_with(file_handle) mock_sheet_hook.return_value.update_values.assert_called_once_with( - spreadsheet_id=SPREADSHEET_ID, - range_="Sheet1", - values=VALUES, + spreadsheet_id=SPREADSHEET_ID, range_="Sheet1", values=VALUES, ) diff --git a/tests/providers/google/suite/transfers/test_gcs_to_sheets_system.py b/tests/providers/google/suite/transfers/test_gcs_to_sheets_system.py index 025f3998df309..0ab198275b7cf 100644 --- a/tests/providers/google/suite/transfers/test_gcs_to_sheets_system.py +++ b/tests/providers/google/suite/transfers/test_gcs_to_sheets_system.py @@ -31,7 +31,6 @@ @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_GCS_KEY) class GoogleSheetsToGCSExampleDagsSystemTest(GoogleSystemTest): - @provide_gcp_context(GCP_GCS_KEY) def setUp(self): super().setUp() diff --git a/tests/providers/grpc/hooks/test_grpc.py b/tests/providers/grpc/hooks/test_grpc.py index 286fd47e33498..3017a6df77718 100644 --- a/tests/providers/grpc/hooks/test_grpc.py +++ b/tests/providers/grpc/hooks/test_grpc.py @@ -26,21 +26,16 @@ def get_airflow_connection(auth_type="NO_AUTH", credential_pem_file=None, scopes=None): - extra = \ - '{{"extra__grpc__auth_type": "{auth_type}",' \ - '"extra__grpc__credential_pem_file": "{credential_pem_file}",' \ - '"extra__grpc__scopes": "{scopes}"}}' \ - .format(auth_type=auth_type, - credential_pem_file=credential_pem_file, - scopes=scopes) - - return Connection( - conn_id='grpc_default', - conn_type='grpc', - host='test:8080', - extra=extra + extra = ( + '{{"extra__grpc__auth_type": "{auth_type}",' + '"extra__grpc__credential_pem_file": "{credential_pem_file}",' + '"extra__grpc__scopes": "{scopes}"}}'.format( + auth_type=auth_type, credential_pem_file=credential_pem_file, scopes=scopes + ) ) + return Connection(conn_id='grpc_default', conn_type='grpc', host='test:8080', extra=extra) + def get_airflow_connection_with_port(): return Connection( @@ -48,7 +43,7 @@ def get_airflow_connection_with_port(): conn_type='grpc', host='test.com', port=1234, - extra='{"extra__grpc__auth_type": "NO_AUTH"}' + extra='{"extra__grpc__auth_type": "NO_AUTH"}', ) @@ -105,15 +100,10 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel): @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('grpc.ssl_channel_credentials') @mock.patch('grpc.secure_channel') - def test_connection_with_ssl(self, - mock_secure_channel, - mock_channel_credentials, - mock_get_connection, - mock_open): - conn = get_airflow_connection( - auth_type="SSL", - credential_pem_file="pem" - ) + def test_connection_with_ssl( + self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open + ): + conn = get_airflow_connection(auth_type="SSL", credential_pem_file="pem") mock_get_connection.return_value = conn mock_open.return_value = StringIO('credential') hook = GrpcHook("grpc_default") @@ -127,25 +117,17 @@ def test_connection_with_ssl(self, mock_open.assert_called_once_with("pem") mock_channel_credentials.assert_called_once_with('credential') - mock_secure_channel.assert_called_once_with( - expected_url, - mock_credential_object - ) + mock_secure_channel.assert_called_once_with(expected_url, mock_credential_object) self.assertEqual(channel, mocked_channel) @mock.patch('airflow.providers.grpc.hooks.grpc.open') @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('grpc.ssl_channel_credentials') @mock.patch('grpc.secure_channel') - def test_connection_with_tls(self, - mock_secure_channel, - mock_channel_credentials, - mock_get_connection, - mock_open): - conn = get_airflow_connection( - auth_type="TLS", - credential_pem_file="pem" - ) + def test_connection_with_tls( + self, mock_secure_channel, mock_channel_credentials, mock_get_connection, mock_open + ): + conn = get_airflow_connection(auth_type="TLS", credential_pem_file="pem") mock_get_connection.return_value = conn mock_open.return_value = StringIO('credential') hook = GrpcHook("grpc_default") @@ -159,24 +141,17 @@ def test_connection_with_tls(self, mock_open.assert_called_once_with("pem") mock_channel_credentials.assert_called_once_with('credential') - mock_secure_channel.assert_called_once_with( - expected_url, - mock_credential_object - ) + mock_secure_channel.assert_called_once_with(expected_url, mock_credential_object) self.assertEqual(channel, mocked_channel) @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('google.auth.jwt.OnDemandCredentials.from_signing_credentials') @mock.patch('google.auth.default') @mock.patch('google.auth.transport.grpc.secure_authorized_channel') - def test_connection_with_jwt(self, - mock_secure_channel, - mock_google_default_auth, - mock_google_cred, - mock_get_connection): - conn = get_airflow_connection( - auth_type="JWT_GOOGLE" - ) + def test_connection_with_jwt( + self, mock_secure_channel, mock_google_default_auth, mock_google_cred, mock_get_connection + ): + conn = get_airflow_connection(auth_type="JWT_GOOGLE") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") mocked_channel = self.channel_mock.return_value @@ -189,26 +164,17 @@ def test_connection_with_jwt(self, expected_url = "test:8080" mock_google_cred.assert_called_once_with(mock_credential_object) - mock_secure_channel.assert_called_once_with( - mock_credential_object, - None, - expected_url - ) + mock_secure_channel.assert_called_once_with(mock_credential_object, None, expected_url) self.assertEqual(channel, mocked_channel) @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('google.auth.transport.requests.Request') @mock.patch('google.auth.default') @mock.patch('google.auth.transport.grpc.secure_authorized_channel') - def test_connection_with_google_oauth(self, - mock_secure_channel, - mock_google_default_auth, - mock_google_auth_request, - mock_get_connection): - conn = get_airflow_connection( - auth_type="OATH_GOOGLE", - scopes="grpc,gcs" - ) + def test_connection_with_google_oauth( + self, mock_secure_channel, mock_google_default_auth, mock_google_auth_request, mock_get_connection + ): + conn = get_airflow_connection(auth_type="OATH_GOOGLE", scopes="grpc,gcs") mock_get_connection.return_value = conn hook = GrpcHook("grpc_default") mocked_channel = self.channel_mock.return_value @@ -221,11 +187,7 @@ def test_connection_with_google_oauth(self, expected_url = "test:8080" mock_google_default_auth.assert_called_once_with(scopes=["grpc", "gcs"]) - mock_secure_channel.assert_called_once_with( - mock_credential_object, - "request", - expected_url - ) + mock_secure_channel.assert_called_once_with(mock_credential_object, "request", expected_url) self.assertEqual(channel, mocked_channel) @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @@ -260,10 +222,9 @@ def test_connection_type_not_supported(self, mock_get_connection): @mock.patch('grpc.intercept_channel') @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('grpc.insecure_channel') - def test_connection_with_interceptors(self, - mock_insecure_channel, - mock_get_connection, - mock_intercept_channel): + def test_connection_with_interceptors( + self, mock_insecure_channel, mock_get_connection, mock_intercept_channel + ): conn = get_airflow_connection() mock_get_connection.return_value = conn mocked_channel = self.channel_mock.return_value diff --git a/tests/providers/grpc/operators/test_grpc.py b/tests/providers/grpc/operators/test_grpc.py index 292e2ca4c8335..70b17fe9bc885 100644 --- a/tests/providers/grpc/operators/test_grpc.py +++ b/tests/providers/grpc/operators/test_grpc.py @@ -37,10 +37,7 @@ def custom_conn_func(self, connection): @mock.patch('airflow.providers.grpc.operators.grpc.GrpcHook') def test_with_interceptors(self, mock_hook): operator = GrpcOperator( - stub_class=StubClass, - call_func="stream_call", - interceptors=[], - task_id="test_grpc", + stub_class=StubClass, call_func="stream_call", interceptors=[], task_id="test_grpc", ) operator.execute({}) @@ -57,7 +54,8 @@ def test_with_custom_connection_func(self, mock_hook): operator.execute({}) mock_hook.assert_called_once_with( - "grpc_default", interceptors=None, custom_connection_func=self.custom_conn_func) + "grpc_default", interceptors=None, custom_connection_func=self.custom_conn_func + ) @mock.patch('airflow.providers.grpc.operators.grpc.GrpcHook') def test_execute_with_log(self, mock_hook): @@ -65,10 +63,7 @@ def test_execute_with_log(self, mock_hook): mock_hook.return_value = mocked_hook mocked_hook.configure_mock(**{'run.return_value': ["value1", "value2"]}) operator = GrpcOperator( - stub_class=StubClass, - call_func="stream_call", - log_response=True, - task_id="test_grpc", + stub_class=StubClass, call_func="stream_call", log_response=True, task_id="test_grpc", ) with mock.patch.object(operator.log, 'info') as mock_info: @@ -87,10 +82,7 @@ def test_execute_with_callback(self, mock_hook): mock_hook.return_value = mocked_hook mocked_hook.configure_mock(**{'run.return_value': ["value1", "value2"]}) operator = GrpcOperator( - stub_class=StubClass, - call_func="stream_call", - task_id="test_grpc", - response_callback=callback + stub_class=StubClass, call_func="stream_call", task_id="test_grpc", response_callback=callback ) with mock.patch.object(operator.log, 'info') as mock_info: diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py index 585f19c02474f..3316e8ba7bad1 100644 --- a/tests/providers/hashicorp/_internal_client/test_vault_client.py +++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py @@ -26,7 +26,6 @@ class TestVaultClient(TestCase): - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_version_wrong(self, mock_hvac): mock_client = mock.MagicMock() @@ -53,8 +52,9 @@ def test_version_one_init(self, mock_hvac): def test_approle(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="approle", role_id="role", url="http://localhost:8180", - secret_id="pass") + vault_client = _VaultClient( + auth_type="approle", role_id="role", url="http://localhost:8180", secret_id="pass" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth_approle.assert_called_with(role_id="role", secret_id="pass") @@ -65,8 +65,13 @@ def test_approle(self, mock_hvac): def test_approle_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="approle", role_id="role", url="http://localhost:8180", - secret_id="pass", auth_mount_point="other") + vault_client = _VaultClient( + auth_type="approle", + role_id="role", + url="http://localhost:8180", + secret_id="pass", + auth_mount_point="other", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth_approle.assert_called_with(role_id="role", secret_id="pass", mount_point="other") @@ -78,23 +83,19 @@ def test_approle_missing_role(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client with self.assertRaisesRegex(VaultError, "requires 'role_id'"): - _VaultClient( - auth_type="approle", - url="http://localhost:8180", - secret_id="pass") + _VaultClient(auth_type="approle", url="http://localhost:8180", secret_id="pass") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_aws_iam(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="aws_iam", role_id="role", url="http://localhost:8180", - key_id="user", secret_id='pass') + vault_client = _VaultClient( + auth_type="aws_iam", role_id="role", url="http://localhost:8180", key_id="user", secret_id='pass' + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth_aws_iam.assert_called_with( - access_key='user', - secret_key='pass', - role="role", + access_key='user', secret_key='pass', role="role", ) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -103,15 +104,18 @@ def test_aws_iam(self, mock_hvac): def test_aws_iam_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="aws_iam", role_id="role", url="http://localhost:8180", - key_id="user", secret_id='pass', auth_mount_point="other") + vault_client = _VaultClient( + auth_type="aws_iam", + role_id="role", + url="http://localhost:8180", + key_id="user", + secret_id='pass', + auth_mount_point="other", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth_aws_iam.assert_called_with( - access_key='user', - secret_key='pass', - role="role", - mount_point='other' + access_key='user', secret_key='pass', role="role", mount_point='other' ) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -120,15 +124,18 @@ def test_aws_iam_different_auth_mount_point(self, mock_hvac): def test_azure(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="azure", azure_tenant_id="tenant_id", azure_resource="resource", - url="http://localhost:8180", key_id="user", secret_id='pass') + vault_client = _VaultClient( + auth_type="azure", + azure_tenant_id="tenant_id", + azure_resource="resource", + url="http://localhost:8180", + key_id="user", + secret_id='pass', + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth.azure.configure.assert_called_with( - tenant_id="tenant_id", - resource="resource", - client_id="user", - client_secret="pass", + tenant_id="tenant_id", resource="resource", client_id="user", client_secret="pass", ) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -137,9 +144,15 @@ def test_azure(self, mock_hvac): def test_azure_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="azure", azure_tenant_id="tenant_id", azure_resource="resource", - url="http://localhost:8180", key_id="user", secret_id='pass', - auth_mount_point="other") + vault_client = _VaultClient( + auth_type="azure", + azure_tenant_id="tenant_id", + azure_resource="resource", + url="http://localhost:8180", + key_id="user", + secret_id='pass', + auth_mount_point="other", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth.azure.configure.assert_called_with( @@ -147,7 +160,7 @@ def test_azure_different_auth_mount_point(self, mock_hvac): resource="resource", client_id="user", client_secret="pass", - mount_point="other" + mount_point="other", ) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -162,7 +175,8 @@ def test_azure_missing_resource(self, mock_hvac): azure_tenant_id="tenant_id", url="http://localhost:8180", key_id="user", - secret_id='pass') + secret_id='pass', + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_azure_missing_tenant_id(self, mock_hvac): @@ -174,7 +188,8 @@ def test_azure_missing_tenant_id(self, mock_hvac): azure_resource='resource', url="http://localhost:8180", key_id="user", - secret_id='pass') + secret_id='pass', + ) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @@ -184,20 +199,17 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes): mock_hvac.Client.return_value = mock_client mock_get_scopes.return_value = ['scope1', 'scope2'] mock_get_credentials.return_value = ("credentials", "project_id") - vault_client = _VaultClient(auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path="path.json", - keyfile_dict=None, - scopes=['scope1', 'scope2'] + key_path="path.json", keyfile_dict=None, scopes=['scope1', 'scope2'] ) mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + client.auth.gcp.configure.assert_called_with(credentials="credentials",) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -209,21 +221,21 @@ def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, m mock_hvac.Client.return_value = mock_client mock_get_scopes.return_value = ['scope1', 'scope2'] mock_get_credentials.return_value = ("credentials", "project_id") - vault_client = _VaultClient(auth_type="gcp", gcp_key_path="path.json", gcp_scopes="scope1,scope2", - url="http://localhost:8180", auth_mount_point="other") + vault_client = _VaultClient( + auth_type="gcp", + gcp_key_path="path.json", + gcp_scopes="scope1,scope2", + url="http://localhost:8180", + auth_mount_point="other", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path="path.json", - keyfile_dict=None, - scopes=['scope1', 'scope2'] + key_path="path.json", keyfile_dict=None, scopes=['scope1', 'scope2'] ) mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.gcp.configure.assert_called_with( - credentials="credentials", - mount_point="other" - ) + client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -235,21 +247,20 @@ def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes): mock_hvac.Client.return_value = mock_client mock_get_scopes.return_value = ['scope1', 'scope2'] mock_get_credentials.return_value = ("credentials", "project_id") - vault_client = _VaultClient(auth_type="gcp", gcp_keyfile_dict={"key": "value"}, - gcp_scopes="scope1,scope2", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="gcp", + gcp_keyfile_dict={"key": "value"}, + gcp_scopes="scope1,scope2", + url="http://localhost:8180", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path=None, - keyfile_dict={"key": "value"}, - scopes=['scope1', 'scope2'] + key_path=None, keyfile_dict={"key": "value"}, scopes=['scope1', 'scope2'] ) mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + client.auth.gcp.configure.assert_called_with(credentials="credentials",) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -257,13 +268,12 @@ def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes): def test_github(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="github", - token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="github", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.github.login.assert_called_with( - token="s.7AU0I51yv1Q1lxOIg1F3ZRAS") + client.auth.github.login.assert_called_with(token="s.7AU0I51yv1Q1lxOIg1F3ZRAS") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -271,13 +281,15 @@ def test_github(self, mock_hvac): def test_github_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="github", - token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", - url="http://localhost:8180", auth_mount_point="other") + vault_client = _VaultClient( + auth_type="github", + token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", + url="http://localhost:8180", + auth_mount_point="other", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.github.login.assert_called_with( - token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", mount_point="other") + client.auth.github.login.assert_called_with(token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", mount_point="other") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -292,15 +304,14 @@ def test_github_missing_token(self, mock_hvac): def test_kubernetes_default_path(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="kubernetes", - kubernetes_role="kube_role", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="kubernetes", kubernetes_role="kube_role", url="http://localhost:8180" + ) with patch("builtins.open", mock_open(read_data="data")) as mock_file: client = vault_client.client mock_file.assert_called_with("/var/run/secrets/kubernetes.io/serviceaccount/token") mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth_kubernetes.assert_called_with( - role="kube_role", jwt="data") + client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -308,16 +319,17 @@ def test_kubernetes_default_path(self, mock_hvac): def test_kubernetes(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="kubernetes", - kubernetes_role="kube_role", - kubernetes_jwt_path="path", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="kubernetes", + kubernetes_role="kube_role", + kubernetes_jwt_path="path", + url="http://localhost:8180", + ) with patch("builtins.open", mock_open(read_data="data")) as mock_file: client = vault_client.client mock_file.assert_called_with("path") mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth_kubernetes.assert_called_with( - role="kube_role", jwt="data") + client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -325,17 +337,18 @@ def test_kubernetes(self, mock_hvac): def test_kubernetes_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="kubernetes", - kubernetes_role="kube_role", - kubernetes_jwt_path="path", - auth_mount_point="other", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="kubernetes", + kubernetes_role="kube_role", + kubernetes_jwt_path="path", + auth_mount_point="other", + url="http://localhost:8180", + ) with patch("builtins.open", mock_open(read_data="data")) as mock_file: client = vault_client.client mock_file.assert_called_with("path") mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth_kubernetes.assert_called_with( - role="kube_role", jwt="data", mount_point="other") + client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data", mount_point="other") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -344,32 +357,30 @@ def test_kubernetes_missing_role(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client with self.assertRaisesRegex(VaultError, "requires 'kubernetes_role'"): - _VaultClient(auth_type="kubernetes", - kubernetes_jwt_path="path", - url="http://localhost:8180") + _VaultClient(auth_type="kubernetes", kubernetes_jwt_path="path", url="http://localhost:8180") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_kubernetes_kubernetes_jwt_path_none(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client with self.assertRaisesRegex(VaultError, "requires 'kubernetes_jwt_path'"): - _VaultClient(auth_type="kubernetes", - kubernetes_role='kube_role', - kubernetes_jwt_path=None, - url="http://localhost:8180") + _VaultClient( + auth_type="kubernetes", + kubernetes_role='kube_role', + kubernetes_jwt_path=None, + url="http://localhost:8180", + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_ldap(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="ldap", - username="user", - password="pass", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="ldap", username="user", password="pass", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.ldap.login.assert_called_with( - username="user", password="pass") + client.auth.ldap.login.assert_called_with(username="user", password="pass") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -377,15 +388,16 @@ def test_ldap(self, mock_hvac): def test_ldap_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="ldap", - username="user", - password="pass", - auth_mount_point="other", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="ldap", + username="user", + password="pass", + auth_mount_point="other", + url="http://localhost:8180", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.ldap.login.assert_called_with( - username="user", password="pass", mount_point="other") + client.auth.ldap.login.assert_called_with(username="user", password="pass", mount_point="other") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -407,16 +419,12 @@ def test_radius_missing_secret(self, mock_hvac): def test_radius(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="radius", - radius_host="radhost", - radius_secret="pass", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="radius", radius_host="radhost", radius_secret="pass", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.radius.configure.assert_called_with( - host="radhost", - secret="pass", - port=None) + client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=None) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -424,18 +432,18 @@ def test_radius(self, mock_hvac): def test_radius_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="radius", - radius_host="radhost", - radius_secret="pass", - auth_mount_point="other", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="radius", + radius_host="radhost", + radius_secret="pass", + auth_mount_point="other", + url="http://localhost:8180", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.auth.radius.configure.assert_called_with( - host="radhost", - secret="pass", - port=None, - mount_point="other") + host="radhost", secret="pass", port=None, mount_point="other" + ) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -443,17 +451,16 @@ def test_radius_different_auth_mount_point(self, mock_hvac): def test_radius_port(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="radius", - radius_host="radhost", - radius_port=8110, - radius_secret="pass", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="radius", + radius_host="radhost", + radius_port=8110, + radius_secret="pass", + url="http://localhost:8180", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth.radius.configure.assert_called_with( - host="radhost", - secret="pass", - port=8110) + client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=8110) client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -468,8 +475,9 @@ def test_token_missing_token(self, mock_hvac): def test_token(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="token", - token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.is_authenticated.assert_called_with() @@ -483,8 +491,9 @@ def test_token_path(self, mock_hvac): mock_hvac.Client.return_value = mock_client with open('/tmp/test_token.txt', 'w+') as the_file: the_file.write('s.7AU0I51yv1Q1lxOIg1F3ZRAS') - vault_client = _VaultClient(auth_type="token", - token_path="/tmp/test_token.txt", url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="token", token_path="/tmp/test_token.txt", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') client.is_authenticated.assert_called_with() @@ -509,12 +518,12 @@ def test_default_auth_type(self, mock_hvac): def test_userpass(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="userpass", - username="user", password="pass", url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="userpass", username="user", password="pass", url="http://localhost:8180" + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth_userpass.assert_called_with( - username="user", password="pass") + client.auth_userpass.assert_called_with(username="user", password="pass") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -522,15 +531,16 @@ def test_userpass(self, mock_hvac): def test_userpass_different_auth_mount_point(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="userpass", - username="user", - password="pass", - auth_mount_point="other", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="userpass", + username="user", + password="pass", + auth_mount_point="other", + url="http://localhost:8180", + ) client = vault_client.client mock_hvac.Client.assert_called_with(url='http://localhost:8180') - client.auth_userpass.assert_called_with( - username="user", password="pass", mount_point="other") + client.auth_userpass.assert_called_with(username="user", password="pass", mount_point="other") client.is_authenticated.assert_called_with() self.assertEqual(2, vault_client.kv_engine_version) @@ -540,12 +550,14 @@ def test_get_non_existing_key_v2(self, mock_hvac): mock_hvac.Client.return_value = mock_client # Response does not contain the requested key mock_client.secrets.kv.v2.read_secret_version.side_effect = InvalidPath() - vault_client = _VaultClient(auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180" + ) secret = vault_client.get_secret(secret_path="missing") self.assertIsNone(secret) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=None) + mount_point='secret', path='missing', version=None + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_non_existing_key_v2_different_auth(self, mock_hvac): @@ -558,12 +570,14 @@ def test_get_non_existing_key_v2_different_auth(self, mock_hvac): radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") + url="http://localhost:8180", + ) secret = vault_client.get_secret(secret_path="missing") self.assertIsNone(secret) self.assertEqual("secret", vault_client.mount_point) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=None) + mount_point='secret', path='missing', version=None + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_non_existing_key_v1(self, mock_hvac): @@ -577,11 +591,11 @@ def test_get_non_existing_key_v1(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") + url="http://localhost:8180", + ) secret = vault_client.get_secret(secret_path="missing") self.assertIsNone(secret) - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='secret', path='missing') + mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing') @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_existing_key_v2(self, mock_hvac): @@ -595,13 +609,16 @@ def test_get_existing_key_v2(self, mock_hvac): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } vault_client = _VaultClient( @@ -609,11 +626,13 @@ def test_get_existing_key_v2(self, mock_hvac): radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") + url="http://localhost:8180", + ) secret = vault_client.get_secret(secret_path="missing") self.assertEqual({'secret_key': 'secret_value'}, secret) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=None) + mount_point='secret', path='missing', version=None + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_existing_key_v2_version(self, mock_hvac): @@ -627,13 +646,16 @@ def test_get_existing_key_v2_version(self, mock_hvac): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } vault_client = _VaultClient( @@ -641,11 +663,13 @@ def test_get_existing_key_v2_version(self, mock_hvac): radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") + url="http://localhost:8180", + ) secret = vault_client.get_secret(secret_path="missing", secret_version=1) self.assertEqual({'secret_key': 'secret_value'}, secret) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=1) + mount_point='secret', path='missing', version=1 + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_existing_key_v1(self, mock_hvac): @@ -660,7 +684,8 @@ def test_get_existing_key_v1(self, mock_hvac): 'data': {'value': 'world'}, 'wrap_info': None, 'warnings': None, - 'auth': None} + 'auth': None, + } vault_client = _VaultClient( auth_type="radius", @@ -668,11 +693,11 @@ def test_get_existing_key_v1(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") + url="http://localhost:8180", + ) secret = vault_client.get_secret(secret_path="missing") self.assertEqual({'value': 'world'}, secret) - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='secret', path='missing') + mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing') @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac): @@ -687,7 +712,8 @@ def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac): 'data': {'value': 'world'}, 'wrap_info': None, 'warnings': None, - 'auth': None} + 'auth': None, + } vault_client = _VaultClient( auth_type="radius", @@ -696,18 +722,22 @@ def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac): radius_secret="pass", kv_engine_version=1, auth_mount_point="other", - url="http://localhost:8180") + url="http://localhost:8180", + ) secret = vault_client.get_secret(secret_path="missing") self.assertEqual({'value': 'world'}, secret) - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='secret', path='missing') + mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing') @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_existing_key_v1_version(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client - vault_client = _VaultClient(auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", - url="http://localhost:8180", kv_engine_version=1) + vault_client = _VaultClient( + auth_type="token", + token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", + url="http://localhost:8180", + kv_engine_version=1, + ) with self.assertRaisesRegex(VaultError, "Secret version"): vault_client.get_secret(secret_path="missing", secret_version=1) @@ -721,18 +751,23 @@ def test_get_secret_metadata_v2(self, mock_hvac): 'renewable': False, 'lease_duration': 0, 'metadata': [ - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}, - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 2}, - ] + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 2, + }, + ], } - vault_client = _VaultClient(auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", - url="http://localhost:8180") + vault_client = _VaultClient( + auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180" + ) metadata = vault_client.get_secret_metadata(secret_path="missing") self.assertEqual( { @@ -741,18 +776,25 @@ def test_get_secret_metadata_v2(self, mock_hvac): 'renewable': False, 'lease_duration': 0, 'metadata': [ - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}, - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 2}, - ] - }, metadata) + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 2, + }, + ], + }, + metadata, + ) mock_client.secrets.kv.v2.read_secret_metadata.assert_called_once_with( - mount_point='secret', path='missing') + mount_point='secret', path='missing' + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_secret_metadata_v1(self, mock_hvac): @@ -765,9 +807,11 @@ def test_get_secret_metadata_v1(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") - with self.assertRaisesRegex(VaultError, "Metadata might only be used with" - " version 2 of the KV engine."): + url="http://localhost:8180", + ) + with self.assertRaisesRegex( + VaultError, "Metadata might only be used with" " version 2 of the KV engine." + ): vault_client.get_secret_metadata(secret_path="missing") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -782,20 +826,24 @@ def test_get_secret_including_metadata_v2(self, mock_hvac): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } vault_client = _VaultClient( auth_type="radius", radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") + url="http://localhost:8180", + ) metadata = vault_client.get_secret_including_metadata(secret_path="missing") self.assertEqual( { @@ -805,16 +853,22 @@ def test_get_secret_including_metadata_v2(self, mock_hvac): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None - }, metadata) + 'auth': None, + }, + metadata, + ) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=None) + mount_point='secret', path='missing', version=None + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_secret_including_metadata_v1(self, mock_hvac): @@ -827,9 +881,11 @@ def test_get_secret_including_metadata_v1(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") - with self.assertRaisesRegex(VaultError, "Metadata might only be used with" - " version 2 of the KV engine."): + url="http://localhost:8180", + ) + with self.assertRaisesRegex( + VaultError, "Metadata might only be used with" " version 2 of the KV engine." + ): vault_client.get_secret_including_metadata(secret_path="missing") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -842,13 +898,12 @@ def test_create_or_update_secret_v2(self, mock_hvac): radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") - vault_client.create_or_update_secret( - secret_path="path", - secret={'key': 'value'} + url="http://localhost:8180", ) + vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}) mock_client.secrets.kv.v2.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=None) + mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=None + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_create_or_update_secret_v2_method(self, mock_hvac): @@ -860,13 +915,10 @@ def test_create_or_update_secret_v2_method(self, mock_hvac): radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") + url="http://localhost:8180", + ) with self.assertRaisesRegex(VaultError, "The method parameter is only valid for version 1"): - vault_client.create_or_update_secret( - secret_path="path", - secret={'key': 'value'}, - method="post" - ) + vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}, method="post") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_create_or_update_secret_v2_cas(self, mock_hvac): @@ -878,14 +930,12 @@ def test_create_or_update_secret_v2_cas(self, mock_hvac): radius_host="radhost", radius_port=8110, radius_secret="pass", - url="http://localhost:8180") - vault_client.create_or_update_secret( - secret_path="path", - secret={'key': 'value'}, - cas=10 + url="http://localhost:8180", ) + vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}, cas=10) mock_client.secrets.kv.v2.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=10) + mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=10 + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_create_or_update_secret_v1(self, mock_hvac): @@ -898,13 +948,12 @@ def test_create_or_update_secret_v1(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") - vault_client.create_or_update_secret( - secret_path="path", - secret={'key': 'value'} + url="http://localhost:8180", ) + vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}) mock_client.secrets.kv.v1.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, method=None) + mount_point='secret', secret_path='path', secret={'key': 'value'}, method=None + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_create_or_update_secret_v1_cas(self, mock_hvac): @@ -917,13 +966,10 @@ def test_create_or_update_secret_v1_cas(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") + url="http://localhost:8180", + ) with self.assertRaisesRegex(VaultError, "The cas parameter is only valid for version 2"): - vault_client.create_or_update_secret( - secret_path="path", - secret={'key': 'value'}, - cas=10 - ) + vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}, cas=10) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_create_or_update_secret_v1_post(self, mock_hvac): @@ -936,11 +982,9 @@ def test_create_or_update_secret_v1_post(self, mock_hvac): radius_port=8110, radius_secret="pass", kv_engine_version=1, - url="http://localhost:8180") - vault_client.create_or_update_secret( - secret_path="path", - secret={'key': 'value'}, - method="post" + url="http://localhost:8180", ) + vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}, method="post") mock_client.secrets.kv.v1.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, method="post") + mount_point='secret', secret_path='path', secret={'key': 'value'}, method="post" + ) diff --git a/tests/providers/hashicorp/hooks/test_vault.py b/tests/providers/hashicorp/hooks/test_vault.py index 35e665bc9dac8..decc9805c7ed2 100644 --- a/tests/providers/hashicorp/hooks/test_vault.py +++ b/tests/providers/hashicorp/hooks/test_vault.py @@ -26,14 +26,10 @@ class TestVaultHook(TestCase): - @staticmethod - def get_mock_connection(conn_type="vault", - schema="secret", - host="localhost", - port=8180, - user="user", - password="pass"): + def get_mock_connection( + conn_type="vault", schema="secret", host="localhost", port=8180, user="user", password="pass" + ): mock_connection = mock.MagicMock() type(mock_connection).conn_type = PropertyMock(return_value=conn_type) type(mock_connection).host = PropertyMock(return_value=host) @@ -51,10 +47,7 @@ def test_version_not_int(self, mock_hvac, mock_get_connection): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection - connection_dict = { - "auth_type": "userpass", - "kv_engine_version": "text" - } + connection_dict = {"auth_type": "userpass", "kv_engine_version": "text"} mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { @@ -71,10 +64,7 @@ def test_version_as_string(self, mock_hvac, mock_get_connection): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection - connection_dict = { - "auth_type": "userpass", - "kv_engine_version": "2" - } + connection_dict = {"auth_type": "userpass", "kv_engine_version": "2"} mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { @@ -115,10 +105,7 @@ def test_custom_auth_mount_point_init_params(self, mock_hvac, mock_get_connectio } mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_mount_point": "custom" - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_mount_point": "custom"} test_hook = VaultHook(**kwargs) self.assertEqual("secret", test_hook.vault_client.mount_point) self.assertEqual("custom", test_hook.vault_client.auth_mount_point) @@ -131,10 +118,7 @@ def test_version_one_init(self, mock_hvac, mock_get_connection): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection - connection_dict = { - "auth_type": "userpass", - "kv_engine_version": 1 - } + connection_dict = {"auth_type": "userpass", "kv_engine_version": 1} mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { @@ -151,10 +135,7 @@ def test_custom_auth_mount_point_dejson(self, mock_hvac, mock_get_connection): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection - connection_dict = { - "auth_type": "userpass", - "auth_mount_point": "custom" - } + connection_dict = {"auth_type": "userpass", "auth_mount_point": "custom"} mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { @@ -199,7 +180,7 @@ def test_vaults_protocol(self, mock_hvac, mock_get_connection): "vault_conn_id": "vault_conn_id", "auth_type": "approle", "role_id": "role", - "kv_engine_version": 2 + "kv_engine_version": 2, } test_hook = VaultHook(**kwargs) @@ -225,7 +206,7 @@ def test_http_protocol(self, mock_hvac, mock_get_connection): "vault_conn_id": "vault_conn_id", "auth_type": "approle", "role_id": "role", - "kv_engine_version": 2 + "kv_engine_version": 2, } test_hook = VaultHook(**kwargs) @@ -251,7 +232,7 @@ def test_https_protocol(self, mock_hvac, mock_get_connection): "vault_conn_id": "vault_conn_id", "auth_type": "approle", "role_id": "role", - "kv_engine_version": 2 + "kv_engine_version": 2, } test_hook = VaultHook(**kwargs) @@ -277,7 +258,7 @@ def test_approle_init_params(self, mock_hvac, mock_get_connection): "vault_conn_id": "vault_conn_id", "auth_type": "approle", "role_id": "role", - "kv_engine_version": 2 + "kv_engine_version": 2, } test_hook = VaultHook(**kwargs) @@ -325,20 +306,14 @@ def test_aws_iam_init_params(self, mock_hvac, mock_get_connection): connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "aws_iam", - "role_id": "role" - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "aws_iam", "role_id": "role"} test_hook = VaultHook(**kwargs) mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') test_client.auth_aws_iam.assert_called_with( - access_key='user', - secret_key='pass', - role="role", + access_key='user', secret_key='pass', role="role", ) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -351,10 +326,7 @@ def test_aws_iam_dejson(self, mock_hvac, mock_get_connection): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection - connection_dict = { - "auth_type": "aws_iam", - "role_id": "role" - } + connection_dict = {"auth_type": "aws_iam", "role_id": "role"} mock_connection.extra_dejson.get.side_effect = connection_dict.get kwargs = { @@ -366,9 +338,7 @@ def test_aws_iam_dejson(self, mock_hvac, mock_get_connection): test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') test_client.auth_aws_iam.assert_called_with( - access_key='user', - secret_key='pass', - role="role", + access_key='user', secret_key='pass', role="role", ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @@ -394,10 +364,7 @@ def test_azure_init_params(self, mock_hvac, mock_get_connection): test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') test_client.auth.azure.configure.assert_called_with( - tenant_id="tenant_id", - resource="resource", - client_id="user", - client_secret="pass", + tenant_id="tenant_id", resource="resource", client_id="user", client_secret="pass", ) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -426,10 +393,7 @@ def test_azure_dejson(self, mock_hvac, mock_get_connection): test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') test_client.auth.azure.configure.assert_called_with( - tenant_id="tenant_id", - resource="resource", - client_id="user", - client_secret="pass", + tenant_id="tenant_id", resource="resource", client_id="user", client_secret="pass", ) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -438,8 +402,7 @@ def test_azure_dejson(self, mock_hvac, mock_get_connection): @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_init_params(self, mock_hvac, mock_get_connection, - mock_get_credentials, mock_get_scopes): + def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -462,14 +425,10 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path="path.json", - keyfile_dict=None, - scopes=['scope1', 'scope2'] + key_path="path.json", keyfile_dict=None, scopes=['scope1', 'scope2'] ) mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + test_client.auth.gcp.configure.assert_called_with(credentials="credentials",) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -477,8 +436,7 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_dejson(self, mock_hvac, mock_get_connection, - mock_get_credentials, mock_get_scopes): + def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -502,14 +460,10 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path="path.json", - keyfile_dict=None, - scopes=['scope1', 'scope2'] + key_path="path.json", keyfile_dict=None, scopes=['scope1', 'scope2'] ) mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + test_client.auth.gcp.configure.assert_called_with(credentials="credentials",) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -517,8 +471,7 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, - mock_get_credentials, mock_get_scopes): + def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials, mock_get_scopes): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -542,14 +495,10 @@ def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") mock_get_credentials.assert_called_with( - key_path=None, - keyfile_dict={'key': 'value'}, - scopes=['scope1', 'scope2'] + key_path=None, keyfile_dict={'key': 'value'}, scopes=['scope1', 'scope2'] ) mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.gcp.configure.assert_called_with( - credentials="credentials", - ) + test_client.auth.gcp.configure.assert_called_with(credentials="credentials",) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -573,8 +522,7 @@ def test_github_init_params(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.github.login.assert_called_with( - token="pass") + test_client.auth.github.login.assert_called_with(token="pass") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -599,8 +547,7 @@ def test_github_dejson(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.github.login.assert_called_with( - token="pass") + test_client.auth.github.login.assert_called_with(token="pass") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -627,8 +574,7 @@ def test_kubernetes_default_path(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") mock_file.assert_called_with("/var/run/secrets/kubernetes.io/serviceaccount/token") mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth_kubernetes.assert_called_with( - role="kube_role", jwt="data") + test_client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -656,8 +602,7 @@ def test_kubernetes_init_params(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") mock_file.assert_called_with("path") mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth_kubernetes.assert_called_with( - role="kube_role", jwt="data") + test_client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -684,8 +629,7 @@ def test_kubernetes_dejson(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") mock_file.assert_called_with("path") mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth_kubernetes.assert_called_with( - role="kube_role", jwt="data") + test_client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -709,8 +653,7 @@ def test_ldap_init_params(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.ldap.login.assert_called_with( - username="user", password="pass") + test_client.auth.ldap.login.assert_called_with(username="user", password="pass") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -735,8 +678,7 @@ def test_ldap_dejson(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.ldap.login.assert_called_with( - username="user", password="pass") + test_client.auth.ldap.login.assert_called_with(username="user", password="pass") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -761,10 +703,7 @@ def test_radius_init_params(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.radius.configure.assert_called_with( - host="radhost", - secret="pass", - port=None) + test_client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=None) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -790,10 +729,7 @@ def test_radius_init_params_port(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.radius.configure.assert_called_with( - host="radhost", - secret="pass", - port=8123) + test_client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=8123) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -820,10 +756,7 @@ def test_radius_dejson(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth.radius.configure.assert_called_with( - host="radhost", - secret="pass", - port=8123) + test_client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=8123) test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -858,11 +791,7 @@ def test_token_init_params(self, mock_hvac, mock_get_connection): mock_get_connection.return_value = mock_connection connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) mock_get_connection.assert_called_with("vault_conn_id") @@ -909,18 +838,13 @@ def test_userpass_init_params(self, mock_hvac, mock_get_connection): connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "userpass", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "userpass", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth_userpass.assert_called_with( - username="user", password="pass") + test_client.auth_userpass.assert_called_with(username="user", password="pass") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -945,8 +869,7 @@ def test_userpass_dejson(self, mock_hvac, mock_get_connection): mock_get_connection.assert_called_with("vault_conn_id") test_client = test_hook.get_conn() mock_hvac.Client.assert_called_with(url='http://localhost:8180') - test_client.auth_userpass.assert_called_with( - username="user", password="pass") + test_client.auth_userpass.assert_called_with(username="user", password="pass") test_client.is_authenticated.assert_called_with() self.assertEqual(2, test_hook.vault_client.kv_engine_version) @@ -967,27 +890,27 @@ def test_get_existing_key_v2(self, mock_hvac, mock_get_connection): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) secret = test_hook.get_secret(secret_path="missing") self.assertEqual({'secret_key': 'secret_value'}, secret) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=None) + mount_point='secret', path='missing', version=None + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1006,27 +929,27 @@ def test_get_existing_key_v2_version(self, mock_hvac, mock_get_connection): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) secret = test_hook.get_secret(secret_path="missing", secret_version=1) self.assertEqual({'secret_key': 'secret_value'}, secret) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=1) + mount_point='secret', path='missing', version=1 + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1046,20 +969,16 @@ def test_get_existing_key_v1(self, mock_hvac, mock_get_connection): 'data': {'value': 'world'}, 'wrap_info': None, 'warnings': None, - 'auth': None} + 'auth': None, + } mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 1 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 1} test_hook = VaultHook(**kwargs) secret = test_hook.get_secret(secret_path="missing") self.assertEqual({'value': 'world'}, secret) - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='secret', path='missing') + mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing') @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1077,23 +996,23 @@ def test_get_secret_metadata_v2(self, mock_hvac, mock_get_connection): 'renewable': False, 'lease_duration': 0, 'metadata': [ - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}, - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 2}, - ] + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 2, + }, + ], } mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) metadata = test_hook.get_secret_metadata(secret_path="missing") @@ -1104,18 +1023,25 @@ def test_get_secret_metadata_v2(self, mock_hvac, mock_get_connection): 'renewable': False, 'lease_duration': 0, 'metadata': [ - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}, - {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 2}, - ] - }, metadata) + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 2, + }, + ], + }, + metadata, + ) mock_client.secrets.kv.v2.read_secret_metadata.assert_called_once_with( - mount_point='secret', path='missing') + mount_point='secret', path='missing' + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1134,21 +1060,20 @@ def test_get_secret_including_metadata_v2(self, mock_hvac, mock_get_connection): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) metadata = test_hook.get_secret_including_metadata(secret_path="missing") @@ -1160,16 +1085,22 @@ def test_get_secret_including_metadata_v2(self, mock_hvac, mock_get_connection): 'lease_duration': 0, 'data': { 'data': {'secret_key': 'secret_value'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None - }, metadata) + 'auth': None, + }, + metadata, + ) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='secret', path='missing', version=None) + mount_point='secret', path='missing', version=None + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1182,19 +1113,13 @@ def test_create_or_update_secret_v2(self, mock_hvac, mock_get_connection): connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) - test_hook.create_or_update_secret( - secret_path="path", - secret={'key': 'value'} - ) + test_hook.create_or_update_secret(secret_path="path", secret={'key': 'value'}) mock_client.secrets.kv.v2.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=None) + mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=None + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1207,20 +1132,13 @@ def test_create_or_update_secret_v2_cas(self, mock_hvac, mock_get_connection): connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 2 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 2} test_hook = VaultHook(**kwargs) - test_hook.create_or_update_secret( - secret_path="path", - secret={'key': 'value'}, - cas=10 - ) + test_hook.create_or_update_secret(secret_path="path", secret={'key': 'value'}, cas=10) mock_client.secrets.kv.v2.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=10) + mount_point='secret', secret_path='path', secret={'key': 'value'}, cas=10 + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1233,19 +1151,13 @@ def test_create_or_update_secret_v1(self, mock_hvac, mock_get_connection): connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 1 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 1} test_hook = VaultHook(**kwargs) - test_hook.create_or_update_secret( - secret_path="path", - secret={'key': 'value'} - ) + test_hook.create_or_update_secret(secret_path="path", secret={'key': 'value'}) mock_client.secrets.kv.v1.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, method=None) + mount_point='secret', secret_path='path', secret={'key': 'value'}, method=None + ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -1258,17 +1170,10 @@ def test_create_or_update_secret_v1_post(self, mock_hvac, mock_get_connection): connection_dict = {} mock_connection.extra_dejson.get.side_effect = connection_dict.get - kwargs = { - "vault_conn_id": "vault_conn_id", - "auth_type": "token", - "kv_engine_version": 1 - } + kwargs = {"vault_conn_id": "vault_conn_id", "auth_type": "token", "kv_engine_version": 1} test_hook = VaultHook(**kwargs) - test_hook.create_or_update_secret( - secret_path="path", - secret={'key': 'value'}, - method="post" - ) + test_hook.create_or_update_secret(secret_path="path", secret={'key': 'value'}, method="post") mock_client.secrets.kv.v1.create_or_update_secret.assert_called_once_with( - mount_point='secret', secret_path='path', secret={'key': 'value'}, method="post") + mount_point='secret', secret_path='path', secret={'key': 'value'}, method="post" + ) diff --git a/tests/providers/hashicorp/secrets/test_vault.py b/tests/providers/hashicorp/secrets/test_vault.py index 5a4449b357352..88ca07d2f8635 100644 --- a/tests/providers/hashicorp/secrets/test_vault.py +++ b/tests/providers/hashicorp/secrets/test_vault.py @@ -23,7 +23,6 @@ class TestVaultSecrets(TestCase): - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_conn_uri(self, mock_hvac): mock_client = mock.MagicMock() @@ -35,13 +34,16 @@ def test_get_conn_uri(self, mock_hvac): 'lease_duration': 0, 'data': { 'data': {'conn_uri': 'postgresql://airflow:airflow@host:5432/airflow'}, - 'metadata': {'created_time': '2020-03-16T21:01:43.331126Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'metadata': { + 'created_time': '2020-03-16T21:01:43.331126Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } kwargs = { @@ -49,7 +51,7 @@ def test_get_conn_uri(self, mock_hvac): "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", - "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS" + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", } test_client = VaultBackend(**kwargs) @@ -68,7 +70,8 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac): 'data': {'conn_uri': 'postgresql://airflow:airflow@host:5432/airflow'}, 'wrap_info': None, 'warnings': None, - 'auth': None} + 'auth': None, + } kwargs = { "connections_path": "connections", @@ -76,13 +79,14 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac): "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", - "kv_engine_version": 1 + "kv_engine_version": 1, } test_client = VaultBackend(**kwargs) returned_uri = test_client.get_conn_uri(conn_id="test_postgres") mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='airflow', path='connections/test_postgres') + mount_point='airflow', path='connections/test_postgres' + ) self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -97,7 +101,8 @@ def test_get_conn_uri_engine_version_1_custom_auth_mount_point(self, mock_hvac): 'data': {'conn_uri': 'postgresql://airflow:airflow@host:5432/airflow'}, 'wrap_info': None, 'warnings': None, - 'auth': None} + 'auth': None, + } kwargs = { "connections_path": "connections", @@ -106,19 +111,18 @@ def test_get_conn_uri_engine_version_1_custom_auth_mount_point(self, mock_hvac): "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", - "kv_engine_version": 1 + "kv_engine_version": 1, } test_client = VaultBackend(**kwargs) self.assertEqual("custom", test_client.vault_client.auth_mount_point) returned_uri = test_client.get_conn_uri(conn_id="test_postgres") mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='airflow', path='connections/test_postgres') + mount_point='airflow', path='connections/test_postgres' + ) self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri) - @mock.patch.dict('os.environ', { - 'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow', - }) + @mock.patch.dict('os.environ', {'AIRFLOW_CONN_TEST_MYSQL': 'mysql://airflow:airflow@host:5432/airflow',}) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_conn_uri_non_existent_key(self, mock_hvac): """ @@ -135,13 +139,14 @@ def test_get_conn_uri_non_existent_key(self, mock_hvac): "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", - "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS" + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", } test_client = VaultBackend(**kwargs) self.assertIsNone(test_client.get_conn_uri(conn_id="test_mysql")) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='airflow', path='connections/test_mysql', version=None) + mount_point='airflow', path='connections/test_mysql', version=None + ) self.assertEqual([], test_client.get_connections(conn_id="test_mysql")) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -153,14 +158,18 @@ def test_get_variable_value(self, mock_hvac): 'lease_id': '', 'renewable': False, 'lease_duration': 0, - 'data': {'data': {'value': 'world'}, - 'metadata': {'created_time': '2020-03-28T02:10:54.301784Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'data': { + 'data': {'value': 'world'}, + 'metadata': { + 'created_time': '2020-03-28T02:10:54.301784Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } kwargs = { @@ -168,7 +177,7 @@ def test_get_variable_value(self, mock_hvac): "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", - "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS" + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", } test_client = VaultBackend(**kwargs) @@ -187,7 +196,8 @@ def test_get_variable_value_engine_version_1(self, mock_hvac): 'data': {'value': 'world'}, 'wrap_info': None, 'warnings': None, - 'auth': None} + 'auth': None, + } kwargs = { "variables_path": "variables", @@ -195,18 +205,17 @@ def test_get_variable_value_engine_version_1(self, mock_hvac): "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", - "kv_engine_version": 1 + "kv_engine_version": 1, } test_client = VaultBackend(**kwargs) returned_uri = test_client.get_variable("hello") mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point='airflow', path='variables/hello') + mount_point='airflow', path='variables/hello' + ) self.assertEqual('world', returned_uri) - @mock.patch.dict('os.environ', { - 'AIRFLOW_VAR_HELLO': 'world', - }) + @mock.patch.dict('os.environ', {'AIRFLOW_VAR_HELLO': 'world',}) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_variable_value_non_existent_key(self, mock_hvac): """ @@ -223,13 +232,14 @@ def test_get_variable_value_non_existent_key(self, mock_hvac): "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", - "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS" + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", } test_client = VaultBackend(**kwargs) self.assertIsNone(test_client.get_variable("hello")) mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point='airflow', path='variables/hello', version=None) + mount_point='airflow', path='variables/hello', version=None + ) self.assertIsNone(test_client.get_variable("hello")) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -243,7 +253,7 @@ def test_auth_failure_raises_error(self, mock_hvac): "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", - "token": "test_wrong_token" + "token": "test_wrong_token", } with self.assertRaisesRegex(VaultError, "Vault Authentication Error!"): @@ -270,14 +280,18 @@ def test_get_config_value(self, mock_hvac): 'lease_id': '', 'renewable': False, 'lease_duration': 0, - 'data': {'data': {'value': 'sqlite:////Users/airflow/airflow/airflow.db'}, - 'metadata': {'created_time': '2020-03-28T02:10:54.301784Z', - 'deletion_time': '', - 'destroyed': False, - 'version': 1}}, + 'data': { + 'data': {'value': 'sqlite:////Users/airflow/airflow/airflow.db'}, + 'metadata': { + 'created_time': '2020-03-28T02:10:54.301784Z', + 'deletion_time': '', + 'destroyed': False, + 'version': 1, + }, + }, 'wrap_info': None, 'warnings': None, - 'auth': None + 'auth': None, } kwargs = { @@ -285,7 +299,7 @@ def test_get_config_value(self, mock_hvac): "mount_point": "secret", "auth_type": "token", "url": "http://127.0.0.1:8200", - "token": "s.FnL7qg0YnHZDpf4zKKuFy0UK" + "token": "s.FnL7qg0YnHZDpf4zKKuFy0UK", } test_client = VaultBackend(**kwargs) diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index 382461573c43e..bfc3b7f69ce80 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -30,21 +30,11 @@ def get_airflow_connection(unused_conn_id=None): - return Connection( - conn_id='http_default', - conn_type='http', - host='test:8080/', - extra='{"bareer": "test"}' - ) + return Connection(conn_id='http_default', conn_type='http', host='test:8080/', extra='{"bareer": "test"}') def get_airflow_connection_with_port(unused_conn_id=None): - return Connection( - conn_id='http_default', - conn_type='http', - host='test.com', - port=1234 - ) + return Connection(conn_id='http_default', conn_type='http', host='test.com', port=1234) class TestHttpHook(unittest.TestCase): @@ -61,15 +51,9 @@ def setUp(self): @requests_mock.mock() def test_raise_for_status_with_200(self, m): - m.get( - 'http://test:8080/v1/test', - status_code=200, - text='{"status":{"status": 200}}', - reason='OK' - ) + m.get('http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK') with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): resp = self.get_hook.run('v1/test') self.assertEqual(resp.text, '{"status":{"status": 200}}') @@ -81,8 +65,7 @@ def test_get_request_with_port(self, mock_requests, request_mock, mock_session): from requests.exceptions import MissingSchema with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection_with_port + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): expected_url = 'http://test.com:1234/some/endpoint' for endpoint in ['some/endpoint', '/some/endpoint']: @@ -93,10 +76,7 @@ def test_get_request_with_port(self, mock_requests, request_mock, mock_session): pass request_mock.assert_called_once_with( - mock.ANY, - expected_url, - headers=mock.ANY, - params=mock.ANY + mock.ANY, expected_url, headers=mock.ANY, params=mock.ANY ) request_mock.reset_mock() @@ -108,12 +88,11 @@ def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m) 'http://test:8080/v1/test', status_code=404, text='{"status":{"status": 404}}', - reason='Bad request' + reason='Bad request', ) with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): resp = self.get_hook.run('v1/test', extra_options={'check_response': False}) self.assertEqual(resp.text, '{"status":{"status": 404}}') @@ -121,8 +100,7 @@ def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m) @requests_mock.mock() def test_hook_contains_header_from_extra_field(self, mock_requests): with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): expected_conn = get_airflow_connection() conn = self.get_hook.get_conn() @@ -133,21 +111,16 @@ def test_hook_contains_header_from_extra_field(self, mock_requests): @mock.patch('requests.Request') def test_hook_with_method_in_lowercase(self, mock_requests, request_mock): from requests.exceptions import InvalidURL, MissingSchema + with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection_with_port + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): data = "test params" try: self.get_lowercase_hook.run('v1/test', data=data) except (MissingSchema, InvalidURL): pass - request_mock.assert_called_once_with( - mock.ANY, - mock.ANY, - headers=mock.ANY, - params=data - ) + request_mock.assert_called_once_with(mock.ANY, mock.ANY, headers=mock.ANY, params=data) @requests_mock.mock() def test_hook_uses_provided_header(self, mock_requests): @@ -162,8 +135,7 @@ def test_hook_has_no_header_from_extra(self, mock_requests): @requests_mock.mock() def test_hooks_header_from_extra_is_overridden(self, mock_requests): with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"}) self.assertEqual(conn.headers.get('bareer'), 'newT0k3n') @@ -171,15 +143,11 @@ def test_hooks_header_from_extra_is_overridden(self, mock_requests): @requests_mock.mock() def test_post_request(self, mock_requests): mock_requests.post( - 'http://test:8080/v1/test', - status_code=200, - text='{"status":{"status": 200}}', - reason='OK' + 'http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK' ) with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): resp = self.post_hook.run('v1/test') self.assertEqual(resp.status_code, 200) @@ -190,12 +158,11 @@ def test_post_request_with_error_code(self, mock_requests): 'http://test:8080/v1/test', status_code=418, text='{"status":{"status": 418}}', - reason='I\'m a teapot' + reason='I\'m a teapot', ) with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): with self.assertRaises(AirflowException): self.post_hook.run('v1/test') @@ -206,12 +173,11 @@ def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, m 'http://test:8080/v1/test', status_code=418, text='{"status":{"status": 418}}', - reason='I\'m a teapot' + reason='I\'m a teapot', ) with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): resp = self.post_hook.run('v1/test', extra_options={'check_response': False}) self.assertEqual(resp.status_code, 418) @@ -222,9 +188,7 @@ def test_retry_on_conn_error(self, mocked_session): retry_args = dict( wait=tenacity.wait_none(), stop=tenacity.stop_after_attempt(7), - retry=tenacity.retry_if_exception_type( - requests.exceptions.ConnectionError - ) + retry=tenacity.retry_if_exception_type(requests.exceptions.ConnectionError), ) def send_and_raise(unused_request, **kwargs): @@ -233,53 +197,36 @@ def send_and_raise(unused_request, **kwargs): mocked_session().send.side_effect = send_and_raise # The job failed for some reason with self.assertRaises(tenacity.RetryError): - self.get_hook.run_with_advanced_retry( - endpoint='v1/test', - _retry_args=retry_args - ) - self.assertEqual( - self.get_hook._retry_obj.stop.max_attempt_number + 1, - mocked_session.call_count - ) + self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args) + self.assertEqual(self.get_hook._retry_obj.stop.max_attempt_number + 1, mocked_session.call_count) @requests_mock.mock() def test_run_with_advanced_retry(self, m): - m.get( - 'http://test:8080/v1/test', - status_code=200, - reason='OK' - ) + m.get('http://test:8080/v1/test', status_code=200, reason='OK') retry_args = dict( wait=tenacity.wait_none(), stop=tenacity.stop_after_attempt(3), retry=tenacity.retry_if_exception_type(Exception), - reraise=True + reraise=True, ) with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): - response = self.get_hook.run_with_advanced_retry( - endpoint='v1/test', - _retry_args=retry_args - ) + response = self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args) self.assertIsInstance(response, requests.Response) def test_header_from_extra_and_run_method_are_merged(self): - def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs): return prepped_request # The job failed for some reason with mock.patch( - 'airflow.providers.http.hooks.http.HttpHook.run_and_check', - side_effect=run_and_return + 'airflow.providers.http.hooks.http.HttpHook.run_and_check', side_effect=run_and_return ): with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): prepared_request = self.get_hook.run('v1/test', headers={'some_other_header': 'test'}) actual = dict(prepared_request.headers) @@ -288,8 +235,7 @@ def run_and_return(unused_session, prepped_request, unused_extra_options, **kwar @mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection') def test_http_connection(self, mock_get_connection): - conn = Connection(conn_id='http_default', conn_type='http', - host='localhost', schema='http') + conn = Connection(conn_id='http_default', conn_type='http', host='localhost', schema='http') mock_get_connection.return_value = conn hook = HttpHook() hook.get_conn({}) @@ -297,8 +243,7 @@ def test_http_connection(self, mock_get_connection): @mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection') def test_https_connection(self, mock_get_connection): - conn = Connection(conn_id='http_default', conn_type='http', - host='localhost', schema='https') + conn = Connection(conn_id='http_default', conn_type='http', host='localhost', schema='https') mock_get_connection.return_value = conn hook = HttpHook() hook.get_conn({}) @@ -306,8 +251,7 @@ def test_https_connection(self, mock_get_connection): @mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection') def test_host_encoded_http_connection(self, mock_get_connection): - conn = Connection(conn_id='http_default', conn_type='http', - host='http://localhost') + conn = Connection(conn_id='http_default', conn_type='http', host='http://localhost') mock_get_connection.return_value = conn hook = HttpHook() hook.get_conn({}) @@ -315,8 +259,7 @@ def test_host_encoded_http_connection(self, mock_get_connection): @mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection') def test_host_encoded_https_connection(self, mock_get_connection): - conn = Connection(conn_id='http_default', conn_type='http', - host='https://localhost') + conn = Connection(conn_id='http_default', conn_type='http', host='https://localhost') mock_get_connection.return_value = conn hook = HttpHook() hook.get_conn({}) @@ -334,10 +277,9 @@ def test_connection_without_host(self, mock_get_connection): hook.get_conn({}) self.assertEqual(hook.base_url, 'http://') - @parameterized.expand([ - 'GET', - 'POST', - ]) + @parameterized.expand( + ['GET', 'POST',] + ) @requests_mock.mock() def test_json_request(self, method, mock_requests): obj1 = {'a': 1, 'b': 'abc', 'c': [1, 2, {"d": 10}]} @@ -345,15 +287,10 @@ def test_json_request(self, method, mock_requests): def match_obj1(request): return request.json() == obj1 - mock_requests.request( - method=method, - url='//test:8080/v1/test', - additional_matcher=match_obj1 - ) + mock_requests.request(method=method, url='//test:8080/v1/test', additional_matcher=match_obj1) with mock.patch( - 'airflow.hooks.base_hook.BaseHook.get_connection', - side_effect=get_airflow_connection + 'airflow.hooks.base_hook.BaseHook.get_connection', side_effect=get_airflow_connection ): # will raise NoMockAddress exception if obj1 != request.json() HttpHook(method=method).run('v1/test', json=obj1) diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index 5e293e5a38f76..023c95a540dfe 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -27,7 +27,6 @@ @mock.patch.dict('os.environ', AIRFLOW_CONN_HTTP_EXAMPLE='http://www.example.com') class TestSimpleHttpOp(unittest.TestCase): - @requests_mock.mock() def test_response_in_logs(self, m): """ @@ -46,10 +45,7 @@ def test_response_in_logs(self, m): with mock.patch.object(operator.log, 'info') as mock_info: operator.execute(None) - calls = [ - mock.call('Example.com fake response'), - mock.call('Example.com fake response') - ] + calls = [mock.call('Example.com fake response'), mock.call('Example.com fake response')] mock_info.has_calls(calls) @requests_mock.mock() @@ -69,15 +65,12 @@ def response_check(response): endpoint='/', http_conn_id='HTTP_EXAMPLE', log_response=True, - response_check=response_check + response_check=response_check, ) with mock.patch.object(operator.log, 'info') as mock_info: self.assertRaises(AirflowException, operator.execute, None) - calls = [ - mock.call('Calling HTTP method'), - mock.call('invalid response') - ] + calls = [mock.call('Calling HTTP method'), mock.call('invalid response')] mock_info.assert_has_calls(calls, any_order=True) @requests_mock.mock() @@ -88,7 +81,7 @@ def test_filters_response(self, m): method='GET', endpoint='/', http_conn_id='HTTP_EXAMPLE', - response_filter=lambda response: response.json() + response_filter=lambda response: response.json(), ) result = operator.execute(None) assert result == {'value': 5} diff --git a/tests/providers/http/operators/test_http_system.py b/tests/providers/http/operators/test_http_system.py index bd5a9eaf8a1b7..f86a898acf79e 100644 --- a/tests/providers/http/operators/test_http_system.py +++ b/tests/providers/http/operators/test_http_system.py @@ -22,14 +22,11 @@ from tests.test_utils import AIRFLOW_MAIN_FOLDER from tests.test_utils.system_tests_class import SystemTest -HTTP_DAG_FOLDER = os.path.join( - AIRFLOW_MAIN_FOLDER, "airflow", "providers", "http", "example_dags" -) +HTTP_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, "airflow", "providers", "http", "example_dags") @pytest.mark.backend("mysql", "postgres") @pytest.mark.system("http") class HttpExampleDagsSystemTest(SystemTest): - def test_run_example_dag_http(self): self.run_dag('example_http_operator', HTTP_DAG_FOLDER) diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 3f7cb5004b746..14c61c4ce3262 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -35,10 +35,7 @@ class TestHttpSensor(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) @patch("airflow.providers.http.hooks.http.requests.Session.send") @@ -60,7 +57,8 @@ def resp_check(_): request_params={}, response_check=resp_check, timeout=5, - poke_interval=1) + poke_interval=1, + ) with self.assertRaisesRegex(AirflowException, 'AirflowException raised here!'): task.execute(context={}) @@ -78,17 +76,15 @@ def resp_check(_): method='HEAD', response_check=resp_check, timeout=5, - poke_interval=1) + poke_interval=1, + ) task.execute(context={}) args, kwargs = mock_session_send.call_args received_request = args[0] - prep_request = requests.Request( - 'HEAD', - 'https://www.httpbin.org', - {}).prepare() + prep_request = requests.Request('HEAD', 'https://www.httpbin.org', {}).prepare() self.assertEqual(prep_request.url, received_request.url) self.assertTrue(prep_request.method, received_request.method) @@ -112,16 +108,14 @@ def resp_check(_, execution_date): response_check=resp_check, timeout=5, poke_interval=1, - dag=self.dag) + dag=self.dag, + ) task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE) task.execute(task_instance.get_template_context()) @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_logging_head_error_request( - self, - mock_session_send - ): + def test_logging_head_error_request(self, mock_session_send): def resp_check(_): return True @@ -140,7 +134,7 @@ def resp_check(_): method='HEAD', response_check=resp_check, timeout=5, - poke_interval=1 + poke_interval=1, ) with mock.patch.object(task.hook.log, 'error') as mock_errors: @@ -176,9 +170,7 @@ def send(self, *args, **kwargs): def prepare_request(self, request): if 'date' in request.params: - self.response._content += ( - '/' + request.params['date'] - ).encode('ascii', 'ignore') + self.response._content += ('/' + request.params['date']).encode('ascii', 'ignore') return self.response @@ -196,7 +188,8 @@ def test_get(self): endpoint='/search', data={"client": "ubuntu", "q": "airflow"}, headers={}, - dag=self.dag) + dag=self.dag, + ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @mock.patch('requests.Session', FakeSession) @@ -208,7 +201,8 @@ def test_get_response_check(self): data={"client": "ubuntu", "q": "airflow"}, response_check=lambda response: ("apache/airflow" in response.text), headers={}, - dag=self.dag) + dag=self.dag, + ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @mock.patch('requests.Session', FakeSession) @@ -220,9 +214,10 @@ def test_sensor(self): request_params={"client": "ubuntu", "q": "airflow", 'date': '{{ds}}'}, headers={}, response_check=lambda response: ( - "apache/airflow/" + DEFAULT_DATE.strftime('%Y-%m-%d') - in response.text), + "apache/airflow/" + DEFAULT_DATE.strftime('%Y-%m-%d') in response.text + ), poke_interval=5, timeout=15, - dag=self.dag) + dag=self.dag, + ) sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/imap/hooks/test_imap.py b/tests/providers/imap/hooks/test_imap.py index ef3b998911bbd..b37a63ecbd7e1 100644 --- a/tests/providers/imap/hooks/test_imap.py +++ b/tests/providers/imap/hooks/test_imap.py @@ -38,10 +38,11 @@ def _create_fake_imap(mock_imaplib, with_mail=False, attachment_name='test1.csv' if with_mail: mock_conn.select.return_value = ('OK', []) mock_conn.search.return_value = ('OK', [b'1']) - mail_string = \ - 'Content-Type: multipart/mixed; boundary=123\r\n--123\r\n' \ - 'Content-Disposition: attachment; filename="{}";' \ + mail_string = ( + 'Content-Type: multipart/mixed; boundary=123\r\n--123\r\n' + 'Content-Disposition: attachment; filename="{}";' 'Content-Transfer-Encoding: base64\r\nSWQsTmFtZQoxLEZlbGl4\r\n--123--'.format(attachment_name) + ) mock_conn.fetch.return_value = ('OK', [(b'', mail_string.encode('utf-8'))]) mock_conn.close.return_value = ('OK', []) @@ -58,7 +59,7 @@ def setUp(self): conn_type='imap', host='imap_server_address', login='imap_user', - password='imap_password' + password='imap_password', ) ) @@ -96,10 +97,7 @@ def test_has_mail_attachment_with_regex_found(self, mock_imaplib): _create_fake_imap(mock_imaplib, with_mail=True) with ImapHook() as imap_hook: - has_attachment_in_inbox = imap_hook.has_mail_attachment( - name=r'test(\d+).csv', - check_regex=True - ) + has_attachment_in_inbox = imap_hook.has_mail_attachment(name=r'test(\d+).csv', check_regex=True) self.assertTrue(has_attachment_in_inbox) @@ -108,10 +106,7 @@ def test_has_mail_attachment_with_regex_not_found(self, mock_imaplib): _create_fake_imap(mock_imaplib, with_mail=True) with ImapHook() as imap_hook: - has_attachment_in_inbox = imap_hook.has_mail_attachment( - name=r'test_(\d+).csv', - check_regex=True - ) + has_attachment_in_inbox = imap_hook.has_mail_attachment(name=r'test_(\d+).csv', check_regex=True) self.assertFalse(has_attachment_in_inbox) @@ -121,10 +116,7 @@ def test_has_mail_attachment_with_mail_filter(self, mock_imaplib): mail_filter = '(SINCE "01-Jan-2019")' with ImapHook() as imap_hook: - imap_hook.has_mail_attachment( - name='test1.csv', - mail_filter=mail_filter - ) + imap_hook.has_mail_attachment(name='test1.csv', mail_filter=mail_filter) mock_imaplib.IMAP4_SSL.return_value.search.assert_called_once_with(None, mail_filter) @@ -150,8 +142,7 @@ def test_retrieve_mail_attachments_with_regex_found(self, mock_imaplib): with ImapHook() as imap_hook: attachments_in_inbox = imap_hook.retrieve_mail_attachments( - name=r'test(\d+).csv', - check_regex=True + name=r'test(\d+).csv', check_regex=True ) self.assertEqual(attachments_in_inbox, [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')]) @@ -161,20 +152,19 @@ def test_retrieve_mail_attachments_with_regex_not_found(self, mock_imaplib): _create_fake_imap(mock_imaplib, with_mail=True) with ImapHook() as imap_hook: - self.assertRaises(AirflowException, - imap_hook.retrieve_mail_attachments, - name=r'test_(\d+).csv', - check_regex=True) + self.assertRaises( + AirflowException, + imap_hook.retrieve_mail_attachments, + name=r'test_(\d+).csv', + check_regex=True, + ) @patch(imaplib_string) def test_retrieve_mail_attachments_latest_only(self, mock_imaplib): _create_fake_imap(mock_imaplib, with_mail=True) with ImapHook() as imap_hook: - attachments_in_inbox = imap_hook.retrieve_mail_attachments( - name='test1.csv', - latest_only=True - ) + attachments_in_inbox = imap_hook.retrieve_mail_attachments(name='test1.csv', latest_only=True) self.assertEqual(attachments_in_inbox, [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')]) @@ -184,10 +174,7 @@ def test_retrieve_mail_attachments_with_mail_filter(self, mock_imaplib): mail_filter = '(SINCE "01-Jan-2019")' with ImapHook() as imap_hook: - imap_hook.retrieve_mail_attachments( - name='test1.csv', - mail_filter=mail_filter - ) + imap_hook.retrieve_mail_attachments(name='test1.csv', mail_filter=mail_filter) mock_imaplib.IMAP4_SSL.return_value.search.assert_called_once_with(None, mail_filter) @@ -208,8 +195,9 @@ def test_download_mail_attachments_not_found(self, mock_imaplib, mock_open_metho _create_fake_imap(mock_imaplib, with_mail=True) with ImapHook() as imap_hook: - self.assertRaises(AirflowException, - imap_hook.download_mail_attachments, 'test1.txt', 'test_directory') + self.assertRaises( + AirflowException, imap_hook.download_mail_attachments, 'test1.txt', 'test_directory' + ) mock_open_method.assert_not_called() mock_open_method.return_value.write.assert_not_called() @@ -221,9 +209,7 @@ def test_download_mail_attachments_with_regex_found(self, mock_imaplib, mock_ope with ImapHook() as imap_hook: imap_hook.download_mail_attachments( - name=r'test(\d+).csv', - local_output_directory='test_directory', - check_regex=True + name=r'test(\d+).csv', local_output_directory='test_directory', check_regex=True ) mock_open_method.assert_called_once_with('test_directory/test1.csv', 'wb') @@ -235,11 +221,13 @@ def test_download_mail_attachments_with_regex_not_found(self, mock_imaplib, mock _create_fake_imap(mock_imaplib, with_mail=True) with ImapHook() as imap_hook: - self.assertRaises(AirflowException, - imap_hook.download_mail_attachments, - name=r'test_(\d+).csv', - local_output_directory='test_directory', - check_regex=True) + self.assertRaises( + AirflowException, + imap_hook.download_mail_attachments, + name=r'test_(\d+).csv', + local_output_directory='test_directory', + check_regex=True, + ) mock_open_method.assert_not_called() mock_open_method.return_value.write.assert_not_called() @@ -251,9 +239,7 @@ def test_download_mail_attachments_with_latest_only(self, mock_imaplib, mock_ope with ImapHook() as imap_hook: imap_hook.download_mail_attachments( - name='test1.csv', - local_output_directory='test_directory', - latest_only=True + name='test1.csv', local_output_directory='test_directory', latest_only=True ) mock_open_method.assert_called_once_with('test_directory/test1.csv', 'wb') @@ -265,10 +251,7 @@ def test_download_mail_attachments_with_escaping_chars(self, mock_imaplib, mock_ _create_fake_imap(mock_imaplib, with_mail=True, attachment_name='../test1.csv') with ImapHook() as imap_hook: - imap_hook.download_mail_attachments( - name='../test1.csv', - local_output_directory='test_directory' - ) + imap_hook.download_mail_attachments(name='../test1.csv', local_output_directory='test_directory') mock_open_method.assert_not_called() mock_open_method.return_value.write.assert_not_called() @@ -280,10 +263,7 @@ def test_download_mail_attachments_with_symlink(self, mock_imaplib, mock_open_me _create_fake_imap(mock_imaplib, with_mail=True, attachment_name='symlink') with ImapHook() as imap_hook: - imap_hook.download_mail_attachments( - name='symlink', - local_output_directory='test_directory' - ) + imap_hook.download_mail_attachments(name='symlink', local_output_directory='test_directory') assert mock_is_symlink.call_count == 1 mock_open_method.assert_not_called() @@ -297,9 +277,7 @@ def test_download_mail_attachments_with_mail_filter(self, mock_imaplib, mock_ope with ImapHook() as imap_hook: imap_hook.download_mail_attachments( - name='test1.csv', - local_output_directory='test_directory', - mail_filter=mail_filter + name='test1.csv', local_output_directory='test_directory', mail_filter=mail_filter ) mock_imaplib.IMAP4_SSL.return_value.search.assert_called_once_with(None, mail_filter) diff --git a/tests/providers/imap/sensors/test_imap_attachment.py b/tests/providers/imap/sensors/test_imap_attachment.py index 094229c08db2e..5af9679d3ea06 100644 --- a/tests/providers/imap/sensors/test_imap_attachment.py +++ b/tests/providers/imap/sensors/test_imap_attachment.py @@ -25,7 +25,6 @@ class TestImapAttachmentSensor(unittest.TestCase): - def setUp(self): self.kwargs = dict( attachment_name='test_file', @@ -33,7 +32,7 @@ def setUp(self): mail_folder='INBOX', mail_filter='All', task_id='test_task', - dag=None + dag=None, ) @parameterized.expand([(True,), (False,)]) @@ -49,5 +48,5 @@ def test_poke(self, has_attachment_return_value, mock_imap_hook): name=self.kwargs['attachment_name'], check_regex=self.kwargs['check_regex'], mail_folder=self.kwargs['mail_folder'], - mail_filter=self.kwargs['mail_filter'] + mail_filter=self.kwargs['mail_filter'], ) diff --git a/tests/providers/jdbc/hooks/test_jdbc.py b/tests/providers/jdbc/hooks/test_jdbc.py index 090133e8e528b..e0585c53b4795 100644 --- a/tests/providers/jdbc/hooks/test_jdbc.py +++ b/tests/providers/jdbc/hooks/test_jdbc.py @@ -25,22 +25,27 @@ from airflow.providers.jdbc.hooks.jdbc import JdbcHook from airflow.utils import db -jdbc_conn_mock = Mock( - name="jdbc_conn" -) +jdbc_conn_mock = Mock(name="jdbc_conn") class TestJdbcHook(unittest.TestCase): def setUp(self): db.merge_conn( Connection( - conn_id='jdbc_default', conn_type='jdbc', - host='jdbc://localhost/', port=443, - extra=json.dumps({"extra__jdbc__drv_path": "/path1/test.jar,/path2/t.jar2", - "extra__jdbc__drv_clsname": "com.driver.main"}))) + conn_id='jdbc_default', + conn_type='jdbc', + host='jdbc://localhost/', + port=443, + extra=json.dumps( + { + "extra__jdbc__drv_path": "/path1/test.jar,/path2/t.jar2", + "extra__jdbc__drv_clsname": "com.driver.main", + } + ), + ) + ) - @patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect", autospec=True, - return_value=jdbc_conn_mock) + @patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect", autospec=True, return_value=jdbc_conn_mock) def test_jdbc_conn_connection(self, jdbc_mock): jdbc_hook = JdbcHook() jdbc_conn = jdbc_hook.get_conn() diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index 50c62f422094f..e16ac446e6c9b 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -23,13 +23,8 @@ class TestJdbcOperator(unittest.TestCase): - def setUp(self): - self.kwargs = dict( - sql='sql', - task_id='test_jdbc_operator', - dag=None - ) + self.kwargs = dict(sql='sql', task_id='test_jdbc_operator', dag=None) @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook') def test_execute(self, mock_jdbc_hook): @@ -38,4 +33,5 @@ def test_execute(self, mock_jdbc_hook): mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id) mock_jdbc_hook.return_value.run.assert_called_once_with( - jdbc_operator.sql, jdbc_operator.autocommit, parameters=jdbc_operator.parameters) + jdbc_operator.sql, jdbc_operator.autocommit, parameters=jdbc_operator.parameters + ) diff --git a/tests/providers/jenkins/hooks/test_jenkins.py b/tests/providers/jenkins/hooks/test_jenkins.py index 044ac46be7a08..8300159898e08 100644 --- a/tests/providers/jenkins/hooks/test_jenkins.py +++ b/tests/providers/jenkins/hooks/test_jenkins.py @@ -23,7 +23,6 @@ class TestJenkinsHook(unittest.TestCase): - @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') def test_client_created_default_http(self, get_connection_mock): """tests `init` method to validate http client creation when all parameters are passed """ @@ -31,10 +30,14 @@ def test_client_created_default_http(self, get_connection_mock): connection_host = 'http://test.com' connection_port = 8080 - get_connection_mock.return_value = mock. \ - Mock(connection_id=default_connection_id, - login='test', password='test', extra='', - host=connection_host, port=connection_port) + get_connection_mock.return_value = mock.Mock( + connection_id=default_connection_id, + login='test', + password='test', + extra='', + host=connection_host, + port=connection_port, + ) complete_url = f'http://{connection_host}:{connection_port}/' hook = JenkinsHook(default_connection_id) @@ -49,10 +52,14 @@ def test_client_created_default_https(self, get_connection_mock): connection_host = 'http://test.com' connection_port = 8080 - get_connection_mock.return_value = mock. \ - Mock(connection_id=default_connection_id, - login='test', password='test', extra='true', - host=connection_host, port=connection_port) + get_connection_mock.return_value = mock.Mock( + connection_id=default_connection_id, + login='test', + password='test', + extra='true', + host=connection_host, + port=connection_port, + ) complete_url = f'https://{connection_host}:{connection_port}/' hook = JenkinsHook(default_connection_id) diff --git a/tests/providers/jenkins/operators/test_jenkins_job_trigger.py b/tests/providers/jenkins/operators/test_jenkins_job_trigger.py index 9af19ba22f687..60ec909ac8cf9 100644 --- a/tests/providers/jenkins/operators/test_jenkins_job_trigger.py +++ b/tests/providers/jenkins/operators/test_jenkins_job_trigger.py @@ -28,29 +28,31 @@ class TestJenkinsOperator(unittest.TestCase): - @parameterized.expand([ - ("dict params", {'a_param': 'blip', 'another_param': '42'},), - ("string params", '{"second_param": "beep", "third_param": "153"}',), - ("list params", ['final_one', 'bop', 'real_final', 'eggs'],), - ]) + @parameterized.expand( + [ + ("dict params", {'a_param': 'blip', 'another_param': '42'},), + ("string params", '{"second_param": "beep", "third_param": "153"}',), + ("list params", ['final_one', 'bop', 'real_final', 'eggs'],), + ] + ) def test_execute(self, _, parameters): jenkins_mock = Mock(spec=jenkins.Jenkins, auth='secret') - jenkins_mock.get_build_info.return_value = \ - {'result': 'SUCCESS', - 'url': 'http://aaa.fake-url.com/congratulation/its-a-job'} - jenkins_mock.build_job_url.return_value = \ - 'http://www.jenkins.url/somewhere/in/the/universe' + jenkins_mock.get_build_info.return_value = { + 'result': 'SUCCESS', + 'url': 'http://aaa.fake-url.com/congratulation/its-a-job', + } + jenkins_mock.build_job_url.return_value = 'http://www.jenkins.url/somewhere/in/the/universe' hook_mock = Mock(spec=JenkinsHook) hook_mock.get_jenkins_server.return_value = jenkins_mock - with patch.object(JenkinsJobTriggerOperator, "get_hook") as get_hook_mocked,\ - patch( - 'airflow.providers.jenkins.operators.jenkins_job_trigger.jenkins_request_with_headers') \ - as mock_make_request: - mock_make_request.side_effect = \ - [{'body': '', 'headers': {'Location': 'http://what-a-strange.url/18'}}, - {'body': '{"executable":{"number":"1"}}', 'headers': {}}] + with patch.object(JenkinsJobTriggerOperator, "get_hook") as get_hook_mocked, patch( + 'airflow.providers.jenkins.operators.jenkins_job_trigger.jenkins_request_with_headers' + ) as mock_make_request: + mock_make_request.side_effect = [ + {'body': '', 'headers': {'Location': 'http://what-a-strange.url/18'}}, + {'body': '{"executable":{"number":"1"}}', 'headers': {}}, + ] get_hook_mocked.return_value = hook_mock operator = JenkinsJobTriggerOperator( dag=None, @@ -59,39 +61,40 @@ def test_execute(self, _, parameters): task_id="operator_test", job_name="a_job_on_jenkins", parameters=parameters, - sleep_time=1) + sleep_time=1, + ) operator.execute(None) self.assertEqual(jenkins_mock.get_build_info.call_count, 1) - jenkins_mock.get_build_info.assert_called_once_with(name='a_job_on_jenkins', - number='1') - - @parameterized.expand([ - ("dict params", {'a_param': 'blip', 'another_param': '42'},), - ("string params", '{"second_param": "beep", "third_param": "153"}',), - ("list params", ['final_one', 'bop', 'real_final', 'eggs'],), - ]) + jenkins_mock.get_build_info.assert_called_once_with(name='a_job_on_jenkins', number='1') + + @parameterized.expand( + [ + ("dict params", {'a_param': 'blip', 'another_param': '42'},), + ("string params", '{"second_param": "beep", "third_param": "153"}',), + ("list params", ['final_one', 'bop', 'real_final', 'eggs'],), + ] + ) def test_execute_job_polling_loop(self, _, parameters): jenkins_mock = Mock(spec=jenkins.Jenkins, auth='secret') jenkins_mock.get_job_info.return_value = {'nextBuildNumber': '1'} - jenkins_mock.get_build_info.side_effect = \ - [{'result': None}, - {'result': 'SUCCESS', - 'url': 'http://aaa.fake-url.com/congratulation/its-a-job'}] - jenkins_mock.build_job_url.return_value = \ - 'http://www.jenkins.url/somewhere/in/the/universe' + jenkins_mock.get_build_info.side_effect = [ + {'result': None}, + {'result': 'SUCCESS', 'url': 'http://aaa.fake-url.com/congratulation/its-a-job'}, + ] + jenkins_mock.build_job_url.return_value = 'http://www.jenkins.url/somewhere/in/the/universe' hook_mock = Mock(spec=JenkinsHook) hook_mock.get_jenkins_server.return_value = jenkins_mock - with patch.object(JenkinsJobTriggerOperator, "get_hook") as get_hook_mocked,\ - patch( - 'airflow.providers.jenkins.operators.jenkins_job_trigger.jenkins_request_with_headers') \ - as mock_make_request: - mock_make_request.side_effect = \ - [{'body': '', 'headers': {'Location': 'http://what-a-strange.url/18'}}, - {'body': '{"executable":{"number":"1"}}', 'headers': {}}] + with patch.object(JenkinsJobTriggerOperator, "get_hook") as get_hook_mocked, patch( + 'airflow.providers.jenkins.operators.jenkins_job_trigger.jenkins_request_with_headers' + ) as mock_make_request: + mock_make_request.side_effect = [ + {'body': '', 'headers': {'Location': 'http://what-a-strange.url/18'}}, + {'body': '{"executable":{"number":"1"}}', 'headers': {}}, + ] get_hook_mocked.return_value = hook_mock operator = JenkinsJobTriggerOperator( dag=None, @@ -100,35 +103,38 @@ def test_execute_job_polling_loop(self, _, parameters): jenkins_connection_id="fake_jenkins_connection", # The hook is mocked, this connection won't be used parameters=parameters, - sleep_time=1) + sleep_time=1, + ) operator.execute(None) self.assertEqual(jenkins_mock.get_build_info.call_count, 2) - @parameterized.expand([ - ("dict params", {'a_param': 'blip', 'another_param': '42'},), - ("string params", '{"second_param": "beep", "third_param": "153"}',), - ("list params", ['final_one', 'bop', 'real_final', 'eggs'],), - ]) + @parameterized.expand( + [ + ("dict params", {'a_param': 'blip', 'another_param': '42'},), + ("string params", '{"second_param": "beep", "third_param": "153"}',), + ("list params", ['final_one', 'bop', 'real_final', 'eggs'],), + ] + ) def test_execute_job_failure(self, _, parameters): jenkins_mock = Mock(spec=jenkins.Jenkins, auth='secret') jenkins_mock.get_job_info.return_value = {'nextBuildNumber': '1'} jenkins_mock.get_build_info.return_value = { 'result': 'FAILURE', - 'url': 'http://aaa.fake-url.com/congratulation/its-a-job'} - jenkins_mock.build_job_url.return_value = \ - 'http://www.jenkins.url/somewhere/in/the/universe' + 'url': 'http://aaa.fake-url.com/congratulation/its-a-job', + } + jenkins_mock.build_job_url.return_value = 'http://www.jenkins.url/somewhere/in/the/universe' hook_mock = Mock(spec=JenkinsHook) hook_mock.get_jenkins_server.return_value = jenkins_mock - with patch.object(JenkinsJobTriggerOperator, "get_hook") as get_hook_mocked,\ - patch( - 'airflow.providers.jenkins.operators.jenkins_job_trigger.jenkins_request_with_headers') \ - as mock_make_request: - mock_make_request.side_effect = \ - [{'body': '', 'headers': {'Location': 'http://what-a-strange.url/18'}}, - {'body': '{"executable":{"number":"1"}}', 'headers': {}}] + with patch.object(JenkinsJobTriggerOperator, "get_hook") as get_hook_mocked, patch( + 'airflow.providers.jenkins.operators.jenkins_job_trigger.jenkins_request_with_headers' + ) as mock_make_request: + mock_make_request.side_effect = [ + {'body': '', 'headers': {'Location': 'http://what-a-strange.url/18'}}, + {'body': '{"executable":{"number":"1"}}', 'headers': {}}, + ] get_hook_mocked.return_value = hook_mock operator = JenkinsJobTriggerOperator( dag=None, @@ -137,7 +143,8 @@ def test_execute_job_failure(self, _, parameters): parameters=parameters, jenkins_connection_id="fake_jenkins_connection", # The hook is mocked, this connection won't be used - sleep_time=1) + sleep_time=1, + ) self.assertRaises(AirflowException, operator.execute, None) @@ -152,7 +159,8 @@ def test_build_job_request_settings(self): dag=None, task_id="build_job_test", job_name="a_job_on_jenkins", - jenkins_connection_id="fake_jenkins_connection") + jenkins_connection_id="fake_jenkins_connection", + ) operator.build_job(jenkins_mock) mock_request = mock_make_request.call_args_list[0][0][1] diff --git a/tests/providers/jira/hooks/test_jira.py b/tests/providers/jira/hooks/test_jira.py index acad7411d4283..06c50c551e106 100644 --- a/tests/providers/jira/hooks/test_jira.py +++ b/tests/providers/jira/hooks/test_jira.py @@ -24,21 +24,22 @@ from airflow.providers.jira.hooks.jira import JiraHook from airflow.utils import db -jira_client_mock = Mock( - name="jira_client" -) +jira_client_mock = Mock(name="jira_client") class TestJiraHook(unittest.TestCase): def setUp(self): db.merge_conn( Connection( - conn_id='jira_default', conn_type='jira', - host='https://localhost/jira/', port=443, - extra='{"verify": "False", "project": "AIRFLOW"}')) + conn_id='jira_default', + conn_type='jira', + host='https://localhost/jira/', + port=443, + extra='{"verify": "False", "project": "AIRFLOW"}', + ) + ) - @patch("airflow.providers.jira.hooks.jira.JIRA", autospec=True, - return_value=jira_client_mock) + @patch("airflow.providers.jira.hooks.jira.JIRA", autospec=True, return_value=jira_client_mock) def test_jira_client_connection(self, jira_mock): jira_hook = JiraHook() diff --git a/tests/providers/jira/operators/test_jira.py b/tests/providers/jira/operators/test_jira.py index 0f991ffadc9f2..4705f4b7b2725 100644 --- a/tests/providers/jira/operators/test_jira.py +++ b/tests/providers/jira/operators/test_jira.py @@ -26,73 +26,60 @@ from airflow.utils import db, timezone DEFAULT_DATE = timezone.datetime(2017, 1, 1) -jira_client_mock = Mock( - name="jira_client_for_test" -) +jira_client_mock = Mock(name="jira_client_for_test") minimal_test_ticket = { "id": "911539", "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539", "key": "TEST-1226", - "fields": { - "labels": [ - "test-label-1", - "test-label-2" - ], - "description": "this is a test description", - } + "fields": {"labels": ["test-label-1", "test-label-2"], "description": "this is a test description",}, } class TestJiraOperator(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) self.dag = dag db.merge_conn( Connection( - conn_id='jira_default', conn_type='jira', - host='https://localhost/jira/', port=443, - extra='{"verify": "False", "project": "AIRFLOW"}')) + conn_id='jira_default', + conn_type='jira', + host='https://localhost/jira/', + port=443, + extra='{"verify": "False", "project": "AIRFLOW"}', + ) + ) - @patch("airflow.providers.jira.hooks.jira.JIRA", - autospec=True, return_value=jira_client_mock) + @patch("airflow.providers.jira.hooks.jira.JIRA", autospec=True, return_value=jira_client_mock) def test_issue_search(self, jira_mock): jql_str = 'issuekey=TEST-1226' jira_mock.return_value.search_issues.return_value = minimal_test_ticket - jira_ticket_search_operator = JiraOperator(task_id='search-ticket-test', - jira_method="search_issues", - jira_method_args={ - 'jql_str': jql_str, - 'maxResults': '1' - }, - dag=self.dag) + jira_ticket_search_operator = JiraOperator( + task_id='search-ticket-test', + jira_method="search_issues", + jira_method_args={'jql_str': jql_str, 'maxResults': '1'}, + dag=self.dag, + ) - jira_ticket_search_operator.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + jira_ticket_search_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) self.assertTrue(jira_mock.called) self.assertTrue(jira_mock.return_value.search_issues.called) - @patch("airflow.providers.jira.hooks.jira.JIRA", - autospec=True, return_value=jira_client_mock) + @patch("airflow.providers.jira.hooks.jira.JIRA", autospec=True, return_value=jira_client_mock) def test_update_issue(self, jira_mock): jira_mock.return_value.add_comment.return_value = True - add_comment_operator = JiraOperator(task_id='add_comment_test', - jira_method="add_comment", - jira_method_args={ - 'issue': minimal_test_ticket.get("key"), - 'body': 'this is test comment' - }, - dag=self.dag) + add_comment_operator = JiraOperator( + task_id='add_comment_test', + jira_method="add_comment", + jira_method_args={'issue': minimal_test_ticket.get("key"), 'body': 'this is test comment'}, + dag=self.dag, + ) - add_comment_operator.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + add_comment_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) self.assertTrue(jira_mock.called) self.assertTrue(jira_mock.return_value.add_comment.called) diff --git a/tests/providers/jira/sensors/test_jira.py b/tests/providers/jira/sensors/test_jira.py index 782eabbd875d5..e4c437da71d42 100644 --- a/tests/providers/jira/sensors/test_jira.py +++ b/tests/providers/jira/sensors/test_jira.py @@ -26,40 +26,32 @@ from airflow.utils import db, timezone DEFAULT_DATE = timezone.datetime(2017, 1, 1) -jira_client_mock = Mock( - name="jira_client_for_test" -) +jira_client_mock = Mock(name="jira_client_for_test") minimal_test_ticket = { "id": "911539", "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539", "key": "TEST-1226", - "fields": { - "labels": [ - "test-label-1", - "test-label-2" - ], - "description": "this is a test description", - } + "fields": {"labels": ["test-label-1", "test-label-2"], "description": "this is a test description",}, } class TestJiraSensor(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) self.dag = dag db.merge_conn( Connection( - conn_id='jira_default', conn_type='jira', - host='https://localhost/jira/', port=443, - extra='{"verify": "False", "project": "AIRFLOW"}')) + conn_id='jira_default', + conn_type='jira', + host='https://localhost/jira/', + port=443, + extra='{"verify": "False", "project": "AIRFLOW"}', + ) + ) - @patch("airflow.providers.jira.hooks.jira.JIRA", - autospec=True, return_value=jira_client_mock) + @patch("airflow.providers.jira.hooks.jira.JIRA", autospec=True, return_value=jira_client_mock) def test_issue_label_set(self, jira_mock): jira_mock.return_value.issue.return_value = minimal_test_ticket @@ -70,10 +62,10 @@ def test_issue_label_set(self, jira_mock): field_checker_func=TestJiraSensor.field_checker_func, timeout=518400, poke_interval=10, - dag=self.dag) + dag=self.dag, + ) - ticket_label_sensor.run(start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, ignore_ti_state=True) + ticket_label_sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) self.assertTrue(jira_mock.called) self.assertTrue(jira_mock.return_value.issue.called) diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py b/tests/providers/microsoft/azure/hooks/test_adx.py index 557171345bc2a..e3a3ab3dbd29f 100644 --- a/tests/providers/microsoft/azure/hooks/test_adx.py +++ b/tests/providers/microsoft/azure/hooks/test_adx.py @@ -36,42 +36,48 @@ class TestAzureDataExplorerHook(unittest.TestCase): def tearDown(self): super().tearDown() with create_session() as session: - session.query(Connection).filter( - Connection.conn_id == ADX_TEST_CONN_ID).delete() + session.query(Connection).filter(Connection.conn_id == ADX_TEST_CONN_ID).delete() def test_conn_missing_method(self): db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - login='client_id', - password='client secret', - host='https://help.kusto.windows.net', - extra=json.dumps({}))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + login='client_id', + password='client secret', + host='https://help.kusto.windows.net', + extra=json.dumps({}), + ) + ) with self.assertRaises(AirflowException) as e: AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - self.assertIn('missing required parameter: `auth_method`', - str(e.exception)) + self.assertIn('missing required parameter: `auth_method`', str(e.exception)) def test_conn_unknown_method(self): db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - login='client_id', - password='client secret', - host='https://help.kusto.windows.net', - extra=json.dumps({'auth_method': 'AAD_OTHER'}))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + login='client_id', + password='client secret', + host='https://help.kusto.windows.net', + extra=json.dumps({'auth_method': 'AAD_OTHER'}), + ) + ) with self.assertRaises(AirflowException) as e: AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - self.assertIn('Unknown authentication method: AAD_OTHER', - str(e.exception)) + self.assertIn('Unknown authentication method: AAD_OTHER', str(e.exception)) def test_conn_missing_cluster(self): db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - login='client_id', - password='client secret', - extra=json.dumps({}))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + login='client_id', + password='client secret', + extra=json.dumps({}), + ) + ) with self.assertRaises(AirflowException) as e: AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) self.assertIn('Host connection option is required', str(e.exception)) @@ -80,90 +86,97 @@ def test_conn_missing_cluster(self): def test_conn_method_aad_creds(self, mock_init): mock_init.return_value = None db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - login='client_id', - password='client secret', - host='https://help.kusto.windows.net', - extra=json.dumps({ - 'tenant': 'tenant', - 'auth_method': 'AAD_CREDS' - }))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + login='client_id', + password='client secret', + host='https://help.kusto.windows.net', + extra=json.dumps({'tenant': 'tenant', 'auth_method': 'AAD_CREDS'}), + ) + ) AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) assert mock_init.called_with( KustoConnectionStringBuilder.with_aad_user_password_authentication( - 'https://help.kusto.windows.net', 'client_id', 'client secret', - 'tenant')) + 'https://help.kusto.windows.net', 'client_id', 'client secret', 'tenant' + ) + ) @mock.patch.object(KustoClient, '__init__') def test_conn_method_aad_app(self, mock_init): mock_init.return_value = None db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - login='app_id', - password='app key', - host='https://help.kusto.windows.net', - extra=json.dumps({ - 'tenant': 'tenant', - 'auth_method': 'AAD_APP' - }))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + login='app_id', + password='app key', + host='https://help.kusto.windows.net', + extra=json.dumps({'tenant': 'tenant', 'auth_method': 'AAD_APP'}), + ) + ) AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) assert mock_init.called_with( - KustoConnectionStringBuilder. - with_aad_application_key_authentication( - 'https://help.kusto.windows.net', 'app_id', 'app key', - 'tenant')) + KustoConnectionStringBuilder.with_aad_application_key_authentication( + 'https://help.kusto.windows.net', 'app_id', 'app key', 'tenant' + ) + ) @mock.patch.object(KustoClient, '__init__') def test_conn_method_aad_app_cert(self, mock_init): mock_init.return_value = None db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - login='client_id', - host='https://help.kusto.windows.net', - extra=json.dumps({ - 'tenant': 'tenant', - 'auth_method': 'AAD_APP_CERT', - 'certificate': 'PEM', - 'thumbprint': 'thumbprint' - }))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + login='client_id', + host='https://help.kusto.windows.net', + extra=json.dumps( + { + 'tenant': 'tenant', + 'auth_method': 'AAD_APP_CERT', + 'certificate': 'PEM', + 'thumbprint': 'thumbprint', + } + ), + ) + ) AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) assert mock_init.called_with( - KustoConnectionStringBuilder. - with_aad_application_certificate_authentication( - 'https://help.kusto.windows.net', 'client_id', 'PEM', - 'thumbprint', 'tenant')) + KustoConnectionStringBuilder.with_aad_application_certificate_authentication( + 'https://help.kusto.windows.net', 'client_id', 'PEM', 'thumbprint', 'tenant' + ) + ) @mock.patch.object(KustoClient, '__init__') def test_conn_method_aad_device(self, mock_init): mock_init.return_value = None db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - host='https://help.kusto.windows.net', - extra=json.dumps({'auth_method': 'AAD_DEVICE'}))) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + host='https://help.kusto.windows.net', + extra=json.dumps({'auth_method': 'AAD_DEVICE'}), + ) + ) AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) assert mock_init.called_with( - KustoConnectionStringBuilder.with_aad_device_authentication( - 'https://help.kusto.windows.net')) + KustoConnectionStringBuilder.with_aad_device_authentication('https://help.kusto.windows.net') + ) @mock.patch.object(KustoClient, 'execute') def test_run_query(self, mock_execute): mock_execute.return_value = None db.merge_conn( - Connection(conn_id=ADX_TEST_CONN_ID, - conn_type='azure_data_explorer', - host='https://help.kusto.windows.net', - extra=json.dumps({'auth_method': 'AAD_DEVICE'}))) - hook = AzureDataExplorerHook( - azure_data_explorer_conn_id=ADX_TEST_CONN_ID) - hook.run_query('Database', - 'Logs | schema', - options={'option1': 'option_value'}) + Connection( + conn_id=ADX_TEST_CONN_ID, + conn_type='azure_data_explorer', + host='https://help.kusto.windows.net', + extra=json.dumps({'auth_method': 'AAD_DEVICE'}), + ) + ) + hook = AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID) + hook.run_query('Database', 'Logs | schema', options={'option1': 'option_value'}) properties = ClientRequestProperties() properties.set_option('option1', 'option_value') - assert mock_execute.called_with('Database', - 'Logs | schema', - properties=properties) + assert mock_execute.called_with('Database', 'Logs | schema', properties=properties) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py index 7ea8b742f7893..2a6a0c2c0571b 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py @@ -46,30 +46,38 @@ def setUp(self): # connect with vm configuration db.merge_conn( - Connection(conn_id=self.test_vm_conn_id, - conn_type="azure_batch", - extra=json.dumps({ - "account_name": self.test_account_name, - "account_key": self.test_account_key, - "account_url": self.test_account_url, - "vm_publisher": self.test_vm_publisher, - "vm_offer": self.test_vm_offer, - "vm_sku": self.test_vm_sku, - "node_agent_sku_id": self.test_node_agent_sku - })) + Connection( + conn_id=self.test_vm_conn_id, + conn_type="azure_batch", + extra=json.dumps( + { + "account_name": self.test_account_name, + "account_key": self.test_account_key, + "account_url": self.test_account_url, + "vm_publisher": self.test_vm_publisher, + "vm_offer": self.test_vm_offer, + "vm_sku": self.test_vm_sku, + "node_agent_sku_id": self.test_node_agent_sku, + } + ), + ) ) # connect with cloud service db.merge_conn( - Connection(conn_id=self.test_cloud_conn_id, - conn_type="azure_batch", - extra=json.dumps({ - "account_name": self.test_account_name, - "account_key": self.test_account_key, - "account_url": self.test_account_url, - "os_family": self.test_cloud_os_family, - "os_version": self.test_cloud_os_version, - "node_agent_sku_id": self.test_node_agent_sku - })) + Connection( + conn_id=self.test_cloud_conn_id, + conn_type="azure_batch", + extra=json.dumps( + { + "account_name": self.test_account_name, + "account_key": self.test_account_key, + "account_url": self.test_account_url, + "os_family": self.test_cloud_os_family, + "os_version": self.test_cloud_os_version, + "node_agent_sku_id": self.test_node_agent_sku, + } + ), + ) ) def test_connection_and_client(self): @@ -79,41 +87,32 @@ def test_connection_and_client(self): def test_configure_pool_with_vm_config(self): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) - pool = hook.configure_pool(pool_id='mypool', - vm_size="test_vm_size", - target_dedicated_nodes=1, - ) + pool = hook.configure_pool(pool_id='mypool', vm_size="test_vm_size", target_dedicated_nodes=1,) self.assertIsInstance(pool, batch_models.PoolAddParameter) def test_configure_pool_with_cloud_config(self): hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) - pool = hook.configure_pool(pool_id='mypool', - vm_size="test_vm_size", - target_dedicated_nodes=1, - ) + pool = hook.configure_pool(pool_id='mypool', vm_size="test_vm_size", target_dedicated_nodes=1,) self.assertIsInstance(pool, batch_models.PoolAddParameter) def test_configure_pool_with_latest_vm(self): - with mock.patch("airflow.providers.microsoft.azure.hooks." - "azure_batch.AzureBatchHook._get_latest_verified_image_vm_and_sku")\ - as mock_getvm: + with mock.patch( + "airflow.providers.microsoft.azure.hooks." + "azure_batch.AzureBatchHook._get_latest_verified_image_vm_and_sku" + ) as mock_getvm: hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) getvm_instance = mock_getvm getvm_instance.return_value = ['test-image', 'test-sku'] - pool = hook.configure_pool(pool_id='mypool', - vm_size="test_vm_size", - use_latest_image_and_sku=True, - ) + pool = hook.configure_pool( + pool_id='mypool', vm_size="test_vm_size", use_latest_image_and_sku=True, + ) self.assertIsInstance(pool, batch_models.PoolAddParameter) @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") def test_create_pool_with_vm_config(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) mock_instance = mock_batch.return_value.pool.add - pool = hook.configure_pool(pool_id='mypool', - vm_size="test_vm_size", - target_dedicated_nodes=1, - ) + pool = hook.configure_pool(pool_id='mypool', vm_size="test_vm_size", target_dedicated_nodes=1,) hook.create_pool(pool=pool) mock_instance.assert_called_once_with(pool) @@ -121,10 +120,7 @@ def test_create_pool_with_vm_config(self, mock_batch): def test_create_pool_with_cloud_config(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) mock_instance = mock_batch.return_value.pool.add - pool = hook.configure_pool(pool_id='mypool', - vm_size="test_vm_size", - target_dedicated_nodes=1, - ) + pool = hook.configure_pool(pool_id='mypool', vm_size="test_vm_size", target_dedicated_nodes=1,) hook.create_pool(pool=pool) mock_instance.assert_called_once_with(pool) @@ -137,8 +133,7 @@ def test_wait_for_all_nodes(self, mock_batch): def test_job_configuration_and_create_job(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) mock_instance = mock_batch.return_value.job.add - job = hook.configure_job(job_id='myjob', - pool_id='mypool') + job = hook.configure_job(job_id='myjob', pool_id='mypool') hook.create_job(job) self.assertIsInstance(job, batch_models.JobAddParameter) mock_instance.assert_called_once_with(job) @@ -147,10 +142,8 @@ def test_job_configuration_and_create_job(self, mock_batch): def test_add_single_task_to_job(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) mock_instance = mock_batch.return_value.task.add - task = hook.configure_task(task_id="mytask", - command_line="echo hello") - hook.add_single_task_to_job(job_id='myjob', - task=task) + task = hook.configure_task(task_id="mytask", command_line="echo hello") + hook.add_single_task_to_job(job_id='myjob', task=task) self.assertIsInstance(task, batch_models.TaskAddParameter) mock_instance.assert_called_once_with(job_id="myjob", task=task) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py index 185c161ce63f7..a38c93ca26ec1 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py @@ -21,7 +21,11 @@ from unittest.mock import patch from azure.mgmt.containerinstance.models import ( - Container, ContainerGroup, Logs, ResourceRequests, ResourceRequirements, + Container, + ContainerGroup, + Logs, + ResourceRequests, + ResourceRequirements, ) from airflow.models import Connection @@ -30,7 +34,6 @@ class TestAzureContainerInstanceHook(unittest.TestCase): - def setUp(self): db.merge_conn( Connection( @@ -38,16 +41,14 @@ def setUp(self): conn_type='azure_container_instances', login='login', password='key', - extra=json.dumps({'tenantId': 'tenant_id', - 'subscriptionId': 'subscription_id'}) + extra=json.dumps({'tenantId': 'tenant_id', 'subscriptionId': 'subscription_id'}), ) ) - self.resources = ResourceRequirements(requests=ResourceRequests( - memory_in_gb='4', - cpu='1')) - with patch('azure.common.credentials.ServicePrincipalCredentials.__init__', - autospec=True, return_value=None): + self.resources = ResourceRequirements(requests=ResourceRequests(memory_in_gb='4', cpu='1')) + with patch( + 'azure.common.credentials.ServicePrincipalCredentials.__init__', autospec=True, return_value=None + ): with patch('azure.mgmt.containerinstance.ContainerInstanceManagementClient'): self.hook = AzureContainerInstanceHook(conn_id='azure_container_instance_test') @@ -79,16 +80,20 @@ def test_delete(self, delete_mock): @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group') def test_exists_with_existing(self, list_mock): - list_mock.return_value = [ContainerGroup(os_type='Linux', - containers=[Container(name='test1', - image='hello-world', - resources=self.resources)])] + list_mock.return_value = [ + ContainerGroup( + os_type='Linux', + containers=[Container(name='test1', image='hello-world', resources=self.resources)], + ) + ] self.assertFalse(self.hook.exists('test', 'test1')) @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group') def test_exists_with_not_existing(self, list_mock): - list_mock.return_value = [ContainerGroup(os_type='Linux', - containers=[Container(name='test1', - image='hello-world', - resources=self.resources)])] + list_mock.return_value = [ + ContainerGroup( + os_type='Linux', + containers=[Container(name='test1', image='hello-world', resources=self.resources)], + ) + ] self.assertFalse(self.hook.exists('test', 'not found')) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py index 801721cc6ff62..703b4ce5de8e3 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py @@ -24,7 +24,6 @@ class TestAzureContainerRegistryHook(unittest.TestCase): - def test_get_conn(self): db.merge_conn( Connection( diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py index 43f76e3651d7d..ba5fb3720fff2 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py @@ -24,21 +24,12 @@ class TestAzureContainerVolumeHook(unittest.TestCase): - def test_get_file_volume(self): - db.merge_conn( - Connection( - conn_id='wasb_test_key', - conn_type='wasb', - login='login', - password='key' - ) - ) + db.merge_conn(Connection(conn_id='wasb_test_key', conn_type='wasb', login='login', password='key')) hook = AzureContainerVolumeHook(wasb_conn_id='wasb_test_key') - volume = hook.get_file_volume(mount_name='mount', - share_name='share', - storage_account_name='storage', - read_only=True) + volume = hook.get_file_volume( + mount_name='mount', share_name='share', storage_account_name='storage', read_only=True + ) self.assertIsNotNone(volume) self.assertEqual(volume.name, 'mount') self.assertEqual(volume.azure_file.share_name, 'share') diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py index 1374dc09946f9..2e5c266ca7d09 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py @@ -49,8 +49,12 @@ def setUp(self): conn_type='azure_cosmos', login=self.test_end_point, password=self.test_master_key, - extra=json.dumps({'database_name': self.test_database_default, - 'collection_name': self.test_collection_default}) + extra=json.dumps( + { + 'database_name': self.test_database_default, + 'collection_name': self.test_collection_default, + } + ), ) ) @@ -82,9 +86,9 @@ def test_create_container_exception(self, mock_cosmos): def test_create_container(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name, self.test_database_name) - expected_calls = [mock.call().CreateContainer( - 'dbs/test_database_name', - {'id': self.test_collection_name})] + expected_calls = [ + mock.call().CreateContainer('dbs/test_database_name', {'id': self.test_collection_name}) + ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) @@ -92,9 +96,9 @@ def test_create_container(self, mock_cosmos): def test_create_container_default(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name) - expected_calls = [mock.call().CreateContainer( - 'dbs/test_database_default', - {'id': self.test_collection_name})] + expected_calls = [ + mock.call().CreateContainer('dbs/test_database_default', {'id': self.test_collection_name}) + ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) @@ -104,9 +108,12 @@ def test_upsert_document_default(self, mock_cosmos): mock_cosmos.return_value.CreateItem.return_value = {'id': test_id} hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') returned_item = hook.upsert_document({'id': test_id}) - expected_calls = [mock.call().CreateItem( - 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'id': test_id})] + expected_calls = [ + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'id': test_id}, + ) + ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) logging.getLogger().info(returned_item) @@ -121,11 +128,15 @@ def test_upsert_document(self, mock_cosmos): {'data1': 'somedata'}, database_name=self.test_database_name, collection_name=self.test_collection_name, - document_id=test_id) + document_id=test_id, + ) - expected_calls = [mock.call().CreateItem( - 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, - {'data1': 'somedata', 'id': test_id})] + expected_calls = [ + mock.call().CreateItem( + 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, + {'data1': 'somedata', 'id': test_id}, + ) + ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) @@ -140,20 +151,25 @@ def test_insert_documents(self, mock_cosmos): documents = [ {'id': test_id1, 'data': 'data1'}, {'id': test_id2, 'data': 'data2'}, - {'id': test_id3, 'data': 'data3'}] + {'id': test_id3, 'data': 'data3'}, + ] hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') returned_item = hook.insert_documents(documents) expected_calls = [ mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'data': 'data1', 'id': test_id1}), + {'data': 'data1', 'id': test_id1}, + ), mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'data': 'data2', 'id': test_id2}), + {'data': 'data2', 'id': test_id2}, + ), mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'data': 'data3', 'id': test_id3})] + {'data': 'data3', 'id': test_id3}, + ), + ] logging.getLogger().info(returned_item) mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls, any_order=True) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py index 4e5b781a289db..2a8c4604331ef 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py @@ -26,7 +26,6 @@ class TestAzureDataLakeHook(unittest.TestCase): - def setUp(self): db.merge_conn( Connection( @@ -34,8 +33,7 @@ def setUp(self): conn_type='azure_data_lake', login='client_id', password='client secret', - extra=json.dumps({"tenant": "tenant", - "account_name": "accountname"}) + extra=json.dumps({"tenant": "tenant", "account_name": "accountname"}), ) ) @@ -44,67 +42,94 @@ def test_conn(self, mock_lib): from azure.datalake.store import core from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') self.assertIsNone(hook._conn) self.assertEqual(hook.conn_id, 'adl_test_key') self.assertIsInstance(hook.get_conn(), core.AzureDLFileSystem) assert mock_lib.auth.called - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', - autospec=True) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True + ) @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) def test_check_for_blob(self, mock_lib, mock_filesystem): from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.check_for_file('file_path') mock_filesystem.glob.called - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.multithread.ADLUploader', - autospec=True) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.azure_data_lake.multithread.ADLUploader', autospec=True + ) @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) def test_upload_file(self, mock_lib, mock_uploader): from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') - hook.upload_file(local_path='tests/hooks/test_adl_hook.py', - remote_path='/test_adl_hook.py', - nthreads=64, overwrite=True, - buffersize=4194304, blocksize=4194304) - mock_uploader.assert_called_once_with(hook.get_conn(), - lpath='tests/hooks/test_adl_hook.py', - rpath='/test_adl_hook.py', - nthreads=64, overwrite=True, - buffersize=4194304, blocksize=4194304) - - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.multithread.ADLDownloader', - autospec=True) + hook.upload_file( + local_path='tests/hooks/test_adl_hook.py', + remote_path='/test_adl_hook.py', + nthreads=64, + overwrite=True, + buffersize=4194304, + blocksize=4194304, + ) + mock_uploader.assert_called_once_with( + hook.get_conn(), + lpath='tests/hooks/test_adl_hook.py', + rpath='/test_adl_hook.py', + nthreads=64, + overwrite=True, + buffersize=4194304, + blocksize=4194304, + ) + + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.azure_data_lake.multithread.ADLDownloader', autospec=True + ) @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) def test_download_file(self, mock_lib, mock_downloader): from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') - hook.download_file(local_path='test_adl_hook.py', - remote_path='/test_adl_hook.py', - nthreads=64, overwrite=True, - buffersize=4194304, blocksize=4194304) - mock_downloader.assert_called_once_with(hook.get_conn(), - lpath='test_adl_hook.py', - rpath='/test_adl_hook.py', - nthreads=64, overwrite=True, - buffersize=4194304, blocksize=4194304) - - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', - autospec=True) + hook.download_file( + local_path='test_adl_hook.py', + remote_path='/test_adl_hook.py', + nthreads=64, + overwrite=True, + buffersize=4194304, + blocksize=4194304, + ) + mock_downloader.assert_called_once_with( + hook.get_conn(), + lpath='test_adl_hook.py', + rpath='/test_adl_hook.py', + nthreads=64, + overwrite=True, + buffersize=4194304, + blocksize=4194304, + ) + + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True + ) @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) def test_list_glob(self, mock_lib, mock_fs): from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.list('file_path/*') mock_fs.return_value.glob.assert_called_once_with('file_path/*') - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', - autospec=True) + @mock.patch( + 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True + ) @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) def test_list_walk(self, mock_lib, mock_fs): from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.list('file_path/some_folder/') mock_fs.return_value.walk.assert_called_once_with('file_path/some_folder/') diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py index 16048b394ab26..9937da1902fb3 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py @@ -36,23 +36,20 @@ class TestAzureFileshareHook(unittest.TestCase): - def setUp(self): + db.merge_conn(Connection(conn_id='wasb_test_key', conn_type='wasb', login='login', password='key')) db.merge_conn( Connection( - conn_id='wasb_test_key', conn_type='wasb', - login='login', password='key' - ) - ) - db.merge_conn( - Connection( - conn_id='wasb_test_sas_token', conn_type='wasb', - login='login', extra=json.dumps({'sas_token': 'token'}) + conn_id='wasb_test_sas_token', + conn_type='wasb', + login='login', + extra=json.dumps({'sas_token': 'token'}), ) ) def test_key_and_connection(self): from azure.storage.file import FileService + hook = AzureFileShareHook(wasb_conn_id='wasb_test_key') self.assertEqual(hook.conn_id, 'wasb_test_key') self.assertIsNone(hook._conn) @@ -60,34 +57,28 @@ def test_key_and_connection(self): def test_sas_token(self): from azure.storage.file import FileService + hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') self.assertEqual(hook.conn_id, 'wasb_test_sas_token') self.assertIsInstance(hook.get_conn(), FileService) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_check_for_file(self, mock_service): mock_instance = mock_service.return_value mock_instance.exists.return_value = True hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') self.assertTrue(hook.check_for_file('share', 'directory', 'file', timeout=3)) - mock_instance.exists.assert_called_once_with( - 'share', 'directory', 'file', timeout=3 - ) + mock_instance.exists.assert_called_once_with('share', 'directory', 'file', timeout=3) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_check_for_directory(self, mock_service): mock_instance = mock_service.return_value mock_instance.exists.return_value = True hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') self.assertTrue(hook.check_for_directory('share', 'directory', timeout=3)) - mock_instance.exists.assert_called_once_with( - 'share', 'directory', timeout=3 - ) + mock_instance.exists.assert_called_once_with('share', 'directory', timeout=3) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_load_file(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') @@ -96,8 +87,7 @@ def test_load_file(self, mock_service): 'share', 'directory', 'file', 'path', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_load_string(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') @@ -106,8 +96,7 @@ def test_load_string(self, mock_service): 'share', 'directory', 'file', 'big string', timeout=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_load_stream(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') @@ -116,28 +105,21 @@ def test_load_stream(self, mock_service): 'share', 'directory', 'file', 'stream', 42, timeout=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_list_directories_and_files(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') hook.list_directories_and_files('share', 'directory', timeout=1) - mock_instance.list_directories_and_files.assert_called_once_with( - 'share', 'directory', timeout=1 - ) + mock_instance.list_directories_and_files.assert_called_once_with('share', 'directory', timeout=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_create_directory(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') hook.create_directory('share', 'directory', timeout=1) - mock_instance.create_directory.assert_called_once_with( - 'share', 'directory', timeout=1 - ) + mock_instance.create_directory.assert_called_once_with('share', 'directory', timeout=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_get_file(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') @@ -146,8 +128,7 @@ def test_get_file(self, mock_service): 'share', 'directory', 'file', 'path', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) def test_get_file_to_stream(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token') diff --git a/tests/providers/microsoft/azure/hooks/test_base_azure.py b/tests/providers/microsoft/azure/hooks/test_base_azure.py index 46d84e5e05e47..84950979a8494 100644 --- a/tests/providers/microsoft/azure/hooks/test_base_azure.py +++ b/tests/providers/microsoft/azure/hooks/test_base_azure.py @@ -22,49 +22,46 @@ class TestBaseAzureHook(unittest.TestCase): - @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file') - @patch('airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection', - return_value=Connection( - conn_id='azure_default', - extra='{ "key_path": "key_file.json" }' - )) + @patch( + 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection', + return_value=Connection(conn_id='azure_default', extra='{ "key_path": "key_file.json" }'), + ) def test_get_conn_with_key_path(self, mock_connection, mock_get_client_from_auth_file): mock_sdk_client = Mock() auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn() mock_get_client_from_auth_file.assert_called_once_with( - client_class=mock_sdk_client, - auth_path=mock_connection.return_value.extra_dejson['key_path'] + client_class=mock_sdk_client, auth_path=mock_connection.return_value.extra_dejson['key_path'] ) assert auth_sdk_client == mock_get_client_from_auth_file.return_value @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict') - @patch('airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection', - return_value=Connection( - conn_id='azure_default', - extra='{ "key_json": { "test": "test" } }' - )) + @patch( + 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection', + return_value=Connection(conn_id='azure_default', extra='{ "key_json": { "test": "test" } }'), + ) def test_get_conn_with_key_json(self, mock_connection, mock_get_client_from_json_dict): mock_sdk_client = Mock() auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn() mock_get_client_from_json_dict.assert_called_once_with( - client_class=mock_sdk_client, - config_dict=mock_connection.return_value.extra_dejson['key_json'] + client_class=mock_sdk_client, config_dict=mock_connection.return_value.extra_dejson['key_json'] ) assert auth_sdk_client == mock_get_client_from_json_dict.return_value @patch('airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials') - @patch('airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection', - return_value=Connection( - conn_id='azure_default', - login='my_login', - password='my_password', - extra='{ "tenantId": "my_tenant", "subscriptionId": "my_subscription" }' - )) + @patch( + 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection', + return_value=Connection( + conn_id='azure_default', + login='my_login', + password='my_password', + extra='{ "tenantId": "my_tenant", "subscriptionId": "my_subscription" }', + ), + ) def test_get_conn_with_credentials(self, mock_connection, mock_spc): mock_sdk_client = Mock() @@ -73,10 +70,10 @@ def test_get_conn_with_credentials(self, mock_connection, mock_spc): mock_spc.assert_called_once_with( client_id=mock_connection.return_value.login, secret=mock_connection.return_value.password, - tenant=mock_connection.return_value.extra_dejson['tenantId'] + tenant=mock_connection.return_value.extra_dejson['tenantId'], ) mock_sdk_client.assert_called_once_with( credentials=mock_spc.return_value, - subscription_id=mock_connection.return_value.extra_dejson['subscriptionId'] + subscription_id=mock_connection.return_value.extra_dejson['subscriptionId'], ) assert auth_sdk_client == mock_sdk_client.return_value diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 87b1c44b0c8e7..d8a7dfa8bf3dc 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -31,73 +31,61 @@ class TestWasbHook(unittest.TestCase): - def setUp(self): + db.merge_conn(Connection(conn_id='wasb_test_key', conn_type='wasb', login='login', password='key')) db.merge_conn( Connection( - conn_id='wasb_test_key', conn_type='wasb', - login='login', password='key' - ) - ) - db.merge_conn( - Connection( - conn_id='wasb_test_sas_token', conn_type='wasb', - login='login', extra=json.dumps({'sas_token': 'token'}) + conn_id='wasb_test_sas_token', + conn_type='wasb', + login='login', + extra=json.dumps({'sas_token': 'token'}), ) ) def test_key(self): from azure.storage.blob import BlockBlobService + hook = WasbHook(wasb_conn_id='wasb_test_key') self.assertEqual(hook.conn_id, 'wasb_test_key') self.assertIsInstance(hook.connection, BlockBlobService) def test_sas_token(self): from azure.storage.blob import BlockBlobService + hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertEqual(hook.conn_id, 'wasb_test_sas_token') self.assertIsInstance(hook.connection, BlockBlobService) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_check_for_blob(self, mock_service): mock_instance = mock_service.return_value mock_instance.exists.return_value = True hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertTrue(hook.check_for_blob('container', 'blob', timeout=3)) - mock_instance.exists.assert_called_once_with( - 'container', 'blob', timeout=3 - ) + mock_instance.exists.assert_called_once_with('container', 'blob', timeout=3) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_check_for_blob_empty(self, mock_service): mock_service.return_value.exists.return_value = False hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertFalse(hook.check_for_blob('container', 'blob')) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_check_for_prefix(self, mock_service): mock_instance = mock_service.return_value mock_instance.list_blobs.return_value = iter(['blob_1']) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') - self.assertTrue(hook.check_for_prefix('container', 'prefix', - timeout=3)) - mock_instance.list_blobs.assert_called_once_with( - 'container', 'prefix', num_results=1, timeout=3 - ) + self.assertTrue(hook.check_for_prefix('container', 'prefix', timeout=3)) + mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_check_for_prefix_empty(self, mock_service): mock_instance = mock_service.return_value mock_instance.list_blobs.return_value = iter([]) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertFalse(hook.check_for_prefix('container', 'prefix')) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_load_file(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') @@ -106,8 +94,7 @@ def test_load_file(self, mock_service): 'container', 'blob', 'path', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_load_string(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') @@ -116,44 +103,32 @@ def test_load_string(self, mock_service): 'container', 'blob', 'big string', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_get_file(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.get_file('path', 'container', 'blob', max_connections=1) - mock_instance.get_blob_to_path.assert_called_once_with( - 'container', 'blob', 'path', max_connections=1 - ) + mock_instance.get_blob_to_path.assert_called_once_with('container', 'blob', 'path', max_connections=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_read_file(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.read_file('container', 'blob', max_connections=1) - mock_instance.get_blob_to_text.assert_called_once_with( - 'container', 'blob', max_connections=1 - ) + mock_instance.get_blob_to_text.assert_called_once_with('container', 'blob', max_connections=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_delete_single_blob(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.delete_file('container', 'blob', is_prefix=False) - mock_instance.delete_blob.assert_called_once_with( - 'container', 'blob', delete_snapshots='include' - ) + mock_instance.delete_blob.assert_called_once_with('container', 'blob', delete_snapshots='include') - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_delete_multiple_blobs(self, mock_service): mock_instance = mock_service.return_value Blob = namedtuple('Blob', ['name']) - mock_instance.list_blobs.return_value = iter( - [Blob('blob_prefix/blob1'), Blob('blob_prefix/blob2')] - ) + mock_instance.list_blobs.return_value = iter([Blob('blob_prefix/blob1'), Blob('blob_prefix/blob2')]) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.delete_file('container', 'blob_prefix', is_prefix=True) mock_instance.delete_blob.assert_any_call( @@ -163,38 +138,27 @@ def test_delete_multiple_blobs(self, mock_service): 'container', 'blob_prefix/blob2', delete_snapshots='include' ) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_delete_nonexisting_blob_fails(self, mock_service): mock_instance = mock_service.return_value mock_instance.exists.return_value = False hook = WasbHook(wasb_conn_id='wasb_test_sas_token') with self.assertRaises(Exception) as context: - hook.delete_file( - 'container', 'nonexisting_blob', - is_prefix=False, ignore_if_missing=False - ) + hook.delete_file('container', 'nonexisting_blob', is_prefix=False, ignore_if_missing=False) self.assertIsInstance(context.exception, AirflowException) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_delete_multiple_nonexisting_blobs_fails(self, mock_service): mock_instance = mock_service.return_value mock_instance.list_blobs.return_value = iter([]) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') with self.assertRaises(Exception) as context: - hook.delete_file( - 'container', 'nonexisting_blob_prefix', - is_prefix=True, ignore_if_missing=False - ) + hook.delete_file('container', 'nonexisting_blob_prefix', is_prefix=True, ignore_if_missing=False) self.assertIsInstance(context.exception, AirflowException) - @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True) def test_get_blobs_list(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.get_blobs_list('container', 'prefix', num_results=1, timeout=3) - mock_instance.list_blobs.assert_called_once_with( - 'container', 'prefix', num_results=1, timeout=3 - ) + mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3) diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py index 6141533dd2b37..9a51ec509c165 100644 --- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py +++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py @@ -30,7 +30,6 @@ class TestWasbTaskHandler(unittest.TestCase): - def setUp(self): super().setUp() self.wasb_log_folder = 'wasb://container/remote/log/location' @@ -43,7 +42,7 @@ def setUp(self): wasb_log_folder=self.wasb_log_folder, wasb_container=self.container_name, filename_template=self.filename_template, - delete_local_copy=True + delete_local_copy=True, ) date = datetime(2020, 8, 10) @@ -62,11 +61,7 @@ def test_hook(self, mock_service): @conf_vars({('logging', 'remote_log_conn_id'): 'wasb_default'}) def test_hook_raises(self): handler = WasbTaskHandler( - self.local_log_location, - self.wasb_log_folder, - self.container_name, - self.filename_template, - True + self.local_log_location, self.wasb_log_folder, self.container_name, self.filename_template, True ) with mock.patch.object(handler.log, 'error') as mock_error: with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") as mock_hook: @@ -77,7 +72,8 @@ def test_hook_raises(self): mock_error.assert_called_once_with( 'Could not create an WasbHook with connection id "%s". ' 'Please make sure that airflow[azure] is installed and ' - 'the Wasb connection exists.', "wasb_default" + 'the Wasb connection exists.', + "wasb_default", ) def test_set_context_raw(self): @@ -95,30 +91,24 @@ def test_wasb_log_exists(self, mock_hook): instance.check_for_blob.return_value = True self.wasb_task_handler.wasb_log_exists(self.remote_log_location) mock_hook.return_value.check_for_blob.assert_called_once_with( - self.container_name, - self.remote_log_location + self.container_name, self.remote_log_location ) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") def test_wasb_read(self, mock_hook): mock_hook.return_value.read_file.return_value = 'Log line' - self.assertEqual( - self.wasb_task_handler.wasb_read(self.remote_log_location), - "Log line" - ) + self.assertEqual(self.wasb_task_handler.wasb_read(self.remote_log_location), "Log line") self.assertEqual( self.wasb_task_handler.read(self.ti), - (['*** Reading remote log from wasb://container/remote/log/location/1.log.\n' - 'Log line\n'], [{'end_of_log': True}]) + ( + ['*** Reading remote log from wasb://container/remote/log/location/1.log.\nLog line\n'], + [{'end_of_log': True}], + ), ) def test_wasb_read_raises(self): handler = WasbTaskHandler( - self.local_log_location, - self.wasb_log_folder, - self.container_name, - self.filename_template, - True + self.local_log_location, self.wasb_log_folder, self.container_name, self.filename_template, True ) with mock.patch.object(handler.log, 'error') as mock_error: with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") as mock_hook: @@ -138,9 +128,7 @@ def test_write_log(self, mock_log_exists, mock_wasb_read, mock_hook): mock_wasb_read.return_value = "" self.wasb_task_handler.wasb_write('text', self.remote_log_location) mock_hook.return_value.load_string.assert_called_once_with( - "text", - self.container_name, - self.remote_log_location + "text", self.container_name, self.remote_log_location ) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") @@ -151,27 +139,19 @@ def test_write_on_existing_log(self, mock_log_exists, mock_wasb_read, mock_hook) mock_wasb_read.return_value = "old log" self.wasb_task_handler.wasb_write('text', self.remote_log_location) mock_hook.return_value.load_string.assert_called_once_with( - "old log\ntext", - self.container_name, - self.remote_log_location + "old log\ntext", self.container_name, self.remote_log_location ) @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") def test_write_when_append_is_false(self, mock_hook): self.wasb_task_handler.wasb_write('text', self.remote_log_location, False) mock_hook.return_value.load_string.assert_called_once_with( - "text", - self.container_name, - self.remote_log_location + "text", self.container_name, self.remote_log_location ) def test_write_raises(self): handler = WasbTaskHandler( - self.local_log_location, - self.wasb_log_folder, - self.container_name, - self.filename_template, - True + self.local_log_location, self.wasb_log_folder, self.container_name, self.filename_template, True ) with mock.patch.object(handler.log, 'error') as mock_error: with mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") as mock_hook: diff --git a/tests/providers/microsoft/azure/operators/test_adls_list.py b/tests/providers/microsoft/azure/operators/test_adls_list.py index d5467ddc37be3..fadd60a1f0483 100644 --- a/tests/providers/microsoft/azure/operators/test_adls_list.py +++ b/tests/providers/microsoft/azure/operators/test_adls_list.py @@ -24,21 +24,22 @@ TASK_ID = 'test-adls-list-operator' TEST_PATH = 'test/*' -MOCK_FILES = ["test/TEST1.csv", "test/TEST2.csv", "test/path/TEST3.csv", - "test/path/PARQUET.parquet", "test/path/PIC.png"] +MOCK_FILES = [ + "test/TEST1.csv", + "test/TEST2.csv", + "test/path/TEST3.csv", + "test/path/PARQUET.parquet", + "test/path/PIC.png", +] class TestAzureDataLakeStorageListOperator(unittest.TestCase): - @mock.patch('airflow.providers.microsoft.azure.operators.adls_list.AzureDataLakeHook') def test_execute(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - operator = AzureDataLakeStorageListOperator(task_id=TASK_ID, - path=TEST_PATH) + operator = AzureDataLakeStorageListOperator(task_id=TASK_ID, path=TEST_PATH) files = operator.execute(None) - mock_hook.return_value.list.assert_called_once_with( - path=TEST_PATH - ) + mock_hook.return_value.list.assert_called_once_with(path=TEST_PATH) self.assertEqual(sorted(files), sorted(MOCK_FILES)) diff --git a/tests/providers/microsoft/azure/operators/test_adx.py b/tests/providers/microsoft/azure/operators/test_adx.py index 519ade6a8bf75..751980de82d05 100644 --- a/tests/providers/microsoft/azure/operators/test_adx.py +++ b/tests/providers/microsoft/azure/operators/test_adx.py @@ -33,25 +33,21 @@ 'task_id': 'test_azure_data_explorer_query_operator', 'query': 'Logs | schema', 'database': 'Database', - 'options': { - 'option1': 'option_value' - } + 'options': {'option1': 'option_value'}, } MOCK_RESULT = { 'name': 'getschema', 'kind': 'PrimaryResult', - 'data': [{ - 'ColumnName': 'Source', - 'ColumnOrdinal': 0, - 'DataType': 'System.String', - 'ColumnType': 'string' - }, { - 'ColumnName': 'Timestamp', - 'ColumnOrdinal': 1, - 'DataType': 'System.DateTime', - 'ColumnType': 'datetime' - }] + 'data': [ + {'ColumnName': 'Source', 'ColumnOrdinal': 0, 'DataType': 'System.String', 'ColumnType': 'string'}, + { + 'ColumnName': 'Timestamp', + 'ColumnOrdinal': 1, + 'DataType': 'System.DateTime', + 'ColumnType': 'datetime', + }, + ], } @@ -61,17 +57,10 @@ class MockResponse: class TestAzureDataExplorerQueryOperator(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'provide_context': True - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE, 'provide_context': True} - self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', - default_args=args, - schedule_interval='@once') - self.operator = AzureDataExplorerQueryOperator(dag=self.dag, - **MOCK_DATA) + self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once') + self.operator = AzureDataExplorerQueryOperator(dag=self.dag, **MOCK_DATA) def test_init(self): self.assertEqual(self.operator.task_id, MOCK_DATA['task_id']) @@ -83,9 +72,9 @@ def test_init(self): @mock.patch.object(AzureDataExplorerHook, 'get_conn') def test_run_query(self, mock_conn, mock_run_query): self.operator.execute(None) - mock_run_query.assert_called_once_with(MOCK_DATA['query'], - MOCK_DATA['database'], - MOCK_DATA['options']) + mock_run_query.assert_called_once_with( + MOCK_DATA['query'], MOCK_DATA['database'], MOCK_DATA['options'] + ) @mock.patch.object(AzureDataExplorerHook, 'run_query', return_value=MockResponse()) @mock.patch.object(AzureDataExplorerHook, 'get_conn') @@ -93,5 +82,4 @@ def test_xcom_push_and_pull(self, mock_conn, mock_run_query): ti = TaskInstance(task=self.operator, execution_date=timezone.utcnow()) ti.run() - self.assertEqual(ti.xcom_pull(task_ids=MOCK_DATA['task_id']), - MOCK_RESULT) + self.assertEqual(ti.xcom_pull(task_ids=MOCK_DATA['task_id']), MOCK_RESULT) diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index d33e561dfda10..3ef4f69785979 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -55,30 +55,38 @@ def setUp(self, mock_batch, mock_hook): # connect with vm configuration db.merge_conn( - Connection(conn_id=self.test_vm_conn_id, - conn_type="azure_batch", - extra=json.dumps({ - "account_name": self.test_account_name, - "account_key": self.test_account_key, - "account_url": self.test_account_url, - "vm_publisher": self.test_vm_publisher, - "vm_offer": self.test_vm_offer, - "vm_sku": self.test_vm_sku, - "node_agent_sku_id": self.test_node_agent_sku - })) + Connection( + conn_id=self.test_vm_conn_id, + conn_type="azure_batch", + extra=json.dumps( + { + "account_name": self.test_account_name, + "account_key": self.test_account_key, + "account_url": self.test_account_url, + "vm_publisher": self.test_vm_publisher, + "vm_offer": self.test_vm_offer, + "vm_sku": self.test_vm_sku, + "node_agent_sku_id": self.test_node_agent_sku, + } + ), + ) ) # connect with cloud service db.merge_conn( - Connection(conn_id=self.test_cloud_conn_id, - conn_type="azure_batch", - extra=json.dumps({ - "account_name": self.test_account_name, - "account_key": self.test_account_key, - "account_url": self.test_account_url, - "os_family": self.test_cloud_os_family, - "os_version": self.test_cloud_os_version, - "node_agent_sku_id": self.test_node_agent_sku - })) + Connection( + conn_id=self.test_cloud_conn_id, + conn_type="azure_batch", + extra=json.dumps( + { + "account_name": self.test_account_name, + "account_key": self.test_account_key, + "account_url": self.test_account_url, + "os_family": self.test_cloud_os_family, + "os_version": self.test_cloud_os_version, + "node_agent_sku_id": self.test_node_agent_sku, + } + ), + ) ) self.operator = AzureBatchOperator( task_id=TASK_ID, @@ -88,7 +96,7 @@ def setUp(self, mock_batch, mock_hook): batch_task_id=BATCH_TASK_ID, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, - timeout=2 + timeout=2, ) self.batch_client = mock_batch.return_value self.assertEqual(self.batch_client, self.operator.hook.connection) diff --git a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py index 5b03fb759fd0b..dd15558c0277d 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py +++ b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py @@ -35,24 +35,21 @@ def make_mock_cg(container_state, events=None): See https://docs.microsoft.com/en-us/rest/api/container-instances/containergroups """ events = events or [] - instance_view_dict = {"current_state": container_state, - "events": events} - instance_view = namedtuple("InstanceView", - instance_view_dict.keys())(*instance_view_dict.values()) + instance_view_dict = {"current_state": container_state, "events": events} + instance_view = namedtuple("InstanceView", instance_view_dict.keys())(*instance_view_dict.values()) container_dict = {"instance_view": instance_view} container = namedtuple("Container", container_dict.keys())(*container_dict.values()) container_g_dict = {"containers": [container]} - container_g = namedtuple("ContainerGroup", - container_g_dict.keys())(*container_g_dict.values()) + container_g = namedtuple("ContainerGroup", container_g_dict.keys())(*container_g_dict.values()) return container_g class TestACIOperator(unittest.TestCase): - - @mock.patch("airflow.providers.microsoft.azure.operators." - "azure_container_instances.AzureContainerInstanceHook") + @mock.patch( + "airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook" + ) def test_execute(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -60,18 +57,19 @@ def test_execute(self, aci_mock): aci_mock.return_value.get_state.return_value = expected_cg aci_mock.return_value.exists.return_value = False - aci = AzureContainerInstancesOperator(ci_conn_id=None, - registry_conn_id=None, - resource_group='resource-group', - name='container-name', - image='container-image', - region='region', - task_id='task') + aci = AzureContainerInstancesOperator( + ci_conn_id=None, + registry_conn_id=None, + resource_group='resource-group', + name='container-name', + image='container-image', + region='region', + task_id='task', + ) aci.execute(None) self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1) - (called_rg, called_cn, called_cg), _ = \ - aci_mock.return_value.create_or_update.call_args + (called_rg, called_cn, called_cg), _ = aci_mock.return_value.create_or_update.call_args self.assertEqual(called_rg, 'resource-group') self.assertEqual(called_cn, 'container-name') @@ -87,8 +85,9 @@ def test_execute(self, aci_mock): self.assertEqual(aci_mock.return_value.delete.call_count, 1) - @mock.patch("airflow.providers.microsoft.azure.operators." - "azure_container_instances.AzureContainerInstanceHook") + @mock.patch( + "airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook" + ) def test_execute_with_failures(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=1, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -96,20 +95,23 @@ def test_execute_with_failures(self, aci_mock): aci_mock.return_value.get_state.return_value = expected_cg aci_mock.return_value.exists.return_value = False - aci = AzureContainerInstancesOperator(ci_conn_id=None, - registry_conn_id=None, - resource_group='resource-group', - name='container-name', - image='container-image', - region='region', - task_id='task') + aci = AzureContainerInstancesOperator( + ci_conn_id=None, + registry_conn_id=None, + resource_group='resource-group', + name='container-name', + image='container-image', + region='region', + task_id='task', + ) with self.assertRaises(AirflowException): aci.execute(None) self.assertEqual(aci_mock.return_value.delete.call_count, 1) - @mock.patch("airflow.providers.microsoft.azure.operators." - "azure_container_instances.AzureContainerInstanceHook") + @mock.patch( + "airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook" + ) def test_execute_with_tags(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -118,19 +120,20 @@ def test_execute_with_tags(self, aci_mock): aci_mock.return_value.get_state.return_value = expected_cg aci_mock.return_value.exists.return_value = False - aci = AzureContainerInstancesOperator(ci_conn_id=None, - registry_conn_id=None, - resource_group='resource-group', - name='container-name', - image='container-image', - region='region', - task_id='task', - tags=tags) + aci = AzureContainerInstancesOperator( + ci_conn_id=None, + registry_conn_id=None, + resource_group='resource-group', + name='container-name', + image='container-image', + region='region', + task_id='task', + tags=tags, + ) aci.execute(None) self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1) - (called_rg, called_cn, called_cg), _ = \ - aci_mock.return_value.create_or_update.call_args + (called_rg, called_cn, called_cg), _ = aci_mock.return_value.create_or_update.call_args self.assertEqual(called_rg, 'resource-group') self.assertEqual(called_cn, 'container-name') @@ -147,8 +150,9 @@ def test_execute_with_tags(self, aci_mock): self.assertEqual(aci_mock.return_value.delete.call_count, 1) - @mock.patch("airflow.providers.microsoft.azure.operators." - "azure_container_instances.AzureContainerInstanceHook") + @mock.patch( + "airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook" + ) def test_execute_with_messages_logs(self, aci_mock): events = [Event(message="test"), Event(message="messages")] expected_c_state1 = ContainerState(state='Running', exit_code=0, detail_status='test') @@ -156,18 +160,19 @@ def test_execute_with_messages_logs(self, aci_mock): expected_c_state2 = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg2 = make_mock_cg(expected_c_state2, events) - aci_mock.return_value.get_state.side_effect = [expected_cg1, - expected_cg2] + aci_mock.return_value.get_state.side_effect = [expected_cg1, expected_cg2] aci_mock.return_value.get_logs.return_value = ["test", "logs"] aci_mock.return_value.exists.return_value = False - aci = AzureContainerInstancesOperator(ci_conn_id=None, - registry_conn_id=None, - resource_group='resource-group', - name='container-name', - image='container-image', - region='region', - task_id='task') + aci = AzureContainerInstancesOperator( + ci_conn_id=None, + registry_conn_id=None, + resource_group='resource-group', + name='container-name', + image='container-image', + region='region', + task_id='task', + ) aci.execute(None) self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1) @@ -179,10 +184,12 @@ def test_execute_with_messages_logs(self, aci_mock): def test_name_checker(self): valid_names = ['test-dash', 'name-with-length---63' * 3] - invalid_names = ['test_underscore', - 'name-with-length---84' * 4, - 'name-ending-with-dash-', - '-name-starting-with-dash'] + invalid_names = [ + 'test_underscore', + 'name-with-length---84' * 4, + 'name-ending-with-dash-', + '-name-starting-with-dash', + ] for name in invalid_names: with self.assertRaises(AirflowException): AzureContainerInstancesOperator._check_name(name) diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py index 7fabd15f20724..f15cc3be39de3 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py @@ -44,8 +44,9 @@ def setUp(self): conn_type='azure_cosmos', login=self.test_end_point, password=self.test_master_key, - extra=json.dumps({'database_name': self.test_database_name, - 'collection_name': self.test_collection_name}) + extra=json.dumps( + {'database_name': self.test_database_name, 'collection_name': self.test_collection_name} + ), ) ) @@ -58,11 +59,15 @@ def test_insert_document(self, cosmos_mock): collection_name=self.test_collection_name, document={'id': test_id, 'data': 'sometestdata'}, azure_cosmos_conn_id='azure_cosmos_test_key_id', - task_id='azure_cosmos_sensor') + task_id='azure_cosmos_sensor', + ) - expected_calls = [mock.call().CreateItem( - 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, - {'data': 'sometestdata', 'id': test_id})] + expected_calls = [ + mock.call().CreateItem( + 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, + {'data': 'sometestdata', 'id': test_id}, + ) + ] op.execute(None) cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) diff --git a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py index 1289312e9fe61..b6fe94331f651 100644 --- a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py +++ b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py @@ -34,46 +34,27 @@ class TestWasbDeleteBlobOperator(unittest.TestCase): } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) def test_init(self): - operator = WasbDeleteBlobOperator( - task_id='wasb_operator_1', - dag=self.dag, - **self._config - ) - self.assertEqual(operator.container_name, - self._config['container_name']) + operator = WasbDeleteBlobOperator(task_id='wasb_operator_1', dag=self.dag, **self._config) + self.assertEqual(operator.container_name, self._config['container_name']) self.assertEqual(operator.blob_name, self._config['blob_name']) self.assertEqual(operator.is_prefix, False) self.assertEqual(operator.ignore_if_missing, False) operator = WasbDeleteBlobOperator( - task_id='wasb_operator_2', - dag=self.dag, - is_prefix=True, - ignore_if_missing=True, - **self._config + task_id='wasb_operator_2', dag=self.dag, is_prefix=True, ignore_if_missing=True, **self._config ) self.assertEqual(operator.is_prefix, True) self.assertEqual(operator.ignore_if_missing, True) - @mock.patch('airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbHook', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbHook', autospec=True) def test_execute(self, mock_hook): mock_instance = mock_hook.return_value operator = WasbDeleteBlobOperator( - task_id='wasb_operator', - dag=self.dag, - is_prefix=True, - ignore_if_missing=True, - **self._config + task_id='wasb_operator', dag=self.dag, is_prefix=True, ignore_if_missing=True, **self._config ) operator.execute(None) - mock_instance.delete_file.assert_called_once_with( - 'container', 'blob', True, True - ) + mock_instance.delete_file.assert_called_once_with('container', 'blob', True, True) diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py index 74dc0f17f0692..32444cddaa14a 100644 --- a/tests/providers/microsoft/azure/sensors/test_wasb.py +++ b/tests/providers/microsoft/azure/sensors/test_wasb.py @@ -35,18 +35,11 @@ class TestWasbBlobSensor(unittest.TestCase): } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) def test_init(self): - sensor = WasbBlobSensor( - task_id='wasb_sensor_1', - dag=self.dag, - **self._config - ) + sensor = WasbBlobSensor(task_id='wasb_sensor_1', dag=self.dag, **self._config) self.assertEqual(sensor.container_name, self._config['container_name']) self.assertEqual(sensor.blob_name, self._config['blob_name']) self.assertEqual(sensor.wasb_conn_id, self._config['wasb_conn_id']) @@ -54,27 +47,18 @@ def test_init(self): self.assertEqual(sensor.timeout, self._config['timeout']) sensor = WasbBlobSensor( - task_id='wasb_sensor_2', - dag=self.dag, - check_options={'timeout': 2}, - **self._config + task_id='wasb_sensor_2', dag=self.dag, check_options={'timeout': 2}, **self._config ) self.assertEqual(sensor.check_options, {'timeout': 2}) - @mock.patch('airflow.providers.microsoft.azure.sensors.wasb.WasbHook', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.sensors.wasb.WasbHook', autospec=True) def test_poke(self, mock_hook): mock_instance = mock_hook.return_value sensor = WasbBlobSensor( - task_id='wasb_sensor', - dag=self.dag, - check_options={'timeout': 2}, - **self._config + task_id='wasb_sensor', dag=self.dag, check_options={'timeout': 2}, **self._config ) sensor.poke(None) - mock_instance.check_for_blob.assert_called_once_with( - 'container', 'blob', timeout=2 - ) + mock_instance.check_for_blob.assert_called_once_with('container', 'blob', timeout=2) class TestWasbPrefixSensor(unittest.TestCase): @@ -86,18 +70,11 @@ class TestWasbPrefixSensor(unittest.TestCase): } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) def test_init(self): - sensor = WasbPrefixSensor( - task_id='wasb_sensor_1', - dag=self.dag, - **self._config - ) + sensor = WasbPrefixSensor(task_id='wasb_sensor_1', dag=self.dag, **self._config) self.assertEqual(sensor.container_name, self._config['container_name']) self.assertEqual(sensor.prefix, self._config['prefix']) self.assertEqual(sensor.wasb_conn_id, self._config['wasb_conn_id']) @@ -105,24 +82,15 @@ def test_init(self): self.assertEqual(sensor.timeout, self._config['timeout']) sensor = WasbPrefixSensor( - task_id='wasb_sensor_2', - dag=self.dag, - check_options={'timeout': 2}, - **self._config + task_id='wasb_sensor_2', dag=self.dag, check_options={'timeout': 2}, **self._config ) self.assertEqual(sensor.check_options, {'timeout': 2}) - @mock.patch('airflow.providers.microsoft.azure.sensors.wasb.WasbHook', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.sensors.wasb.WasbHook', autospec=True) def test_poke(self, mock_hook): mock_instance = mock_hook.return_value sensor = WasbPrefixSensor( - task_id='wasb_sensor', - dag=self.dag, - check_options={'timeout': 2}, - **self._config + task_id='wasb_sensor', dag=self.dag, check_options={'timeout': 2}, **self._config ) sensor.poke(None) - mock_instance.check_for_prefix.assert_called_once_with( - 'container', 'prefix', timeout=2 - ) + mock_instance.check_for_prefix.assert_called_once_with('container', 'prefix', timeout=2) diff --git a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py index cefc5f81e4dcb..2cd40766c6e3f 100644 --- a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py +++ b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py @@ -37,45 +37,28 @@ class TestFileToWasbOperator(unittest.TestCase): } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) def test_init(self): - operator = FileToWasbOperator( - task_id='wasb_operator_1', - dag=self.dag, - **self._config - ) + operator = FileToWasbOperator(task_id='wasb_operator_1', dag=self.dag, **self._config) self.assertEqual(operator.file_path, self._config['file_path']) - self.assertEqual(operator.container_name, - self._config['container_name']) + self.assertEqual(operator.container_name, self._config['container_name']) self.assertEqual(operator.blob_name, self._config['blob_name']) self.assertEqual(operator.wasb_conn_id, self._config['wasb_conn_id']) self.assertEqual(operator.load_options, {}) self.assertEqual(operator.retries, self._config['retries']) operator = FileToWasbOperator( - task_id='wasb_operator_2', - dag=self.dag, - load_options={'timeout': 2}, - **self._config + task_id='wasb_operator_2', dag=self.dag, load_options={'timeout': 2}, **self._config ) self.assertEqual(operator.load_options, {'timeout': 2}) - @mock.patch('airflow.providers.microsoft.azure.transfers.file_to_wasb.WasbHook', - autospec=True) + @mock.patch('airflow.providers.microsoft.azure.transfers.file_to_wasb.WasbHook', autospec=True) def test_execute(self, mock_hook): mock_instance = mock_hook.return_value operator = FileToWasbOperator( - task_id='wasb_sensor', - dag=self.dag, - load_options={'timeout': 2}, - **self._config + task_id='wasb_sensor', dag=self.dag, load_options={'timeout': 2}, **self._config ) operator.execute(None) - mock_instance.load_file.assert_called_once_with( - 'file', 'container', 'blob', timeout=2 - ) + mock_instance.load_file.assert_called_once_with('file', 'container', 'blob', timeout=2) diff --git a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py index 5ccb92292dafe..dda285cec6551 100644 --- a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py +++ b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py @@ -45,7 +45,7 @@ def test_write_temp_file(self): encoding = 'utf-8' cursor_description = [ ('id', "", 39, None, 38, 0, 0), - ('description', "", 60, 240, None, None, 1) + ('description', "", 60, 240, None, None, 1), ] cursor_rows = [[1, 'description 1'], [2, 'description 2']] mock_cursor = MagicMock() @@ -61,7 +61,8 @@ def test_write_temp_file(self): azure_data_lake_conn_id=azure_data_lake_conn_id, azure_data_lake_path=azure_data_lake_path, delimiter=delimiter, - encoding=encoding) + encoding=encoding, + ) with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: op._write_temp_file(mock_cursor, os.path.join(temp, filename)) @@ -81,10 +82,8 @@ def test_write_temp_file(self): self.assertEqual(row[1], cursor_rows[rownum - 1][1]) rownum = rownum + 1 - @mock.patch(mock_module_path + '.OracleHook', - autospec=True) - @mock.patch(mock_module_path + '.AzureDataLakeHook', - autospec=True) + @mock.patch(mock_module_path + '.OracleHook', autospec=True) + @mock.patch(mock_module_path + '.AzureDataLakeHook', autospec=True) def test_execute(self, mock_data_lake_hook, mock_oracle_hook): task_id = "some_test_id" sql = "some_sql" @@ -97,7 +96,7 @@ def test_execute(self, mock_data_lake_hook, mock_oracle_hook): encoding = 'latin-1' cursor_description = [ ('id', "", 39, None, 38, 0, 0), - ('description', "", 60, 240, None, None, 1) + ('description', "", 60, 240, None, None, 1), ] cursor_rows = [[1, 'description 1'], [2, 'description 2']] cursor_mock = MagicMock() @@ -116,10 +115,10 @@ def test_execute(self, mock_data_lake_hook, mock_oracle_hook): azure_data_lake_conn_id=azure_data_lake_conn_id, azure_data_lake_path=azure_data_lake_path, delimiter=delimiter, - encoding=encoding) + encoding=encoding, + ) op.execute(None) mock_oracle_hook.assert_called_once_with(oracle_conn_id=oracle_conn_id) - mock_data_lake_hook.assert_called_once_with( - azure_data_lake_conn_id=azure_data_lake_conn_id) + mock_data_lake_hook.assert_called_once_with(azure_data_lake_conn_id=azure_data_lake_conn_id) diff --git a/tests/providers/microsoft/mssql/operators/test_mssql.py b/tests/providers/microsoft/mssql/operators/test_mssql.py index 03d71269d5009..7bf73cc780db8 100644 --- a/tests/providers/microsoft/mssql/operators/test_mssql.py +++ b/tests/providers/microsoft/mssql/operators/test_mssql.py @@ -28,8 +28,8 @@ from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator -ODBC_CONN = Connection(conn_id='test-odbc', conn_type='odbc', ) -PYMSSQL_CONN = Connection(conn_id='test-pymssql', conn_type='anything', ) +ODBC_CONN = Connection(conn_id='test-odbc', conn_type='odbc',) +PYMSSQL_CONN = Connection(conn_id='test-pymssql', conn_type='anything',) class TestMsSqlOperator: diff --git a/tests/providers/microsoft/winrm/hooks/test_winrm.py b/tests/providers/microsoft/winrm/hooks/test_winrm.py index 95b85b726dc51..a7b7570bc8235 100644 --- a/tests/providers/microsoft/winrm/hooks/test_winrm.py +++ b/tests/providers/microsoft/winrm/hooks/test_winrm.py @@ -26,7 +26,6 @@ class TestWinRMHook(unittest.TestCase): - @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol') def test_get_conn_exists(self, mock_protocol): winrm_hook = WinRMHook() @@ -48,12 +47,13 @@ def test_get_conn_error(self, mock_protocol): WinRMHook(remote_host='host').get_conn() @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol', autospec=True) - @patch('airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection', - return_value=Connection( - login='username', - password='password', - host='remote_host', - extra="""{ + @patch( + 'airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection', + return_value=Connection( + login='username', + password='password', + host='remote_host', + extra="""{ "endpoint": "endpoint", "remote_port": 123, "transport": "plaintext", @@ -70,8 +70,9 @@ def test_get_conn_error(self, mock_protocol): "message_encryption": "auto", "credssp_disable_tlsv1_2": "true", "send_cbt": "false" - }""" - )) + }""", + ), + ) def test_get_conn_from_connection(self, mock_get_connection, mock_protocol): connection = mock_get_connection.return_value winrm_hook = WinRMHook(ssh_conn_id='conn_id') @@ -96,7 +97,7 @@ def test_get_conn_from_connection(self, mock_get_connection, mock_protocol): kerberos_hostname_override=str(connection.extra_dejson['kerberos_hostname_override']), message_encryption=str(connection.extra_dejson['message_encryption']), credssp_disable_tlsv1_2=str(connection.extra_dejson['credssp_disable_tlsv1_2']).lower() == 'true', - send_cbt=str(connection.extra_dejson['send_cbt']).lower() == 'true' + send_cbt=str(connection.extra_dejson['send_cbt']).lower() == 'true', ) @patch('airflow.providers.microsoft.winrm.hooks.winrm.getpass.getuser', return_value='user') @@ -114,5 +115,6 @@ def test_get_conn_no_endpoint(self, mock_protocol): winrm_hook.get_conn() - self.assertEqual('http://{0}:{1}/wsman'.format(winrm_hook.remote_host, winrm_hook.remote_port), - winrm_hook.endpoint) + self.assertEqual( + 'http://{0}:{1}/wsman'.format(winrm_hook.remote_host, winrm_hook.remote_port), winrm_hook.endpoint + ) diff --git a/tests/providers/microsoft/winrm/operators/test_winrm.py b/tests/providers/microsoft/winrm/operators/test_winrm.py index 0dac784580f2b..ecc14651b2bc1 100644 --- a/tests/providers/microsoft/winrm/operators/test_winrm.py +++ b/tests/providers/microsoft/winrm/operators/test_winrm.py @@ -25,20 +25,14 @@ class TestWinRMOperator(unittest.TestCase): def test_no_winrm_hook_no_ssh_conn_id(self): - op = WinRMOperator(task_id='test_task_id', - winrm_hook=None, - ssh_conn_id=None) + op = WinRMOperator(task_id='test_task_id', winrm_hook=None, ssh_conn_id=None) exception_msg = "Cannot operate without winrm_hook or ssh_conn_id." with self.assertRaisesRegex(AirflowException, exception_msg): op.execute(None) @mock.patch('airflow.providers.microsoft.winrm.operators.winrm.WinRMHook') def test_no_command(self, mock_hook): - op = WinRMOperator( - task_id='test_task_id', - winrm_hook=mock_hook, - command=None - ) + op = WinRMOperator(task_id='test_task_id', winrm_hook=mock_hook, command=None) exception_msg = "No command specified so nothing to execute here." with self.assertRaisesRegex(AirflowException, exception_msg): op.execute(None) diff --git a/tests/providers/mongo/hooks/test_mongo.py b/tests/providers/mongo/hooks/test_mongo.py index 0724e50800a94..8b5fa95d64ae1 100644 --- a/tests/providers/mongo/hooks/test_mongo.py +++ b/tests/providers/mongo/hooks/test_mongo.py @@ -34,6 +34,7 @@ class MongoHookTest(MongoHook): Extending hook so that a mockmongo collection object can be passed in to get_collection() """ + def __init__(self, conn_id='mongo_default', *args, **kwargs): super().__init__(conn_id=conn_id, *args, **kwargs) @@ -47,8 +48,13 @@ def setUp(self): self.conn = self.hook.get_conn() db.merge_conn( Connection( - conn_id='mongo_default_with_srv', conn_type='mongo', - host='mongo', port='27017', extra='{"srv": true}')) + conn_id='mongo_default_with_srv', + conn_type='mongo', + host='mongo', + port='27017', + extra='{"srv": true}', + ) + ) @unittest.skipIf(mongomock is None, 'mongomock package not present') def test_get_conn(self): @@ -73,10 +79,7 @@ def test_insert_one(self): @unittest.skipIf(mongomock is None, 'mongomock package not present') def test_insert_many(self): collection = mongomock.MongoClient().db.collection - objs = [ - {'test_insert_many_1': 'test_value'}, - {'test_insert_many_2': 'test_value'} - ] + objs = [{'test_insert_many_1': 'test_value'}, {'test_insert_many_2': 'test_value'}] self.hook.insert_many(collection, objs) @@ -259,25 +262,14 @@ def test_find_many(self): def test_aggregate(self): collection = mongomock.MongoClient().db.collection objs = [ - { - 'test_id': '1', - 'test_status': 'success' - }, - { - 'test_id': '2', - 'test_status': 'failure' - }, - { - 'test_id': '3', - 'test_status': 'success' - } + {'test_id': '1', 'test_status': 'success'}, + {'test_id': '2', 'test_status': 'failure'}, + {'test_id': '3', 'test_status': 'success'}, ] collection.insert(objs) - aggregate_query = [ - {"$match": {'test_status': 'success'}} - ] + aggregate_query = [{"$match": {'test_status': 'success'}}] results = self.hook.aggregate(collection, aggregate_query) self.assertEqual(len(list(results)), 2) diff --git a/tests/providers/mongo/sensors/test_mongo.py b/tests/providers/mongo/sensors/test_mongo.py index 4b15f1b59f23f..688d19abcf593 100644 --- a/tests/providers/mongo/sensors/test_mongo.py +++ b/tests/providers/mongo/sensors/test_mongo.py @@ -32,17 +32,12 @@ @pytest.mark.integration("mongo") class TestMongoSensor(unittest.TestCase): - def setUp(self): db.merge_conn( - Connection( - conn_id='mongo_test', conn_type='mongo', - host='mongo', port='27017', schema='test')) - - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + Connection(conn_id='mongo_test', conn_type='mongo', host='mongo', port='27017', schema='test') + ) + + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) hook = MongoHook('mongo_test') @@ -53,7 +48,7 @@ def setUp(self): mongo_conn_id='mongo_test', dag=self.dag, collection='foo', - query={'bar': 'baz'} + query={'bar': 'baz'}, ) def test_poke(self): diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py index 5d26bd166c7a9..02a00d5433347 100644 --- a/tests/providers/mysql/hooks/test_mysql.py +++ b/tests/providers/mysql/hooks/test_mysql.py @@ -31,24 +31,15 @@ from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.utils import timezone -SSL_DICT = { - 'cert': '/tmp/client-cert.pem', - 'ca': '/tmp/server-ca.pem', - 'key': '/tmp/client-key.pem' -} +SSL_DICT = {'cert': '/tmp/client-cert.pem', 'ca': '/tmp/server-ca.pem', 'key': '/tmp/client-key.pem'} class TestMySqlHookConn(unittest.TestCase): - def setUp(self): super().setUp() self.connection = Connection( - conn_type='mysql', - login='login', - password='password', - host='host', - schema='schema', + conn_type='mysql', login='login', password='password', host='host', schema='schema', ) self.db_hook = MySqlHook() @@ -162,13 +153,17 @@ def test_get_conn_rds_iam(self, mock_client, mock_connect): self.connection.extra = '{"iam":true}' mock_client.return_value.generate_db_auth_token.return_value = 'aws_token' self.db_hook.get_conn() - mock_connect.assert_called_once_with(user='login', passwd='aws_token', host='host', - db='schema', port=3306, - read_default_group='enable-cleartext-plugin') + mock_connect.assert_called_once_with( + user='login', + passwd='aws_token', + host='host', + db='schema', + port=3306, + read_default_group='enable-cleartext-plugin', + ) class TestMySqlHookConnMySqlConnectorPython(unittest.TestCase): - def setUp(self): super().setUp() @@ -177,7 +172,7 @@ def setUp(self): password='password', host='host', schema='schema', - extra='{"client": "mysql-connector-python"}' + extra='{"client": "mysql-connector-python"}', ) self.db_hook = MySqlHook() @@ -217,7 +212,6 @@ def test_get_conn_allow_local_infile(self, mock_connect): class TestMySqlHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -275,26 +269,27 @@ def test_run_multi_queries(self): self.assertEqual(len(args), 1) self.assertEqual(args[0], sql[i]) self.assertEqual(kwargs, {}) - calls = [ - mock.call(sql[0]), - mock.call(sql[1]) - ] + calls = [mock.call(sql[0]), mock.call(sql[1])] self.cur.execute.assert_has_calls(calls, any_order=True) self.conn.commit.assert_not_called() def test_bulk_load(self): self.db_hook.bulk_load('table', '/tmp/file') - self.cur.execute.assert_called_once_with(""" + self.cur.execute.assert_called_once_with( + """ LOAD DATA LOCAL INFILE '/tmp/file' INTO TABLE table - """) + """ + ) def test_bulk_dump(self): self.db_hook.bulk_dump('table', '/tmp/file') - self.cur.execute.assert_called_once_with(""" + self.cur.execute.assert_called_once_with( + """ SELECT * INTO OUTFILE '/tmp/file' FROM table - """) + """ + ) def test_serialize_cell(self): self.assertEqual('foo', self.db_hook._serialize_cell('foo', None)) @@ -306,16 +301,18 @@ def test_bulk_load_custom(self): 'IGNORE', """FIELDS TERMINATED BY ';' OPTIONALLY ENCLOSED BY '"' - IGNORE 1 LINES""" + IGNORE 1 LINES""", ) - self.cur.execute.assert_called_once_with(""" + self.cur.execute.assert_called_once_with( + """ LOAD DATA LOCAL INFILE '/tmp/file' IGNORE INTO TABLE table FIELDS TERMINATED BY ';' OPTIONALLY ENCLOSED BY '"' IGNORE 1 LINES - """) + """ + ) DEFAULT_DATE = timezone.datetime(2015, 1, 1) @@ -340,10 +337,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @pytest.mark.backend("mysql") class TestMySql(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag @@ -353,30 +347,37 @@ def tearDown(self): for table in drop_tables: conn.execute("DROP TABLE IF EXISTS {}".format(table)) - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [("mysqlclient",), ("mysql-connector-python",),] + ) def test_mysql_hook_test_bulk_load(self, client): with MySqlContext(client): records = ("foo", "bar", "baz") import tempfile + with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with hook.get_conn() as conn: - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) - """) + """ + ) conn.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) conn.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in conn.fetchall()) self.assertEqual(sorted(results), sorted(records)) - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [("mysqlclient",), ("mysql-connector-python",),] + ) def test_mysql_hook_test_bulk_dump(self, client): with MySqlContext(client): hook = MySqlHook('airflow_db') @@ -385,10 +386,11 @@ def test_mysql_hook_test_bulk_dump(self, client): # Confirm that no error occurs hook.bulk_dump("INFORMATION_SCHEMA.TABLES", os.path.join(priv[0], "TABLES_{}".format(client))) else: - self.skipTest("Skip test_mysql_hook_test_bulk_load " - "since file output is not permitted") + self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted") - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [("mysqlclient",), ("mysql-connector-python",),] + ) @mock.patch('airflow.providers.mysql.hooks.mysql.MySqlHook.get_conn') def test_mysql_hook_test_bulk_dump_mock(self, client, mock_get_conn): with MySqlContext(client): @@ -401,9 +403,12 @@ def test_mysql_hook_test_bulk_dump_mock(self, client, mock_get_conn): hook.bulk_dump(table, tmp_file) from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces + assert mock_execute.call_count == 1 query = """ SELECT * INTO OUTFILE '{tmp_file}' FROM {table} - """.format(tmp_file=tmp_file, table=table) + """.format( + tmp_file=tmp_file, table=table + ) assert_equal_ignore_multiple_spaces(self, mock_execute.call_args[0][0], query) diff --git a/tests/providers/mysql/operators/test_mysql.py b/tests/providers/mysql/operators/test_mysql.py index 7ea33592c851b..3e69e7a75cb7e 100644 --- a/tests/providers/mysql/operators/test_mysql.py +++ b/tests/providers/mysql/operators/test_mysql.py @@ -35,10 +35,7 @@ @pytest.mark.backend("mysql") class TestMySql(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag @@ -48,7 +45,9 @@ def tearDown(self): for table in drop_tables: conn.execute("DROP TABLE IF EXISTS {}".format(table)) - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [("mysqlclient",), ("mysql-connector-python",),] + ) def test_mysql_operator_test(self, client): with MySqlContext(client): sql = """ @@ -56,13 +55,12 @@ def test_mysql_operator_test(self, client): dummy VARCHAR(50) ); """ - op = MySqlOperator( - task_id='basic_mysql', - sql=sql, - dag=self.dag) + op = MySqlOperator(task_id='basic_mysql', sql=sql, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [("mysqlclient",), ("mysql-connector-python",),] + ) def test_mysql_operator_test_multi(self, client): with MySqlContext(client): sql = [ @@ -70,14 +68,12 @@ def test_mysql_operator_test_multi(self, client): "TRUNCATE TABLE test_airflow", "INSERT INTO test_airflow VALUES ('X')", ] - op = MySqlOperator( - task_id='mysql_operator_test_multi', - sql=sql, - dag=self.dag, - ) + op = MySqlOperator(task_id='mysql_operator_test_multi', sql=sql, dag=self.dag,) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @parameterized.expand([("mysqlclient",), ("mysql-connector-python",), ]) + @parameterized.expand( + [("mysqlclient",), ("mysql-connector-python",),] + ) def test_overwrite_schema(self, client): """ Verifies option to overwrite connection schema @@ -85,15 +81,12 @@ def test_overwrite_schema(self, client): with MySqlContext(client): sql = "SELECT 1;" op = MySqlOperator( - task_id='test_mysql_operator_test_schema_overwrite', - sql=sql, - dag=self.dag, - database="foobar", + task_id='test_mysql_operator_test_schema_overwrite', sql=sql, dag=self.dag, database="foobar", ) from _mysql_exceptions import OperationalError + try: - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) except OperationalError as e: assert "Unknown database 'foobar'" in str(e) diff --git a/tests/providers/mysql/transfers/test_presto_to_mysql.py b/tests/providers/mysql/transfers/test_presto_to_mysql.py index 052f55789f97e..0df726a112dc2 100644 --- a/tests/providers/mysql/transfers/test_presto_to_mysql.py +++ b/tests/providers/mysql/transfers/test_presto_to_mysql.py @@ -24,13 +24,8 @@ class TestPrestoToMySqlTransfer(TestHiveEnvironment): - def setUp(self): - self.kwargs = dict( - sql='sql', - mysql_table='mysql_table', - task_id='test_presto_to_mysql_transfer', - ) + self.kwargs = dict(sql='sql', mysql_table='mysql_table', task_id='test_presto_to_mysql_transfer',) super().setUp() @patch('airflow.providers.mysql.transfers.presto_to_mysql.MySqlHook') @@ -40,7 +35,8 @@ def test_execute(self, mock_presto_hook, mock_mysql_hook): mock_presto_hook.return_value.get_records.assert_called_once_with(self.kwargs['sql']) mock_mysql_hook.return_value.insert_rows.assert_called_once_with( - table=self.kwargs['mysql_table'], rows=mock_presto_hook.return_value.get_records.return_value) + table=self.kwargs['mysql_table'], rows=mock_presto_hook.return_value.get_records.return_value + ) @patch('airflow.providers.mysql.transfers.presto_to_mysql.MySqlHook') @patch('airflow.providers.mysql.transfers.presto_to_mysql.PrestoHook') @@ -52,11 +48,12 @@ def test_execute_with_mysql_preoperator(self, mock_presto_hook, mock_mysql_hook) mock_presto_hook.return_value.get_records.assert_called_once_with(self.kwargs['sql']) mock_mysql_hook.return_value.run.assert_called_once_with(self.kwargs['mysql_preoperator']) mock_mysql_hook.return_value.insert_rows.assert_called_once_with( - table=self.kwargs['mysql_table'], rows=mock_presto_hook.return_value.get_records.return_value) + table=self.kwargs['mysql_table'], rows=mock_presto_hook.return_value.get_records.return_value + ) @unittest.skipIf( - 'AIRFLOW_RUNALL_TESTS' not in os.environ, - "Skipped because AIRFLOW_RUNALL_TESTS is not set") + 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + ) def test_presto_to_mysql(self): op = PrestoToMySqlOperator( task_id='presto_to_mysql_check', @@ -67,6 +64,6 @@ def test_presto_to_mysql(self): """, mysql_table='test_static_babynames', mysql_preoperator='TRUNCATE TABLE test_static_babynames;', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/mysql/transfers/test_s3_to_mysql.py b/tests/providers/mysql/transfers/test_s3_to_mysql.py index 6b8596beb5925..d63bcc5699481 100644 --- a/tests/providers/mysql/transfers/test_s3_to_mysql.py +++ b/tests/providers/mysql/transfers/test_s3_to_mysql.py @@ -27,7 +27,6 @@ class TestS3ToMySqlTransfer(unittest.TestCase): - def setUp(self): configuration.conf.load_test_config() @@ -37,7 +36,7 @@ def setUp(self): conn_type='s3', schema='test', extra='{"aws_access_key_id": "aws_access_key_id", "aws_secret_access_key":' - ' "aws_secret_access_key"}' + ' "aws_secret_access_key"}', ) ) db.merge_conn( @@ -47,7 +46,7 @@ def setUp(self): host='some.host.com', schema='test_db', login='user', - password='password' + password='password', ) ) @@ -62,7 +61,7 @@ def setUp(self): IGNORE 1 LINES """, 'task_id': 'task_id', - 'dag': None + 'dag': None, } @patch('airflow.providers.mysql.transfers.s3_to_mysql.S3Hook.download_file') @@ -71,14 +70,12 @@ def setUp(self): def test_execute(self, mock_remove, mock_bulk_load_custom, mock_download_file): S3ToMySqlOperator(**self.s3_to_mysql_transfer_kwargs).execute({}) - mock_download_file.assert_called_once_with( - key=self.s3_to_mysql_transfer_kwargs['s3_source_key'] - ) + mock_download_file.assert_called_once_with(key=self.s3_to_mysql_transfer_kwargs['s3_source_key']) mock_bulk_load_custom.assert_called_once_with( table=self.s3_to_mysql_transfer_kwargs['mysql_table'], tmp_file=mock_download_file.return_value, duplicate_key_handling=self.s3_to_mysql_transfer_kwargs['mysql_duplicate_key_handling'], - extra_options=self.s3_to_mysql_transfer_kwargs['mysql_extra_options'] + extra_options=self.s3_to_mysql_transfer_kwargs['mysql_extra_options'], ) mock_remove.assert_called_once_with(mock_download_file.return_value) @@ -88,23 +85,23 @@ def test_execute(self, mock_remove, mock_bulk_load_custom, mock_download_file): def test_execute_exception(self, mock_remove, mock_bulk_load_custom, mock_download_file): mock_bulk_load_custom.side_effect = Exception - self.assertRaises(Exception, S3ToMySqlOperator( - **self.s3_to_mysql_transfer_kwargs).execute, {}) + self.assertRaises(Exception, S3ToMySqlOperator(**self.s3_to_mysql_transfer_kwargs).execute, {}) - mock_download_file.assert_called_once_with( - key=self.s3_to_mysql_transfer_kwargs['s3_source_key'] - ) + mock_download_file.assert_called_once_with(key=self.s3_to_mysql_transfer_kwargs['s3_source_key']) mock_bulk_load_custom.assert_called_once_with( table=self.s3_to_mysql_transfer_kwargs['mysql_table'], tmp_file=mock_download_file.return_value, duplicate_key_handling=self.s3_to_mysql_transfer_kwargs['mysql_duplicate_key_handling'], - extra_options=self.s3_to_mysql_transfer_kwargs['mysql_extra_options'] + extra_options=self.s3_to_mysql_transfer_kwargs['mysql_extra_options'], ) mock_remove.assert_called_once_with(mock_download_file.return_value) def tearDown(self): with create_session() as session: - (session - .query(models.Connection) - .filter(or_(models.Connection.conn_id == 's3_test', models.Connection.conn_id == 'mysql_test')) - .delete()) + ( + session.query(models.Connection) + .filter( + or_(models.Connection.conn_id == 's3_test', models.Connection.conn_id == 'mysql_test') + ) + .delete() + ) diff --git a/tests/providers/mysql/transfers/test_vertica_to_mysql.py b/tests/providers/mysql/transfers/test_vertica_to_mysql.py index e8f0171b62dea..96a4e94d76bae 100644 --- a/tests/providers/mysql/transfers/test_vertica_to_mysql.py +++ b/tests/providers/mysql/transfers/test_vertica_to_mysql.py @@ -25,65 +25,62 @@ def mock_get_conn(): - commit_mock = mock.MagicMock( - ) + commit_mock = mock.MagicMock() cursor_mock = mock.MagicMock( - execute=[], - fetchall=[['1', '2', '3']], - description=['a', 'b', 'c'], - iterate=[['1', '2', '3']], - ) - conn_mock = mock.MagicMock( - commit=commit_mock, - cursor=cursor_mock, + execute=[], fetchall=[['1', '2', '3']], description=['a', 'b', 'c'], iterate=[['1', '2', '3']], ) + conn_mock = mock.MagicMock(commit=commit_mock, cursor=cursor_mock,) return conn_mock class TestVerticaToMySqlTransfer(unittest.TestCase): def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': datetime.datetime(2017, 1, 1) - } + args = {'owner': 'airflow', 'start_date': datetime.datetime(2017, 1, 1)} self.dag = DAG('test_dag_id', default_args=args) @mock.patch( - 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaHook.get_conn', side_effect=mock_get_conn) - @mock.patch( - 'airflow.providers.mysql.transfers.vertica_to_mysql.MySqlHook.get_conn', side_effect=mock_get_conn) + 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaHook.get_conn', side_effect=mock_get_conn + ) @mock.patch( - 'airflow.providers.mysql.transfers.vertica_to_mysql.MySqlHook.insert_rows', return_value=True) + 'airflow.providers.mysql.transfers.vertica_to_mysql.MySqlHook.get_conn', side_effect=mock_get_conn + ) + @mock.patch('airflow.providers.mysql.transfers.vertica_to_mysql.MySqlHook.insert_rows', return_value=True) def test_select_insert_transfer(self, *args): """ Test check selection from vertica into memory and after that inserting into mysql """ - task = VerticaToMySqlOperator(task_id='test_task_id', - sql='select a, b, c', - mysql_table='test_table', - vertica_conn_id='test_vertica_conn_id', - mysql_conn_id='test_mysql_conn_id', - params={}, - bulk_load=False, - dag=self.dag) + task = VerticaToMySqlOperator( + task_id='test_task_id', + sql='select a, b, c', + mysql_table='test_table', + vertica_conn_id='test_vertica_conn_id', + mysql_conn_id='test_mysql_conn_id', + params={}, + bulk_load=False, + dag=self.dag, + ) task.execute(None) @mock.patch( - 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaHook.get_conn', side_effect=mock_get_conn) + 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaHook.get_conn', side_effect=mock_get_conn + ) @mock.patch( - 'airflow.providers.mysql.transfers.vertica_to_mysql.MySqlHook.get_conn', side_effect=mock_get_conn) + 'airflow.providers.mysql.transfers.vertica_to_mysql.MySqlHook.get_conn', side_effect=mock_get_conn + ) def test_select_bulk_insert_transfer(self, *args): """ Test check selection from vertica into temporary file and after that bulk inserting into mysql """ - task = VerticaToMySqlOperator(task_id='test_task_id', - sql='select a, b, c', - mysql_table='test_table', - vertica_conn_id='test_vertica_conn_id', - mysql_conn_id='test_mysql_conn_id', - params={}, - bulk_load=True, - dag=self.dag) + task = VerticaToMySqlOperator( + task_id='test_task_id', + sql='select a, b, c', + mysql_table='test_table', + vertica_conn_id='test_vertica_conn_id', + mysql_conn_id='test_mysql_conn_id', + params={}, + bulk_load=True, + dag=self.dag, + ) task.execute(None) diff --git a/tests/providers/odbc/hooks/test_odbc.py b/tests/providers/odbc/hooks/test_odbc.py index 901002fb917fb..e39416d04c413 100644 --- a/tests/providers/odbc/hooks/test_odbc.py +++ b/tests/providers/odbc/hooks/test_odbc.py @@ -73,9 +73,7 @@ def test_driver_in_both(self): def test_dsn_in_extra(self): conn_params = dict(extra=json.dumps(dict(DSN='MyDSN', Fake_Param='Fake Param'))) hook = self.get_hook(conn_params=conn_params) - expected = ( - 'DSN=MyDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;Fake_Param=Fake Param;' - ) + expected = 'DSN=MyDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;Fake_Param=Fake Param;' assert hook.odbc_connection_string == expected def test_dsn_in_both(self): @@ -107,10 +105,7 @@ def test_connect_kwargs_from_hook(self): hook = self.get_hook( hook_params=dict( connect_kwargs={ - 'attrs_before': { - 1: 2, - pyodbc.SQL_TXN_ISOLATION: pyodbc.SQL_TXN_READ_UNCOMMITTED, - }, + 'attrs_before': {1: 2, pyodbc.SQL_TXN_ISOLATION: pyodbc.SQL_TXN_READ_UNCOMMITTED,}, 'readonly': True, 'autocommit': False, } @@ -126,10 +121,7 @@ def test_connect_kwargs_from_conn(self): extra = json.dumps( dict( connect_kwargs={ - 'attrs_before': { - 1: 2, - pyodbc.SQL_TXN_ISOLATION: pyodbc.SQL_TXN_READ_UNCOMMITTED, - }, + 'attrs_before': {1: 2, pyodbc.SQL_TXN_ISOLATION: pyodbc.SQL_TXN_READ_UNCOMMITTED,}, 'readonly': True, 'autocommit': True, } @@ -148,9 +140,7 @@ def test_connect_kwargs_from_conn_and_hook(self): When connect_kwargs in both hook and conn, should be merged properly. Hook beats conn. """ - conn_extra = json.dumps( - dict(connect_kwargs={'attrs_before': {1: 2, 3: 4}, 'readonly': False}) - ) + conn_extra = json.dumps(dict(connect_kwargs={'attrs_before': {1: 2, 3: 4}, 'readonly': False})) hook_params = dict( connect_kwargs={'attrs_before': {3: 5, pyodbc.SQL_TXN_ISOLATION: 0}, 'readonly': True} ) diff --git a/tests/providers/openfaas/hooks/test_openfaas.py b/tests/providers/openfaas/hooks/test_openfaas.py index 07ba2aa2b8db2..39aff73501ced 100644 --- a/tests/providers/openfaas/hooks/test_openfaas.py +++ b/tests/providers/openfaas/hooks/test_openfaas.py @@ -43,8 +43,11 @@ def setUp(self): @mock.patch.object(BaseHook, 'get_connection') @requests_mock.mock() def test_is_function_exist_false(self, mock_get_connection, m): - m.get("http://open-faas.io" + self.GET_FUNCTION + FUNCTION_NAME, - json=self.mock_response, status_code=404) + m.get( + "http://open-faas.io" + self.GET_FUNCTION + FUNCTION_NAME, + json=self.mock_response, + status_code=404, + ) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection @@ -54,8 +57,11 @@ def test_is_function_exist_false(self, mock_get_connection, m): @mock.patch.object(BaseHook, 'get_connection') @requests_mock.mock() def test_is_function_exist_true(self, mock_get_connection, m): - m.get("http://open-faas.io" + self.GET_FUNCTION + FUNCTION_NAME, - json=self.mock_response, status_code=202) + m.get( + "http://open-faas.io" + self.GET_FUNCTION + FUNCTION_NAME, + json=self.mock_response, + status_code=202, + ) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection @@ -85,8 +91,11 @@ def test_update_function_false(self, mock_get_connection, m): @mock.patch.object(BaseHook, 'get_connection') @requests_mock.mock() def test_invoke_async_function_false(self, mock_get_connection, m): - m.post("http://open-faas.io" + self.INVOKE_ASYNC_FUNCTION + FUNCTION_NAME, json=self.mock_response, - status_code=400) + m.post( + "http://open-faas.io" + self.INVOKE_ASYNC_FUNCTION + FUNCTION_NAME, + json=self.mock_response, + status_code=400, + ) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection @@ -97,8 +106,11 @@ def test_invoke_async_function_false(self, mock_get_connection, m): @mock.patch.object(BaseHook, 'get_connection') @requests_mock.mock() def test_invoke_async_function_true(self, mock_get_connection, m): - m.post("http://open-faas.io" + self.INVOKE_ASYNC_FUNCTION + FUNCTION_NAME, json=self.mock_response, - status_code=202) + m.post( + "http://open-faas.io" + self.INVOKE_ASYNC_FUNCTION + FUNCTION_NAME, + json=self.mock_response, + status_code=202, + ) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection self.assertEqual(self.hook.invoke_async_function({}), None) diff --git a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py index b12df1cadeb2e..8a52b4d1e8786 100644 --- a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py +++ b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py @@ -42,13 +42,13 @@ class TestOpsgenieAlertHook(unittest.TestCase): {'id': 'aee8a0de-c80f-4515-a232-501c0bc9d715', 'type': 'escalation'}, {'name': 'Nightwatch Escalation', 'type': 'escalation'}, {'id': '80564037-1984-4f38-b98e-8a1f662df552', 'type': 'schedule'}, - {'name': 'First Responders Schedule', 'type': 'schedule'} + {'name': 'First Responders Schedule', 'type': 'schedule'}, ], 'visibleTo': [ {'id': '4513b7ea-3b91-438f-b7e4-e3e54af9147c', 'type': 'team'}, {'name': 'rocket_team', 'type': 'team'}, {'id': 'bb4d9938-c3c2-455d-aaab-727aa701c0d8', 'type': 'user'}, - {'username': 'trinity@opsgenie.com', 'type': 'user'} + {'username': 'trinity@opsgenie.com', 'type': 'user'}, ], 'actions': ['Restart', 'AnExampleAction'], 'tags': ['OverwriteQuietHours', 'Critical'], @@ -57,12 +57,12 @@ class TestOpsgenieAlertHook(unittest.TestCase): 'source': 'Airflow', 'priority': 'P1', 'user': 'Jesse', - 'note': 'Write this down' + 'note': 'Write this down', } _mock_success_response_body = { "result": "Request will be processed", "took": 0.302, - "requestId": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120" + "requestId": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120", } def setUp(self): @@ -71,7 +71,7 @@ def setUp(self): conn_id=self.conn_id, conn_type='http', host='https://api.opsgenie.com/', - password='eb243592-faa2-4ba2-a551q-1afdf565c889' + password='eb243592-faa2-4ba2-a551q-1afdf565c889', ) ) @@ -88,11 +88,7 @@ def test_get_conn_defaults_host(self): @requests_mock.mock() def test_call_with_success(self, m): hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - m.post( - self.opsgenie_alert_endpoint, - status_code=202, - json=self._mock_success_response_body - ) + m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) resp = hook.execute(payload=self._payload) self.assertEqual(resp.status_code, 202) self.assertEqual(resp.json(), self._mock_success_response_body) @@ -100,33 +96,22 @@ def test_call_with_success(self, m): @requests_mock.mock() def test_api_key_set(self, m): hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - m.post( - self.opsgenie_alert_endpoint, - status_code=202, - json=self._mock_success_response_body - ) + m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) resp = hook.execute(payload=self._payload) - self.assertEqual(resp.request.headers.get('Authorization'), - 'GenieKey eb243592-faa2-4ba2-a551q-1afdf565c889') + self.assertEqual( + resp.request.headers.get('Authorization'), 'GenieKey eb243592-faa2-4ba2-a551q-1afdf565c889' + ) @requests_mock.mock() def test_api_key_not_set(self, m): hook = OpsgenieAlertHook() - m.post( - self.opsgenie_alert_endpoint, - status_code=202, - json=self._mock_success_response_body - ) + m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) with self.assertRaises(AirflowException): hook.execute(payload=self._payload) @requests_mock.mock() def test_payload_set(self, m): hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - m.post( - self.opsgenie_alert_endpoint, - status_code=202, - json=self._mock_success_response_body - ) + m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) resp = hook.execute(payload=self._payload) self.assertEqual(json.loads(resp.request.body), self._payload) diff --git a/tests/providers/opsgenie/operators/test_opsgenie_alert.py b/tests/providers/opsgenie/operators/test_opsgenie_alert.py index d309e744859d8..40faa63035429 100644 --- a/tests/providers/opsgenie/operators/test_opsgenie_alert.py +++ b/tests/providers/opsgenie/operators/test_opsgenie_alert.py @@ -39,13 +39,13 @@ class TestOpsgenieAlertOperator(unittest.TestCase): {'id': 'aee8a0de-c80f-4515-a232-501c0bc9d715', 'type': 'escalation'}, {'name': 'Nightwatch Escalation', 'type': 'escalation'}, {'id': '80564037-1984-4f38-b98e-8a1f662df552', 'type': 'schedule'}, - {'name': 'First Responders Schedule', 'type': 'schedule'} + {'name': 'First Responders Schedule', 'type': 'schedule'}, ], 'visible_to': [ {'id': '4513b7ea-3b91-438f-b7e4-e3e54af9147c', 'type': 'team'}, {'name': 'rocket_team', 'type': 'team'}, {'id': 'bb4d9938-c3c2-455d-aaab-727aa701c0d8', 'type': 'user'}, - {'username': 'trinity@opsgenie.com', 'type': 'user'} + {'username': 'trinity@opsgenie.com', 'type': 'user'}, ], 'actions': ['Restart', 'AnExampleAction'], 'tags': ['OverwriteQuietHours', 'Critical'], @@ -54,7 +54,7 @@ class TestOpsgenieAlertOperator(unittest.TestCase): 'source': 'Airflow', 'priority': 'P1', 'user': 'Jesse', - 'note': 'Write this down' + 'note': 'Write this down', } expected_payload_dict = { @@ -70,23 +70,16 @@ class TestOpsgenieAlertOperator(unittest.TestCase): 'source': _config['source'], 'priority': _config['priority'], 'user': _config['user'], - 'note': _config['note'] + 'note': _config['note'], } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) def test_build_opsgenie_payload(self): # Given / When - operator = OpsgenieAlertOperator( - task_id='opsgenie_alert_job', - dag=self.dag, - **self._config - ) + operator = OpsgenieAlertOperator(task_id='opsgenie_alert_job', dag=self.dag, **self._config) payload = operator._build_opsgenie_payload() @@ -95,11 +88,7 @@ def test_build_opsgenie_payload(self): def test_properties(self): # Given / When - operator = OpsgenieAlertOperator( - task_id='opsgenie_alert_job', - dag=self.dag, - **self._config - ) + operator = OpsgenieAlertOperator(task_id='opsgenie_alert_job', dag=self.dag, **self._config) self.assertEqual('opsgenie_default', operator.opsgenie_conn_id) self.assertEqual(self._config['message'], operator.message) diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 0eb3ea08aa468..6cf4bc5be8182 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -35,16 +35,10 @@ @unittest.skipIf(cx_Oracle is None, 'cx_Oracle package not present') class TestOracleHookConn(unittest.TestCase): - def setUp(self): super().setUp() - self.connection = Connection( - login='login', - password='password', - host='host', - port=1521 - ) + self.connection = Connection(login='login', password='password', host='host', port=1521) self.db_hook = OracleHook() self.db_hook.get_connection = mock.Mock() @@ -68,9 +62,9 @@ def test_get_conn_sid(self, mock_connect): assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) - self.assertEqual(kwargs['dsn'], - cx_Oracle.makedsn(dsn_sid['dsn'], - self.connection.port, dsn_sid['sid'])) + self.assertEqual( + kwargs['dsn'], cx_Oracle.makedsn(dsn_sid['dsn'], self.connection.port, dsn_sid['sid']) + ) @mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect') def test_get_conn_service_name(self, mock_connect): @@ -80,9 +74,12 @@ def test_get_conn_service_name(self, mock_connect): assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) - self.assertEqual(kwargs['dsn'], cx_Oracle.makedsn( - dsn_service_name['dsn'], self.connection.port, - service_name=dsn_service_name['service_name'])) + self.assertEqual( + kwargs['dsn'], + cx_Oracle.makedsn( + dsn_service_name['dsn'], self.connection.port, service_name=dsn_service_name['service_name'] + ), + ) @mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect') def test_get_conn_encoding_without_nencoding(self, mock_connect): @@ -158,7 +155,7 @@ def test_get_conn_purity(self, mock_connect): purity = { 'new': cx_Oracle.ATTR_PURITY_NEW, 'self': cx_Oracle.ATTR_PURITY_SELF, - 'default': cx_Oracle.ATTR_PURITY_DEFAULT + 'default': cx_Oracle.ATTR_PURITY_DEFAULT, } first = True for pur in purity: @@ -204,34 +201,61 @@ def test_run_with_parameters(self): assert self.conn.commit.called def test_insert_rows_with_fields(self): - rows = [("'basestr_with_quote", None, numpy.NAN, - numpy.datetime64('2019-01-24T01:02:03'), - datetime(2019, 1, 24), 1, 10.24, 'str')] - target_fields = ['basestring', 'none', 'numpy_nan', 'numpy_datetime64', - 'datetime', 'int', 'float', 'str'] + rows = [ + ( + "'basestr_with_quote", + None, + numpy.NAN, + numpy.datetime64('2019-01-24T01:02:03'), + datetime(2019, 1, 24), + 1, + 10.24, + 'str', + ) + ] + target_fields = [ + 'basestring', + 'none', + 'numpy_nan', + 'numpy_datetime64', + 'datetime', + 'int', + 'float', + 'str', + ] self.db_hook.insert_rows('table', rows, target_fields) self.cur.execute.assert_called_once_with( "INSERT /*+ APPEND */ INTO table " "(basestring, none, numpy_nan, numpy_datetime64, datetime, int, float, str) " "VALUES ('''basestr_with_quote',NULL,NULL,'2019-01-24T01:02:03'," - "to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')") + "to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')" + ) def test_insert_rows_without_fields(self): - rows = [("'basestr_with_quote", None, numpy.NAN, - numpy.datetime64('2019-01-24T01:02:03'), - datetime(2019, 1, 24), 1, 10.24, 'str')] + rows = [ + ( + "'basestr_with_quote", + None, + numpy.NAN, + numpy.datetime64('2019-01-24T01:02:03'), + datetime(2019, 1, 24), + 1, + 10.24, + 'str', + ) + ] self.db_hook.insert_rows('table', rows) self.cur.execute.assert_called_once_with( "INSERT /*+ APPEND */ INTO table " " VALUES ('''basestr_with_quote',NULL,NULL,'2019-01-24T01:02:03'," - "to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')") + "to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')" + ) def test_bulk_insert_rows_with_fields(self): rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] target_fields = ['col1', 'col2', 'col3'] self.db_hook.bulk_insert_rows('table', rows, target_fields) - self.cur.prepare.assert_called_once_with( - "insert into table (col1, col2, col3) values (:1, :2, :3)") + self.cur.prepare.assert_called_once_with("insert into table (col1, col2, col3) values (:1, :2, :3)") self.cur.executemany.assert_called_once_with(None, rows) def test_bulk_insert_rows_with_commit_every(self): @@ -252,8 +276,7 @@ def test_bulk_insert_rows_with_commit_every(self): def test_bulk_insert_rows_without_fields(self): rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] self.db_hook.bulk_insert_rows('table', rows) - self.cur.prepare.assert_called_once_with( - "insert into table values (:1, :2, :3)") + self.cur.prepare.assert_called_once_with("insert into table values (:1, :2, :3)") self.cur.executemany.assert_called_once_with(None, rows) def test_bulk_insert_rows_no_rows(self): diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index 729ccedca8403..829d3eb2dd871 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -33,8 +33,13 @@ def test_execute(self, mock_run): context = "test_context" task_id = "test_task_id" - operator = OracleOperator(sql=sql, oracle_conn_id=oracle_conn_id, parameters=parameters, - autocommit=autocommit, task_id=task_id) + operator = OracleOperator( + sql=sql, + oracle_conn_id=oracle_conn_id, + parameters=parameters, + autocommit=autocommit, + task_id=task_id, + ) operator.execute(context=context) mock_run.assert_called_once_with(sql, autocommit=autocommit, parameters=parameters) diff --git a/tests/providers/oracle/transfers/test_oracle_to_oracle.py b/tests/providers/oracle/transfers/test_oracle_to_oracle.py index 296689dccf0be..f52f59e5ecc4a 100644 --- a/tests/providers/oracle/transfers/test_oracle_to_oracle.py +++ b/tests/providers/oracle/transfers/test_oracle_to_oracle.py @@ -25,7 +25,6 @@ class TestOracleToOracleTransfer(unittest.TestCase): - @staticmethod def test_execute(): oracle_destination_conn_id = 'oracle_destination_conn_id' @@ -36,7 +35,7 @@ def test_execute(): rows_chunk = 5000 cursor_description = [ ('id', "", 39, None, 38, 0, 0), - ('description', "", 60, 240, None, None, 1) + ('description', "", 60, 240, None, None, 1), ] cursor_rows = [[1, 'description 1'], [2, 'description 2']] @@ -54,7 +53,8 @@ def test_execute(): oracle_source_conn_id=oracle_source_conn_id, source_sql=source_sql, source_sql_params=source_sql_params, - rows_chunk=rows_chunk) + rows_chunk=rows_chunk, + ) op._execute(mock_src_hook, mock_dest_hook, None) @@ -68,7 +68,5 @@ def test_execute(): ] mock_cursor.fetchmany.assert_has_calls(calls) mock_dest_hook.bulk_insert_rows.assert_called_once_with( - destination_table, - cursor_rows, - commit_every=rows_chunk, - target_fields=['id', 'description']) + destination_table, cursor_rows, commit_every=rows_chunk, target_fields=['id', 'description'] + ) diff --git a/tests/providers/pagerduty/hooks/test_pagerduty.py b/tests/providers/pagerduty/hooks/test_pagerduty.py index f614e33f73a4c..1450a9db1232a 100644 --- a/tests/providers/pagerduty/hooks/test_pagerduty.py +++ b/tests/providers/pagerduty/hooks/test_pagerduty.py @@ -31,21 +31,23 @@ class TestPagerdutyHook(unittest.TestCase): @classmethod @provide_session def setUpClass(cls, session=None): - session.add(Connection( - conn_id=DEFAULT_CONN_ID, - conn_type='http', - password="pagerduty_token", - extra='{"routing_key": "route"}', - )) + session.add( + Connection( + conn_id=DEFAULT_CONN_ID, + conn_type='http', + password="pagerduty_token", + extra='{"routing_key": "route"}', + ) + ) session.commit() @provide_session def test_without_routing_key_extra(self, session): - session.add(Connection( - conn_id="pagerduty_no_extra", - conn_type='http', - password="pagerduty_token_without_extra", - )) + session.add( + Connection( + conn_id="pagerduty_no_extra", conn_type='http', password="pagerduty_token_without_extra", + ) + ) session.commit() hook = PagerdutyHook(pagerduty_conn_id="pagerduty_no_extra") self.assertEqual(hook.token, 'pagerduty_token_without_extra', 'token initialised.') @@ -67,24 +69,16 @@ def test_create_event(self, mock_event_create): "message": "Event processed", "dedup_key": "samplekeyhere", } - resp = hook.create_event( - routing_key="key", - summary="test", - source="airflow_test", - severity="error", - ) + resp = hook.create_event(routing_key="key", summary="test", source="airflow_test", severity="error",) self.assertEqual(resp["status"], "success") mock_event_create.assert_called_once_with( api_key="pagerduty_token", data={ "routing_key": "key", "event_action": "trigger", - "payload": { - "severity": "error", - "source": "airflow_test", - "summary": "test", - }, - }) + "payload": {"severity": "error", "source": "airflow_test", "summary": "test",}, + }, + ) @mock.patch('airflow.providers.pagerduty.hooks.pagerduty.pypd.EventV2.create') def test_create_event_with_default_routing_key(self, mock_event_create): @@ -95,10 +89,7 @@ def test_create_event_with_default_routing_key(self, mock_event_create): "dedup_key": "samplekeyhere", } resp = hook.create_event( - summary="test", - source="airflow_test", - severity="error", - custom_details='{"foo": "bar"}', + summary="test", source="airflow_test", severity="error", custom_details='{"foo": "bar"}', ) self.assertEqual(resp["status"], "success") mock_event_create.assert_called_once_with( @@ -112,4 +103,5 @@ def test_create_event_with_default_routing_key(self, mock_event_create): "summary": "test", "custom_details": '{"foo": "bar"}', }, - }) + }, + ) diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index cf3815b531b43..c2ebf2cbd4148 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -26,22 +26,19 @@ class TestPapermillOperator(unittest.TestCase): def test_execute(self, mock_papermill): in_nb = "/tmp/does_not_exist" out_nb = "/tmp/will_not_exist" - parameters = {"msg": "hello_world", - "train": 1} + parameters = {"msg": "hello_world", "train": 1} op = PapermillOperator( - input_nb=in_nb, output_nb=out_nb, parameters=parameters, + input_nb=in_nb, + output_nb=out_nb, + parameters=parameters, task_id="papermill_operator_test", - dag=None + dag=None, ) op.pre_execute(context={}) # make sure to have the inlets op.execute(context={}) mock_papermill.execute_notebook.assert_called_once_with( - in_nb, - out_nb, - parameters=parameters, - progress_bar=False, - report_mode=True + in_nb, out_nb, parameters=parameters, progress_bar=False, report_mode=True ) diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index cac19b2ba489b..381857d817e64 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -29,16 +29,10 @@ class TestPostgresHookConn(unittest.TestCase): - def setUp(self): super().setUp() - self.connection = Connection( - login='login', - password='password', - host='host', - schema='schema' - ) + self.connection = Connection(login='login', password='password', host='host', schema='schema') class UnitTestPostgresHook(PostgresHook): conn_name_attr = 'test_conn_id' @@ -51,24 +45,30 @@ class UnitTestPostgresHook(PostgresHook): def test_get_conn_non_default_id(self, mock_connect): self.db_hook.test_conn_id = 'non_default' # pylint: disable=attribute-defined-outside-init self.db_hook.get_conn() - mock_connect.assert_called_once_with(user='login', password='password', - host='host', dbname='schema', - port=None) + mock_connect.assert_called_once_with( + user='login', password='password', host='host', dbname='schema', port=None + ) self.db_hook.get_connection.assert_called_once_with('non_default') @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') def test_get_conn(self, mock_connect): self.db_hook.get_conn() - mock_connect.assert_called_once_with(user='login', password='password', host='host', - dbname='schema', port=None) + mock_connect.assert_called_once_with( + user='login', password='password', host='host', dbname='schema', port=None + ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') def test_get_conn_cursor(self, mock_connect): self.connection.extra = '{"cursor": "dictcursor"}' self.db_hook.get_conn() - mock_connect.assert_called_once_with(cursor_factory=psycopg2.extras.DictCursor, - user='login', password='password', host='host', - dbname='schema', port=None) + mock_connect.assert_called_once_with( + cursor_factory=psycopg2.extras.DictCursor, + user='login', + password='password', + host='host', + dbname='schema', + port=None, + ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') def test_get_conn_with_invalid_cursor(self, mock_connect): @@ -100,8 +100,9 @@ def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect): self.connection.extra = '{"iam":true}' mock_client.return_value.generate_db_auth_token.return_value = 'aws_token' self.db_hook.get_conn() - mock_connect.assert_called_once_with(user='login', password='aws_token', host='host', - dbname='schema', port=5432) + mock_connect.assert_called_once_with( + user='login', password='aws_token', host='host', dbname='schema', port=5432 + ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type') @@ -109,15 +110,17 @@ def test_get_conn_rds_iam_redshift(self, mock_client, mock_connect): self.connection.extra = '{"iam":true, "redshift":true}' self.connection.host = 'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com' login = 'IAM:{login}'.format(login=self.connection.login) - mock_client.return_value.get_cluster_credentials.return_value = {'DbPassword': 'aws_token', - 'DbUser': login} + mock_client.return_value.get_cluster_credentials.return_value = { + 'DbPassword': 'aws_token', + 'DbUser': login, + } self.db_hook.get_conn() - mock_connect.assert_called_once_with(user=login, password='aws_token', host=self.connection.host, - dbname='schema', port=5439) + mock_connect.assert_called_once_with( + user=login, password='aws_token', host=self.connection.host, dbname='schema', port=5439 + ) class TestPostgresHook(unittest.TestCase): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.table = "test_postgres_hook_table" @@ -203,8 +206,7 @@ def test_bulk_dump(self): @pytest.mark.backend("postgres") def test_insert_rows(self): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] self.db_hook.insert_rows(table, rows) @@ -221,12 +223,10 @@ def test_insert_rows(self): @pytest.mark.backend("postgres") def test_insert_rows_replace(self): table = "table" - rows = [(1, "hello",), - (2, "world",)] + rows = [(1, "hello",), (2, "world",)] fields = ("id", "value") - self.db_hook.insert_rows( - table, rows, fields, replace=True, replace_index=fields[0]) + self.db_hook.insert_rows(table, rows, fields, replace=True, replace_index=fields[0]) assert self.conn.close.call_count == 1 assert self.cur.close.call_count == 1 @@ -234,9 +234,10 @@ def test_insert_rows_replace(self): commit_count = 2 # The first and last commit self.assertEqual(commit_count, self.conn.commit.call_count) - sql = "INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) " \ - "ON CONFLICT ({1}) DO UPDATE SET {2} = excluded.{2}".format( - table, fields[0], fields[1]) + sql = ( + "INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) " + "ON CONFLICT ({1}) DO UPDATE SET {2} = excluded.{2}".format(table, fields[0], fields[1]) + ) for row in rows: self.cur.execute.assert_any_call(sql, row) @@ -244,18 +245,15 @@ def test_insert_rows_replace(self): @pytest.mark.backend("postgres") def test_insert_rows_replace_missing_target_field_arg(self): table = "table" - rows = [(1, "hello",), - (2, "world",)] + rows = [(1, "hello",), (2, "world",)] fields = ("id", "value") - self.db_hook.insert_rows( - table, rows, replace=True, replace_index=fields[0]) + self.db_hook.insert_rows(table, rows, replace=True, replace_index=fields[0]) @pytest.mark.xfail @pytest.mark.backend("postgres") def test_insert_rows_replace_missing_replace_index_arg(self): table = "table" - rows = [(1, "hello",), - (2, "world",)] + rows = [(1, "hello",), (2, "world",)] fields = ("id", "value") self.db_hook.insert_rows(table, rows, fields, replace=True) diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index 9bd8880dac659..e3a531310b52c 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -40,6 +40,7 @@ def setUp(self): def tearDown(self): tables_to_drop = ['test_postgres_to_postgres', 'test_airflow'] from airflow.providers.postgres.hooks.postgres import PostgresHook + with PostgresHook().get_conn() as conn: with conn.cursor() as cur: for table in tables_to_drop: @@ -55,14 +56,9 @@ def test_postgres_operator_test(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) autocommit_task = PostgresOperator( - task_id='basic_postgres_with_autocommit', - sql=sql, - dag=self.dag, - autocommit=True) - autocommit_task.run( - start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE, - ignore_ti_state=True) + task_id='basic_postgres_with_autocommit', sql=sql, dag=self.dag, autocommit=True + ) + autocommit_task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_postgres_operator_test_multi(self): sql = [ @@ -70,8 +66,7 @@ def test_postgres_operator_test_multi(self): "TRUNCATE TABLE test_airflow", "INSERT INTO test_airflow VALUES ('X')", ] - op = PostgresOperator( - task_id='postgres_operator_test_multi', sql=sql, dag=self.dag) + op = PostgresOperator(task_id='postgres_operator_test_multi', sql=sql, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_vacuum(self): @@ -80,11 +75,7 @@ def test_vacuum(self): """ sql = "VACUUM ANALYZE;" - op = PostgresOperator( - task_id='postgres_operator_test_vacuum', - sql=sql, - dag=self.dag, - autocommit=True) + op = PostgresOperator(task_id='postgres_operator_test_vacuum', sql=sql, dag=self.dag, autocommit=True) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_overwrite_schema(self): @@ -102,8 +93,8 @@ def test_overwrite_schema(self): ) from psycopg2 import OperationalError + try: - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) except OperationalError as e: assert 'database "foobar" does not exist' in str(e) diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py index ef7b7d532f9d3..389ad93572316 100644 --- a/tests/providers/presto/hooks/test_presto.py +++ b/tests/providers/presto/hooks/test_presto.py @@ -28,16 +28,10 @@ class TestPrestoHookConn(unittest.TestCase): - def setUp(self): super().setUp() - self.connection = Connection( - login='login', - password='password', - host='host', - schema='hive', - ) + self.connection = Connection(login='login', password='password', host='host', schema='hive',) class UnitTestPrestoHook(PrestoHook): conn_name_attr = 'presto_conn_id' @@ -50,13 +44,20 @@ class UnitTestPrestoHook(PrestoHook): @patch('airflow.providers.presto.hooks.presto.prestodb.dbapi.connect') def test_get_conn(self, mock_connect, mock_basic_auth): self.db_hook.get_conn() - mock_connect.assert_called_once_with(catalog='hive', host='host', port=None, http_scheme='http', - schema='hive', source='airflow', user='login', isolation_level=0, - auth=mock_basic_auth('login', 'password')) + mock_connect.assert_called_once_with( + catalog='hive', + host='host', + port=None, + http_scheme='http', + schema='hive', + source='airflow', + user='login', + isolation_level=0, + auth=mock_basic_auth('login', 'password'), + ) class TestPrestoHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -79,8 +80,7 @@ def get_isolation_level(self): @patch('airflow.hooks.dbapi_hook.DbApiHook.insert_rows') def test_insert_rows(self, mock_insert_rows): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] target_fields = None commit_every = 10 self.db_hook.insert_rows(table, rows, target_fields, commit_every) diff --git a/tests/providers/qubole/operators/test_qubole.py b/tests/providers/qubole/operators/test_qubole.py index a21c9261be42d..3ae4faec9d4d7 100644 --- a/tests/providers/qubole/operators/test_qubole.py +++ b/tests/providers/qubole/operators/test_qubole.py @@ -38,16 +38,12 @@ class TestQuboleOperator(unittest.TestCase): def setUp(self): - db.merge_conn( - Connection(conn_id=DEFAULT_CONN, conn_type='HTTP')) - db.merge_conn( - Connection(conn_id=TEST_CONN, conn_type='HTTP', - host='http://localhost/api')) + db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP')) + db.merge_conn(Connection(conn_id=TEST_CONN, conn_type='HTTP', host='http://localhost/api')) def tearDown(self): session = settings.Session() - session.query(Connection).filter( - Connection.conn_id == TEST_CONN).delete() + session.query(Connection).filter(Connection.conn_id == TEST_CONN).delete() session.commit() session.close() @@ -70,9 +66,7 @@ def test_init_with_template_cluster_label(self): task_id=TASK_ID, dag=dag, cluster_label='{{ params.cluster_label }}', - params={ - 'cluster_label': 'default' - } + params={'cluster_label': 'default'}, ) ti = TaskInstance(task, DEFAULT_DATE) @@ -101,8 +95,9 @@ def test_position_args_parameters(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) with dag: - task = QuboleOperator(task_id=TASK_ID, command_type='pigcmd', - parameters="key1=value1 key2=value2", dag=dag) + task = QuboleOperator( + task_id=TASK_ID, command_type='pigcmd', parameters="key1=value1 key2=value2", dag=dag + ) self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[1], "key1=value1") self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[2], "key2=value2") @@ -110,26 +105,27 @@ def test_position_args_parameters(self): cmd = "s3distcp --src s3n://airflow/source_hadoopcmd --dest s3n://airflow/destination_hadoopcmd" task = QuboleOperator(task_id=TASK_ID + "_1", command_type='hadoopcmd', dag=dag, sub_command=cmd) - self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[1], - "s3distcp") - self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[2], - "--src") - self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[3], - "s3n://airflow/source_hadoopcmd") - self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[4], - "--dest") - self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[5], - "s3n://airflow/destination_hadoopcmd") + self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[1], "s3distcp") + self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[2], "--src") + self.assertEqual( + task.get_hook().create_cmd_args({'run_id': 'dummy'})[3], "s3n://airflow/source_hadoopcmd" + ) + self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[4], "--dest") + self.assertEqual( + task.get_hook().create_cmd_args({'run_id': 'dummy'})[5], "s3n://airflow/destination_hadoopcmd" + ) def test_get_redirect_url(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) with dag: - task = QuboleOperator(task_id=TASK_ID, - qubole_conn_id=TEST_CONN, - command_type='shellcmd', - parameters="param1 param2", - dag=dag) + task = QuboleOperator( + task_id=TASK_ID, + qubole_conn_id=TEST_CONN, + command_type='shellcmd', + parameters="param1 param2", + dag=dag, + ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.xcom_push('qbol_cmd_id', 12345) @@ -146,9 +142,7 @@ def test_extra_serialized_field(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) with dag: QuboleOperator( - task_id=TASK_ID, - command_type='shellcmd', - qubole_conn_id=TEST_CONN, + task_id=TASK_ID, command_type='shellcmd', qubole_conn_id=TEST_CONN, ) serialized_dag = SerializedDAG.to_dict(dag) diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py index e063ddced561a..84bfcb66de5b2 100644 --- a/tests/providers/qubole/operators/test_qubole_check.py +++ b/tests/providers/qubole/operators/test_qubole_check.py @@ -30,13 +30,11 @@ class TestQuboleValueCheckOperator(unittest.TestCase): - def setUp(self): self.task_id = 'test_task' self.conn_id = 'default_conn' - def __construct_operator(self, query, pass_value, tolerance=None, - results_parser_callable=None): + def __construct_operator(self, query, pass_value, tolerance=None, results_parser_callable=None): dag = DAG('test_dag', start_date=datetime(2017, 1, 1)) @@ -48,7 +46,8 @@ def __construct_operator(self, query, pass_value, tolerance=None, pass_value=pass_value, results_parser_callable=results_parser_callable, command_type='hivecmd', - tolerance=tolerance) + tolerance=tolerance, + ) def test_pass_value_template(self): pass_value_str = "2018-03-22" @@ -79,8 +78,7 @@ def test_execute_assertion_fail(self, mock_get_hook): mock_cmd = mock.Mock() mock_cmd.status = 'done' mock_cmd.id = 123 - mock_cmd.is_success = mock.Mock( - return_value=HiveCommand.is_success(mock_cmd.status)) + mock_cmd.is_success = mock.Mock(return_value=HiveCommand.is_success(mock_cmd.status)) mock_hook = mock.Mock() mock_hook.get_first.return_value = [11] @@ -89,8 +87,7 @@ def test_execute_assertion_fail(self, mock_get_hook): operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1) - with self.assertRaisesRegex(AirflowException, - 'Qubole Command Id: ' + str(mock_cmd.id)): + with self.assertRaisesRegex(AirflowException, 'Qubole Command Id: ' + str(mock_cmd.id)): operator.execute() mock_cmd.is_success.assert_called_once_with(mock_cmd.status) @@ -101,8 +98,7 @@ def test_execute_assert_query_fail(self, mock_get_hook): mock_cmd = mock.Mock() mock_cmd.status = 'error' mock_cmd.id = 123 - mock_cmd.is_success = mock.Mock( - return_value=HiveCommand.is_success(mock_cmd.status)) + mock_cmd.is_success = mock.Mock(return_value=HiveCommand.is_success(mock_cmd.status)) mock_hook = mock.Mock() mock_hook.get_first.return_value = [11] @@ -129,7 +125,8 @@ def test_results_parser_callable(self, mock_execute, mock_get_query_results): results_parser_callable = mock.Mock() results_parser_callable.return_value = [pass_value] - operator = self.__construct_operator('select value from tab1 limit 1;', - pass_value, None, results_parser_callable) + operator = self.__construct_operator( + 'select value from tab1 limit 1;', pass_value, None, results_parser_callable + ) operator.execute() results_parser_callable.assert_called_once_with([pass_value]) diff --git a/tests/providers/qubole/sensors/test_qubole.py b/tests/providers/qubole/sensors/test_qubole.py index 646d16a70b16d..e7bed5364798d 100644 --- a/tests/providers/qubole/sensors/test_qubole.py +++ b/tests/providers/qubole/sensors/test_qubole.py @@ -35,15 +35,13 @@ class TestQuboleSensor(unittest.TestCase): def setUp(self): - db.merge_conn( - Connection(conn_id=DEFAULT_CONN, conn_type='HTTP')) + db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP')) @patch('airflow.providers.qubole.sensors.qubole.QuboleFileSensor.poke') def test_file_sensore(self, patched_poke): patched_poke.return_value = True sensor = QuboleFileSensor( - task_id='test_qubole_file_sensor', - data={"files": ["s3://some_bucket/some_file"]} + task_id='test_qubole_file_sensor', data={"files": ["s3://some_bucket/some_file"]} ) self.assertTrue(sensor.poke({})) @@ -56,8 +54,8 @@ def test_partition_sensor(self, patched_poke): data={ "schema": "default", "table": "my_partitioned_table", - "columns": [{"column": "month", "values": ["1", "2"]}] - } + "columns": [{"column": "month", "values": ["1", "2"]}], + }, ) self.assertTrue(sensor.poke({})) @@ -75,7 +73,7 @@ def test_partition_sensor_error(self, patched_poke): data={ "schema": "default", "table": "my_partitioned_table", - "columns": [{"column": "month", "values": ["1", "2"]}] + "columns": [{"column": "month", "values": ["1", "2"]}], }, - dag=dag + dag=dag, ) diff --git a/tests/providers/redis/hooks/test_redis.py b/tests/providers/redis/hooks/test_redis.py index b6fe1ba898b61..4328571c86dcb 100644 --- a/tests/providers/redis/hooks/test_redis.py +++ b/tests/providers/redis/hooks/test_redis.py @@ -38,12 +38,13 @@ def test_get_conn(self): self.assertIs(hook.get_conn(), hook.get_conn(), 'Connection initialized only if None.') @mock.patch('airflow.providers.redis.hooks.redis.Redis') - @mock.patch('airflow.providers.redis.hooks.redis.RedisHook.get_connection', - return_value=Connection( - password='password', - host='remote_host', - port=1234, - extra="""{ + @mock.patch( + 'airflow.providers.redis.hooks.redis.RedisHook.get_connection', + return_value=Connection( + password='password', + host='remote_host', + port=1234, + extra="""{ "db": 2, "ssl": true, "ssl_cert_reqs": "required", @@ -51,8 +52,9 @@ def test_get_conn(self): "ssl_keyfile": "/path/to/key-file", "ssl_cert_file": "/path/to/cert-file", "ssl_check_hostname": true - }""" - )) + }""", + ), + ) def test_get_conn_with_extra_config(self, mock_get_connection, mock_redis): connection = mock_get_connection.return_value hook = RedisHook() @@ -68,7 +70,7 @@ def test_get_conn_with_extra_config(self, mock_get_connection, mock_redis): ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"], ssl_keyfile=connection.extra_dejson["ssl_keyfile"], ssl_cert_file=connection.extra_dejson["ssl_cert_file"], - ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"] + ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"], ) def test_get_conn_password_stays_none(self): diff --git a/tests/providers/redis/operators/test_redis_publish.py b/tests/providers/redis/operators/test_redis_publish.py index affb82319e379..4ea9fde3284d6 100644 --- a/tests/providers/redis/operators/test_redis_publish.py +++ b/tests/providers/redis/operators/test_redis_publish.py @@ -32,12 +32,8 @@ @pytest.mark.integration("redis") class TestRedisPublishOperator(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_redis_dag_id', default_args=args) @@ -50,7 +46,7 @@ def test_execute_hello(self): dag=self.dag, message='hello', channel=self.channel, - redis_conn_id='redis_default' + redis_conn_id='redis_default', ) hook = RedisHook(redis_conn_id='redis_default') diff --git a/tests/providers/redis/sensors/test_redis_key.py b/tests/providers/redis/sensors/test_redis_key.py index 1b55055ac1784..a67582a0ab1f1 100644 --- a/tests/providers/redis/sensors/test_redis_key.py +++ b/tests/providers/redis/sensors/test_redis_key.py @@ -31,19 +31,12 @@ @pytest.mark.integration("redis") class TestRedisSensor(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) self.sensor = RedisKeySensor( - task_id='test_task', - redis_conn_id='redis_default', - dag=self.dag, - key='test_key' + task_id='test_task', redis_conn_id='redis_default', dag=self.dag, key='test_key' ) def test_poke(self): diff --git a/tests/providers/redis/sensors/test_redis_pub_sub.py b/tests/providers/redis/sensors/test_redis_pub_sub.py index 5b632857e81c4..207a8268c6184 100644 --- a/tests/providers/redis/sensors/test_redis_pub_sub.py +++ b/tests/providers/redis/sensors/test_redis_pub_sub.py @@ -31,12 +31,8 @@ class TestRedisPubSubSensor(unittest.TestCase): - def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) @@ -45,34 +41,35 @@ def setUp(self): @patch('airflow.providers.redis.hooks.redis.RedisHook.get_conn') def test_poke_mock_true(self, mock_redis_conn): sensor = RedisPubSubSensor( - task_id='test_task', - dag=self.dag, - channels='test', - redis_conn_id='redis_default' + task_id='test_task', dag=self.dag, channels='test', redis_conn_id='redis_default' ) - mock_redis_conn().pubsub().get_message.return_value = \ - {'type': 'message', 'channel': b'test', 'data': b'd1'} + mock_redis_conn().pubsub().get_message.return_value = { + 'type': 'message', + 'channel': b'test', + 'data': b'd1', + } result = sensor.poke(self.mock_context) self.assertTrue(result) - context_calls = [call.xcom_push(key='message', - value={'type': 'message', 'channel': b'test', 'data': b'd1'})] + context_calls = [ + call.xcom_push(key='message', value={'type': 'message', 'channel': b'test', 'data': b'd1'}) + ] self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context call should be same") @patch('airflow.providers.redis.hooks.redis.RedisHook.get_conn') def test_poke_mock_false(self, mock_redis_conn): sensor = RedisPubSubSensor( - task_id='test_task', - dag=self.dag, - channels='test', - redis_conn_id='redis_default' + task_id='test_task', dag=self.dag, channels='test', redis_conn_id='redis_default' ) - mock_redis_conn().pubsub().get_message.return_value = \ - {'type': 'subscribe', 'channel': b'test', 'data': b'd1'} + mock_redis_conn().pubsub().get_message.return_value = { + 'type': 'subscribe', + 'channel': b'test', + 'data': b'd1', + } result = sensor.poke(self.mock_context) self.assertFalse(result) @@ -83,10 +80,7 @@ def test_poke_mock_false(self, mock_redis_conn): @pytest.mark.integration("redis") def test_poke_true(self): sensor = RedisPubSubSensor( - task_id='test_task', - dag=self.dag, - channels='test', - redis_conn_id='redis_default' + task_id='test_task', dag=self.dag, channels='test', redis_conn_id='redis_default' ) hook = RedisHook(redis_conn_id='redis_default') @@ -100,7 +94,9 @@ def test_poke_true(self): context_calls = [ call.xcom_push( key='message', - value={'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'message'})] + value={'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'message'}, + ) + ] self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context calls should be same") result = sensor.poke(self.mock_context) self.assertFalse(result) @@ -108,10 +104,7 @@ def test_poke_true(self): @pytest.mark.integration("redis") def test_poke_false(self): sensor = RedisPubSubSensor( - task_id='test_task', - dag=self.dag, - channels='test', - redis_conn_id='redis_default' + task_id='test_task', dag=self.dag, channels='test', redis_conn_id='redis_default' ) result = sensor.poke(self.mock_context) diff --git a/tests/providers/salesforce/hooks/test_salesforce.py b/tests/providers/salesforce/hooks/test_salesforce.py index 1c9dfdcc0ffd2..ba73ad1047475 100644 --- a/tests/providers/salesforce/hooks/test_salesforce.py +++ b/tests/providers/salesforce/hooks/test_salesforce.py @@ -189,8 +189,7 @@ def test_obect_to_df_with_timestamp_conversion(self, mock_data_frame, mock_descr obj_name = "obj_name" data_frame = self.salesforce_hook.object_to_df( - query_results=[{"attributes": {"type": obj_name}}], - coerce_to_timestamp=True, + query_results=[{"attributes": {"type": obj_name}}], coerce_to_timestamp=True, ) mock_describe_object.assert_called_once_with(obj_name) @@ -204,9 +203,7 @@ def test_obect_to_df_with_timestamp_conversion(self, mock_data_frame, mock_descr return_value=pd.DataFrame({"test": [1, 2, 3]}), ) def test_object_to_df_with_record_time(self, mock_data_frame, mock_time): - data_frame = self.salesforce_hook.object_to_df( - query_results=[], record_time_added=True - ) + data_frame = self.salesforce_hook.object_to_df(query_results=[], record_time_added=True) pd.testing.assert_frame_equal( data_frame, diff --git a/tests/providers/salesforce/hooks/test_tableau.py b/tests/providers/salesforce/hooks/test_tableau.py index f868fb9af4dd5..b416965c1c1fb 100644 --- a/tests/providers/salesforce/hooks/test_tableau.py +++ b/tests/providers/salesforce/hooks/test_tableau.py @@ -24,7 +24,6 @@ class TestTableauHook(unittest.TestCase): - def setUp(self): configuration.conf.load_test_config() @@ -35,7 +34,7 @@ def setUp(self): host='tableau', login='user', password='password', - extra='{"site_id": "my_site"}' + extra='{"site_id": "my_site"}', ) ) db.merge_conn( @@ -43,7 +42,7 @@ def setUp(self): conn_id='tableau_test_token', conn_type='tableau', host='tableau', - extra='{"token_name": "my_token", "personal_access_token": "my_personal_access_token"}' + extra='{"token_name": "my_token", "personal_access_token": "my_personal_access_token"}', ) ) @@ -55,11 +54,9 @@ def test_get_conn_auth_via_password_and_site_in_connection(self, mock_server, mo mock_tableau_auth.assert_called_once_with( username=tableau_hook.conn.login, password=tableau_hook.conn.password, - site_id=tableau_hook.conn.extra_dejson['site_id'] - ) - mock_server.return_value.auth.sign_in.assert_called_once_with( - mock_tableau_auth.return_value + site_id=tableau_hook.conn.extra_dejson['site_id'], ) + mock_server.return_value.auth.sign_in.assert_called_once_with(mock_tableau_auth.return_value) mock_server.return_value.auth.sign_out.assert_called_once_with() @patch('airflow.providers.salesforce.hooks.tableau.PersonalAccessTokenAuth') @@ -70,7 +67,7 @@ def test_get_conn_auth_via_token_and_site_in_init(self, mock_server, mock_tablea mock_tableau_auth.assert_called_once_with( token_name=tableau_hook.conn.extra_dejson['token_name'], personal_access_token=tableau_hook.conn.extra_dejson['personal_access_token'], - site_id=tableau_hook.site_id + site_id=tableau_hook.site_id, ) mock_server.return_value.auth.sign_in_with_personal_access_token.assert_called_once_with( mock_tableau_auth.return_value @@ -80,12 +77,7 @@ def test_get_conn_auth_via_token_and_site_in_init(self, mock_server, mock_tablea @patch('airflow.providers.salesforce.hooks.tableau.TableauAuth') @patch('airflow.providers.salesforce.hooks.tableau.Server') @patch('airflow.providers.salesforce.hooks.tableau.Pager', return_value=[1, 2, 3]) - def test_get_all( - self, - mock_pager, - mock_server, - mock_tableau_auth # pylint: disable=unused-argument - ): + def test_get_all(self, mock_pager, mock_server, mock_tableau_auth): # pylint: disable=unused-argument with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook: jobs = tableau_hook.get_all(resource_name='jobs') self.assertEqual(jobs, mock_pager.return_value) diff --git a/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py b/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py index b4d038244cb30..4751cc9eb10df 100644 --- a/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py +++ b/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py @@ -23,7 +23,6 @@ class TestTableauRefreshWorkbookOperator(unittest.TestCase): - def setUp(self): self.mocked_workbooks = [] for i in range(3): @@ -31,11 +30,7 @@ def setUp(self): mock_workbook.id = i mock_workbook.name = f'wb_{i}' self.mocked_workbooks.append(mock_workbook) - self.kwargs = { - 'site_id': 'test_site', - 'task_id': 'task', - 'dag': None - } + self.kwargs = {'site_id': 'test_site', 'task_id': 'task', 'dag': None} @patch('airflow.providers.salesforce.operators.tableau_refresh_workbook.TableauHook') def test_execute(self, mock_tableau_hook): @@ -64,7 +59,7 @@ def test_execute_blocking(self, mock_tableau_hook, mock_tableau_job_status_senso site_id=self.kwargs['site_id'], tableau_conn_id='tableau_default', task_id='wait_until_succeeded', - dag=None + dag=None, ) @patch('airflow.providers.salesforce.operators.tableau_refresh_workbook.TableauHook') diff --git a/tests/providers/salesforce/sensors/test_tableau_job_status.py b/tests/providers/salesforce/sensors/test_tableau_job_status.py index 67f877853b4fa..f8b7c3e4d82fd 100644 --- a/tests/providers/salesforce/sensors/test_tableau_job_status.py +++ b/tests/providers/salesforce/sensors/test_tableau_job_status.py @@ -21,19 +21,14 @@ from parameterized import parameterized from airflow.providers.salesforce.sensors.tableau_job_status import ( - TableauJobFailedException, TableauJobStatusSensor, + TableauJobFailedException, + TableauJobStatusSensor, ) class TestTableauJobStatusSensor(unittest.TestCase): - def setUp(self): - self.kwargs = { - 'job_id': 'job_2', - 'site_id': 'test_site', - 'task_id': 'task', - 'dag': None - } + self.kwargs = {'job_id': 'job_2', 'site_id': 'test_site', 'task_id': 'task', 'dag': None} @patch('airflow.providers.salesforce.sensors.tableau_job_status.TableauHook') def test_poke(self, mock_tableau_hook): diff --git a/tests/providers/samba/hooks/test_samba.py b/tests/providers/samba/hooks/test_samba.py index cf63ef662305d..35dc452c2469b 100644 --- a/tests/providers/samba/hooks/test_samba.py +++ b/tests/providers/samba/hooks/test_samba.py @@ -44,8 +44,9 @@ def test_get_conn(self, get_conn_mock): @mock.patch('airflow.providers.samba.hooks.samba.SambaHook.get_conn') @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') - def test_push_from_local_should_succeed_if_destination_has_same_name_but_not_a_file(self, base_conn_mock, - samba_hook_mock): + def test_push_from_local_should_succeed_if_destination_has_same_name_but_not_a_file( + self, base_conn_mock, samba_hook_mock + ): base_conn_mock.return_value = connection samba_hook_mock.get_conn.return_value = mock.Mock() @@ -67,8 +68,9 @@ def test_push_from_local_should_succeed_if_destination_has_same_name_but_not_a_f @mock.patch('airflow.providers.samba.hooks.samba.SambaHook.get_conn') @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') - def test_push_from_local_should_delete_file_if_exists_and_save_file(self, base_conn_mock, - samba_hook_mock): + def test_push_from_local_should_delete_file_if_exists_and_save_file( + self, base_conn_mock, samba_hook_mock + ): base_conn_mock.return_value = connection samba_hook_mock.get_conn.return_value = mock.Mock() @@ -83,8 +85,9 @@ def test_push_from_local_should_delete_file_if_exists_and_save_file(self, base_c base_conn_mock.assert_called_once_with('samba_default') samba_hook_mock.assert_called_once() - samba_hook_mock.return_value.exists.assert_has_calls([call(destination_filepath), - call(destination_folder)]) + samba_hook_mock.return_value.exists.assert_has_calls( + [call(destination_filepath), call(destination_folder)] + ) samba_hook_mock.return_value.isfile.assert_not_called() samba_hook_mock.return_value.remove.assert_not_called() samba_hook_mock.return_value.mkdir.assert_called_once_with(destination_folder) @@ -92,8 +95,9 @@ def test_push_from_local_should_delete_file_if_exists_and_save_file(self, base_c @mock.patch('airflow.providers.samba.hooks.samba.SambaHook.get_conn') @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') - def test_push_from_local_should_create_directory_if_not_exist_and_save_file(self, base_conn_mock, - samba_hook_mock): + def test_push_from_local_should_create_directory_if_not_exist_and_save_file( + self, base_conn_mock, samba_hook_mock + ): base_conn_mock.return_value = connection samba_hook_mock.get_conn.return_value = mock.Mock() @@ -108,8 +112,9 @@ def test_push_from_local_should_create_directory_if_not_exist_and_save_file(self base_conn_mock.assert_called_once_with('samba_default') samba_hook_mock.assert_called_once() - samba_hook_mock.return_value.exists.assert_has_calls([call(destination_filepath), - call(destination_folder)]) + samba_hook_mock.return_value.exists.assert_has_calls( + [call(destination_filepath), call(destination_folder)] + ) samba_hook_mock.return_value.isfile.assert_not_called() samba_hook_mock.return_value.remove.assert_not_called() samba_hook_mock.return_value.mkdir.assert_called_once_with(destination_folder) diff --git a/tests/providers/segment/hooks/test_segment.py b/tests/providers/segment/hooks/test_segment.py index f49790b631747..723a646391ada 100644 --- a/tests/providers/segment/hooks/test_segment.py +++ b/tests/providers/segment/hooks/test_segment.py @@ -27,7 +27,6 @@ class TestSegmentHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -37,7 +36,6 @@ def setUp(self): self.conn.extra_dejson = {'write_key': self.expected_write_key} class UnitTestSegmentHook(SegmentHook): - def get_conn(self): return conn diff --git a/tests/providers/segment/operators/test_segment_track_event.py b/tests/providers/segment/operators/test_segment_track_event.py index 2c3490ac98aff..e53b66ddf6977 100644 --- a/tests/providers/segment/operators/test_segment_track_event.py +++ b/tests/providers/segment/operators/test_segment_track_event.py @@ -28,7 +28,6 @@ class TestSegmentHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -38,7 +37,6 @@ def setUp(self): self.conn.extra_dejson = {'write_key': self.expected_write_key} class UnitTestSegmentHook(SegmentHook): - def get_conn(self): return conn @@ -59,7 +57,6 @@ def test_on_error(self): class TestSegmentTrackEventOperator(unittest.TestCase): - @mock.patch('airflow.providers.segment.operators.segment_track_event.SegmentHook') def test_execute(self, mock_hook): # Given @@ -68,10 +65,7 @@ def test_execute(self, mock_hook): properties = {} operator = SegmentTrackEventOperator( - task_id='segment-track', - user_id=user_id, - event=event, - properties=properties, + task_id='segment-track', user_id=user_id, event=event, properties=properties, ) # When @@ -79,7 +73,5 @@ def test_execute(self, mock_hook): # Then mock_hook.return_value.track.assert_called_once_with( - user_id=user_id, - event=event, - properties=properties, + user_id=user_id, event=event, properties=properties, ) diff --git a/tests/providers/sendgrid/utils/test_emailer.py b/tests/providers/sendgrid/utils/test_emailer.py index 856995f1f413b..af798082192d6 100644 --- a/tests/providers/sendgrid/utils/test_emailer.py +++ b/tests/providers/sendgrid/utils/test_emailer.py @@ -38,9 +38,12 @@ def setUp(self): self.expected_mail_data = { 'content': [{'type': 'text/html', 'value': self.html_content}], 'personalizations': [ - {'cc': [{'email': 'foo-cc@foo.com'}, {'email': 'bar-cc@bar.com'}], - 'to': [{'email': 'foo@foo.com'}, {'email': 'bar@bar.com'}], - 'bcc': [{'email': 'foo-bcc@foo.com'}, {'email': 'bar-bcc@bar.com'}]}], + { + 'cc': [{'email': 'foo-cc@foo.com'}, {'email': 'bar-cc@bar.com'}], + 'to': [{'email': 'foo@foo.com'}, {'email': 'bar@bar.com'}], + 'bcc': [{'email': 'foo-bcc@foo.com'}, {'email': 'bar-bcc@bar.com'}], + } + ], 'from': {'email': 'foo@bar.com'}, 'subject': 'sendgrid-send-email unit test', } @@ -48,8 +51,9 @@ def setUp(self): self.categories = ['cat1', 'cat2'] # extras self.expected_mail_data_extras = copy.deepcopy(self.expected_mail_data) - self.expected_mail_data_extras['personalizations'][0]['custom_args'] = ( - self.personalization_custom_args) + self.expected_mail_data_extras['personalizations'][0][ + 'custom_args' + ] = self.personalization_custom_args self.expected_mail_data_extras['categories'] = ['cat2', 'cat1'] self.expected_mail_data_extras['from'] = { 'name': 'Foo', @@ -73,39 +77,52 @@ def test_send_email_sendgrid_correct_email(self, mock_post): filename = os.path.basename(f.name) expected_mail_data = dict( self.expected_mail_data, - attachments=[{ - 'content': 'dGhpcyBpcyBzb21lIHRlc3QgZGF0YQ==', - 'content_id': '<{0}>'.format(filename), - 'disposition': 'attachment', - 'filename': filename, - 'type': 'text/plain', - }], + attachments=[ + { + 'content': 'dGhpcyBpcyBzb21lIHRlc3QgZGF0YQ==', + 'content_id': '<{0}>'.format(filename), + 'disposition': 'attachment', + 'filename': filename, + 'type': 'text/plain', + } + ], ) - send_email(self.recepients, - self.subject, - self.html_content, - cc=self.carbon_copy, - bcc=self.bcc, - files=[f.name]) + send_email( + self.recepients, + self.subject, + self.html_content, + cc=self.carbon_copy, + bcc=self.bcc, + files=[f.name], + ) mock_post.assert_called_once_with(expected_mail_data) # Test the right email is constructed. - @mock.patch.dict( - 'os.environ', - SENDGRID_MAIL_FROM='foo@bar.com', - SENDGRID_MAIL_SENDER='Foo' - ) + @mock.patch.dict('os.environ', SENDGRID_MAIL_FROM='foo@bar.com', SENDGRID_MAIL_SENDER='Foo') @mock.patch('airflow.providers.sendgrid.utils.emailer._post_sendgrid_mail') def test_send_email_sendgrid_correct_email_extras(self, mock_post): - send_email(self.recepients, self.subject, self.html_content, cc=self.carbon_copy, bcc=self.bcc, - personalization_custom_args=self.personalization_custom_args, - categories=self.categories) + send_email( + self.recepients, + self.subject, + self.html_content, + cc=self.carbon_copy, + bcc=self.bcc, + personalization_custom_args=self.personalization_custom_args, + categories=self.categories, + ) mock_post.assert_called_once_with(self.expected_mail_data_extras) @mock.patch.dict('os.environ', clear=True) @mock.patch('airflow.providers.sendgrid.utils.emailer._post_sendgrid_mail') def test_send_email_sendgrid_sender(self, mock_post): - send_email(self.recepients, self.subject, self.html_content, cc=self.carbon_copy, bcc=self.bcc, - from_email='foo@foo.bar', from_name='Foo Bar') + send_email( + self.recepients, + self.subject, + self.html_content, + cc=self.carbon_copy, + bcc=self.bcc, + from_email='foo@foo.bar', + from_name='Foo Bar', + ) mock_post.assert_called_once_with(self.expected_mail_data_sender) diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py index a22bc25a9f34c..131daf2083d9d 100644 --- a/tests/providers/sftp/hooks/test_sftp.py +++ b/tests/providers/sftp/hooks/test_sftp.py @@ -37,12 +37,9 @@ class TestSFTPHook(unittest.TestCase): - @provide_session def update_connection(self, login, session=None): - connection = (session.query(Connection). - filter(Connection.conn_id == "sftp_default") - .first()) + connection = session.query(Connection).filter(Connection.conn_id == "sftp_default").first() old_login = connection.login connection.login = login session.commit() @@ -73,75 +70,57 @@ def test_describe_directory(self): self.assertTrue(TMP_DIR_FOR_TESTS in output) def test_list_directory(self): - output = self.hook.list_directory( - path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertEqual(output, [SUB_DIR]) def test_create_and_delete_directory(self): new_dir_name = 'new_dir' - self.hook.create_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(new_dir_name in output) - self.hook.delete_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(new_dir_name not in output) def test_create_and_delete_directories(self): base_dir = "base_dir" sub_dir = "sub_dir" new_dir_path = os.path.join(base_dir, sub_dir) - self.hook.create_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(base_dir in output) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) self.assertTrue(sub_dir in output) - self.hook.delete_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) - self.hook.delete_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) + self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(new_dir_path not in output) self.assertTrue(base_dir not in output) def test_store_retrieve_and_delete_file(self): self.hook.store_file( - remote_full_path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), - local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS) + remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), ) - output = self.hook.list_directory( - path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS]) retrieved_file_name = 'retrieved.txt' self.hook.retrieve_file( - remote_full_path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), - local_full_path=os.path.join(TMP_PATH, retrieved_file_name) + remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, retrieved_file_name), ) self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH)) os.remove(os.path.join(TMP_PATH, retrieved_file_name)) - self.hook.delete_file(path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) - output = self.hook.list_directory( - path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.hook.delete_file(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertEqual(output, [SUB_DIR]) def test_get_mod_time(self): self.hook.store_file( - remote_full_path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), - local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS) + remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), ) - output = self.hook.get_mod_time(path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + output = self.hook.get_mod_time(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) self.assertEqual(len(output), 14) @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') @@ -153,9 +132,7 @@ def test_no_host_key_check_default(self, get_connection): @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') def test_no_host_key_check_enabled(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"no_host_key_check": true}') + connection = Connection(login='login', host='host', extra='{"no_host_key_check": true}') get_connection.return_value = connection hook = SFTPHook() @@ -163,9 +140,7 @@ def test_no_host_key_check_enabled(self, get_connection): @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') def test_no_host_key_check_disabled(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"no_host_key_check": false}') + connection = Connection(login='login', host='host', extra='{"no_host_key_check": false}') get_connection.return_value = connection hook = SFTPHook() @@ -173,9 +148,7 @@ def test_no_host_key_check_disabled(self, get_connection): @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"no_host_key_check": "foo"}') + connection = Connection(login='login', host='host', extra='{"no_host_key_check": "foo"}') get_connection.return_value = connection hook = SFTPHook() @@ -183,9 +156,7 @@ def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') def test_no_host_key_check_ignore(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"ignore_hostkey_verification": true}') + connection = Connection(login='login', host='host', extra='{"ignore_hostkey_verification": true}') get_connection.return_value = connection hook = SFTPHook() @@ -193,37 +164,39 @@ def test_no_host_key_check_ignore(self, get_connection): @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') def test_no_host_key_check_no_ignore(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"ignore_hostkey_verification": false}') + connection = Connection(login='login', host='host', extra='{"ignore_hostkey_verification": false}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) - @parameterized.expand([ - (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), - (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), - (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), - (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), - ]) + @parameterized.expand( + [ + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), + (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), + ] + ) def test_path_exists(self, path, exists): result = self.hook.path_exists(path) self.assertEqual(result, exists) - @parameterized.expand([ - ("test/path/file.bin", None, None, True), - ("test/path/file.bin", "test", None, True), - ("test/path/file.bin", "test/", None, True), - ("test/path/file.bin", None, "bin", True), - ("test/path/file.bin", "test", "bin", True), - ("test/path/file.bin", "test/", "file.bin", True), - ("test/path/file.bin", None, "file.bin", True), - ("test/path/file.bin", "diff", None, False), - ("test/path/file.bin", "test//", None, False), - ("test/path/file.bin", None, ".txt", False), - ("test/path/file.bin", "diff", ".txt", False), - ]) + @parameterized.expand( + [ + ("test/path/file.bin", None, None, True), + ("test/path/file.bin", "test", None, True), + ("test/path/file.bin", "test/", None, True), + ("test/path/file.bin", None, "bin", True), + ("test/path/file.bin", "test", "bin", True), + ("test/path/file.bin", "test/", "file.bin", True), + ("test/path/file.bin", None, "file.bin", True), + ("test/path/file.bin", "diff", None, False), + ("test/path/file.bin", "test//", None, False), + ("test/path/file.bin", None, ".txt", False), + ("test/path/file.bin", "diff", ".txt", False), + ] + ) def test_path_match(self, path, prefix, delimiter, match): result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter) self.assertEqual(result, match) diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index a3a4433adb831..fde5b47730073 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -53,22 +53,19 @@ def setUp(self): self.test_remote_dir = "/tmp/tmp1" self.test_local_filename = 'test_local_file' self.test_remote_filename = 'test_remote_file' - self.test_local_filepath = '{0}/{1}'.format(self.test_dir, - self.test_local_filename) + self.test_local_filepath = '{0}/{1}'.format(self.test_dir, self.test_local_filename) # Local Filepath with Intermediate Directory - self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir, - self.test_local_filename) - self.test_remote_filepath = '{0}/{1}'.format(self.test_dir, - self.test_remote_filename) + self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir, self.test_local_filename) + self.test_remote_filepath = '{0}/{1}'.format(self.test_dir, self.test_remote_filename) # Remote Filepath with Intermediate Directory - self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir, - self.test_remote_filename) + self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir, self.test_remote_filename) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_pickle_file_transfer_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ + test_local_file_content = ( + b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" + ) # create a test file locally with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) @@ -81,7 +78,7 @@ def test_pickle_file_transfer_put(self): remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, create_intermediate_dirs=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(put_test_task) ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) @@ -93,20 +90,22 @@ def test_pickle_file_transfer_put(self): ssh_hook=self.hook, command="cat {0}".format(self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(check_file_task) ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) ti3.run() self.assertEqual( ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(), - test_local_file_content) + test_local_file_content, + ) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_file_transfer_no_intermediate_dir_error_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ + test_local_file_content = ( + b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" + ) # create a test file locally with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) @@ -122,7 +121,7 @@ def test_file_transfer_no_intermediate_dir_error_put(self): remote_filepath=self.test_remote_filepath_int_dir, operation=SFTPOperation.PUT, create_intermediate_dirs=False, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(put_test_task) ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) @@ -131,9 +130,10 @@ def test_file_transfer_no_intermediate_dir_error_put(self): @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_file_transfer_with_intermediate_dir_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ + test_local_file_content = ( + b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" + ) # create a test file locally with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) @@ -146,7 +146,7 @@ def test_file_transfer_with_intermediate_dir_put(self): remote_filepath=self.test_remote_filepath_int_dir, operation=SFTPOperation.PUT, create_intermediate_dirs=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(put_test_task) ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) @@ -158,20 +158,21 @@ def test_file_transfer_with_intermediate_dir_put(self): ssh_hook=self.hook, command="cat {0}".format(self.test_remote_filepath_int_dir), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(check_file_task) ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) ti3.run() self.assertEqual( - ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), - test_local_file_content) + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), test_local_file_content + ) @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) def test_json_file_transfer_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ + test_local_file_content = ( + b"This is local file content \n which is multiline " b"continuing....with other character\nanother line here \n this is last line" + ) # create a test file locally with open(self.test_local_filepath, 'wb') as file: file.write(test_local_file_content) @@ -183,7 +184,7 @@ def test_json_file_transfer_put(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(put_test_task) ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) @@ -195,29 +196,30 @@ def test_json_file_transfer_put(self): ssh_hook=self.hook, command="cat {0}".format(self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(check_file_task) ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) ti3.run() self.assertEqual( ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(), - b64encode(test_local_file_content).decode('utf-8')) + b64encode(test_local_file_content).decode('utf-8'), + ) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_pickle_file_transfer_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ + test_remote_file_content = ( + "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" + ) # create a test file remotely create_file_task = SSHOperator( task_id="test_create_file", ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), + command="echo '{0}' > {1}".format(test_remote_file_content, self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(create_file_task) ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) @@ -230,7 +232,7 @@ def test_pickle_file_transfer_get(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.GET, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(get_test_task) ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) @@ -244,18 +246,18 @@ def test_pickle_file_transfer_get(self): @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) def test_json_file_transfer_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ + test_remote_file_content = ( + "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" + ) # create a test file remotely create_file_task = SSHOperator( task_id="test_create_file", ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), + command="echo '{0}' > {1}".format(test_remote_file_content, self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(create_file_task) ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) @@ -268,7 +270,7 @@ def test_json_file_transfer_get(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.GET, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(get_test_task) ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) @@ -278,23 +280,22 @@ def test_json_file_transfer_get(self): content_received = None with open(self.test_local_filepath, 'r') as file: content_received = file.read() - self.assertEqual(content_received.strip(), - test_remote_file_content.encode('utf-8').decode('utf-8')) + self.assertEqual(content_received.strip(), test_remote_file_content.encode('utf-8').decode('utf-8')) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_file_transfer_no_intermediate_dir_error_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ + test_remote_file_content = ( + "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" + ) # create a test file remotely create_file_task = SSHOperator( task_id="test_create_file", ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), + command="echo '{0}' > {1}".format(test_remote_file_content, self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(create_file_task) ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) @@ -310,7 +311,7 @@ def test_file_transfer_no_intermediate_dir_error_get(self): local_filepath=self.test_local_filepath_int_dir, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.GET, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(get_test_task) ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) @@ -319,18 +320,18 @@ def test_file_transfer_no_intermediate_dir_error_get(self): @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_file_transfer_with_intermediate_dir_error_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ + test_remote_file_content = ( + "This is remote file content \n which is also multiline " "another line here \n this is last line. EOF" + ) # create a test file remotely create_file_task = SSHOperator( task_id="test_create_file", ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), + command="echo '{0}' > {1}".format(test_remote_file_content, self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(create_file_task) ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) @@ -344,7 +345,7 @@ def test_file_transfer_with_intermediate_dir_error_get(self): remote_filepath=self.test_remote_filepath, operation=SFTPOperation.GET, create_intermediate_dirs=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(get_test_task) ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) @@ -359,14 +360,13 @@ def test_file_transfer_with_intermediate_dir_error_get(self): @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self): # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided - with self.assertRaisesRegex(AirflowException, - "Cannot operate without ssh_hook or ssh_conn_id."): + with self.assertRaisesRegex(AirflowException, "Cannot operate without ssh_hook or ssh_conn_id."): task_0 = SFTPOperator( task_id="test_sftp_0", local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag + dag=self.dag, ) task_0.execute(None) @@ -378,7 +378,7 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag + dag=self.dag, ) try: task_1.execute(None) @@ -392,7 +392,7 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag + dag=self.dag, ) try: task_2.execute(None) @@ -408,7 +408,7 @@ def test_arg_checking(self): local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath, operation=SFTPOperation.PUT, - dag=self.dag + dag=self.dag, ) try: task_3.execute(None) @@ -432,7 +432,7 @@ def delete_remote_resource(self): ssh_hook=self.hook, command="rm {0}".format(self.test_remote_filepath), do_xcom_push=True, - dag=self.dag + dag=self.dag, ) self.assertIsNotNone(remove_file_task) ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py index f35941aded8b5..6115ab956eec0 100644 --- a/tests/providers/sftp/sensors/test_sftp.py +++ b/tests/providers/sftp/sensors/test_sftp.py @@ -28,49 +28,30 @@ class TestSFTPSensor(unittest.TestCase): @patch('airflow.providers.sftp.sensors.sftp.SFTPHook') def test_file_present(self, sftp_hook_mock): sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - context = { - 'ds': '1970-01-01' - } + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt') + context = {'ds': '1970-01-01'} output = sftp_sensor.poke(context) - sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( - '/path/to/file/1970-01-01.txt') + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt') self.assertTrue(output) @patch('airflow.providers.sftp.sensors.sftp.SFTPHook') def test_file_absent(self, sftp_hook_mock): - sftp_hook_mock.return_value.get_mod_time.side_effect = OSError( - SFTP_NO_SUCH_FILE, 'File missing') - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - context = { - 'ds': '1970-01-01' - } + sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(SFTP_NO_SUCH_FILE, 'File missing') + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt') + context = {'ds': '1970-01-01'} output = sftp_sensor.poke(context) - sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( - '/path/to/file/1970-01-01.txt') + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt') self.assertFalse(output) @patch('airflow.providers.sftp.sensors.sftp.SFTPHook') def test_sftp_failure(self, sftp_hook_mock): - sftp_hook_mock.return_value.get_mod_time.side_effect = OSError( - SFTP_FAILURE, 'SFTP failure') - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - context = { - 'ds': '1970-01-01' - } + sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(SFTP_FAILURE, 'SFTP failure') + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt') + context = {'ds': '1970-01-01'} with self.assertRaises(OSError): sftp_sensor.poke(context) - sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( - '/path/to/file/1970-01-01.txt') + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt') def test_hook_not_created_during_init(self): - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') + sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt') self.assertIsNone(sftp_sensor.hook) diff --git a/tests/providers/singularity/operators/test_singularity.py b/tests/providers/singularity/operators/test_singularity.py index cb262ce10abfe..eb143bda1dc65 100644 --- a/tests/providers/singularity/operators/test_singularity.py +++ b/tests/providers/singularity/operators/test_singularity.py @@ -30,30 +30,17 @@ class SingularityOperatorTestCase(unittest.TestCase): @mock.patch('airflow.providers.singularity.operators.singularity.Client') def test_execute(self, client_mock): - instance = mock.Mock(autospec=Instance, **{ - 'start.return_value': 0, - 'stop.return_value': 0, - }) + instance = mock.Mock(autospec=Instance, **{'start.return_value': 0, 'stop.return_value': 0,}) client_mock.instance.return_value = instance - client_mock.execute.return_value = {'return_code': 0, - 'message': 'message'} + client_mock.execute.return_value = {'return_code': 0, 'message': 'message'} - task = SingularityOperator( - task_id='task-id', - image="docker://busybox", - command="echo hello" - ) + task = SingularityOperator(task_id='task-id', image="docker://busybox", command="echo hello") task.execute({}) - client_mock.instance.assert_called_once_with("docker://busybox", - options=[], - args=None, - start=False) + client_mock.instance.assert_called_once_with("docker://busybox", options=[], args=None, start=False) - client_mock.execute.assert_called_once_with(mock.ANY, - "echo hello", - return_result=True) + client_mock.execute.assert_called_once_with(mock.ANY, "echo hello", return_result=True) execute_args, _ = client_mock.execute.call_args self.assertIs(execute_args[0], instance) @@ -61,75 +48,59 @@ def test_execute(self, client_mock): instance.start.assert_called_once_with() instance.stop.assert_called_once_with() - @parameterized.expand([ - ("",), - (None,), - ]) + @parameterized.expand( + [("",), (None,),] + ) def test_command_is_required(self, command): - task = SingularityOperator( - task_id='task-id', - image="docker://busybox", - command=command - ) + task = SingularityOperator(task_id='task-id', image="docker://busybox", command=command) with six.assertRaisesRegex(self, AirflowException, "You must define a command."): task.execute({}) @mock.patch('airflow.providers.singularity.operators.singularity.Client') def test_image_should_be_pulled_when_not_exists(self, client_mock): - instance = mock.Mock(autospec=Instance, **{ - 'start.return_value': 0, - 'stop.return_value': 0, - }) + instance = mock.Mock(autospec=Instance, **{'start.return_value': 0, 'stop.return_value': 0,}) client_mock.pull.return_value = '/tmp/busybox_latest.sif' client_mock.instance.return_value = instance - client_mock.execute.return_value = {'return_code': 0, - 'message': 'message'} + client_mock.execute.return_value = {'return_code': 0, 'message': 'message'} task = SingularityOperator( task_id='task-id', image="docker://busybox", command="echo hello", pull_folder="/tmp", - force_pull=True + force_pull=True, ) task.execute({}) client_mock.instance.assert_called_once_with( "/tmp/busybox_latest.sif", options=[], args=None, start=False ) - client_mock.pull.assert_called_once_with( - "docker://busybox", stream=True, pull_folder="/tmp" - ) - client_mock.execute.assert_called_once_with(mock.ANY, - "echo hello", - return_result=True) - - @parameterized.expand([ - (None, [], ), - ([], [], ), - (["AAA"], ['--bind', 'AAA'], ), - (["AAA", "BBB"], ['--bind', 'AAA', '--bind', 'BBB'], ), - (["AAA", "BBB", "CCC"], ['--bind', 'AAA', '--bind', 'BBB', '--bind', 'CCC'], ), - - ]) + client_mock.pull.assert_called_once_with("docker://busybox", stream=True, pull_folder="/tmp") + client_mock.execute.assert_called_once_with(mock.ANY, "echo hello", return_result=True) + + @parameterized.expand( + [ + (None, [],), + ([], [],), + (["AAA"], ['--bind', 'AAA'],), + (["AAA", "BBB"], ['--bind', 'AAA', '--bind', 'BBB'],), + (["AAA", "BBB", "CCC"], ['--bind', 'AAA', '--bind', 'BBB', '--bind', 'CCC'],), + ] + ) @mock.patch('airflow.providers.singularity.operators.singularity.Client') def test_bind_options(self, volumes, expected_options, client_mock): - instance = mock.Mock(autospec=Instance, **{ - 'start.return_value': 0, - 'stop.return_value': 0, - }) + instance = mock.Mock(autospec=Instance, **{'start.return_value': 0, 'stop.return_value': 0,}) client_mock.pull.return_value = 'docker://busybox' client_mock.instance.return_value = instance - client_mock.execute.return_value = {'return_code': 0, - 'message': 'message'} + client_mock.execute.return_value = {'return_code': 0, 'message': 'message'} task = SingularityOperator( task_id='task-id', image="docker://busybox", command="echo hello", force_pull=True, - volumes=volumes + volumes=volumes, ) task.execute({}) @@ -137,28 +108,22 @@ def test_bind_options(self, volumes, expected_options, client_mock): "docker://busybox", options=expected_options, args=None, start=False ) - @parameterized.expand([ - (None, [], ), - ("", ['--workdir', ''], ), - ("/work-dir/", ['--workdir', '/work-dir/'], ), - ]) + @parameterized.expand( + [(None, [],), ("", ['--workdir', ''],), ("/work-dir/", ['--workdir', '/work-dir/'],),] + ) @mock.patch('airflow.providers.singularity.operators.singularity.Client') def test_working_dir(self, working_dir, expected_working_dir, client_mock): - instance = mock.Mock(autospec=Instance, **{ - 'start.return_value': 0, - 'stop.return_value': 0, - }) + instance = mock.Mock(autospec=Instance, **{'start.return_value': 0, 'stop.return_value': 0,}) client_mock.pull.return_value = 'docker://busybox' client_mock.instance.return_value = instance - client_mock.execute.return_value = {'return_code': 0, - 'message': 'message'} + client_mock.execute.return_value = {'return_code': 0, 'message': 'message'} task = SingularityOperator( task_id='task-id', image="docker://busybox", command="echo hello", force_pull=True, - working_dir=working_dir + working_dir=working_dir, ) task.execute({}) diff --git a/tests/providers/slack/hooks/test_slack.py b/tests/providers/slack/hooks/test_slack.py index 3c7bade6990d2..0c956ccd4b765 100644 --- a/tests/providers/slack/hooks/test_slack.py +++ b/tests/providers/slack/hooks/test_slack.py @@ -26,7 +26,6 @@ class TestSlackHook(unittest.TestCase): - def test_get_token_with_token_only(self): """tests `__get_token` method when only token is provided """ # Given @@ -126,5 +125,4 @@ def test_api_call(self, mock_slack_client, mock_slack_api_call): test_api_json = {'channel': 'test_channel'} slack_hook.call("chat.postMessage", json=test_api_json) - mock_slack_api_call.assert_called_once_with( - mock_slack_client, "chat.postMessage", json=test_api_json) + mock_slack_api_call.assert_called_once_with(mock_slack_client, "chat.postMessage", json=test_api_json) diff --git a/tests/providers/slack/hooks/test_slack_webhook.py b/tests/providers/slack/hooks/test_slack_webhook.py index 4edfdfab197e9..c13b34a1c7bda 100644 --- a/tests/providers/slack/hooks/test_slack_webhook.py +++ b/tests/providers/slack/hooks/test_slack_webhook.py @@ -40,7 +40,7 @@ class TestSlackWebhookHook(unittest.TestCase): 'icon_emoji': ':hankey:', 'icon_url': 'https://airflow.apache.org/_images/pin_large.png', 'link_names': True, - 'proxy': 'https://my-horrible-proxy.proxyist.com:8080' + 'proxy': 'https://my-horrible-proxy.proxyist.com:8080', } expected_message_dict = { 'channel': _config['channel'], @@ -50,7 +50,7 @@ class TestSlackWebhookHook(unittest.TestCase): 'link_names': 1, 'attachments': _config['attachments'], 'blocks': _config['blocks'], - 'text': _config['message'] + 'text': _config['message'], } expected_message = json.dumps(expected_message_dict) expected_url = 'https://hooks.slack.com/services/T000/B000/XXX' @@ -61,19 +61,20 @@ def setUp(self): Connection( conn_id='slack-webhook-default', conn_type='http', - extra='{"webhook_token": "your_token_here"}') + extra='{"webhook_token": "your_token_here"}', + ) ) db.merge_conn( Connection( conn_id='slack-webhook-url', conn_type='http', - host='https://hooks.slack.com/services/T000/B000/XXX') + host='https://hooks.slack.com/services/T000/B000/XXX', + ) ) db.merge_conn( Connection( - conn_id='slack-webhook-host', - conn_type='http', - host='https://hooks.slack.com/services/T000/') + conn_id='slack-webhook-host', conn_type='http', host='https://hooks.slack.com/services/T000/' + ) ) def test_get_token_manual_token(self): @@ -118,10 +119,7 @@ def test_url_generated_by_http_conn_id(self, mock_request, mock_session): except MissingSchema: pass mock_request.assert_called_once_with( - self.expected_method, - self.expected_url, - headers=mock.ANY, - data=mock.ANY + self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY ) mock_request.reset_mock() @@ -134,26 +132,19 @@ def test_url_generated_by_endpoint(self, mock_request, mock_session): except MissingSchema: pass mock_request.assert_called_once_with( - self.expected_method, - self.expected_url, - headers=mock.ANY, - data=mock.ANY + self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY ) mock_request.reset_mock() @mock.patch('requests.Session') @mock.patch('requests.Request') def test_url_generated_by_http_conn_id_and_endpoint(self, mock_request, mock_session): - hook = SlackWebhookHook(http_conn_id='slack-webhook-host', - webhook_token='B000/XXX') + hook = SlackWebhookHook(http_conn_id='slack-webhook-host', webhook_token='B000/XXX') try: hook.execute() except MissingSchema: pass mock_request.assert_called_once_with( - self.expected_method, - self.expected_url, - headers=mock.ANY, - data=mock.ANY + self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY ) mock_request.reset_mock() diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py index 354fe09e9f387..efc292b2e406e 100644 --- a/tests/providers/slack/operators/test_slack.py +++ b/tests/providers/slack/operators/test_slack.py @@ -41,18 +41,12 @@ def setUp(self): "title": "Slack API Documentation", "title_link": "https://api.slack.com/", "text": "Optional text that appears within the attachment", - "fields": [ - { - "title": "Priority", - "value": "High", - "short": 'false' - } - ], + "fields": [{"title": "Priority", "value": "High", "short": 'false'}], "image_url": "http://my-website.com/path/to/image.jpg", "thumb_url": "http://example.com/path/to/thumb.png", "footer": "Slack API", "footer_icon": "https://platform.slack-edge.com/img/default_application_icon.png", - "ts": 123456789 + "ts": 123456789, } ] self.test_blocks = [ @@ -60,19 +54,12 @@ def setUp(self): "type": "section", "text": { "text": "A message *with some bold text* and _some italicized text_.", - "type": "mrkdwn" + "type": "mrkdwn", }, "fields": [ - { - "type": "mrkdwn", - "text": "High" - }, - { - "type": "plain_text", - "emoji": True, - "text": "String" - } - ] + {"type": "mrkdwn", "text": "High"}, + {"type": "plain_text", "emoji": True, "text": "String"}, + ], } ] self.test_attachments_in_json = json.dumps(self.test_attachments) @@ -128,9 +115,7 @@ def test_api_call_params_with_default_args(self, mock_hook): test_slack_conn_id = 'test_slack_conn_id' slack_api_post_operator = SlackAPIPostOperator( - task_id='slack', - username=self.test_username, - slack_conn_id=test_slack_conn_id, + task_id='slack', username=self.test_username, slack_conn_id=test_slack_conn_id, ) slack_api_post_operator.execute() @@ -139,10 +124,10 @@ def test_api_call_params_with_default_args(self, mock_hook): 'channel': "#general", 'username': self.test_username, 'text': 'No message has been set.\n' - 'Here is a cat video instead\n' - 'https://www.youtube.com/watch?v=J---aiyznGQ', + 'Here is a cat video instead\n' + 'https://www.youtube.com/watch?v=J---aiyznGQ', 'icon_url': "https://raw.githubusercontent.com/apache/" - "airflow/master/airflow/www/static/pin_100.png", + "airflow/master/airflow/www/static/pin_100.png", 'attachments': '[]', 'blocks': '[]', } @@ -205,10 +190,7 @@ def test_init_with_valid_params(self): def test_api_call_params_with_default_args(self, mock_hook): test_slack_conn_id = 'test_slack_conn_id' - slack_api_post_operator = SlackAPIFileOperator( - task_id='slack', - slack_conn_id=test_slack_conn_id, - ) + slack_api_post_operator = SlackAPIFileOperator(task_id='slack', slack_conn_id=test_slack_conn_id,) slack_api_post_operator.execute() @@ -217,7 +199,6 @@ def test_api_call_params_with_default_args(self, mock_hook): 'initial_comment': 'No message has been set!', 'filename': 'default_name.csv', 'filetype': 'csv', - 'content': 'default,content,csv,file' - + 'content': 'default,content,csv,file', } self.assertEqual(expected_api_params, slack_api_post_operator.api_params) diff --git a/tests/providers/slack/operators/test_slack_webhook.py b/tests/providers/slack/operators/test_slack_webhook.py index ba859449fe65f..96fd42503e1e0 100644 --- a/tests/providers/slack/operators/test_slack_webhook.py +++ b/tests/providers/slack/operators/test_slack_webhook.py @@ -38,23 +38,16 @@ class TestSlackWebhookOperator(unittest.TestCase): 'icon_emoji': ':hankey', 'icon_url': 'https://airflow.apache.org/_images/pin_large.png', 'link_names': True, - 'proxy': 'https://my-horrible-proxy.proxyist.com:8080' + 'proxy': 'https://my-horrible-proxy.proxyist.com:8080', } def setUp(self): - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - } + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) def test_execute(self): # Given / When - operator = SlackWebhookOperator( - task_id='slack_webhook_job', - dag=self.dag, - **self._config - ) + operator = SlackWebhookOperator(task_id='slack_webhook_job', dag=self.dag, **self._config) self.assertEqual(self._config['http_conn_id'], operator.http_conn_id) self.assertEqual(self._config['webhook_token'], operator.webhook_token) @@ -69,13 +62,16 @@ def test_execute(self): self.assertEqual(self._config['proxy'], operator.proxy) def test_assert_templated_fields(self): - operator = SlackWebhookOperator( - task_id='slack_webhook_job', - dag=self.dag, - **self._config - ) + operator = SlackWebhookOperator(task_id='slack_webhook_job', dag=self.dag, **self._config) - template_fields = ['webhook_token', 'message', 'attachments', 'blocks', 'channel', - 'username', 'proxy'] + template_fields = [ + 'webhook_token', + 'message', + 'attachments', + 'blocks', + 'channel', + 'username', + 'proxy', + ] self.assertEqual(operator.template_fields, template_fields) diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index 0a842f05f1a69..c6c19643961de 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -28,7 +28,6 @@ class TestSnowflakeHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -39,11 +38,13 @@ def setUp(self): self.conn.login = 'user' self.conn.password = 'pw' self.conn.schema = 'public' - self.conn.extra_dejson = {'database': 'db', - 'account': 'airflow', - 'warehouse': 'af_wh', - 'region': 'af_region', - 'role': 'af_role'} + self.conn.extra_dejson = { + 'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh', + 'region': 'af_region', + 'role': 'af_role', + } class UnitTestSnowflakeHook(SnowflakeHook): conn_name_attr = 'snowflake_conn_id' @@ -60,27 +61,20 @@ def get_connection(self, _): self.encrypted_private_key = "/tmp/test_key.p8" # Write some temporary private keys. First is not encrypted, second is with a passphrase. - key = rsa.generate_private_key( - backend=default_backend(), - public_exponent=65537, - key_size=2048 + key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048) + private_key = key.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption() ) - private_key = key.private_bytes(serialization.Encoding.PEM, - serialization.PrivateFormat.PKCS8, - serialization.NoEncryption()) with open(self.non_encrypted_private_key, "wb") as file: file.write(private_key) - key = rsa.generate_private_key( - backend=default_backend(), - public_exponent=65537, - key_size=2048 + key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048) + private_key = key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(self.conn.password.encode()), ) - private_key = key.private_bytes(serialization.Encoding.PEM, - serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.BestAvailableEncryption( - self.conn.password.encode())) with open(self.encrypted_private_key, "wb") as file: file.write(private_key) @@ -90,20 +84,23 @@ def tearDown(self): os.remove(self.non_encrypted_private_key) def test_get_uri(self): - uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role' \ - '&authenticator=snowflake' + uri_shouldbe = ( + 'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role&authenticator=snowflake' + ) self.assertEqual(uri_shouldbe, self.db_hook.get_uri()) def test_get_conn_params(self): - conn_params_shouldbe = {'user': 'user', - 'password': 'pw', - 'schema': 'public', - 'database': 'db', - 'account': 'airflow', - 'warehouse': 'af_wh', - 'region': 'af_region', - 'role': 'af_role', - 'authenticator': 'snowflake'} + conn_params_shouldbe = { + 'user': 'user', + 'password': 'pw', + 'schema': 'public', + 'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh', + 'region': 'af_region', + 'role': 'af_role', + 'authenticator': 'snowflake', + } self.assertEqual(self.db_hook.snowflake_conn_id, 'snowflake_default') # pylint: disable=no-member self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params()) @@ -111,23 +108,27 @@ def test_get_conn(self): self.assertEqual(self.db_hook.get_conn(), self.conn) def test_key_pair_auth_encrypted(self): - self.conn.extra_dejson = {'database': 'db', - 'account': 'airflow', - 'warehouse': 'af_wh', - 'region': 'af_region', - 'role': 'af_role', - 'private_key_file': self.encrypted_private_key} + self.conn.extra_dejson = { + 'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh', + 'region': 'af_region', + 'role': 'af_role', + 'private_key_file': self.encrypted_private_key, + } params = self.db_hook._get_conn_params() self.assertTrue('private_key' in params) def test_key_pair_auth_not_encrypted(self): - self.conn.extra_dejson = {'database': 'db', - 'account': 'airflow', - 'warehouse': 'af_wh', - 'region': 'af_region', - 'role': 'af_role', - 'private_key_file': self.non_encrypted_private_key} + self.conn.extra_dejson = { + 'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh', + 'region': 'af_region', + 'role': 'af_role', + 'private_key_file': self.non_encrypted_private_key, + } self.conn.password = '' params = self.db_hook._get_conn_params() diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index d68208baf8149..a706a8578cac1 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -33,7 +33,6 @@ class TestSnowflakeOperator(unittest.TestCase): - def setUp(self): super().setUp() args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} @@ -47,9 +46,5 @@ def test_snowflake_operator(self, mock_get_hook): dummy VARCHAR(50) ); """ - operator = SnowflakeOperator( - task_id='basic_snowflake', - sql=sql, - dag=self.dag) - operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) + operator = SnowflakeOperator(task_id='basic_snowflake', sql=sql, dag=self.dag) + operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/snowflake/operators/test_snowflake_system.py b/tests/providers/snowflake/operators/test_snowflake_system.py index 69ae58750b361..dd1a3620dd8bc 100644 --- a/tests/providers/snowflake/operators/test_snowflake_system.py +++ b/tests/providers/snowflake/operators/test_snowflake_system.py @@ -28,14 +28,12 @@ CREDENTIALS_DIR = os.environ.get('CREDENTIALS_DIR', '/files/airflow-breeze-config/keys') SNOWFLAKE_KEY = 'snowflake.json' SNOWFLAKE_CREDENTIALS_PATH = os.path.join(CREDENTIALS_DIR, SNOWFLAKE_KEY) -SNOWFLAKE_DAG_FOLDER = os.path.join( - AIRFLOW_MAIN_FOLDER, 'airflow', 'providers', 'snowflake', 'example_dags') +SNOWFLAKE_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, 'airflow', 'providers', 'snowflake', 'example_dags') @pytest.mark.credential_file(SNOWFLAKE_KEY) @pytest.mark.system('snowflake') class SnowflakeExampleDagsSystemTest(SystemTest): - def setUp(self): super().setUp() @@ -61,9 +59,14 @@ def setUp(self): 'warehouse': credentials['warehouse'], 'database': credentials['database'], } - conn = Connection(conn_id='snowflake_conn_id', login=credentials['user'], - password=credentials['password'], schema=credentials['schema'], - conn_type='snowflake', extra=json.dumps(extra)) + conn = Connection( + conn_id='snowflake_conn_id', + login=credentials['user'], + password=credentials['password'], + schema=credentials['schema'], + conn_type='snowflake', + extra=json.dumps(extra), + ) db.merge_conn(conn) def test_dag_example(self): diff --git a/tests/providers/snowflake/transfers/test_s3_to_snowflake.py b/tests/providers/snowflake/transfers/test_s3_to_snowflake.py index 21bf0ef888c1b..02e6e5a205bd9 100644 --- a/tests/providers/snowflake/transfers/test_s3_to_snowflake.py +++ b/tests/providers/snowflake/transfers/test_s3_to_snowflake.py @@ -40,7 +40,7 @@ def test_execute(self, mock_run): schema=schema, columns_array=None, task_id="task_id", - dag=None + dag=None, ).execute(None) files = str(s3_keys) @@ -51,17 +51,13 @@ def test_execute(self, mock_run): files={files} file_format={file_format} """.format( - stage=stage, - files=files, - file_format=file_format + stage=stage, files=files, file_format=file_format ) copy_query = """ COPY INTO {schema}.{table} {base_sql} """.format( - schema=schema, - table=table, - base_sql=base_sql + schema=schema, table=table, base_sql=base_sql ) assert mock_run.call_count == 1 @@ -84,7 +80,7 @@ def test_execute_with_columns(self, mock_run): schema=schema, columns_array=columns_array, task_id="task_id", - dag=None + dag=None, ).execute(None) files = str(s3_keys) @@ -95,18 +91,13 @@ def test_execute_with_columns(self, mock_run): files={files} file_format={file_format} """.format( - stage=stage, - files=files, - file_format=file_format + stage=stage, files=files, file_format=file_format ) copy_query = """ COPY INTO {schema}.{table}({columns}) {base_sql} """.format( - schema=schema, - table=table, - columns=",".join(columns_array), - base_sql=base_sql + schema=schema, table=table, columns=",".join(columns_array), base_sql=base_sql ) assert mock_run.call_count == 1 diff --git a/tests/providers/snowflake/transfers/test_snowflake_to_slack.py b/tests/providers/snowflake/transfers/test_snowflake_to_slack.py index d4e6b527c3471..dea853047f1c9 100644 --- a/tests/providers/snowflake/transfers/test_snowflake_to_slack.py +++ b/tests/providers/snowflake/transfers/test_snowflake_to_slack.py @@ -50,7 +50,7 @@ def test_hooks_and_rendering(self, mock_slack_hook_class, mock_snowflake_hook_cl 'parameters': ['1', '2', '3'], 'slack_message': 'message: {{ ds }}, {{ xxxx }}', 'slack_token': 'test_token', - 'dag': self.example_dag + 'dag': self.example_dag, } snowflake_to_slack_operator = self._construct_operator(**operator_args) @@ -61,20 +61,22 @@ def test_hooks_and_rendering(self, mock_slack_hook_class, mock_snowflake_hook_cl snowflake_to_slack_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) # Test that the Snowflake hook is instantiated with the right parameters - mock_snowflake_hook_class.assert_called_once_with(database='test_database', - role='test_role', - schema='test_schema', - snowflake_conn_id='snowflake_connection', - warehouse='test_warehouse') + mock_snowflake_hook_class.assert_called_once_with( + database='test_database', + role='test_role', + schema='test_schema', + snowflake_conn_id='snowflake_connection', + warehouse='test_warehouse', + ) # Test that the get_pandas_df method is executed on the Snowflake hook with the prendered sql and # correct params snowflake_hook.get_pandas_df.assert_called_once_with('sql 2017-01-01', parameters=['1', '2', '3']) # Test that the Slack hook is instantiated with the right parameters - mock_slack_hook_class.assert_called_once_with(http_conn_id='slack_connection', - message='message: 2017-01-01, 1234', - webhook_token='test_token') + mock_slack_hook_class.assert_called_once_with( + http_conn_id='slack_connection', message='message: 2017-01-01, 1234', webhook_token='test_token' + ) # Test that the Slack hook's execute method gets run once slack_webhook_hook.execute.assert_called_once() diff --git a/tests/providers/sqlite/hooks/test_sqlite.py b/tests/providers/sqlite/hooks/test_sqlite.py index 833bb8837bd5a..ca7c0ae0d1be0 100644 --- a/tests/providers/sqlite/hooks/test_sqlite.py +++ b/tests/providers/sqlite/hooks/test_sqlite.py @@ -26,7 +26,6 @@ class TestSqliteHookConn(unittest.TestCase): - def setUp(self): self.connection = Connection(host='host') @@ -52,7 +51,6 @@ def test_get_conn_non_default_id(self, mock_connect): class TestSqliteHook(unittest.TestCase): - def setUp(self): self.cur = mock.MagicMock() diff --git a/tests/providers/sqlite/operators/test_sqlite.py b/tests/providers/sqlite/operators/test_sqlite.py index f9ff282b170a7..3cc93c3156ef3 100644 --- a/tests/providers/sqlite/operators/test_sqlite.py +++ b/tests/providers/sqlite/operators/test_sqlite.py @@ -40,6 +40,7 @@ def setUp(self): def tearDown(self): tables_to_drop = ['test_airflow', 'test_airflow2'] from airflow.providers.sqlite.hooks.sqlite import SqliteHook + with SqliteHook().get_conn() as conn: cur = conn.cursor() for table in tables_to_drop: @@ -59,8 +60,7 @@ def test_sqlite_operator_with_multiple_statements(self): "CREATE TABLE IF NOT EXISTS test_airflow (dummy VARCHAR(50))", "INSERT INTO test_airflow VALUES ('X')", ] - op = SqliteOperator( - task_id='sqlite_operator_with_multiple_statements', sql=sql, dag=self.dag) + op = SqliteOperator(task_id='sqlite_operator_with_multiple_statements', sql=sql, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_sqlite_operator_with_invalid_sql(self): @@ -70,9 +70,9 @@ def test_sqlite_operator_with_invalid_sql(self): ] from sqlite3 import OperationalError + try: - op = SqliteOperator( - task_id='sqlite_operator_with_multiple_statements', sql=sql, dag=self.dag) + op = SqliteOperator(task_id='sqlite_operator_with_multiple_statements', sql=sql, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) pytest.fail("An exception should have been thrown") except OperationalError as e: diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index a436ef6d6db5e..fc7e55973fd1f 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -74,8 +74,7 @@ def setUpClass(cls) -> None: conn_id=cls.CONN_SSH_WITH_EXTRA, host='localhost', conn_type='ssh', - extra='{"compress" : true, "no_host_key_check" : "true", ' - '"allow_host_key_change": false}' + extra='{"compress" : true, "no_host_key_check" : "true", ' '"allow_host_key_change": false}', ) ) db.merge_conn( @@ -84,7 +83,7 @@ def setUpClass(cls) -> None: host='localhost', conn_type='ssh', extra='{"compress" : true, "no_host_key_check" : "true", ' - '"allow_host_key_change": false, "look_for_keys": false}' + '"allow_host_key_change": false, "look_for_keys": false}', ) ) db.merge_conn( @@ -92,20 +91,20 @@ def setUpClass(cls) -> None: conn_id=cls.CONN_SSH_WITH_PRIVATE_KEY_EXTRA, host='localhost', conn_type='ssh', - extra=json.dumps({ - "private_key": TEST_PRIVATE_KEY, - }) + extra=json.dumps({"private_key": TEST_PRIVATE_KEY,}), ) ) @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') def test_ssh_connection_with_password(self, ssh_mock): - hook = SSHHook(remote_host='remote_host', - port='port', - username='username', - password='password', - timeout=10, - key_file='fake.file') + hook = SSHHook( + remote_host='remote_host', + port='port', + username='username', + password='password', + timeout=10, + key_file='fake.file', + ) with hook.get_conn(): ssh_mock.return_value.connect.assert_called_once_with( @@ -117,16 +116,14 @@ def test_ssh_connection_with_password(self, ssh_mock): compress=True, port='port', sock=None, - look_for_keys=True + look_for_keys=True, ) @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') def test_ssh_connection_without_password(self, ssh_mock): - hook = SSHHook(remote_host='remote_host', - port='port', - username='username', - timeout=10, - key_file='fake.file') + hook = SSHHook( + remote_host='remote_host', port='port', username='username', timeout=10, key_file='fake.file' + ) with hook.get_conn(): ssh_mock.return_value.connect.assert_called_once_with( @@ -137,47 +134,51 @@ def test_ssh_connection_without_password(self, ssh_mock): compress=True, port='port', sock=None, - look_for_keys=True + look_for_keys=True, ) @mock.patch('airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder') def test_tunnel_with_password(self, ssh_mock): - hook = SSHHook(remote_host='remote_host', - port='port', - username='username', - password='password', - timeout=10, - key_file='fake.file') + hook = SSHHook( + remote_host='remote_host', + port='port', + username='username', + password='password', + timeout=10, + key_file='fake.file', + ) with hook.get_tunnel(1234): - ssh_mock.assert_called_once_with('remote_host', - ssh_port='port', - ssh_username='username', - ssh_password='password', - ssh_pkey='fake.file', - ssh_proxy=None, - local_bind_address=('localhost', ), - remote_bind_address=('localhost', 1234), - logger=hook.log) + ssh_mock.assert_called_once_with( + 'remote_host', + ssh_port='port', + ssh_username='username', + ssh_password='password', + ssh_pkey='fake.file', + ssh_proxy=None, + local_bind_address=('localhost',), + remote_bind_address=('localhost', 1234), + logger=hook.log, + ) @mock.patch('airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder') def test_tunnel_without_password(self, ssh_mock): - hook = SSHHook(remote_host='remote_host', - port='port', - username='username', - timeout=10, - key_file='fake.file') + hook = SSHHook( + remote_host='remote_host', port='port', username='username', timeout=10, key_file='fake.file' + ) with hook.get_tunnel(1234): - ssh_mock.assert_called_once_with('remote_host', - ssh_port='port', - ssh_username='username', - ssh_pkey='fake.file', - ssh_proxy=None, - local_bind_address=('localhost', ), - remote_bind_address=('localhost', 1234), - host_pkey_directories=[], - logger=hook.log) + ssh_mock.assert_called_once_with( + 'remote_host', + ssh_port='port', + ssh_username='username', + ssh_pkey='fake.file', + ssh_proxy=None, + local_bind_address=('localhost',), + remote_bind_address=('localhost', 1234), + host_pkey_directories=[], + logger=hook.log, + ) def test_conn_with_extra_parameters(self): ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA) @@ -201,15 +202,17 @@ def test_tunnel_with_private_key(self, ssh_mock): ) with hook.get_tunnel(1234): - ssh_mock.assert_called_once_with('remote_host', - ssh_port='port', - ssh_username='username', - ssh_pkey=TEST_PKEY, - ssh_proxy=None, - local_bind_address=('localhost',), - remote_bind_address=('localhost', 1234), - host_pkey_directories=[], - logger=hook.log) + ssh_mock.assert_called_once_with( + 'remote_host', + ssh_port='port', + ssh_username='username', + ssh_pkey=TEST_PKEY, + ssh_proxy=None, + local_bind_address=('localhost',), + remote_bind_address=('localhost', 1234), + host_pkey_directories=[], + logger=hook.log, + ) def test_ssh_connection(self): hook = SSHHook(ssh_conn_id='ssh_default') @@ -230,10 +233,7 @@ def test_tunnel(self): import socket import subprocess - subprocess_kwargs = dict( - args=["python", "-c", HELLO_SERVER_CMD], - stdout=subprocess.PIPE, - ) + subprocess_kwargs = dict(args=["python", "-c", HELLO_SERVER_CMD], stdout=subprocess.PIPE,) with subprocess.Popen(**subprocess_kwargs) as server_handle, hook.create_tunnel(2135, 2134): server_output = server_handle.stdout.read(5) self.assertEqual(b"ready", server_output) @@ -264,5 +264,5 @@ def test_ssh_connection_with_private_key_extra(self, ssh_mock): compress=True, port='port', sock=None, - look_for_keys=True + look_for_keys=True, ) diff --git a/tests/providers/ssh/operators/test_ssh.py b/tests/providers/ssh/operators/test_ssh.py index b9e9a0c04551a..d80c6d735e108 100644 --- a/tests/providers/ssh/operators/test_ssh.py +++ b/tests/providers/ssh/operators/test_ssh.py @@ -39,6 +39,7 @@ class TestSSHOperator(unittest.TestCase): def setUp(self): from airflow.providers.ssh.hooks.ssh import SSHHook + hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True args = { @@ -54,11 +55,7 @@ def test_hook_created_correctly(self): timeout = 20 ssh_id = "ssh_default" task = SSHOperator( - task_id="test", - command=COMMAND, - dag=self.dag, - timeout=timeout, - ssh_conn_id="ssh_default" + task_id="test", command=COMMAND, dag=self.dag, timeout=timeout, ssh_conn_id="ssh_default" ) self.assertIsNotNone(task) @@ -70,36 +67,27 @@ def test_hook_created_correctly(self): @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) def test_json_command_execution(self): task = SSHOperator( - task_id="test", - ssh_hook=self.hook, - command=COMMAND, - do_xcom_push=True, - dag=self.dag, + task_id="test", ssh_hook=self.hook, command=COMMAND, do_xcom_push=True, dag=self.dag, ) self.assertIsNotNone(task) - ti = TaskInstance( - task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) ti.run() self.assertIsNotNone(ti.duration) - self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), - b64encode(b'airflow').decode('utf-8')) + self.assertEqual( + ti.xcom_pull(task_ids='test', key='return_value'), b64encode(b'airflow').decode('utf-8') + ) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) def test_pickle_command_execution(self): task = SSHOperator( - task_id="test", - ssh_hook=self.hook, - command=COMMAND, - do_xcom_push=True, - dag=self.dag, + task_id="test", ssh_hook=self.hook, command=COMMAND, do_xcom_push=True, dag=self.dag, ) self.assertIsNotNone(task) - ti = TaskInstance( - task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) ti.run() self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'airflow') @@ -111,45 +99,35 @@ def test_command_execution_with_env(self): command=COMMAND, do_xcom_push=True, dag=self.dag, - environment={'TEST': 'value'} + environment={'TEST': 'value'}, ) self.assertIsNotNone(task) with conf_vars({('core', 'enable_xcom_pickling'): 'True'}): - ti = TaskInstance( - task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) ti.run() self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'airflow') def test_no_output_command(self): task = SSHOperator( - task_id="test", - ssh_hook=self.hook, - command="sleep 1", - do_xcom_push=True, - dag=self.dag, + task_id="test", ssh_hook=self.hook, command="sleep 1", do_xcom_push=True, dag=self.dag, ) self.assertIsNotNone(task) with conf_vars({('core', 'enable_xcom_pickling'): 'True'}): - ti = TaskInstance( - task=task, execution_date=timezone.utcnow()) + ti = TaskInstance(task=task, execution_date=timezone.utcnow()) ti.run() self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'') - @unittest.mock.patch('os.environ', { - 'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost" - }) + @unittest.mock.patch('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self): # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided - with self.assertRaisesRegex(AirflowException, - "Cannot operate without ssh_hook or ssh_conn_id."): - task_0 = SSHOperator(task_id="test", command=COMMAND, - timeout=TIMEOUT, dag=self.dag) + with self.assertRaisesRegex(AirflowException, "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=self.dag) task_0.execute(None) # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook @@ -159,7 +137,7 @@ def test_arg_checking(self): ssh_conn_id=TEST_CONN_ID, command=COMMAND, timeout=TIMEOUT, - dag=self.dag + dag=self.dag, ) try: task_1.execute(None) @@ -172,7 +150,7 @@ def test_arg_checking(self): ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided command=COMMAND, timeout=TIMEOUT, - dag=self.dag + dag=self.dag, ) try: task_2.execute(None) @@ -187,7 +165,7 @@ def test_arg_checking(self): ssh_conn_id=TEST_CONN_ID, command=COMMAND, timeout=TIMEOUT, - dag=self.dag + dag=self.dag, ) try: task_3.execute(None) @@ -195,12 +173,14 @@ def test_arg_checking(self): pass self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) - @parameterized.expand([ - (COMMAND, False, False), - (COMMAND, True, True), - (COMMAND_WITH_SUDO, False, True), - (COMMAND_WITH_SUDO, True, True), - ]) + @parameterized.expand( + [ + (COMMAND, False, False), + (COMMAND, True, True), + (COMMAND_WITH_SUDO, False, True), + (COMMAND_WITH_SUDO, True, True), + ] + ) def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out): task = SSHOperator( task_id="test", @@ -208,7 +188,7 @@ def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out): command=command, timeout=TIMEOUT, get_pty=get_pty_in, - dag=self.dag + dag=self.dag, ) try: task.execute(None) diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py index 0d69e8c4ecd02..1940e69a93601 100644 --- a/tests/providers/vertica/hooks/test_vertica.py +++ b/tests/providers/vertica/hooks/test_vertica.py @@ -26,16 +26,10 @@ class TestVerticaHookConn(unittest.TestCase): - def setUp(self): super().setUp() - self.connection = Connection( - login='login', - password='password', - host='host', - schema='vertica', - ) + self.connection = Connection(login='login', password='password', host='host', schema='vertica',) class UnitTestVerticaHook(VerticaHook): conn_name_attr = 'vertica_conn_id' @@ -47,13 +41,12 @@ class UnitTestVerticaHook(VerticaHook): @patch('airflow.providers.vertica.hooks.vertica.connect') def test_get_conn(self, mock_connect): self.db_hook.get_conn() - mock_connect.assert_called_once_with(host='host', port=5433, - database='vertica', user='login', - password="password") + mock_connect.assert_called_once_with( + host='host', port=5433, database='vertica', user='login', password="password" + ) class TestVerticaHook(unittest.TestCase): - def setUp(self): super().setUp() @@ -73,8 +66,7 @@ def get_conn(self): @patch('airflow.hooks.dbapi_hook.DbApiHook.insert_rows') def test_insert_rows(self, mock_insert_rows): table = "table" - rows = [("hello",), - ("world",)] + rows = [("hello",), ("world",)] target_fields = None commit_every = 10 self.db_hook.insert_rows(table, rows, target_fields, commit_every) diff --git a/tests/providers/vertica/operators/test_vertica.py b/tests/providers/vertica/operators/test_vertica.py index 49dfca99bf15a..0b963f30a5bd2 100644 --- a/tests/providers/vertica/operators/test_vertica.py +++ b/tests/providers/vertica/operators/test_vertica.py @@ -23,13 +23,9 @@ class TestVerticaOperator(unittest.TestCase): - @mock.patch('airflow.providers.vertica.operators.vertica.VerticaHook') def test_execute(self, mock_hook): sql = "select a, b, c" - op = VerticaOperator(task_id='test_task_id', - sql=sql) + op = VerticaOperator(task_id='test_task_id', sql=sql) op.execute(None) - mock_hook.return_value.run.assert_called_once_with( - sql=sql - ) + mock_hook.return_value.run.assert_called_once_with(sql=sql) diff --git a/tests/providers/yandex/hooks/test_yandex.py b/tests/providers/yandex/hooks/test_yandex.py index c9493b4d95bb9..55b0f3b3e9de1 100644 --- a/tests/providers/yandex/hooks/test_yandex.py +++ b/tests/providers/yandex/hooks/test_yandex.py @@ -24,11 +24,9 @@ class TestYandexHook(unittest.TestCase): - @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') - def test_client_created_without_exceptions(self, get_credentials_mock, - get_connection_mock): + def test_client_created_without_exceptions(self, get_credentials_mock, get_connection_mock): """tests `init` method to validate client creation when all parameters are passed """ # Inputs to constructor @@ -38,12 +36,12 @@ def test_client_created_without_exceptions(self, get_credentials_mock, extra_dejson = '{"extras": "extra"}' get_connection_mock['extra_dejson'] = "sdsd" get_connection_mock.extra_dejson = '{"extras": "extra"}' - get_connection_mock.return_value = mock.\ - Mock(connection_id='yandexcloud_default', extra_dejson=extra_dejson) + get_connection_mock.return_value = mock.Mock( + connection_id='yandexcloud_default', extra_dejson=extra_dejson + ) get_credentials_mock.return_value = {"token": 122323} - hook = YandexCloudBaseHook(None, - default_folder_id, default_public_ssh_key) + hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key) self.assertIsNotNone(hook.client) @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @@ -58,11 +56,13 @@ def test_get_credentials_raise_exception(self, get_connection_mock): extra_dejson = '{"extras": "extra"}' get_connection_mock['extra_dejson'] = "sdsd" get_connection_mock.extra_dejson = '{"extras": "extra"}' - get_connection_mock.return_value = mock.Mock(connection_id='yandexcloud_default', - extra_dejson=extra_dejson) + get_connection_mock.return_value = mock.Mock( + connection_id='yandexcloud_default', extra_dejson=extra_dejson + ) - self.assertRaises(AirflowException, YandexCloudBaseHook, None, - default_folder_id, default_public_ssh_key) + self.assertRaises( + AirflowException, YandexCloudBaseHook, None, default_folder_id, default_public_ssh_key + ) @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') @mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') @@ -74,11 +74,11 @@ def test_get_field(self, get_credentials_mock, get_connection_mock): extra_dejson = {"extra__yandexcloud__one": "value_one"} get_connection_mock['extra_dejson'] = "sdsd" get_connection_mock.extra_dejson = '{"extras": "extra"}' - get_connection_mock.return_value = mock.Mock(connection_id='yandexcloud_default', - extra_dejson=extra_dejson) + get_connection_mock.return_value = mock.Mock( + connection_id='yandexcloud_default', extra_dejson=extra_dejson + ) get_credentials_mock.return_value = {"token": 122323} - hook = YandexCloudBaseHook(None, - default_folder_id, default_public_ssh_key) + hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key) self.assertEqual(hook._get_field('one'), 'value_one') diff --git a/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py b/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py index c679dadf01e38..d92d7ddf5b807 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py @@ -64,7 +64,6 @@ @unittest.skipIf(yandexcloud is None, 'Skipping Yandex.Cloud hook test: no yandexcloud module') class TestYandexCloudDataprocHook(unittest.TestCase): - def _init_hook(self): with patch('airflow.hooks.base_hook.BaseHook.get_connection') as get_connection_mock: get_connection_mock.return_value = self.connection @@ -93,9 +92,7 @@ def test_create_dataproc_cluster_mocked(self, create_operation_mock): @patch('yandexcloud.SDK.create_operation_and_get_result') def test_delete_dataproc_cluster_mocked(self, create_operation_mock): self._init_hook() - self.hook.client.delete_cluster( - 'my_cluster_id' - ) + self.hook.client.delete_cluster('my_cluster_id') self.assertTrue(create_operation_mock.called) @patch('yandexcloud.SDK.create_operation_and_get_result') @@ -119,14 +116,21 @@ def test_create_mapreduce_job_hook(self, create_operation_mock): self.hook.client.create_mapreduce_job( archive_uris=None, args=[ - '-mapper', 'mapper.py', '-reducer', 'reducer.py', '-numReduceTasks', '1', '-input', - 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', '-output', - 's3a://some-out-bucket/dataproc/job/results' + '-mapper', + 'mapper.py', + '-reducer', + 'reducer.py', + '-numReduceTasks', + '1', + '-input', + 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', + '-output', + 's3a://some-out-bucket/dataproc/job/results', ], cluster_id='my_cluster_id', file_uris=[ 's3a://some-in-bucket/jobs/sources/mapreduce-001/mapper.py', - 's3a://some-in-bucket/jobs/sources/mapreduce-001/reducer.py' + 's3a://some-in-bucket/jobs/sources/mapreduce-001/reducer.py', ], jar_file_uris=None, main_class='org.apache.hadoop.streaming.HadoopStreaming', @@ -135,8 +139,8 @@ def test_create_mapreduce_job_hook(self, create_operation_mock): properties={ 'yarn.app.mapreduce.am.resource.mb': '2048', 'yarn.app.mapreduce.am.command-opts': '-Xmx2048m', - 'mapreduce.job.maps': '6' - } + 'mapreduce.job.maps': '6', + }, ) self.assertTrue(create_operation_mock.called) @@ -148,7 +152,7 @@ def test_create_spark_job_hook(self, create_operation_mock): archive_uris=['s3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip'], args=[ 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', - 's3a://some-out-bucket/dataproc/job/results/${{JOB_ID}}' + 's3a://some-out-bucket/dataproc/job/results/${{JOB_ID}}', ], cluster_id='my_cluster_id', file_uris=['s3a://some-in-bucket/jobs/sources/data/config.json'], @@ -156,12 +160,12 @@ def test_create_spark_job_hook(self, create_operation_mock): 's3a://some-in-bucket/jobs/sources/java/icu4j-61.1.jar', 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar', 's3a://some-in-bucket/jobs/sources/java/opencsv-4.1.jar', - 's3a://some-in-bucket/jobs/sources/java/json-20190722.jar' + 's3a://some-in-bucket/jobs/sources/java/json-20190722.jar', ], main_class='ru.yandex.cloud.dataproc.examples.PopulationSparkJob', main_jar_file_uri='s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar', name='Spark job', - properties={'spark.submit.deployMode': 'cluster'} + properties={'spark.submit.deployMode': 'cluster'}, ) self.assertTrue(create_operation_mock.called) @@ -173,18 +177,18 @@ def test_create_pyspark_job_hook(self, create_operation_mock): archive_uris=['s3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip'], args=[ 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', - 's3a://some-out-bucket/jobs/results/${{JOB_ID}}' + 's3a://some-out-bucket/jobs/results/${{JOB_ID}}', ], cluster_id='my_cluster_id', file_uris=['s3a://some-in-bucket/jobs/sources/data/config.json'], jar_file_uris=[ 's3a://some-in-bucket/jobs/sources/java/dataproc-examples-1.0.jar', 's3a://some-in-bucket/jobs/sources/java/icu4j-61.1.jar', - 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar' + 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar', ], main_python_file_uri='s3a://some-in-bucket/jobs/sources/pyspark-001/main.py', name='Pyspark job', properties={'spark.submit.deployMode': 'cluster'}, - python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py'] + python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py'], ) self.assertTrue(create_operation_mock.called) diff --git a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py index b582cca665178..0fb9825b6b285 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py +++ b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py @@ -22,8 +22,12 @@ from airflow.models.dag import DAG from airflow.providers.yandex.operators.yandexcloud_dataproc import ( - DataprocCreateClusterOperator, DataprocCreateHiveJobOperator, DataprocCreateMapReduceJobOperator, - DataprocCreatePysparkJobOperator, DataprocCreateSparkJobOperator, DataprocDeleteClusterOperator, + DataprocCreateClusterOperator, + DataprocCreateHiveJobOperator, + DataprocCreateMapReduceJobOperator, + DataprocCreatePysparkJobOperator, + DataprocCreateSparkJobOperator, + DataprocDeleteClusterOperator, ) # Airflow connection with type "yandexcloud" @@ -67,7 +71,7 @@ def setUp(self): 'start_date': datetime.datetime.today(), 'end_date': datetime.datetime.today() + datetime.timedelta(days=1), }, - schedule_interval='@daily' + schedule_interval='@daily', ) @patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') @@ -110,21 +114,20 @@ def test_create_cluster(self, create_cluster_mock, *_): 'e6faKCxH6iDRteo4D8L8BxwzN42uZSB0nfmjkIxFTcEU3mFSXEbWByg78aoddMrAAjatyrhH1pON6P0=' ], subnet_id='my_subnet_id', - zone='ru-central1-c' + zone='ru-central1-c', + ) + context['task_instance'].xcom_push.assert_has_calls( + [ + call(key='cluster_id', value=create_cluster_mock().response.id), + call(key='yandexcloud_connection_id', value=CONNECTION_ID), + ] ) - context['task_instance'].xcom_push.assert_has_calls([ - call(key='cluster_id', value=create_cluster_mock().response.id), - call(key='yandexcloud_connection_id', value=CONNECTION_ID), - ]) @patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') @patch('airflow.hooks.base_hook.BaseHook.get_connection') @patch('yandexcloud._wrappers.dataproc.Dataproc.delete_cluster') def test_delete_cluster_operator(self, delete_cluster_mock, *_): - operator = DataprocDeleteClusterOperator( - task_id='delete_cluster', - connection_id=CONNECTION_ID, - ) + operator = DataprocDeleteClusterOperator(task_id='delete_cluster', connection_id=CONNECTION_ID,) context = {'task_instance': MagicMock()} context['task_instance'].xcom_pull.return_value = 'my_cluster_id' operator.execute(context) @@ -135,18 +138,14 @@ def test_delete_cluster_operator(self, delete_cluster_mock, *_): @patch('airflow.hooks.base_hook.BaseHook.get_connection') @patch('yandexcloud._wrappers.dataproc.Dataproc.create_hive_job') def test_create_hive_job_operator(self, create_hive_job_mock, *_): - operator = DataprocCreateHiveJobOperator( - task_id='create_hive_job', - query='SELECT 1;', - ) + operator = DataprocCreateHiveJobOperator(task_id='create_hive_job', query='SELECT 1;',) context = {'task_instance': MagicMock()} context['task_instance'].xcom_pull.return_value = 'my_cluster_id' operator.execute(context) - context['task_instance'].xcom_pull.assert_has_calls([ - call(key='cluster_id'), - call(key='yandexcloud_connection_id'), - ]) + context['task_instance'].xcom_pull.assert_has_calls( + [call(key='cluster_id'), call(key='yandexcloud_connection_id'),] + ) create_hive_job_mock.assert_called_once_with( cluster_id='my_cluster_id', @@ -167,14 +166,19 @@ def test_create_mapreduce_job_operator(self, create_mapreduce_job_mock, *_): main_class='org.apache.hadoop.streaming.HadoopStreaming', file_uris=[ 's3a://some-in-bucket/jobs/sources/mapreduce-001/mapper.py', - 's3a://some-in-bucket/jobs/sources/mapreduce-001/reducer.py' + 's3a://some-in-bucket/jobs/sources/mapreduce-001/reducer.py', ], args=[ - '-mapper', 'mapper.py', - '-reducer', 'reducer.py', - '-numReduceTasks', '1', - '-input', 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', - '-output', 's3a://some-out-bucket/dataproc/job/results' + '-mapper', + 'mapper.py', + '-reducer', + 'reducer.py', + '-numReduceTasks', + '1', + '-input', + 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', + '-output', + 's3a://some-out-bucket/dataproc/job/results', ], properties={ 'yarn.app.mapreduce.am.resource.mb': '2048', @@ -186,22 +190,28 @@ def test_create_mapreduce_job_operator(self, create_mapreduce_job_mock, *_): context['task_instance'].xcom_pull.return_value = 'my_cluster_id' operator.execute(context) - context['task_instance'].xcom_pull.assert_has_calls([ - call(key='cluster_id'), - call(key='yandexcloud_connection_id'), - ]) + context['task_instance'].xcom_pull.assert_has_calls( + [call(key='cluster_id'), call(key='yandexcloud_connection_id'),] + ) create_mapreduce_job_mock.assert_called_once_with( archive_uris=None, args=[ - '-mapper', 'mapper.py', '-reducer', 'reducer.py', '-numReduceTasks', '1', '-input', - 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', '-output', - 's3a://some-out-bucket/dataproc/job/results' + '-mapper', + 'mapper.py', + '-reducer', + 'reducer.py', + '-numReduceTasks', + '1', + '-input', + 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', + '-output', + 's3a://some-out-bucket/dataproc/job/results', ], cluster_id='my_cluster_id', file_uris=[ 's3a://some-in-bucket/jobs/sources/mapreduce-001/mapper.py', - 's3a://some-in-bucket/jobs/sources/mapreduce-001/reducer.py' + 's3a://some-in-bucket/jobs/sources/mapreduce-001/reducer.py', ], jar_file_uris=None, main_class='org.apache.hadoop.streaming.HadoopStreaming', @@ -210,8 +220,8 @@ def test_create_mapreduce_job_operator(self, create_mapreduce_job_mock, *_): properties={ 'yarn.app.mapreduce.am.resource.mb': '2048', 'yarn.app.mapreduce.am.command-opts': '-Xmx2048m', - 'mapreduce.job.maps': '6' - } + 'mapreduce.job.maps': '6', + }, ) @patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') @@ -222,40 +232,33 @@ def test_create_spark_job_operator(self, create_spark_job_mock, *_): task_id='create_spark_job', main_jar_file_uri='s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar', main_class='ru.yandex.cloud.dataproc.examples.PopulationSparkJob', - file_uris=[ - 's3a://some-in-bucket/jobs/sources/data/config.json', - ], - archive_uris=[ - 's3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip', - ], + file_uris=['s3a://some-in-bucket/jobs/sources/data/config.json',], + archive_uris=['s3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip',], jar_file_uris=[ 's3a://some-in-bucket/jobs/sources/java/icu4j-61.1.jar', 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar', 's3a://some-in-bucket/jobs/sources/java/opencsv-4.1.jar', - 's3a://some-in-bucket/jobs/sources/java/json-20190722.jar' + 's3a://some-in-bucket/jobs/sources/java/json-20190722.jar', ], args=[ 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', 's3a://some-out-bucket/dataproc/job/results/${{JOB_ID}}', ], - properties={ - 'spark.submit.deployMode': 'cluster', - }, + properties={'spark.submit.deployMode': 'cluster',}, ) context = {'task_instance': MagicMock()} context['task_instance'].xcom_pull.return_value = 'my_cluster_id' operator.execute(context) - context['task_instance'].xcom_pull.assert_has_calls([ - call(key='cluster_id'), - call(key='yandexcloud_connection_id'), - ]) + context['task_instance'].xcom_pull.assert_has_calls( + [call(key='cluster_id'), call(key='yandexcloud_connection_id'),] + ) create_spark_job_mock.assert_called_once_with( archive_uris=['s3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip'], args=[ 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', - 's3a://some-out-bucket/dataproc/job/results/${{JOB_ID}}' + 's3a://some-out-bucket/dataproc/job/results/${{JOB_ID}}', ], cluster_id='my_cluster_id', file_uris=['s3a://some-in-bucket/jobs/sources/data/config.json'], @@ -263,12 +266,12 @@ def test_create_spark_job_operator(self, create_spark_job_mock, *_): 's3a://some-in-bucket/jobs/sources/java/icu4j-61.1.jar', 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar', 's3a://some-in-bucket/jobs/sources/java/opencsv-4.1.jar', - 's3a://some-in-bucket/jobs/sources/java/json-20190722.jar' + 's3a://some-in-bucket/jobs/sources/java/json-20190722.jar', ], main_class='ru.yandex.cloud.dataproc.examples.PopulationSparkJob', main_jar_file_uri='s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar', name='Spark job', - properties={'spark.submit.deployMode': 'cluster'} + properties={'spark.submit.deployMode': 'cluster'}, ) @patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials') @@ -278,15 +281,9 @@ def test_create_pyspark_job_operator(self, create_pyspark_job_mock, *_): operator = DataprocCreatePysparkJobOperator( task_id='create_pyspark_job', main_python_file_uri='s3a://some-in-bucket/jobs/sources/pyspark-001/main.py', - python_file_uris=[ - 's3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py', - ], - file_uris=[ - 's3a://some-in-bucket/jobs/sources/data/config.json', - ], - archive_uris=[ - 's3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip', - ], + python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py',], + file_uris=['s3a://some-in-bucket/jobs/sources/data/config.json',], + archive_uris=['s3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip',], args=[ 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', 's3a://some-out-bucket/jobs/results/${{JOB_ID}}', @@ -296,34 +293,31 @@ def test_create_pyspark_job_operator(self, create_pyspark_job_mock, *_): 's3a://some-in-bucket/jobs/sources/java/icu4j-61.1.jar', 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar', ], - properties={ - 'spark.submit.deployMode': 'cluster', - }, + properties={'spark.submit.deployMode': 'cluster',}, ) context = {'task_instance': MagicMock()} context['task_instance'].xcom_pull.return_value = 'my_cluster_id' operator.execute(context) - context['task_instance'].xcom_pull.assert_has_calls([ - call(key='cluster_id'), - call(key='yandexcloud_connection_id'), - ]) + context['task_instance'].xcom_pull.assert_has_calls( + [call(key='cluster_id'), call(key='yandexcloud_connection_id'),] + ) create_pyspark_job_mock.assert_called_once_with( archive_uris=['s3a://some-in-bucket/jobs/sources/data/country-codes.csv.zip'], args=[ 's3a://some-in-bucket/jobs/sources/data/cities500.txt.bz2', - 's3a://some-out-bucket/jobs/results/${{JOB_ID}}' + 's3a://some-out-bucket/jobs/results/${{JOB_ID}}', ], cluster_id='my_cluster_id', file_uris=['s3a://some-in-bucket/jobs/sources/data/config.json'], jar_file_uris=[ 's3a://some-in-bucket/jobs/sources/java/dataproc-examples-1.0.jar', 's3a://some-in-bucket/jobs/sources/java/icu4j-61.1.jar', - 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar' + 's3a://some-in-bucket/jobs/sources/java/commons-lang-2.6.jar', ], main_python_file_uri='s3a://some-in-bucket/jobs/sources/pyspark-001/main.py', name='Pyspark job', properties={'spark.submit.deployMode': 'cluster'}, - python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py'] + python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py'], ) diff --git a/tests/providers/zendesk/hooks/test_zendesk.py b/tests/providers/zendesk/hooks/test_zendesk.py index b03455131bea5..e17907c7e3ff3 100644 --- a/tests/providers/zendesk/hooks/test_zendesk.py +++ b/tests/providers/zendesk/hooks/test_zendesk.py @@ -26,7 +26,6 @@ class TestZendeskHook(unittest.TestCase): - @mock.patch("airflow.providers.zendesk.hooks.zendesk.time") def test_sleeps_for_correct_interval(self, mocked_time): sleep_time = 10 @@ -36,9 +35,8 @@ def test_sleeps_for_correct_interval(self, mocked_time): mock_response = mock.Mock() mock_response.headers.get.return_value = sleep_time conn_mock.call = mock.Mock( - side_effect=RateLimitError(msg="some message", - code="some code", - response=mock_response)) + side_effect=RateLimitError(msg="some message", code="some code", response=mock_response) + ) zendesk_hook = ZendeskHook("conn_id") zendesk_hook.get_conn = mock.Mock(return_value=conn_mock) @@ -56,9 +54,7 @@ def test_returns_single_page_if_get_all_pages_false(self, _): zendesk_hook.get_conn() mock_conn = mock.Mock() - mock_call = mock.Mock( - return_value={'next_page': 'https://some_host/something', - 'path': []}) + mock_call = mock.Mock(return_value={'next_page': 'https://some_host/something', 'path': []}) mock_conn.call = mock_call zendesk_hook.get_conn = mock.Mock(return_value=mock_conn) zendesk_hook.call("path", get_all_pages=False) @@ -73,9 +69,7 @@ def test_returns_multiple_pages_if_get_all_pages_true(self, _): zendesk_hook.get_conn() mock_conn = mock.Mock() - mock_call = mock.Mock( - return_value={'next_page': 'https://some_host/something', - 'path': []}) + mock_call = mock.Mock(return_value={'next_page': 'https://some_host/something', 'path': []}) mock_conn.call = mock_call zendesk_hook.get_conn = mock.Mock(return_value=mock_conn) zendesk_hook.call("path", get_all_pages=True) @@ -91,8 +85,12 @@ def test_zdesk_is_inited_correctly(self, mock_zendesk): zendesk_hook = ZendeskHook("conn_id") zendesk_hook.get_connection = mock.Mock(return_value=conn_mock) zendesk_hook.get_conn() - mock_zendesk.assert_called_once_with(zdesk_url='https://conn_host', zdesk_email='conn_login', - zdesk_password='conn_pass', zdesk_token=True) + mock_zendesk.assert_called_once_with( + zdesk_url='https://conn_host', + zdesk_email='conn_login', + zdesk_password='conn_pass', + zdesk_token=True, + ) @mock.patch("airflow.providers.zendesk.hooks.zendesk.Zendesk") def test_zdesk_sideloading_works_correctly(self, mock_zendesk): @@ -104,14 +102,16 @@ def test_zdesk_sideloading_works_correctly(self, mock_zendesk): mock_conn = mock.Mock() mock_call = mock.Mock( - return_value={'next_page': 'https://some_host/something', - 'tickets': [], - 'users': [], - 'groups': []}) + return_value={ + 'next_page': 'https://some_host/something', + 'tickets': [], + 'users': [], + 'groups': [], + } + ) mock_conn.call = mock_call zendesk_hook.get_conn = mock.Mock(return_value=mock_conn) - results = zendesk_hook.call(".../tickets.json", - query={"include": "users,groups"}, - get_all_pages=False, - side_loading=True) + results = zendesk_hook.call( + ".../tickets.json", query={"include": "users,groups"}, get_all_pages=False, side_loading=True + ) assert results == {'groups': [], 'users': [], 'tickets': []}