Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make the http server handle coroutine-making REST servlets (#5475)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkowl authored Jun 29, 2019
1 parent c7ff297 commit f40a7dc
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 174 deletions.
1 change: 1 addition & 0 deletions changelog.d/5475.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Synapse can now handle RestServlets that return coroutines.
77 changes: 41 additions & 36 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import cgi
import collections
import http.client
import logging

from six import PY3
from six.moves import http_client, urllib
import types
import urllib
from io import BytesIO

from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json

Expand All @@ -41,11 +42,6 @@
from synapse.util.caches import intern_dict
from synapse.util.logcontext import preserve_fn

if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
Expand Down Expand Up @@ -75,10 +71,9 @@ def wrap_json_request_handler(h):
deferred fails with any other type of error we send a 500 reponse.
"""

@defer.inlineCallbacks
def wrapped_request_handler(self, request):
async def wrapped_request_handler(self, request):
try:
yield h(self, request)
await h(self, request)
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
Expand Down Expand Up @@ -142,10 +137,12 @@ def wrap_html_request_handler(h):
where "request" must be a SynapseRequest.
"""

def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
return d
async def wrapped_request_handler(self, request):
try:
return await h(self, request)
except Exception:
f = failure.Failure()
return _return_html_error(f, request)

return wrap_async_request_handler(wrapped_request_handler)

Expand All @@ -171,7 +168,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
code = http_client.INTERNAL_SERVER_ERROR
code = http.client.INTERNAL_SERVER_ERROR
msg = "Internal server error"

logger.error(
Expand Down Expand Up @@ -201,10 +198,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""

@defer.inlineCallbacks
def wrapped_async_request_handler(self, request):
async def wrapped_async_request_handler(self, request):
with request.processing():
yield h(self, request)
await h(self, request)

# we need to preserve_fn here, because the synchronous render method won't yield for
# us (obviously)
Expand Down Expand Up @@ -270,12 +266,11 @@ def register_paths(self, method, path_patterns, callback):
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
self._async_render(request)
defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET

@wrap_json_request_handler
@defer.inlineCallbacks
def _async_render(self, request):
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
Expand All @@ -292,26 +287,19 @@ def _async_render(self, request):
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.

def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode("ascii")).decode("utf8")

kwargs = intern_dict(
{
name: _unquote(value) if value else value
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)

callback_return = yield callback(request, **kwargs)
callback_return = callback(request, **kwargs)

# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return

if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
Expand Down Expand Up @@ -360,6 +348,23 @@ def _send_response(
)


class DirectServeResource(resource.Resource):
def render(self, request):
"""
Render the request, using an asynchronous render handler if it exists.
"""
render_callback_name = "_async_render_" + request.method.decode("ascii")

if hasattr(self, render_callback_name):
# Call the handler
callback = getattr(self, render_callback_name)
defer.ensureDeferred(callback(request))

return NOT_DONE_YET
else:
super().render(request)


def _options_handler(request):
"""Request handler for OPTIONS requests
Expand Down
35 changes: 12 additions & 23 deletions synapse/rest/consent/consent_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import jinja2
from jinja2 import TemplateNotFound

from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import finish_request, wrap_html_request_handler
from synapse.http.server import (
DirectServeResource,
finish_request,
wrap_html_request_handler,
)
from synapse.http.servlet import parse_string
from synapse.types import UserID

Expand All @@ -47,7 +47,7 @@ def compare_digest(a, b):
return a == b


class ConsentResource(Resource):
class ConsentResource(DirectServeResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, hs):
Args:
hs (synapse.server.HomeServer): homeserver
"""
Resource.__init__(self)
super().__init__()

self.hs = hs
self.store = hs.get_datastore()
Expand Down Expand Up @@ -118,18 +118,12 @@ def __init__(self, hs):

self._hmac_secret = hs.config.form_secret.encode("utf-8")

def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET

@wrap_html_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
async def _async_render_GET(self, request):
"""
Args:
request (twisted.web.http.Request):
"""

version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", required=False, default="")
userhmac = None
Expand All @@ -145,7 +139,7 @@ def _async_render_GET(self, request):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()

u = yield self.store.get_user_by_id(qualified_user_id)
u = await self.store.get_user_by_id(qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")

Expand All @@ -165,13 +159,8 @@ def _async_render_GET(self, request):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")

def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET

@wrap_html_request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
async def _async_render_POST(self, request):
"""
Args:
request (twisted.web.http.Request):
Expand All @@ -188,12 +177,12 @@ def _async_render_POST(self, request):
qualified_user_id = UserID(username, self.hs.hostname).to_string()

try:
yield self.store.user_set_consent_version(qualified_user_id, version)
await self.store.user_set_consent_version(qualified_user_id, version)
except StoreError as e:
if e.code != 404:
raise
raise NotFoundError("Unknown user")
yield self.registration_handler.post_consent_actions(qualified_user_id)
await self.registration_handler.post_consent_actions(qualified_user_id)

try:
self._render_template(request, "success.html")
Expand Down
28 changes: 10 additions & 18 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@
from io import BytesIO

from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.server import (
DirectServeResource,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_json_object_from_request

logger = logging.getLogger(__name__)


class RemoteKey(Resource):
class RemoteKey(DirectServeResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
Expand Down Expand Up @@ -94,13 +96,8 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist

def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET

@wrap_json_request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
query = {server.decode("ascii"): {}}
Expand All @@ -114,20 +111,15 @@ def async_render_GET(self, request):
else:
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)

yield self.query_keys(request, query, query_remote_on_cache_miss=True)

def render_POST(self, request):
self.async_render_POST(request)
return NOT_DONE_YET
await self.query_keys(request, query, query_remote_on_cache_miss=True)

@wrap_json_request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)

query = content["server_keys"]

yield self.query_keys(request, query, query_remote_on_cache_miss=True)
await self.query_keys(request, query, query_remote_on_cache_miss=True)

@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
Expand Down
21 changes: 9 additions & 12 deletions synapse/rest/media/v1/config_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,28 @@
# limitations under the License.
#

from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)


class MediaConfigResource(Resource):
class MediaConfigResource(DirectServeResource):
isLeaf = True

def __init__(self, hs):
Resource.__init__(self)
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}

def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET

@wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
yield self.auth.get_user_by_req(request)
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)

def render_OPTIONS(self, request):
Expand Down
Loading

0 comments on commit f40a7dc

Please sign in to comment.