Skip to content

Commit

Permalink
Merge remote-tracking branch 'gitlab/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 24, 2021
2 parents d078590 + e63908a commit 809ad96
Show file tree
Hide file tree
Showing 21 changed files with 222 additions and 93 deletions.
2 changes: 1 addition & 1 deletion sydent/hs_federation/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _getKeysForServer(self, server_name):
defer.returnValue(self.cache[server_name]['verify_keys'])

client = FederationHttpClient(self.sydent)
result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name)
result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50)
if 'verify_keys' not in result:
raise SignatureVerifyException("No key found in response")

Expand Down
34 changes: 16 additions & 18 deletions sydent/http/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def tokenFromRequest(request):
return token


def authIfV2(sydent, request, requireTermsAgreed=True):
def authV2(sydent, request, requireTermsAgreed=True):
"""For v2 APIs check that the request has a valid access token associated with it
:param sydent: The Sydent instance to use.
Expand All @@ -67,25 +67,23 @@ def authIfV2(sydent, request, requireTermsAgreed=True):
:raises MatrixRestError: If the request is v2 but could not be authed or the user has
not accepted terms.
"""
if request.path.startswith(b'/_matrix/identity/v2'):
token = tokenFromRequest(request)
token = tokenFromRequest(request)

if token is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
if token is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")

accountStore = AccountStore(sydent)
accountStore = AccountStore(sydent)

account = accountStore.getAccountByToken(token)
if account is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
account = accountStore.getAccountByToken(token)
if account is None:
raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")

if requireTermsAgreed:
terms = get_terms(sydent)
if (
terms.getMasterVersion() is not None and
account.consentVersion != terms.getMasterVersion()
):
raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")
if requireTermsAgreed:
terms = get_terms(sydent)
if (
terms.getMasterVersion() is not None and
account.consentVersion != terms.getMasterVersion()
):
raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")

return account
return None
return account
14 changes: 11 additions & 3 deletions sydent/http/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sydent.http.matrixfederationagent import MatrixFederationAgent

from sydent.http.federation_tls_options import ClientTLSOptionsFactory
from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size

logger = logging.getLogger(__name__)

Expand All @@ -34,12 +35,15 @@ class HTTPClient(object):
requests.
"""
@defer.inlineCallbacks
def get_json(self, uri):
def get_json(self, uri, max_size = None):
"""Make a GET request to an endpoint returning JSON and parse result
:param uri: The URI to make a GET request to.
:type uri: unicode
:param max_size: The maximum size (in bytes) to allow as a response.
:type max_size: int
:return: A deferred containing JSON parsed into a Python object.
:rtype: twisted.internet.defer.Deferred[dict[any, any]]
"""
Expand All @@ -49,7 +53,7 @@ def get_json(self, uri):
b"GET",
uri.encode("utf8"),
)
body = yield readBody(response)
body = yield read_body_with_max_size(response, max_size)
try:
# json.loads doesn't allow bytes in Python 3.5
json_body = json.loads(body.decode("UTF-8"))
Expand Down Expand Up @@ -94,7 +98,11 @@ def post_json_get_nothing(self, uri, post_json, opts):
# Ensure the body object is read otherwise we'll leak HTTP connections
# as per
# https://twistedmatrix.com/documents/current/web/howto/client.html
yield readBody(response)
try:
# TODO Will this cause the server to think the request was a failure?
yield read_body_with_max_size(response, 0)
except BodyExceededMaxSize:
pass

defer.returnValue(response)

Expand Down
101 changes: 101 additions & 0 deletions sydent/http/httpcommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
# limitations under the License.

import logging
from io import BytesIO

import twisted.internet.ssl
from twisted.internet import defer, protocol
from twisted.internet.protocol import connectionDone
from twisted.web._newclient import ResponseDone
from twisted.web.http import PotentialDataLoss
from twisted.web.iweb import UNKNOWN_LENGTH

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,3 +68,98 @@ def makeTrustRoot(self):
return twisted.internet._sslverify.OpenSSLCertificateAuthorities([caCert.original])
else:
return twisted.internet.ssl.OpenSSLDefaultPaths()



class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded."""


