Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make forwardmodelrunner async #9198

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
import argparse
import asyncio
import json
import logging
import os
import signal
import sys
import time
import typing
from datetime import datetime
from typing import List, Sequence

from _ert.forward_model_runner import reporting
from _ert.forward_model_runner.reporting.message import Finish, ProcessTreeStatus
from _ert.forward_model_runner.reporting.message import (
Finish,
Message,
ProcessTreeStatus,
)
from _ert.forward_model_runner.runner import ForwardModelRunner

JOBS_FILE = "jobs.json"

logger = logging.getLogger(__name__)


class ForwardModelRunnerException(Exception):
pass


def _setup_reporters(
is_interactive_run,
ens_id,
dispatch_url,
ee_token=None,
ee_cert_path=None,
experiment_id=None,
) -> typing.List[reporting.Reporter]:
reporters: typing.List[reporting.Reporter] = []
) -> List[reporting.Reporter]:
reporters: List[reporting.Reporter] = []
if is_interactive_run:
reporters.append(reporting.Interactive())
elif ens_id and experiment_id is None:
Expand Down Expand Up @@ -71,26 +79,26 @@ def _setup_logging(directory: str = "logs"):
JOBS_JSON_RETRY_TIME = 30


def _wait_for_retry():
time.sleep(JOBS_JSON_RETRY_TIME)
async def _wait_for_retry():
Copy link
Contributor

@xjules xjules Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we need this helper function at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need it for one of the tests. test_job_dispatch.py::test_retry_of_jobs_json_file_read

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this usage of that function is a bit strange though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We mock it to lock.acquire in a test, so that it will stop here

await asyncio.sleep(JOBS_JSON_RETRY_TIME)


def _read_jobs_file(retry=True):
async def _read_jobs_file(retry=True):
try:
with open(JOBS_FILE, "r", encoding="utf-8") as json_file:
with open(JOBS_FILE, "r", encoding="utf-8") as json_file: # noqa: ASYNC230
return json.load(json_file)
except json.JSONDecodeError as e:
raise IOError("Job Runner cli failed to load JSON-file.") from e
except FileNotFoundError as e:
if retry:
logger.error(f"Could not find file {JOBS_FILE}, retrying")
_wait_for_retry()
return _read_jobs_file(retry=False)
await _wait_for_retry()
return await _read_jobs_file(retry=False)
else:
raise e


def main(args):
async def main(args):
parser = argparse.ArgumentParser(
description=(
"Run all the jobs specified in jobs.json, "
Expand Down Expand Up @@ -118,7 +126,7 @@ def main(args):
# Make sure that logging is setup _after_ we have moved to the runpath directory
_setup_logging()

jobs_data = _read_jobs_file()
jobs_data = await _read_jobs_file()

experiment_id = jobs_data.get("experiment_id")
ens_id = jobs_data.get("ens_id")
Expand All @@ -135,21 +143,64 @@ def main(args):
ee_cert_path,
experiment_id,
)
reporter_queue: asyncio.Queue[Message] = asyncio.Queue()

is_done = asyncio.Event()
forward_model_runner = ForwardModelRunner(jobs_data, reporter_queue=reporter_queue)
forward_model_runner_task = asyncio.create_task(
forward_model_runner.run(parsed_args.job)
)
reporting_task = asyncio.create_task(
handle_reporting(reporters, reporter_queue, is_done)
)

def handle_sigterm(*args, **kwargs):
nonlocal reporters, forward_model_runner_task
forward_model_runner_task.cancel()
for reporter in reporters:
reporter.cancel()
Copy link
Contributor

@xjules xjules Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try:
  await reporter
except asyncio.CancelledError:
  pass

or maybe just asyncio.gather(*reporters, return ....)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signal handler has to be synced, but we await the task anyways so it should be fine.

Copy link
Contributor

@xjules xjules Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To shutdown gracefully, this is what chatgpt suggests:

def setup_signal_handlers(loop):
    """
    Setup signal handlers for graceful shutdown.
    """
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown(loop, signal=sig)))

