From 5bfacf81c63668ea63e7cb48f4a708a67d0ac0a2 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Thu, 19 May 2022 23:44:32 -0700 Subject: [PATCH] [Issue#22846] allow option to encode or not encode UUID when uploading from Cassandra to GCS (#23766) --- .../cloud/transfers/cassandra_to_gcs.py | 47 ++++++++++--------- .../cloud/transfers/test_cassandra_to_gcs.py | 22 ++++++--- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index 39f937203e74b..248a03d8e1770 100644 --- a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -79,6 +79,8 @@ class CassandraToGCSOperator(BaseOperator): :param query_timeout: (Optional) The amount of time, in seconds, used to execute the Cassandra query. If not set, the timeout value will be set in Session.execute() by Cassandra driver. If set to None, there is no timeout. + :param encode_uuid: (Optional) Option to encode UUID or not when upload from Cassandra to GCS. + Default is to encode UUID. """ template_fields: Sequence[str] = ( @@ -105,6 +107,7 @@ def __init__( delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, query_timeout: Union[float, None, NotSetType] = NOT_SET, + encode_uuid: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -120,6 +123,7 @@ def __init__( self.gzip = gzip self.impersonation_chain = impersonation_chain self.query_timeout = query_timeout + self.encode_uuid = encode_uuid # Default Cassandra to BigQuery type mapping CQL_TYPE_MAP = { @@ -256,13 +260,11 @@ def _upload_to_gcs(self, file_to_upload): gzip=self.gzip, ) - @classmethod - def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]: + def generate_data_dict(self, names: Iterable[str], values: Any) -> Dict[str, Any]: """Generates data structure that will be stored as file in GCS.""" - return {n: cls.convert_value(v) for n, v in zip(names, values)} + return {n: self.convert_value(v) for n, v in zip(names, values)} - @classmethod - def convert_value(cls, value: Optional[Any]) -> Optional[Any]: + def convert_value(self, value: Optional[Any]) -> Optional[Any]: """Convert value to BQ type.""" if not value: return value @@ -271,7 +273,10 @@ def convert_value(cls, value: Optional[Any]) -> Optional[Any]: elif isinstance(value, bytes): return b64encode(value).decode('ascii') elif isinstance(value, UUID): - return b64encode(value.bytes).decode('ascii') + if self.encode_uuid: + return b64encode(value.bytes).decode('ascii') + else: + return str(value) elif isinstance(value, (datetime, Date)): return str(value) elif isinstance(value, Decimal): @@ -279,51 +284,47 @@ def convert_value(cls, value: Optional[Any]) -> Optional[Any]: elif isinstance(value, Time): return str(value).split('.')[0] elif isinstance(value, (list, SortedSet)): - return cls.convert_array_types(value) + return self.convert_array_types(value) elif hasattr(value, '_fields'): - return cls.convert_user_type(value) + return self.convert_user_type(value) elif isinstance(value, tuple): - return cls.convert_tuple_type(value) + return self.convert_tuple_type(value) elif isinstance(value, OrderedMapSerializedKey): - return cls.convert_map_type(value) + return self.convert_map_type(value) else: raise AirflowException('Unexpected value: ' + str(value)) - @classmethod - def convert_array_types(cls, value: Union[List[Any], SortedSet]) -> List[Any]: + def convert_array_types(self, value: Union[List[Any], SortedSet]) -> List[Any]: """Maps convert_value over array.""" - return [cls.convert_value(nested_value) for nested_value in value] + return [self.convert_value(nested_value) for nested_value in value] - @classmethod - def convert_user_type(cls, value: Any) -> Dict[str, Any]: + def convert_user_type(self, value: Any) -> Dict[str, Any]: """ Converts a user type to RECORD that contains n fields, where n is the number of attributes. Each element in the user type class will be converted to its corresponding data type in BQ. """ names = value._fields - values = [cls.convert_value(getattr(value, name)) for name in names] - return cls.generate_data_dict(names, values) + values = [self.convert_value(getattr(value, name)) for name in names] + return self.generate_data_dict(names, values) - @classmethod - def convert_tuple_type(cls, values: Tuple[Any]) -> Dict[str, Any]: + def convert_tuple_type(self, values: Tuple[Any]) -> Dict[str, Any]: """ Converts a tuple to RECORD that contains n fields, each will be converted to its corresponding data type in bq and will be named 'field_', where index is determined by the order of the tuple elements defined in cassandra. """ names = ['field_' + str(i) for i in range(len(values))] - return cls.generate_data_dict(names, values) + return self.generate_data_dict(names, values) - @classmethod - def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]: + def convert_map_type(self, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]: """ Converts a map to a repeated RECORD that contains two fields: 'key' and 'value', each will be converted to its corresponding data type in BQ. """ converted_map = [] for k, v in zip(value.keys(), value.values()): - converted_map.append({'key': cls.convert_value(k), 'value': cls.convert_value(v)}) + converted_map.append({'key': self.convert_value(k), 'value': self.convert_value(v)}) return converted_map @classmethod diff --git a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py index b53bbb4e66cb9..ade3ea982d41a 100644 --- a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py @@ -23,6 +23,11 @@ from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator TMP_FILE_NAME = "temp-file" +TEST_BUCKET = "test-bucket" +SCHEMA = "schema.json" +FILENAME = "data.json" +CQL = "select * from keyspace1.table1" +TASK_ID = "test-cas-to-gcs" class TestCassandraToGCS(unittest.TestCase): @@ -30,16 +35,16 @@ class TestCassandraToGCS(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.GCSHook.upload") @mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraHook") def test_execute(self, mock_hook, mock_upload, mock_tempfile): - test_bucket = "test-bucket" - schema = "schema.json" - filename = "data.json" + test_bucket = TEST_BUCKET + schema = SCHEMA + filename = FILENAME gzip = True query_timeout = 20 mock_tempfile.return_value.name = TMP_FILE_NAME operator = CassandraToGCSOperator( - task_id="test-cas-to-gcs", - cql="select * from keyspace1.table1", + task_id=TASK_ID, + cql=CQL, bucket=test_bucket, filename=filename, schema_filename=schema, @@ -70,7 +75,10 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile): mock_upload.assert_has_calls([call_schema, call_data], any_order=True) def test_convert_value(self): - op = CassandraToGCSOperator + op = CassandraToGCSOperator(task_id=TASK_ID, bucket=TEST_BUCKET, cql=CQL, filename=FILENAME) + unencoded_uuid_op = CassandraToGCSOperator( + task_id=TASK_ID, bucket=TEST_BUCKET, cql=CQL, filename=FILENAME, encode_uuid=False + ) assert op.convert_value(None) is None assert op.convert_value(1) == 1 assert op.convert_value(1.0) == 1.0 @@ -95,6 +103,8 @@ def test_convert_value(self): test_uuid = uuid.uuid4() encoded_uuid = b64encode(test_uuid.bytes).decode("ascii") assert op.convert_value(test_uuid) == encoded_uuid + unencoded_uuid = str(test_uuid) + assert unencoded_uuid_op.convert_value(test_uuid) == unencoded_uuid byte_str = b"abc" encoded_b = b64encode(byte_str).decode("ascii")