diff --git a/.github/workflows/gh-ci.yaml b/.github/workflows/gh-ci.yaml index d34577a..73d892d 100644 --- a/.github/workflows/gh-ci.yaml +++ b/.github/workflows/gh-ci.yaml @@ -39,7 +39,7 @@ jobs: release: "latest" main-tests: - if: "github.repository == 'becksteinlab/imdreader'" + if: github.repository == 'becksteinlab/imdreader' needs: environment-config runs-on: ${{ matrix.os }} strategy: @@ -113,7 +113,7 @@ jobs: pylint_check: - if: "github.repository == 'becksteinlab/imdreader'" + if: github.repository == 'becksteinlab/imdreader' needs: environment-config runs-on: ubuntu-latest @@ -139,7 +139,7 @@ jobs: pypi_check: - if: "github.repository == 'becksteinlab/imdreader'" + if: github.repository == 'becksteinlab/imdreader' needs: environment-config runs-on: ubuntu-latest diff --git a/imdreader/IMDClient.py b/imdreader/IMDClient.py new file mode 100644 index 0000000..382f660 --- /dev/null +++ b/imdreader/IMDClient.py @@ -0,0 +1,541 @@ +import socket +import threading +from .IMDProtocol import * +import logging +import queue +import select +import time +import numpy as np +from typing import Union, Dict + +logger = logging.getLogger("imdreader.IMDREADER") + + +class IMDClient: + def __init__( + self, + host, + port, + n_atoms, + socket_bufsize=None, + buffer_size=(10 * 1024**2), + pause_empty_proportion=0.25, + unpause_empty_proportion=0.5, + **kwargs, + ): + + conn = self._connect_to_server(host, port, socket_bufsize) + self._imdsinfo = self._await_IMD_handshake(conn) + self._buf = IMDFrameBuffer( + buffer_size, + self._imdsinfo, + n_atoms, + pause_empty_proportion, + unpause_empty_proportion, + ) + producer = IMDProducer( + conn, + self._buf, + self._imdsinfo, + n_atoms, + ) + self._go(conn) + producer.start() + + def get_imdframe(self): + return self._buf.pop_full_imdframe() + + def get_imdsessioninfo(self): + return self._imdsinfo + + def stop(self): + self._buf.notify_consumer_finished() + + def _connect_to_server(self, host, port, socket_bufsize): + """ + Establish connection with the server, failing out if this + does not occur within 5 seconds. + """ + conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if socket_bufsize is not None: + conn.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, self._socket_bufsize + ) + conn.settimeout(5) + try: + conn.connect((host, port)) + except ConnectionRefusedError: + raise ConnectionRefusedError( + f"IMDReader: Connection to {host}:{port} refused" + ) + return conn + + def _await_IMD_handshake(self, conn) -> IMDSessionInfo: + """ + Wait for the server to send a handshake packet, then determine + IMD session information. + """ + end = ">" + ver = None + + h_buf = bytearray(IMDHEADERSIZE) + try: + read_into_buf(conn, h_buf) + except IndexError: + raise ConnectionError("IMDReader: No handshake received.") + + header = IMDHeader(h_buf) + + if header.type != IMDHeaderType.IMD_HANDSHAKE: + raise ValueError( + f"Expected header type `IMD_HANDSHAKE`, got {header.type}" + ) + + if header.length not in IMDVERSIONS: + # Try swapping endianness + swapped = struct.unpack("i", header.length))[0] + if swapped not in IMDVERSIONS: + err_version = min(swapped, header.length) + # NOTE: Add test for this + raise ValueError( + f"Incompatible IMD version. Expected version in {IMDVERSIONS}, got {err_version}" + ) + else: + end = "<" + ver = swapped + else: + ver = header.length + + sinfo = None + if ver == 2: + # IMD v2 does not send a configuration packet + sinfo = IMDSessionInfo( + version=ver, + endianness=end, + imdterm=None, + imdwait=None, + imdpull=None, + wrapped_coords=False, + energies=1, + dimensions=0, + positions=1, + velocities=0, + forces=0, + ) + elif ver == 3: + sinfo = parse_imdv3_session_info(conn, end) + + return sinfo + + def _go(self, conn): + """ + Send a go packet to the client to start the simulation + and begin receiving data. + """ + go = create_header_bytes(IMDHeaderType.IMD_GO, 0) + conn.sendall(go) + logger.debug("IMDProducer: Sent go packet to server") + + +class IMDProducer(threading.Thread): + + def __init__( + self, + conn, + buffer, + sinfo, + n_atoms, + ): + super(IMDProducer, self).__init__() + self._conn = conn + self._imdsinfo = sinfo + self._paused = False + + # Timeout for first frame should be longer + # than rest of frames + self._timeout = 5 + self._conn.settimeout(self._timeout) + + self._buf = buffer + + self._frame = 0 + self._parse_frame_time = 0 + + # The body of an x/v/f packet should contain + # (4 bytes per float * 3 atoms * n_atoms) bytes + self._n_atoms = n_atoms + xvf_bytes = 12 * n_atoms + + self._header = bytearray(IMDHEADERSIZE) + if self._imdsinfo.energies > 0: + self._energies = bytearray(40) + if self._imdsinfo.dimensions > 0: + self._dimensions = bytearray(36) + if self._imdsinfo.positions > 0: + self._positions = bytearray(xvf_bytes) + if self._imdsinfo.velocities > 0: + self._velocities = bytearray(xvf_bytes) + if self._imdsinfo.forces > 0: + self._forces = bytearray(xvf_bytes) + + def _pause(self): + """ + Block the simulation until the buffer has more space. + """ + self._conn.settimeout(0) + logger.debug( + "IMDProducer: Pausing simulation because buffer is almost full" + ) + pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) + try: + self._conn.sendall(pause) + except ConnectionResetError as e: + # Simulation has already ended by the time we paused + raise IndexError + # Edge case: pause occured in the time between server sends its last frame + # and closing socket + # Simulation is not actually paused but is over, but we still want to read remaining data + # from the socket + + def _unpause(self): + self._conn.settimeout(self._timeout) + logger.debug("IMDProducer: Unpausing simulation, buffer has space") + unpause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) + try: + self._conn.sendall(unpause) + except ConnectionResetError as e: + # Edge case: pause occured in the time between server sends its last frame + # and closing socket + # Simulation was never actually paused in this case and is now over + raise IndexError + # Edge case: pause & unpause occured in the time between server sends its last frame and closing socket + # in this case, the simulation isn't actually unpaused but over + + def run(self): + try: + while True: + if not self._paused: + if self._buf.is_full(): + self._pause() + self._paused = True + + if self._paused: + # wait for socket to empty before unpausing + if not sock_contains_data(self._conn, 0): + self._buf.wait_for_space() + self._unpause() + self._paused = False + + logger.debug(f"IMDProducer: Attempting to get timestep") + imdf = self._buf.pop_empty_imdframe() + + logger.debug(f"IMDProducer: Attempting to read nrg and pos") + # NOTE: This can be replaced with a simple parser if + # the server doesn't send the final frame with all data + # as in xtc + self._expect_header( + IMDHeaderType.IMD_ENERGIES, expected_value=1 + ) + read_into_buf(self._conn, self._energies) + + logger.debug("read energy data") + imdf.energies.update( + parse_energy_bytes( + self._energies, self._imdsinfo.endianness + ) + ) + logger.debug(f"IMDProducer: added energies to {imdf.energies}") + + logger.debug(f"IMDProducer: added energies to imdf") + + self._expect_header( + IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms + ) + + logger.debug(f"IMDProducer: Expected header") + read_into_buf(self._conn, self._positions) + + logger.debug(f"IMDProducer: attempting to load ts") + + imdf.positions = np.frombuffer( + self._positions, dtype=f"{self._imdsinfo.endianness}f" + ).reshape((self._n_atoms, 3)) + + logger.debug(f"IMDProducer: ts loaded- inserting it") + + self._buf.push_full_imdframe(imdf) + + logger.debug(f"IMDProducer: ts inserted") + + self._frame += 1 + except EOFError: + # Don't raise error if simulation ended in a way + # that we expected + # i.e. consumer stopped or read_into_buf didn't find + # full token of data + pass + finally: + + logger.debug("IMDProducer: simulation ended") + + # Tell reader not to expect more frames to be added + self._buf.notify_producer_finished() + # MUST disconnect before stopping run loop + # if simulation already ended, this method will do nothing + self._disconnect() + + return + + def _expect_header(self, expected_type, expected_value=None): + + read_into_buf(self._conn, self._header) + + logger.debug(f"IMDProducer: header: {self._header}") + header = IMDHeader(self._header) + + logger.debug(f"IMDProducer: header parsed") + + if header.type != expected_type: + raise RuntimeError + + if expected_value is not None and header.length != expected_value: + raise RuntimeError + + def _disconnect(self): + try: + disconnect = create_header_bytes(IMDHeaderType.IMD_DISCONNECT, 0) + self._conn.sendall(disconnect) + logger.debug("IMDProducer: Disconnected from server") + except (ConnectionResetError, BrokenPipeError): + logger.debug( + f"IMDProducer: Attempted to disconnect but server already terminated the connection" + ) + finally: + self._conn.close() + + +class IMDFrameBuffer: + """ + Acts as interface between producer and consumer threads + """ + + def __init__( + self, + buffer_size, + imdsinfo, + n_atoms, + pause_empty_proportion, + unpause_empty_proportion, + ): + + # Syncing reader and producer + self._producer_finished = False + self._consumer_finished = False + + self._prev_empty_imdf = None + + self._empty_q = queue.Queue() + self._full_q = queue.Queue() + self._empty_imdf_avail = threading.Condition(threading.Lock()) + self._full_imdf_avail = threading.Condition(threading.Lock()) + + if pause_empty_proportion < 0 or pause_empty_proportion > 1: + raise ValueError("pause_empty_proportion must be between 0 and 1") + self._pause_empty_proportion = pause_empty_proportion + if unpause_empty_proportion < 0 or unpause_empty_proportion > 1: + raise ValueError("unpause_empty_proportion must be between 0 and 1") + self._unpause_empty_proportion = unpause_empty_proportion + + if buffer_size <= 0: + raise ValueError("Buffer size must be positive") + # Allocate IMDFrames with all of xvf present in imdsinfo + # even if they aren't sent every frame. Can be optimized if needed + imdf_memsize = imdframe_memsize(n_atoms, imdsinfo) + self._total_imdf = buffer_size // imdf_memsize + logger.debug( + f"IMDFRAMEBuffer: Total timesteps allocated: {self._total_imdf}" + ) + if self._total_imdf == 0: + raise ValueError( + "Buffer size is too small to hold a single IMDFrame" + ) + for i in range(self._total_imdf): + self._empty_q.put(IMDFrame(n_atoms, imdsinfo)) + + # Timing for analysis + self._t1 = None + self._t2 = None + + self._frame = 0 + + def is_full(self): + if ( + self._empty_q.qsize() / self._total_imdf + <= self._pause_empty_proportion + ): + return True + return False + + def wait_for_space(self): + with self._empty_imdf_avail: + while ( + self._empty_q.qsize() / self._total_imdf + < self._unpause_empty_proportion + ) and not self._consumer_finished: + self._empty_imdf_avail.wait() + + if self._consumer_finished: + raise EOFError + + def pop_empty_imdframe(self): + with self._empty_imdf_avail: + while self._empty_q.qsize() == 0 and not self._consumer_finished: + self._empty_imdf_avail.wait() + + if self._consumer_finished: + raise EOFError + + imdf = self._empty_q.get() + + return imdf + + def push_full_imdframe(self, imdf): + self._full_q.put(imdf) + with self._full_imdf_avail: + self._full_imdf_avail.notify() + + def pop_full_imdframe(self): + """Put empty_ts in the empty_q and get the next full timestep""" + # Start timer- one frame of analysis is starting (including removal + # from buffer) + + self._t1 = self._t2 + self._t2 = time.time() + if self._t1 is not None: + logger.debug( + f"IMDReader: Frame #{self._frame} analyzed in {self._t2 - self._t1} seconds" + ) + + self._frame += 1 + + # Return the processed IMDFrame + if self._prev_empty_imdf is not None: + self._empty_q.put(self._prev_empty_imdf) + with self._empty_imdf_avail: + self._empty_imdf_avail.notify() + + # Get the next IMDFrame + logger.debug("IMDReader: Attempting to get next frame") + with self._full_imdf_avail: + while self._full_q.qsize() == 0 and not self._producer_finished: + self._full_imdf_avail.wait() + + if self._producer_finished and self._full_q.qsize() == 0: + logger.debug("IMDReader: Producer finished") + raise EOFError + + imdf = self._full_q.get() + + self._prev_empty_imdf = imdf + + logger.debug(f"IMDReader: Got frame {self._frame}") + + return imdf + + def notify_producer_finished(self): + self._producer_finished = True + with self._full_imdf_avail: + self._full_imdf_avail.notify() + + def notify_consumer_finished(self): + self._consumer_finished = True + with self._empty_imdf_avail: + # noop if producer isn't waiting + self._empty_imdf_avail.notify() + + +class IMDFrame: + def __init__(self, n_atoms, imdsinfo): + if imdsinfo.energies > 0: + self.energies = { + "step": 0, + "temperature": 0.0, + "total_energy": 0.0, + "potential_energy": 0.0, + "van_der_walls_energy": 0.0, + "coulomb_energy": 0.0, + "bonds_energy": 0.0, + "angles_energy": 0.0, + "dihedrals_energy": 0.0, + "improper_dihedrals_energy": 0.0, + } + else: + self.energies = None + if imdsinfo.dimensions > 0: + self.dimensions = np.empty((3, 3), dtype=np.float32) + else: + self.dimensions = None + if imdsinfo.positions > 0: + self.positions = np.empty((n_atoms, 3), dtype=np.float32) + else: + self.positions = None + if imdsinfo.velocities > 0: + self.velocities = np.empty((n_atoms, 3), dtype=np.float32) + else: + self.velocities = None + if imdsinfo.forces > 0: + self.forces = np.empty((n_atoms, 3), dtype=np.float32) + else: + self.forces = None + + +def imdframe_memsize(n_atoms, imdsinfo) -> int: + """ + Calculate the memory size of an IMDFrame in bytes + """ + memsize = 0 + if imdsinfo.energies > 0: + memsize += 4 * 10 + if imdsinfo.dimensions > 0: + memsize += 4 * 9 + if imdsinfo.positions > 0: + memsize += 4 * 3 * n_atoms + if imdsinfo.velocities > 0: + memsize += 4 * 3 * n_atoms + if imdsinfo.forces > 0: + memsize += 4 * 3 * n_atoms + + return memsize + + +def read_into_buf(sock, buf) -> bool: + """Receives len(buf) bytes into buf from the socket sock""" + view = memoryview(buf) + total_received = 0 + while total_received < len(view): + try: + received = sock.recv_into(view[total_received:]) + if received == 0: + # Server called close() + # Server is definitely done sending frames + logger.debug( + "IMDProducer: recv excepting due to server calling close()" + ) + raise EOFError + except TimeoutError: + # Server is *likely* done sending frames + logger.debug("IMDProducer: recv excepting due to timeout") + raise EOFError + except BlockingIOError: + # Occurs when timeout is 0 in place of a TimeoutError + # Server is *likely* done sending frames + logger.debug("IMDProducer: recv excepting due to blocking") + raise EOFError + total_received += received + + +def sock_contains_data(sock, timeout) -> bool: + ready_to_read, ready_to_write, in_error = select.select( + [sock], [], [], timeout + ) + return sock in ready_to_read diff --git a/imdreader/IMDProtocol.py b/imdreader/IMDProtocol.py index 331cd0f..c08e33a 100644 --- a/imdreader/IMDProtocol.py +++ b/imdreader/IMDProtocol.py @@ -1,12 +1,8 @@ -import select -import socket import struct import logging from enum import Enum, auto from typing import Union from dataclasses import dataclass -import abc -import threading """ IMD Packets have an 8 byte header and a variable length payload @@ -24,6 +20,7 @@ IMDENERGYPACKETLENGTH = 40 IMDBOXPACKETLENGTH = 36 IMDVERSIONS = {2, 3} +IMDAWAITGOTIME = 1 class IMDHeaderType(Enum): @@ -38,10 +35,55 @@ class IMDHeaderType(Enum): IMD_TRATE = 8 IMD_IOERROR = 9 # New in IMD v3 - IMD_BOX = 10 - IMD_VELS = 11 - IMD_FORCES = 12 - IMD_EOS = 13 + # IMD_BOX = 10 + # IMD_VELS = 11 + # IMD_FORCES = 12 + # IMD_EOS = 13 + + +def parse_energy_bytes(data, endianness): + keys = [ + "step", + "temperature", + "total_energy", + "potential_energy", + "van_der_walls_energy", + "coulomb_energy", + "bonds_energy", + "angles_energy", + "dihedrals_energy", + "improper_dihedrals_energy", + ] + values = struct.unpack(f"{endianness}ifffffffff", data) + return dict(zip(keys, values)) + + +def create_energy_bytes( + step, + temperature, + total_energy, + potential_energy, + van_der_walls_energy, + coulomb_energy, + bonds_energy, + angles_energy, + dihedrals_energy, + improper_dihedrals_energy, + endianness, +): + return struct.pack( + f"{endianness}ifffffffff", + step, + temperature, + total_energy, + potential_energy, + van_der_walls_energy, + coulomb_energy, + bonds_energy, + angles_energy, + dihedrals_energy, + improper_dihedrals_energy, + ) class IMDHeader: diff --git a/imdreader/IMDREADER.py b/imdreader/IMDREADER.py index a116bca..ff3983d 100644 --- a/imdreader/IMDREADER.py +++ b/imdreader/IMDREADER.py @@ -53,6 +53,7 @@ from MDAnalysis.coordinates import core from MDAnalysis.lib.util import store_init_arguments from .IMDProtocol import * +from .IMDClient import * from .util import * import socket import threading @@ -60,6 +61,8 @@ import signal import logging import time +import select +import warnings logger = logging.getLogger(__name__) @@ -83,8 +86,6 @@ def __init__( filename, convert_units=True, n_atoms=None, - buffer_size=2**26, - socket_bufsize=None, **kwargs, ): """ @@ -93,138 +94,111 @@ def __init__( filename : a string of the form "host:port" where host is the hostname or IP address of the listening GROMACS server and port is the port number. - + convert_units : bool (optional) + convert units to MDAnalysis units [``True``] + n_atoms : int (optional) + number of atoms in the system. defaults to number of atoms + in the topology. don't set this unless you know what you're doing. """ self._producer = None super(IMDReader, self).__init__(filename, **kwargs) + logger.debug("Reader initializing") + + if n_atoms is None: + raise ValueError("IMDReader: n_atoms must be specified") self.n_atoms = n_atoms - logger.debug(f"IMDReader: n_atoms: {self.n_atoms}") + host, port = parse_host_port(filename) + + # This starts the simulation + self._imdclient = IMDClient(host, port, n_atoms, **kwargs) + + self._imdsinfo = self._imdclient.get_imdsessioninfo() + + self.convert_units = convert_units + # NOTE: changme after deciding how imdreader will handle units self.units = { "time": "ps", "length": "nm", "force": "kJ/(mol*nm)", + "velocity": "nm/ps", } - - self._host, self._port = parse_host_port(filename) - self._buffer_size = buffer_size - self._socket_bufsize = socket_bufsize + self.ts = self._Timestep( + self.n_atoms, + positions=(self._imdsinfo.positions > 0), + velocities=(self._imdsinfo.velocities > 0), + forces=(self._imdsinfo.forces > 0), + **self._ts_kwargs, + ) self._frame = -1 + self._init_scope = True + self._reopen_called = False + self._read_next_timestep() def _read_next_timestep(self): - if self._frame == -1: - # Reader is responsible for performing handshake - # and parsing the configuration before - # passing the connection off to the appropriate producer - # and allocating an appropriate buffer - conn = self._connect_to_server() - imdsinfo = self._await_IMD_handshake(conn) - - if imdsinfo.version == 2: - self._buffer = TimestepBuffer( - self._buffer_size, imdsinfo, self._Timestep, self.n_atoms - ) - self._producer = IMDProducer( - conn, - self._buffer, - imdsinfo, - self.n_atoms, - ) - # Producer responsible for sending go packet - self._producer.start() + # No rewinding- to both load the first frame on __init__ + # and access it during iteration, we need to store first ts in mem + if not self._init_scope and self._frame == -1: + self._frame += 1 + return self.ts return self._read_frame(self._frame + 1) def _read_frame(self, frame): - self._ts = self._buffer.consume_next_timestep() + try: + imdf = self._imdclient.get_imdframe() + except EOFError: + # Not strictly necessary, but for clarity + raise StopIteration - # Must set frame after read occurs successfully - # Since buffer raises IO error - # after producer is finished and there are no more frames - self._frame = frame + self._load_imdframe_into_ts(imdf) - return self._ts + if self.convert_units: + self._convert_units() - def _connect_to_server(self): - """ - Establish connection with the server, failing out if this - does not occur within 5 seconds. - """ - conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self._socket_bufsize is not None: - conn.setsockopt( - socket.SOL_SOCKET, socket.SO_RCVBUF, self._socket_bufsize - ) - conn.settimeout(5) - try: - conn.connect((self._host, self._port)) - except ConnectionRefusedError: - logger.error( - f"IMDReader: Connection to {self._host}:{self._port} refused" - ) - raise ConnectionRefusedError( - f"IMDReader: Connection to {self._host}:{self._port} refused" - ) - return conn - - def _await_IMD_handshake(self, conn) -> IMDSessionInfo: - """ - Wait for the server to send a handshake packet, then parse - endianness and version information and IMD session configuration. + self._frame = frame + + if self._init_scope: + self._init_scope = False + + logger.debug(f"IMDReader: Loaded frame {self._frame}") + return self.ts + + def _load_imdframe_into_ts(self, imdf): + self.ts.frame = self._frame + # NOTE: need time. + if imdf.energies is not None: + self.ts.data.update(imdf.energies) + if imdf.dimensions is not None: + self.ts.dimensions = core.triclinic_box(*imdf.dimensions) + if imdf.positions is not None: + self.ts.positions = imdf.positions + if imdf.velocities is not None: + self.ts.velocities = imdf.velocities + if imdf.forces is not None: + self.ts.forces = imdf.forces + + def _convert_units(self): + """converts time, position, velocity, and force values if they + are not given in MDAnalysis standard units """ - end = ">" - ver = None - - h_buf = bytearray(IMDHEADERSIZE) - handshake_recieved = read_into_buf(conn, h_buf) - if not handshake_recieved: - raise ConnectionError("IMDReader: No handshake received.") - - header = IMDHeader(h_buf) - - if header.type != IMDHeaderType.IMD_HANDSHAKE: - raise ValueError( - f"Expected header type `IMD_HANDSHAKE`, got {header.type}" - ) - - if header.length not in IMDVERSIONS: - # Try swapping endianness - swapped = struct.unpack("i", header.length))[0] - if swapped not in IMDVERSIONS: - err_version = min(swapped, header.length) - # NOTE: Add test for this - raise ValueError( - f"Incompatible IMD version. Expected version in {IMDVERSIONS}, got {err_version}" - ) - else: - end = "<" - ver = swapped - else: - ver = header.length - - sinfo = None - if ver == 2: - # IMD v2 does not send a configuration packet - sinfo = IMDSessionInfo( - version=ver, - endianness=end, - imdterm=None, - imdwait=None, - imdpull=None, - wrapped_coords=False, - energies=1, - dimensions=0, - positions=1, - velocities=0, - forces=0, - ) - elif ver == 3: - sinfo = parse_imdv3_session_info(conn, end) - - return sinfo + + self.ts.time = self.convert_time_from_native(self.ts.time) + + if self.ts.dimensions is not None: + self.convert_pos_from_native(self.ts.dimensions[:3]) + + if self.ts.has_positions: + self.convert_pos_from_native(self.ts.positions) + + if self.ts.has_velocities: + self.convert_velocities_from_native(self.ts.velocities) + + if self.ts.has_forces: + self.convert_forces_from_native(self.ts.forces) @property def n_frames(self): @@ -244,19 +218,23 @@ def _format_hint(thing): def close(self): """Gracefully shut down the reader. Stops the producer thread.""" - if self._producer is not None: - self._buffer.notify_consumer_finished() - # NOTE: is join necessary here? - # self._producer.join() + # Don't stop client if only first ts was loaded + # from __init__ + if self._init_scope: + return + self._imdclient.stop() # NOTE: removeme after testing - print("IMDReader shut down gracefully.") + logger.debug("IMDReader shut down gracefully.") # Incompatible methods def copy(self): raise NotImplementedError("IMDReader does not support copying") def _reopen(self): - pass + if self._reopen_called: + raise RuntimeError("IMDReader: Cannot reopen IMD stream") + self._frame = -1 + self._reopen_called = True def __getitem__(self, frame): """This method from ProtoReader must be overridden @@ -267,335 +245,401 @@ def __getitem__(self, frame): # NOTE: prevent auxiliary iteration methods from being called -class IMDProducer(threading.Thread): - - def __init__(self, conn, buffer, sinfo, n_atoms): - super(IMDProducer, self).__init__() - self._conn = conn - self.sinfo = sinfo - self._should_stop = False - self._paused = False - - # Timeout for first frame should be longer - # than rest of frames - self._timeout = 5 - self._conn.settimeout(self._timeout) - - self._buffer = buffer - - self._frame = 0 - self._parse_frame_time = 0 - - # The body of an x/v/f packet should contain - # (4 bytes per float * 3 atoms * n_atoms) bytes - self.n_atoms = n_atoms - xvf_bytes = 12 * n_atoms - - # NOTE: does 0 mean every frame? - self._header = bytearray(IMDHEADERSIZE) - if self.sinfo.energies > 0: - self._energies = bytearray(40) - if self.sinfo.dimensions > 0: - self._dimensions = bytearray(36) - if self.sinfo.positions > 0: - self._positions = bytearray(xvf_bytes) - if self.sinfo.velocities > 0: - self._velocities = bytearray(xvf_bytes) - if self.sinfo.forces > 0: - self._forces = bytearray(xvf_bytes) - - def _go(self): - """ - Send a go packet to the client to start the simulation - and begin receiving data. - """ - # NOTE: removeme after testing - print("sending go packet...") - go = create_header_bytes(IMDHeaderType.IMD_GO, 0) - self._conn.sendall(go) - logger.debug("IMDProducer: Sent go packet to server") - - def _pause(self): - """ - Block the simulation until the buffer has more space. - """ - logger.debug( - "IMDProducer: Pausing simulation because buffer is almost full" - ) - pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) - try: - self._conn.sendall(pause) - except ConnectionResetError as e: - # Simulation has already ended by the time we paused - return False - # Edge case: pause occured in the time between server sends its last frame - # and closing socket - # Simulation is not actually paused but is over, but we still want to read remaining data - # from the socket - return True - - def _unpause(self): - logger.debug("IMDProducer: Unpausing simulation, buffer has space") - unpause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) - try: - self._conn.sendall(unpause) - except ConnectionResetError as e: - # Edge case: pause occured in the time between server sends its last frame - # and closing socket - # Simulation was never actually paused in this case and is now over - return False - # Edge case: pause & unpause occured in the time between server sends its last frame and closing socket - # in this case, the simulation isn't actually unpaused but over - return True - - def run(self): - self._go() - - while not self._should_stop: - logger.debug(f"IMDProducer: Attempting to get timestep") - ts = self._buffer.get_timestep() - # Reader is closed - if ts is None: - break - - logger.debug(f"IMDProducer: Got timstep") - - # This value is approximate - n_empty_ts = self._buffer.get_empty_qsize() - - logger.debug(f"IMDProducer: Got empty qsize") - - # If buffer is more than 50% full, pause the simulation - if not self._paused and n_empty_ts < self._buffer.capacity // 2: - # if pause succeeds, simulation may still have ended - pause_success = self._pause() - if pause_success: - self._paused = True - - # If buffer is less than 25% full, unpause the simulation - if self._paused and n_empty_ts >= self._buffer.capacity // 4: - unpause_success = self._unpause() - if unpause_success: - self._paused = False - - logger.debug(f"IMDProducer: Attempting to read nrg and pos") - # NOTE: This can be replaced with a simple parser if - # the server doesn't send the final frame with all data - e_header_success = self._expect_header( - IMDHeaderType.IMD_ENERGIES, expected_value=1 - ) - if not e_header_success: - break - logger.debug(f"IMDProducer: Expected header") - - energies_successs = read_into_buf(self._conn, self._energies) - if not energies_successs: - break - - logger.debug(f"IMDProducer: Read nrg, reading pos") - - p_header_success = self._expect_header( - IMDHeaderType.IMD_FCOORDS, expected_value=self.n_atoms - ) - if not p_header_success: - break - logger.debug(f"IMDProducer: Expected header") - positions_success = read_into_buf(self._conn, self._positions) - - if not positions_success: - break - - logger.debug(f"IMDProducer: attempting to load ts") - - ts.frame = self._frame - ts.positions = np.frombuffer( - self._positions, dtype=f"{self.sinfo.endianness}f4" - ).reshape((self.n_atoms, 3)) +# class IMDProducer(threading.Thread): + +# def __init__( +# self, +# conn, +# buffer, +# sinfo, +# n_atoms, +# pause_empty_proportion, +# unpause_empty_proportion, +# ): +# super(IMDProducer, self).__init__() +# self._conn = conn +# self.sinfo = sinfo +# self._paused = False + +# self._pause_empty_proportion = pause_empty_proportion +# self._unpause_empty_proportion = unpause_empty_proportion + +# # Timeout for first frame should be longer +# # than rest of frames +# self._timeout = 5 +# self._conn.settimeout(self._timeout) + +# self._buffer = buffer + +# self._frame = 0 +# self._parse_frame_time = 0 + +# # The body of an x/v/f packet should contain +# # (4 bytes per float * 3 atoms * n_atoms) bytes +# self.n_atoms = n_atoms +# xvf_bytes = 12 * n_atoms + +# self._header = bytearray(IMDHEADERSIZE) +# if self.sinfo.energies > 0: +# self._energies = bytearray(40) +# if self.sinfo.dimensions > 0: +# self._dimensions = bytearray(36) +# if self.sinfo.positions > 0: +# self._positions = bytearray(xvf_bytes) +# if self.sinfo.velocities > 0: +# self._velocities = bytearray(xvf_bytes) +# if self.sinfo.forces > 0: +# self._forces = bytearray(xvf_bytes) + +# def _go(self): +# """ +# Send a go packet to the client to start the simulation +# and begin receiving data. +# """ +# # NOTE: removeme after testing +# print("sending go packet...") +# go = create_header_bytes(IMDHeaderType.IMD_GO, 0) +# self._conn.sendall(go) +# logger.debug("IMDProducer: Sent go packet to server") + +# def _pause(self): +# """ +# Block the simulation until the buffer has more space. +# """ +# self._conn.settimeout(0) +# logger.debug( +# "IMDProducer: Pausing simulation because buffer is almost full" +# ) +# pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) +# try: +# self._conn.sendall(pause) +# except ConnectionResetError as e: +# # Simulation has already ended by the time we paused +# raise IndexError +# # Edge case: pause occured in the time between server sends its last frame +# # and closing socket +# # Simulation is not actually paused but is over, but we still want to read remaining data +# # from the socket + +# def _unpause(self): +# self._conn.settimeout(self._timeout) +# logger.debug("IMDProducer: Unpausing simulation, buffer has space") +# unpause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) +# try: +# self._conn.sendall(unpause) +# except ConnectionResetError as e: +# # Edge case: pause occured in the time between server sends its last frame +# # and closing socket +# # Simulation was never actually paused in this case and is now over +# raise IndexError +# # Edge case: pause & unpause occured in the time between server sends its last frame and closing socket +# # in this case, the simulation isn't actually unpaused but over + +# def run(self): +# self._go() + +# try: +# while True: + +# logger.debug(f"IMDProducer: Got timstep") + +# # This value is approximate, doesn't acquire lock +# n_empty_ts = self._buffer.get_empty_qsize() + +# logger.debug(f"IMDProducer: Got empty qsize {n_empty_ts}") + +# logger.debug( +# f"IMDProducer: {n_empty_ts} // {self._buffer.n_ts} = {n_empty_ts // self._buffer.n_ts}" +# ) +# logger.debug(f"IMDProducer: {self._pause_empty_proportion}") +# # If buffer is more than 50% full, pause the simulation +# if ( +# not self._paused +# and n_empty_ts / self._buffer.n_ts +# <= self._pause_empty_proportion +# ): +# # if pause succeeds, simulation may still have ended +# self._pause() +# self._paused = True + +# if self._paused: +# # If buffer is less than 25% full, unpause the simulation +# if ( +# n_empty_ts / self._buffer.n_ts +# >= self._unpause_empty_proportion +# ): +# self._unpause() +# self._paused = False +# # If the buffer is still full but we've run out +# # of frames to read, wait until the buffer is less full +# elif not sock_contains_data(self._conn, 0): +# self._buffer.wait_for_space( +# self._unpause_empty_proportion +# ) +# self._unpause() +# self._paused = False + +# logger.debug(f"IMDProducer: Attempting to get timestep") +# ts = self._buffer.get_timestep() + +# logger.debug(f"IMDProducer: Attempting to read nrg and pos") +# # NOTE: This can be replaced with a simple parser if +# # the server doesn't send the final frame with all data +# # as in xtc +# self._expect_header( +# IMDHeaderType.IMD_ENERGIES, expected_value=1 +# ) +# logger.debug(f"IMDProducer: Expected header") + +# read_into_buf(self._conn, self._energies) + +# self._load_energies(ts) +# logger.debug(f"IMDProducer: Read nrg, reading pos") + +# self._expect_header( +# IMDHeaderType.IMD_FCOORDS, expected_value=self.n_atoms +# ) + +# logger.debug(f"IMDProducer: Expected header") +# read_into_buf(self._conn, self._positions) + +# logger.debug(f"IMDProducer: attempting to load ts") + +# ts.frame = self._frame +# ts.positions = np.frombuffer( +# self._positions, dtype=f"{self.sinfo.endianness}f" +# ).reshape((self.n_atoms, 3)) + +# logger.debug(f"IMDProducer: ts loaded- inserting it") + +# self._buffer.insert(ts) + +# logger.debug(f"IMDProducer: ts inserted") + +# if self._frame == 0: +# self._conn.settimeout(1) + +# self._frame += 1 +# except IndexError: +# # Don't raise error if simulation ended in a way +# # that we expected +# pass +# finally: + +# logger.debug("IMDProducer: simluation ended") + +# # Tell reader not to expect more frames to be added +# self._buffer.notify_producer_finished() +# # MUST disconnect before stopping run loop +# # if simulation already ended, this method will do nothing +# self._disconnect() + +# return + +# def _expect_header(self, expected_type, expected_value=None): + +# read_into_buf(self._conn, self._header) + +# logger.debug(f"IMDProducer: header: {self._header}") +# header = IMDHeader(self._header) + +# logger.debug(f"IMDProducer: header parsed") + +# if header.type != expected_type: +# raise RuntimeError + +# if expected_value is not None and header.length != expected_value: +# raise RuntimeError + +# def _disconnect(self): +# try: +# disconnect = create_header_bytes(IMDHeaderType.IMD_DISCONNECT, 0) +# self._conn.sendall(disconnect) +# logger.debug("IMDProducer: Disconnected from server") +# except (ConnectionResetError, BrokenPipeError): +# logger.debug( +# f"IMDProducer: Attempted to disconnect but server already terminated the connection" +# ) +# finally: +# self._conn.close() + +# def _load_energies(self, ts): + +# energy_dict = IMDEnergyPacket( +# self._energies, self.sinfo.endianness +# ).data +# logger.debug(f"IMDProducer: Loaded energies {energy_dict}") +# ts.data.update(energy_dict) +# logger.debug(f"IMDProducer: Updated ts with energies") + + +# class TimestepBuffer: +# """ +# Acts as interface between producer and consumer threads +# """ + +# def __init__(self, buffer_size, imdsinfo, ts_class, n_atoms, ts_kwargs): + +# # Syncing reader and producer +# self._producer_finished = False +# self._consumer_finished = False + +# self._prev_empty_ts = None + +# self._empty_q = queue.Queue() +# self._full_q = queue.Queue() +# self._empty_ts_avail = threading.Condition(threading.Lock()) +# self._full_ts_avail = threading.Condition(threading.Lock()) - logger.debug(f"IMDProducer: ts loaded- inserting it") - - self._buffer.insert(ts) - - logger.debug(f"IMDProducer: ts inserted") - - if self._frame == 0: - self._conn.settimeout(1) - - self._frame += 1 - - logger.debug("IMDProducer: break occuurred") - - # Tell reader not to expect more frames to be added - self._buffer.notify_producer_finished() - # MUST disconnect before stopping run loop - # if simulation already ended, this method will do nothing - self._disconnect() - - return - - def _expect_header(self, expected_type, expected_value=None): - - recv_success = read_into_buf(self._conn, self._header) - logger.debug(f"IMDProducer: recv success: {recv_success}") - if not recv_success: - return False - - logger.debug(f"IMDProducer: header: {self._header}") - header = IMDHeader(self._header) - - logger.debug(f"IMDProducer: header parsed") - - if header.type != expected_type: - return False - - if expected_value is not None and header.length != expected_value: - return False - - return True - - def _disconnect(self): - try: - disconnect = create_header_bytes(IMDHeaderType.IMD_DISCONNECT, 0) - self._conn.sendall(disconnect) - logger.debug("IMDProducer: Disconnected from server") - except ConnectionResetError: - logger.debug( - f"IMDProducer: Attempted to disconnect but server already terminated the connection" - ) - finally: - self._conn.close() - - -class TimestepBuffer: - """ - Acts as interface between producer and consumer threads - """ - - # NOTE: Use 1 buffer for pos, vel, force rather than 3 - def __init__(self, buffer_size, imdsinfo, ts_class, n_atoms): - self._buffer_size = buffer_size - - # Syncing reader and producer - self._producer_finished = False - self._consumer_finished = False - - self._prev_empty_ts = None - - self.imdsinfo = imdsinfo - - self._empty_q = queue.Queue() - self._full_q = queue.Queue() - - self._full_ts_avail = threading.Condition(threading.Lock()) - self._empty_ts_avail = threading.Condition(threading.Lock()) - - # NOTE: hardcoded for testing - self._total_ts = 101 - for i in range(101): - self._empty_q.put(ts_class(n_atoms, positions=True)) - - # Timing for analysis - self._t1 = None - self._t2 = None - self._start = True - self._analyze_frame_time = None - - self._frame = 0 - - def get_empty_qsize(self): - return self._empty_q.qsize() - - def get_timestep(self): - with self._empty_ts_avail: - while self._empty_q.qsize() == 0 and not self._consumer_finished: - self._empty_ts_avail.wait() - - if self._consumer_finished: - return None - - ts = self._empty_q.get() - return ts - - def insert(self, ts): - self._full_q.put(ts) - with self._full_ts_avail: - self._full_ts_avail.notify() - - def consume_next_timestep(self): - """Put empty_ts in the empty_q and get the next full timestep""" - # Start timer- one frame of analysis is starting (including removal - # from buffer) - - self._t1 = self._t2 - self._t2 = time.time() - if self._t1 is not None: - logger.debug( - f"IMDReader: Frame #{self._frame - 1} analyzed in {self._t2 - self._t1} seconds" - ) - self._analyze_frame_time = self._t2 - self._t1 - - self._frame += 1 - - # Return the processed timestep - if self._prev_empty_ts is not None: - self._empty_q.put(self._prev_empty_ts) - with self._empty_ts_avail: - self._empty_ts_avail.notify() - - # Get the next timestep - with self._full_ts_avail: - while self._full_q.qsize() == 0 and not self._producer_finished: - self._full_ts_avail.wait() - - # Buffer is responsible for stopping iteration - if self._producer_finished and self._full_q.qsize() == 0: - raise StopIteration from None - - ts = self._full_q.get() - - self._prev_empty_ts = ts - - return ts - - def notify_producer_finished(self): - self._producer_finished = True - with self._full_ts_avail: - self._full_ts_avail.notify() - - def notify_consumer_finished(self): - self._consumer_finished = True - with self._empty_ts_avail: - # noop if producer isn't waiting - self._empty_ts_avail.notify() - - @property - def analyze_frame_time(self): - if self._analyze_frame_time is not None: - return self._analyze_frame_time - else: - return None - - @property - def capacity(self): - return self._total_ts - - -def read_into_buf(sock, buf) -> bool: - """Receives len(buf) bytes into buf from the socket sock""" - view = memoryview(buf) - total_received = 0 - while total_received < len(view): - try: - received = sock.recv_into(view[total_received:]) - if received == 0: - # Server called close() - logger.debug( - "IMDProducer: recv returning false due to server calling close()" - ) - return False - except TimeoutError: - # Server is *likely* done sending frames - logger.debug("IMDProducer: recv returning false due to timeout") - return False - total_received += received - return True +# # Allocate timesteps with all of xvf present in imdsinfo +# # even if they aren't sent every frame. Can be optimized if needed +# ts_memsize = approximate_timestep_memsize( +# n_atoms, +# (imdsinfo.energies > 0), +# (imdsinfo.dimensions > 0), +# (imdsinfo.positions > 0), +# (imdsinfo.velocities > 0), +# (imdsinfo.forces > 0), +# ) +# self._total_ts = buffer_size // ts_memsize +# logger.debug( +# f"Timestepbuffer: Total timesteps allocated: {self._total_ts}" +# ) +# for i in range(self._total_ts): +# self._empty_q.put( +# ts_class( +# n_atoms, +# positions=(imdsinfo.positions > 0), +# velocities=(imdsinfo.velocities > 0), +# forces=(imdsinfo.forces > 0), +# **ts_kwargs, +# ) +# ) + +# # Timing for analysis +# self._t1 = None +# self._t2 = None +# self._start = True +# self._analyze_frame_time = None + +# self._frame = 0 + +# def get_empty_qsize(self): +# return self._empty_q.qsize() + +# def get_timestep(self): +# with self._empty_ts_avail: +# while self._empty_q.qsize() == 0 and not self._consumer_finished: +# self._empty_ts_avail.wait() + +# if self._consumer_finished: +# raise IndexError + +# ts = self._empty_q.get() + +# return ts + +# def wait_for_space(self, unpause_empty_proportion): +# with self._empty_ts_avail: +# while ( +# self._empty_q.qsize() / self.n_ts < unpause_empty_proportion +# ) and not self._consumer_finished: +# self._empty_ts_avail.wait() + +# if self._consumer_finished: +# raise IndexError + +# def insert(self, ts): +# self._full_q.put(ts) +# with self._full_ts_avail: +# self._full_ts_avail.notify() + +# def consume_next_timestep(self): +# """Put empty_ts in the empty_q and get the next full timestep""" +# # Start timer- one frame of analysis is starting (including removal +# # from buffer) + +# self._t1 = self._t2 +# self._t2 = time.time() +# if self._t1 is not None: +# logger.debug( +# f"IMDReader: Frame #{self._frame} analyzed in {self._t2 - self._t1} seconds" +# ) +# self._analyze_frame_time = self._t2 - self._t1 + +# self._frame += 1 + +# # Return the processed timestep +# if self._prev_empty_ts is not None: +# self._empty_q.put(self._prev_empty_ts) +# with self._empty_ts_avail: +# self._empty_ts_avail.notify() + +# # Get the next timestep +# with self._full_ts_avail: +# while self._full_q.qsize() == 0 and not self._producer_finished: +# self._full_ts_avail.wait() + +# # Buffer is responsible for stopping iteration +# if self._producer_finished and self._full_q.qsize() == 0: +# raise StopIteration from None + +# ts = self._full_q.get() + +# self._prev_empty_ts = ts + +# return ts + +# def notify_producer_finished(self): +# self._producer_finished = True +# with self._full_ts_avail: +# self._full_ts_avail.notify() + +# def notify_consumer_finished(self): +# self._consumer_finished = True +# with self._empty_ts_avail: +# # noop if producer isn't waiting +# self._empty_ts_avail.notify() + +# @property +# def analyze_frame_time(self): +# if self._analyze_frame_time is not None: +# return self._analyze_frame_time +# else: +# return None + +# @property +# def n_ts(self): +# return self._total_ts + + +# def read_into_buf(sock, buf) -> bool: +# """Receives len(buf) bytes into buf from the socket sock""" +# view = memoryview(buf) +# total_received = 0 +# while total_received < len(view): +# try: +# received = sock.recv_into(view[total_received:]) +# if received == 0: +# # Server called close() +# logger.debug( +# "IMDProducer: recv excepting due to server calling close()" +# ) +# raise IndexError +# except TimeoutError: +# # Server is *likely* done sending frames +# logger.debug("IMDProducer: recv excepting due to timeout") +# raise IndexError +# except BlockingIOError: +# # Server is done sending frames +# logger.debug("IMDProducer: recv excepting due to blocking") +# raise IndexError +# total_received += received +# return True + + +# def sock_contains_data(sock, timeout) -> bool: +# ready_to_read, ready_to_write, in_error = select.select( +# [sock], [], [], timeout +# ) +# return sock in ready_to_read diff --git a/imdreader/__init__.py b/imdreader/__init__.py index 43c46dc..75b2339 100644 --- a/imdreader/__init__.py +++ b/imdreader/__init__.py @@ -2,8 +2,7 @@ IMDReader """ -from .IMDProtocol import * -from .IMDREADER import * +from .IMDREADER import IMDReader from importlib.metadata import version diff --git a/imdreader/tests/test_imdreader.py b/imdreader/tests/test_imdreader.py index 5e3b5ef..3ec357e 100644 --- a/imdreader/tests/test_imdreader.py +++ b/imdreader/tests/test_imdreader.py @@ -1,21 +1,30 @@ from MDAnalysisTests.datafiles import COORDINATES_TOPOLOGY, COORDINATES_TRR import MDAnalysis as mda -from .utils import DummyIMDServer +import imdreader +from imdreader.IMDClient import imdframe_memsize +from .utils import ( + DummyIMDServer, + get_free_port, + ExpectPauseLoopV2Behavior, +) from MDAnalysisTests.coordinates.base import ( MultiframeReaderTest, BaseReference, BaseWriterTest, assert_timestep_almost_equal, ) +from MDAnalysisTests.coordinates.test_xdr import TRRReference import numpy as np import logging import pytest +import time +# NOTE: removeme after initial testing @pytest.fixture(autouse=True) def log_config(): logger = logging.getLogger("imdreader.IMDREADER") - file_handler = logging.FileHandler("test.log") + file_handler = logging.FileHandler("tmp/test.log") formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -26,16 +35,107 @@ def log_config(): logger.removeHandler(file_handler) -def test_traj_unchanged(): - u = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_TRR) - u_imd = mda.Universe(COORDINATES_TOPOLOGY, "localhost:8888") - server = DummyIMDServer() - server.start() +IMDENERGYKEYS = [ + "step", + "temperature", + "total_energy", + "potential_energy", + "van_der_walls_energy", + "coulomb_energy", + "bonds_energy", + "angles_energy", + "dihedrals_energy", + "improper_dihedrals_energy", +] + + +class TestIMDReaderV2: + + @pytest.fixture + def port(self): + return get_free_port() + + @pytest.fixture + def traj(self): + return mda.coordinates.TRR.TRRReader( + COORDINATES_TRR, convert_units=False + ) + + @pytest.fixture + def ref(self): + return mda.coordinates.TRR.TRRReader( + COORDINATES_TRR, convert_units=True + ) + + @pytest.fixture + def server(self, traj): + server = DummyIMDServer(traj, 2) + return server + + @pytest.mark.parametrize("endianness", ["<", ">"]) + def test_endianness_traj_unchanged(self, server, endianness, ref, port): + server.port = port + server.imdsessioninfo.endianness = endianness + server.start() + + reader = imdreader.IMDREADER.IMDReader( + f"localhost:{port}", + convert_units=True, + n_atoms=ref.trajectory.n_atoms, + ) + + i = 0 + # Can't call assert in loop- this prevents reader's __exit__ from being called + # if assert fails. Instead copy timesteps and then assert them + timesteps = [] + + for ts in reader: + timesteps.append(ts.copy()) + i += 1 + + assert i == len(ref) + + for j in range(len(ref)): + np.testing.assert_allclose(timesteps[j].positions, ref[j].positions) + offset = 0 + for energy_key in IMDENERGYKEYS: + assert timesteps[j].data[energy_key] == j + offset + offset += 1 + + def test_pause_traj_unchanged(self, server, ref, port): + server.port = port + server.loop_behavior = ExpectPauseLoopV2Behavior() + server.start() + + # Give the reader only 1 IMDFrame of memory + # We expect the producer thread to have to + # pause every frame (except the first) + reader = imdreader.IMDREADER.IMDReader( + f"localhost:{port}", + convert_units=True, + n_atoms=ref.trajectory.n_atoms, + buffer_size=imdframe_memsize( + ref.trajectory.n_atoms, server.imdsessioninfo + ), + ) + + i = 0 + timesteps = [] + + for ts in reader: + time.sleep(1) + timesteps.append(ts.copy()) + i += 1 + + assert i == len(ref) - i = 0 - for ts in u_imd.trajectory: - print(ts.dimensions) - np.testing.assert_allclose(ts.positions, u.trajectory[i].positions) - i += 1 + for j in range(len(ref)): + np.testing.assert_allclose(timesteps[j].positions, ref[j].positions) + offset = 0 + for energy_key in IMDENERGYKEYS: + assert timesteps[j].data[energy_key] == j + offset + offset += 1 - assert i == len(u.trajectory) + def test_no_connection(self): + with pytest.raises(ConnectionRefusedError): + imdreader.IMDREADER.IMDReader("localhost:12345", n_atoms=1) diff --git a/imdreader/tests/test_integration.py b/imdreader/tests/test_integration.py index 7082b1b..6a77ecb 100644 --- a/imdreader/tests/test_integration.py +++ b/imdreader/tests/test_integration.py @@ -143,6 +143,8 @@ def run_gmx(tmpdir): def test_traj_len(run_gmx): + logger = logging.getLogger("imdreader.IMDREADER") + logger.debug("test_traj_len") port = run_gmx recvuntil( "gmx_output.log", @@ -158,64 +160,3 @@ def test_traj_len(run_gmx): pass assert len(u2.trajectory) == len(u.trajectory) - - -def test_pause(run_gmx, caplog): - port = run_gmx - recvuntil( - "gmx_output.log", - "IMD: Will wait until I have a connection and IMD_GO orders.", - 60, - ) - u2 = mda.Universe(IMDGROUP_GRO, OUT_TRR) - u = mda.Universe( - IMDGROUP_GRO, - f"localhost:{port}", - # 1240 bytes per frame - buffer_size=62000, - ) - i = 0 - for ts in u.trajectory: - time.sleep(0.05) - assert_timestep_almost_equal(ts, u2.trajectory[i], decimal=5) - i += 1 - - assert len(u.trajectory) == 101 - assert ( - "IMDProducer: Pausing simulation because buffer is almost full" - in caplog.text - ) - assert "IMDProducer: Unpausing simulation, buffer has space" in caplog.text - assert "data likely lost in frame" not in caplog.text - - -def test_no_connection(caplog): - u = mda.Universe( - IMDGROUP_GRO, - "localhost:8888", - buffer_size=62000, - ) - for ts in u.trajectory: - with pytest.raises(ConnectionRefusedError): - pass - # NOTE: assert this in output: No connection received. Pausing simulation. - assert "IMDReader: Connection to localhost:8888 refused" in caplog.text - - -""" -import socket -import struct - -conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - -def connect(): - conn.connect(("localhost",8888)) - conn.recv(8) - go = struct.pack("!ii", 3, 0) - conn.sendall(go) - -def pause(): - pause = struct.pack("!ii", 7, 0) - conn.sendall(pause) - -""" diff --git a/imdreader/tests/utils.py b/imdreader/tests/utils.py index f80e149..b8780a6 100644 --- a/imdreader/tests/utils.py +++ b/imdreader/tests/utils.py @@ -5,116 +5,232 @@ import threading import time from imdreader.IMDProtocol import * -from imdreader.IMDREADER import read_into_buf +from imdreader.IMDREADER import read_into_buf, sock_contains_data from MDAnalysisTests.datafiles import COORDINATES_TOPOLOGY, COORDINATES_TRR +import abc +import imdreader +import logging + +logger = logging.getLogger(imdreader.IMDClient.__name__) class Behavior(abc.ABC): + """Abstract base class for behaviors for the DummyIMDServer to perform. + Ensure that behaviors do not contain potentially infinite loops- they should be + testing for a specific sequence of events""" + @abc.abstractmethod - def perform(self): + def perform(self, *args, **kwargs): pass class DefaultConnectionBehavior(Behavior): - def __init__(self, host, port): - self.host = host - self.port = port - def perform(self): + def perform(self, host, port, event_q): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind((self.host, self.port)) + s.bind((host, port)) s.listen(60) conn, addr = s.accept() return (conn, addr) class DefaultHandshakeV2Behavior(Behavior): - def __init__(self, imdsessioninfo): - self.version = imdsessioninfo.version - self.endianness = imdsessioninfo.endianness - - def perform(self, conn): + def perform(self, conn, imdsessioninfo, event_q): header = struct.pack("!i", IMDHeaderType.IMD_HANDSHAKE.value) - if self.endianness == "<": - header += struct.pack("i", self.version) + header += struct.pack(">i", 2) conn.sendall(header) class DefaultHandshakeV3Behavior(Behavior): - def __init__(self, imdsessioninfo): - pass - - def perform(self, conn): + def perform(self, conn, imdsessioninfo, event_q): pass class DefaultAwaitGoBehavior(Behavior): - def __init__(self, timeout=1): - self.timeout = timeout - - def perform(self, conn): - conn.settimeout(self.timeout) + def perform(self, conn, event_q): + conn.settimeout(IMDAWAITGOTIME) head_buf = bytearray(IMDHEADERSIZE) read_into_buf(conn, head_buf) header = IMDHeader(head_buf) if header.type != IMDHeaderType.IMD_GO: raise ValueError("Expected IMD_GO packet, got something else") - conn.settimeout(None) + logger.debug("DummyIMDServer: Received IMD_GO") -class DefaultLoopBehaviorV2(Behavior): - def __init__(self, trajectory, imdsessioninfo): - self.traj = trajectory - self.endianness = imdsessioninfo.endianness - self.imdterm = imdsessioninfo.imdterm - self.imdwait = imdsessioninfo.imdwait - self.imdpull = imdsessioninfo.imdpull +class DefaultLoopV2Behavior(Behavior): + """Default behavior doesn't allow pausing""" - def perform(self, conn): + def perform(self, conn, traj, imdsessioninfo, event_q): conn.settimeout(1) headerbuf = bytearray(IMDHEADERSIZE) paused = False - energies = np.zeros((len(self.traj), 10), dtype=np.float32) + logger.debug("DummyIMDServer: Starting loop") - for i in range(len(self.traj)): - while sock_contains_data(conn, 1) or paused: - header_success = read_into_buf(conn, headerbuf) - if header_success: - header = IMDHeader(headerbuf) - if header.type == IMDHeaderType.IMD_PAUSE: - paused = not paused + for i in range(len(traj)): + logger.debug(f"DummyIMDServer: generating frame {i}") energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1) - energy = np.ascontiguousarray( - energies[i], dtype=f"{self.endianness}f4" + + energies = create_energy_bytes( + i, + i + 1, + i + 2, + i + 3, + i + 4, + i + 5, + i + 6, + i + 7, + i + 8, + i + 9, + imdsessioninfo.endianness, + ) + + pos_header = create_header_bytes( + IMDHeaderType.IMD_FCOORDS, traj.n_atoms + ) + pos = np.ascontiguousarray( + traj[i].positions, dtype=f"{imdsessioninfo.endianness}f" ).tobytes() + + conn.sendall(energy_header + energies) + conn.sendall(pos_header + pos) + logger.debug(f"Sent frame {i}") + + +class ExpectPauseLoopV2Behavior(DefaultLoopV2Behavior): + """Waits for a pause & unpause in all frames after the first frame.""" + + def perform(self, conn, traj, imdsessioninfo, event_q): + conn.settimeout(10) + headerbuf = bytearray(IMDHEADERSIZE) + + logger.debug("DummyIMDServer: Starting loop") + + for i in range(len(traj)): + if i != 0: + read_into_buf(conn, headerbuf) + header = IMDHeader(headerbuf) + if header.type != IMDHeaderType.IMD_PAUSE: + logger.debug( + f"DummyIMDServer: Expected IMD_PAUSE, got {header.type}" + ) + + read_into_buf(conn, headerbuf) + header = IMDHeader(headerbuf) + if header.type != IMDHeaderType.IMD_PAUSE: + logger.debug( + f"DummyIMDServer: Expected IMD_PAUSE, got {header.type}" + ) + + logger.debug(f"DummyIMDServer: generating frame {i}") + + energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1) + + energies = create_energy_bytes( + i, + i + 1, + i + 2, + i + 3, + i + 4, + i + 5, + i + 6, + i + 7, + i + 8, + i + 9, + imdsessioninfo.endianness, + ) + pos_header = create_header_bytes( - IMDHeaderType.IMD_FCOORDS, self.traj.n_atoms + IMDHeaderType.IMD_FCOORDS, traj.n_atoms ) pos = np.ascontiguousarray( - self.traj[i].positions, dtype=f"{self.endianness}f4" + traj[i].positions, dtype=f"{imdsessioninfo.endianness}f" ).tobytes() - conn.sendall(energy_header + energy) + conn.sendall(energy_header + energies) conn.sendall(pos_header + pos) + logger.debug(f"Sent frame {i}") -class DefaultDisconnectBehavior(Behavior): - def __init__(self, pre_shutdown_wait=0, pre_close_wait=0): - self.pre_shutdown_wait = pre_shutdown_wait - self.pre_close_wait = pre_close_wait +class ExpectPauseUnpauseAfterLoopV2Behavior(DefaultLoopV2Behavior): + def perform(self, conn, traj, imdsessioninfo, event_q): + conn.settimeout(10) + headerbuf = bytearray(IMDHEADERSIZE) + + logger.debug("DummyIMDServer: Starting loop") + + for i in range(len(traj)): + logger.debug(f"DummyIMDServer: generating frame {i}") + + energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1) + + energies = create_energy_bytes( + i, + i + 1, + i + 2, + i + 3, + i + 4, + i + 5, + i + 6, + i + 7, + i + 8, + i + 9, + imdsessioninfo.endianness, + ) - def perform(self, conn): - time.sleep(self.pre_shutdown_wait) + pos_header = create_header_bytes( + IMDHeaderType.IMD_FCOORDS, traj.n_atoms + ) + pos = np.ascontiguousarray( + traj[i].positions, dtype=f"{imdsessioninfo.endianness}f" + ).tobytes() + + conn.sendall(energy_header + energies) + conn.sendall(pos_header + pos) + logger.debug(f"Sent frame {i}") + + read_into_buf(conn, headerbuf) + header = IMDHeader(headerbuf) + if header.type != IMDHeaderType.IMD_PAUSE: + logger.debug( + f"DummyIMDServer: Expected IMD_PAUSE, got {header.type}" + ) + + read_into_buf(conn, headerbuf) + header = IMDHeader(headerbuf) + if header.type != IMDHeaderType.IMD_PAUSE: + logger.debug( + f"DummyIMDServer: Expected IMD_PAUSE, got {header.type}" + ) + + +class DefaultDisconnectBehavior(Behavior): + def perform(self, conn, event_q): # Gromacs uses the c equivalent of the SHUT_WR flag conn.shutdown(socket.SHUT_WR) - time.sleep(self.pre_close_wait) conn.close() +def create_default_imdsinfo_v2(): + return IMDSessionInfo( + version=2, + endianness="<", + imdterm=None, + imdwait=None, + imdpull=None, + wrapped_coords=True, + energies=1, + dimensions=0, + positions=1, + velocities=0, + forces=0, + ) + + class DummyIMDServer(threading.Thread): """Performs the following steps in order: @@ -124,89 +240,126 @@ class DummyIMDServer(threading.Thread): 4. LoopBehavior.perform_loop() 5. DisconnectBehavior.perform_disconnect() - Each if these behaviors can be changed by calling DummyIMDServer.set_x_behavior(y) where x is the behavior name - and y is the behavior object. - Start the server by calling DummyIMDServer.start(). """ def __init__( self, - host="localhost", - port=8888, - imdsessioninfo=None, - traj=None, + traj, + version, ): """ If passing `traj` kwarg, ensure it is a copy of the trajectory used - in the test to avoid modifying the original trajectory "head" in the + in the test to avoid moving the original trajectory "head" in the main thread. """ super().__init__(daemon=True) - if traj is None: - traj = mda.Universe( - COORDINATES_TOPOLOGY, COORDINATES_TRR - ).trajectory - - if imdsessioninfo is None: - imdsessioninfo = IMDSessionInfo( - version=2, - endianness="<", - imdterm=None, - imdwait=None, - imdpull=None, - wrapped_coords=True, - energies=1, - dimensions=0, - positions=1, - velocities=0, - forces=0, - ) + logger.debug("DummyIMDServer: Initializing") - self.connection_behavior = DefaultConnectionBehavior(host, port) + self._traj = traj - if imdsessioninfo.version == 2: - self.handshake_behavior = DefaultHandshakeV2Behavior(imdsessioninfo) - self.loop_behavior = DefaultLoopBehaviorV2(traj, imdsessioninfo) - elif imdsessioninfo.version == 3: - self.connection_behavior = DefaultConnectionBehavior( - "localhost", port - ) + self._host = "localhost" + self._port = 8888 + + self.connection_behavior = DefaultConnectionBehavior() + + if version == 2: + self._imdsessioninfo = create_default_imdsinfo_v2() + self.handshake_behavior = DefaultHandshakeV2Behavior() + self.loop_behavior = DefaultLoopV2Behavior() + elif version == 3: + # self.imdsessioninfo = create_default_imdsinfo_v3() # self.handshake_behavior = DefaultHandshakeV3Behavior() + # self.loop_behavior = DefaultLoopBehaviorV3() + pass self.await_go_behavior = DefaultAwaitGoBehavior() - self.disconnect_behavior = DefaultDisconnectBehavior() - def run(self): - conn = self.connection_behavior.perform()[0] - self.handshake_behavior.perform(conn) - self.await_go_behavior.perform(conn) - self.loop_behavior.perform(conn) - self.disconnect_behavior.perform(conn) - - def set_connection_behavior(self, connection_behavior): - self.connection_behavior = connection_behavior - - def set_handshake_behavior(self, handshake_behavior): - self.handshake_behavior = handshake_behavior - - def set_await_go_behavior(self, await_go_behavior): - self.await_go_behavior = await_go_behavior + self._event_q = [] - def set_loop_behavior(self, loop_behavior): - self.loop_behavior = loop_behavior - - def set_disconnect_behavior(self, disconnect_behavior): - self.disconnect_behavior = disconnect_behavior - - -def sock_contains_data(sock, timeout) -> bool: - ready_to_read, ready_to_write, in_error = select.select( - [sock], [], [], timeout - ) - return sock in ready_to_read + def run(self): + conn = self.connection_behavior.perform( + self.host, self.port, self._event_q + )[0] + self.handshake_behavior.perform( + conn, self.imdsessioninfo, self._event_q + ) + self.await_go_behavior.perform(conn, self._event_q) + self.loop_behavior.perform( + conn, self._traj, self.imdsessioninfo, self._event_q + ) + self.disconnect_behavior.perform(conn, self._event_q) + return + + @property + def port(self): + return self._port + + @port.setter + def port(self, port): + self._port = port + + @property + def host(self): + return self._host + + @host.setter + def host(self, host): + self._host = host + + @property + def imdsessioninfo(self): + return self._imdsessioninfo + + @imdsessioninfo.setter + def imdsessioninfo(self, imdsessioninfo): + self._imdsessioninfo = imdsessioninfo + + @property + def connection_behavior(self): + return self._connection_behavior + + @connection_behavior.setter + def connection_behavior(self, connection_behavior): + self._connection_behavior = connection_behavior + + @property + def handshake_behavior(self): + return self._handshake_behavior + + @handshake_behavior.setter + def handshake_behavior(self, handshake_behavior): + self._handshake_behavior = handshake_behavior + + @property + def await_go_behavior(self): + return self._await_go_behavior + + @await_go_behavior.setter + def await_go_behavior(self, await_go_behavior): + self._await_go_behavior = await_go_behavior + + @property + def loop_behavior(self): + return self._loop_behavior + + @loop_behavior.setter + def loop_behavior(self, loop_behavior): + self._loop_behavior = loop_behavior + + @property + def disconnect_behavior(self): + return self._disconnect_behavior + + @disconnect_behavior.setter + def disconnect_behavior(self, disconnect_behavior): + self._disconnect_behavior = disconnect_behavior + + @property + def event_q(self): + return self._event_q def recvuntil(file_path, target_line, timeout): @@ -240,21 +393,6 @@ def recvuntil(file_path, target_line, timeout): ) -def check_port_availability(port): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - s.bind(("0.0.0.0", port)) - except socket.error as e: - if e.errno == socket.errno.EADDRINUSE: - print(f"Port {port} is already in use") - return False - else: - raise - finally: - s.close() - return True - - def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) diff --git a/imdreader/util.py b/imdreader/util.py index d46d5e3..c90a23b 100644 --- a/imdreader/util.py +++ b/imdreader/util.py @@ -57,3 +57,22 @@ def parse_host_port(filename): else: # Handle the case where the format does not match "host:port" raise ValueError("Filename must be in the format 'host:port'") + + +def approximate_timestep_memsize( + n_atoms, energies, dimensions, positions, velocities, forces +): + total_size = 0 + + if energies: + total_size += 36 + + if dimensions: + # dimensions in the form (*A*, *B*, *C*, *alpha*, *beta*, *gamma*) + total_size += 24 + + for dset in (positions, velocities, forces): + if dset: + total_size += n_atoms * 12 + + return total_size