Skip to content

Commit

Permalink
Accept custom HTTP headers when reading over HTTP(S) (#272)
Browse files Browse the repository at this point in the history
* Accept http headers

* Swapped headers var for the not global one

* Changed headers to be a local var that is used if passed in

* Headers global var copy vs assignment

* flipped if else for headers check

* Added header information to open() in http.py

* Added tests

* Added headers to seekable buffered input as well

* fix test

* add top-level test

* flake8 for http.py

* respond to code review

* update help.txt
  • Loading branch information
ampersand-five authored and mpenkov committed Jul 18, 2019
1 parent e2e1bea commit 0f17449
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
6 changes: 5 additions & 1 deletion help.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ FUNCTIONS
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
headers: dict, optional
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
set this variable to an empty dict, ``{}``.

WebHDFS (for details, see :mod:`smart_open.webhdfs` and :func:`smart_open.webhdfs.open`):

Expand Down Expand Up @@ -266,6 +270,6 @@ VERSION
1.8.4

FILE
/Users/misha/git/smart_open/smart_open/__init__.py
/home/misha/git/smart_open/smart_open/__init__.py


37 changes: 27 additions & 10 deletions smart_open/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""


def open(uri, mode, kerberos=False, user=None, password=None):
def open(uri, mode, kerberos=False, user=None, password=None, headers=None):
"""Implement streamed reader from a web site.
Supports Kerberos and Basic HTTP authentication.
Expand All @@ -39,20 +39,29 @@ def open(uri, mode, kerberos=False, user=None, password=None):
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
headers: dict, optional
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
set this variable to an empty dict, ``{}``.
Note
----
If neither kerberos or (user, password) are set, will connect unauthenticated.
If neither kerberos or (user, password) are set, will connect
unauthenticated, unless set separately in headers.
"""
if mode == 'rb':
return BufferedInputBase(uri, mode, kerberos=kerberos, user=user, password=password)
return BufferedInputBase(
uri, mode, kerberos=kerberos,
user=user, password=password, headers=headers
)
else:
raise NotImplementedError('http support for mode %r not implemented' % mode)


class BufferedInputBase(io.BufferedIOBase):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, headers=None):
if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
Expand All @@ -64,7 +73,12 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=Fals
self.buffer_size = buffer_size
self.mode = mode

self.response = requests.get(url, auth=auth, stream=True, headers=_HEADERS)
if headers is None:
self.headers = _HEADERS.copy()
else:
self.headers = headers

self.response = requests.get(url, auth=auth, stream=True, headers=self.headers)

if not self.response.ok:
self.response.raise_for_status()
Expand Down Expand Up @@ -154,7 +168,7 @@ class SeekableBufferedInputBase(BufferedInputBase):
"""

def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None):
kerberos=False, user=None, password=None, headers=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
Otherwise, will try to use "basic" HTTP authentication via username/password.
Expand All @@ -171,6 +185,11 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
else:
self.auth = None

if headers is None:
self.headers = _HEADERS.copy()
else:
self.headers = headers

self.buffer_size = buffer_size
self.mode = mode
self.response = self._partial_request()
Expand Down Expand Up @@ -253,10 +272,8 @@ def truncate(self, size=None):
raise io.UnsupportedOperation

def _partial_request(self, start_pos=None):
headers = _HEADERS.copy()

if start_pos is not None:
headers.update({"range": s3.make_range_string(start_pos)})
self.headers.update({"range": s3.make_range_string(start_pos)})

response = requests.get(self.url, auth=self.auth, stream=True, headers=headers)
response = requests.get(self.url, auth=self.auth, stream=True, headers=self.headers)
return response
24 changes: 24 additions & 0 deletions smart_open/tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,27 @@ def test_seek_from_end(self):
read_bytes = reader.read(size=10)
self.assertEqual(reader.tell(), len(BYTES))
self.assertEqual(BYTES[-10:], read_bytes)

@responses.activate
def test_headers_are_as_assigned(self):
responses.add_callback(responses.GET, URL, callback=request_callback)

# use default _HEADERS
x = smart_open.http.BufferedInputBase(URL)
# set different ones
x.headers['Accept-Encoding'] = 'compress, gzip'
x.headers['Other-Header'] = 'value'

# use default again, global shoudn't overwritten from x
y = smart_open.http.BufferedInputBase(URL)
# should be default headers
self.assertEqual(y.headers, {'Accept-Encoding': 'identity'})
# should be assigned headers
self.assertEqual(x.headers, {'Accept-Encoding': 'compress, gzip', 'Other-Header': 'value'})

@responses.activate
def test_headers(self):
"""Does the top-level http.open function handle headers correctly?"""
responses.add_callback(responses.GET, URL, callback=request_callback)
reader = smart_open.http.open(URL, 'rb', headers={'Foo': 'bar'})
self.assertEqual(reader.headers['Foo'], 'bar')

0 comments on commit 0f17449

Please sign in to comment.