wherein shutdown is an async function.


job_runner = ForwardModelRunner(jobs_data)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, handle_sigterm)

await forward_model_runner_task

is_done.set()
await reporting_task


async def handle_reporting(
reporters: Sequence[reporting.Reporter],
message_queue: asyncio.Queue[Message],
is_done: asyncio.Event,
):
while True:
try:
job_status = await asyncio.wait_for(message_queue.get(), timeout=2)
except asyncio.TimeoutError:
if is_done.is_set():
break
continue

for job_status in job_runner.run(parsed_args.job):
logger.info(f"Job status: {job_status}")
for reporter in reporters:
try:
reporter.report(job_status)
await reporter.report(job_status)
except OSError as oserror:
print(
f"job_dispatch failed due to {oserror}. Stopping and cleaning up."
)
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)
await let_reporters_finish(reporters)
raise ForwardModelRunnerException from oserror

message_queue.task_done()
if isinstance(job_status, Finish) and not job_status.success():
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)
await let_reporters_finish(reporters)
raise ForwardModelRunnerException(job_status.error_message)

await let_reporters_finish(reporters)


async def let_reporters_finish(reporters):
for reporter in reporters:
if isinstance(reporter, reporting.Event):
await reporter.join()
5 changes: 1 addition & 4 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def get_websocket(self) -> ClientConnection:
close_timeout=self.CONNECTION_TIMEOUT,
)

async def _send(self, msg: AnyStr) -> None:
async def send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
try:
if self.websocket is None:
Expand Down Expand Up @@ -133,6 +133,3 @@ async def _send(self, msg: AnyStr) -> None:
raise ClientConnectionError(_error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
70 changes: 40 additions & 30 deletions src/_ert/forward_model_runner/forward_model_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
import io
import json
Expand All @@ -11,10 +12,19 @@
import time
from datetime import datetime as dt
from pathlib import Path
from subprocess import Popen, run
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple, cast
from subprocess import run
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Dict,
List,
Optional,
Sequence,
Tuple,
cast,
)

from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess
from psutil import AccessDenied, NoSuchProcess, Process, ZombieProcess

