From 9f9342648a0515a083ed3954a7ff46c6c1b57811 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Fri, 25 Feb 2022 17:30:50 +0000 Subject: [PATCH] Add ALPN support Signed-off-by: Michael Boulton --- src/paho/mqtt/client.py | 7 ++- test/lib/08-ssl-connect-alpn.py | 54 ++++++++++++++++++++++++ test/lib/Makefile | 1 + test/lib/python/08-ssl-connect-alpn.test | 39 +++++++++++++++++ test/paho_test.py | 20 ++++++--- 5 files changed, 114 insertions(+), 7 deletions(-) create mode 100755 test/lib/08-ssl-connect-alpn.py create mode 100755 test/lib/python/08-ssl-connect-alpn.test diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 1c0236e4..37ebdfb3 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -732,7 +732,7 @@ def tls_set_context(self, context=None): if hasattr(context, 'check_hostname'): self._tls_insecure = not context.check_hostname - def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None, keyfile_password=None): + def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None, keyfile_password=None, alpn_protocols=None): """Configure network encryption and authentication options. Enables SSL/TLS support. ca_certs : a string path to the Certificate Authority certificate files @@ -808,6 +808,11 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl if ciphers is not None: context.set_ciphers(ciphers) + if alpn_protocols is not None: + if not getattr(ssl, "HAS_ALPN", None): + raise ValueError("SSL library has no support for ALPN") + context.set_alpn_protocols(alpn_protocols) + self.tls_set_context(context) if cert_reqs != ssl.CERT_NONE: diff --git a/test/lib/08-ssl-connect-alpn.py b/test/lib/08-ssl-connect-alpn.py new file mode 100755 index 00000000..770fa576 --- /dev/null +++ b/test/lib/08-ssl-connect-alpn.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +# Test whether a client produces a correct connect and subsequent disconnect when using SSL. +# Client must provide a certificate. + +# The client should connect to port 1888 with keepalive=60, clean session set, +# and client id 08-ssl-connect-alpn +# It should use the CA certificate ssl/all-ca.crt for verifying the server. +# The test will send a CONNACK message to the client with rc=0. Upon receiving +# the CONNACK and verifying that rc=0, the client should send a DISCONNECT +# message. If rc!=0, the client should exit with an error. +# +# Additionally, the secure socket must have been negotiated with the "paho-test-protocol" + +import context +import paho_test +from paho_test import ssl + +context.check_ssl() + +rc = 1 +keepalive = 60 +connect_packet = paho_test.gen_connect("08-ssl-connect-alpn", keepalive=keepalive) +connack_packet = paho_test.gen_connack(rc=0) +disconnect_packet = paho_test.gen_disconnect() + +ssock = paho_test.create_server_socket_ssl(cert_reqs=ssl.CERT_REQUIRED, alpn_protocols=["paho-test-protocol"]) + +client = context.start_client() + +try: + (conn, address) = ssock.accept() + conn.settimeout(10) + + paho_test.expect_packet(conn, "connect", connect_packet) + conn.send(connack_packet) + + paho_test.expect_packet(conn, "disconnect", disconnect_packet) + rc = 0 + + if getattr(ssl, "HAS_ALPN"): + negotiated_protocol = conn.selected_alpn_protocol() + if negotiated_protocol != "paho-test-protocol": + raise Exception( + "Unexpected protocol '{}'".format(negotiated_protocol) + ) + + conn.close() +finally: + client.terminate() + client.wait() + ssock.close() + +exit(rc) diff --git a/test/lib/Makefile b/test/lib/Makefile index 1e41b367..d6ae61df 100644 --- a/test/lib/Makefile +++ b/test/lib/Makefile @@ -34,4 +34,5 @@ test : $(PYTHON) ./08-ssl-bad-cacert.py python/08-ssl-bad-cacert.test $(PYTHON) ./08-ssl-connect-cert-auth-pw.py python/08-ssl-connect-cert-auth-pw.test $(PYTHON) ./08-ssl-connect-cert-auth.py python/08-ssl-connect-cert-auth.test + $(PYTHON) ./08-ssl-connect-alpn.py python/08-ssl-connect-alpn.test $(PYTHON) ./08-ssl-connect-no-auth.py python/08-ssl-connect-no-auth.test diff --git a/test/lib/python/08-ssl-connect-alpn.test b/test/lib/python/08-ssl-connect-alpn.test new file mode 100755 index 00000000..2973426d --- /dev/null +++ b/test/lib/python/08-ssl-connect-alpn.test @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import sys + +import paho.mqtt.client as mqtt + +if sys.version_info < (2, 7, 9): + print("WARNING: SSL/TLS not supported on Python 2.6") + exit(0) + +import ssl + +if not getattr(ssl, "HAS_ALPN"): + print("ALPN not supported in this version of Python") + exit(0) + + +def on_connect(mqttc, obj, flags, rc): + if rc != 0: + exit(rc) + else: + mqttc.disconnect() + + +def on_disconnect(mqttc, obj, rc): + obj = rc + + +run = -1 +mqttc = mqtt.Client("08-ssl-connect-alpn", run) +mqttc.tls_set("../ssl/all-ca.crt", "../ssl/client.crt", "../ssl/client.key", alpn_protocols=["paho-test-protocol"]) +mqttc.on_connect = on_connect +mqttc.on_disconnect = on_disconnect + +mqttc.connect("localhost", 1888) +while run == -1: + mqttc.loop() + +exit(run) diff --git a/test/paho_test.py b/test/paho_test.py index 65c753d7..6b606a0c 100644 --- a/test/paho_test.py +++ b/test/paho_test.py @@ -34,7 +34,7 @@ def create_server_socket(): return sock -def create_server_socket_ssl(*args, **kwargs): +def create_server_socket_ssl(cert_reqs=None, alpn_protocols=None): if ssl is None: raise RuntimeError @@ -44,10 +44,18 @@ def create_server_socket_ssl(*args, **kwargs): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - ssock = ssl.wrap_socket( - sock, ca_certs="../ssl/all-ca.crt", - keyfile="../ssl/server.key", certfile="../ssl/server.crt", - server_side=True, ssl_version=ssl_version, **kwargs) + + context = ssl.SSLContext(ssl_version) + if cert_reqs is not None: + context.options |= cert_reqs + + if alpn_protocols is not None: + context.set_alpn_protocols(alpn_protocols) + + context.load_verify_locations(cafile="../ssl/all-ca.crt") + context.load_cert_chain(certfile="../ssl/server.crt", keyfile="../ssl/server.key") + + ssock = context.wrap_socket(sock, server_side=True) ssock.settimeout(10) ssock.bind(('', 1888)) ssock.listen(5) @@ -63,7 +71,7 @@ def expect_packet(sock, name, expected): packet_recvd = b"" try: while len(packet_recvd) < rlen: - data = sock.recv(rlen-len(packet_recvd)) + data = sock.recv(rlen - len(packet_recvd)) if len(data) == 0: break packet_recvd += data