diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index f6336beb..b2024f83 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -720,7 +720,17 @@ def tls_set_context(self, context=None): if hasattr(context, 'check_hostname'): self._tls_insecure = not context.check_hostname - def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None, keyfile_password=None): + def tls_set( + self, + ca_certs=None, + certfile=None, + keyfile=None, + cert_reqs=None, + tls_version=None, + ciphers=None, + keyfile_password=None, + alpn_protocols=None, + ) -> None: """Configure network encryption and authentication options. Enables SSL/TLS support. ca_certs : a string path to the Certificate Authority certificate files @@ -799,6 +809,11 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl else: context.load_default_certs() + if alpn_protocols is not None: + if not getattr(ssl, "HAS_ALPN", None): + raise ValueError("SSL library has no support for ALPN") + context.set_alpn_protocols(alpn_protocols) + self.tls_set_context(context) if cert_reqs != ssl.CERT_NONE: diff --git a/tests/lib/clients/08-ssl-connect-alpn.py b/tests/lib/clients/08-ssl-connect-alpn.py new file mode 100755 index 00000000..513f2b4c --- /dev/null +++ b/tests/lib/clients/08-ssl-connect-alpn.py @@ -0,0 +1,23 @@ +import os + +import paho.mqtt.client as mqtt + +from tests.paho_test import get_test_server_port, loop_until_keyboard_interrupt + + +def on_connect(mqttc, obj, flags, rc): + assert rc == 0, f"Connect failed ({rc})" + mqttc.disconnect() + + +mqttc = mqtt.Client("08-ssl-connect-alpn", clean_session=True) +mqttc.tls_set( + os.path.join(os.environ["PAHO_SSL_PATH"], "all-ca.crt"), + os.path.join(os.environ["PAHO_SSL_PATH"], "client.crt"), + os.path.join(os.environ["PAHO_SSL_PATH"], "client.key"), + alpn_protocols=["paho-test-protocol"], +) +mqttc.on_connect = on_connect + +mqttc.connect("localhost", get_test_server_port()) +loop_until_keyboard_interrupt(mqttc) diff --git a/tests/lib/conftest.py b/tests/lib/conftest.py index 171cd0e7..dfdd4f5a 100644 --- a/tests/lib/conftest.py +++ b/tests/lib/conftest.py @@ -32,6 +32,15 @@ def ssl_server_socket(monkeypatch): yield from _yield_server(monkeypatch, create_server_socket_ssl()) +@pytest.fixture() +def alpn_ssl_server_socket(monkeypatch): + if ssl is None: + pytest.skip("no ssl module") + if not getattr(ssl, "HAS_ALPN", False): + pytest.skip("ALPN not supported in this version of Python") + yield from _yield_server(monkeypatch, create_server_socket_ssl(alpn_protocols=["paho-test-protocol"])) + + def stop_process(proc: subprocess.Popen) -> None: if sys.platform == "win32": proc.send_signal(signal.CTRL_C_EVENT) diff --git a/tests/lib/test_08_ssl_connect_alpn.py b/tests/lib/test_08_ssl_connect_alpn.py new file mode 100755 index 00000000..af1ecc4f --- /dev/null +++ b/tests/lib/test_08_ssl_connect_alpn.py @@ -0,0 +1,38 @@ +# Test whether a client produces a correct connect and subsequent disconnect when using SSL. +# Client must provide a certificate. +# +# The client should connect with keepalive=60, clean session set, +# and client id 08-ssl-connect-alpn +# It should use the CA certificate ssl/all-ca.crt for verifying the server. +# The test will send a CONNACK message to the client with rc=0. Upon receiving +# the CONNACK and verifying that rc=0, the client should send a DISCONNECT +# message. If rc!=0, the client should exit with an error. +# +# Additionally, the secure socket must have been negotiated with the "paho-test-protocol" + + +from tests import paho_test +from tests.paho_test import ssl + + +def test_08_ssl_connect_alpn(alpn_ssl_server_socket, start_client): + connect_packet = paho_test.gen_connect("08-ssl-connect-alpn", keepalive=60) + connack_packet = paho_test.gen_connack(rc=0) + disconnect_packet = paho_test.gen_disconnect() + + start_client("08-ssl-connect-alpn.py") + + (conn, address) = alpn_ssl_server_socket.accept() + conn.settimeout(10) + + paho_test.expect_packet(conn, "connect", connect_packet) + conn.send(connack_packet) + + paho_test.expect_packet(conn, "disconnect", disconnect_packet) + + if ssl.HAS_ALPN: + negotiated_protocol = conn.selected_alpn_protocol() + if negotiated_protocol != "paho-test-protocol": + raise Exception(f"Unexpected protocol '{negotiated_protocol}'") + + conn.close() diff --git a/tests/paho_test.py b/tests/paho_test.py index f0282c32..9274fed6 100644 --- a/tests/paho_test.py +++ b/tests/paho_test.py @@ -32,7 +32,7 @@ def create_server_socket(): return (sock, port) -def create_server_socket_ssl(*, verify_mode=None): +def create_server_socket_ssl(*, verify_mode=None, alpn_protocols=None): assert ssl, "SSL not available" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -46,6 +46,9 @@ def create_server_socket_ssl(*, verify_mode=None): if verify_mode: context.verify_mode = verify_mode + if alpn_protocols is not None: + context.set_alpn_protocols(alpn_protocols) + ssock = context.wrap_socket(sock, server_side=True) ssock.settimeout(10) port = bind_to_any_free_port(ssock)