From 642b5f54a96c2ab5506ccf4bcc4cd49c8b465e23 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 4 Oct 2024 10:57:44 -0400 Subject: [PATCH 1/6] images aren't saved in pkl files anymore, and are stuffed back in at load time --- .../src/browsergym/experiments/loop.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index 9aeb2245..dbb1c44a 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -17,10 +17,11 @@ import gymnasium as gym import numpy as np -from browsergym.core.chat import Chat from PIL import Image from tqdm import tqdm +from browsergym.core.chat import Chat + from .agent import Agent from .utils import count_messages_token, count_tokens @@ -412,17 +413,22 @@ def make_stats(self): self.stats = stats - def save_step_info(self, exp_dir, save_json=False, save_jpg=True): + def save_step_info(self, exp_dir, save_json=False, save_jpg=True, save_som=False): + + screenshot = self.obs.pop("screenshot", None) + screenshot_som = self.obs.pop("screenshot_som", None) + + if save_jpg and screenshot is not None: + img = Image.fromarray(screenshot) + img.save(exp_dir / f"screenshot_step_{self.step}.jpg") + + if save_som and screenshot_som is not None: + img = Image.fromarray(screenshot_som) + img.save(exp_dir / f"screenshot_som_step_{self.step}.jpg") with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f: pickle.dump(self, f) - if save_jpg and self.obs is not None: - for name in ("screenshot", "screenshot_som"): - if name in self.obs: - img = Image.fromarray(self.obs[name]) - img.save(exp_dir / f"{name}_step_{self.step}.jpg") - if save_json: with open(exp_dir / "steps_info.json", "w") as f: json.dump(self, f, indent=4, cls=DataclassJSONEncoder) @@ -552,6 +558,20 @@ def get_step_info(self, step: int) -> StepInfo: if self._steps_info.get(step, None) is None: with gzip.open(self.exp_dir / f"step_{step}.pkl.gz", "rb") as f: self._steps_info[step] = pickle.load(f) + if "screenshot" not in self._steps_info[step].obs: + try: + self._steps_info[step].obs["screenshot"] = np.array( + Image.open(self.exp_dir / f"screenshot_step_{step}.jpg"), dtype=np.uint8 + ) + except FileNotFoundError: + pass + if "screenshot_som" not in self._steps_info[step].obs: + try: + self._steps_info[step].obs["screenshot_som"] = np.array( + Image.open(self.exp_dir / f"screenshot_som_step_{step}.jpg"), dtype=np.uint8 + ) + except FileNotFoundError: + pass return self._steps_info[step] @property From 4f8f7a4c04e7bcd03a537f82ff443b55b7e86b46 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 4 Oct 2024 10:59:50 -0400 Subject: [PATCH 2/6] added kwargs to control img/som saving --- .../experiments/src/browsergym/experiments/loop.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index dbb1c44a..9da83769 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -143,6 +143,8 @@ class ExpArgs: logging_level: int = logging.INFO exp_id: str = None depends_on: tuple[str] = () + save_jpg: bool = True + save_som: bool = False def prepare(self, exp_root): """Prepare the experiment directory and save the experiment arguments. @@ -221,7 +223,9 @@ def run(self): # will end the episode after saving the step info. step_info.truncated = True - step_info.save_step_info(self.exp_dir) + step_info.save_step_info( + self.exp_dir, save_jpg=self.save_jpg, save_som=self.save_som + ) logger.debug(f"Step info saved.") _send_chat_info(env.unwrapped.chat, action, step_info.agent_info) @@ -252,7 +256,9 @@ def run(self): finally: try: if step_info is not None: - step_info.save_step_info(self.exp_dir) + step_info.save_step_info( + self.exp_dir, save_jpg=self.save_jpg, save_som=self.save_som + ) except Exception as e: logger.error(f"Error while saving step info in the finally block: {e}") try: From dae2ee2cff002debded7348e8e12ae58e8baa40a Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 4 Oct 2024 12:13:19 -0400 Subject: [PATCH 3/6] saving as png, adding screenshots back into obs --- .../src/browsergym/experiments/loop.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index 9da83769..4fc72b09 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -143,7 +143,7 @@ class ExpArgs: logging_level: int = logging.INFO exp_id: str = None depends_on: tuple[str] = () - save_jpg: bool = True + save_screenshot: bool = True save_som: bool = False def prepare(self, exp_root): @@ -224,7 +224,7 @@ def run(self): step_info.truncated = True step_info.save_step_info( - self.exp_dir, save_jpg=self.save_jpg, save_som=self.save_som + self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som ) logger.debug(f"Step info saved.") @@ -257,7 +257,7 @@ def run(self): try: if step_info is not None: step_info.save_step_info( - self.exp_dir, save_jpg=self.save_jpg, save_som=self.save_som + self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som ) except Exception as e: logger.error(f"Error while saving step info in the finally block: {e}") @@ -419,18 +419,18 @@ def make_stats(self): self.stats = stats - def save_step_info(self, exp_dir, save_json=False, save_jpg=True, save_som=False): + def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_som=False): screenshot = self.obs.pop("screenshot", None) screenshot_som = self.obs.pop("screenshot_som", None) - if save_jpg and screenshot is not None: + if save_screenshot and screenshot is not None: img = Image.fromarray(screenshot) - img.save(exp_dir / f"screenshot_step_{self.step}.jpg") + img.save(exp_dir / f"screenshot_step_{self.step}.png") if save_som and screenshot_som is not None: img = Image.fromarray(screenshot_som) - img.save(exp_dir / f"screenshot_som_step_{self.step}.jpg") + img.save(exp_dir / f"screenshot_som_step_{self.step}.png") with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f: pickle.dump(self, f) @@ -439,6 +439,12 @@ def save_step_info(self, exp_dir, save_json=False, save_jpg=True, save_som=False with open(exp_dir / "steps_info.json", "w") as f: json.dump(self, f, indent=4, cls=DataclassJSONEncoder) + # add the screenshots back to the obs + if screenshot is not None: + self.obs["screenshot"] = screenshot + if screenshot_som is not None: + self.obs["screenshot_som"] = screenshot_som + def _extract_err_msg(episode_info: list[StepInfo]): """Extract the last error message from the episode info.""" @@ -567,14 +573,14 @@ def get_step_info(self, step: int) -> StepInfo: if "screenshot" not in self._steps_info[step].obs: try: self._steps_info[step].obs["screenshot"] = np.array( - Image.open(self.exp_dir / f"screenshot_step_{step}.jpg"), dtype=np.uint8 + Image.open(self.exp_dir / f"screenshot_step_{step}.png"), dtype=np.uint8 ) except FileNotFoundError: pass if "screenshot_som" not in self._steps_info[step].obs: try: self._steps_info[step].obs["screenshot_som"] = np.array( - Image.open(self.exp_dir / f"screenshot_som_step_{step}.jpg"), dtype=np.uint8 + Image.open(self.exp_dir / f"screenshot_som_step_{step}.png"), dtype=np.uint8 ) except FileNotFoundError: pass @@ -602,12 +608,12 @@ def summary_info(self) -> dict: def get_screenshot(self, step: int, som=False) -> Image: key = (step, som) if self._screenshots.get(key, None) is None: - file_name = f"screenshot_{'som_' if som else ''}step_{step}.jpg" + file_name = f"screenshot_{'som_' if som else ''}step_{step}.png" self._screenshots[key] = Image.open(self.exp_dir / file_name) return self._screenshots[key] def get_screenshots(self, som=False): - files = list(self.exp_dir.glob("screenshot_step_*.jpg")) + files = list(self.exp_dir.glob("screenshot_step_*.png")) max_step = 0 for file in files: step = int(file.name.split("_")[-1].split(".")[0]) From c42496b383c40e3871f6727fb4a3c0030cb19c5c Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 4 Oct 2024 13:46:26 -0400 Subject: [PATCH 4/6] retrocompatibility for image loading --- .../src/browsergym/experiments/loop.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index 4fc72b09..fe1ae809 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -572,15 +572,13 @@ def get_step_info(self, step: int) -> StepInfo: self._steps_info[step] = pickle.load(f) if "screenshot" not in self._steps_info[step].obs: try: - self._steps_info[step].obs["screenshot"] = np.array( - Image.open(self.exp_dir / f"screenshot_step_{step}.png"), dtype=np.uint8 - ) + self._steps_info[step].obs["screenshot"] = self.get_screenshot(step) except FileNotFoundError: pass if "screenshot_som" not in self._steps_info[step].obs: try: - self._steps_info[step].obs["screenshot_som"] = np.array( - Image.open(self.exp_dir / f"screenshot_som_step_{step}.png"), dtype=np.uint8 + self._steps_info[step].obs["screenshot_som"] = self.get_screenshot( + step, som=True ) except FileNotFoundError: pass @@ -608,8 +606,11 @@ def summary_info(self) -> dict: def get_screenshot(self, step: int, som=False) -> Image: key = (step, som) if self._screenshots.get(key, None) is None: - file_name = f"screenshot_{'som_' if som else ''}step_{step}.png" - self._screenshots[key] = Image.open(self.exp_dir / file_name) + file_name = f"screenshot_{'som_' if som else ''}step_{step}" + try: + self._screenshots[key] = Image.open(self.exp_dir / (file_name + ".png")) + except FileNotFoundError: + self._screenshots[key] = Image.open(self.exp_dir / (file_name + ".jpg")) return self._screenshots[key] def get_screenshots(self, som=False): From 49ada25b185889f83b2cd7bcb4f58296b63feaf0 Mon Sep 17 00:00:00 2001 From: Thibault LSDC <78021491+ThibaultLSDC@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:55:28 -0400 Subject: [PATCH 5/6] making get_screenshots work for png and jpg --- browsergym/experiments/src/browsergym/experiments/loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index fe1ae809..033e5366 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -614,7 +614,7 @@ def get_screenshot(self, step: int, som=False) -> Image: return self._screenshots[key] def get_screenshots(self, som=False): - files = list(self.exp_dir.glob("screenshot_step_*.png")) + files = list(self.exp_dir.glob("screenshot_step_*")) max_step = 0 for file in files: step = int(file.name.split("_")[-1].split(".")[0]) From 417177a2bf64c137542e5fe0426c48d46225a1c2 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Mon, 7 Oct 2024 11:49:21 -0400 Subject: [PATCH 6/6] fixing image types and closing files --- .../experiments/src/browsergym/experiments/loop.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index 033e5366..f39d4622 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -572,13 +572,15 @@ def get_step_info(self, step: int) -> StepInfo: self._steps_info[step] = pickle.load(f) if "screenshot" not in self._steps_info[step].obs: try: - self._steps_info[step].obs["screenshot"] = self.get_screenshot(step) + self._steps_info[step].obs["screenshot"] = np.array( + self.get_screenshot(step), dtype=np.uint8 + ) except FileNotFoundError: pass if "screenshot_som" not in self._steps_info[step].obs: try: - self._steps_info[step].obs["screenshot_som"] = self.get_screenshot( - step, som=True + self._steps_info[step].obs["screenshot_som"] = np.array( + self.get_screenshot(step, som=True), dtype=np.uint8 ) except FileNotFoundError: pass @@ -608,9 +610,11 @@ def get_screenshot(self, step: int, som=False) -> Image: if self._screenshots.get(key, None) is None: file_name = f"screenshot_{'som_' if som else ''}step_{step}" try: - self._screenshots[key] = Image.open(self.exp_dir / (file_name + ".png")) + with Image.open(self.exp_dir / (file_name + ".png")) as img: + self._screenshots[key] = img.copy() except FileNotFoundError: - self._screenshots[key] = Image.open(self.exp_dir / (file_name + ".jpg")) + with Image.open(self.exp_dir / (file_name + ".jpg")) as img: + self._screenshots[key] = img.copy() return self._screenshots[key] def get_screenshots(self, som=False):