Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
test harness improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwoods2 committed Aug 7, 2024
1 parent ec3de2c commit d3a292c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 50 deletions.
41 changes: 15 additions & 26 deletions imdreader/IMDClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,19 @@ def stop(self):

def _connect_to_server(self, host, port, socket_bufsize):
"""
Establish connection with the server, failing out if this
does not occur within 5 seconds.
Establish connection with the server, failing out instantly if server is not running
"""
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if socket_bufsize is not None:
conn.setsockopt(
socket.SOL_SOCKET, socket.SO_RCVBUF, self._socket_bufsize
)
conn.settimeout(60)
try:
logger.debug(f"IMDClient: Connecting to {host}:{port}")
conn.connect((host, port))
except ConnectionRefusedError:
raise ConnectionRefusedError(
f"IMDReader: Connection to {host}:{port} refused"
f"IMDClient: Connection to {host}:{port} refused"
)
return conn

Expand Down Expand Up @@ -226,10 +225,8 @@ def run(self):
self._unpause()
self._paused = False

logger.debug(f"IMDProducer: Attempting to get timestep")
imdf = self._buf.pop_empty_imdframe()

logger.debug(f"IMDProducer: Attempting to read nrg and pos")
# NOTE: This can be replaced with a simple parser if
# the server doesn't send the final frame with all data
# as in xtc
Expand All @@ -238,35 +235,28 @@ def run(self):
)
read_into_buf(self._conn, self._energies)

logger.debug("read energy data")
imdf.energies.update(
parse_energy_bytes(
self._energies, self._imdsinfo.endianness
)
)
logger.debug(f"IMDProducer: added energies to {imdf.energies}")

logger.debug(f"IMDProducer: added energies to imdf")

self._expect_header(
IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms
)

logger.debug(f"IMDProducer: Expected header")
read_into_buf(self._conn, self._positions)

logger.debug(f"IMDProducer: attempting to load ts")

imdf.positions = np.frombuffer(
self._positions, dtype=f"{self._imdsinfo.endianness}f"
).reshape((self._n_atoms, 3))

logger.debug(f"IMDProducer: ts loaded- inserting it")
logger.debug(
f"IMDProducer: positions for frame {self._frame}: {imdf.positions}"
)

self._buf.push_full_imdframe(imdf)

logger.debug(f"IMDProducer: ts inserted")

self._frame += 1
except EOFError:
# Don't raise error if simulation ended in a way
Expand All @@ -276,7 +266,7 @@ def run(self):
pass
finally:

logger.debug("IMDProducer: simulation ended")
logger.debug("IMDProducer: Simulation ended, cleaning up")

# Tell reader not to expect more frames to be added
self._buf.notify_producer_finished()
Expand All @@ -290,11 +280,8 @@ def _expect_header(self, expected_type, expected_value=None):

read_into_buf(self._conn, self._header)

logger.debug(f"IMDProducer: header: {self._header}")
header = IMDHeader(self._header)

logger.debug(f"IMDProducer: header parsed")

if header.type != expected_type:
raise RuntimeError

