diff --git a/sydent/hs_federation/verifier.py b/sydent/hs_federation/verifier.py index 78d9af0e..1360672d 100644 --- a/sydent/hs_federation/verifier.py +++ b/sydent/hs_federation/verifier.py @@ -69,7 +69,7 @@ def _getKeysForServer(self, server_name): defer.returnValue(self.cache[server_name]['verify_keys']) client = FederationHttpClient(self.sydent) - result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name) + result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50) if 'verify_keys' not in result: raise SignatureVerifyException("No key found in response") diff --git a/sydent/http/auth.py b/sydent/http/auth.py index c1e3463b..4ca84392 100644 --- a/sydent/http/auth.py +++ b/sydent/http/auth.py @@ -52,7 +52,7 @@ def tokenFromRequest(request): return token -def authIfV2(sydent, request, requireTermsAgreed=True): +def authV2(sydent, request, requireTermsAgreed=True): """For v2 APIs check that the request has a valid access token associated with it :param sydent: The Sydent instance to use. @@ -67,25 +67,23 @@ def authIfV2(sydent, request, requireTermsAgreed=True): :raises MatrixRestError: If the request is v2 but could not be authed or the user has not accepted terms. """ - if request.path.startswith(b'/_matrix/identity/v2'): - token = tokenFromRequest(request) + token = tokenFromRequest(request) - if token is None: - raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized") + if token is None: + raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized") - accountStore = AccountStore(sydent) + accountStore = AccountStore(sydent) - account = accountStore.getAccountByToken(token) - if account is None: - raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized") + account = accountStore.getAccountByToken(token) + if account is None: + raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized") - if requireTermsAgreed: - terms = get_terms(sydent) - if ( - terms.getMasterVersion() is not None and - account.consentVersion != terms.getMasterVersion() - ): - raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed") + if requireTermsAgreed: + terms = get_terms(sydent) + if ( + terms.getMasterVersion() is not None and + account.consentVersion != terms.getMasterVersion() + ): + raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed") - return account - return None + return account diff --git a/sydent/http/httpclient.py b/sydent/http/httpclient.py index 7358304f..2da0c5f1 100644 --- a/sydent/http/httpclient.py +++ b/sydent/http/httpclient.py @@ -25,6 +25,7 @@ from sydent.http.matrixfederationagent import MatrixFederationAgent from sydent.http.federation_tls_options import ClientTLSOptionsFactory +from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size logger = logging.getLogger(__name__) @@ -34,12 +35,15 @@ class HTTPClient(object): requests. """ @defer.inlineCallbacks - def get_json(self, uri): + def get_json(self, uri, max_size = None): """Make a GET request to an endpoint returning JSON and parse result :param uri: The URI to make a GET request to. :type uri: unicode + :param max_size: The maximum size (in bytes) to allow as a response. + :type max_size: int + :return: A deferred containing JSON parsed into a Python object. :rtype: twisted.internet.defer.Deferred[dict[any, any]] """ @@ -49,7 +53,7 @@ def get_json(self, uri): b"GET", uri.encode("utf8"), ) - body = yield readBody(response) + body = yield read_body_with_max_size(response, max_size) try: # json.loads doesn't allow bytes in Python 3.5 json_body = json.loads(body.decode("UTF-8")) @@ -94,7 +98,11 @@ def post_json_get_nothing(self, uri, post_json, opts): # Ensure the body object is read otherwise we'll leak HTTP connections # as per # https://twistedmatrix.com/documents/current/web/howto/client.html - yield readBody(response) + try: + # TODO Will this cause the server to think the request was a failure? + yield read_body_with_max_size(response, 0) + except BodyExceededMaxSize: + pass defer.returnValue(response) diff --git a/sydent/http/httpcommon.py b/sydent/http/httpcommon.py index cbcf95f8..9dde0ec1 100644 --- a/sydent/http/httpcommon.py +++ b/sydent/http/httpcommon.py @@ -15,8 +15,14 @@ # limitations under the License. import logging +from io import BytesIO import twisted.internet.ssl +from twisted.internet import defer, protocol +from twisted.internet.protocol import connectionDone +from twisted.web._newclient import ResponseDone +from twisted.web.http import PotentialDataLoss +from twisted.web.iweb import UNKNOWN_LENGTH logger = logging.getLogger(__name__) @@ -62,3 +68,98 @@ def makeTrustRoot(self): return twisted.internet._sslverify.OpenSSLCertificateAuthorities([caCert.original]) else: return twisted.internet.ssl.OpenSSLDefaultPaths() + + + +class BodyExceededMaxSize(Exception): + """The maximum allowed size of the HTTP body was exceeded.""" + + +class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): + """A protocol which immediately errors upon receiving data.""" + + def __init__(self, deferred): + self.deferred = deferred + + def _maybe_fail(self): + """ + Report a max size exceed error and disconnect the first time this is called. + """ + if not self.deferred.called: + self.deferred.errback(BodyExceededMaxSize()) + # Close the connection (forcefully) since all the data will get + # discarded anyway. + self.transport.abortConnection() + + def dataReceived(self, data) -> None: + self._maybe_fail() + + def connectionLost(self, reason) -> None: + self._maybe_fail() + + +class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): + """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" + + def __init__(self, deferred, max_size): + self.stream = BytesIO() + self.deferred = deferred + self.length = 0 + self.max_size = max_size + + def dataReceived(self, data) -> None: + # If the deferred was called, bail early. + if self.deferred.called: + return + + self.stream.write(data) + self.length += len(data) + # The first time the maximum size is exceeded, error and cancel the + # connection. dataReceived might be called again if data was received + # in the meantime. + if self.max_size is not None and self.length >= self.max_size: + self.deferred.errback(BodyExceededMaxSize()) + # Close the connection (forcefully) since all the data will get + # discarded anyway. + self.transport.abortConnection() + + def connectionLost(self, reason = connectionDone) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + + if reason.check(ResponseDone): + self.deferred.callback(self.stream.getvalue()) + elif reason.check(PotentialDataLoss): + # stolen from https://github.com/twisted/treq/pull/49/files + # http://twistedmatrix.com/trac/ticket/4840 + self.deferred.callback(self.stream.getvalue()) + else: + self.deferred.errback(reason) + + +def read_body_with_max_size(response, max_size): + """ + Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. + + If the maximum file size is reached, the returned Deferred will resolve to a + Failure with a BodyExceededMaxSize exception. + + Args: + response: The HTTP response to read from. + max_size: The maximum file size to allow. + + Returns: + A Deferred which resolves to the read body. + """ + d = defer.Deferred() + + # If the Content-Length header gives a size larger than the maximum allowed + # size, do not bother downloading the body. + if max_size is not None and response.length != UNKNOWN_LENGTH: + if response.length > max_size: + response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d)) + return d + + response.deliverBody(_ReadBodyWithMaxSizeProtocol(d, max_size)) + return d diff --git a/sydent/http/httpserver.py b/sydent/http/httpserver.py index 7ac377b1..245f3574 100644 --- a/sydent/http/httpserver.py +++ b/sydent/http/httpserver.py @@ -45,30 +45,19 @@ def __init__(self, sydent): v2 = self.sydent.servlets.v2 validate = Resource() + validate_v2 = Resource() email = Resource() + email_v2 = Resource() msisdn = Resource() - emailReqCode = self.sydent.servlets.emailRequestCode - emailValCode = self.sydent.servlets.emailValidate - msisdnReqCode = self.sydent.servlets.msisdnRequestCode - msisdnValCode = self.sydent.servlets.msisdnValidate - getValidated3pid = self.sydent.servlets.getValidated3pid - - lookup = self.sydent.servlets.lookup - bulk_lookup = self.sydent.servlets.bulk_lookup - - hash_details = self.sydent.servlets.hash_details - lookup_v2 = self.sydent.servlets.lookup_v2 + msisdn_v2 = Resource() threepid_v1 = Resource() threepid_v2 = Resource() - bind = self.sydent.servlets.threepidBind unbind = self.sydent.servlets.threepidUnbind pubkey = Resource() ephemeralPubkey = Resource() - pk_ed25519 = self.sydent.servlets.pubkey_ed25519 - root.putChild(b'_matrix', matrix) matrix.putChild(b'identity', identity) identity.putChild(b'api', api) @@ -78,33 +67,42 @@ def __init__(self, sydent): validate.putChild(b'email', email) validate.putChild(b'msisdn', msisdn) + validate_v2.putChild(b'email', email_v2) + validate_v2.putChild(b'msisdn', msisdn_v2) + v1.putChild(b'validate', validate) - v1.putChild(b'lookup', lookup) - v1.putChild(b'bulk_lookup', bulk_lookup) + v1.putChild(b'lookup', self.sydent.servlets.lookup) + v1.putChild(b'bulk_lookup', self.sydent.servlets.bulk_lookup) v1.putChild(b'pubkey', pubkey) pubkey.putChild(b'isvalid', self.sydent.servlets.pubkeyIsValid) - pubkey.putChild(b'ed25519:0', pk_ed25519) + pubkey.putChild(b'ed25519:0', self.sydent.servlets.pubkey_ed25519) pubkey.putChild(b'ephemeral', ephemeralPubkey) ephemeralPubkey.putChild(b'isvalid', self.sydent.servlets.ephemeralPubkeyIsValid) - threepid_v2.putChild(b'getValidated3pid', getValidated3pid) - threepid_v2.putChild(b'bind', bind) + threepid_v2.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pidV2) + threepid_v2.putChild(b'bind', self.sydent.servlets.threepidBindV2) threepid_v2.putChild(b'unbind', unbind) - threepid_v1.putChild(b'getValidated3pid', getValidated3pid) + threepid_v1.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pid) threepid_v1.putChild(b'unbind', unbind) if self.sydent.enable_v1_associations: - threepid_v1.putChild(b'bind', bind) + threepid_v1.putChild(b'bind', self.sydent.servlets.threepidBind) v1.putChild(b'3pid', threepid_v1) - email.putChild(b'requestToken', emailReqCode) - email.putChild(b'submitToken', emailValCode) + email.putChild(b'requestToken', self.sydent.servlets.emailRequestCode) + email.putChild(b'submitToken', self.sydent.servlets.emailValidate) + + email_v2.putChild(b'requestToken', self.sydent.servlets.emailRequestCodeV2) + email_v2.putChild(b'submitToken', self.sydent.servlets.emailValidateV2) + + msisdn.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCode) + msisdn.putChild(b'submitToken', self.sydent.servlets.msisdnValidate) - msisdn.putChild(b'requestToken', msisdnReqCode) - msisdn.putChild(b'submitToken', msisdnValCode) + msisdn_v2.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCodeV2) + msisdn_v2.putChild(b'submitToken', self.sydent.servlets.msisdnValidateV2) v1.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet) @@ -122,13 +120,13 @@ def __init__(self, sydent): account.putChild(b'logout', self.sydent.servlets.logoutServlet) # v2 versions of existing APIs - v2.putChild(b'validate', validate) + v2.putChild(b'validate', validate_v2) v2.putChild(b'pubkey', pubkey) v2.putChild(b'3pid', threepid_v2) - v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet) - v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServlet) - v2.putChild(b'lookup', lookup_v2) - v2.putChild(b'hash_details', hash_details) + v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServletV2) + v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServletV2) + v2.putChild(b'lookup', self.sydent.servlets.lookup_v2) + v2.putChild(b'hash_details', self.sydent.servlets.hash_details) self.factory = Site(root) self.factory.displayTracebacks = False diff --git a/sydent/http/matrixfederationagent.py b/sydent/http/matrixfederationagent.py index f7995c9d..bc4a968f 100644 --- a/sydent/http/matrixfederationagent.py +++ b/sydent/http/matrixfederationagent.py @@ -26,11 +26,12 @@ from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody +from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent +from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size from sydent.http.srvresolver import SrvResolver, pick_server_from_list from sydent.util.ttlcache import TTLCache @@ -46,6 +47,9 @@ # cap for .well-known cache period WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 +# The maximum size (in bytes) to allow a well-known file to be. +WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB + logger = logging.getLogger(__name__) well_known_cache = TTLCache('well-known') @@ -316,7 +320,7 @@ def _do_get_well_known(self, server_name): logger.info("Fetching %s", uri_str) try: response = yield self._well_known_agent.request(b"GET", uri) - body = yield readBody(response) + body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE) if response.code != 200: raise Exception("Non-200 response %s" % (response.code, )) @@ -334,6 +338,7 @@ def _do_get_well_known(self, server_name): cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) defer.returnValue((None, cache_period)) + return result = parsed_body["m.server"].encode("ascii") diff --git a/sydent/http/servlets/accountservlet.py b/sydent/http/servlets/accountservlet.py index f100710d..2e6f4a02 100644 --- a/sydent/http/servlets/accountservlet.py +++ b/sydent/http/servlets/accountservlet.py @@ -18,7 +18,7 @@ from twisted.web.resource import Resource from sydent.http.servlets import jsonwrap, send_cors -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 class AccountServlet(Resource): @@ -36,7 +36,7 @@ def render_GET(self, request): """ send_cors(request) - account = authIfV2(self.sydent, request) + account = authV2(self.sydent, request) return { "user_id": account.userId, @@ -45,4 +45,3 @@ def render_GET(self, request): def render_OPTIONS(self, request): send_cors(request) return b'' - diff --git a/sydent/http/servlets/blindlysignstuffservlet.py b/sydent/http/servlets/blindlysignstuffservlet.py index 306ce90e..d0623f67 100644 --- a/sydent/http/servlets/blindlysignstuffservlet.py +++ b/sydent/http/servlets/blindlysignstuffservlet.py @@ -22,7 +22,7 @@ import signedjson.sign from sydent.db.invite_tokens import JoinTokenStore from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 logger = logging.getLogger(__name__) @@ -30,16 +30,18 @@ class BlindlySignStuffServlet(Resource): isLeaf = True - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd self.server_name = syd.server_name self.tokenStore = JoinTokenStore(syd) + self.require_auth = require_auth @jsonwrap def render_POST(self, request): send_cors(request) - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) args = get_args(request, ("private_key", "token", "mxid")) diff --git a/sydent/http/servlets/bulklookupservlet.py b/sydent/http/servlets/bulklookupservlet.py index 11c3f5af..5b8cb01f 100644 --- a/sydent/http/servlets/bulklookupservlet.py +++ b/sydent/http/servlets/bulklookupservlet.py @@ -21,7 +21,6 @@ import logging from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError -from sydent.http.auth import authIfV2 logger = logging.getLogger(__name__) @@ -45,8 +44,6 @@ def render_POST(self, request): """ send_cors(request) - authIfV2(self.sydent, request) - args = get_args(request, ('threepids',)) threepids = args['threepids'] diff --git a/sydent/http/servlets/emailservlet.py b/sydent/http/servlets/emailservlet.py index 7f030890..077d15f5 100644 --- a/sydent/http/servlets/emailservlet.py +++ b/sydent/http/servlets/emailservlet.py @@ -28,20 +28,22 @@ from sydent.http.servlets import get_args, jsonwrap, send_cors -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 class EmailRequestCodeServlet(Resource): isLeaf = True - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd + self.require_auth = require_auth @jsonwrap def render_POST(self, request): send_cors(request) - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) args = get_args(request, ('email', 'client_secret', 'send_attempt')) @@ -85,8 +87,9 @@ def render_OPTIONS(self, request): class EmailValidateCodeServlet(Resource): isLeaf = True - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd + self.require_auth = require_auth def render_GET(self, request): args = get_args(request, ('nextLink',), required=False) @@ -121,7 +124,8 @@ def render_GET(self, request): def render_POST(self, request): send_cors(request) - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) return self.do_validate_request(request) diff --git a/sydent/http/servlets/getvalidated3pidservlet.py b/sydent/http/servlets/getvalidated3pidservlet.py index 95fffcca..592b37d5 100644 --- a/sydent/http/servlets/getvalidated3pidservlet.py +++ b/sydent/http/servlets/getvalidated3pidservlet.py @@ -18,7 +18,7 @@ from twisted.web.resource import Resource from sydent.http.servlets import jsonwrap, get_args -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 from sydent.db.valsession import ThreePidValSessionStore from sydent.util.stringutils import is_valid_client_secret from sydent.validators import ( @@ -32,12 +32,14 @@ class GetValidated3pidServlet(Resource): isLeaf = True - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd + self.require_auth = require_auth @jsonwrap def render_GET(self, request): - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) args = get_args(request, ('sid', 'client_secret')) diff --git a/sydent/http/servlets/hashdetailsservlet.py b/sydent/http/servlets/hashdetailsservlet.py index babff6df..0b1d0323 100644 --- a/sydent/http/servlets/hashdetailsservlet.py +++ b/sydent/http/servlets/hashdetailsservlet.py @@ -16,7 +16,7 @@ from __future__ import absolute_import from twisted.web.resource import Resource -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 import logging @@ -48,7 +48,7 @@ def render_GET(self, request): """ send_cors(request) - authIfV2(self.sydent, request) + authV2(self.sydent, request) return { "algorithms": self.known_algorithms, diff --git a/sydent/http/servlets/logoutservlet.py b/sydent/http/servlets/logoutservlet.py index feb5cba2..cb654fce 100644 --- a/sydent/http/servlets/logoutservlet.py +++ b/sydent/http/servlets/logoutservlet.py @@ -21,7 +21,7 @@ from sydent.http.servlets import jsonwrap, send_cors from sydent.db.accounts import AccountStore -from sydent.http.auth import authIfV2, tokenFromRequest +from sydent.http.auth import authV2, tokenFromRequest logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def render_POST(self, request): """ send_cors(request) - authIfV2(self.sydent, request, False) + authV2(self.sydent, request, False) token = tokenFromRequest(request) diff --git a/sydent/http/servlets/lookupservlet.py b/sydent/http/servlets/lookupservlet.py index e336eb38..352c28e4 100644 --- a/sydent/http/servlets/lookupservlet.py +++ b/sydent/http/servlets/lookupservlet.py @@ -24,7 +24,6 @@ import signedjson.sign from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError -from sydent.http.auth import authIfV2 logger = logging.getLogger(__name__) @@ -49,8 +48,6 @@ def render_GET(self, request): """ send_cors(request) - authIfV2(self.sydent, request) - args = get_args(request, ('medium', 'address')) medium = args['medium'] diff --git a/sydent/http/servlets/lookupv2servlet.py b/sydent/http/servlets/lookupv2servlet.py index a9e9674d..c66ef58f 100644 --- a/sydent/http/servlets/lookupv2servlet.py +++ b/sydent/http/servlets/lookupv2servlet.py @@ -21,7 +21,7 @@ from sydent.http.servlets import get_args, jsonwrap, send_cors from sydent.db.threepid_associations import GlobalAssociationStore -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 from sydent.http.servlets.hashdetailsservlet import HashDetailsServlet logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ def render_POST(self, request): """ send_cors(request) - authIfV2(self.sydent, request) + authV2(self.sydent, request) args = get_args(request, ('addresses', 'algorithm', 'pepper')) diff --git a/sydent/http/servlets/msisdnservlet.py b/sydent/http/servlets/msisdnservlet.py index 7571a0f9..40b9c4e3 100644 --- a/sydent/http/servlets/msisdnservlet.py +++ b/sydent/http/servlets/msisdnservlet.py @@ -29,7 +29,7 @@ ) from sydent.http.servlets import get_args, jsonwrap, send_cors -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 from sydent.util.stringutils import is_valid_client_secret @@ -39,14 +39,16 @@ class MsisdnRequestCodeServlet(Resource): isLeaf = True - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd + self.require_auth = require_auth @jsonwrap def render_POST(self, request): send_cors(request) - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) args = get_args(request, ('phone_number', 'country', 'client_secret', 'send_attempt')) @@ -107,8 +109,9 @@ def render_OPTIONS(self, request): class MsisdnValidateCodeServlet(Resource): isLeaf = True - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd + self.require_auth = require_auth def render_GET(self, request): send_cors(request) @@ -142,7 +145,8 @@ def render_GET(self, request): def render_POST(self, request): send_cors(request) - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) return self.do_validate_request(request) diff --git a/sydent/http/servlets/registerservlet.py b/sydent/http/servlets/registerservlet.py index d385a478..ebd07cce 100644 --- a/sydent/http/servlets/registerservlet.py +++ b/sydent/http/servlets/registerservlet.py @@ -51,6 +51,7 @@ def render_POST(self, request): "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" % ( args['matrix_server_name'], urllib.parse.quote(args['access_token']), ), + 1024 * 5, ) if 'sub' not in result: raise Exception("Invalid response from homeserver") diff --git a/sydent/http/servlets/store_invite_servlet.py b/sydent/http/servlets/store_invite_servlet.py index f373cbd1..d661ea0f 100644 --- a/sydent/http/servlets/store_invite_servlet.py +++ b/sydent/http/servlets/store_invite_servlet.py @@ -28,20 +28,22 @@ from sydent.db.threepid_associations import GlobalAssociationStore from sydent.http.servlets import get_args, send_cors, jsonwrap -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 from sydent.util.emailutils import sendEmail class StoreInviteServlet(Resource): - def __init__(self, syd): + def __init__(self, syd, require_auth=False): self.sydent = syd self.random = random.SystemRandom() + self.require_auth = require_auth @jsonwrap def render_POST(self, request): send_cors(request) - authIfV2(self.sydent, request) + if self.require_auth: + authV2(self.sydent, request) args = get_args(request, ("medium", "address", "room_id", "sender",)) medium = args["medium"] diff --git a/sydent/http/servlets/termsservlet.py b/sydent/http/servlets/termsservlet.py index cb325187..802d3933 100644 --- a/sydent/http/servlets/termsservlet.py +++ b/sydent/http/servlets/termsservlet.py @@ -21,7 +21,7 @@ from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError from sydent.terms.terms import get_terms -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 from sydent.db.terms import TermsStore from sydent.db.accounts import AccountStore @@ -54,7 +54,7 @@ def render_POST(self, request): """ send_cors(request) - account = authIfV2(self.sydent, request, False) + account = authV2(self.sydent, request, False) args = get_args(request, ("user_accepts",)) diff --git a/sydent/http/servlets/threepidbindservlet.py b/sydent/http/servlets/threepidbindservlet.py index 4cc02597..0e77dbd6 100644 --- a/sydent/http/servlets/threepidbindservlet.py +++ b/sydent/http/servlets/threepidbindservlet.py @@ -20,21 +20,24 @@ from sydent.db.valsession import ThreePidValSessionStore from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError -from sydent.http.auth import authIfV2 +from sydent.http.auth import authV2 from sydent.util.stringutils import is_valid_client_secret from sydent.validators import SessionExpiredException, IncorrectClientSecretException, InvalidSessionIdException,\ SessionNotValidatedException class ThreePidBindServlet(Resource): - def __init__(self, sydent): + def __init__(self, sydent, require_auth=False): self.sydent = sydent + self.require_auth = require_auth @jsonwrap def render_POST(self, request): send_cors(request) - account = authIfV2(self.sydent, request) + account = None + if self.require_auth: + account = authV2(self.sydent, request) args = get_args(request, ('sid', 'client_secret', 'mxid')) diff --git a/sydent/sydent.py b/sydent/sydent.py index 3eabd3e1..e842715e 100644 --- a/sydent/sydent.py +++ b/sydent/sydent.py @@ -278,9 +278,13 @@ def __init__(self, cfg, reactor=twisted.internet.reactor): self.servlets.v1 = V1Servlet(self) self.servlets.v2 = V2Servlet(self) self.servlets.emailRequestCode = EmailRequestCodeServlet(self) + self.servlets.emailRequestCodeV2 = EmailRequestCodeServlet(self, require_auth=True) self.servlets.emailValidate = EmailValidateCodeServlet(self) + self.servlets.emailValidateV2 = EmailValidateCodeServlet(self, require_auth=True) self.servlets.msisdnRequestCode = MsisdnRequestCodeServlet(self) + self.servlets.msisdnRequestCodeV2 = MsisdnRequestCodeServlet(self, require_auth=True) self.servlets.msisdnValidate = MsisdnValidateCodeServlet(self) + self.servlets.msisdnValidateV2 = MsisdnValidateCodeServlet(self, require_auth=True) self.servlets.lookup = LookupServlet(self) self.servlets.bulk_lookup = BulkLookupServlet(self) self.servlets.hash_details = HashDetailsServlet(self, lookup_pepper) @@ -289,11 +293,15 @@ def __init__(self, cfg, reactor=twisted.internet.reactor): self.servlets.pubkeyIsValid = PubkeyIsValidServlet(self) self.servlets.ephemeralPubkeyIsValid = EphemeralPubkeyIsValidServlet(self) self.servlets.threepidBind = ThreePidBindServlet(self) + self.servlets.threepidBindV2 = ThreePidBindServlet(self, require_auth=True) self.servlets.threepidUnbind = ThreePidUnbindServlet(self) self.servlets.replicationPush = ReplicationPushServlet(self) self.servlets.getValidated3pid = GetValidated3pidServlet(self) + self.servlets.getValidated3pidV2 = GetValidated3pidServlet(self, require_auth=True) self.servlets.storeInviteServlet = StoreInviteServlet(self) + self.servlets.storeInviteServletV2 = StoreInviteServlet(self, require_auth=True) self.servlets.blindlySignStuffServlet = BlindlySignStuffServlet(self) + self.servlets.blindlySignStuffServletV2 = BlindlySignStuffServlet(self, require_auth=True) self.servlets.termsServlet = TermsServlet(self) self.servlets.accountServlet = AccountServlet(self) self.servlets.registerServlet = RegisterServlet(self)