Skip to content

Commit

Permalink
goal_object refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Sep 23, 2024
1 parent 2132c95 commit 4be7ad5
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 117 deletions.
132 changes: 53 additions & 79 deletions browsergym/core/src/browsergym/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from .spaces import AnyBox, AnyDict, Unicode
from .task import AbstractBrowserTask
from ..utils.obs import b64_to_pil

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(
)
),
"goal": Unicode(min_length=0, max_length=TEXT_MAX_LENGTH),
"goal_data": AnyDict(),
"goal_object": gym.spaces.Sequence(AnyDict()),
"open_pages_urls": gym.spaces.Sequence(
Unicode(min_length=0, max_length=TEXT_MAX_LENGTH)
),
Expand Down Expand Up @@ -267,14 +266,37 @@ def override_property(task, env, property):
# setup the task
task_goal, task_info = self.task.setup(page=self.page)

# process the task goal

# no goal specified
if task_goal is None:
goal = []
# convert text-only goal (legacy) to new format
elif isinstance(task_goal, str):
goal = [{"type": "text", "text": task_goal}]
# new format goal with multiple texts and images (OpenAI style)
elif isinstance(task_goal, list):
goal = task_goal
else:
raise ValueError(f"task_goal should be of type str or list, got {task_goal.__class__}")

# initialize the chat
self.chat.add_message(
role="assistant",
msg="Hi! I am your UI assistant, I can perform web tasks for you. What can I help you with?",
)

# process the task goal
self.goal, self.goal_data = self._process_task_goal(task_goal, send_to_chat=True)
# send task goal (if any) to the chat
for message in goal:
match message["type"]:
case "text":
self.chat.add_message(role="user", msg=message["text"])
case "image_url":
self.chat.add_message(role="user_image", msg=message["image_url"])
case _:
raise ValueError(
f"Unknown message type {repr(message['type'])} in the task goal."
)

self._wait_dom_loaded()

Expand Down Expand Up @@ -498,13 +520,14 @@ def _get_obs(self):

# if no goal has been set yet, try to extract it from the chat
if not self.goal:
self.goal, self.goal_data = self._try_and_extract_goal_from_chat()
logger.warning(f"Empty goal, trying to extract goal from chat as a fallback.")
self.goal = self._try_and_extract_goal_from_chat()

