diff --git a/airflow/utils/compression.py b/airflow/utils/compression.py deleted file mode 100644 index 8f4946346d636..0000000000000 --- a/airflow/utils/compression.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import bz2 -import gzip -import shutil -from tempfile import NamedTemporaryFile - - -def uncompress_file(input_file_name, file_extension, dest_dir): - """Uncompress gz and bz2 files.""" - if file_extension.lower() not in (".gz", ".bz2"): - raise NotImplementedError( - f"Received {file_extension} format. Only gz and bz2 files can currently be uncompressed." - ) - if file_extension.lower() == ".gz": - fmodule = gzip.GzipFile - elif file_extension.lower() == ".bz2": - fmodule = bz2.BZ2File - with fmodule(input_file_name, mode="rb") as f_compressed, NamedTemporaryFile( - dir=dest_dir, mode="wb", delete=False - ) as f_uncompressed: - shutil.copyfileobj(f_compressed, f_uncompressed) - return f_uncompressed.name diff --git a/providers/src/airflow/providers/apache/hive/transfers/s3_to_hive.py b/providers/src/airflow/providers/apache/hive/transfers/s3_to_hive.py index 6285103d370bd..e56a244f71a21 100644 --- a/providers/src/airflow/providers/apache/hive/transfers/s3_to_hive.py +++ b/providers/src/airflow/providers/apache/hive/transfers/s3_to_hive.py @@ -22,6 +22,7 @@ import bz2 import gzip import os +import shutil import tempfile from collections.abc import Sequence from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -31,7 +32,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.apache.hive.hooks.hive import HiveCliHook -from airflow.utils.compression import uncompress_file if TYPE_CHECKING: from airflow.utils.context import Context @@ -277,3 +277,20 @@ def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir): for line in f_in: f_out.write(line) return fn_output + + +def uncompress_file(input_file_name, file_extension, dest_dir): + """Uncompress gz and bz2 files.""" + if file_extension.lower() not in (".gz", ".bz2"): + raise NotImplementedError( + f"Received {file_extension} format. Only gz and bz2 files can currently be uncompressed." + ) + if file_extension.lower() == ".gz": + fmodule = gzip.GzipFile + elif file_extension.lower() == ".bz2": + fmodule = bz2.BZ2File + with fmodule(input_file_name, mode="rb") as f_compressed, NamedTemporaryFile( + dir=dest_dir, mode="wb", delete=False + ) as f_uncompressed: + shutil.copyfileobj(f_compressed, f_uncompressed) + return f_uncompressed.name diff --git a/providers/tests/apache/hive/transfers/test_s3_to_hive.py b/providers/tests/apache/hive/transfers/test_s3_to_hive.py index 5f738f4b54dc4..eea3dd39a905a 100644 --- a/providers/tests/apache/hive/transfers/test_s3_to_hive.py +++ b/providers/tests/apache/hive/transfers/test_s3_to_hive.py @@ -30,7 +30,7 @@ import pytest from airflow.exceptions import AirflowException -from airflow.providers.apache.hive.transfers.s3_to_hive import S3ToHiveOperator +from airflow.providers.apache.hive.transfers.s3_to_hive import S3ToHiveOperator, uncompress_file boto3 = pytest.importorskip("boto3") moto = pytest.importorskip("moto") @@ -122,7 +122,7 @@ def _set_fn(self, fn, ext, header): # Helper method to fetch a file of a # certain format (file extension and header) - def _get_fn(self, ext, header): + def _get_fn(self, ext, header=None): key = self._get_key(ext, header) return self.file_names[key] @@ -279,3 +279,19 @@ def test_execute_with_select_expression(self, mock_hiveclihook): expression=select_expression, input_serialization=input_serialization, ) + + def test_uncompress_file(self): + # Testing txt file type + with pytest.raises(NotImplementedError, match="^Received .txt format. Only gz and bz2.*"): + uncompress_file( + **{"input_file_name": None, "file_extension": ".txt", "dest_dir": None}, + ) + # Testing gz file type + fn_txt = self._get_fn(".txt") + fn_gz = self._get_fn(".gz") + txt_gz = uncompress_file(fn_gz, ".gz", self.tmp_dir) + assert filecmp.cmp(txt_gz, fn_txt, shallow=False), "Uncompressed file does not match original" + # Testing bz2 file type + fn_bz2 = self._get_fn(".bz2") + txt_bz2 = uncompress_file(fn_bz2, ".bz2", self.tmp_dir) + assert filecmp.cmp(txt_bz2, fn_txt, shallow=False), "Uncompressed file does not match original" diff --git a/tests/utils/test_compression.py b/tests/utils/test_compression.py deleted file mode 100644 index 66d4806fd1142..0000000000000 --- a/tests/utils/test_compression.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import bz2 -import errno -import filecmp -import gzip -import shutil -import tempfile - -import pytest - -from airflow.utils import compression - - -class TestCompression: - @pytest.fixture(autouse=True) - def setup_attrs(self): - self.file_names = {} - header = b"Sno\tSome,Text \n" - line1 = b"1\tAirflow Test\n" - line2 = b"2\tCompressionUtil\n" - self.tmp_dir = tempfile.mkdtemp(prefix="test_utils_compression_") - # create sample txt, gz and bz2 files - with tempfile.NamedTemporaryFile(mode="wb+", dir=self.tmp_dir, delete=False) as f_txt: - self._set_fn(f_txt.name, ".txt") - f_txt.writelines([header, line1, line2]) - - fn_gz = self._get_fn(".txt") + ".gz" - with gzip.GzipFile(filename=fn_gz, mode="wb") as f_gz: - self._set_fn(fn_gz, ".gz") - f_gz.writelines([header, line1, line2]) - - fn_bz2 = self._get_fn(".txt") + ".bz2" - with bz2.BZ2File(filename=fn_bz2, mode="wb") as f_bz2: - self._set_fn(fn_bz2, ".bz2") - f_bz2.writelines([header, line1, line2]) - - yield - - try: - shutil.rmtree(self.tmp_dir) - except OSError as e: - # ENOENT - no such file or directory - if e.errno != errno.ENOENT: - raise e - - # Helper method to create a dictionary of file names and - # file extension - def _set_fn(self, fn, ext): - self.file_names[ext] = fn - - # Helper method to fetch a file of a - # certain extension - def _get_fn(self, ext): - return self.file_names[ext] - - def test_uncompress_file(self): - # Testing txt file type - with pytest.raises(NotImplementedError, match="^Received .txt format. Only gz and bz2.*"): - compression.uncompress_file( - **{"input_file_name": None, "file_extension": ".txt", "dest_dir": None}, - ) - # Testing gz file type - fn_txt = self._get_fn(".txt") - fn_gz = self._get_fn(".gz") - txt_gz = compression.uncompress_file(fn_gz, ".gz", self.tmp_dir) - assert filecmp.cmp(txt_gz, fn_txt, shallow=False), "Uncompressed file does not match original" - # Testing bz2 file type - fn_bz2 = self._get_fn(".bz2") - txt_bz2 = compression.uncompress_file(fn_bz2, ".bz2", self.tmp_dir) - assert filecmp.cmp(txt_bz2, fn_txt, shallow=False), "Uncompressed file does not match original"