Skip to content

Commit

Permalink
Merge pull request #483 from KeepSafe/forced_close
Browse files Browse the repository at this point in the history
Forced close
  • Loading branch information
asvetlov committed Sep 1, 2015
2 parents 4109b05 + 91f5f05 commit d187627
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ CHANGES
* default headers in ClientSession are now case-insensitive

* Make '=' char and 'wss://' schema safe in urls #477

* `ClientResponse.close()` forces connection closing by default from now #479
N.B. Backward incompatible change: was `.close(force=False)
Using `force` parameter for the method is deprecated: use `.release()` instead.
44 changes: 24 additions & 20 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,10 @@ def start(self, connection, read_until_eof=False):
'Can not load response cookies: %s', exc)
return self

def close(self, force=False):
def close(self, force=True):
if not force:
warnings.warn("force parameter should be True", DeprecationWarning,
stacklevel=2)
if self._closed:
return

Expand All @@ -609,29 +612,30 @@ def close(self, force=False):
return

if self._connection is not None:
if self.content and not self.content.at_eof():
force = True
self._connection.close()
self._connection = None
self._cleanup_writer()

if force:
self._connection.close()
else:
@asyncio.coroutine
def release(self):
try:
content = self.content
if content is not None and not content.at_eof():
chunk = yield from content.readany()
while chunk is not EOF_MARKER or chunk:
chunk = yield from content.readany()
finally:
if self._connection is not None:
self._connection.release()
if self._reader is not None:
self._reader.unset_parser()
self._connection = None
self._cleanup_writer()

self._connection = None
def _cleanup_writer(self):
if self._writer is not None and not self._writer.done():
self._writer.cancel()
self._writer = None

@asyncio.coroutine
def release(self):
try:
chunk = yield from self.content.readany()
while chunk is not EOF_MARKER or chunk:
chunk = yield from self.content.readany()
finally:
self.close()
self._writer = None

@asyncio.coroutine
def wait_for_close(self):
Expand All @@ -640,7 +644,7 @@ def wait_for_close(self):
yield from self._writer
finally:
self._writer = None
self.close()
yield from self.release()

@asyncio.coroutine
def read(self, decode=False):
Expand All @@ -649,10 +653,10 @@ def read(self, decode=False):
try:
self._content = yield from self.content.read()
except:
self.close(True)
self.close()
raise
else:
self.close()
yield from self.release()

data = self._content

Expand Down
27 changes: 16 additions & 11 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_close(self):
self.response._connection = self.connection
self.response.close()
self.assertIsNone(self.response.connection)
self.assertTrue(self.connection.release.called)
self.response.close()
self.response.close()

Expand Down Expand Up @@ -76,7 +75,7 @@ def side_effect(*args, **kwargs):

res = self.loop.run_until_complete(self.response.read())
self.assertEqual(res, b'payload')
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_read_and_release_connection_with_error(self):
content = self.response.content = unittest.mock.Mock()
Expand All @@ -87,7 +86,7 @@ def test_read_and_release_connection_with_error(self):
self.assertRaises(
ValueError,
self.loop.run_until_complete, self.response.read())
self.response.close.assert_called_with(True)
self.response.close.assert_called_with()

def test_release(self):
fut = asyncio.Future(loop=self.loop)
Expand All @@ -97,7 +96,7 @@ def test_release(self):
self.response.close = unittest.mock.Mock()

self.loop.run_until_complete(self.response.release())
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_read_and_close(self):
self.response.read = unittest.mock.Mock()
Expand Down Expand Up @@ -133,7 +132,7 @@ def side_effect(*args, **kwargs):

res = self.loop.run_until_complete(self.response.text())
self.assertEqual(res, '{"тест": "пройден"}')
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_text_custom_encoding(self):
def side_effect(*args, **kwargs):
Expand All @@ -150,7 +149,7 @@ def side_effect(*args, **kwargs):
res = self.loop.run_until_complete(
self.response.text(encoding='cp1251'))
self.assertEqual(res, '{"тест": "пройден"}')
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)
self.assertFalse(self.response._get_encoding.called)

def test_text_detect_encoding(self):
Expand All @@ -166,7 +165,7 @@ def side_effect(*args, **kwargs):
self.loop.run_until_complete(self.response.read())
res = self.loop.run_until_complete(self.response.text())
self.assertEqual(res, '{"тест": "пройден"}')
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_text_after_read(self):
def side_effect(*args, **kwargs):
Expand All @@ -181,7 +180,7 @@ def side_effect(*args, **kwargs):

res = self.loop.run_until_complete(self.response.text())
self.assertEqual(res, '{"тест": "пройден"}')
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_json(self):
def side_effect(*args, **kwargs):
Expand All @@ -196,7 +195,7 @@ def side_effect(*args, **kwargs):

res = self.loop.run_until_complete(self.response.json())
self.assertEqual(res, {'тест': 'пройден'})
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_json_custom_loader(self):
self.response.headers = {
Expand Down Expand Up @@ -237,7 +236,7 @@ def side_effect(*args, **kwargs):
res = self.loop.run_until_complete(
self.response.json(encoding='cp1251'))
self.assertEqual(res, {'тест': 'пройден'})
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)
self.assertFalse(self.response._get_encoding.called)

def test_json_detect_encoding(self):
Expand All @@ -252,7 +251,7 @@ def side_effect(*args, **kwargs):

res = self.loop.run_until_complete(self.response.json())
self.assertEqual(res, {'тест': 'пройден'})
self.assertTrue(self.response.close.called)
self.assertIsNone(self.response._connection)

def test_override_flow_control(self):
class MyResponse(ClientResponse):
Expand All @@ -269,3 +268,9 @@ def test_get_encoding_unknown(self, m_chardet):

self.response.headers = {'CONTENT-TYPE': 'application/json'}
self.assertEqual(self.response._get_encoding(), 'utf-8')

def test_close_deprecated(self):
self.response._connection = self.connection
with self.assertWarns(DeprecationWarning):
self.response.close(force=False)
self.assertIsNone(self.response._connection)

0 comments on commit d187627

Please sign in to comment.