Expand Down Expand Up @@ -353,7 +340,7 @@ def __init__(
imdf_memsize = imdframe_memsize(n_atoms, imdsinfo)
self._total_imdf = buffer_size // imdf_memsize
logger.debug(
f"IMDFRAMEBuffer: Total timesteps allocated: {self._total_imdf}"
f"IMDFrameBuffer: Total timesteps allocated: {self._total_imdf}"
)
if self._total_imdf == 0:
raise ValueError(
Expand Down Expand Up @@ -436,9 +423,11 @@ def pop_full_imdframe(self):

imdf = self._full_q.get()

self._prev_empty_imdf = imdf
logger.debug(
f"IMDFrameBuffer: positions for frame {self._frame}: {imdf.positions}"
)

logger.debug(f"IMDReader: Got frame {self._frame}")
self._prev_empty_imdf = imdf

return imdf

Expand Down Expand Up @@ -519,17 +508,17 @@ def read_into_buf(sock, buf) -> bool:
# Server called close()
# Server is definitely done sending frames
logger.debug(
"IMDProducer: recv excepting due to server calling close()"
"read_into_buf excepting due to server calling close()"
)
raise EOFError
except TimeoutError:
# Server is *likely* done sending frames
logger.debug("IMDProducer: recv excepting due to timeout")
logger.debug("read_into_buf excepting due to timeout")
raise EOFError
except BlockingIOError:
# Occurs when timeout is 0 in place of a TimeoutError
# Server is *likely* done sending frames
logger.debug("IMDProducer: recv excepting due to blocking")
logger.debug("read_into_buf excepting due to blocking")
raise EOFError
total_received += received

Expand Down
2 changes: 1 addition & 1 deletion imdreader/IMDREADER.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(

super(IMDReader, self).__init__(filename, **kwargs)

logger.debug("Reader initializing")
logger.debug("IMDReader initializing")

if n_atoms is None:
raise ValueError("IMDReader: n_atoms must be specified")
Expand Down
26 changes: 22 additions & 4 deletions imdreader/tests/test_imdreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import imdreader
from imdreader.IMDClient import imdframe_memsize
from .utils import (
IMDServerEventType,
DummyIMDServer,
get_free_port,
ExpectPauseLoopV2Behavior,
Expand Down Expand Up @@ -39,6 +40,8 @@ def log_config():
logger.removeHandler(file_handler)


logger = logging.getLogger("imdreader.IMDREADER")

IMDENERGYKEYS = [
"step",
"temperature",
Expand Down Expand Up @@ -76,11 +79,18 @@ def server(self, traj):
server = DummyIMDServer(traj, 2)
return server

@pytest.mark.parametrize("endianness", ["<", ">"])
def test_endianness_traj_unchanged(self, server, endianness, ref, port):
@pytest.fixture(params=[">", "<"])
def setup_test_endianness_traj_unchanged(self, request, server, port):
server.port = port
server.imdsessioninfo.endianness = endianness
server.imdsessioninfo.endianness = request.param
server.start()
server.wait_for_event(IMDServerEventType.LISTENING)
return server, port

def test_endianness_traj_unchanged(
self, setup_test_endianness_traj_unchanged, ref
):
server, port = setup_test_endianness_traj_unchanged

reader = imdreader.IMDREADER.IMDReader(
f"localhost:{port}",
Expand All @@ -94,6 +104,9 @@ def test_endianness_traj_unchanged(self, server, endianness, ref, port):
timesteps = []

for ts in reader:
logger.debug(
f"test_imdreader: positions for frame {i}: {ts.positions}"
)
timesteps.append(ts.copy())
i += 1

Expand All @@ -106,11 +119,16 @@ def test_endianness_traj_unchanged(self, server, endianness, ref, port):
assert timesteps[j].data[energy_key] == j + offset
offset += 1

def test_pause_traj_unchanged(self, server, ref, port):
@pytest.fixture
def setup_test_pause_traj_unchanged(self, server, port):
server.port = port
server.loop_behavior = ExpectPauseLoopV2Behavior()
server.start()
server.wait_for_event(IMDServerEventType.LISTENING)
return server, port

def test_pause_traj_unchanged(self, setup_test_pause_traj_unchanged, ref):
server, port = setup_test_pause_traj_unchanged
# Give the reader only 1 IMDFrame of memory
# We expect the producer thread to have to
# pause every frame (except the first)
Expand Down
70 changes: 51 additions & 19 deletions imdreader/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import abc
import imdreader
import logging
from enum import Enum

logger = logging.getLogger(imdreader.IMDClient.__name__)
logger = logging.getLogger("imdreader.IMDREADER")


class Behavior(abc.ABC):
class IMDServerBehavior(abc.ABC):
"""Abstract base class for behaviors for the DummyIMDServer to perform.
Ensure that behaviors do not contain potentially infinite loops- they should be
testing for a specific sequence of events"""
Expand All @@ -24,17 +25,19 @@ def perform(self, *args, **kwargs):
pass


class DefaultConnectionBehavior(Behavior):
class DefaultConnectionBehavior(IMDServerBehavior):

def perform(self, host, port, event_q):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((host, port))
logger.debug(f"DummyIMDServer: Listening on {host}:{port}")
event_q.append(IMDServerEvent(IMDServerEventType.LISTENING, 120))
s.listen(120)
conn, addr = s.accept()
return (conn, addr)


class DefaultHandshakeV2Behavior(Behavior):
class DefaultHandshakeV2Behavior(IMDServerBehavior):
def perform(self, conn, imdsessioninfo, event_q):
header = struct.pack("!i", IMDHeaderType.IMD_HANDSHAKE.value)
if imdsessioninfo.endianness == "<":
Expand All @@ -44,12 +47,12 @@ def perform(self, conn, imdsessioninfo, event_q):
conn.sendall(header)


class DefaultHandshakeV3Behavior(Behavior):
class DefaultHandshakeV3Behavior(IMDServerBehavior):
def perform(self, conn, imdsessioninfo, event_q):
pass


class DefaultAwaitGoBehavior(Behavior):
class DefaultAwaitGoBehavior(IMDServerBehavior):
def perform(self, conn, event_q):
conn.settimeout(IMDAWAITGOTIME)
head_buf = bytearray(IMDHEADERSIZE)
Expand All @@ -60,18 +63,15 @@ def perform(self, conn, event_q):
logger.debug("DummyIMDServer: Received IMD_GO")


class DefaultLoopV2Behavior(Behavior):
class DefaultLoopV2Behavior(IMDServerBehavior):
"""Default behavior doesn't allow pausing"""

def perform(self, conn, traj, imdsessioninfo, event_q):
conn.settimeout(1)
headerbuf = bytearray(IMDHEADERSIZE)
paused = False

logger.debug("DummyIMDServer: Starting loop")

for i in range(len(traj)):
logger.debug(f"DummyIMDServer: generating frame {i}")

energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1)

Expand All @@ -92,6 +92,11 @@ def perform(self, conn, traj, imdsessioninfo, event_q):
pos_header = create_header_bytes(
IMDHeaderType.IMD_FCOORDS, traj.n_atoms
)

logger.debug(
f"DummyIMDServer: positions for frame {i}: {traj[i].positions}"
)

pos = np.ascontiguousarray(
traj[i].positions, dtype=f"{imdsessioninfo.endianness}f"
).tobytes()
Expand All @@ -108,8 +113,6 @@ def perform(self, conn, traj, imdsessioninfo, event_q):
conn.settimeout(10)
headerbuf = bytearray(IMDHEADERSIZE)

logger.debug("DummyIMDServer: Starting loop")

for i in range(len(traj)):
if i != 0:
read_into_buf(conn, headerbuf)
Expand Down Expand Up @@ -147,13 +150,17 @@ def perform(self, conn, traj, imdsessioninfo, event_q):
pos_header = create_header_bytes(
IMDHeaderType.IMD_FCOORDS, traj.n_atoms
)

logger.debug(
f"DummyIMDServer: positions for frame {i}: {traj[i].positions}"
)

pos = np.ascontiguousarray(
traj[i].positions, dtype=f"{imdsessioninfo.endianness}f"
).tobytes()

conn.sendall(energy_header + energies)
conn.sendall(pos_header + pos)
logger.debug(f"Sent frame {i}")


class ExpectPauseUnpauseAfterLoopV2Behavior(DefaultLoopV2Behavior):
Expand Down Expand Up @@ -208,13 +215,26 @@ def perform(self, conn, traj, imdsessioninfo, event_q):
)


class DefaultDisconnectBehavior(Behavior):
class DefaultDisconnectBehavior(IMDServerBehavior):
def perform(self, conn, event_q):
# Gromacs uses the c equivalent of the SHUT_WR flag
conn.shutdown(socket.SHUT_WR)
conn.close()


class IMDServerEventType(Enum):
LISTENING = 0


class IMDServerEvent:
def __init__(self, event_type, data):
self.event_type = event_type
self.data = data

def __str__(self):
return f"IMDServerEvent: {self.event_type}, {self.data}"


def create_default_imdsinfo_v2():
return IMDSessionInfo(
version=2,
Expand All @@ -234,11 +254,11 @@ def create_default_imdsinfo_v2():
class DummyIMDServer(threading.Thread):
"""Performs the following steps in order:
1. ConnectionBehavior.perform_connection()
2. HandshakeBehavior.perform_handshake()
3. AwaitGoBehavior.perform_await_go()
4. LoopBehavior.perform_loop()
5. DisconnectBehavior.perform_disconnect()
1. ConnectionBehavior.perform()
2. HandshakeBehavior.perform()
3. AwaitGoBehavior.perform()
4. LoopBehavior.perform()
5. DisconnectBehavior.perform()
Start the server by calling DummyIMDServer.start().
"""
Expand Down Expand Up @@ -280,6 +300,7 @@ def __init__(
self._event_q = []

def run(self):
logger.debug("DummyIMDServer: Starting")
conn = self.connection_behavior.perform(
self.host, self.port, self._event_q
)[0]
Expand All @@ -293,6 +314,17 @@ def run(self):
self.disconnect_behavior.perform(conn, self._event_q)
return

def wait_for_event(self, event_type, timeout=10, poll_interval=0.1):
end_time = time.time() + timeout
while time.time() < end_time:
time.sleep(poll_interval)
for event in self._event_q:
if event.event_type == event_type:
return event
raise TimeoutError(
f"DummyIMDServer: Timeout after {timeout} seconds waiting for event {event_type}"
)

@property
def port(self):
return self._port
Expand Down

0 comments on commit d3a292c

Please sign in to comment.