diff --git a/browsergym/core/src/browsergym/core/env.py b/browsergym/core/src/browsergym/core/env.py index feeb3151..537cb2a6 100644 --- a/browsergym/core/src/browsergym/core/env.py +++ b/browsergym/core/src/browsergym/core/env.py @@ -30,7 +30,6 @@ ) from .spaces import AnyBox, AnyDict, Unicode from .task import AbstractBrowserTask -from ..utils.obs import b64_to_pil logger = logging.getLogger(__name__) @@ -123,7 +122,7 @@ def __init__( ) ), "goal": Unicode(min_length=0, max_length=TEXT_MAX_LENGTH), - "goal_data": AnyDict(), + "goal_object": gym.spaces.Sequence(AnyDict()), "open_pages_urls": gym.spaces.Sequence( Unicode(min_length=0, max_length=TEXT_MAX_LENGTH) ), @@ -267,14 +266,37 @@ def override_property(task, env, property): # setup the task task_goal, task_info = self.task.setup(page=self.page) + # process the task goal + + # no goal specified + if task_goal is None: + goal = [] + # convert text-only goal (legacy) to new format + elif isinstance(task_goal, str): + goal = [{"type": "text", "text": task_goal}] + # new format goal with multiple texts and images (OpenAI style) + elif isinstance(task_goal, list): + goal = task_goal + else: + raise ValueError(f"task_goal should be of type str or list, got {task_goal.__class__}") + # initialize the chat self.chat.add_message( role="assistant", msg="Hi! I am your UI assistant, I can perform web tasks for you. What can I help you with?", ) - # process the task goal - self.goal, self.goal_data = self._process_task_goal(task_goal, send_to_chat=True) + # send task goal (if any) to the chat + for message in goal: + match message["type"]: + case "text": + self.chat.add_message(role="user", msg=message["text"]) + case "image_url": + self.chat.add_message(role="user_image", msg=message["image_url"]) + case _: + raise ValueError( + f"Unknown message type {repr(message['type'])} in the task goal." + ) self._wait_dom_loaded() @@ -498,13 +520,14 @@ def _get_obs(self): # if no goal has been set yet, try to extract it from the chat if not self.goal: - self.goal, self.goal_data = self._try_and_extract_goal_from_chat() + logger.warning(f"Empty goal, trying to extract goal from chat as a fallback.") + self.goal = self._try_and_extract_goal_from_chat() # obs is generic to all tasks obs = { "chat_messages": copy.deepcopy(self.chat.messages), - "goal": self.goal, - "goal_data": self.goal_data, + "goal": self._goal_to_text(self.goal), # legacy goal, deprecated + "goal_object": self.goal, # new goal format, liust of messages openai style "open_pages_urls": [page.url for page in self.context.pages], "active_page_index": np.asarray([self.context.pages.index(self.page)]), "url": self.page.url, @@ -520,83 +543,34 @@ def _get_obs(self): return obs - def _process_task_goal(self, task_goal, send_to_chat: bool = False): - # no goal specified - if task_goal is None: - goal = "" - goal_data = {} - - # text-only goal - elif isinstance(task_goal, str): - goal = task_goal - goal_data = {} - - if send_to_chat: - self.chat.add_message(role="user", msg=task_goal) - - # goal with text and images - elif isinstance(task_goal, dict): - goal = task_goal["text"] - goal_data = { - "image_urls": task_goal["image_urls"], - "image_paths": [], - } - # save images from the goal to files in a local directory - temp_dir = Path(tempfile.mkdtemp()) - for image_i, image_url in enumerate(task_goal["image_urls"]): - # download remotely hosted images - if image_url.startswith("http"): - image = Image.open(requests.get(image_url, stream=True).raw) - # decode base64-encoded images - elif image_url.startswith("data:image"): - image = b64_to_pil(image_url) - else: - raise ValueError(f"Unexpected image_url: {image_url}") - # write images to local files - format = image.format.lower() - image_path = temp_dir / f"input_image_{image_i}.{format}" - image.save(image_path) - # add image path to the goal - goal_data["image_paths"].append(image_path) - - if send_to_chat: - for image_url in task_goal["image_urls"]: - # send goal images to the chat - self.chat.add_message(role="user_image", msg=image_url) - # send goal text to the chat, after the images - self.chat.add_message(role="user", msg=task_goal["text"]) - else: - raise ValueError(f"task_goal should be of type str or dict, got {task_goal.__class__}") - - return goal, goal_data - def _try_and_extract_goal_from_chat(self): # as a fallback, when a task does not specify a goal, we try to convert the first chat message into a goal - first_user_message = None - first_user_images = [] + goal = [] for msg in self.chat.messages: # extract first user message as goal, if any if msg["role"] == "user": - first_user_message = msg["message"] + goal = [{"type": "text", "text": msg["message"]}] break - # extract any user_image message present before as a goal image - elif msg["role"] == "user_image": - first_user_images.append(msg["message"]) - else: - pass - # convert chat messages to a task_goal, if any - if first_user_message is None: - task_goal = None - elif not first_user_images: - task_goal = first_user_message - else: - task_goal = { - "text": first_user_message, - "image_urls": first_user_images, - } - - # process the task_goal into a proper goal, if any - goal, goal_data = self._process_task_goal(task_goal=task_goal, send_to_chat=False) + return goal + + def _goal_to_text(self, goal: list): + goal_text_strings = [] + for message in goal: + match message["type"]: + case "text": + goal_text_strings.append(message["text"]) + case "image_url": + if message["image_url"].startswith("data:image"): + goal_text_strings.append( + "image_url: " + message["image_url"][:30] + "... (truncated)" + ) + else: + goal_text_strings.append("image_url: " + message["image_url"]) + case _: + raise ValueError( + f"Unknown message type {repr(message['type'])} in the task goal." + ) + goal_text = "\n".join(goal_text_strings) - return goal, goal_data + return goal_text diff --git a/browsergym/core/src/browsergym/utils/obs.py b/browsergym/core/src/browsergym/utils/obs.py index db37a386..1799cbb6 100644 --- a/browsergym/core/src/browsergym/utils/obs.py +++ b/browsergym/core/src/browsergym/utils/obs.py @@ -1,6 +1,4 @@ import ast -import base64 -import io import logging import numpy as np import PIL.Image @@ -10,7 +8,6 @@ from collections import defaultdict from bs4 import BeautifulSoup -from typing import Literal from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR from browsergym.core.constants import BROWSERGYM_VISIBILITY_ATTRIBUTE as VIS_ATTR @@ -548,25 +545,3 @@ def prune_html(html): html = soup.prettify() return html - - -def pil_to_b64(img: PIL.Image.Image, format: Literal["png", "jpeg"] = "png") -> str: - assert format in ("png", "jpeg") - with io.BytesIO() as image_buffer: - img.save(image_buffer, format=format.upper()) - byte_data = image_buffer.getvalue() - img_b64 = base64.b64encode(byte_data).decode("utf-8") - img_b64 = f"data:image/{format};base64," + img_b64 - return img_b64 - - -def b64_to_pil(img_b64: str) -> str: - if img_b64.startswith("data:image/png;base64,"): - img_b64 = img_b64.removeprefix("data:image/png;base64,") - elif img_b64.startswith("data:image/jpeg;base64,"): - img_b64 = img_b64.removeprefix("data:image/jpeg;base64,") - else: - raise ValueError(f"Unexpected base64 encoding: {img_b64}") - img_data = base64.b64decode(img_b64) - img = PIL.Image.open(io.BytesIO(img_data)) - return img diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py index 11608b17..a189945e 100644 --- a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py @@ -2,6 +2,7 @@ import logging import playwright.sync_api import importlib.resources +import pathlib import tempfile import requests @@ -10,6 +11,7 @@ from browsergym.core.task import AbstractBrowserTask from .instance import VisualWebArenaInstance +from .utils import image_url_to_pil_image logger = logging.getLogger(__name__) @@ -129,26 +131,60 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]: if i < len(start_urls) - 1: page = page.context.new_page() - # recover goal - goal = { - "text": self.config["intent"], - "image_urls": self.config.get("image", []), - } - # fix goal if needed - if goal["image_urls"] is None: - goal["image_urls"] = [] - elif isinstance(goal["image_urls"], str): - goal["image_urls"] = [goal["image_urls"]] + # recover goal text + goal_text = self.config["intent"] # This note is present in some of webarena's agent prompts if self.with_na_hint: - goal[ - "text" - ] += """\ + goal_text += """\ If you believe the task is impossible to complete, provide the answer "N/A". """ + # recover goal image urls + image_urls = self.config.get("image", []) + + # fix image list if needed + if image_urls is None: + image_urls = [] + elif isinstance(image_urls, str): + image_urls = [image_urls] + + # save goal images to local files in a temporary directory + image_paths = [] + temp_dir = pathlib.Path(tempfile.mkdtemp()) + for i, image_url in enumerate(image_urls): + # extract image content from url + image = image_url_to_pil_image(image_url) + # write image to local file + format = image.format.lower() + image_path = temp_dir / f"input_image_{i}.{format}" + image.save(image_path) + # add image path to the goal + image_paths.append(image_path) + + # build an OpenAI-style structured goal + # textual goal first + goal = [{"type": "text", "text": goal_text}] + # then goal images + for i, (image_url, image_path) in enumerate(zip(image_urls, image_paths)): + goal.extend( + [ + # image description (id and filepath) + { + "type": "text", + "text": f"Input image {i}/{len(image_urls)} below (local path: {repr(image_path)})", + }, + # actual image (image_url) + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + ] + ) + return goal, {} def cheat(self, page: playwright.sync_api.Page, chat_messages: list[str]) -> None: diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py new file mode 100644 index 00000000..abb80fbd --- /dev/null +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py @@ -0,0 +1,20 @@ +import base64 +import io +import PIL.Image +import requests + + +def image_url_to_pil_image(image_url: str) -> str: + if image_url.startswith("http"): + image_data = requests.get(image_url, stream=True).raw + elif image_url.startswith("data:image/png;base64,"): + image_data = base64.b64decode(image_url.removeprefix("data:image/png;base64,")) + elif image_url.startswith("data:image/jpeg;base64,"): + image_data = base64.b64decode(image_url.removeprefix("data:image/jpeg;base64,")) + else: + if image_url.startswith("data:image/"): + raise ValueError(f"Unexpected image encoding: {image_url}") + else: + raise ValueError(f"Unexpected image URL: {image_url}") + img = PIL.Image.open(io.BytesIO(image_data)) + return img