Skip to content

Commit

Permalink
Merge pull request #1808 from andrewwhitehead/fix/put-redirect
Browse files Browse the repository at this point in the history
Fix put_file when the server returns a redirect
  • Loading branch information
swcurran authored Jun 15, 2022
2 parents 592dfd0 + d5ac9dd commit c308329
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 26 deletions.
72 changes: 58 additions & 14 deletions aries_cloudagent/utils/http.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
"""HTTP utility methods."""

import asyncio

from aiohttp import BaseConnector, ClientError, ClientResponse, ClientSession
import logging
import urllib.parse

from aiohttp import (
BaseConnector,
ClientError,
ClientResponse,
ClientSession,
FormData,
)
from aiohttp.web import HTTPConflict

from ..core.error import BaseError

from .repeat import RepeatSequence


LOGGER = logging.getLogger(__name__)


class FetchError(BaseError):
"""Error raised when an HTTP fetch fails."""

Expand Down Expand Up @@ -147,7 +158,6 @@ async def put_file(
"""
(data_key, file_path) = [k for k in file_data.items()][0]
data = {**extra_data}
limit = max_attempts if retry else 1

if not session:
Expand All @@ -158,17 +168,51 @@ async def put_file(
async for attempt in RepeatSequence(limit, interval, backoff):
try:
async with attempt.timeout(request_timeout):
with open(file_path, "rb") as f:
data[data_key] = f
response: ClientResponse = await session.put(url, data=data)
if (response.status < 200 or response.status >= 300) and (
response.status != HTTPConflict.status_code
):
raise ClientError(
f"Bad response from server: {response.status}, "
f"{response.reason}"
)
formdata = FormData()
try:
fp = open(file_path, "rb")
except OSError as e:
raise PutError("Error opening file for upload") from e
if extra_data:
for k, v in extra_data.items():
formdata.add_field(k, v)
formdata.add_field(
data_key, fp, content_type="application/octet-stream"
)
response: ClientResponse = await session.put(
url, data=formdata, allow_redirects=False
)
if (
# redirect codes
response.status in (301, 302, 303, 307, 308)
and not attempt.final
):
# NOTE: a redirect counts as another upload attempt
to_url = response.headers.get("Location")
if not to_url:
raise PutError("Redirect missing target URL")
try:
parsed_to = urllib.parse.urlsplit(to_url)
parsed_from = urllib.parse.urlsplit(url)
except ValueError:
raise PutError("Invalid redirect URL")
if parsed_to.hostname != parsed_from.hostname:
raise PutError("Redirect denied: hostname mismatch")
url = to_url
LOGGER.info("Upload redirect: %s", to_url)
elif (response.status < 200 or response.status >= 300) and (
response.status != HTTPConflict.status_code
):
raise ClientError(
f"Bad response from server: {response.status}, "
f"{response.reason}"
)
else:
return await (response.json() if json else response.text())
except (ClientError, asyncio.TimeoutError) as e:
if isinstance(e, ClientError):
LOGGER.warning("Upload error: %s", e)
else:
LOGGER.warning("Upload error: request timed out")
if attempt.final:
raise PutError("Exceeded maximum put attempts") from e
raise PutError("Exceeded maximum upload attempts") from e
69 changes: 57 additions & 12 deletions aries_cloudagent/utils/tests/test_http.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
import os
import tempfile

from aiohttp import web
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
from asynctest import mock as async_mock, mock_open
from aiohttp.test_utils import AioHTTPTestCase

from ..http import fetch, fetch_stream, FetchError, put_file, PutError


class TempFile:
def __init__(self):
self.name = None

def __enter__(self):
file = tempfile.NamedTemporaryFile(delete=False)
file.write(b"test")
file.close()
self.name = file.name
return self.name

def __exit__(self, *args):
if self.name:
os.unlink(self.name)


class TestTransportUtils(AioHTTPTestCase):
async def setUpAsync(self):
self.fail_calls = 0
self.succeed_calls = 0
self.redirects = 0
await super().setUpAsync()

async def get_application(self):
Expand All @@ -19,19 +38,30 @@ async def get_application(self):
web.get("/succeed", self.succeed_route),
web.put("/fail", self.fail_route),
web.put("/succeed", self.succeed_route),
web.put("/redirect", self.redirect_route),
]
)
return app

async def fail_route(self, request):
self.fail_calls += 1
# avoid aiohttp test server issue: https://github.com/aio-libs/aiohttp/issues/3968
await request.read()
raise web.HTTPForbidden()

async def succeed_route(self, request):
self.succeed_calls += 1
ret = web.json_response([True])
return ret

async def redirect_route(self, request):
if self.redirects > 0:
self.redirects -= 1
# avoid aiohttp test server issue: https://github.com/aio-libs/aiohttp/issues/3968
await request.read()
raise web.HTTPRedirection(f"http://localhost:{self.server.port}/success")
return await self.succeed_route(request)

async def test_fetch_stream(self):
server_addr = f"http://localhost:{self.server.port}"
stream = await fetch_stream(
Expand Down Expand Up @@ -84,40 +114,55 @@ async def test_fetch_fail(self):
)
assert self.fail_calls == 2

async def test_put_file(self):
async def test_put_file_with_session(self):
server_addr = f"http://localhost:{self.server.port}"
with async_mock.patch("builtins.open", mock_open(read_data="data")):
with TempFile() as tails:
result = await put_file(
f"{server_addr}/succeed",
{"tails": "/tmp/dummy/path"},
{"tails": tails},
{"genesis": "..."},
session=self.client.session,
json=True,
)
assert result == [1]
assert result == [True]
assert self.succeed_calls == 1

async def test_put_file_default_client(self):
server_addr = f"http://localhost:{self.server.port}"
with async_mock.patch("builtins.open", mock_open(read_data="data")):
with TempFile() as tails:
result = await put_file(
f"{server_addr}/succeed",
{"tails": "/tmp/dummy/path"},
{"tails": tails},
{"genesis": "..."},
json=True,
)
assert result == [1]
assert result == [True]
assert self.succeed_calls == 1

async def test_put_file_fail(self):
server_addr = f"http://localhost:{self.server.port}"
with async_mock.patch("builtins.open", mock_open(read_data="data")):
with TempFile() as tails:
with self.assertRaises(PutError):
result = await put_file(
_ = await put_file(
f"{server_addr}/fail",
{"tails": "/tmp/dummy/path"},
{"tails": tails},
{"genesis": "..."},
max_attempts=2,
json=True,
)
assert self.fail_calls == 2

async def test_put_file_redirect(self):
server_addr = f"http://localhost:{self.server.port}"
self.redirects = 1
with TempFile() as tails:
result = await put_file(
f"{server_addr}/redirect",
{"tails": tails},
{"genesis": "..."},
max_attempts=2,
json=True,
)
assert result == [True]
assert self.succeed_calls == 1
assert self.redirects == 0

0 comments on commit c308329

Please sign in to comment.