Skip to content

Commit

Permalink
add ssl unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Florian Agbuya <[email protected]>
  • Loading branch information
fsagbuya committed Jan 6, 2025
1 parent 094a6cd commit 794f617
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 22 deletions.
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
version = "1.8";
src = self;
propagatedBuildInputs = with pkgs.python3Packages; [ pybase64 numpy ];
nativeCheckInputs = [ pkgs.openssl ];
};
sipyco-aarch64 = (with nixpkgs.legacyPackages.aarch64-linux; python3Packages.buildPythonPackage {
inherit (sipyco) pname version src;
Expand Down
39 changes: 39 additions & 0 deletions sipyco/test/ssl_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import atexit
import tempfile
import subprocess
from pathlib import Path


def create_ssl_certs():
"""Generate temporary SSL certificates for testing.
Returns dict with cert paths or None if certs exist in env."""
if all(x in os.environ for x in ['SERVER_KEY', 'SERVER_CERT', 'CLIENT_KEY', 'CLIENT_CERT']):
return None

cert_dir = tempfile.mkdtemp()
for cert_name in ["server", "client"]:
subprocess.run([
"openssl", "req", "-x509", "-newkey", "rsa:2048",
"-keyout", os.path.join(cert_dir, f"{cert_name}.key"),
"-nodes",
"-out", os.path.join(cert_dir, f"{cert_name}.pem"),
"-sha256", "-days", "1",
"-subj", "/"
], check=True)

certs = {
"SERVER_KEY": os.path.join(cert_dir, "server.key"),
"SERVER_CERT": os.path.join(cert_dir, "server.pem"),
"CLIENT_KEY": os.path.join(cert_dir, "client.key"),
"CLIENT_CERT": os.path.join(cert_dir, "client.pem")
}

def cleanup():
if os.path.exists(cert_dir):
for file in Path(cert_dir).glob("*"):
file.unlink()
os.rmdir(cert_dir)

atexit.register(cleanup)
return certs
43 changes: 37 additions & 6 deletions sipyco/test/test_pc_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import sys
import time
import unittest
import os

import numpy as np

from sipyco import pc_rpc, pyon
from sipyco.test.ssl_utils import create_ssl_certs


test_address = "::1"
Expand All @@ -18,7 +20,15 @@


class RPCCase(unittest.TestCase):
def _run_server_and_test(self, test, *args):
@classmethod
def setUpClass(cls):
cls.ssl_certs = create_ssl_certs()

def _run_server_and_test(self, test, *args, ssl_certs=None):
env = os.environ.copy()
if ssl_certs:
env.update(ssl_certs)

# running this file outside of unittest starts the echo server
with subprocess.Popen([sys.executable,
sys.modules[__name__].__file__]) as proc:
Expand All @@ -31,12 +41,19 @@ def _run_server_and_test(self, test, *args):
proc.kill()
raise

def _blocking_echo(self, target, die_using_sys_exit=False):
def _blocking_echo(self, target, die_using_sys_exit=False, ssl_certs=None):
ssl_args = {}
if ssl_certs:
ssl_args = {
"local_cert": ssl_certs["CLIENT_CERT"],
"local_key": ssl_certs["CLIENT_KEY"],
"peer_cert": ssl_certs["SERVER_CERT"]
}

for attempt in range(100):
time.sleep(.2)
try:
remote = pc_rpc.Client(test_address, test_port,
target)
remote = pc_rpc.Client(test_address, test_port, target, **ssl_args)
except ConnectionRefusedError:
pass
else:
Expand All @@ -61,18 +78,29 @@ def _blocking_echo(self, target, die_using_sys_exit=False):
def test_blocking_echo(self):
self._run_server_and_test(self._blocking_echo, "test")

def test_ssl_blocking_echo(self):
self._run_server_and_test(self._blocking_echo, "test", ssl_certs=self.ssl_certs)

def test_sys_exit(self):
self._run_server_and_test(self._blocking_echo, "test", True)

def test_blocking_echo_autotarget(self):
self._run_server_and_test(self._blocking_echo, pc_rpc.AutoTarget)

async def _asyncio_echo(self, target):
async def _asyncio_echo(self, target, ssl_certs=None):
remote = pc_rpc.AsyncioClient()
ssl_args = {}
if ssl_certs:
ssl_args = {
"local_cert": ssl_certs["CLIENT_CERT"],
"local_key": ssl_certs["CLIENT_KEY"],
"peer_cert": ssl_certs["SERVER_CERT"]
}

for attempt in range(100):
await asyncio.sleep(.2)
try:
await remote.connect_rpc(test_address, test_port, target)
await remote.connect_rpc(test_address, test_port, target, **ssl_args)
except ConnectionRefusedError:
pass
else:
Expand Down Expand Up @@ -101,6 +129,9 @@ def _loop_asyncio_echo(self, target):
def test_asyncio_echo(self):
self._run_server_and_test(self._loop_asyncio_echo, "test")

def test_ssl_asyncio_echo(self):
self._run_server_and_test(self._loop_asyncio_echo, "test", ssl_certs=self.ssl_certs)

def test_asyncio_echo_autotarget(self):
self._run_server_and_test(self._loop_asyncio_echo, pc_rpc.AutoTarget)

Expand Down
49 changes: 36 additions & 13 deletions sipyco/test/test_rpctool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sys
import asyncio
import unittest
import os

from sipyco.pc_rpc import Server
from sipyco.test.ssl_utils import create_ssl_certs


class Target:
Expand All @@ -11,25 +13,46 @@ def output_value(self):


class TestRPCTool(unittest.TestCase):
async def check_value(self):
proc = await asyncio.create_subprocess_exec(
sys.executable, "-m", "sipyco.sipyco_rpctool", "::1", "7777", "call", "output_value",
stdout = asyncio.subprocess.PIPE)
@classmethod
def setUpClass(cls):
cls.ssl_certs = create_ssl_certs()

def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

async def check_value(self, ssl_certs=None):
cmd = [sys.executable, "-m", "sipyco.sipyco_rpctool", "::1", "7777"]
if ssl_certs:
cmd.extend(["--ssl", ssl_certs["CLIENT_CERT"],
ssl_certs["CLIENT_KEY"],
ssl_certs["SERVER_CERT"]])
cmd.extend(["call", "output_value"])

proc = await asyncio.create_subprocess_exec(*cmd,
stdout=asyncio.subprocess.PIPE, env=os.environ)
(value, err) = await proc.communicate()
self.assertEqual(value.decode('ascii').rstrip(), '4125380')
await proc.wait()

async def do_test(self):
async def do_test(self, ssl_certs=None):
ssl_args = {}
if ssl_certs:
ssl_args = {
"local_cert": ssl_certs["SERVER_CERT"],
"local_key": ssl_certs["SERVER_KEY"],
"peer_cert": ssl_certs["CLIENT_CERT"]
}
server = Server({"target": Target()})
await server.start("::1", 7777)
await self.check_value()
await server.start("::1", 7777, **ssl_args)
await self.check_value(ssl_certs)
await server.stop()

def test_rpc(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.do_test())
finally:
loop.close()
self.loop.run_until_complete(self.do_test())

def test_ssl_rpc(self):
self.loop.run_until_complete(self.do_test(ssl_certs=self.ssl_certs))

def tearDown(self):
self.loop.close()
31 changes: 28 additions & 3 deletions sipyco/test/test_sync_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from sipyco import sync_struct
from sipyco.test.ssl_utils import create_ssl_certs


test_address = "::1"
Expand Down Expand Up @@ -48,17 +49,37 @@ def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

async def _do_test_recv(self):
@classmethod
def setUpClass(cls):
cls.ssl_certs = create_ssl_certs()

async def _do_test_recv(self, ssl_certs=None):
self.init_done = asyncio.Event()
self.receiving_done = asyncio.Event()

test_dict = sync_struct.Notifier(dict())
publisher = sync_struct.Publisher({"test": test_dict})
await publisher.start(test_address, test_port)

pub_ssl = {}
sub_ssl = {}

if ssl_certs:
pub_ssl = {
"local_cert": ssl_certs["SERVER_CERT"],
"local_key": ssl_certs["SERVER_KEY"],
"peer_cert": ssl_certs["CLIENT_CERT"]
}
sub_ssl = {
"local_cert": ssl_certs["CLIENT_CERT"],
"local_key": ssl_certs["CLIENT_KEY"],
"peer_cert": ssl_certs["SERVER_CERT"]
}

await publisher.start(test_address, test_port, **pub_ssl)

subscriber = sync_struct.Subscriber("test", self.init_test_dict,
self.notify)
await subscriber.connect(test_address, test_port)
await subscriber.connect(test_address, test_port, **sub_ssl)

# Wait for the initial replication to be completed so we actually
# exercise the various actions instead of sending just one init mod.
Expand All @@ -75,5 +96,9 @@ async def _do_test_recv(self):
def test_recv(self):
self.loop.run_until_complete(self._do_test_recv())

def test_ssl_recv(self):
ssl_certs = create_ssl_certs()
self.loop.run_until_complete(self._do_test_recv(ssl_certs=self.ssl_certs))

def tearDown(self):
self.loop.close()

0 comments on commit 794f617

Please sign in to comment.