diff --git a/changelog.d/444.misc b/changelog.d/444.misc new file mode 100644 index 00000000..4f884113 --- /dev/null +++ b/changelog.d/444.misc @@ -0,0 +1 @@ +Get `sydent.http.matrixfederationagent` to pass `mypy --strict`. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 119f97ff..05722726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,16 +49,11 @@ strict = true files = [ # Find files that pass with # find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null" \; -print + # TODO "sydent/*.py" "sydent/config", "sydent/db", - "sydent/http/auth.py", - "sydent/http/blacklisting_reactor.py", - "sydent/http/federation_tls_options.py", - "sydent/http/httpclient.py", - "sydent/http/httpcommon.py", - "sydent/http/httpsclient.py", - "sydent/http/httpserver.py", - "sydent/http/srvresolver.py", + "sydent/http/*.py", + # TODO "sydent/http/servlets", "sydent/hs_federation", "sydent/replication", "sydent/sms", diff --git a/stubs/twisted/internet/endpoints.pyi b/stubs/twisted/internet/endpoints.pyi new file mode 100644 index 00000000..2efb5017 --- /dev/null +++ b/stubs/twisted/internet/endpoints.pyi @@ -0,0 +1,32 @@ +from typing import Any, AnyStr, Optional + +from twisted.internet import interfaces +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import ( + IOpenSSLClientConnectionCreator, + IProtocol, + IProtocolFactory, + IStreamClientEndpoint, +) +from zope.interface import implementer + +@implementer(interfaces.IStreamClientEndpoint) +class HostnameEndpoint: + # Reactor should be a "provider of L{IReactorTCP}, L{IReactorTime} and + # either L{IReactorPluggableNameResolver} or L{IReactorPluggableResolver}." + # I don't know how to encode that in the type system. + def __init__( + self, + reactor: object, + host: AnyStr, + port: int, + timeout: float = 30, + bindAddress: Optional[bytes] = None, + attemptDelay: Optional[float] = None, + ): ... + def connect(self, protocol_factory: IProtocolFactory) -> Deferred[IProtocol]: ... + +def wrapClientTLS( + connectionCreator: IOpenSSLClientConnectionCreator, + wrappedEndpoint: IStreamClientEndpoint, +) -> IStreamClientEndpoint: ... diff --git a/stubs/twisted/python/log.pyi b/stubs/twisted/python/log.pyi index 9d9fa686..c719b2c2 100644 --- a/stubs/twisted/python/log.pyi +++ b/stubs/twisted/python/log.pyi @@ -5,5 +5,5 @@ from twisted.python.failure import Failure def err( _stuff: Union[None, Exception, Failure] = None, _why: Optional[str] = None, - **kw: Any, + **kw: object, ) -> None: ... diff --git a/stubs/twisted/web/client.pyi b/stubs/twisted/web/client.pyi index e46f994c..fce41be2 100644 --- a/stubs/twisted/web/client.pyi +++ b/stubs/twisted/web/client.pyi @@ -9,7 +9,13 @@ from twisted.internet.interfaces import ( ) from twisted.internet.task import Cooperator from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse +from twisted.web.iweb import ( + IAgent, + IAgentEndpointFactory, + IBodyProducer, + IPolicyForHTTPS, + IResponse, +) from zope.interface import implementer C = TypeVar("C") @@ -20,13 +26,23 @@ class BrowserLikePolicyForHTTPS: self, hostname: bytes, port: int ) -> IOpenSSLClientConnectionCreator: ... -class HTTPConnectionPool: ... +class HTTPConnectionPool: + persistent: bool + maxPersistentPerHost: int + cachedConnectionTimeout: float + retryAutomatically: bool + def __init__(self, reactor: object, persistent: bool = True): ... @implementer(IAgent) class Agent: + # Here and in `usingEndpointFactory`, reactor should be a "provider of + # L{IReactorTCP}, L{IReactorTime} and either + # L{IReactorPluggableNameResolver} or L{IReactorPluggableResolver}." + # I don't know how to encode that in the type system; see also + # https://github.com/Shoobx/mypy-zope/issues/58 def __init__( self, - reactor: Any, + reactor: object, contextFactory: IPolicyForHTTPS = BrowserLikePolicyForHTTPS(), connectTimeout: Optional[float] = None, bindAddress: Optional[bytes] = None, @@ -39,17 +55,20 @@ class Agent: headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, ) -> Deferred[IResponse]: ... + @classmethod + def usingEndpointFactory( + cls: Type[C], + reactor: object, + endpointFactory: IAgentEndpointFactory, + pool: Optional[HTTPConnectionPool] = None, + ) -> C: ... @implementer(IBodyProducer) class FileBodyProducer: def __init__( self, inputFile: BinaryIO, - # Type safety: twisted.internet.task.cooperate is a function with the - # same signature as Cooperator.cooperate. (It just wraps a module-level - # global cooperator.) But there's no easy way to annotate "either this - # type or a specific module". - cooperator: Cooperator = twisted.internet.task, # type: ignore[assignment] + cooperator: Cooperator = ..., readSize: int = 2 ** 16, ): ... # Length is either `int` or the opaque object UNKNOWN_LENGTH. @@ -95,3 +114,14 @@ class URI: ): ... @classmethod def fromBytes(cls: Type[C], uri: bytes, defaultPort: Optional[int] = None) -> C: ... + +@implementer(IAgent) +class RedirectAgent: + def __init__(self, Agent: Agent, redirectLimit: int = 20): ... + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> Deferred[IResponse]: ... diff --git a/stubs/twisted/web/http.pyi b/stubs/twisted/web/http.pyi index 7d809836..487ef069 100644 --- a/stubs/twisted/web/http.pyi +++ b/stubs/twisted/web/http.pyi @@ -51,3 +51,5 @@ class Request: class PotentialDataLoss(Exception): ... CACHED: object + +def stringToDatetime(dateString: bytes) -> int: ... diff --git a/sydent/http/httpclient.py b/sydent/http/httpclient.py index ca2182cb..7001f74f 100644 --- a/sydent/http/httpclient.py +++ b/sydent/http/httpclient.py @@ -177,7 +177,12 @@ class FederationHttpClient(HTTPClient[MatrixFederationAgent]): def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.agent = MatrixFederationAgent( - BlacklistingReactorWrapper( + # Type-safety: I don't have a good way of expressing that + # the reactor is IReactorTCP, IReactorTime and + # IReactorPluggableNameResolver all at once. But it is, because + # it wraps the sydent reactor. + # TODO: can we introduce a SydentReactor type like SynapseReactor? + BlacklistingReactorWrapper( # type: ignore[arg-type] reactor=self.sydent.reactor, ip_whitelist=sydent.config.general.ip_whitelist, ip_blacklist=sydent.config.general.ip_blacklist, diff --git a/sydent/http/matrixfederationagent.py b/sydent/http/matrixfederationagent.py index ff07f7e7..ddd96a85 100644 --- a/sydent/http/matrixfederationagent.py +++ b/sydent/http/matrixfederationagent.py @@ -15,19 +15,31 @@ import logging import random import time -from typing import Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple import attr -from netaddr import IPAddress # type: ignore +from netaddr import IPAddress from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, Response +from twisted.internet.interfaces import ( + IProtocol, + IProtocolFactory, + IReactorTime, + IStreamClientEndpoint, +) +from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer +from twisted.web.iweb import ( + IAgent, + IAgentEndpointFactory, + IBodyProducer, + IPolicyForHTTPS, + IResponse, +) from zope.interface import implementer +from sydent.http.federation_tls_options import ClientTLSOptionsFactory from sydent.http.httpcommon import read_body_with_max_size from sydent.http.srvresolver import SrvResolver, pick_server_from_list from sydent.util import json_decoder @@ -49,7 +61,7 @@ WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB logger = logging.getLogger(__name__) -well_known_cache = TTLCache("well-known") +well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known") @implementer(IAgent) @@ -59,32 +71,28 @@ class MatrixFederationAgent: Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) :param reactor: twisted reactor to use for underlying requests - :type reactor: IReactor :param tls_client_options_factory: Factory to use for fetching client tls options, or none to disable TLS. - :type tls_client_options_factory: ClientTLSOptionsFactory, None :param _well_known_tls_policy: TLS policy to use for fetching .well-known files. None to use a default (browser-like) implementation. - :type _well_known_tls_policy: IPolicyForHTTPS, None - - :param _srv_resolver: SRVResolver impl to use for looking up SRV records. - None to use a default implementation. - :type _srv_resolver: SrvResolver, None :param _well_known_cache: TTLCache impl for storing cached well-known lookups. Omit to use a default implementation. - :type _well_known_cache: TTLCache """ def __init__( self, - reactor, - tls_client_options_factory, - _well_known_tls_policy=None, - _srv_resolver: Optional["SrvResolver"] = None, - _well_known_cache: "TTLCache" = well_known_cache, + # This reactor should also be IReactorTCP and IReactorPluggableNameResolver + # because it eventually makes its way to HostnameEndpoint.__init__. + # But that's not easy to express with an annotation. We use the + # `seconds` attribute below, so mark this as IReactorTime for now. + reactor: IReactorTime, + tls_client_options_factory: Optional[ClientTLSOptionsFactory], + _well_known_tls_policy: Optional[IPolicyForHTTPS] = None, + _srv_resolver: Optional[SrvResolver] = None, + _well_known_cache: TTLCache[bytes, Optional[bytes]] = well_known_cache, ) -> None: self._reactor = reactor @@ -98,15 +106,15 @@ def __init__( self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 - agent_args = {} if _well_known_tls_policy is not None: # the param is called 'contextFactory', but actually passing a # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args["contextFactory"] = _well_known_tls_policy - _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args), - ) - self._well_known_agent = _well_known_agent + _well_known_agent = Agent( + self._reactor, pool=self._pool, contextFactory=_well_known_tls_policy + ) + else: + _well_known_agent = Agent(self._reactor, pool=self._pool) + self._well_known_agent = RedirectAgent(_well_known_agent) # our cache of .well-known lookup results, mapping from server name # to delegated name. The values can be: @@ -121,7 +129,7 @@ def request( uri: bytes, headers: Optional["Headers"] = None, bodyProducer: Optional["IBodyProducer"] = None, - ) -> Response: + ) -> Generator["defer.Deferred[Any]", Any, IResponse]: """ :param method: HTTP method (GET/POST/etc). @@ -141,7 +149,8 @@ def request( (including problems that prevent the request from being sent). """ parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri)) + routing: _RoutingResult + routing = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri)) # set up the TLS connection params # @@ -152,32 +161,37 @@ def request( tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + routing.tls_server_name.decode("ascii") ) # make sure that the Host header is set correctly if headers is None: headers = Headers() else: - headers = headers.copy() + # Type safety: Headers.copy doesn't have a return type annotated, + # and I don't want to stub web.http_headers. Could use stubgen? It's + # a pretty simple file. + headers = headers.copy() # type: ignore[no-untyped-call] assert headers is not None if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", res.host_header) + headers.addRawHeader(b"host", routing.host_header) + @implementer(IAgentEndpointFactory) class EndpointFactory: @staticmethod - def endpointForURI(_uri): - ep = LoggingHostnameEndpoint( + def endpointForURI(_uri: URI) -> IStreamClientEndpoint: + ep: IStreamClientEndpoint = LoggingHostnameEndpoint( self._reactor, - res.target_host, - res.target_port, + routing.target_host, + routing.target_port, ) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) return ep agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) + res: IResponse res = yield agent.request(method, uri, headers, bodyProducer) return res @@ -232,9 +246,11 @@ async def _route_matrix_uri( # parse the server name in the .well-known response into host/port. # (This code is lifted from twisted.web.client.URI.fromBytes). if b":" in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b":", 1) + well_known_host, well_known_port_raw = well_known_server.rsplit( + b":", 1 + ) try: - well_known_port = int(well_known_port) + well_known_port = int(well_known_port_raw) except ValueError: # the part after the colon could not be parsed as an int # - we assume it is an IPv6 literal with no port (the closing @@ -308,7 +324,7 @@ async def _get_well_known(self, server_name: bytes) -> Optional[bytes]: async def _do_get_well_known( self, server_name: bytes - ) -> Tuple[Union[bytes, None, object], int]: + ) -> Tuple[Optional[bytes], float]: """Actually fetch and parse a .well-known, without checking the cache :param server_name: Name of the server, from the requested url @@ -321,6 +337,7 @@ async def _do_get_well_known( uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) + cache_period: Optional[float] try: response = await self._well_known_agent.request(b"GET", uri) body = await read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE) @@ -338,7 +355,7 @@ async def _do_get_well_known( # add some randomness to the TTL to avoid a stampeding herd every hour # after startup - cache_period: float = WELL_KNOWN_INVALID_CACHE_PERIOD + cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) return (None, cache_period) @@ -363,27 +380,33 @@ async def _do_get_well_known( class LoggingHostnameEndpoint: """A wrapper for HostnameEndpint which logs when it connects""" - def __init__(self, reactor, host, port, *args, **kwargs): + def __init__( + self, reactor: IReactorTime, host: bytes, port: int, *args: Any, **kwargs: Any + ): self.host = host self.port = port self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs) logger.info("Endpoint created with %s:%d", host, port) - def connect(self, protocol_factory): + def connect( + self, protocol_factory: IProtocolFactory + ) -> "defer.Deferred[IProtocol]": logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port) return self.ep.connect(protocol_factory) -def _cache_period_from_headers(headers, time_now=time.time): +def _cache_period_from_headers( + headers: Headers, time_now: Callable[[], float] = time.time +) -> Optional[float]: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: return 0 - if b"max-age" in cache_controls: + max_age = cache_controls.get(b"max-age") + if max_age is not None: try: - max_age = int(cache_controls[b"max-age"]) - return max_age + return int(max_age) except ValueError: pass @@ -401,8 +424,8 @@ def _cache_period_from_headers(headers, time_now=time.time): return None -def _parse_cache_control(headers): - cache_controls = {} +def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: + cache_controls: Dict[bytes, Optional[bytes]] = {} for hdr in headers.getRawHeaders(b"cache-control", []): for directive in hdr.split(b","): splits = [x.strip() for x in directive.split(b"=", 1)] @@ -412,7 +435,7 @@ def _parse_cache_control(headers): return cache_controls -@attr.s +@attr.s(frozen=True, slots=True, auto_attribs=True) class _RoutingResult: """The result returned by `_route_matrix_uri`. Contains the parameters needed to direct a federation connection to a particular @@ -421,30 +444,26 @@ class _RoutingResult: chosen from the list. """ - host_header = attr.ib() + host_header: bytes """ The value we should assign to the Host header (host:port from the matrix URI, or .well-known). - :type: bytes """ - tls_server_name = attr.ib() + tls_server_name: bytes """ The server name we should set in the SNI (typically host, without port, from the matrix URI or .well-known) - :type: bytes """ - target_host = attr.ib() + target_host: bytes """ The hostname (or IP literal) we should route the TCP connection to (the target of the SRV record, or the hostname from the URL/.well-known) - :type: bytes """ - target_port = attr.ib() + target_port: int """ The port we should route the TCP connection to (the target of the SRV record, or the port from the URL/.well-known, or 8448) - :type: int """