diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index 23d6f63b..ec71f87d 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -28,6 +28,7 @@ import sys import requests import io +import warnings from boto.compat import BytesIO, urlsplit, six import boto.s3.connection @@ -76,6 +77,12 @@ SYSTEM_ENCODING = sys.getdefaultencoding() +_ISSUE_146_FSTR = ( + "You have explicitly specified encoding=%(encoding)s, but smart_open does " + "not currently support decoding text via the %(scheme)s scheme. " + "Re-open the file without specifying an encoding to suppress this warning." +) + def smart_open(uri, mode="rb", **kw): """ @@ -140,11 +147,21 @@ def smart_open(uri, mode="rb", **kw): """ logger.debug('%r', locals()) + # + # This is a work-around for the problem described in Issue #144. + # If the user has explicitly specified an encoding, then assume they want + # us to open the destination in text mode, instead of the default binary. + # + # If we change the default mode to be text, and match the normal behavior + # of Py2 and 3, then the above assumption will be unnecessary. + # + if kw.get('encoding') is not None and 'b' in mode: + mode = mode.replace('b', '') + # validate mode parameter if not isinstance(mode, six.string_types): raise TypeError('mode should be a string') - if isinstance(uri, six.string_types): # this method just routes the request to classes handling the specific storage # schemes, depending on the URI protocol in `uri` @@ -157,6 +174,9 @@ def smart_open(uri, mode="rb", **kw): elif parsed_uri.scheme in ("s3", "s3n", 's3u'): return s3_open_uri(parsed_uri, mode, **kw) elif parsed_uri.scheme in ("hdfs", ): + encoding = kw.pop('encoding', None) + if encoding is not None: + warnings.warn(_ISSUE_146_FSTR % {'encoding': encoding, 'scheme': parsed_uri.scheme}) if mode in ('r', 'rb'): return HdfsOpenRead(parsed_uri, **kw) if mode in ('w', 'wb'): @@ -164,6 +184,9 @@ def smart_open(uri, mode="rb", **kw): else: raise NotImplementedError("file mode %s not supported for %r scheme", mode, parsed_uri.scheme) elif parsed_uri.scheme in ("webhdfs", ): + encoding = kw.pop('encoding', None) + if encoding is not None: + warnings.warn(_ISSUE_146_FSTR % {'encoding': encoding, 'scheme': parsed_uri.scheme}) if mode in ('r', 'rb'): return WebHdfsOpenRead(parsed_uri, **kw) elif mode in ('w', 'wb'): @@ -171,6 +194,9 @@ def smart_open(uri, mode="rb", **kw): else: raise NotImplementedError("file mode %s not supported for %r scheme", mode, parsed_uri.scheme) elif parsed_uri.scheme.startswith('http'): + encoding = kw.pop('encoding', None) + if encoding is not None: + warnings.warn(_ISSUE_146_FSTR % {'encoding': encoding, 'scheme': parsed_uri.scheme}) if mode in ('r', 'rb'): return HttpOpenRead(parsed_uri, **kw) else: diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index eb9bdf1c..8f866cb0 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -333,6 +333,17 @@ def test_hdfs(self, mock_subprocess): smart_open_object.__iter__() mock_subprocess.Popen.assert_called_with(["hdfs", "dfs", "-text", "/tmp/test.txt"], stdout=mock_subprocess.PIPE) + @mock.patch('smart_open.smart_open_lib.subprocess') + def test_hdfs_encoding(self, mock_subprocess): + """Is HDFS line iterator called correctly?""" + mock_subprocess.PIPE.return_value = "test" + with mock.patch('warnings.warn') as warn: + smart_open.smart_open("hdfs:///tmp/test.txt", encoding='utf-8') + expected = smart_open.smart_open_lib._ISSUE_146_FSTR % { + 'encoding': 'utf-8', 'scheme': 'hdfs' + } + warn.assert_called_with(expected) + @responses.activate def test_webhdfs(self): """Is webhdfs line iterator called correctly""" @@ -342,6 +353,18 @@ def test_webhdfs(self): self.assertEqual(next(iterator).decode("utf-8"), "line1") self.assertEqual(next(iterator).decode("utf-8"), "line2") + @mock.patch('smart_open.smart_open_lib.subprocess') + def test_webhdfs_encoding(self, mock_subprocess): + """Is HDFS line iterator called correctly?""" + url = "webhdfs://127.0.0.1:8440/path/file" + mock_subprocess.PIPE.return_value = "test" + with mock.patch('warnings.warn') as warn: + smart_open.smart_open(url, encoding='utf-8') + expected = smart_open.smart_open_lib._ISSUE_146_FSTR % { + 'encoding': 'utf-8', 'scheme': 'webhdfs' + } + warn.assert_called_with(expected) + @responses.activate def test_webhdfs_read(self): """Does webhdfs read method work correctly""" @@ -1015,6 +1038,46 @@ def test_gzip_read_mode(self): smart_open.s3_open_uri(uri, "r") mock_open.assert_called_with('bucket', 'key.gz', 'rb') + @mock_s3 + def test_read_encoding(self): + """Should open the file with the correct encoding, explicit text read.""" + conn = boto.connect_s3() + conn.create_bucket('test-bucket') + key = "s3://bucket/key.txt" + text = u'это знала ева, это знал адам, колеса любви едут прямо по нам' + with smart_open.smart_open(key, 'wb') as fout: + fout.write(text.encode('koi8-r')) + with smart_open.smart_open(key, 'r', encoding='koi8-r') as fin: + actual = fin.read() + self.assertEqual(text, actual) + + @mock_s3 + def test_read_encoding_implicit_text(self): + """Should open the file with the correct encoding, implicit text read.""" + conn = boto.connect_s3() + conn.create_bucket('test-bucket') + key = "s3://bucket/key.txt" + text = u'это знала ева, это знал адам, колеса любви едут прямо по нам' + with smart_open.smart_open(key, 'wb') as fout: + fout.write(text.encode('koi8-r')) + with smart_open.smart_open(key, encoding='koi8-r') as fin: + actual = fin.read() + self.assertEqual(text, actual) + + @mock_s3 + def test_write_encoding(self): + """Should open the file for writing with the correct encoding.""" + conn = boto.connect_s3() + conn.create_bucket('test-bucket') + key = "s3://bucket/key.txt" + text = u'какая боль, какая боль, аргентина - ямайка, 5-0' + + with smart_open.smart_open(key, 'w', encoding='koi8-r') as fout: + fout.write(text) + with smart_open.smart_open(key, encoding='koi8-r') as fin: + actual = fin.read() + self.assertEqual(text, actual) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)