diff --git a/src/_ert/forward_model_runner/cli.py b/src/_ert/forward_model_runner/cli.py index a41b0ca4b16..4760212fb98 100644 --- a/src/_ert/forward_model_runner/cli.py +++ b/src/_ert/forward_model_runner/cli.py @@ -1,15 +1,19 @@ import argparse +import asyncio import json import logging import os import signal import sys -import time -import typing from datetime import datetime +from typing import List, Sequence from _ert.forward_model_runner import reporting -from _ert.forward_model_runner.reporting.message import Finish, ProcessTreeStatus +from _ert.forward_model_runner.reporting.message import ( + Finish, + Message, + ProcessTreeStatus, +) from _ert.forward_model_runner.runner import ForwardModelRunner JOBS_FILE = "jobs.json" @@ -17,6 +21,10 @@ logger = logging.getLogger(__name__) +class ForwardModelRunnerException(Exception): + pass + + def _setup_reporters( is_interactive_run, ens_id, @@ -24,8 +32,8 @@ def _setup_reporters( ee_token=None, ee_cert_path=None, experiment_id=None, -) -> typing.List[reporting.Reporter]: - reporters: typing.List[reporting.Reporter] = [] +) -> List[reporting.Reporter]: + reporters: List[reporting.Reporter] = [] if is_interactive_run: reporters.append(reporting.Interactive()) elif ens_id and experiment_id is None: @@ -71,26 +79,26 @@ def _setup_logging(directory: str = "logs"): JOBS_JSON_RETRY_TIME = 30 -def _wait_for_retry(): - time.sleep(JOBS_JSON_RETRY_TIME) +async def _wait_for_retry(): + await asyncio.sleep(JOBS_JSON_RETRY_TIME) -def _read_jobs_file(retry=True): +async def _read_jobs_file(retry=True): try: - with open(JOBS_FILE, "r", encoding="utf-8") as json_file: + with open(JOBS_FILE, "r", encoding="utf-8") as json_file: # noqa: ASYNC230 return json.load(json_file) except json.JSONDecodeError as e: raise IOError("Job Runner cli failed to load JSON-file.") from e except FileNotFoundError as e: if retry: logger.error(f"Could not find file {JOBS_FILE}, retrying") - _wait_for_retry() - return _read_jobs_file(retry=False) + await _wait_for_retry() + return await _read_jobs_file(retry=False) else: raise e -def main(args): +async def main(args): parser = argparse.ArgumentParser( description=( "Run all the jobs specified in jobs.json, " @@ -118,7 +126,7 @@ def main(args): # Make sure that logging is setup _after_ we have moved to the runpath directory _setup_logging() - jobs_data = _read_jobs_file() + jobs_data = await _read_jobs_file() experiment_id = jobs_data.get("experiment_id") ens_id = jobs_data.get("ens_id") @@ -135,21 +143,64 @@ def main(args): ee_cert_path, experiment_id, ) + reporter_queue: asyncio.Queue[Message] = asyncio.Queue() + + is_done = asyncio.Event() + forward_model_runner = ForwardModelRunner(jobs_data, reporter_queue=reporter_queue) + forward_model_runner_task = asyncio.create_task( + forward_model_runner.run(parsed_args.job) + ) + reporting_task = asyncio.create_task( + handle_reporting(reporters, reporter_queue, is_done) + ) + + def handle_sigterm(*args, **kwargs): + nonlocal reporters, forward_model_runner_task + forward_model_runner_task.cancel() + for reporter in reporters: + reporter.cancel() - job_runner = ForwardModelRunner(jobs_data) + asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, handle_sigterm) + + await forward_model_runner_task + + is_done.set() + await reporting_task + + +async def handle_reporting( + reporters: Sequence[reporting.Reporter], + message_queue: asyncio.Queue[Message], + is_done: asyncio.Event, +): + while True: + try: + job_status = await asyncio.wait_for(message_queue.get(), timeout=2) + except asyncio.TimeoutError: + if is_done.is_set(): + break + continue - for job_status in job_runner.run(parsed_args.job): logger.info(f"Job status: {job_status}") for reporter in reporters: try: - reporter.report(job_status) + await reporter.report(job_status) except OSError as oserror: print( f"job_dispatch failed due to {oserror}. Stopping and cleaning up." ) - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGKILL) + await let_reporters_finish(reporters) + raise ForwardModelRunnerException from oserror + message_queue.task_done() if isinstance(job_status, Finish) and not job_status.success(): - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGKILL) + await let_reporters_finish(reporters) + raise ForwardModelRunnerException(job_status.error_message) + + await let_reporters_finish(reporters) + + +async def let_reporters_finish(reporters): + for reporter in reporters: + if isinstance(reporter, reporting.Event): + await reporter.join() diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 60b1042ab91..0accdf59a2d 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -96,7 +96,7 @@ async def get_websocket(self) -> ClientConnection: close_timeout=self.CONNECTION_TIMEOUT, ) - async def _send(self, msg: AnyStr) -> None: + async def send(self, msg: AnyStr) -> None: for retry in range(self._max_retries + 1): try: if self.websocket is None: @@ -133,6 +133,3 @@ async def _send(self, msg: AnyStr) -> None: raise ClientConnectionError(_error_msg) from exception await asyncio.sleep(0.2 + self._timeout_multiplier * retry) self.websocket = None - - def send(self, msg: AnyStr) -> None: - self.loop.run_until_complete(self._send(msg)) diff --git a/src/_ert/forward_model_runner/forward_model_step.py b/src/_ert/forward_model_runner/forward_model_step.py index bc4b649e0e8..b26bdf5c7cc 100644 --- a/src/_ert/forward_model_runner/forward_model_step.py +++ b/src/_ert/forward_model_runner/forward_model_step.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import contextlib import io import json @@ -11,10 +12,19 @@ import time from datetime import datetime as dt from pathlib import Path -from subprocess import Popen, run -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple, cast +from subprocess import run +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Dict, + List, + Optional, + Sequence, + Tuple, + cast, +) -from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess +from psutil import AccessDenied, NoSuchProcess, Process, ZombieProcess from .io import check_executable from .reporting.message import ( @@ -89,9 +99,9 @@ def __init__( self.std_err = job_data.get("stderr") self.std_out = job_data.get("stdout") - def run(self) -> Generator[Start | Exited | Running | None]: + async def run(self) -> AsyncGenerator[Start | Exited | Running | None]: try: - for msg in self._run(): + async for msg in self._run(): yield msg except Exception as e: yield Exited(self, exit_code=1).with_error(str(e)) @@ -151,9 +161,8 @@ def _create_environment(self) -> Optional[Dict[str, str]]: combined_environment = {**os.environ, **environment} return combined_environment - def _run(self) -> Generator[Start | Exited | Running | None]: + async def _run(self) -> AsyncGenerator[Start | Exited | Running | None]: start_message = self.create_start_message_and_check_job_files() - yield start_message if not start_message.success(): return @@ -167,14 +176,14 @@ def _run(self) -> Generator[Start | Exited | Running | None]: target_file_mtime: Optional[int] = _get_target_file_ntime(target_file) try: - proc = Popen( - arg_list, + proc = await asyncio.create_subprocess_exec( + *arg_list, stdin=stdin, stdout=stdout, stderr=stderr, env=self._create_environment(), ) - process = Process(proc.pid) + except OSError as e: exited_message = self._handle_process_io_error_and_create_exited_message( e, stderr @@ -186,9 +195,9 @@ def _run(self) -> Generator[Start | Exited | Running | None]: exit_code = None max_memory_usage = 0 - fm_step_pids = {int(process.pid)} + fm_step_pids = {int(proc.pid)} while exit_code is None: - (memory_rss, cpu_seconds, oom_score) = _get_processtree_data(process) + (memory_rss, cpu_seconds, oom_score) = _get_processtree_data(proc.pid) max_memory_usage = max(memory_rss, max_memory_usage) yield Running( self, @@ -203,8 +212,10 @@ def _run(self) -> Generator[Start | Exited | Running | None]: ) try: - exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD) - except TimeoutExpired: + exit_code = await asyncio.wait_for( + proc.wait(), timeout=self.MEMORY_POLL_PERIOD + ) + except asyncio.TimeoutError: potential_exited_msg = ( self.handle_process_timeout_and_create_exited_msg(exit_code, proc) ) @@ -212,9 +223,11 @@ def _run(self) -> Generator[Start | Exited | Running | None]: yield potential_exited_msg return - fm_step_pids |= { - int(child.pid) for child in process.children(recursive=True) - } + with contextlib.suppress(NoSuchProcess): + proccess = Process(proc.pid) + fm_step_pids |= { + int(child.pid) for child in proccess.children(recursive=True) + } ensure_file_handles_closed([stdin, stdout, stderr]) exited_message = self._create_exited_message_based_on_exit_code( @@ -274,7 +287,7 @@ def _create_exited_msg_for_non_zero_exit_code( ) def handle_process_timeout_and_create_exited_msg( - self, exit_code: Optional[int], proc: Popen[Process] + self, exit_code: Optional[int], proc: asyncio.subprocess.Process ) -> Exited | None: max_running_minutes = self.job_data.get("max_running_minutes") run_start_time = dt.now() @@ -349,7 +362,6 @@ def _check_job_files(self) -> list[str]: if executable_error := check_executable(self.job_data.get("executable")): errors.append(executable_error) - return errors def _check_target_file_is_written( @@ -428,7 +440,7 @@ def ensure_file_handles_closed(file_handles: Sequence[io.TextIOWrapper | None]) def _get_processtree_data( - process: Process, + pid: int, ) -> Tuple[int, float, Optional[int]]: """Obtain the oom_score (the Linux kernel uses this number to decide which process to kill first in out-of-memory siturations). @@ -450,21 +462,19 @@ def _get_processtree_data( memory_rss = 0 cpu_seconds = 0.0 with contextlib.suppress(ValueError, FileNotFoundError): - oom_score = int( - Path(f"/proc/{process.pid}/oom_score").read_text(encoding="utf-8") - ) - with ( - contextlib.suppress( - ValueError, NoSuchProcess, AccessDenied, ZombieProcess, ProcessLookupError - ), - process.oneshot(), + oom_score = int(Path(f"/proc/{pid}/oom_score").read_text(encoding="utf-8")) + with contextlib.suppress( + ValueError, NoSuchProcess, AccessDenied, ZombieProcess, ProcessLookupError ): - memory_rss = process.memory_info().rss - cpu_seconds = process.cpu_times().user + process = Process(pid) + with process.oneshot(): + memory_rss = process.memory_info().rss + cpu_seconds = process.cpu_times().user with contextlib.suppress( NoSuchProcess, AccessDenied, ZombieProcess, ProcessLookupError ): + process = Process(pid) for child in process.children(recursive=True): with contextlib.suppress( ValueError, diff --git a/src/_ert/forward_model_runner/job_dispatch.py b/src/_ert/forward_model_runner/job_dispatch.py index ccd1e5044c2..c83f7601484 100644 --- a/src/_ert/forward_model_runner/job_dispatch.py +++ b/src/_ert/forward_model_runner/job_dispatch.py @@ -1,3 +1,5 @@ +import asyncio +import contextlib import os import signal import sys @@ -13,12 +15,8 @@ def sigterm_handler(_signo, _stack_frame): def main(): os.nice(19) signal.signal(signal.SIGTERM, sigterm_handler) - try: - job_runner_main(sys.argv) - except Exception as e: - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGTERM) - raise e + with contextlib.suppress(asyncio.CancelledError): + asyncio.run(job_runner_main(sys.argv)) if __name__ == "__main__": diff --git a/src/_ert/forward_model_runner/reporting/base.py b/src/_ert/forward_model_runner/reporting/base.py index 5b7dd1e3dc8..65e0e54d825 100644 --- a/src/_ert/forward_model_runner/reporting/base.py +++ b/src/_ert/forward_model_runner/reporting/base.py @@ -5,5 +5,9 @@ class Reporter(ABC): @abstractmethod - def report(self, msg: Message): + async def report(self, msg: Message): """Report a message.""" + + @abstractmethod + def cancel(self): + """Safely shut down the reporter""" diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 8bf13dee238..ffce533c3a6 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -1,8 +1,7 @@ from __future__ import annotations +import asyncio import logging -import queue -import threading from datetime import datetime, timedelta from pathlib import Path from typing import Final, Union @@ -28,11 +27,10 @@ Exited, Finish, Init, + Message, Running, Start, ) -from _ert.forward_model_runner.reporting.statemachine import StateMachine -from _ert.threading import ErtThread logger = logging.getLogger(__name__) @@ -68,45 +66,48 @@ def __init__(self, evaluator_url, token=None, cert_path=None): else: self._cert = None - self._statemachine = StateMachine() - self._statemachine.add_handler((Init,), self._init_handler) - self._statemachine.add_handler((Start, Running, Exited), self._job_handler) - self._statemachine.add_handler((Checksum,), self._checksum_handler) - self._statemachine.add_handler((Finish,), self._finished_handler) - self._ens_id = None self._real_id = None - self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue() - self._event_publisher_thread = ErtThread(target=self._event_publisher) + self._event_queue: asyncio.Queue[events.Event | EventSentinel] = asyncio.Queue() + self._timeout_timestamp = None - self._timestamp_lock = threading.Lock() # seconds to timeout the reporter the thread after Finish() was received self._reporter_timeout = 60 - def _event_publisher(self): + self._queue_polling_timeout = 2 + self._event_publishing_task = asyncio.create_task(self.async_event_publisher()) + + async def join(self) -> None: + await self._event_publishing_task + + async def async_event_publisher(self): logger.debug("Publishing event.") - with Client( + async with Client( url=self._evaluator_url, token=self._token, cert=self._cert, ) as client: event = None - while True: - with self._timestamp_lock: - if ( - self._timeout_timestamp is not None - and datetime.now() > self._timeout_timestamp - ): - self._timeout_timestamp = None - break + while ( + self._timeout_timestamp is None + or datetime.now() <= self._timeout_timestamp + ): if event is None: # if we successfully sent the event we can proceed # to next one - event = self._event_queue.get() + try: + event = await asyncio.wait_for( + self._event_queue.get(), timeout=self._queue_polling_timeout + ) + except asyncio.TimeoutError: + continue + if event is self._sentinel: + self._event_queue.task_done() break try: - client.send(event_to_json(event)) + await client.send(event_to_json(event)) + self._event_queue.task_done() event = None except ClientConnectionError as exception: # Possible intermittent failure, we retry sending the event @@ -115,21 +116,31 @@ def _event_publisher(self): # The receiving end has closed the connection, we stop # sending events logger.debug(str(exception)) + self._event_queue.task_done() break - def report(self, msg): - self._statemachine.transition(msg) + async def report(self, msg: Message): + await self._report(msg) + + async def _report(self, msg: Message): + if isinstance(msg, Init): + await self._init_handler(msg) + elif isinstance(msg, (Start, Running, Exited)): + await self._job_handler(msg) + elif isinstance(msg, Checksum): + await self._checksum_handler(msg) + elif isinstance(msg, Finish): + await self._finished_handler() - def _dump_event(self, event: events.Event): + async def _dump_event(self, event: events.Event): logger.debug(f'Schedule "{type(event)}" for delivery') - self._event_queue.put(event) + await self._event_queue.put(event) - def _init_handler(self, msg: Init): + async def _init_handler(self, msg: Init): self._ens_id = str(msg.ens_id) self._real_id = str(msg.real_id) - self._event_publisher_thread.start() - def _job_handler(self, msg: Union[Start, Running, Exited]): + async def _job_handler(self, msg: Union[Start, Running, Exited]): assert msg.job job_name = msg.job.name() job_msg = { @@ -144,16 +155,16 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): std_out=str(Path(msg.job.std_out).resolve()), std_err=str(Path(msg.job.std_err).resolve()), ) - self._dump_event(event) + await self._dump_event(event) if not msg.success(): logger.error(f"Job {job_name} FAILED to start") event = ForwardModelStepFailure(**job_msg, error_msg=msg.error_message) - self._dump_event(event) + await self._dump_event(event) elif isinstance(msg, Exited): if msg.success(): logger.debug(f"Job {job_name} exited successfully") - self._dump_event(ForwardModelStepSuccess(**job_msg)) + await self._dump_event(ForwardModelStepSuccess(**job_msg)) else: logger.error( _JOB_EXIT_FAILED_STRING.format( @@ -165,7 +176,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): event = ForwardModelStepFailure( **job_msg, exit_code=msg.exit_code, error_msg=msg.error_message ) - self._dump_event(event) + await self._dump_event(event) elif isinstance(msg, Running): logger.debug(f"{job_name} job is running") @@ -175,21 +186,21 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): current_memory_usage=msg.memory_status.rss, cpu_seconds=msg.memory_status.cpu_seconds, ) - self._dump_event(event) + await self._dump_event(event) - def _finished_handler(self, _): - self._event_queue.put(Event._sentinel) - with self._timestamp_lock: - self._timeout_timestamp = datetime.now() + timedelta( - seconds=self._reporter_timeout - ) - if self._event_publisher_thread.is_alive(): - self._event_publisher_thread.join() + async def _finished_handler(self): + await self._event_queue.put(Event._sentinel) + self._timeout_timestamp = datetime.now() + timedelta( + seconds=self._reporter_timeout + ) - def _checksum_handler(self, msg: Checksum): + async def _checksum_handler(self, msg: Checksum) -> None: fm_checksum = ForwardModelStepChecksum( ensemble=self._ens_id, real=self._real_id, checksums={msg.run_path: msg.data}, ) - self._dump_event(fm_checksum) + await self._dump_event(fm_checksum) + + def cancel(self) -> None: + self._event_publishing_task.cancel() diff --git a/src/_ert/forward_model_runner/reporting/file.py b/src/_ert/forward_model_runner/reporting/file.py index e6e601fe0f2..0427be570a5 100644 --- a/src/_ert/forward_model_runner/reporting/file.py +++ b/src/_ert/forward_model_runner/reporting/file.py @@ -39,9 +39,8 @@ def __init__(self): self.status_dict = {} self.node = socket.gethostname() - def report(self, msg: Message): + async def report(self, msg: Message): fm_step_status = {} - if msg.job: logger.debug("Adding message job to status dictionary.") fm_step_status = self.status_dict["jobs"][msg.job.index] @@ -217,3 +216,6 @@ def _dump_ok_file(): def _dump_status_json(self): with open(STATUS_json, "wb") as fp: fp.write(orjson.dumps(self.status_dict, option=orjson.OPT_INDENT_2)) + + def cancel(self): + pass diff --git a/src/_ert/forward_model_runner/reporting/interactive.py b/src/_ert/forward_model_runner/reporting/interactive.py index fd489c78378..004a7546a7a 100644 --- a/src/_ert/forward_model_runner/reporting/interactive.py +++ b/src/_ert/forward_model_runner/reporting/interactive.py @@ -26,7 +26,10 @@ def _report(msg: Message) -> Optional[str]: ) return f"Running job: {msg.job.name()} ... " - def report(self, msg: Message): + async def report(self, msg: Message): _msg = self._report(msg) if _msg is not None: print(_msg) + + def cancel(self): + pass diff --git a/src/_ert/forward_model_runner/reporting/statemachine.py b/src/_ert/forward_model_runner/reporting/statemachine.py deleted file mode 100644 index 4d749414e4d..00000000000 --- a/src/_ert/forward_model_runner/reporting/statemachine.py +++ /dev/null @@ -1,62 +0,0 @@ -import logging -from typing import Callable, Dict, Tuple, Type - -from _ert.forward_model_runner.reporting.message import ( - Checksum, - Exited, - Finish, - Init, - Message, - Running, - Start, -) - -logger = logging.getLogger(__name__) - - -class TransitionError(ValueError): - pass - - -class StateMachine: - def __init__(self) -> None: - logger.debug("Initializing state machines") - initialized = (Init,) - jobs = (Start, Running, Exited) - checksum = (Checksum,) - finished = (Finish,) - self._handler: Dict[Message, Callable[[Message], None]] = {} - self._transitions = { - None: initialized, - initialized: jobs + checksum + finished, - jobs: jobs + checksum + finished, - checksum: checksum + finished, - } - self._state = None - - def add_handler( - self, states: Tuple[Type[Message], ...], handler: Callable[[Message], None] - ) -> None: - if states in self._handler: - raise ValueError(f"{states} already handled by {self._handler[states]}") - self._handler[states] = handler - - def transition(self, message: Message): - new_state = None - for state in self._handler: - if isinstance(message, state): - new_state = state - - if self._state not in self._transitions or not isinstance( - message, self._transitions[self._state] - ): - logger.error( - f"{message} illegal state transition: {self._state} -> {new_state}" - ) - raise TransitionError( - f"Illegal transition {self._state} -> {new_state} for {message}, " - f"expected to transition into {self._transitions[self._state]}" - ) - - self._handler[new_state](message) - self._state = new_state diff --git a/src/_ert/forward_model_runner/runner.py b/src/_ert/forward_model_runner/runner.py index bd304f3c7d3..62badc28c40 100644 --- a/src/_ert/forward_model_runner/runner.py +++ b/src/_ert/forward_model_runner/runner.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import json import os @@ -5,11 +6,18 @@ from typing import Any, Dict, List from _ert.forward_model_runner.forward_model_step import ForwardModelStep -from _ert.forward_model_runner.reporting.message import Checksum, Finish, Init +from _ert.forward_model_runner.reporting.message import ( + Checksum, + Finish, + Init, + Message, +) class ForwardModelRunner: - def __init__(self, steps_data: Dict[str, Any]): + def __init__( + self, steps_data: Dict[str, Any], reporter_queue: asyncio.Queue[Message] + ): self.steps_data = ( steps_data # On disk, this is called jobs.json for legacy reasons ) @@ -25,7 +33,7 @@ def __init__(self, steps_data: Dict[str, Any]): self.steps: List[ForwardModelStep] = [] for index, step_data in enumerate(steps_data["jobList"]): self.steps.append(ForwardModelStep(step_data, index)) - + self._reporter_queue = reporter_queue self._set_environment() def _read_manifest(self): @@ -49,46 +57,60 @@ def _populate_checksums(self, manifest): info["error"] = f"Expected file {path} not created by forward model!" return manifest - def run(self, names_of_steps_to_run: List[str]): - if not names_of_steps_to_run: - step_queue = self.steps - else: - step_queue = [ - step for step in self.steps if step.name() in names_of_steps_to_run - ] - init_message = Init( - step_queue, - self.simulation_id, - self.ert_pid, - self.ens_id, - self.real_id, - self.experiment_id, - ) - - unused = set(names_of_steps_to_run) - {step.name() for step in step_queue} - if unused: - init_message.with_error( - f"{unused} does not exist. " - f"Available forward_model steps: {[step.name() for step in self.steps]}" + async def run(self, names_of_steps_to_run: List[str]) -> None: + try: + if not names_of_steps_to_run: + step_queue = self.steps + else: + step_queue = [ + step for step in self.steps if step.name() in names_of_steps_to_run + ] + init_message = Init( + step_queue, + self.simulation_id, + self.ert_pid, + self.ens_id, + self.real_id, + self.experiment_id, ) - yield init_message - return - else: - yield init_message - for step in step_queue: - for status_update in step.run(): - yield status_update - if not status_update.success(): - yield Checksum(checksum_dict={}, run_path=os.getcwd()) - yield Finish().with_error( - "Not all forward model steps completed successfully." - ) - return + unused = set(names_of_steps_to_run) - {step.name() for step in step_queue} + if unused: + init_message.with_error( + f"{unused} does not exist. " + f"Available forward_model steps: {[step.name() for step in self.steps]}" + ) + await self.put_event(init_message) + return - checksum_dict = self._populate_checksums(self._read_manifest()) - yield Checksum(checksum_dict=checksum_dict, run_path=os.getcwd()) - yield Finish() + await self.put_event(init_message) + for step in step_queue: + async for status_update in step.run(): + await self.put_event(status_update) + if not status_update.success(): + await self.put_event( + Checksum(checksum_dict={}, run_path=os.getcwd()) + ) + await self.put_event( + Finish().with_error( + f"Not all forward model steps completed successfully ({status_update.error_message})." + ) + ) + return + checksum_dict = self._populate_checksums(self._read_manifest()) + await self.put_event( + Checksum(checksum_dict=checksum_dict, run_path=os.getcwd()) + ) + await self.put_event(Finish()) + return + except asyncio.CancelledError: + await self.put_event(Checksum(checksum_dict={}, run_path=os.getcwd())) + await self.put_event( + Finish().with_error( + "Not all forward model steps completed successfully." + ) + ) + return def _set_environment(self): if self.global_environment: @@ -96,3 +118,6 @@ def _set_environment(self): for env_key, env_val in os.environ.items(): value = value.replace(f"${env_key}", env_val) os.environ[key] = value + + async def put_event(self, event: Message): + await self._reporter_queue.put(event) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index 846740f7479..40887b098c3 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -208,7 +208,7 @@ async def send_event( retries: int = 10, ) -> None: async with Client(url, token, cert, max_retries=retries) as client: - await client._send(event_to_json(event)) + await client.send(event_to_json(event)) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: diff --git a/tests/ert/ui_tests/cli/test_field_parameter.py b/tests/ert/ui_tests/cli/test_field_parameter.py index 6c1c7cb193a..2a490ad2521 100644 --- a/tests/ert/ui_tests/cli/test_field_parameter.py +++ b/tests/ert/ui_tests/cli/test_field_parameter.py @@ -23,6 +23,7 @@ from .run_cli import run_cli +@pytest.mark.timeout(600) def test_field_param_update_using_heat_equation(heat_equation_storage): config = ErtConfig.from_file("config.ert") with open_storage(config.ens_path, mode="w") as storage: diff --git a/tests/ert/ui_tests/gui/test_restart_esmda.py b/tests/ert/ui_tests/gui/test_restart_esmda.py index 05723301520..784150dc477 100644 --- a/tests/ert/ui_tests/gui/test_restart_esmda.py +++ b/tests/ert/ui_tests/gui/test_restart_esmda.py @@ -69,7 +69,7 @@ def test_custom_weights_stored_and_retrieved_from_metadata_esmda( qtbot.mouseClick(run_experiment, Qt.MouseButton.LeftButton) qtbot.waitUntil(lambda: gui.findChild(RunDialog) is not None, timeout=5000) run_dialog = gui.findChild(RunDialog) - qtbot.waitUntil(lambda: run_dialog.is_simulation_done() == True, timeout=20000) + qtbot.waitUntil(lambda: run_dialog.is_simulation_done() == True, timeout=60000) assert ( run_dialog._total_progress_label.text() == "Total progress 100% — Experiment completed." diff --git a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py index 6d15180ce41..7fe541604d1 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py +++ b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py @@ -76,14 +76,14 @@ async def evaluate(self, config, _, __): ) async with Client(config.url + "/dispatch") as dispatch: event = EnsembleStarted(ensemble=self.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event_id += 1 for real in range(0, self.test_reals): real = str(real) event = RealizationUnknown(ensemble=self.id_, real=real) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event_id += 1 for fm_step in range(0, self.fm_steps): @@ -95,7 +95,7 @@ async def evaluate(self, config, _, __): fm_step=fm_step, current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event_id += 1 event = ForwardModelStepSuccess( @@ -104,16 +104,16 @@ async def evaluate(self, config, _, __): fm_step=fm_step, current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event_id += 1 event_id += 1 event = RealizationSuccess(ensemble=self.id_, real=real) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event_id += 1 event = EnsembleSucceeded(ensemble=self.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) @property def cancellable(self) -> bool: diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index 0e66cc99b46..0385d35019f 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -1,66 +1,45 @@ -from functools import partial - import pytest from _ert.forward_model_runner.client import Client, ClientConnectionError -from _ert.threading import ErtThread - -from .ensemble_evaluator_utils import _mock_ws +from tests.ert.utils import _mock_ws_task -def test_invalid_server(): +async def test_invalid_server(): port = 7777 host = "localhost" url = f"ws://{host}:{port}" - with ( - Client(url, max_retries=2, timeout_multiplier=2) as c1, - pytest.raises(ClientConnectionError), - ): - c1.send("hei") + async with Client(url, max_retries=2, timeout_multiplier=2) as c1: + with pytest.raises(ClientConnectionError): + await c1.send("hei") -def test_successful_sending(unused_tcp_port): +async def test_successful_sending(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" messages = [] - mock_ws_thread = ErtThread( - target=partial(_mock_ws, messages=messages), args=(host, unused_tcp_port) - ) - - mock_ws_thread.start() - messages_c1 = ["test_1", "test_2", "test_3", "stop"] - with Client(url) as c1: + messages_c1 = ["test_1", "test_2", "test_3"] + async with _mock_ws_task(host, unused_tcp_port, messages), Client(url) as c1: for msg in messages_c1: - c1.send(msg) - - mock_ws_thread.join() + await c1.send(msg) for msg in messages_c1: assert msg in messages -def test_retry(unused_tcp_port): +async def test_retry(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" messages = [] - mock_ws_thread = ErtThread( - target=partial(_mock_ws, messages=messages, delay_startup=2), - args=( - host, - unused_tcp_port, - ), - ) - - mock_ws_thread.start() - messages_c1 = ["test_1", "test_2", "test_3", "stop"] - with Client(url, max_retries=2, timeout_multiplier=2) as c1: + messages_c1 = ["test_1", "test_2", "test_3"] + async with ( + _mock_ws_task(host, unused_tcp_port, messages, delay_startup=2), + Client(url, max_retries=2, timeout_multiplier=2) as c1, + ): for msg in messages_c1: - c1.send(msg) - - mock_ws_thread.join() + await c1.send(msg) for msg in messages_c1: assert msg in messages diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index fdae28e50a0..0b8711dbeb3 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -172,7 +172,7 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, @@ -180,7 +180,7 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): fm_step="0", error_msg="error", ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: try: @@ -212,7 +212,7 @@ def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: fm_step="0", current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) # reconnect new monitor async with Monitor(config_info) as new_monitor: @@ -270,7 +270,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch1._send(event_to_json(event)) + await dispatch1.send(event_to_json(event)) # second dispatch endpoint client informs that forward model 0 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -278,7 +278,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that forward model 1 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -286,7 +286,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="1", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) final_snapshot = EnsembleSnapshot() @@ -330,12 +330,12 @@ def check_if_all_fm_running(snapshot: EnsembleSnapshot) -> bool: fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that job 1 is failed event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, real="1", fm_step="1", error_msg="error" ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) def check_if_final_snapshot_is_complete(final_snapshot: EnsembleSnapshot) -> bool: try: @@ -408,7 +408,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch1._send(event_to_json(event)) + await dispatch1.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 0 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -416,7 +416,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 0 is done event = ForwardModelStepSuccess( ensemble=evaluator.ensemble.id_, @@ -424,7 +424,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 1 is failed event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, @@ -432,7 +432,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="1", error_msg="error", ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) event = await anext(events) snapshot = EnsembleSnapshot.from_nested_dict(event.snapshot) @@ -496,17 +496,17 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use): assert type(snapshot_event) is EESnapshot async with Client(url + "/dispatch", cert=cert, token=token) as dispatch: event = EnsembleStarted(ensemble=evaluator.ensemble.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = RealizationSuccess( ensemble=evaluator.ensemble.id_, real="0", queue_event_type="" ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = RealizationSuccess( ensemble=evaluator.ensemble.id_, real="1", queue_event_type="" ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = EnsembleSucceeded(ensemble=evaluator.ensemble.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) await monitor.signal_done() diff --git a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py index 0575e78b954..59bb49899d6 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py @@ -1,6 +1,6 @@ +import asyncio import os import sys -import time from unittest.mock import patch import pytest @@ -26,18 +26,10 @@ Running, Start, ) -from _ert.forward_model_runner.reporting.statemachine import TransitionError -from tests.ert.utils import _mock_ws_thread +from tests.ert.utils import _mock_ws_task, async_wait_until -def _wait_until(condition, timeout, fail_msg): - start = time.time() - while not condition(): - assert start + timeout > time.time(), fail_msg - time.sleep(0.1) - - -def test_report_with_successful_start_message_argument(unused_tcp_port): +async def test_report_with_successful_start_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -45,10 +37,12 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Start(fmstep1)) - reporter.report(Finish()) + + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Start(fmstep1)) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) @@ -58,9 +52,10 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): assert event.fm_step == "0" assert os.path.basename(event.std_out) == "stdout" assert os.path.basename(event.std_err) == "stderr" + reporter.cancel() -def test_report_with_failed_start_message_argument(unused_tcp_port): +async def test_report_with_failed_start_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -70,13 +65,13 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) msg = Start(fmstep1).with_error("massive_failure") - - reporter.report(msg) - reporter.report(Finish()) + await reporter.report(msg) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 2 event = event_from_json(lines[1]) @@ -84,7 +79,7 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): assert event.error_msg == "massive_failure" -def test_report_with_successful_exit_message_argument(unused_tcp_port): +async def test_report_with_successful_exit_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -93,17 +88,18 @@ def test_report_with_successful_exit_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Exited(fmstep1, 0)) - reporter.report(Finish().with_error("failed")) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Exited(fmstep1, 0)) + await reporter.report(Finish().with_error("failed")) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) assert type(event) is ForwardModelStepSuccess -def test_report_with_failed_exit_message_argument(unused_tcp_port): +async def test_report_with_failed_exit_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -112,10 +108,11 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) @@ -123,7 +120,7 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port): assert event.error_msg == "massive_failure" -def test_report_with_running_message_argument(unused_tcp_port): +async def test_report_with_running_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -132,10 +129,11 @@ def test_report_with_running_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 event = event_from_json(lines[0]) @@ -144,7 +142,7 @@ def test_report_with_running_message_argument(unused_tcp_port): assert event.current_memory_usage == 10 -def test_report_only_job_running_for_successful_run(unused_tcp_port): +async def test_report_only_job_running_for_successful_run(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -153,15 +151,16 @@ def test_report_only_job_running_for_successful_run(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Finish()) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Finish()) + await reporter.join() assert len(lines) == 1 -def test_report_with_failed_finish_message_argument(unused_tcp_port): +async def test_report_with_failed_finish_message_argument(unused_tcp_port): host = "localhost" url = f"ws://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -170,32 +169,19 @@ def test_report_with_failed_finish_message_argument(unused_tcp_port): ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Finish().with_error("massive_failure")) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Finish().with_error("massive_failure")) + await reporter.join() assert len(lines) == 1 -def test_report_inconsistent_events(unused_tcp_port): - host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) - - lines = [] - with ( - _mock_ws_thread(host, unused_tcp_port, lines), - pytest.raises( - TransitionError, - match=r"Illegal transition None -> \(MessageType,\)", - ), - ): - reporter.report(Finish()) - - @pytest.mark.integration_test -def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): +async def test_report_with_failed_reporter_but_finished_jobs( + unused_tcp_port, monkeypatch +): # this is to show when the reporter fails ert won't crash nor # staying hanging but instead finishes up the job; # see reporter._event_publisher_thread.join() @@ -204,8 +190,8 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # which then sets _timeout_timestamp=None mock_send_retry_time = 2 - def mock_send(msg): - time.sleep(mock_send_retry_time) + async def mock_send(msg): + await asyncio.sleep(mock_send_retry_time) raise ClientConnectionError("Sending failed!") host = "localhost" @@ -216,20 +202,24 @@ def mock_send(msg): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + async with _mock_ws_task(host, unused_tcp_port, lines): with patch( "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) ): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)) + ) # set _stop_timestamp - reporter.report(Finish()) - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() + await reporter.report(Finish()) + await reporter.join() # set _stop_timestamp to None only when timer stopped - assert reporter._timeout_timestamp is None assert len(lines) == 0, "expected 0 Job running messages" @@ -238,7 +228,7 @@ def mock_send(msg): @pytest.mark.skipif( sys.platform.startswith("darwin"), reason="Performance can be flaky" ) -def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): +async def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails but reconnects # reporter still manages to send events and completes fine # see assert reporter._timeout_timestamp is not None @@ -246,27 +236,33 @@ def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): # it finished succesfully mock_send_retry_time = 0.1 - def send_func(msg): - time.sleep(mock_send_retry_time) + async def send_func(msg): + await asyncio.sleep(mock_send_retry_time) raise ClientConnectionError("Sending failed!") host = "localhost" url = f"ws://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + async with _mock_ws_task(host, unused_tcp_port, lines): with patch("_ert.forward_model_runner.client.Client.send") as patched_send: + reporter = Event(evaluator_url=url) patched_send.side_effect = send_func - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10)) + ) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10)) + ) - _wait_until( + await async_wait_until( condition=lambda: patched_send.call_count == 3, timeout=10, fail_msg="10 seconds should be sufficient to send three events", @@ -274,23 +270,22 @@ def send_func(msg): # reconnect and continue sending events # set _stop_timestamp - reporter.report(Finish()) - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() + await reporter.report(Finish()) + await reporter.join() # set _stop_timestamp was not set to None since the reporter finished on time assert reporter._timeout_timestamp is not None assert len(lines) == 3, "expected 3 Job running messages" @pytest.mark.integration_test -def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): +async def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): # Whenever the receiver end closes the connection, a ConnectionClosedOK is raised # The reporter should exit the publisher thread gracefully and not send any # more events mock_send_retry_time = 3 - def mock_send(msg): - time.sleep(mock_send_retry_time) + async def mock_send(msg): + await asyncio.sleep(mock_send_retry_time) raise ClientConnectionClosedOK("Connection Closed") host = "localhost" @@ -300,13 +295,13 @@ def mock_send(msg): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) + async with _mock_ws_task(host, unused_tcp_port, lines): + await reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) # sleep until both Running events have been received - _wait_until( + await async_wait_until( condition=lambda: len(lines) == 2, timeout=10, fail_msg="Should not take 10 seconds to send two events", @@ -315,21 +310,19 @@ def mock_send(msg): with patch( "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) ): - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) + await reporter.report( + Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10)) + ) # Make sure the publisher thread exits because it got # ClientConnectionClosedOK. If it hangs it could indicate that the # exception is not caught/handled correctly - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() - - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) - reporter.report(Finish()) + await reporter.join() - # set _stop_timestamp was not set to None since the reporter finished on time - assert reporter._timeout_timestamp is not None + await reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) + await reporter.report(Finish()) # The Running(fmstep1, 300, 10) is popped from the queue, but never sent. # The following Running is added to queue along with the sentinel - assert reporter._event_queue.qsize() == 2 + assert reporter._event_queue.qsize() == 2, reporter._event_queue # None of the messages after ClientConnectionClosedOK was raised, has been sent assert len(lines) == 2, "expected 2 Job running messages" diff --git a/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py index d64af04d83c..589afb79852 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py @@ -22,24 +22,24 @@ def reporter(): @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_init_message_argument(reporter): +async def test_report_with_init_message_argument(reporter): r = reporter fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "/stdout", "stderr": "/stderr"}, 0 ) - r.report(Init([fmstep1], 1, 19)) + await r.report(Init([fmstep1], 1, 19)) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 assert "Current host" in f.readline(), "STATUS file missing expected value" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert '"name": "fmstep1"' in content, "status.json missing fmstep1" assert '"status": "Waiting"' in content, "status.json missing Waiting status" @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_successful_start_message_argument(reporter): +async def test_report_with_successful_start_message_argument(reporter): msg = Start( ForwardModelStep( { @@ -54,33 +54,33 @@ def test_report_with_successful_start_message_argument(reporter): ) reporter.status_dict = reporter._init_job_status_dict(msg.timestamp, 0, [msg.job]) - reporter.report(msg) + await reporter.report(msg) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 assert "fmstep1" in f.readline(), "STATUS file missing fmstep1" - with open(LOG_file, "r", encoding="utf-8") as f: + with open(LOG_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 assert ( "Calling: /bin/sh --foo 1 --bar 2" in f.readline() ), """JOB_LOG file missing executable and arguments""" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert '"status": "Running"' in content, "status.json missing Running status" assert '"start_time": null' not in content, "start_time not set" @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_failed_start_message_argument(reporter): +async def test_report_with_failed_start_message_argument(reporter): msg = Start(ForwardModelStep({"name": "fmstep1"}, 0)).with_error("massive_failure") reporter.status_dict = reporter._init_job_status_dict(msg.timestamp, 0, [msg.job]) - reporter.report(msg) + await reporter.report(msg) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 assert ( "EXIT: -10/massive_failure" in f.readline() ), "STATUS file missing EXIT message" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert '"status": "Failure"' in content, "status.json missing Failure status" assert ( @@ -92,29 +92,29 @@ def test_report_with_failed_start_message_argument(reporter): @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_successful_exit_message_argument(reporter): +async def test_report_with_successful_exit_message_argument(reporter): msg = Exited(ForwardModelStep({"name": "fmstep1"}, 0), 0) reporter.status_dict = reporter._init_job_status_dict(msg.timestamp, 0, [msg.job]) - reporter.report(msg) + await reporter.report(msg) - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert '"status": "Success"' in content, "status.json missing Success status" @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_failed_exit_message_argument(reporter): +async def test_report_with_failed_exit_message_argument(reporter): msg = Exited(ForwardModelStep({"name": "fmstep1"}, 0), 1).with_error( "massive_failure" ) reporter.status_dict = reporter._init_job_status_dict(msg.timestamp, 0, [msg.job]) - reporter.report(msg) + await reporter.report(msg) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 assert "EXIT: 1/massive_failure" in f.readline() - with open(ERROR_file, "r", encoding="utf-8") as f: + with open(ERROR_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert "fmstep1" in content, "ERROR file missing fmstep" assert ( @@ -123,7 +123,7 @@ def test_report_with_failed_exit_message_argument(reporter): assert ( "stderr: Not redirected" in content ), "ERROR had invalid stderr information" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert '"status": "Failure"' in content, "status.json missing Failure status" assert ( @@ -133,16 +133,16 @@ def test_report_with_failed_exit_message_argument(reporter): @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_running_message_argument(reporter): +async def test_report_with_running_message_argument(reporter): msg = Running( ForwardModelStep({"name": "fmstep1"}, 0), ProcessTreeStatus(max_rss=100, rss=10, cpu_seconds=1.1), ) reporter.status_dict = reporter._init_job_status_dict(msg.timestamp, 0, [msg.job]) - reporter.report(msg) + await reporter.report(msg) - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, "r", encoding="utf-8") as f: # noqa: ASYNC230 content = "".join(f.readlines()) assert '"status": "Running"' in content, "status.json missing status" assert ( @@ -155,11 +155,11 @@ def test_report_with_running_message_argument(reporter): @pytest.mark.usefixtures("use_tmpdir") -def test_report_with_successful_finish_message_argument(reporter): +async def test_report_with_successful_finish_message_argument(reporter): msg = Finish() reporter.status_dict = reporter._init_job_status_dict(msg.timestamp, 0, []) - reporter.report(msg) + await reporter.report(msg) @pytest.mark.usefixtures("use_tmpdir") @@ -198,7 +198,7 @@ def test_old_file_deletion(reporter): @pytest.mark.usefixtures("use_tmpdir") -def test_status_file_is_correct(reporter): +async def test_status_file_is_correct(reporter): """The STATUS file is a file to which we append data about steps as they are run. So this involves multiple reports, and should be tested as such. @@ -213,7 +213,7 @@ def test_status_file_is_correct(reporter): exited_j_2 = Exited(j_2, 9).with_error("failed horribly") for msg in [init, start_j_1, exited_j_1, start_j_2, exited_j_2]: - reporter.report(msg) + await reporter.report(msg) expected_j1_line = ( f"{j_1.name():32}: {start_j_1.timestamp:%H:%M:%S} .... " @@ -226,7 +226,7 @@ def test_status_file_is_correct(reporter): f"EXIT: {exited_j_2.exit_code}/{exited_j_2.error_message}\n" ) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, "r", encoding="utf-8") as f: # noqa: ASYNC230 for expected in [ "Current host", expected_j1_line, diff --git a/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py b/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py index dc4e9103b40..709c88fe0fb 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py +++ b/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py @@ -1,3 +1,4 @@ +import asyncio import json import os import os.path @@ -75,13 +76,13 @@ def set_up_environ(): @pytest.mark.usefixtures("use_tmpdir") -def test_missing_joblist_json(): +async def test_missing_joblist_json(): with pytest.raises(KeyError): - ForwardModelRunner({}) + ForwardModelRunner({}, asyncio.Queue()) @pytest.mark.usefixtures("use_tmpdir") -def test_run_output_rename(): +async def test_run_output_rename(): fm_step = { "name": "TEST_FMSTEP", "executable": "mkdir", @@ -90,17 +91,22 @@ def test_run_output_rename(): } fm_step_list = [fm_step, fm_step, fm_step, fm_step, fm_step] - fmr = ForwardModelRunner(create_jobs_json(fm_step_list)) + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(create_jobs_json(fm_step_list), event_queue) + await fmr.run([]) - for status in enumerate(fmr.run([])): + index = 0 + while not event_queue.empty(): + status = (index, await event_queue.get()) if isinstance(status, Start): assert status.job is not None assert status.job.std_err == f"err.{status.job.index}" assert status.job.std_out == f"out.{status.job.index}" + index += 1 @pytest.mark.usefixtures("use_tmpdir") -def test_run_multiple_ok(): +async def test_run_multiple_ok(): fm_step_list = [] dir_list = ["1", "2", "3", "4", "5"] for fm_step_index in dir_list: @@ -112,14 +118,18 @@ def test_run_multiple_ok(): "argList": ["-p", "-v", fm_step_index], } fm_step_list.append(fm_step) + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(create_jobs_json(fm_step_list), event_queue) + await fmr.run([]) - fmr = ForwardModelRunner(create_jobs_json(fm_step_list)) + exit_message_count = 0 + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Exited): + exit_message_count += 1 + assert event.exit_code == 0 - statuses = [s for s in list(fmr.run([])) if isinstance(s, Exited)] - - assert len(statuses) == 5 - for status in statuses: - assert status.exit_code == 0 + assert exit_message_count == 5 for dir_number in dir_list: assert os.path.isdir(dir_number) @@ -129,7 +139,7 @@ def test_run_multiple_ok(): @pytest.mark.usefixtures("use_tmpdir") -def test_when_forward_model_contains_multiple_steps_just_one_checksum_status_is_given(): +async def test_when_forward_model_contains_multiple_steps_just_one_checksum_status_is_given(): fm_step_list = [] file_list = ["1", "2", "3", "4", "5"] manifest = {} @@ -143,18 +153,23 @@ def test_when_forward_model_contains_multiple_steps_just_one_checksum_status_is_ "argList": [fm_step_index], } fm_step_list.append(fm_step) - with open("manifest.json", "w", encoding="utf-8") as f: + with open("manifest.json", "w", encoding="utf-8") as f: # noqa: ASYNC230 json.dump(manifest, f) - fmr = ForwardModelRunner(create_jobs_json(fm_step_list)) - - statuses = [s for s in list(fmr.run([])) if isinstance(s, Checksum)] + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(create_jobs_json(fm_step_list), event_queue) + await fmr.run([]) + statuses = [] + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Checksum): + statuses.append(event) assert len(statuses) == 1 assert len(statuses[0].data) == 5 @pytest.mark.usefixtures("use_tmpdir") -def test_when_manifest_file_is_not_created_by_fm_runner_checksum_contains_error(): +async def test_when_manifest_file_is_not_created_by_fm_runner_checksum_contains_error(): fm_step_list = [] file_name = "test" manifest = {"file_1": f"{file_name}"} @@ -168,12 +183,16 @@ def test_when_manifest_file_is_not_created_by_fm_runner_checksum_contains_error( "argList": ["not_test"], } ) - with open("manifest.json", "w", encoding="utf-8") as f: + with open("manifest.json", "w", encoding="utf-8") as f: # noqa: ASYNC230 json.dump(manifest, f) - - fmr = ForwardModelRunner(create_jobs_json(fm_step_list)) - - checksum_msg = [s for s in list(fmr.run([])) if isinstance(s, Checksum)] + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(create_jobs_json(fm_step_list), event_queue) + await fmr.run([]) + checksum_msg = [] + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Checksum): + checksum_msg.append(event) assert len(checksum_msg) == 1 info = checksum_msg[0].data["file_1"] assert "md5sum" not in info @@ -185,7 +204,7 @@ def test_when_manifest_file_is_not_created_by_fm_runner_checksum_contains_error( @pytest.mark.usefixtures("use_tmpdir") -def test_run_multiple_fail_only_runs_one(): +async def test_run_multiple_fail_only_runs_one(): fm_step_list = [] for index in range(1, 6): fm_step = { @@ -200,10 +219,14 @@ def test_run_multiple_fail_only_runs_one(): ], } fm_step_list.append(fm_step) - - fmr = ForwardModelRunner(create_jobs_json(fm_step_list)) - - statuses = [s for s in list(fmr.run([])) if isinstance(s, Exited)] + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(create_jobs_json(fm_step_list), event_queue) + await fmr.run([]) + statuses = [] + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Exited): + statuses.append(event) assert len(statuses) == 1 for i, status in enumerate(statuses): @@ -211,8 +234,8 @@ def test_run_multiple_fail_only_runs_one(): @pytest.mark.usefixtures("use_tmpdir") -def test_exec_env(): - with open("exec_env.py", "w", encoding="utf-8") as f: +async def test_exec_env(): + with open("exec_env.py", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( """#!/usr/bin/env python\n import os @@ -224,7 +247,7 @@ def test_exec_env(): ) os.chmod("exec_env.py", stat.S_IEXEC + stat.S_IREAD) - with open("EXEC_ENV", "w", encoding="utf-8") as f: + with open("EXEC_ENV", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write("EXECUTABLE exec_env.py\n") f.write("EXEC_ENV TEST_ENV 123\n") @@ -232,7 +255,7 @@ def test_exec_env(): name=None, config_file="EXEC_ENV" ) - with open("jobs.json", mode="w", encoding="utf-8") as fptr: + with open("jobs.json", mode="w", encoding="utf-8") as fptr: # noqa: ASYNC230 ert_config = ErtConfig(forward_model_steps=[forward_model]) json.dump( forward_model_data_to_json( @@ -245,21 +268,26 @@ def test_exec_env(): fptr, ) - with open("jobs.json", "r", encoding="utf-8") as f: + with open("jobs.json", "r", encoding="utf-8") as f: # noqa: ASYNC230 jobs_json = json.load(f) - for msg in list(ForwardModelRunner(jobs_json).run([])): - if isinstance(msg, Start): - with open("exec_env_exec_env.json", encoding="utf-8") as f: + event_queue = asyncio.Queue() + forward_model_runner = ForwardModelRunner(jobs_json, event_queue) + await forward_model_runner.run([]) + + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Start): + with open("exec_env_exec_env.json", encoding="utf-8") as f: # noqa: ASYNC230 exec_env = json.load(f) assert exec_env["TEST_ENV"] == "123" - if isinstance(msg, Exited): - assert msg.exit_code == 0 + if isinstance(event, Exited): + assert event.exit_code == 0 @pytest.mark.usefixtures("use_tmpdir") -def test_env_var_available_inside_step_context(): - with open("run_me.py", "w", encoding="utf-8") as f: +async def test_env_var_available_inside_step_context(): + with open("run_me.py", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( """#!/usr/bin/env python\n import os @@ -268,12 +296,12 @@ def test_env_var_available_inside_step_context(): ) os.chmod("run_me.py", stat.S_IEXEC + stat.S_IREAD) - with open("RUN_ENV", "w", encoding="utf-8") as f: + with open("RUN_ENV", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write("EXECUTABLE run_me.py\n") f.write("ENV TEST_ENV 123\n") step = _forward_model_step_from_config_file(name=None, config_file="RUN_ENV") - with open("jobs.json", mode="w", encoding="utf-8") as fptr: + with open("jobs.json", mode="w", encoding="utf-8") as fptr: # noqa: ASYNC230 ert_config = ErtConfig(forward_model_steps=[step]) json.dump( forward_model_data_to_json( @@ -286,23 +314,26 @@ def test_env_var_available_inside_step_context(): fptr, ) - with open("jobs.json", "r", encoding="utf-8") as f: + with open("jobs.json", "r", encoding="utf-8") as f: # noqa: ASYNC230 jobs_json = json.load(f) # Check ENV variable not available outside of step context assert "TEST_ENV" not in os.environ - - for msg in list(ForwardModelRunner(jobs_json).run([])): - if isinstance(msg, Exited): - assert msg.exit_code == 0 - + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(jobs_json, event_queue) + await fmr.run([]) + + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Exited): + assert event.exit_code == 0 # Check ENV variable not available outside of step context assert "TEST_ENV" not in os.environ @pytest.mark.usefixtures("use_tmpdir") -def test_default_env_variables_available_inside_fm_step_context(): - with open("run_me.py", "w", encoding="utf-8") as f: +async def test_default_env_variables_available_inside_fm_step_context(): + with open("run_me.py", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( textwrap.dedent( """\ @@ -316,11 +347,11 @@ def test_default_env_variables_available_inside_fm_step_context(): ) os.chmod("run_me.py", stat.S_IEXEC + stat.S_IREAD) - with open("RUN_ENV", "w", encoding="utf-8") as f: + with open("RUN_ENV", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write("EXECUTABLE run_me.py\n") step = _forward_model_step_from_config_file(name=None, config_file="RUN_ENV") - with open("jobs.json", mode="w", encoding="utf-8") as fptr: + with open("jobs.json", mode="w", encoding="utf-8") as fptr: # noqa: ASYNC230 ert_config = ErtConfig( forward_model_steps=[step], substitutions=Substitutions({"": "./"}), @@ -336,16 +367,20 @@ def test_default_env_variables_available_inside_fm_step_context(): fptr, ) - with open("jobs.json", "r", encoding="utf-8") as f: + with open("jobs.json", "r", encoding="utf-8") as f: # noqa: ASYNC230 jobs_json = json.load(f) # Check default ENV variable not available outside of step context for k in ForwardModelStep.default_env: assert k not in os.environ - - for msg in list(ForwardModelRunner(jobs_json).run([])): - if isinstance(msg, Exited): - assert msg.exit_code == 0 + event_queue = asyncio.Queue() + fmr = ForwardModelRunner(jobs_json, event_queue) + await fmr.run([]) + + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Exited): + assert event.exit_code == 0 # Check default ENV variable not available outside of step context for k in ForwardModelStep.default_env: diff --git a/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner_runtime_kw.py b/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner_runtime_kw.py index c5526e7dbff..2ecad513e73 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner_runtime_kw.py +++ b/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner_runtime_kw.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from _ert.forward_model_runner.reporting.message import Exited, Finish, Start @@ -5,7 +7,7 @@ @pytest.mark.usefixtures("use_tmpdir") -def test_run_one_fm_step_with_an_integer_arg_is_actually_a_fractional(): +async def test_run_one_fm_step_with_an_integer_arg_is_actually_a_fractional(): fm_step_1 = { "name": "FM_STEP_1", "executable": "echo", @@ -18,18 +20,22 @@ def test_run_one_fm_step_with_an_integer_arg_is_actually_a_fractional(): } data = {"jobList": [fm_step_1]} + event_queue = asyncio.Queue() + runner = ForwardModelRunner(data, event_queue) + await runner.run([]) + start_msg_count = 0 + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Start): + start_msg_count += 1 + assert not event.success(), "fm_step should not start with success" - runner = ForwardModelRunner(data) - statuses = list(runner.run([])) - starts = [e for e in statuses if isinstance(e, Start)] - - assert len(starts) == 1, "There should be 1 start message" - assert not starts[0].success(), "fm_step should not start with success" + assert start_msg_count == 1, "There should be 1 start message" @pytest.mark.usefixtures("use_tmpdir") -def test_run_given_one_fm_step_with_missing_file_and_one_file_present(): - with open("a_file", "w", encoding="utf-8") as f: +async def test_run_given_one_fm_step_with_missing_file_and_one_file_present(): + with open("a_file", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write("Hello") executable = "echo" @@ -59,10 +65,14 @@ def test_run_given_one_fm_step_with_missing_file_and_one_file_present(): data = { "jobList": [fm_step_0, fm_step_1], } - - runner = ForwardModelRunner(data) - - statuses = list(runner.run([])) + event_queue = asyncio.Queue() + runner = ForwardModelRunner(data, event_queue) + await runner.run([]) + + statuses = [] + while not event_queue.empty(): + event = await event_queue.get() + statuses.append(event) starts = [e for e in statuses if isinstance(e, Start)] assert len(starts) == 2, "There should be 2 start messages" diff --git a/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py b/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py index 00718d4a95a..c10a82ee56a 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py +++ b/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py @@ -9,6 +9,7 @@ import pytest +from _ert.forward_model_runner import forward_model_step from _ert.forward_model_runner.forward_model_step import ( ForwardModelStep, _get_processtree_data, @@ -17,33 +18,41 @@ @patch("_ert.forward_model_runner.forward_model_step.check_executable") -@patch("_ert.forward_model_runner.forward_model_step.Popen") +@patch("_ert.forward_model_runner.forward_model_step.asyncio.create_subprocess_exec") @patch("_ert.forward_model_runner.forward_model_step.Process") @pytest.mark.usefixtures("use_tmpdir") -def test_run_with_process_failing(mock_process, mock_popen, mock_check_executable): +async def test_run_with_process_failing( + mock_process, mock_create_subprocess_exec, mock_check_executable +): fmstep = ForwardModelStep({}, 0) mock_check_executable.return_value = "" type(mock_process.return_value.memory_info.return_value).rss = PropertyMock( return_value=10 ) - mock_process.return_value.wait.return_value = 9 - run = fmstep.run() + async def mocked_subprocess_wait(*args, **kwargs): + return 9 - assert isinstance(next(run), Start), "run did not yield Start message" - assert isinstance(next(run), Running), "run did not yield Running message" - exited = next(run) + mock_create_subprocess_exec.return_value.wait = mocked_subprocess_wait + + run = (a async for a in fmstep.run()) + + assert isinstance(await anext(run), Start), "run did not yield Start message" + assert isinstance(await anext(run), Running), "run did not yield Running message" + exited = await anext(run) assert isinstance(exited, Exited), "run did not yield Exited message" - assert exited.exit_code == 9, "Exited message had unexpected exit code" + assert ( + exited.exit_code == 9 + ), f"Exited message had unexpected exit code '{exited.exit_code}'" - with pytest.raises(StopIteration): - next(run) + with pytest.raises(StopAsyncIteration): + await anext(run) @pytest.mark.flaky(reruns=10) @pytest.mark.integration_test @pytest.mark.usefixtures("use_tmpdir") -def test_cpu_seconds_can_detect_multiprocess(): +async def test_cpu_seconds_can_detect_multiprocess(): """Run a fm step that sets of two simultaneous processes that each run for 2 second. We should be able to detect the total cpu seconds consumed to be roughly 2 seconds. @@ -55,7 +64,7 @@ def test_cpu_seconds_can_detect_multiprocess(): sub-processes. """ pythonscript = "busy.py" - with open(pythonscript, "w", encoding="utf-8") as pyscript: + with open(pythonscript, "w", encoding="utf-8") as pyscript: # noqa: ASYNC230 pyscript.write( textwrap.dedent( """\ @@ -66,7 +75,7 @@ def test_cpu_seconds_can_detect_multiprocess(): ) ) scriptname = "saturate_cpus.sh" - with open(scriptname, "w", encoding="utf-8") as script: + with open(scriptname, "w", encoding="utf-8") as script: # noqa: ASYNC230 script.write( textwrap.dedent( """\ @@ -85,7 +94,7 @@ def test_cpu_seconds_can_detect_multiprocess(): ) fmstep.MEMORY_POLL_PERIOD = 0.05 cpu_seconds = 0.0 - for status in fmstep.run(): + async for status in fmstep.run(): if isinstance(status, Running): cpu_seconds = max(cpu_seconds, status.memory_status.cpu_seconds) assert 2.5 < cpu_seconds < 4.5 @@ -94,10 +103,10 @@ def test_cpu_seconds_can_detect_multiprocess(): @pytest.mark.integration_test @pytest.mark.flaky(reruns=5) @pytest.mark.usefixtures("use_tmpdir") -def test_memory_usage_counts_grandchildren(): +async def test_memory_usage_counts_grandchildren(): scriptname = "recursive_memory_hog.py" blobsize = 1e7 - with open(scriptname, "w", encoding="utf-8") as script: + with open(scriptname, "w", encoding="utf-8") as script: # noqa: ASYNC230 script.write( textwrap.dedent( """\ @@ -122,7 +131,7 @@ def test_memory_usage_counts_grandchildren(): executable = os.path.realpath(scriptname) os.chmod(scriptname, stat.S_IRWXU | stat.S_IRWXO | stat.S_IRWXG) - def max_memory_per_subprocess_layer(layers: int) -> int: + async def max_memory_per_subprocess_layer(layers: int) -> int: fmstep = ForwardModelStep( { "executable": executable, @@ -132,7 +141,7 @@ def max_memory_per_subprocess_layer(layers: int) -> int: ) fmstep.MEMORY_POLL_PERIOD = 0.01 max_seen = 0 - for status in fmstep.run(): + async for status in fmstep.run(): if isinstance(status, Running): max_seen = max(max_seen, status.memory_status.max_rss) return max_seen @@ -143,7 +152,7 @@ def max_memory_per_subprocess_layer(layers: int) -> int: # when running the program. memory_per_numbers_list = sys.getsizeof(int(0)) * blobsize * 0.90 - max_seens = [max_memory_per_subprocess_layer(layers) for layers in range(3)] + max_seens = [await max_memory_per_subprocess_layer(layers) for layers in range(3)] assert max_seens[0] + memory_per_numbers_list < max_seens[1] assert max_seens[1] + memory_per_numbers_list < max_seens[2] @@ -174,13 +183,17 @@ def oneshot(self): return contextlib.nullcontext() -def test_cpu_seconds_for_process_with_children(): - (_, cpu_seconds, _) = _get_processtree_data(MockedProcess(123)) +def test_cpu_seconds_for_process_with_children(monkeypatch): + def mocked_process(pid): + return MockedProcess(123) + + monkeypatch.setattr(forward_model_step, "Process", mocked_process) + (_, cpu_seconds, _) = _get_processtree_data(123) assert cpu_seconds == 123 / 10.0 + 124 / 10.0 @pytest.mark.skipif(sys.platform.startswith("darwin"), reason="No oom_score on MacOS") -def test_oom_score_is_max_over_processtree(): +def test_oom_score_is_max_over_processtree(monkeypatch): def read_text_side_effect(self: pathlib.Path, *args, **kwargs): if self.absolute() == pathlib.Path("/proc/123/oom_score"): return "234" @@ -189,13 +202,18 @@ def read_text_side_effect(self: pathlib.Path, *args, **kwargs): with patch("pathlib.Path.read_text", autospec=True) as mocked_read_text: mocked_read_text.side_effect = read_text_side_effect - (_, _, oom_score) = _get_processtree_data(MockedProcess(123)) + + def mocked_process(pid): + return MockedProcess(123) + + monkeypatch.setattr(forward_model_step, "Process", mocked_process) + (_, _, oom_score) = _get_processtree_data(123) assert oom_score == 456 @pytest.mark.usefixtures("use_tmpdir") -def test_run_fails_using_exit_bash_builtin(): +async def test_run_fails_using_exit_bash_builtin(): fmstep = ForwardModelStep( { "name": "exit 1", @@ -207,7 +225,7 @@ def test_run_fails_using_exit_bash_builtin(): 0, ) - statuses = list(fmstep.run()) + statuses = [status async for status in fmstep.run()] assert len(statuses) == 3, "Wrong statuses count" assert statuses[2].exit_code == 1, "Exited status wrong exit_code" @@ -217,7 +235,7 @@ def test_run_fails_using_exit_bash_builtin(): @pytest.mark.usefixtures("use_tmpdir") -def test_run_with_defined_executable_but_missing(): +async def test_run_with_defined_executable_but_missing(): executable = os.path.join(os.getcwd(), "this/is/not/a/file") fmstep = ForwardModelStep( { @@ -229,15 +247,15 @@ def test_run_with_defined_executable_but_missing(): 0, ) - start_message = next(fmstep.run()) + start_message = await anext(fmstep.run()) assert isinstance(start_message, Start) assert "this/is/not/a/file is not a file" in start_message.error_message @pytest.mark.usefixtures("use_tmpdir") -def test_run_with_empty_executable(): +async def test_run_with_empty_executable(): empty_executable = os.path.join(os.getcwd(), "foo") - with open(empty_executable, "a", encoding="utf-8"): + with open(empty_executable, "a", encoding="utf-8"): # noqa: ASYNC230 pass st = os.stat(empty_executable) os.chmod(empty_executable, st.st_mode | stat.S_IEXEC) @@ -251,7 +269,7 @@ def test_run_with_empty_executable(): }, 0, ) - run_status = list(fmstep.run()) + run_status = [status async for status in fmstep.run()] assert len(run_status) == 2 start_msg, exit_msg = run_status assert isinstance(start_msg, Start) @@ -261,9 +279,9 @@ def test_run_with_empty_executable(): @pytest.mark.usefixtures("use_tmpdir") -def test_run_with_defined_executable_no_exec_bit(): +async def test_run_with_defined_executable_no_exec_bit(): non_executable = os.path.join(os.getcwd(), "foo") - with open(non_executable, "a", encoding="utf-8"): + with open(non_executable, "a", encoding="utf-8"): # noqa: ASYNC230 pass fmstep = ForwardModelStep( @@ -275,7 +293,7 @@ def test_run_with_defined_executable_no_exec_bit(): }, 0, ) - start_message = next(fmstep.run()) + start_message = await anext(fmstep.run()) assert isinstance(start_message, Start) assert "foo is not an executable" in start_message.error_message @@ -301,7 +319,7 @@ def test_init_fmstep_with_std(): assert fmstep.std_out == "exit_out" -def test_makedirs(monkeypatch, tmp_path): +async def test_makedirs(monkeypatch, tmp_path): """ Test that the directories for the output process streams are created if they don't exist @@ -315,7 +333,7 @@ def test_makedirs(monkeypatch, tmp_path): }, 0, ) - for _ in fmstep.run(): + async for _ in fmstep.run(): pass assert (tmp_path / "a/file").is_file() assert (tmp_path / "b/c/file").is_file() diff --git a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py index 0befe45c5a9..1c7a0cb7f12 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py +++ b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py @@ -1,37 +1,38 @@ from __future__ import annotations +import asyncio import glob import importlib import json import os -import signal import stat import subprocess import sys from subprocess import Popen from textwrap import dedent -from threading import Lock -from unittest.mock import mock_open, patch import pandas as pd import psutil import pytest import _ert.forward_model_runner.cli -from _ert.forward_model_runner.cli import JOBS_FILE, _setup_reporters, main +from _ert.forward_model_runner.cli import ( + JOBS_FILE, + ForwardModelRunnerException, + _setup_reporters, + main, +) from _ert.forward_model_runner.forward_model_step import killed_by_oom from _ert.forward_model_runner.reporting import Event, Interactive from _ert.forward_model_runner.reporting.message import Finish, Init -from _ert.threading import ErtThread -from tests.ert.utils import _mock_ws_thread, wait_until - -from .test_event_reporter import _wait_until +from _ert.forward_model_runner.runner import ForwardModelRunner +from tests.ert.utils import _mock_ws_task, async_wait_until, wait_until @pytest.mark.usefixtures("use_tmpdir") -def test_terminate_steps(): +async def test_terminate_steps(): # Executes itself recursively and sleeps for 100 seconds - with open("dummy_executable", "w", encoding="utf-8") as f: + with open("dummy_executable", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( """#!/usr/bin/env python import sys, os, time @@ -73,11 +74,11 @@ def test_terminate_steps(): "ert_pid": "", } - with open(JOBS_FILE, "w", encoding="utf-8") as f: + with open(JOBS_FILE, "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write(json.dumps(step_list)) # macOS doesn't provide /usr/bin/setsid, so we roll our own - with open("setsid", "w", encoding="utf-8") as f: + with open("setsid", "w", encoding="utf-8") as f: # noqa: ASYNC230 f.write( dedent( """\ @@ -95,7 +96,7 @@ def test_terminate_steps(): "_ert.forward_model_runner.job_dispatch" ).origin # (we wait for the process below) - job_dispatch_process = Popen( + job_dispatch_process = Popen( # noqa: ASYNC220 [ os.getcwd() + "/setsid", sys.executable, @@ -113,7 +114,8 @@ def test_terminate_steps(): wait_until(lambda: len(p.children(recursive=True)) == 0) - os.wait() # allow os to clean up zombie processes + # allow os to clean up zombie processes + os.wait() # noqa: ASYNC222 @pytest.mark.usefixtures("use_tmpdir") @@ -277,26 +279,29 @@ def test_job_dispatch_run_subset_specified_as_parameter(): assert os.path.isfile("step_C.out") -def test_no_jobs_json_file_raises_IOError(tmp_path): +async def test_no_jobs_json_file_raises_IOError(tmp_path): with pytest.raises(IOError): - main(["script.py", str(tmp_path)]) + await main(["script.py", str(tmp_path)]) -def test_invalid_jobs_json_raises_OSError(tmp_path): +async def test_invalid_jobs_json_raises_OSError(tmp_path): (tmp_path / JOBS_FILE).write_text("not json") with pytest.raises(OSError): - main(["script.py", str(tmp_path)]) + await main(["script.py", str(tmp_path)]) -def test_missing_directory_exits(tmp_path): +async def test_missing_directory_exits(tmp_path): with pytest.raises(SystemExit): - main(["script.py", str(tmp_path / "non_existent")]) + await main(["script.py", str(tmp_path / "non_existent")]) -def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, caplog): - lock = Lock() - lock.acquire() +async def test_retry_of_jobs_json_file_read( + unused_tcp_port, tmp_path, monkeypatch, caplog +): + lock = asyncio.Lock() + await lock.acquire() + monkeypatch.setattr(_ert.forward_model_runner.cli, "_wait_for_retry", lock.acquire) jobs_json = json.dumps( { @@ -306,24 +311,26 @@ def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, ca } ) - with _mock_ws_thread("localhost", unused_tcp_port, []): - thread = ErtThread(target=main, args=[["script.py", str(tmp_path)]]) - thread.start() - _wait_until( + async with _mock_ws_task("localhost", unused_tcp_port, []): + fm_runner_task = asyncio.create_task(main(["script.py", str(tmp_path)])) + + await async_wait_until( lambda: f"Could not find file {JOBS_FILE}, retrying" in caplog.text, 2, "Did not get expected log message from missing jobs.json", ) (tmp_path / JOBS_FILE).write_text(jobs_json) + await asyncio.sleep(0) lock.release() - thread.join() + + await fm_runner_task @pytest.mark.parametrize( "is_interactive_run, ens_id", [(False, None), (False, "1234"), (True, None), (True, "1234")], ) -def test_setup_reporters(is_interactive_run, ens_id): +async def test_setup_reporters(is_interactive_run, ens_id): reporters = _setup_reporters(is_interactive_run, ens_id, "") if not is_interactive_run and not ens_id: @@ -338,29 +345,32 @@ def test_setup_reporters(is_interactive_run, ens_id): assert len(reporters) == 1 assert any(isinstance(r, Interactive) for r in reporters) + for reporter in reporters: + reporter.cancel() + @pytest.mark.usefixtures("use_tmpdir") -def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): +async def test_job_dispatch_kills_itself_after_unsuccessful_job( + unused_tcp_port, monkeypatch +): host = "localhost" port = unused_tcp_port - jobs_json = json.dumps({"ens_id": "_id_", "dispatch_url": f"ws://localhost:{port}"}) - - with ( - patch("_ert.forward_model_runner.cli.os.killpg") as mock_killpg, - patch("_ert.forward_model_runner.cli.os.getpgid") as mock_getpgid, - patch("_ert.forward_model_runner.cli.open", new=mock_open(read_data=jobs_json)), - patch("_ert.forward_model_runner.cli.ForwardModelRunner") as mock_runner, - ): - mock_runner.return_value.run.return_value = [ - Init([], 0, 0), - Finish().with_error("overall bad run"), - ] - mock_getpgid.return_value = 17 + jobs_obj = { + "ens_id": "_id_", + "dispatch_url": f"ws://localhost:{port}", + "jobList": [], + } + with open("jobs.json", mode="w+", encoding="utf-8") as f: # noqa: ASYNC230 + json.dump(jobs_obj, f) - with _mock_ws_thread(host, port, []): - main(["script.py"]) + async def mock_run_method(self: ForwardModelRunner, *args, **kwargs): + await self.put_event(Init([], 0, 0)) + await self.put_event(Finish().with_error("overall bad run")) - mock_killpg.assert_called_with(17, signal.SIGKILL) + monkeypatch.setattr(ForwardModelRunner, "run", mock_run_method) + async with _mock_ws_task(host, port, []): + with pytest.raises(ForwardModelRunnerException): + await main(["script.py"]) @pytest.mark.skipif(sys.platform.startswith("darwin"), reason="No oom_score on MacOS") diff --git a/tests/ert/unit_tests/simulator/test_batch_sim.py b/tests/ert/unit_tests/simulator/test_batch_sim.py index 4ff43593752..416b26e0f6b 100644 --- a/tests/ert/unit_tests/simulator/test_batch_sim.py +++ b/tests/ert/unit_tests/simulator/test_batch_sim.py @@ -1,3 +1,4 @@ +import asyncio import os import sys import time @@ -491,7 +492,7 @@ def assertContextStatusOddFailures(batch_ctx: BatchContext, final_state_only=Fal @pytest.mark.integration_test -def test_batch_ctx_status_failing_jobs(setup_case, storage): +async def test_batch_ctx_status_failing_jobs(setup_case, storage): ert_config = setup_case("batch_sim", "batch_sim_sleep_and_fail.ert") external_parameters = { @@ -515,6 +516,6 @@ def test_batch_ctx_status_failing_jobs(setup_case, storage): batch_ctx = rsim.start("case_name", ensembles) while batch_ctx.running(): assertContextStatusOddFailures(batch_ctx) - time.sleep(1) + await asyncio.sleep(1) assertContextStatusOddFailures(batch_ctx, final_state_only=True) diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 732f816f8cd..88abe6c6f8b 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -3,14 +3,12 @@ import asyncio import contextlib import time -from functools import partial from pathlib import Path from typing import TYPE_CHECKING import websockets.server from _ert.forward_model_runner.client import Client -from _ert.threading import ErtThread from ert.scheduler.event import FinishedEvent, StartedEvent if TYPE_CHECKING: @@ -61,9 +59,18 @@ def wait_until(func, interval=0.5, timeout=30): ) -def _mock_ws(host, port, messages, delay_startup=0): - loop = asyncio.new_event_loop() - done = loop.create_future() +async def async_wait_until(condition, timeout, fail_msg, interval=0.1): + t = 0 + while t < timeout: + await asyncio.sleep(interval) + if condition(): + return + t += interval + raise AssertionError(fail_msg) + + +async def _mock_ws_async(host, port, messages, delay_startup=0): + done = asyncio.Future() async def _handler(websocket, path): while True: @@ -73,33 +80,24 @@ async def _handler(websocket, path): done.set_result(None) break - async def _run_server(): - await asyncio.sleep(delay_startup) - async with websockets.server.serve(_handler, host, port): - await done - - loop.run_until_complete(_run_server()) - loop.close() + await asyncio.sleep(delay_startup) + async with websockets.server.serve(_handler, host, port): + await done -@contextlib.contextmanager -def _mock_ws_thread(host, port, messages): - mock_ws_thread = ErtThread( - target=partial(_mock_ws, messages=messages), - args=( - host, - port, - ), +@contextlib.asynccontextmanager +async def _mock_ws_task(host, port, messages, delay_startup=0): + mock_ws_task = asyncio.create_task( + _mock_ws_async(host, port, messages, delay_startup) ) - mock_ws_thread.start() try: yield # Make sure to join the thread even if an exception occurs finally: url = f"ws://{host}:{port}" - with Client(url) as client: - client.send("stop") - mock_ws_thread.join() + async with Client(url) as client: + await client.send("stop") + await mock_ws_task messages.pop()