class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""

def __init__(self, deferred):
self.deferred = deferred

def _maybe_fail(self):
"""
Report a max size exceed error and disconnect the first time this is called.
"""
if not self.deferred.called:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()

def dataReceived(self, data) -> None:
self._maybe_fail()

def connectionLost(self, reason) -> None:
self._maybe_fail()


class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""

def __init__(self, deferred, max_size):
self.stream = BytesIO()
self.deferred = deferred
self.length = 0
self.max_size = max_size

def dataReceived(self, data) -> None:
# If the deferred was called, bail early.
if self.deferred.called:
return

self.stream.write(data)
self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received
# in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()

def connectionLost(self, reason = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return

if reason.check(ResponseDone):
self.deferred.callback(self.stream.getvalue())
elif reason.check(PotentialDataLoss):
# stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840
self.deferred.callback(self.stream.getvalue())
else:
self.deferred.errback(reason)


def read_body_with_max_size(response, max_size):
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
If the maximum file size is reached, the returned Deferred will resolve to a
Failure with a BodyExceededMaxSize exception.
Args:
response: The HTTP response to read from.
max_size: The maximum file size to allow.
Returns:
A Deferred which resolves to the read body.
"""
d = defer.Deferred()

# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_size is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_size:
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d

response.deliverBody(_ReadBodyWithMaxSizeProtocol(d, max_size))
return d
58 changes: 28 additions & 30 deletions sydent/http/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,19 @@ def __init__(self, sydent):
v2 = self.sydent.servlets.v2

validate = Resource()
validate_v2 = Resource()
email = Resource()
email_v2 = Resource()
msisdn = Resource()
emailReqCode = self.sydent.servlets.emailRequestCode
emailValCode = self.sydent.servlets.emailValidate
msisdnReqCode = self.sydent.servlets.msisdnRequestCode
msisdnValCode = self.sydent.servlets.msisdnValidate
getValidated3pid = self.sydent.servlets.getValidated3pid

lookup = self.sydent.servlets.lookup
bulk_lookup = self.sydent.servlets.bulk_lookup

hash_details = self.sydent.servlets.hash_details
lookup_v2 = self.sydent.servlets.lookup_v2
msisdn_v2 = Resource()

threepid_v1 = Resource()
threepid_v2 = Resource()
bind = self.sydent.servlets.threepidBind
unbind = self.sydent.servlets.threepidUnbind

pubkey = Resource()
ephemeralPubkey = Resource()

pk_ed25519 = self.sydent.servlets.pubkey_ed25519

root.putChild(b'_matrix', matrix)
matrix.putChild(b'identity', identity)
identity.putChild(b'api', api)
Expand All @@ -78,33 +67,42 @@ def __init__(self, sydent):
validate.putChild(b'email', email)
validate.putChild(b'msisdn', msisdn)

validate_v2.putChild(b'email', email_v2)
validate_v2.putChild(b'msisdn', msisdn_v2)

v1.putChild(b'validate', validate)

v1.putChild(b'lookup', lookup)
v1.putChild(b'bulk_lookup', bulk_lookup)
v1.putChild(b'lookup', self.sydent.servlets.lookup)
v1.putChild(b'bulk_lookup', self.sydent.servlets.bulk_lookup)

v1.putChild(b'pubkey', pubkey)
pubkey.putChild(b'isvalid', self.sydent.servlets.pubkeyIsValid)
pubkey.putChild(b'ed25519:0', pk_ed25519)
pubkey.putChild(b'ed25519:0', self.sydent.servlets.pubkey_ed25519)
pubkey.putChild(b'ephemeral', ephemeralPubkey)
ephemeralPubkey.putChild(b'isvalid', self.sydent.servlets.ephemeralPubkeyIsValid)

threepid_v2.putChild(b'getValidated3pid', getValidated3pid)
threepid_v2.putChild(b'bind', bind)
threepid_v2.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pidV2)
threepid_v2.putChild(b'bind', self.sydent.servlets.threepidBindV2)
threepid_v2.putChild(b'unbind', unbind)

threepid_v1.putChild(b'getValidated3pid', getValidated3pid)
threepid_v1.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pid)
threepid_v1.putChild(b'unbind', unbind)
if self.sydent.enable_v1_associations:
threepid_v1.putChild(b'bind', bind)
threepid_v1.putChild(b'bind', self.sydent.servlets.threepidBind)

v1.putChild(b'3pid', threepid_v1)

email.putChild(b'requestToken', emailReqCode)
email.putChild(b'submitToken', emailValCode)
email.putChild(b'requestToken', self.sydent.servlets.emailRequestCode)
email.putChild(b'submitToken', self.sydent.servlets.emailValidate)

email_v2.putChild(b'requestToken', self.sydent.servlets.emailRequestCodeV2)
email_v2.putChild(b'submitToken', self.sydent.servlets.emailValidateV2)

msisdn.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCode)
msisdn.putChild(b'submitToken', self.sydent.servlets.msisdnValidate)

msisdn.putChild(b'requestToken', msisdnReqCode)
msisdn.putChild(b'submitToken', msisdnValCode)
msisdn_v2.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCodeV2)
msisdn_v2.putChild(b'submitToken', self.sydent.servlets.msisdnValidateV2)

v1.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)

Expand All @@ -122,13 +120,13 @@ def __init__(self, sydent):
account.putChild(b'logout', self.sydent.servlets.logoutServlet)

# v2 versions of existing APIs
v2.putChild(b'validate', validate)
v2.putChild(b'validate', validate_v2)
v2.putChild(b'pubkey', pubkey)
v2.putChild(b'3pid', threepid_v2)
v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServlet)
v2.putChild(b'lookup', lookup_v2)
v2.putChild(b'hash_details', hash_details)
v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServletV2)
v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServletV2)
v2.putChild(b'lookup', self.sydent.servlets.lookup_v2)
v2.putChild(b'hash_details', self.sydent.servlets.hash_details)

self.factory = Site(root)
self.factory.displayTracebacks = False
Expand Down
9 changes: 7 additions & 2 deletions sydent/http/matrixfederationagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent

from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size
from sydent.http.srvresolver import SrvResolver, pick_server_from_list
from sydent.util.ttlcache import TTLCache

Expand All @@ -46,6 +47,9 @@
# cap for .well-known cache period
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600

# The maximum size (in bytes) to allow a well-known file to be.
WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB

logger = logging.getLogger(__name__)
well_known_cache = TTLCache('well-known')

Expand Down Expand Up @@ -316,7 +320,7 @@ def _do_get_well_known(self, server_name):
logger.info("Fetching %s", uri_str)
try:
response = yield self._well_known_agent.request(b"GET", uri)
body = yield readBody(response)
body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code, ))

Expand All @@ -334,6 +338,7 @@ def _do_get_well_known(self, server_name):
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
defer.returnValue((None, cache_period))
return

result = parsed_body["m.server"].encode("ascii")

Expand Down
5 changes: 2 additions & 3 deletions sydent/http/servlets/accountservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from twisted.web.resource import Resource

from sydent.http.servlets import jsonwrap, send_cors
from sydent.http.auth import authIfV2
from sydent.http.auth import authV2


class AccountServlet(Resource):
Expand All @@ -36,7 +36,7 @@ def render_GET(self, request):
"""
send_cors(request)

account = authIfV2(self.sydent, request)
account = authV2(self.sydent, request)

return {
"user_id": account.userId,
Expand All @@ -45,4 +45,3 @@ def render_GET(self, request):
def render_OPTIONS(self, request):
send_cors(request)
return b''

Loading

0 comments on commit 809ad96

Please sign in to comment.