diff --git a/README.md b/README.md
index f0fbd2a3..d711b045 100644
--- a/README.md
+++ b/README.md
@@ -158,24 +158,44 @@ print("\n".join(env_ids))
## Demo
-If you want to experiment with an agent in BrowserGym, follow these steps:
+If you want to experiment with a demo agent in BrowserGym, follow these steps:
```sh
cd demo-agent
-conda env create -f environment.yml; conda activate demo-agent
+conda env create -f environment.yml
+conda activate demo-agent
# or simply use `pip install -r requirements.txt`
playwright install chromium
```
-Optional: Set your `OPENAI_API_KEY` to use a GPT agent.
-
-Launch the demo on the open web:
+Our demo agent uses `openai` as a backend, be sure to set your `OPENAI_API_KEY`.
+Launch the demo agent on the open web:
```sh
python run_demo.py --task_name openended --start_url https://www.google.com
```
-You can customize your experience by changing the `model_name` to your preferred LLM, toggling Chain-of-thought with `use_thinking`, adding screenshots for your VLMs with `use_screenshot`, and much more!
+Or use it to solve a simple MiniWoB task:
+```sh
+python run_demo.py --task_name miniwob.click-test
+```
+
+A VisualWebArena task:
+```sh
+python run_demo.py --task_name visualwebarena.398
+```
+
+A WebArena task:
+```sh
+python run_demo.py --task_name webarena.4
+```
+
+A WorkArena task:
+```sh
+python run_demo.py --task_name workarena.servicenow.order-standard-laptop
+```
+
+You can customize your experience by changing the `model_name` to your preferred LLM (it uses `gpt-4o-mini` by default), adding screenshots for your VLMs with `use_screenshot`, and much more! (see `python run_demo.py --help`)
## Citing This Work
diff --git a/browsergym/core/src/browsergym/core/env.py b/browsergym/core/src/browsergym/core/env.py
index 662293f7..22fffaaa 100644
--- a/browsergym/core/src/browsergym/core/env.py
+++ b/browsergym/core/src/browsergym/core/env.py
@@ -1,38 +1,54 @@
import copy
-import gymnasium as gym
import logging
-import numpy as np
-import playwright.sync_api
-import time
import re
-
+import time
from abc import ABC
from pathlib import Path
-from typing import Optional, Literal
+from typing import Literal, Optional
+import gymnasium as gym
+import numpy as np
+import playwright.sync_api
+
+from . import _get_global_playwright
+from .action.base import execute_python_code
+from .action.highlevel import HighLevelActionSet
from .chat import Chat
-from .task import AbstractBrowserTask
-from .spaces import Unicode, AnyDict, AnyBox
-from .constants import TEXT_MAX_LENGTH, BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES
+from .constants import BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES, TEXT_MAX_LENGTH
from .observation import (
- _pre_extract,
+ MarkingError,
_post_extract,
- extract_screenshot,
- extract_dom_snapshot,
+ _pre_extract,
extract_dom_extra_properties,
- extract_merged_axtree,
+ extract_dom_snapshot,
extract_focused_element_bid,
- MarkingError,
+ extract_merged_axtree,
+ extract_screenshot,
)
-from .action.base import execute_python_code
-from .action.highlevel import HighLevelActionSet
-from .action.base import execute_python_code
-from . import _get_global_playwright
-
+from .spaces import AnyBox, AnyDict, Unicode
+from .task import AbstractBrowserTask
logger = logging.getLogger(__name__)
+def _try_to_extract_legacy_goal(goal: list):
+ legacy_goal_strings = []
+ for message in goal:
+ if message["type"] == "text":
+ legacy_goal_strings.append(message["text"])
+ else:
+ logger.debug(
+ f"Message type {repr(message['type'])} present in the goal, cannot be converted to legacy text-only format."
+ )
+ legacy_goal_strings.append(
+ 'WARNING: This goal cannot be converted to a text-only goal format. Use the new goal format instead ("goal_object" field). Any agent reading this should abort immediately.'
+ )
+ break
+ legacy_goal = "\n".join(legacy_goal_strings)
+
+ return legacy_goal
+
+
class BrowserEnv(gym.Env, ABC):
"""The main BrowserGym class, which encapsulates instruction-following Web browsing into a Gymnasium environment."""
@@ -120,11 +136,8 @@ def __init__(
}
)
),
- # TODO: this is redundant with chat messages, to be removed
"goal": Unicode(min_length=0, max_length=TEXT_MAX_LENGTH),
- "goal_image_urls": gym.spaces.Sequence(
- Unicode(min_length=0, max_length=TEXT_MAX_LENGTH)
- ),
+ "goal_object": gym.spaces.Sequence(AnyDict()),
"open_pages_urls": gym.spaces.Sequence(
Unicode(min_length=0, max_length=TEXT_MAX_LENGTH)
),
@@ -266,27 +279,42 @@ def override_property(task, env, property):
recording_start_time = time.time()
# setup the task
- goal, task_info = self.task.setup(page=self.page)
+ task_goal, task_info = self.task.setup(page=self.page)
+
+ # process the task goal
+
+ # no goal specified
+ if task_goal is None:
+ self.goal_object = []
+ # convert text-only goal (legacy) to new format
+ elif isinstance(task_goal, str):
+ self.goal_object = [{"type": "text", "text": task_goal}]
+ # new format goal with multiple texts and images (OpenAI style)
+ elif isinstance(task_goal, list):
+ self.goal_object = task_goal
+ else:
+ raise ValueError(f"task_goal should be of type str or list, got {task_goal.__class__}")
# initialize the chat
self.chat.add_message(
role="assistant",
msg="Hi! I am your UI assistant, I can perform web tasks for you. What can I help you with?",
)
- # if any, add the task's goal to the chat
- if goal:
-
- # goal is text-only
- if isinstance(goal, str):
- goal_msg = goal
- # goal is text + images
- elif isinstance(goal, dict):
- goal_msg = goal["message"]
- for image_url in goal["image_urls"]:
- self.chat.add_message(role="user_image", msg=image_url)
-
- self.chat.add_message(role="user", msg=goal_msg)
+ # send task goal (if any) to the chat
+ for message in self.goal_object:
+ match message["type"]:
+ case "text":
+ self.chat.add_message(role="user", msg=message["text"])
+ case "image_url":
+ image_src = message["image_url"]
+ if isinstance(image_src, dict):
+ image_src = image_src["url"]
+ self.chat.add_message(role="user_image", msg=image_src)
+ case _:
+ raise ValueError(
+ f"Unknown message type {repr(message['type'])} in the task goal."
+ )
self._wait_dom_loaded()
@@ -508,26 +536,11 @@ def _get_obs(self):
# post-extraction cleanup of temporary info in dom
_post_extract(self.page)
- # use first user message as goal, if any
- # use all user images before first user message as goal images, if any
- goal_msg = "There is no goal."
- goal_image_urls = []
- _prev_image_urls = []
- for msg in self.chat.messages:
- if msg["role"] == "user_image":
- _prev_image_urls.append(msg["message"])
- elif msg["role"] == "user":
- goal_msg = msg["message"]
- goal_image_urls = _prev_image_urls
- break
- else:
- pass
-
# obs is generic to all tasks
obs = {
"chat_messages": copy.deepcopy(self.chat.messages),
- "goal": goal_msg, # TODO: redundant with chat messages, to be removed?
- "goal_image_urls": goal_image_urls, # TODO: redundant with chat messages, to be removed?
+ "goal": _try_to_extract_legacy_goal(self.goal_object), # legacy goal, deprecated
+ "goal_object": self.goal_object, # new goal format, list of messages openai style
"open_pages_urls": [page.url for page in self.context.pages],
"active_page_index": np.asarray([self.context.pages.index(self.page)]),
"url": self.page.url,
diff --git a/browsergym/core/src/browsergym/core/registration.py b/browsergym/core/src/browsergym/core/registration.py
index dd0e36ed..8bc23ca6 100644
--- a/browsergym/core/src/browsergym/core/registration.py
+++ b/browsergym/core/src/browsergym/core/registration.py
@@ -6,7 +6,12 @@
def register_task(
- id: str, task_class: Type[AbstractBrowserTask], nondeterministic: bool = True, *args, **kwargs
+ id: str,
+ task_class: Type[AbstractBrowserTask],
+ task_kwargs: dict = None,
+ nondeterministic: bool = True,
+ *args,
+ **kwargs,
):
"""
Registers a browser task as a gym environment with its unique id.
@@ -19,9 +24,16 @@ def register_task(
*kwargs: additional arguments for the browsergym environment.
"""
+ # these environment arguments will be fixed, and error will be raised if they are set when calling gym.make()
+ fixed_env_kwargs = {}
+ if task_kwargs is not None:
+ fixed_env_kwargs["task_kwargs"] = task_kwargs
+
gym.register(
id=f"browsergym/{id}",
- entry_point=lambda *env_args, **env_kwargs: BrowserEnv(task_class, *env_args, **env_kwargs),
+ entry_point=lambda *env_args, **env_kwargs: BrowserEnv(
+ task_class, *env_args, **fixed_env_kwargs, **env_kwargs
+ ),
nondeterministic=nondeterministic,
*args,
**kwargs,
diff --git a/browsergym/core/src/browsergym/core/spaces.py b/browsergym/core/src/browsergym/core/spaces.py
index 177959e5..fb3ee7fe 100644
--- a/browsergym/core/src/browsergym/core/spaces.py
+++ b/browsergym/core/src/browsergym/core/spaces.py
@@ -79,6 +79,19 @@ def __eq__(self, other: Any) -> bool:
return isinstance(other, AnyDict)
+class Anything(Space):
+ """A space representing an arbitrary dictionary object."""
+
+ def contains(self, x: Any) -> bool:
+ return True
+
+ def __repr__(self) -> str:
+ return f"Anything()"
+
+ def __eq__(self, other: Any) -> bool:
+ return isinstance(other, Anything)
+
+
class AnyBox(Space[NDArray[Any]]):
"""A space representing an arbitrary dictionary object."""
diff --git a/browsergym/core/src/browsergym/core/task.py b/browsergym/core/src/browsergym/core/task.py
index 6555223d..6f7f9fec 100644
--- a/browsergym/core/src/browsergym/core/task.py
+++ b/browsergym/core/src/browsergym/core/task.py
@@ -1,9 +1,9 @@
-import numpy as np
-import playwright.sync_api
-
from abc import ABC, abstractmethod
from typing import Tuple
+import numpy as np
+import playwright.sync_api
+
class AbstractBrowserTask(ABC):
"""
diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py
index f39d4622..95b81c5d 100644
--- a/browsergym/experiments/src/browsergym/experiments/loop.py
+++ b/browsergym/experiments/src/browsergym/experiments/loop.py
@@ -432,7 +432,18 @@ def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_so
img = Image.fromarray(screenshot_som)
img.save(exp_dir / f"screenshot_som_step_{self.step}.png")
+ # save goal object (which might contain images) to a separate file to save space
+ if self.obs is not None and self.obs.get("goal_object", False):
+ # save the goal object only once (goal should never change once setup)
+ goal_object_file = Path(exp_dir) / "goal_object.pkl.gz"
+ if not goal_object_file.exists():
+ with gzip.open(goal_object_file, "wb") as f:
+ pickle.dump(self.obs["goal_object"], f)
+ # set goal_object to a special placeholder value, which indicates it should be loaded from a separate file
+ self.obs["goal_object"] = None
+
with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f:
+ # TODO should we pop the screenshots too before this to save space ?
pickle.dump(self, f)
if save_json:
@@ -584,6 +595,16 @@ def get_step_info(self, step: int) -> StepInfo:
)
except FileNotFoundError:
pass
+ # if goal_object is set to None, it indicates it has been saved into a separate file
+ if (
+ self._steps_info[step].obs
+ and "goal_object" in self._steps_info[step].obs
+ and self._steps_info[step].obs["goal_object"] is None
+ ):
+ with gzip.open(self.exp_dir / "goal_object.pkl.gz", "rb") as f:
+ goal_object = pickle.load(f)
+ self._steps_info[step].obs["goal_object"] = goal_object
+
return self._steps_info[step]
@property
diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py
index e229da99..c2ec0c88 100644
--- a/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py
+++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py
@@ -13,7 +13,7 @@
register_task(
gym_id,
task.GenericVisualWebArenaTask,
- kwargs={"task_kwargs": {"task_id": task_id}},
+ task_kwargs={"task_id": task_id},
)
ALL_VISUALWEBARENA_TASK_IDS.append(gym_id)
if task_id in config.TASK_IDS_WITH_RESET:
diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py
index 90a8624a..00a3107e 100644
--- a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py
+++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py
@@ -2,6 +2,7 @@
import logging
import playwright.sync_api
import importlib.resources
+import pathlib
import tempfile
import requests
@@ -10,10 +11,83 @@
from browsergym.core.task import AbstractBrowserTask
from .instance import VisualWebArenaInstance
+from .utils import image_url_to_pil_image, pil_image_to_data_uri
logger = logging.getLogger(__name__)
+def _build_goal(config, with_na_hint: bool):
+ """
+ Build an openai-style goal (list of messages)
+ - recovers the goal text from config
+ - download goal images if any
+ - save goal images to local files
+ - expose goal images as image_url messages using base64 encoding
+ - expose goal images as local file paths (if task requires to upload them)
+ """
+
+ # recover goal text
+ goal_text = config["intent"]
+
+ # This note is present in some of webarena's agent prompts
+ if with_na_hint:
+ goal_text += """\
+
+If you believe the task is impossible to complete, provide the answer "N/A".
+"""
+
+ # recover goal image urls
+ image_urls = config.get("image", [])
+ image_data_uris = []
+ image_paths = []
+
+ # fix image list if needed
+ if image_urls is None:
+ image_urls = []
+ elif isinstance(image_urls, str):
+ image_urls = [image_urls]
+
+ # save images to local files in a temporary directory
+ temp_dir = pathlib.Path(tempfile.mkdtemp())
+ for i, image_url in enumerate(image_urls):
+ # extract image content from url
+ image = image_url_to_pil_image(image_url)
+ format = image.format.lower()
+ # write image to local file
+ image_path = temp_dir / f"input_image_{i+1}.{format}"
+ image.save(image_path)
+ # save image path for the goal
+ image_paths.append(image_path)
+ # save image data as base64 for the goal
+ image_data_uris.append(pil_image_to_data_uri(image))
+
+ # build an OpenAI-style structured goal
+ # textual goal first
+ goal = [{"type": "text", "text": goal_text}]
+ # then goal images
+ for i, (image_url, image_data_uri, image_path) in enumerate(
+ zip(image_urls, image_data_uris, image_paths)
+ ):
+ goal.extend(
+ [
+ # image description (id, filepath, url)
+ {
+ "type": "text",
+ "text": f"Input image {i+1}/{len(image_urls)} below (local path: {repr(image_path)}, url: {repr(image_url)})",
+ },
+ # actual image (base64 image data URI)
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": image_data_uri, # send data URI instead of URL (local urls might be inaccessible from the outside)
+ },
+ },
+ ]
+ )
+
+ return goal
+
+
class GenericVisualWebArenaTask(AbstractBrowserTask):
"""
Base class for all WebArena tasks.
@@ -52,13 +126,14 @@ def __init__(
)
# substitute URLs
- for pattern, url_key in {
- "__REDDIT__": "reddit",
- "__SHOPPING__": "shopping",
- "__WIKIPEDIA__": "wikipedia",
- "__CLASSIFIEDS__": "classifieds",
+ for pattern, url in {
+ "__REDDIT__": self.webarena_instance.urls["reddit"],
+ "__SHOPPING__": self.webarena_instance.urls["shopping"],
+ "__WIKIPEDIA__": self.webarena_instance.urls["wikipedia"],
+ "__CLASSIFIEDS__": self.webarena_instance.urls["classifieds"],
+ "__HOMEPAGE__": self.webarena_instance.home_url,
}.items():
- all_configs_str = all_configs_str.replace(pattern, self.webarena_instance.urls[url_key])
+ all_configs_str = all_configs_str.replace(pattern, url)
# load all task configs to JSON
all_configs = json.loads(all_configs_str)
@@ -129,25 +204,7 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
if i < len(start_urls) - 1:
page = page.context.new_page()
- # recover goal
- goal = {
- "message": self.config["intent"],
- "image_urls": self.config.get("image", []),
- }
- # fix goal if needed
- if goal["image_urls"] is None:
- goal["image_urls"] = []
- elif isinstance(goal["image_urls"], str):
- goal["image_urls"] = [goal["image_urls"]]
-
- # This note is present in some of webarena's agent prompts
- if self.with_na_hint:
- goal[
- "message"
- ] += """\
-
-If you believe the task is impossible to complete, provide the answer "N/A".
-"""
+ goal = _build_goal(self.config, with_na_hint=self.with_na_hint)
return goal, {}
diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py
new file mode 100644
index 00000000..c02d2f72
--- /dev/null
+++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py
@@ -0,0 +1,39 @@
+import base64
+import io
+import PIL.Image
+import requests
+
+from typing import Literal
+
+
+def image_url_to_pil_image(image_url: str) -> PIL.Image:
+ if not image_url.startswith("http"):
+ raise ValueError(f"Unexpected image URL: {image_url}")
+ response = requests.get(image_url, stream=True)
+ if response.status_code != 200:
+ raise ValueError(
+ f"Could not download image from url {image_url} (status code {response.status_code})"
+ )
+ img = PIL.Image.open(io.BytesIO(response.content))
+ return img
+
+
+def data_uri_to_pil_image(data_uri: str) -> PIL.Image:
+ if data_uri.startswith("data:image/png;base64,"):
+ image_data = base64.b64decode(data_uri.removeprefix("data:image/png;base64,"))
+ elif data_uri.startswith("data:image/jpeg;base64,"):
+ image_data = base64.b64decode(data_uri.removeprefix("data:image/jpeg;base64,"))
+ else:
+ raise ValueError(f"Unexpected image encoding: {data_uri}")
+ img = PIL.Image.open(io.BytesIO(image_data))
+ return img
+
+
+def pil_image_to_data_uri(image: PIL.Image, format: Literal["png", "jpeg"] = "png") -> str:
+ assert format in ("png", "jpeg")
+ with io.BytesIO() as image_buffer:
+ image.save(image_buffer, format=format.upper())
+ byte_data = image_buffer.getvalue()
+ image_b64 = base64.b64encode(byte_data).decode("utf-8")
+ image_b64 = f"data:image/{format};base64," + image_b64
+ return image_b64
diff --git a/browsergym/webarena/src/browsergym/webarena/__init__.py b/browsergym/webarena/src/browsergym/webarena/__init__.py
index 9695ce30..26c9319a 100644
--- a/browsergym/webarena/src/browsergym/webarena/__init__.py
+++ b/browsergym/webarena/src/browsergym/webarena/__init__.py
@@ -11,6 +11,6 @@
register_task(
gym_id,
task.GenericWebArenaTask,
- kwargs={"task_kwargs": {"task_id": task_id}},
+ task_kwargs={"task_id": task_id},
)
ALL_WEBARENA_TASK_IDS.append(gym_id)
diff --git a/demo_agent/__init__.py b/demo_agent/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/demo_agent/agents/__init__.py b/demo_agent/agents/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/demo_agent/agents/basic/__init__.py b/demo_agent/agents/basic/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/demo_agent/agents/basic/agent.py b/demo_agent/agents/basic/agent.py
deleted file mode 100644
index b5e1e6be..00000000
--- a/demo_agent/agents/basic/agent.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import dataclasses
-
-from browsergym.experiments import Agent, AbstractAgentArgs
-from browsergym.core.action.highlevel import HighLevelActionSet
-from browsergym.core.action.python import PythonActionSet
-from browsergym.utils.obs import flatten_axtree_to_str
-
-
-class DemoAgent(Agent):
- """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities."""
-
- action_set = HighLevelActionSet(
- subsets=["chat", "bid"], # define a subset of the action space
- # subsets=["chat", "bid", "coord"] # allow the agent to also use x,y coordinates
- strict=False, # less strict on the parsing of the actions
- multiaction=True, # enable to agent to take multiple actions at once
- demo_mode="default", # add visual effects
- )
- # use this instead to allow the agent to directly use Python code
- # action_set = PythonActionSet())
-
- def obs_preprocessor(self, obs: dict) -> dict:
- return {
- "goal": obs["goal"],
- "axtree_txt": flatten_axtree_to_str(obs["axtree_object"]),
- }
-
- def __init__(self, model_name) -> None:
- super().__init__()
- self.model_name = model_name
-
- from openai import OpenAI
-
- self.openai_client = OpenAI()
-
- def get_action(self, obs: dict) -> tuple[str, dict]:
- system_msg = f"""\
-# Instructions
-Review the current state of the page and all other information to find the best
-possible next action to accomplish your goal. Your answer will be interpreted
-and executed by a program, make sure to follow the formatting instructions.
-
-# Goal:
-{obs["goal"]}"""
-
- prompt = f"""\
-# Current Accessibility Tree:
-{obs["axtree_txt"]}
-
-# Action Space
-{self.action_set.describe(with_long_description=False, with_examples=True)}
-
-Here is an example with chain of thought of a valid action when clicking on a button:
-"
-In order to accomplish my goal I need to click on the button with bid 12
-```click("12")```
-"
-"""
-
- # query OpenAI model
- response = self.openai_client.chat.completions.create(
- model=self.model_name,
- messages=[
- {"role": "system", "content": system_msg},
- {"role": "user", "content": prompt},
- ],
- )
- action = response.choices[0].message.content
-
- return action, {}
-
-
-@dataclasses.dataclass
-class DemoAgentArgs(AbstractAgentArgs):
- """
- This class is meant to store the arguments that define the agent.
-
- By isolating them in a dataclass, this ensures serialization without storing
- internal states of the agent.
- """
-
- model_name: str = "gpt-3.5-turbo"
-
- def make_agent(self):
- return DemoAgent(model_name=self.model_name)
-
-
-def main():
- from browsergym.experiments import EnvArgs, ExpArgs, get_exp_result
- from pathlib import Path
-
- exp_root = Path().home() / "agent_experiments"
- exp_root.mkdir(exist_ok=True)
-
- exp_args = ExpArgs(
- agent_args=DemoAgentArgs(model_name="gpt-3.5-turbo"),
- env_args=EnvArgs(
- task_name="miniwob.click-test",
- task_seed=42,
- headless=False, # shows the browser
- ),
- )
-
- exp_args.prepare(exp_root=exp_root)
- exp_args.run()
-
- exp_result = get_exp_result(exp_args.exp_dir)
- exp_record = exp_result.get_exp_record()
-
- for key, val in exp_record.items():
- print(f"{key}: {val}")
diff --git a/demo_agent/agents/legacy/__init__.py b/demo_agent/agents/legacy/__init__.py
deleted file mode 100644
index 005e6fb1..00000000
--- a/demo_agent/agents/legacy/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .agent import GenericAgentArgs
-from .dynamic_prompting import Flags
diff --git a/demo_agent/agents/legacy/agent.py b/demo_agent/agents/legacy/agent.py
deleted file mode 100644
index a6b263a8..00000000
--- a/demo_agent/agents/legacy/agent.py
+++ /dev/null
@@ -1,149 +0,0 @@
-"""
-WARNING DEPRECATED WILL BE REMOVED SOON
-"""
-
-from dataclasses import asdict, dataclass, field
-import traceback
-from warnings import warn
-from langchain.schema import HumanMessage, SystemMessage
-
-from browsergym.core.action.base import AbstractActionSet
-from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html
-from browsergym.experiments import Agent, AbstractAgentArgs
-
-from ..legacy import dynamic_prompting
-from .utils.llm_utils import ParseError, retry
-from .utils.chat_api import ChatModelArgs
-
-
-@dataclass
-class GenericAgentArgs(AbstractAgentArgs):
- chat_model_args: ChatModelArgs = None
- flags: dynamic_prompting.Flags = field(default_factory=lambda: dynamic_prompting.Flags())
- max_retry: int = 4
-
- def make_agent(self):
- return GenericAgent(
- chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
- )
-
-
-class GenericAgent(Agent):
-
- def obs_preprocessor(self, obs: dict) -> dict:
- """
- Augment observations with text HTML and AXTree representations, which will be stored in
- the experiment traces.
- """
-
- obs = obs.copy()
- obs["dom_txt"] = flatten_dom_to_str(
- obs["dom_object"],
- with_visible=self.flags.extract_visible_tag,
- with_center_coords=self.flags.extract_coords == "center",
- with_bounding_box_coords=self.flags.extract_coords == "box",
- filter_visible_only=self.flags.extract_visible_elements_only,
- )
- obs["axtree_txt"] = flatten_axtree_to_str(
- obs["axtree_object"],
- with_visible=self.flags.extract_visible_tag,
- with_center_coords=self.flags.extract_coords == "center",
- with_bounding_box_coords=self.flags.extract_coords == "box",
- filter_visible_only=self.flags.extract_visible_elements_only,
- )
- obs["pruned_html"] = prune_html(obs["dom_txt"])
-
- return obs
-
- def __init__(
- self,
- chat_model_args: ChatModelArgs = None,
- flags: dynamic_prompting.Flags = None,
- max_retry: int = 4,
- ):
- self.chat_model_args = chat_model_args if chat_model_args is not None else ChatModelArgs()
- self.flags = flags if flags is not None else dynamic_prompting.Flags()
- self.max_retry = max_retry
-
- self.chat_llm = chat_model_args.make_chat_model()
- self.action_set = dynamic_prompting._get_action_space(self.flags)
-
- # consistency check
- if self.flags.use_screenshot:
- if not self.chat_model_args.has_vision():
- warn(
- """\
-
-Warning: use_screenshot is set to True, but the chat model \
-does not support vision. Disabling use_screenshot."""
- )
- self.flags.use_screenshot = False
-
- # reset episode memory
- self.obs_history = []
- self.actions = []
- self.memories = []
- self.thoughts = []
-
- def get_action(self, obs):
-
- self.obs_history.append(obs)
-
- main_prompt = dynamic_prompting.MainPrompt(
- obs_history=self.obs_history,
- actions=self.actions,
- memories=self.memories,
- thoughts=self.thoughts,
- flags=self.flags,
- )
-
- # Determine the minimum non-None token limit from prompt, total, and input tokens, or set to None if all are None.
- maxes = (
- self.flags.max_prompt_tokens,
- self.chat_model_args.max_total_tokens,
- self.chat_model_args.max_input_tokens,
- )
- maxes = [m for m in maxes if m is not None]
- max_prompt_tokens = min(maxes) if maxes else None
-
- prompt = dynamic_prompting.fit_tokens(
- main_prompt,
- max_prompt_tokens=max_prompt_tokens,
- model_name=self.chat_model_args.model_name,
- )
-
- chat_messages = [
- SystemMessage(content=dynamic_prompting.SystemPrompt().prompt),
- HumanMessage(content=prompt),
- ]
-
- def parser(text):
- try:
- ans_dict = main_prompt._parse_answer(text)
- except ParseError as e:
- # these parse errors will be caught by the retry function and
- # the chat_llm will have a chance to recover
- return None, False, str(e)
-
- return ans_dict, True, ""
-
- try:
- ans_dict = retry(self.chat_llm, chat_messages, n_retry=self.max_retry, parser=parser)
- # inferring the number of retries, TODO: make this less hacky
- ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
- except ValueError as e:
- # Likely due to maximum retry. We catch it here to be able to return
- # the list of messages for further analysis
- ans_dict = {"action": None}
- ans_dict["err_msg"] = str(e)
- ans_dict["stack_trace"] = traceback.format_exc()
- ans_dict["n_retry"] = self.max_retry
-
- self.actions.append(ans_dict["action"])
- self.memories.append(ans_dict.get("memory", None))
- self.thoughts.append(ans_dict.get("think", None))
-
- ans_dict["chat_messages"] = [m.content for m in chat_messages]
- ans_dict["chat_model_args"] = asdict(self.chat_model_args)
-
- return ans_dict["action"], ans_dict
diff --git a/demo_agent/agents/legacy/dynamic_prompting.py b/demo_agent/agents/legacy/dynamic_prompting.py
deleted file mode 100644
index 20e3f7bc..00000000
--- a/demo_agent/agents/legacy/dynamic_prompting.py
+++ /dev/null
@@ -1,757 +0,0 @@
-"""
-WARNING DEPRECATED WILL BE REMOVED SOON
-"""
-
-import abc
-import difflib
-import logging
-import platform
-
-from copy import deepcopy
-from dataclasses import asdict, dataclass
-from textwrap import dedent
-from typing import Literal
-from warnings import warn
-
-from browsergym.core.action.base import AbstractActionSet
-from browsergym.core.action.highlevel import HighLevelActionSet
-from browsergym.core.action.python import PythonActionSet
-
-from .utils.llm_utils import ParseError
-from .utils.llm_utils import (
- count_tokens,
- image_to_jpg_base64_url,
- parse_html_tags_raise,
-)
-
-
-@dataclass
-class Flags:
- use_html: bool = True
- use_ax_tree: bool = False
- drop_ax_tree_first: bool = True # This flag is no longer active TODO delete
- use_thinking: bool = False
- use_error_logs: bool = False
- use_past_error_logs: bool = False
- use_history: bool = False
- use_action_history: bool = False
- use_memory: bool = False
- use_diff: bool = False
- html_type: str = "pruned_html"
- use_concrete_example: bool = True
- use_abstract_example: bool = False
- multi_actions: bool = False
- action_space: Literal[
- "python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"
- ] = "bid"
- is_strict: bool = False
- # This flag will be automatically disabled `if not chat_model_args.has_vision()`
- use_screenshot: bool = True
- enable_chat: bool = False
- max_prompt_tokens: int = None
- extract_visible_tag: bool = False
- extract_coords: Literal["False", "center", "box"] = "False"
- extract_visible_elements_only: bool = False
- demo_mode: Literal["off", "default", "only_visible_elements"] = "off"
- retry_with_force: bool = False
-
- def copy(self):
- return deepcopy(self)
-
- def asdict(self):
- """Helper for JSON serializble requirement."""
- return asdict(self)
-
- @classmethod
- def from_dict(self, flags_dict):
- """Helper for JSON serializble requirement."""
- if isinstance(flags_dict, Flags):
- return flags_dict
-
- if not isinstance(flags_dict, dict):
- raise ValueError(f"Unregcognized type for flags_dict of type {type(flags_dict)}.")
- return Flags(**flags_dict)
-
-
-class PromptElement:
- """Base class for all prompt elements. Prompt elements can be hidden.
-
- Prompt elements are used to build the prompt. Use flags to control which
- prompt elements are visible. We use class attributes as a convenient way
- to implement static prompts, but feel free to override them with instance
- attributes or @property decorator."""
-
- _prompt = ""
- _abstract_ex = ""
- _concrete_ex = ""
-
- def __init__(self, visible: bool = True) -> None:
- """Prompt element that can be hidden.
-
- Parameters
- ----------
- visible : bool, optional
- Whether the prompt element should be visible, by default True. Can
- be a callable that returns a bool. This is useful when a specific
- flag changes during a shrink iteration.
- """
- self._visible = visible
-
- @property
- def prompt(self):
- """Avoid overriding this method. Override _prompt instead."""
- return self._hide(self._prompt)
-
- @property
- def abstract_ex(self):
- """Useful when this prompt element is requesting an answer from the llm.
- Provide an abstract example of the answer here. See Memory for an
- example.
-
- Avoid overriding this method. Override _abstract_ex instead
- """
- return self._hide(self._abstract_ex)
-
- @property
- def concrete_ex(self):
- """Useful when this prompt element is requesting an answer from the llm.
- Provide a concrete example of the answer here. See Memory for an
- example.
-
- Avoid overriding this method. Override _concrete_ex instead
- """
- return self._hide(self._concrete_ex)
-
- @property
- def is_visible(self):
- """Handle the case where visible is a callable."""
- visible = self._visible
- if callable(visible):
- visible = visible()
- return visible
-
- def _hide(self, value):
- """Return value if visible is True, else return empty string."""
- if self.is_visible:
- return value
- else:
- return ""
-
- def _parse_answer(self, text_answer) -> dict:
- if self.is_visible:
- return self._parse_answer(text_answer)
- else:
- return {}
-
-
-class Shrinkable(PromptElement, abc.ABC):
- @abc.abstractmethod
- def shrink(self) -> None:
- """Implement shrinking of this prompt element.
-
- You need to recursively call all shrinkable elements that are part of
- this prompt. You can also implement a shriking startegy for this prompt.
- Shrinking is can be called multiple times to progressively shrink the
- prompt until it fits max_tokens. Default max shrink iterations is 20.
- """
- pass
-
-
-class Trunkater(Shrinkable):
- def __init__(self, visible, shrink_speed=0.3, start_trunkate_iteration=10):
- super().__init__(visible=visible)
- self.shrink_speed = shrink_speed
- self.start_trunkate_iteration = start_trunkate_iteration
- self.shrink_calls = 0
- self.deleted_lines = 0
-
- def shrink(self) -> None:
- if self.is_visible and self.shrink_calls >= self.start_trunkate_iteration:
- # remove the fraction of _prompt
- lines = self._prompt.splitlines()
- new_line_count = int(len(lines) * (1 - self.shrink_speed))
- self.deleted_lines += len(lines) - new_line_count
- self._prompt = "\n".join(lines[:new_line_count])
- self._prompt += f"\n... Deleted {self.deleted_lines} lines to reduce prompt size."
-
- self.shrink_calls += 1
-
-
-def fit_tokens(
- shrinkable: Shrinkable, max_prompt_tokens=None, max_iterations=20, model_name="openai/gpt-4"
-):
- """Shrink a prompt element until it fits max_tokens.
-
- Parameters
- ----------
- shrinkable : Shrinkable
- The prompt element to shrink.
- max_tokens : int
- The maximum number of tokens allowed.
- max_iterations : int, optional
- The maximum number of shrink iterations, by default 20.
- model_name : str, optional
- The name of the model used when tokenizing.
-
- Returns
- -------
- str : the prompt after shrinking.
- """
-
- if max_prompt_tokens is None:
- return shrinkable.prompt
-
- for _ in range(max_iterations):
- prompt = shrinkable.prompt
- if isinstance(prompt, str):
- prompt_str = prompt
- elif isinstance(prompt, list):
- prompt_str = "\n".join([p["text"] for p in prompt if p["type"] == "text"])
- else:
- raise ValueError(f"Unrecognized type for prompt: {type(prompt)}")
- n_token = count_tokens(prompt_str, model=model_name)
- if n_token <= max_prompt_tokens:
- return prompt
- shrinkable.shrink()
-
- logging.info(
- dedent(
- f"""\
- After {max_iterations} shrink iterations, the prompt is still
- {count_tokens(prompt_str)} tokens (greater than {max_prompt_tokens}). Returning the prompt as is."""
- )
- )
- return prompt
-
-
-class HTML(Trunkater):
- def __init__(self, html, visible: bool = True, prefix="") -> None:
- super().__init__(visible=visible, start_trunkate_iteration=5)
- self._prompt = f"\n{prefix}HTML:\n{html}\n"
-
-
-class AXTree(Trunkater):
- def __init__(self, ax_tree, visible: bool = True, coord_type=None, prefix="") -> None:
- super().__init__(visible=visible, start_trunkate_iteration=10)
- if coord_type == "center":
- coord_note = """\
-Note: center coordinates are provided in parenthesis and are
- relative to the top left corner of the page.\n\n"""
- elif coord_type == "box":
- coord_note = """\
-Note: bounding box of each object are provided in parenthesis and are
- relative to the top left corner of the page.\n\n"""
- else:
- coord_note = ""
- self._prompt = f"\n{prefix}AXTree:\n{coord_note}{ax_tree}\n"
-
-
-class Error(PromptElement):
- def __init__(self, error, visible: bool = True, prefix="") -> None:
- super().__init__(visible=visible)
- self._prompt = f"\n{prefix}Error from previous action:\n{error}\n"
-
-
-class Observation(Shrinkable):
- """Observation of the current step.
-
- Contains the html, the accessibility tree and the error logs.
- """
-
- def __init__(self, obs, flags: Flags) -> None:
- super().__init__()
- self.flags = flags
- self.obs = obs
- self.html = HTML(obs[flags.html_type], visible=lambda: flags.use_html, prefix="## ")
- self.ax_tree = AXTree(
- obs["axtree_txt"],
- visible=lambda: flags.use_ax_tree,
- coord_type=flags.extract_coords,
- prefix="## ",
- )
- self.error = Error(
- obs["last_action_error"],
- visible=lambda: flags.use_error_logs and obs["last_action_error"],
- prefix="## ",
- )
-
- def shrink(self):
- self.ax_tree.shrink()
- self.html.shrink()
-
- @property
- def _prompt(self) -> str:
- return f"\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n"
-
- def add_screenshot(self, prompt):
- if self.flags.use_screenshot:
- if isinstance(prompt, str):
- prompt = [{"type": "text", "text": prompt}]
- img_url = image_to_jpg_base64_url(self.obs["screenshot"])
- prompt.append({"type": "image_url", "image_url": {"url": img_url}})
-
- return prompt
-
-
-class MacNote(PromptElement):
- def __init__(self) -> None:
- super().__init__(visible=platform.system() == "Darwin")
- self._prompt = (
- "\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n"
- )
-
-
-class BeCautious(PromptElement):
- def __init__(self, visible: bool = True) -> None:
- super().__init__(visible=visible)
- self._prompt = f"""\
-\nBe very cautious. Avoid submitting anything before verifying the effect of your
-actions. Take the time to explore the effect of safe actions first. For example
-you can fill a few elements of a form, but don't click submit before verifying
-that everything was filled correctly.\n"""
-
-
-class GoalInstructions(PromptElement):
- def __init__(self, goal, visible: bool = True) -> None:
- super().__init__(visible)
- self._prompt = f"""\
-# Instructions
-Review the current state of the page and all other information to find the best
-possible next action to accomplish your goal. Your answer will be interpreted
-and executed by a program, make sure to follow the formatting instructions.
-
-## Goal:
-{goal}
-"""
-
-
-class ChatInstructions(PromptElement):
- def __init__(self, chat_messages, visible: bool = True) -> None:
- super().__init__(visible)
- self._prompt = f"""\
-# Instructions
-
-You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can
-communicate with the user via a chat, in which the user gives you instructions and in which you
-can send back messages. You have access to a web browser that both you and the user can see,
-and with which only you can interact via specific commands.
-
-Review the instructions from the user, the current state of the page and all other information
-to find the best possible next action to accomplish your goal. Your answer will be interpreted
-and executed by a program, make sure to follow the formatting instructions.
-
-## Chat messages:
-
-"""
- self._prompt += "\n".join(
- [
- f"""\
- - [{msg['role']}] {msg['message']}"""
- for msg in chat_messages
- ]
- )
-
-
-class SystemPrompt(PromptElement):
- _prompt = """\
-You are an agent trying to solve a web task based on the content of the page and
-a user instructions. You can interact with the page and explore. Each time you
-submit an action it will be sent to the browser and you will receive a new page."""
-
-
-class MainPrompt(Shrinkable):
- def __init__(
- self,
- obs_history,
- actions,
- memories,
- thoughts,
- flags: Flags,
- ) -> None:
- super().__init__()
- self.flags = flags
- self.history = History(obs_history, actions, memories, thoughts, flags)
- if self.flags.enable_chat:
- self.instructions = ChatInstructions(obs_history[-1]["chat_messages"])
- else:
- if sum([msg["role"] == "user" for msg in obs_history[-1]["chat_messages"]]) > 1:
- logging.warning(
- "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
- )
- self.instructions = GoalInstructions(obs_history[-1]["goal"])
-
- self.obs = Observation(obs_history[-1], self.flags)
- self.action_space = ActionSpace(self.flags)
-
- self.think = Think(visible=lambda: flags.use_thinking)
- self.memory = Memory(visible=lambda: flags.use_memory)
-
- @property
- def _prompt(self) -> str:
- prompt = f"""\
-{self.instructions.prompt}\
-{self.obs.prompt}\
-{self.history.prompt}\
-{self.action_space.prompt}\
-{self.think.prompt}\
-{self.memory.prompt}\
-"""
-
- if self.flags.use_abstract_example:
- prompt += f"""
-# Abstract Example
-
-Here is an abstract version of the answer with description of the content of
-each tag. Make sure you follow this structure, but replace the content with your
-answer:
-{self.think.abstract_ex}\
-{self.memory.abstract_ex}\
-{self.action_space.abstract_ex}\
-"""
-
- if self.flags.use_concrete_example:
- prompt += f"""
-# Concrete Example
-
-Here is a concrete example of how to format your answer.
-Make sure to follow the template with proper tags:
-{self.think.concrete_ex}\
-{self.memory.concrete_ex}\
-{self.action_space.concrete_ex}\
-"""
- return self.obs.add_screenshot(prompt)
-
- def shrink(self):
- self.history.shrink()
- self.obs.shrink()
-
- def _parse_answer(self, text_answer):
- ans_dict = {}
- ans_dict.update(self.think._parse_answer(text_answer))
- ans_dict.update(self.memory._parse_answer(text_answer))
- ans_dict.update(self.action_space._parse_answer(text_answer))
- return ans_dict
-
-
-class ActionSpace(PromptElement):
- def __init__(self, flags: Flags) -> None:
- super().__init__()
- self.flags = flags
- self.action_space = _get_action_space(flags)
-
- self._prompt = f"# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n"
- self._abstract_ex = f"""
-