Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Port rest client v2_alpha to async/await #6483

Merged
merged 2 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6483.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Port synapse.rest.client.v2_alpha to async/await.
2 changes: 1 addition & 1 deletion synapse/rest/client/v2_alpha/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def on_POST(self, request):
"""

def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res

Expand Down
119 changes: 51 additions & 68 deletions synapse/rest/client/v2_alpha/account.py

Large diffs are not rendered by default.

30 changes: 12 additions & 18 deletions synapse/rest/client/v2_alpha/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request

Expand All @@ -41,29 +39,27 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()

@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
requester = yield self.auth.get_user_by_req(request)
async def on_PUT(self, request, user_id, account_data_type):
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")

body = parse_json_object_from_request(request)

max_id = yield self.store.add_account_data_for_user(
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)

self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

return 200, {}

@defer.inlineCallbacks
def on_GET(self, request, user_id, account_data_type):
requester = yield self.auth.get_user_by_req(request)
async def on_GET(self, request, user_id, account_data_type):
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")

event = yield self.store.get_global_account_data_by_type_for_user(
event = await self.store.get_global_account_data_by_type_for_user(
account_data_type, user_id
)

Expand Down Expand Up @@ -91,9 +87,8 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()

@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, account_data_type):
requester = yield self.auth.get_user_by_req(request)
async def on_PUT(self, request, user_id, room_id, account_data_type):
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")

Expand All @@ -106,21 +101,20 @@ def on_PUT(self, request, user_id, room_id, account_data_type):
" Use /rooms/!roomId:server.name/read_markers",
)

max_id = yield self.store.add_account_data_to_room(
max_id = await self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)

self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

return 200, {}

@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id, account_data_type):
requester = yield self.auth.get_user_by_req(request)
async def on_GET(self, request, user_id, room_id, account_data_type):
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")

event = yield self.store.get_account_data_for_room_and_type(
event = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type
)

Expand Down
17 changes: 6 additions & 11 deletions synapse/rest/client/v2_alpha/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
Expand Down Expand Up @@ -45,13 +43,12 @@ def __init__(self, hs):
self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content

@defer.inlineCallbacks
def on_GET(self, request):
async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]

token_valid = yield self.account_activity_handler.renew_account(
token_valid = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)

Expand All @@ -67,7 +64,6 @@ def on_GET(self, request):
request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8"))
finish_request(request)
defer.returnValue(None)


class AccountValiditySendMailServlet(RestServlet):
Expand All @@ -85,18 +81,17 @@ def __init__(self, hs):
self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity

@defer.inlineCallbacks
def on_POST(self, request):
async def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)

requester = yield self.auth.get_user_by_req(request, allow_expired=True)
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id)
await self.account_activity_handler.send_renewal_email_to_user(user_id)

defer.returnValue((200, {}))
return 200, {}


def register_servlets(hs, http_server):
Expand Down
9 changes: 3 additions & 6 deletions synapse/rest/client/v2_alpha/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
Expand Down Expand Up @@ -171,8 +169,7 @@ def on_GET(self, request, stagetype):
else:
raise SynapseError(404, "Unknown auth stage type")

@defer.inlineCallbacks
def on_POST(self, request, stagetype):
async def on_POST(self, request, stagetype):

session = parse_string(request, "session")
if not session:
Expand All @@ -186,7 +183,7 @@ def on_POST(self, request, stagetype):

authdict = {"response": response, "session": session}

success = yield self.auth_handler.add_oob_auth(
success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
)

Expand Down Expand Up @@ -215,7 +212,7 @@ def on_POST(self, request, stagetype):
session = request.args["session"][0]
authdict = {"session": session}

success = yield self.auth_handler.add_oob_auth(
success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
)

Expand Down
9 changes: 3 additions & 6 deletions synapse/rest/client/v2_alpha/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
import logging

from twisted.internet import defer

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet

Expand All @@ -40,10 +38,9 @@ def __init__(self, hs):
self.auth = hs.get_auth()
self.store = hs.get_datastore()

@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = yield self.store.get_user_by_id(requester.user.to_string())
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = await self.store.get_user_by_id(requester.user.to_string())
change_password = bool(user["password_hash"])

response = {
Expand Down
41 changes: 17 additions & 24 deletions synapse/rest/client/v2_alpha/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api import errors
from synapse.http.servlet import (
RestServlet,
Expand All @@ -42,10 +40,9 @@ def __init__(self, hs):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()

@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
devices = yield self.device_handler.get_devices_by_user(
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
)
return 200, {"devices": devices}
Expand All @@ -67,9 +64,8 @@ def __init__(self, hs):
self.auth_handler = hs.get_auth_handler()

@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)

try:
body = parse_json_object_from_request(request)
Expand All @@ -84,11 +80,11 @@ def on_POST(self, request):

assert_params_in_dict(body, ["devices"])

yield self.auth_handler.validate_user_via_ui_auth(
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)

yield self.device_handler.delete_devices(
await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"]
)
return 200, {}
Expand All @@ -108,18 +104,16 @@ def __init__(self, hs):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()

@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
device = yield self.device_handler.get_device(
async def on_GET(self, request, device_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
return 200, device

@interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
async def on_DELETE(self, request, device_id):
requester = await self.auth.get_user_by_req(request)

try:
body = parse_json_object_from_request(request)
Expand All @@ -132,19 +126,18 @@ def on_DELETE(self, request, device_id):
else:
raise

yield self.auth_handler.validate_user_via_ui_auth(
await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)

yield self.device_handler.delete_device(requester.user.to_string(), device_id)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}

@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
async def on_PUT(self, request, device_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)

body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
await self.device_handler.update_device(
requester.user.to_string(), device_id, body
)
return 200, {}
Expand Down
16 changes: 6 additions & 10 deletions synapse/rest/client/v2_alpha/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
Expand All @@ -35,10 +33,9 @@ def __init__(self, hs):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()

@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
async def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request)
requester = await self.auth.get_user_by_req(request)

if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users")
Expand All @@ -52,7 +49,7 @@ def on_GET(self, request, user_id, filter_id):
raise SynapseError(400, "Invalid filter_id")

try:
filter_collection = yield self.filtering.get_user_filter(
filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
)
except StoreError as e:
Expand All @@ -72,11 +69,10 @@ def __init__(self, hs):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()

@defer.inlineCallbacks
def on_POST(self, request, user_id):
async def on_POST(self, request, user_id):

target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request)
requester = await self.auth.get_user_by_req(request)

if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users")
Expand All @@ -87,7 +83,7 @@ def on_POST(self, request, user_id):
content = parse_json_object_from_request(request)
set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)

filter_id = yield self.filtering.add_user_filter(
filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content
)

Expand Down
Loading