diff --git a/browsergym/experiments/src/browsergym/experiments/benchmark/utils.py b/browsergym/experiments/src/browsergym/experiments/benchmark/utils.py index 88cd2b94..396accaa 100644 --- a/browsergym/experiments/src/browsergym/experiments/benchmark/utils.py +++ b/browsergym/experiments/src/browsergym/experiments/benchmark/utils.py @@ -1,8 +1,10 @@ import logging +import multiprocessing as mp import os +import traceback +import typing from typing import Literal -import gymnasium as gym import numpy as np from browsergym.experiments.loop import SEED_MAX, EnvArgs @@ -206,35 +208,73 @@ def prepare_backend(backend: str): raise NotImplementedError(f"Unknown benchmark backend {repr(backend)}") -def massage_tasks(task_ids: list[str], max_retries: int = 1): +def massage_tasks(task_ids: list[str], max_retries: int = 1, timeout: int = 60): for i, task_id in enumerate(task_ids): - gym_id = f"browsergym/{task_id}" - logger.info(f"Massaging task {i + 1} / {len(task_ids)}: {gym_id}") - task_retries = 0 - while True: - env = gym.make(gym_id) - try: - env.reset() # task setup - try: - no_action = "noop()" - # check if action space exists and is compatible with "noop()" - env.unwrapped.action_mapping(no_action) - except: - # fallback plan - no_action = "" - env.step(no_action) # task validation - env.step(no_action) # task validation again - logger.info(f"Massage successful") + logger.info(f"Massaging task {i + 1} / {len(task_ids)}: {task_id}") + for retries in range(max_retries + 1): + outcome, err_msg = massage_task_within_subprocess(task_id=task_id, timeout=timeout) + if outcome == "success": break - except Exception as e: - if task_retries < max_retries: - task_retries += 1 - logger.info(f"Massage failed, retrying ({task_retries} / {max_retries})") - continue - else: - logger.warning( - f"Error during task massage after {task_retries} retries ({gym_id}): {e}" - ) - break - finally: - env.close() + if retries < max_retries: + logger.info( + f"Massage resulted in {outcome}, retrying ({retries + 1} / {max_retries} retries)" + ) + else: + logger.warning( + f"Massage unsuccessful after {retries} retries, skipping. Last error message: {err_msg}" + ) + + +def massage_task_within_subprocess( + task_id: str, timeout: int, kill_timeout: int = 10 +) -> typing.Tuple[str, str]: + """Massages a BrowserGym task (reset, noop, noop) inside a subprocess to monitor execution + times and kill the process after a timeout. + + Returns: an (outcome, err_msg) tuple. + - outcome: the outcome of the massage, one of 'success', 'exception' or 'timeout'. + - err_msg: error message if any, or None. + """ + + def run_massage(outcome_queue: mp.Queue): + import gymnasium as gym + + gym_id = f"browsergym/{task_id}" + env = gym.make(gym_id) + no_action = "noop()" + # check if action space exists and is compatible with "noop()" + try: + env.unwrapped.action_mapping(no_action) + except: + no_action = "" # fallback plan + # run massage + try: + env.reset() # task setup + env.step(no_action) # task validation + env.step(no_action) # task validation again + outcome = "success", None + except Exception as e: + outcome = "exception", traceback.format_exception(e) + finally: + env.close() + outcome_queue.put(outcome) + + queue = mp.Queue() + process = mp.Process(target=run_massage, args=queue) + process.start() + process.join(timeout=timeout) + + if process.is_alive(): + # if the process is still alive after the timeout + outcome = "timeout", f"Timeout {timeout} seconds exceeded" + process.kill() + process.join(timeout=kill_timeout) + if process.is_alive(): + # if the process is still alive after the kill + logger.warning( + f"Massage sub-process still alive {kill_timeout} seconds after kill(), you might have a zombie process now." + ) + else: + outcome = queue.get_nowait() + + return outcome