# obs is generic to all tasks
obs = {
"chat_messages": copy.deepcopy(self.chat.messages),
"goal": self.goal,
"goal_data": self.goal_data,
"goal": self._goal_to_text(self.goal), # legacy goal, deprecated
"goal_object": self.goal, # new goal format, liust of messages openai style
"open_pages_urls": [page.url for page in self.context.pages],
"active_page_index": np.asarray([self.context.pages.index(self.page)]),
"url": self.page.url,
Expand All @@ -520,83 +543,34 @@ def _get_obs(self):

return obs

def _process_task_goal(self, task_goal, send_to_chat: bool = False):
# no goal specified
if task_goal is None:
goal = ""
goal_data = {}

# text-only goal
elif isinstance(task_goal, str):
goal = task_goal
goal_data = {}

if send_to_chat:
self.chat.add_message(role="user", msg=task_goal)

# goal with text and images
elif isinstance(task_goal, dict):
goal = task_goal["text"]
goal_data = {
"image_urls": task_goal["image_urls"],
"image_paths": [],
}
# save images from the goal to files in a local directory
temp_dir = Path(tempfile.mkdtemp())
for image_i, image_url in enumerate(task_goal["image_urls"]):
# download remotely hosted images
if image_url.startswith("http"):
image = Image.open(requests.get(image_url, stream=True).raw)
# decode base64-encoded images
elif image_url.startswith("data:image"):
image = b64_to_pil(image_url)
else:
raise ValueError(f"Unexpected image_url: {image_url}")
# write images to local files
format = image.format.lower()
image_path = temp_dir / f"input_image_{image_i}.{format}"
image.save(image_path)
# add image path to the goal
goal_data["image_paths"].append(image_path)

if send_to_chat:
for image_url in task_goal["image_urls"]:
# send goal images to the chat
self.chat.add_message(role="user_image", msg=image_url)
# send goal text to the chat, after the images
self.chat.add_message(role="user", msg=task_goal["text"])
else:
raise ValueError(f"task_goal should be of type str or dict, got {task_goal.__class__}")

return goal, goal_data

def _try_and_extract_goal_from_chat(self):
# as a fallback, when a task does not specify a goal, we try to convert the first chat message into a goal
first_user_message = None
first_user_images = []
goal = []
for msg in self.chat.messages:
# extract first user message as goal, if any
if msg["role"] == "user":
first_user_message = msg["message"]
goal = [{"type": "text", "text": msg["message"]}]
break
# extract any user_image message present before as a goal image
elif msg["role"] == "user_image":
first_user_images.append(msg["message"])
else:
pass

# convert chat messages to a task_goal, if any
if first_user_message is None:
task_goal = None
elif not first_user_images:
task_goal = first_user_message
else:
task_goal = {
"text": first_user_message,
"image_urls": first_user_images,
}

# process the task_goal into a proper goal, if any
goal, goal_data = self._process_task_goal(task_goal=task_goal, send_to_chat=False)
return goal

def _goal_to_text(self, goal: list):
goal_text_strings = []
for message in goal:
match message["type"]:
case "text":
goal_text_strings.append(message["text"])
case "image_url":
if message["image_url"].startswith("data:image"):
goal_text_strings.append(
"image_url: " + message["image_url"][:30] + "... (truncated)"
)
else:
goal_text_strings.append("image_url: " + message["image_url"])
case _:
raise ValueError(
f"Unknown message type {repr(message['type'])} in the task goal."
)
goal_text = "\n".join(goal_text_strings)

return goal, goal_data
return goal_text
25 changes: 0 additions & 25 deletions browsergym/core/src/browsergym/utils/obs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import ast
import base64
import io
import logging
import numpy as np
import PIL.Image
Expand All @@ -10,7 +8,6 @@

from collections import defaultdict
from bs4 import BeautifulSoup
from typing import Literal

from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR
from browsergym.core.constants import BROWSERGYM_VISIBILITY_ATTRIBUTE as VIS_ATTR
Expand Down Expand Up @@ -548,25 +545,3 @@ def prune_html(html):
html = soup.prettify()

return html


def pil_to_b64(img: PIL.Image.Image, format: Literal["png", "jpeg"] = "png") -> str:
assert format in ("png", "jpeg")
with io.BytesIO() as image_buffer:
img.save(image_buffer, format=format.upper())
byte_data = image_buffer.getvalue()
img_b64 = base64.b64encode(byte_data).decode("utf-8")
img_b64 = f"data:image/{format};base64," + img_b64
return img_b64


def b64_to_pil(img_b64: str) -> str:
if img_b64.startswith("data:image/png;base64,"):
img_b64 = img_b64.removeprefix("data:image/png;base64,")
elif img_b64.startswith("data:image/jpeg;base64,"):
img_b64 = img_b64.removeprefix("data:image/jpeg;base64,")
else:
raise ValueError(f"Unexpected base64 encoding: {img_b64}")
img_data = base64.b64decode(img_b64)
img = PIL.Image.open(io.BytesIO(img_data))
return img
62 changes: 49 additions & 13 deletions browsergym/visualwebarena/src/browsergym/visualwebarena/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import playwright.sync_api
import importlib.resources
import pathlib
import tempfile
import requests

Expand All @@ -10,6 +11,7 @@
from browsergym.core.task import AbstractBrowserTask

from .instance import VisualWebArenaInstance
from .utils import image_url_to_pil_image

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,26 +131,60 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
if i < len(start_urls) - 1:
page = page.context.new_page()

# recover goal
goal = {
"text": self.config["intent"],
"image_urls": self.config.get("image", []),
}
# fix goal if needed
if goal["image_urls"] is None:
goal["image_urls"] = []
elif isinstance(goal["image_urls"], str):
goal["image_urls"] = [goal["image_urls"]]
# recover goal text
goal_text = self.config["intent"]

# This note is present in some of webarena's agent prompts
if self.with_na_hint:
goal[
"text"
] += """\
goal_text += """\
If you believe the task is impossible to complete, provide the answer "N/A".
"""

# recover goal image urls
image_urls = self.config.get("image", [])

# fix image list if needed
if image_urls is None:
image_urls = []
elif isinstance(image_urls, str):
image_urls = [image_urls]

# save goal images to local files in a temporary directory
image_paths = []
temp_dir = pathlib.Path(tempfile.mkdtemp())
for i, image_url in enumerate(image_urls):
# extract image content from url
image = image_url_to_pil_image(image_url)
# write image to local file
format = image.format.lower()
image_path = temp_dir / f"input_image_{i}.{format}"
image.save(image_path)
# add image path to the goal
image_paths.append(image_path)

# build an OpenAI-style structured goal
# textual goal first
goal = [{"type": "text", "text": goal_text}]
# then goal images
for i, (image_url, image_path) in enumerate(zip(image_urls, image_paths)):
goal.extend(
[
# image description (id and filepath)
{
"type": "text",
"text": f"Input image {i}/{len(image_urls)} below (local path: {repr(image_path)})",
},
# actual image (image_url)
{
"type": "image_url",
"image_url": {
"url": image_url,
},
},
]
)

return goal, {}

def cheat(self, page: playwright.sync_api.Page, chat_messages: list[str]) -> None:
Expand Down
20 changes: 20 additions & 0 deletions browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import base64
import io
import PIL.Image
import requests


def image_url_to_pil_image(image_url: str) -> str:
if image_url.startswith("http"):
image_data = requests.get(image_url, stream=True).raw
elif image_url.startswith("data:image/png;base64,"):
image_data = base64.b64decode(image_url.removeprefix("data:image/png;base64,"))
elif image_url.startswith("data:image/jpeg;base64,"):
image_data = base64.b64decode(image_url.removeprefix("data:image/jpeg;base64,"))
else:
if image_url.startswith("data:image/"):
raise ValueError(f"Unexpected image encoding: {image_url}")
else:
raise ValueError(f"Unexpected image URL: {image_url}")
img = PIL.Image.open(io.BytesIO(image_data))
return img

0 comments on commit 4be7ad5

Please sign in to comment.