From d3a292ca077e6e9ecc9d839a9a865ccaa5a54b9b Mon Sep 17 00:00:00 2001 From: Lawson Woods Date: Tue, 6 Aug 2024 20:28:25 -0700 Subject: [PATCH] test harness improvement --- imdreader/IMDClient.py | 41 +++++++----------- imdreader/IMDREADER.py | 2 +- imdreader/tests/test_imdreader.py | 26 ++++++++++-- imdreader/tests/utils.py | 70 ++++++++++++++++++++++--------- 4 files changed, 89 insertions(+), 50 deletions(-) diff --git a/imdreader/IMDClient.py b/imdreader/IMDClient.py index 93467af..3ab345d 100644 --- a/imdreader/IMDClient.py +++ b/imdreader/IMDClient.py @@ -53,20 +53,19 @@ def stop(self): def _connect_to_server(self, host, port, socket_bufsize): """ - Establish connection with the server, failing out if this - does not occur within 5 seconds. + Establish connection with the server, failing out instantly if server is not running """ conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if socket_bufsize is not None: conn.setsockopt( socket.SOL_SOCKET, socket.SO_RCVBUF, self._socket_bufsize ) - conn.settimeout(60) try: + logger.debug(f"IMDClient: Connecting to {host}:{port}") conn.connect((host, port)) except ConnectionRefusedError: raise ConnectionRefusedError( - f"IMDReader: Connection to {host}:{port} refused" + f"IMDClient: Connection to {host}:{port} refused" ) return conn @@ -226,10 +225,8 @@ def run(self): self._unpause() self._paused = False - logger.debug(f"IMDProducer: Attempting to get timestep") imdf = self._buf.pop_empty_imdframe() - logger.debug(f"IMDProducer: Attempting to read nrg and pos") # NOTE: This can be replaced with a simple parser if # the server doesn't send the final frame with all data # as in xtc @@ -238,35 +235,28 @@ def run(self): ) read_into_buf(self._conn, self._energies) - logger.debug("read energy data") imdf.energies.update( parse_energy_bytes( self._energies, self._imdsinfo.endianness ) ) - logger.debug(f"IMDProducer: added energies to {imdf.energies}") - - logger.debug(f"IMDProducer: added energies to imdf") self._expect_header( IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms ) - logger.debug(f"IMDProducer: Expected header") read_into_buf(self._conn, self._positions) - logger.debug(f"IMDProducer: attempting to load ts") - imdf.positions = np.frombuffer( self._positions, dtype=f"{self._imdsinfo.endianness}f" ).reshape((self._n_atoms, 3)) - logger.debug(f"IMDProducer: ts loaded- inserting it") + logger.debug( + f"IMDProducer: positions for frame {self._frame}: {imdf.positions}" + ) self._buf.push_full_imdframe(imdf) - logger.debug(f"IMDProducer: ts inserted") - self._frame += 1 except EOFError: # Don't raise error if simulation ended in a way @@ -276,7 +266,7 @@ def run(self): pass finally: - logger.debug("IMDProducer: simulation ended") + logger.debug("IMDProducer: Simulation ended, cleaning up") # Tell reader not to expect more frames to be added self._buf.notify_producer_finished() @@ -290,11 +280,8 @@ def _expect_header(self, expected_type, expected_value=None): read_into_buf(self._conn, self._header) - logger.debug(f"IMDProducer: header: {self._header}") header = IMDHeader(self._header) - logger.debug(f"IMDProducer: header parsed") - if header.type != expected_type: raise RuntimeError @@ -353,7 +340,7 @@ def __init__( imdf_memsize = imdframe_memsize(n_atoms, imdsinfo) self._total_imdf = buffer_size // imdf_memsize logger.debug( - f"IMDFRAMEBuffer: Total timesteps allocated: {self._total_imdf}" + f"IMDFrameBuffer: Total timesteps allocated: {self._total_imdf}" ) if self._total_imdf == 0: raise ValueError( @@ -436,9 +423,11 @@ def pop_full_imdframe(self): imdf = self._full_q.get() - self._prev_empty_imdf = imdf + logger.debug( + f"IMDFrameBuffer: positions for frame {self._frame}: {imdf.positions}" + ) - logger.debug(f"IMDReader: Got frame {self._frame}") + self._prev_empty_imdf = imdf return imdf @@ -519,17 +508,17 @@ def read_into_buf(sock, buf) -> bool: # Server called close() # Server is definitely done sending frames logger.debug( - "IMDProducer: recv excepting due to server calling close()" + "read_into_buf excepting due to server calling close()" ) raise EOFError except TimeoutError: # Server is *likely* done sending frames - logger.debug("IMDProducer: recv excepting due to timeout") + logger.debug("read_into_buf excepting due to timeout") raise EOFError except BlockingIOError: # Occurs when timeout is 0 in place of a TimeoutError # Server is *likely* done sending frames - logger.debug("IMDProducer: recv excepting due to blocking") + logger.debug("read_into_buf excepting due to blocking") raise EOFError total_received += received diff --git a/imdreader/IMDREADER.py b/imdreader/IMDREADER.py index 95c0234..591732a 100644 --- a/imdreader/IMDREADER.py +++ b/imdreader/IMDREADER.py @@ -105,7 +105,7 @@ def __init__( super(IMDReader, self).__init__(filename, **kwargs) - logger.debug("Reader initializing") + logger.debug("IMDReader initializing") if n_atoms is None: raise ValueError("IMDReader: n_atoms must be specified") diff --git a/imdreader/tests/test_imdreader.py b/imdreader/tests/test_imdreader.py index 3acdd65..91b80a3 100644 --- a/imdreader/tests/test_imdreader.py +++ b/imdreader/tests/test_imdreader.py @@ -7,6 +7,7 @@ import imdreader from imdreader.IMDClient import imdframe_memsize from .utils import ( + IMDServerEventType, DummyIMDServer, get_free_port, ExpectPauseLoopV2Behavior, @@ -39,6 +40,8 @@ def log_config(): logger.removeHandler(file_handler) +logger = logging.getLogger("imdreader.IMDREADER") + IMDENERGYKEYS = [ "step", "temperature", @@ -76,11 +79,18 @@ def server(self, traj): server = DummyIMDServer(traj, 2) return server - @pytest.mark.parametrize("endianness", ["<", ">"]) - def test_endianness_traj_unchanged(self, server, endianness, ref, port): + @pytest.fixture(params=[">", "<"]) + def setup_test_endianness_traj_unchanged(self, request, server, port): server.port = port - server.imdsessioninfo.endianness = endianness + server.imdsessioninfo.endianness = request.param server.start() + server.wait_for_event(IMDServerEventType.LISTENING) + return server, port + + def test_endianness_traj_unchanged( + self, setup_test_endianness_traj_unchanged, ref + ): + server, port = setup_test_endianness_traj_unchanged reader = imdreader.IMDREADER.IMDReader( f"localhost:{port}", @@ -94,6 +104,9 @@ def test_endianness_traj_unchanged(self, server, endianness, ref, port): timesteps = [] for ts in reader: + logger.debug( + f"test_imdreader: positions for frame {i}: {ts.positions}" + ) timesteps.append(ts.copy()) i += 1 @@ -106,11 +119,16 @@ def test_endianness_traj_unchanged(self, server, endianness, ref, port): assert timesteps[j].data[energy_key] == j + offset offset += 1 - def test_pause_traj_unchanged(self, server, ref, port): + @pytest.fixture + def setup_test_pause_traj_unchanged(self, server, port): server.port = port server.loop_behavior = ExpectPauseLoopV2Behavior() server.start() + server.wait_for_event(IMDServerEventType.LISTENING) + return server, port + def test_pause_traj_unchanged(self, setup_test_pause_traj_unchanged, ref): + server, port = setup_test_pause_traj_unchanged # Give the reader only 1 IMDFrame of memory # We expect the producer thread to have to # pause every frame (except the first) diff --git a/imdreader/tests/utils.py b/imdreader/tests/utils.py index cafe158..06cfcbf 100644 --- a/imdreader/tests/utils.py +++ b/imdreader/tests/utils.py @@ -10,11 +10,12 @@ import abc import imdreader import logging +from enum import Enum -logger = logging.getLogger(imdreader.IMDClient.__name__) +logger = logging.getLogger("imdreader.IMDREADER") -class Behavior(abc.ABC): +class IMDServerBehavior(abc.ABC): """Abstract base class for behaviors for the DummyIMDServer to perform. Ensure that behaviors do not contain potentially infinite loops- they should be testing for a specific sequence of events""" @@ -24,17 +25,19 @@ def perform(self, *args, **kwargs): pass -class DefaultConnectionBehavior(Behavior): +class DefaultConnectionBehavior(IMDServerBehavior): def perform(self, host, port, event_q): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind((host, port)) + logger.debug(f"DummyIMDServer: Listening on {host}:{port}") + event_q.append(IMDServerEvent(IMDServerEventType.LISTENING, 120)) s.listen(120) conn, addr = s.accept() return (conn, addr) -class DefaultHandshakeV2Behavior(Behavior): +class DefaultHandshakeV2Behavior(IMDServerBehavior): def perform(self, conn, imdsessioninfo, event_q): header = struct.pack("!i", IMDHeaderType.IMD_HANDSHAKE.value) if imdsessioninfo.endianness == "<": @@ -44,12 +47,12 @@ def perform(self, conn, imdsessioninfo, event_q): conn.sendall(header) -class DefaultHandshakeV3Behavior(Behavior): +class DefaultHandshakeV3Behavior(IMDServerBehavior): def perform(self, conn, imdsessioninfo, event_q): pass -class DefaultAwaitGoBehavior(Behavior): +class DefaultAwaitGoBehavior(IMDServerBehavior): def perform(self, conn, event_q): conn.settimeout(IMDAWAITGOTIME) head_buf = bytearray(IMDHEADERSIZE) @@ -60,7 +63,7 @@ def perform(self, conn, event_q): logger.debug("DummyIMDServer: Received IMD_GO") -class DefaultLoopV2Behavior(Behavior): +class DefaultLoopV2Behavior(IMDServerBehavior): """Default behavior doesn't allow pausing""" def perform(self, conn, traj, imdsessioninfo, event_q): @@ -68,10 +71,7 @@ def perform(self, conn, traj, imdsessioninfo, event_q): headerbuf = bytearray(IMDHEADERSIZE) paused = False - logger.debug("DummyIMDServer: Starting loop") - for i in range(len(traj)): - logger.debug(f"DummyIMDServer: generating frame {i}") energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1) @@ -92,6 +92,11 @@ def perform(self, conn, traj, imdsessioninfo, event_q): pos_header = create_header_bytes( IMDHeaderType.IMD_FCOORDS, traj.n_atoms ) + + logger.debug( + f"DummyIMDServer: positions for frame {i}: {traj[i].positions}" + ) + pos = np.ascontiguousarray( traj[i].positions, dtype=f"{imdsessioninfo.endianness}f" ).tobytes() @@ -108,8 +113,6 @@ def perform(self, conn, traj, imdsessioninfo, event_q): conn.settimeout(10) headerbuf = bytearray(IMDHEADERSIZE) - logger.debug("DummyIMDServer: Starting loop") - for i in range(len(traj)): if i != 0: read_into_buf(conn, headerbuf) @@ -147,13 +150,17 @@ def perform(self, conn, traj, imdsessioninfo, event_q): pos_header = create_header_bytes( IMDHeaderType.IMD_FCOORDS, traj.n_atoms ) + + logger.debug( + f"DummyIMDServer: positions for frame {i}: {traj[i].positions}" + ) + pos = np.ascontiguousarray( traj[i].positions, dtype=f"{imdsessioninfo.endianness}f" ).tobytes() conn.sendall(energy_header + energies) conn.sendall(pos_header + pos) - logger.debug(f"Sent frame {i}") class ExpectPauseUnpauseAfterLoopV2Behavior(DefaultLoopV2Behavior): @@ -208,13 +215,26 @@ def perform(self, conn, traj, imdsessioninfo, event_q): ) -class DefaultDisconnectBehavior(Behavior): +class DefaultDisconnectBehavior(IMDServerBehavior): def perform(self, conn, event_q): # Gromacs uses the c equivalent of the SHUT_WR flag conn.shutdown(socket.SHUT_WR) conn.close() +class IMDServerEventType(Enum): + LISTENING = 0 + + +class IMDServerEvent: + def __init__(self, event_type, data): + self.event_type = event_type + self.data = data + + def __str__(self): + return f"IMDServerEvent: {self.event_type}, {self.data}" + + def create_default_imdsinfo_v2(): return IMDSessionInfo( version=2, @@ -234,11 +254,11 @@ def create_default_imdsinfo_v2(): class DummyIMDServer(threading.Thread): """Performs the following steps in order: - 1. ConnectionBehavior.perform_connection() - 2. HandshakeBehavior.perform_handshake() - 3. AwaitGoBehavior.perform_await_go() - 4. LoopBehavior.perform_loop() - 5. DisconnectBehavior.perform_disconnect() + 1. ConnectionBehavior.perform() + 2. HandshakeBehavior.perform() + 3. AwaitGoBehavior.perform() + 4. LoopBehavior.perform() + 5. DisconnectBehavior.perform() Start the server by calling DummyIMDServer.start(). """ @@ -280,6 +300,7 @@ def __init__( self._event_q = [] def run(self): + logger.debug("DummyIMDServer: Starting") conn = self.connection_behavior.perform( self.host, self.port, self._event_q )[0] @@ -293,6 +314,17 @@ def run(self): self.disconnect_behavior.perform(conn, self._event_q) return + def wait_for_event(self, event_type, timeout=10, poll_interval=0.1): + end_time = time.time() + timeout + while time.time() < end_time: + time.sleep(poll_interval) + for event in self._event_q: + if event.event_type == event_type: + return event + raise TimeoutError( + f"DummyIMDServer: Timeout after {timeout} seconds waiting for event {event_type}" + ) + @property def port(self): return self._port