from .io import check_executable
from .reporting.message import (
Expand Down Expand Up @@ -89,9 +99,9 @@ def __init__(
self.std_err = job_data.get("stderr")
self.std_out = job_data.get("stdout")

def run(self) -> Generator[Start | Exited | Running | None]:
async def run(self) -> AsyncGenerator[Start | Exited | Running | None]:
try:
for msg in self._run():
async for msg in self._run():
yield msg
except Exception as e:
yield Exited(self, exit_code=1).with_error(str(e))
Expand Down Expand Up @@ -151,9 +161,8 @@ def _create_environment(self) -> Optional[Dict[str, str]]:
combined_environment = {**os.environ, **environment}
return combined_environment

def _run(self) -> Generator[Start | Exited | Running | None]:
async def _run(self) -> AsyncGenerator[Start | Exited | Running | None]:
start_message = self.create_start_message_and_check_job_files()

yield start_message
if not start_message.success():
return
Expand All @@ -167,14 +176,14 @@ def _run(self) -> Generator[Start | Exited | Running | None]:
target_file_mtime: Optional[int] = _get_target_file_ntime(target_file)

try:
proc = Popen(
arg_list,
proc = await asyncio.create_subprocess_exec(
*arg_list,
stdin=stdin,
stdout=stdout,
stderr=stderr,
env=self._create_environment(),
)
process = Process(proc.pid)

except OSError as e:
exited_message = self._handle_process_io_error_and_create_exited_message(
e, stderr
Expand All @@ -186,9 +195,9 @@ def _run(self) -> Generator[Start | Exited | Running | None]:
exit_code = None

max_memory_usage = 0
fm_step_pids = {int(process.pid)}
fm_step_pids = {int(proc.pid)}
while exit_code is None:
(memory_rss, cpu_seconds, oom_score) = _get_processtree_data(process)
(memory_rss, cpu_seconds, oom_score) = _get_processtree_data(proc.pid)
max_memory_usage = max(memory_rss, max_memory_usage)
yield Running(
self,
Expand All @@ -203,18 +212,22 @@ def _run(self) -> Generator[Start | Exited | Running | None]:
)

try:
exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD)
except TimeoutExpired:
exit_code = await asyncio.wait_for(
proc.wait(), timeout=self.MEMORY_POLL_PERIOD
)
except asyncio.TimeoutError:
potential_exited_msg = (
self.handle_process_timeout_and_create_exited_msg(exit_code, proc)
)
if isinstance(potential_exited_msg, Exited):
yield potential_exited_msg

return
fm_step_pids |= {
int(child.pid) for child in process.children(recursive=True)
}
with contextlib.suppress(NoSuchProcess):
proccess = Process(proc.pid)
fm_step_pids |= {
int(child.pid) for child in proccess.children(recursive=True)
}

ensure_file_handles_closed([stdin, stdout, stderr])
exited_message = self._create_exited_message_based_on_exit_code(
Expand Down Expand Up @@ -274,7 +287,7 @@ def _create_exited_msg_for_non_zero_exit_code(
)

def handle_process_timeout_and_create_exited_msg(
self, exit_code: Optional[int], proc: Popen[Process]
self, exit_code: Optional[int], proc: asyncio.subprocess.Process
) -> Exited | None:
max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()
Expand Down Expand Up @@ -349,7 +362,6 @@ def _check_job_files(self) -> list[str]:

if executable_error := check_executable(self.job_data.get("executable")):
errors.append(executable_error)

return errors

def _check_target_file_is_written(
Expand Down Expand Up @@ -428,7 +440,7 @@ def ensure_file_handles_closed(file_handles: Sequence[io.TextIOWrapper | None])


def _get_processtree_data(
process: Process,
pid: int,
) -> Tuple[int, float, Optional[int]]:
"""Obtain the oom_score (the Linux kernel uses this number to
decide which process to kill first in out-of-memory siturations).
Expand All @@ -450,21 +462,19 @@ def _get_processtree_data(
memory_rss = 0
cpu_seconds = 0.0
with contextlib.suppress(ValueError, FileNotFoundError):
oom_score = int(
Path(f"/proc/{process.pid}/oom_score").read_text(encoding="utf-8")
)
with (
contextlib.suppress(
ValueError, NoSuchProcess, AccessDenied, ZombieProcess, ProcessLookupError
),
process.oneshot(),
oom_score = int(Path(f"/proc/{pid}/oom_score").read_text(encoding="utf-8"))
with contextlib.suppress(
ValueError, NoSuchProcess, AccessDenied, ZombieProcess, ProcessLookupError
):
memory_rss = process.memory_info().rss
cpu_seconds = process.cpu_times().user
process = Process(pid)
with process.oneshot():
memory_rss = process.memory_info().rss
cpu_seconds = process.cpu_times().user

with contextlib.suppress(
NoSuchProcess, AccessDenied, ZombieProcess, ProcessLookupError
):
process = Process(pid)
for child in process.children(recursive=True):
with contextlib.suppress(
ValueError,
Expand Down
10 changes: 4 additions & 6 deletions src/_ert/forward_model_runner/job_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import contextlib
import os
import signal
import sys
Expand All @@ -13,12 +15,8 @@ def sigterm_handler(_signo, _stack_frame):
def main():
os.nice(19)
signal.signal(signal.SIGTERM, sigterm_handler)
try:
job_runner_main(sys.argv)
except Exception as e:
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGTERM)
raise e
with contextlib.suppress(asyncio.CancelledError):
asyncio.run(job_runner_main(sys.argv))


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion src/_ert/forward_model_runner/reporting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@

class Reporter(ABC):
@abstractmethod
def report(self, msg: Message):
async def report(self, msg: Message):
"""Report a message."""

@abstractmethod
def cancel(self):
"""Safely shut down the reporter"""
Loading
Loading