From 090772358a8c78362a0b916862c1f849752a93f3 Mon Sep 17 00:00:00 2001 From: richardbadman Date: Mon, 3 Jul 2023 14:14:48 +0100 Subject: [PATCH 1/5] Append region info to S3ToRedshitOperator if present It's possible to copy from S3 into Redshift across different regions, however, currently you are unable to do so with the S3ToRedshiftOperator. This PR simply makes this possible, by checking the aws connection passed has the region set in the extras part of the connection config. If this is set, it'll use this in line with the syntax defined [here](https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html) --- .../amazon/aws/transfers/s3_to_redshift.py | 9 ++-- .../aws/transfers/test_s3_to_redshift.py | 49 +++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 0d2a059f6e727..6d84bb96a88bb 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -119,13 +119,14 @@ def __init__( if arg in self.redshift_data_api_kwargs.keys(): raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") - def _build_copy_query(self, copy_destination: str, credentials_block: str, copy_options: str) -> str: + def _build_copy_query(self, copy_destination: str, credentials_block: str, region_info: str, copy_options: str) -> str: column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else "" return f""" COPY {copy_destination} {column_names} FROM 's3://{self.s3_bucket}/{self.s3_key}' credentials '{credentials_block}' + {region_info} {copy_options}; """ @@ -139,7 +140,9 @@ def execute(self, context: Context) -> None: else: redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) conn = S3Hook.get_connection(conn_id=self.aws_conn_id) - + region_info: str | None = "" + if conn.extra_dejson.get("region", False): + region_info = f"region '{conn.extra_dejson['region']}'" if conn.extra_dejson.get("role_arn", False): credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}" else: @@ -151,7 +154,7 @@ def execute(self, context: Context) -> None: destination = f"{self.schema}.{self.table}" copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination - copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options) + copy_statement = self._build_copy_query(copy_destination, credentials_block, region_info, copy_options) sql: str | Iterable[str] diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index 33d38b94e745b..8ee68808d0948 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -324,6 +324,55 @@ def test_execute_role_arn(self, mock_run, mock_session, mock_connection, mock_ho assert mock_run.call_count == 1 assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_statement) + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + def test_different_region(self, mock_run, mock_session, mock_connection, mock_hook): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + extra = {"region": "eu-central-1"} + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = Connection(extra=extra) + mock_hook.return_value = Connection(extra=extra) + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + copy_options = "" + + op = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + op.execute(None) + copy_query = """ + COPY schema.table + FROM 's3://bucket/key' + credentials + 'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key' + region 'eu-central-1' + ; + """ + + assert access_key in copy_query + assert secret_key in copy_query + assert extra["region"] in copy_query + assert mock_run.call_count == 1 + assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_query) + def test_template_fields_overrides(self): assert S3ToRedshiftOperator.template_fields == ( "s3_bucket", From fa3e9a49408dea8bf18464447be51d4048cae1d1 Mon Sep 17 00:00:00 2001 From: richardbadman Date: Mon, 3 Jul 2023 17:14:30 +0100 Subject: [PATCH 2/5] Update tests to make assertion checking valid Following on from discussion in PR, currently the way assertion is done is kind of redundant, as it's asserting static variables == other static variables. Instead, this now gets compared to what gets generated from the `_build_copy_query` function, this has been reflected for all applicable test cases in this file. --- .../aws/transfers/test_s3_to_redshift.py | 75 +++++++++++-------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index 8ee68808d0948..e659b5487d8c2 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -63,17 +63,19 @@ def test_execute(self, mock_run, mock_session, mock_connection, mock_hook): dag=None, ) op.execute(None) - copy_query = """ + expected_copy_query = """ COPY schema.table FROM 's3://bucket/key' credentials 'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key' ; """ + actual_copy_query = mock_run.call_args.args[0] + assert mock_run.call_count == 1 - assert access_key in copy_query - assert secret_key in copy_query - assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_query) + assert access_key in actual_copy_query + assert secret_key in actual_copy_query + assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection") @@ -110,17 +112,19 @@ def test_execute_with_column_list(self, mock_run, mock_session, mock_connection, dag=None, ) op.execute(None) - copy_query = """ + expected_copy_query = """ COPY schema.table (column_1, column_2) FROM 's3://bucket/key' credentials 'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key' ; """ + actual_copy_query = mock_run.call_args.args[0] + assert mock_run.call_count == 1 - assert access_key in copy_query - assert secret_key in copy_query - assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_query) + assert access_key in actual_copy_query + assert secret_key in actual_copy_query + assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection") @@ -263,18 +267,20 @@ def test_execute_sts_token(self, mock_run, mock_session, mock_connection, mock_h dag=None, ) op.execute(None) - copy_statement = """ + expected_copy_query = """ COPY schema.table FROM 's3://bucket/key' credentials 'aws_access_key_id=ASIA_aws_access_key_id;aws_secret_access_key=aws_secret_access_key;token=aws_secret_token' ; """ - assert access_key in copy_statement - assert secret_key in copy_statement - assert token in copy_statement + actual_copy_query = mock_run.call_args.args[0] + + assert access_key in actual_copy_query + assert secret_key in actual_copy_query + assert token in actual_copy_query assert mock_run.call_count == 1 - assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_statement) + assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection") @@ -312,17 +318,18 @@ def test_execute_role_arn(self, mock_run, mock_session, mock_connection, mock_ho dag=None, ) op.execute(None) - copy_statement = """ + expected_copy_query = """ COPY schema.table FROM 's3://bucket/key' credentials 'aws_iam_role=arn:aws:iam::112233445566:role/myRole' ; """ + actual_copy_query = mock_run.call_args.args[0] - assert extra["role_arn"] in copy_statement + assert extra["role_arn"] in actual_copy_query assert mock_run.call_count == 1 - assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_statement) + assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection") @@ -358,7 +365,7 @@ def test_different_region(self, mock_run, mock_session, mock_connection, mock_ho dag=None, ) op.execute(None) - copy_query = """ + expected_copy_query = """ COPY schema.table FROM 's3://bucket/key' credentials @@ -366,12 +373,13 @@ def test_different_region(self, mock_run, mock_session, mock_connection, mock_ho region 'eu-central-1' ; """ + actual_copy_query = mock_run.call_args.args[0] - assert access_key in copy_query - assert secret_key in copy_query - assert extra["region"] in copy_query + assert access_key in actual_copy_query + assert secret_key in actual_copy_query + assert extra["region"] in actual_copy_query assert mock_run.call_count == 1 - assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], copy_query) + assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) def test_template_fields_overrides(self): assert S3ToRedshiftOperator.template_fields == ( @@ -469,16 +477,8 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con ), ) op.execute(None) - copy_query = """ - COPY schema.table - FROM 's3://bucket/key' - credentials - 'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key' - ; - """ + mock_run.assert_not_called() - assert access_key in copy_query - assert secret_key in copy_query mock_rs.execute_statement.assert_called_once() # test with all args besides sql @@ -492,8 +492,21 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con StatementName=statement_name, WithEvent=False, ) + + expected_copy_query = """ + COPY schema.table + FROM 's3://bucket/key' + credentials + 'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key' + ; + """ + actual_copy_query = mock_rs.execute_statement.call_args.kwargs["Sql"] + mock_rs.describe_statement.assert_called_once_with( Id="STATEMENT_ID", ) + + assert access_key in actual_copy_query + assert secret_key in actual_copy_query # test sql arg - assert_equal_ignore_multiple_spaces(mock_rs.execute_statement.call_args.kwargs["Sql"], copy_query) + assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) From 958a048dc4d230944f18d293ef0e873447262d49 Mon Sep 17 00:00:00 2001 From: richardbadman Date: Wed, 5 Jul 2023 11:40:59 +0100 Subject: [PATCH 3/5] Remove redundant comments --- tests/providers/amazon/aws/transfers/test_s3_to_redshift.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index e659b5487d8c2..f73acf661e697 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -481,7 +481,6 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con mock_run.assert_not_called() mock_rs.execute_statement.assert_called_once() - # test with all args besides sql _call = deepcopy(mock_rs.execute_statement.call_args.kwargs) _call.pop("Sql") assert _call == dict( @@ -508,5 +507,4 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con assert access_key in actual_copy_query assert secret_key in actual_copy_query - # test sql arg assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) From 694a3f4d37bea0cf27e6493575c8a9d770424287 Mon Sep 17 00:00:00 2001 From: richardbadman Date: Thu, 6 Jul 2023 13:36:18 +0100 Subject: [PATCH 4/5] Follow static check formatting --- airflow/providers/amazon/aws/transfers/s3_to_redshift.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 6d84bb96a88bb..72f20a44b08c8 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -119,7 +119,9 @@ def __init__( if arg in self.redshift_data_api_kwargs.keys(): raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") - def _build_copy_query(self, copy_destination: str, credentials_block: str, region_info: str, copy_options: str) -> str: + def _build_copy_query( + self, copy_destination: str, credentials_block: str, region_info: str | None, copy_options: str + ) -> str: column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else "" return f""" COPY {copy_destination} {column_names} @@ -154,7 +156,9 @@ def execute(self, context: Context) -> None: destination = f"{self.schema}.{self.table}" copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination - copy_statement = self._build_copy_query(copy_destination, credentials_block, region_info, copy_options) + copy_statement = self._build_copy_query( + copy_destination, credentials_block, region_info, copy_options + ) sql: str | Iterable[str] From d3cb69b483475703cf9582daf5d941e68a9ef965 Mon Sep 17 00:00:00 2001 From: richardbadman Date: Tue, 11 Jul 2023 13:27:03 +0100 Subject: [PATCH 5/5] Remove optional parameter & None reference from method --- airflow/providers/amazon/aws/transfers/s3_to_redshift.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 72f20a44b08c8..b42b2d8cbb7f8 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -120,7 +120,7 @@ def __init__( raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") def _build_copy_query( - self, copy_destination: str, credentials_block: str, region_info: str | None, copy_options: str + self, copy_destination: str, credentials_block: str, region_info: str, copy_options: str ) -> str: column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else "" return f""" @@ -142,7 +142,7 @@ def execute(self, context: Context) -> None: else: redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) conn = S3Hook.get_connection(conn_id=self.aws_conn_id) - region_info: str | None = "" + region_info = "" if conn.extra_dejson.get("region", False): region_info = f"region '{conn.extra_dejson['region']}'" if conn.extra_dejson.get("role_arn", False):