diff --git a/README.md b/README.md index f0fbd2a3..d711b045 100644 --- a/README.md +++ b/README.md @@ -158,24 +158,44 @@ print("\n".join(env_ids)) ## Demo -If you want to experiment with an agent in BrowserGym, follow these steps: +If you want to experiment with a demo agent in BrowserGym, follow these steps: ```sh cd demo-agent -conda env create -f environment.yml; conda activate demo-agent +conda env create -f environment.yml +conda activate demo-agent # or simply use `pip install -r requirements.txt` playwright install chromium ``` -Optional: Set your `OPENAI_API_KEY` to use a GPT agent. - -Launch the demo on the open web: +Our demo agent uses `openai` as a backend, be sure to set your `OPENAI_API_KEY`. +Launch the demo agent on the open web: ```sh python run_demo.py --task_name openended --start_url https://www.google.com ``` -You can customize your experience by changing the `model_name` to your preferred LLM, toggling Chain-of-thought with `use_thinking`, adding screenshots for your VLMs with `use_screenshot`, and much more! +Or use it to solve a simple MiniWoB task: +```sh +python run_demo.py --task_name miniwob.click-test +``` + +A VisualWebArena task: +```sh +python run_demo.py --task_name visualwebarena.398 +``` + +A WebArena task: +```sh +python run_demo.py --task_name webarena.4 +``` + +A WorkArena task: +```sh +python run_demo.py --task_name workarena.servicenow.order-standard-laptop +``` + +You can customize your experience by changing the `model_name` to your preferred LLM (it uses `gpt-4o-mini` by default), adding screenshots for your VLMs with `use_screenshot`, and much more! (see `python run_demo.py --help`) ## Citing This Work diff --git a/browsergym/core/src/browsergym/core/env.py b/browsergym/core/src/browsergym/core/env.py index 662293f7..22fffaaa 100644 --- a/browsergym/core/src/browsergym/core/env.py +++ b/browsergym/core/src/browsergym/core/env.py @@ -1,38 +1,54 @@ import copy -import gymnasium as gym import logging -import numpy as np -import playwright.sync_api -import time import re - +import time from abc import ABC from pathlib import Path -from typing import Optional, Literal +from typing import Literal, Optional +import gymnasium as gym +import numpy as np +import playwright.sync_api + +from . import _get_global_playwright +from .action.base import execute_python_code +from .action.highlevel import HighLevelActionSet from .chat import Chat -from .task import AbstractBrowserTask -from .spaces import Unicode, AnyDict, AnyBox -from .constants import TEXT_MAX_LENGTH, BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES +from .constants import BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES, TEXT_MAX_LENGTH from .observation import ( - _pre_extract, + MarkingError, _post_extract, - extract_screenshot, - extract_dom_snapshot, + _pre_extract, extract_dom_extra_properties, - extract_merged_axtree, + extract_dom_snapshot, extract_focused_element_bid, - MarkingError, + extract_merged_axtree, + extract_screenshot, ) -from .action.base import execute_python_code -from .action.highlevel import HighLevelActionSet -from .action.base import execute_python_code -from . import _get_global_playwright - +from .spaces import AnyBox, AnyDict, Unicode +from .task import AbstractBrowserTask logger = logging.getLogger(__name__) +def _try_to_extract_legacy_goal(goal: list): + legacy_goal_strings = [] + for message in goal: + if message["type"] == "text": + legacy_goal_strings.append(message["text"]) + else: + logger.debug( + f"Message type {repr(message['type'])} present in the goal, cannot be converted to legacy text-only format." + ) + legacy_goal_strings.append( + 'WARNING: This goal cannot be converted to a text-only goal format. Use the new goal format instead ("goal_object" field). Any agent reading this should abort immediately.' + ) + break + legacy_goal = "\n".join(legacy_goal_strings) + + return legacy_goal + + class BrowserEnv(gym.Env, ABC): """The main BrowserGym class, which encapsulates instruction-following Web browsing into a Gymnasium environment.""" @@ -120,11 +136,8 @@ def __init__( } ) ), - # TODO: this is redundant with chat messages, to be removed "goal": Unicode(min_length=0, max_length=TEXT_MAX_LENGTH), - "goal_image_urls": gym.spaces.Sequence( - Unicode(min_length=0, max_length=TEXT_MAX_LENGTH) - ), + "goal_object": gym.spaces.Sequence(AnyDict()), "open_pages_urls": gym.spaces.Sequence( Unicode(min_length=0, max_length=TEXT_MAX_LENGTH) ), @@ -266,27 +279,42 @@ def override_property(task, env, property): recording_start_time = time.time() # setup the task - goal, task_info = self.task.setup(page=self.page) + task_goal, task_info = self.task.setup(page=self.page) + + # process the task goal + + # no goal specified + if task_goal is None: + self.goal_object = [] + # convert text-only goal (legacy) to new format + elif isinstance(task_goal, str): + self.goal_object = [{"type": "text", "text": task_goal}] + # new format goal with multiple texts and images (OpenAI style) + elif isinstance(task_goal, list): + self.goal_object = task_goal + else: + raise ValueError(f"task_goal should be of type str or list, got {task_goal.__class__}") # initialize the chat self.chat.add_message( role="assistant", msg="Hi! I am your UI assistant, I can perform web tasks for you. What can I help you with?", ) - # if any, add the task's goal to the chat - if goal: - - # goal is text-only - if isinstance(goal, str): - goal_msg = goal - # goal is text + images - elif isinstance(goal, dict): - goal_msg = goal["message"] - for image_url in goal["image_urls"]: - self.chat.add_message(role="user_image", msg=image_url) - - self.chat.add_message(role="user", msg=goal_msg) + # send task goal (if any) to the chat + for message in self.goal_object: + match message["type"]: + case "text": + self.chat.add_message(role="user", msg=message["text"]) + case "image_url": + image_src = message["image_url"] + if isinstance(image_src, dict): + image_src = image_src["url"] + self.chat.add_message(role="user_image", msg=image_src) + case _: + raise ValueError( + f"Unknown message type {repr(message['type'])} in the task goal." + ) self._wait_dom_loaded() @@ -508,26 +536,11 @@ def _get_obs(self): # post-extraction cleanup of temporary info in dom _post_extract(self.page) - # use first user message as goal, if any - # use all user images before first user message as goal images, if any - goal_msg = "There is no goal." - goal_image_urls = [] - _prev_image_urls = [] - for msg in self.chat.messages: - if msg["role"] == "user_image": - _prev_image_urls.append(msg["message"]) - elif msg["role"] == "user": - goal_msg = msg["message"] - goal_image_urls = _prev_image_urls - break - else: - pass - # obs is generic to all tasks obs = { "chat_messages": copy.deepcopy(self.chat.messages), - "goal": goal_msg, # TODO: redundant with chat messages, to be removed? - "goal_image_urls": goal_image_urls, # TODO: redundant with chat messages, to be removed? + "goal": _try_to_extract_legacy_goal(self.goal_object), # legacy goal, deprecated + "goal_object": self.goal_object, # new goal format, list of messages openai style "open_pages_urls": [page.url for page in self.context.pages], "active_page_index": np.asarray([self.context.pages.index(self.page)]), "url": self.page.url, diff --git a/browsergym/core/src/browsergym/core/registration.py b/browsergym/core/src/browsergym/core/registration.py index dd0e36ed..8bc23ca6 100644 --- a/browsergym/core/src/browsergym/core/registration.py +++ b/browsergym/core/src/browsergym/core/registration.py @@ -6,7 +6,12 @@ def register_task( - id: str, task_class: Type[AbstractBrowserTask], nondeterministic: bool = True, *args, **kwargs + id: str, + task_class: Type[AbstractBrowserTask], + task_kwargs: dict = None, + nondeterministic: bool = True, + *args, + **kwargs, ): """ Registers a browser task as a gym environment with its unique id. @@ -19,9 +24,16 @@ def register_task( *kwargs: additional arguments for the browsergym environment. """ + # these environment arguments will be fixed, and error will be raised if they are set when calling gym.make() + fixed_env_kwargs = {} + if task_kwargs is not None: + fixed_env_kwargs["task_kwargs"] = task_kwargs + gym.register( id=f"browsergym/{id}", - entry_point=lambda *env_args, **env_kwargs: BrowserEnv(task_class, *env_args, **env_kwargs), + entry_point=lambda *env_args, **env_kwargs: BrowserEnv( + task_class, *env_args, **fixed_env_kwargs, **env_kwargs + ), nondeterministic=nondeterministic, *args, **kwargs, diff --git a/browsergym/core/src/browsergym/core/spaces.py b/browsergym/core/src/browsergym/core/spaces.py index 177959e5..fb3ee7fe 100644 --- a/browsergym/core/src/browsergym/core/spaces.py +++ b/browsergym/core/src/browsergym/core/spaces.py @@ -79,6 +79,19 @@ def __eq__(self, other: Any) -> bool: return isinstance(other, AnyDict) +class Anything(Space): + """A space representing an arbitrary dictionary object.""" + + def contains(self, x: Any) -> bool: + return True + + def __repr__(self) -> str: + return f"Anything()" + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Anything) + + class AnyBox(Space[NDArray[Any]]): """A space representing an arbitrary dictionary object.""" diff --git a/browsergym/core/src/browsergym/core/task.py b/browsergym/core/src/browsergym/core/task.py index 6555223d..6f7f9fec 100644 --- a/browsergym/core/src/browsergym/core/task.py +++ b/browsergym/core/src/browsergym/core/task.py @@ -1,9 +1,9 @@ -import numpy as np -import playwright.sync_api - from abc import ABC, abstractmethod from typing import Tuple +import numpy as np +import playwright.sync_api + class AbstractBrowserTask(ABC): """ diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index f39d4622..95b81c5d 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -432,7 +432,18 @@ def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_so img = Image.fromarray(screenshot_som) img.save(exp_dir / f"screenshot_som_step_{self.step}.png") + # save goal object (which might contain images) to a separate file to save space + if self.obs is not None and self.obs.get("goal_object", False): + # save the goal object only once (goal should never change once setup) + goal_object_file = Path(exp_dir) / "goal_object.pkl.gz" + if not goal_object_file.exists(): + with gzip.open(goal_object_file, "wb") as f: + pickle.dump(self.obs["goal_object"], f) + # set goal_object to a special placeholder value, which indicates it should be loaded from a separate file + self.obs["goal_object"] = None + with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f: + # TODO should we pop the screenshots too before this to save space ? pickle.dump(self, f) if save_json: @@ -584,6 +595,16 @@ def get_step_info(self, step: int) -> StepInfo: ) except FileNotFoundError: pass + # if goal_object is set to None, it indicates it has been saved into a separate file + if ( + self._steps_info[step].obs + and "goal_object" in self._steps_info[step].obs + and self._steps_info[step].obs["goal_object"] is None + ): + with gzip.open(self.exp_dir / "goal_object.pkl.gz", "rb") as f: + goal_object = pickle.load(f) + self._steps_info[step].obs["goal_object"] = goal_object + return self._steps_info[step] @property diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py index e229da99..c2ec0c88 100644 --- a/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py @@ -13,7 +13,7 @@ register_task( gym_id, task.GenericVisualWebArenaTask, - kwargs={"task_kwargs": {"task_id": task_id}}, + task_kwargs={"task_id": task_id}, ) ALL_VISUALWEBARENA_TASK_IDS.append(gym_id) if task_id in config.TASK_IDS_WITH_RESET: diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py index 90a8624a..00a3107e 100644 --- a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py @@ -2,6 +2,7 @@ import logging import playwright.sync_api import importlib.resources +import pathlib import tempfile import requests @@ -10,10 +11,83 @@ from browsergym.core.task import AbstractBrowserTask from .instance import VisualWebArenaInstance +from .utils import image_url_to_pil_image, pil_image_to_data_uri logger = logging.getLogger(__name__) +def _build_goal(config, with_na_hint: bool): + """ + Build an openai-style goal (list of messages) + - recovers the goal text from config + - download goal images if any + - save goal images to local files + - expose goal images as image_url messages using base64 encoding + - expose goal images as local file paths (if task requires to upload them) + """ + + # recover goal text + goal_text = config["intent"] + + # This note is present in some of webarena's agent prompts + if with_na_hint: + goal_text += """\ + +If you believe the task is impossible to complete, provide the answer "N/A". +""" + + # recover goal image urls + image_urls = config.get("image", []) + image_data_uris = [] + image_paths = [] + + # fix image list if needed + if image_urls is None: + image_urls = [] + elif isinstance(image_urls, str): + image_urls = [image_urls] + + # save images to local files in a temporary directory + temp_dir = pathlib.Path(tempfile.mkdtemp()) + for i, image_url in enumerate(image_urls): + # extract image content from url + image = image_url_to_pil_image(image_url) + format = image.format.lower() + # write image to local file + image_path = temp_dir / f"input_image_{i+1}.{format}" + image.save(image_path) + # save image path for the goal + image_paths.append(image_path) + # save image data as base64 for the goal + image_data_uris.append(pil_image_to_data_uri(image)) + + # build an OpenAI-style structured goal + # textual goal first + goal = [{"type": "text", "text": goal_text}] + # then goal images + for i, (image_url, image_data_uri, image_path) in enumerate( + zip(image_urls, image_data_uris, image_paths) + ): + goal.extend( + [ + # image description (id, filepath, url) + { + "type": "text", + "text": f"Input image {i+1}/{len(image_urls)} below (local path: {repr(image_path)}, url: {repr(image_url)})", + }, + # actual image (base64 image data URI) + { + "type": "image_url", + "image_url": { + "url": image_data_uri, # send data URI instead of URL (local urls might be inaccessible from the outside) + }, + }, + ] + ) + + return goal + + class GenericVisualWebArenaTask(AbstractBrowserTask): """ Base class for all WebArena tasks. @@ -52,13 +126,14 @@ def __init__( ) # substitute URLs - for pattern, url_key in { - "__REDDIT__": "reddit", - "__SHOPPING__": "shopping", - "__WIKIPEDIA__": "wikipedia", - "__CLASSIFIEDS__": "classifieds", + for pattern, url in { + "__REDDIT__": self.webarena_instance.urls["reddit"], + "__SHOPPING__": self.webarena_instance.urls["shopping"], + "__WIKIPEDIA__": self.webarena_instance.urls["wikipedia"], + "__CLASSIFIEDS__": self.webarena_instance.urls["classifieds"], + "__HOMEPAGE__": self.webarena_instance.home_url, }.items(): - all_configs_str = all_configs_str.replace(pattern, self.webarena_instance.urls[url_key]) + all_configs_str = all_configs_str.replace(pattern, url) # load all task configs to JSON all_configs = json.loads(all_configs_str) @@ -129,25 +204,7 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]: if i < len(start_urls) - 1: page = page.context.new_page() - # recover goal - goal = { - "message": self.config["intent"], - "image_urls": self.config.get("image", []), - } - # fix goal if needed - if goal["image_urls"] is None: - goal["image_urls"] = [] - elif isinstance(goal["image_urls"], str): - goal["image_urls"] = [goal["image_urls"]] - - # This note is present in some of webarena's agent prompts - if self.with_na_hint: - goal[ - "message" - ] += """\ - -If you believe the task is impossible to complete, provide the answer "N/A". -""" + goal = _build_goal(self.config, with_na_hint=self.with_na_hint) return goal, {} diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py new file mode 100644 index 00000000..c02d2f72 --- /dev/null +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/utils.py @@ -0,0 +1,39 @@ +import base64 +import io +import PIL.Image +import requests + +from typing import Literal + + +def image_url_to_pil_image(image_url: str) -> PIL.Image: + if not image_url.startswith("http"): + raise ValueError(f"Unexpected image URL: {image_url}") + response = requests.get(image_url, stream=True) + if response.status_code != 200: + raise ValueError( + f"Could not download image from url {image_url} (status code {response.status_code})" + ) + img = PIL.Image.open(io.BytesIO(response.content)) + return img + + +def data_uri_to_pil_image(data_uri: str) -> PIL.Image: + if data_uri.startswith("data:image/png;base64,"): + image_data = base64.b64decode(data_uri.removeprefix("data:image/png;base64,")) + elif data_uri.startswith("data:image/jpeg;base64,"): + image_data = base64.b64decode(data_uri.removeprefix("data:image/jpeg;base64,")) + else: + raise ValueError(f"Unexpected image encoding: {data_uri}") + img = PIL.Image.open(io.BytesIO(image_data)) + return img + + +def pil_image_to_data_uri(image: PIL.Image, format: Literal["png", "jpeg"] = "png") -> str: + assert format in ("png", "jpeg") + with io.BytesIO() as image_buffer: + image.save(image_buffer, format=format.upper()) + byte_data = image_buffer.getvalue() + image_b64 = base64.b64encode(byte_data).decode("utf-8") + image_b64 = f"data:image/{format};base64," + image_b64 + return image_b64 diff --git a/browsergym/webarena/src/browsergym/webarena/__init__.py b/browsergym/webarena/src/browsergym/webarena/__init__.py index 9695ce30..26c9319a 100644 --- a/browsergym/webarena/src/browsergym/webarena/__init__.py +++ b/browsergym/webarena/src/browsergym/webarena/__init__.py @@ -11,6 +11,6 @@ register_task( gym_id, task.GenericWebArenaTask, - kwargs={"task_kwargs": {"task_id": task_id}}, + task_kwargs={"task_id": task_id}, ) ALL_WEBARENA_TASK_IDS.append(gym_id) diff --git a/demo_agent/__init__.py b/demo_agent/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/demo_agent/agents/__init__.py b/demo_agent/agents/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/demo_agent/agents/basic/__init__.py b/demo_agent/agents/basic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/demo_agent/agents/basic/agent.py b/demo_agent/agents/basic/agent.py deleted file mode 100644 index b5e1e6be..00000000 --- a/demo_agent/agents/basic/agent.py +++ /dev/null @@ -1,111 +0,0 @@ -import dataclasses - -from browsergym.experiments import Agent, AbstractAgentArgs -from browsergym.core.action.highlevel import HighLevelActionSet -from browsergym.core.action.python import PythonActionSet -from browsergym.utils.obs import flatten_axtree_to_str - - -class DemoAgent(Agent): - """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities.""" - - action_set = HighLevelActionSet( - subsets=["chat", "bid"], # define a subset of the action space - # subsets=["chat", "bid", "coord"] # allow the agent to also use x,y coordinates - strict=False, # less strict on the parsing of the actions - multiaction=True, # enable to agent to take multiple actions at once - demo_mode="default", # add visual effects - ) - # use this instead to allow the agent to directly use Python code - # action_set = PythonActionSet()) - - def obs_preprocessor(self, obs: dict) -> dict: - return { - "goal": obs["goal"], - "axtree_txt": flatten_axtree_to_str(obs["axtree_object"]), - } - - def __init__(self, model_name) -> None: - super().__init__() - self.model_name = model_name - - from openai import OpenAI - - self.openai_client = OpenAI() - - def get_action(self, obs: dict) -> tuple[str, dict]: - system_msg = f"""\ -# Instructions -Review the current state of the page and all other information to find the best -possible next action to accomplish your goal. Your answer will be interpreted -and executed by a program, make sure to follow the formatting instructions. - -# Goal: -{obs["goal"]}""" - - prompt = f"""\ -# Current Accessibility Tree: -{obs["axtree_txt"]} - -# Action Space -{self.action_set.describe(with_long_description=False, with_examples=True)} - -Here is an example with chain of thought of a valid action when clicking on a button: -" -In order to accomplish my goal I need to click on the button with bid 12 -```click("12")``` -" -""" - - # query OpenAI model - response = self.openai_client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_msg}, - {"role": "user", "content": prompt}, - ], - ) - action = response.choices[0].message.content - - return action, {} - - -@dataclasses.dataclass -class DemoAgentArgs(AbstractAgentArgs): - """ - This class is meant to store the arguments that define the agent. - - By isolating them in a dataclass, this ensures serialization without storing - internal states of the agent. - """ - - model_name: str = "gpt-3.5-turbo" - - def make_agent(self): - return DemoAgent(model_name=self.model_name) - - -def main(): - from browsergym.experiments import EnvArgs, ExpArgs, get_exp_result - from pathlib import Path - - exp_root = Path().home() / "agent_experiments" - exp_root.mkdir(exist_ok=True) - - exp_args = ExpArgs( - agent_args=DemoAgentArgs(model_name="gpt-3.5-turbo"), - env_args=EnvArgs( - task_name="miniwob.click-test", - task_seed=42, - headless=False, # shows the browser - ), - ) - - exp_args.prepare(exp_root=exp_root) - exp_args.run() - - exp_result = get_exp_result(exp_args.exp_dir) - exp_record = exp_result.get_exp_record() - - for key, val in exp_record.items(): - print(f"{key}: {val}") diff --git a/demo_agent/agents/legacy/__init__.py b/demo_agent/agents/legacy/__init__.py deleted file mode 100644 index 005e6fb1..00000000 --- a/demo_agent/agents/legacy/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .agent import GenericAgentArgs -from .dynamic_prompting import Flags diff --git a/demo_agent/agents/legacy/agent.py b/demo_agent/agents/legacy/agent.py deleted file mode 100644 index a6b263a8..00000000 --- a/demo_agent/agents/legacy/agent.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -WARNING DEPRECATED WILL BE REMOVED SOON -""" - -from dataclasses import asdict, dataclass, field -import traceback -from warnings import warn -from langchain.schema import HumanMessage, SystemMessage - -from browsergym.core.action.base import AbstractActionSet -from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html -from browsergym.experiments import Agent, AbstractAgentArgs - -from ..legacy import dynamic_prompting -from .utils.llm_utils import ParseError, retry -from .utils.chat_api import ChatModelArgs - - -@dataclass -class GenericAgentArgs(AbstractAgentArgs): - chat_model_args: ChatModelArgs = None - flags: dynamic_prompting.Flags = field(default_factory=lambda: dynamic_prompting.Flags()) - max_retry: int = 4 - - def make_agent(self): - return GenericAgent( - chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry - ) - - -class GenericAgent(Agent): - - def obs_preprocessor(self, obs: dict) -> dict: - """ - Augment observations with text HTML and AXTree representations, which will be stored in - the experiment traces. - """ - - obs = obs.copy() - obs["dom_txt"] = flatten_dom_to_str( - obs["dom_object"], - with_visible=self.flags.extract_visible_tag, - with_center_coords=self.flags.extract_coords == "center", - with_bounding_box_coords=self.flags.extract_coords == "box", - filter_visible_only=self.flags.extract_visible_elements_only, - ) - obs["axtree_txt"] = flatten_axtree_to_str( - obs["axtree_object"], - with_visible=self.flags.extract_visible_tag, - with_center_coords=self.flags.extract_coords == "center", - with_bounding_box_coords=self.flags.extract_coords == "box", - filter_visible_only=self.flags.extract_visible_elements_only, - ) - obs["pruned_html"] = prune_html(obs["dom_txt"]) - - return obs - - def __init__( - self, - chat_model_args: ChatModelArgs = None, - flags: dynamic_prompting.Flags = None, - max_retry: int = 4, - ): - self.chat_model_args = chat_model_args if chat_model_args is not None else ChatModelArgs() - self.flags = flags if flags is not None else dynamic_prompting.Flags() - self.max_retry = max_retry - - self.chat_llm = chat_model_args.make_chat_model() - self.action_set = dynamic_prompting._get_action_space(self.flags) - - # consistency check - if self.flags.use_screenshot: - if not self.chat_model_args.has_vision(): - warn( - """\ - -Warning: use_screenshot is set to True, but the chat model \ -does not support vision. Disabling use_screenshot.""" - ) - self.flags.use_screenshot = False - - # reset episode memory - self.obs_history = [] - self.actions = [] - self.memories = [] - self.thoughts = [] - - def get_action(self, obs): - - self.obs_history.append(obs) - - main_prompt = dynamic_prompting.MainPrompt( - obs_history=self.obs_history, - actions=self.actions, - memories=self.memories, - thoughts=self.thoughts, - flags=self.flags, - ) - - # Determine the minimum non-None token limit from prompt, total, and input tokens, or set to None if all are None. - maxes = ( - self.flags.max_prompt_tokens, - self.chat_model_args.max_total_tokens, - self.chat_model_args.max_input_tokens, - ) - maxes = [m for m in maxes if m is not None] - max_prompt_tokens = min(maxes) if maxes else None - - prompt = dynamic_prompting.fit_tokens( - main_prompt, - max_prompt_tokens=max_prompt_tokens, - model_name=self.chat_model_args.model_name, - ) - - chat_messages = [ - SystemMessage(content=dynamic_prompting.SystemPrompt().prompt), - HumanMessage(content=prompt), - ] - - def parser(text): - try: - ans_dict = main_prompt._parse_answer(text) - except ParseError as e: - # these parse errors will be caught by the retry function and - # the chat_llm will have a chance to recover - return None, False, str(e) - - return ans_dict, True, "" - - try: - ans_dict = retry(self.chat_llm, chat_messages, n_retry=self.max_retry, parser=parser) - # inferring the number of retries, TODO: make this less hacky - ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 - except ValueError as e: - # Likely due to maximum retry. We catch it here to be able to return - # the list of messages for further analysis - ans_dict = {"action": None} - ans_dict["err_msg"] = str(e) - ans_dict["stack_trace"] = traceback.format_exc() - ans_dict["n_retry"] = self.max_retry - - self.actions.append(ans_dict["action"]) - self.memories.append(ans_dict.get("memory", None)) - self.thoughts.append(ans_dict.get("think", None)) - - ans_dict["chat_messages"] = [m.content for m in chat_messages] - ans_dict["chat_model_args"] = asdict(self.chat_model_args) - - return ans_dict["action"], ans_dict diff --git a/demo_agent/agents/legacy/dynamic_prompting.py b/demo_agent/agents/legacy/dynamic_prompting.py deleted file mode 100644 index 20e3f7bc..00000000 --- a/demo_agent/agents/legacy/dynamic_prompting.py +++ /dev/null @@ -1,757 +0,0 @@ -""" -WARNING DEPRECATED WILL BE REMOVED SOON -""" - -import abc -import difflib -import logging -import platform - -from copy import deepcopy -from dataclasses import asdict, dataclass -from textwrap import dedent -from typing import Literal -from warnings import warn - -from browsergym.core.action.base import AbstractActionSet -from browsergym.core.action.highlevel import HighLevelActionSet -from browsergym.core.action.python import PythonActionSet - -from .utils.llm_utils import ParseError -from .utils.llm_utils import ( - count_tokens, - image_to_jpg_base64_url, - parse_html_tags_raise, -) - - -@dataclass -class Flags: - use_html: bool = True - use_ax_tree: bool = False - drop_ax_tree_first: bool = True # This flag is no longer active TODO delete - use_thinking: bool = False - use_error_logs: bool = False - use_past_error_logs: bool = False - use_history: bool = False - use_action_history: bool = False - use_memory: bool = False - use_diff: bool = False - html_type: str = "pruned_html" - use_concrete_example: bool = True - use_abstract_example: bool = False - multi_actions: bool = False - action_space: Literal[ - "python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav" - ] = "bid" - is_strict: bool = False - # This flag will be automatically disabled `if not chat_model_args.has_vision()` - use_screenshot: bool = True - enable_chat: bool = False - max_prompt_tokens: int = None - extract_visible_tag: bool = False - extract_coords: Literal["False", "center", "box"] = "False" - extract_visible_elements_only: bool = False - demo_mode: Literal["off", "default", "only_visible_elements"] = "off" - retry_with_force: bool = False - - def copy(self): - return deepcopy(self) - - def asdict(self): - """Helper for JSON serializble requirement.""" - return asdict(self) - - @classmethod - def from_dict(self, flags_dict): - """Helper for JSON serializble requirement.""" - if isinstance(flags_dict, Flags): - return flags_dict - - if not isinstance(flags_dict, dict): - raise ValueError(f"Unregcognized type for flags_dict of type {type(flags_dict)}.") - return Flags(**flags_dict) - - -class PromptElement: - """Base class for all prompt elements. Prompt elements can be hidden. - - Prompt elements are used to build the prompt. Use flags to control which - prompt elements are visible. We use class attributes as a convenient way - to implement static prompts, but feel free to override them with instance - attributes or @property decorator.""" - - _prompt = "" - _abstract_ex = "" - _concrete_ex = "" - - def __init__(self, visible: bool = True) -> None: - """Prompt element that can be hidden. - - Parameters - ---------- - visible : bool, optional - Whether the prompt element should be visible, by default True. Can - be a callable that returns a bool. This is useful when a specific - flag changes during a shrink iteration. - """ - self._visible = visible - - @property - def prompt(self): - """Avoid overriding this method. Override _prompt instead.""" - return self._hide(self._prompt) - - @property - def abstract_ex(self): - """Useful when this prompt element is requesting an answer from the llm. - Provide an abstract example of the answer here. See Memory for an - example. - - Avoid overriding this method. Override _abstract_ex instead - """ - return self._hide(self._abstract_ex) - - @property - def concrete_ex(self): - """Useful when this prompt element is requesting an answer from the llm. - Provide a concrete example of the answer here. See Memory for an - example. - - Avoid overriding this method. Override _concrete_ex instead - """ - return self._hide(self._concrete_ex) - - @property - def is_visible(self): - """Handle the case where visible is a callable.""" - visible = self._visible - if callable(visible): - visible = visible() - return visible - - def _hide(self, value): - """Return value if visible is True, else return empty string.""" - if self.is_visible: - return value - else: - return "" - - def _parse_answer(self, text_answer) -> dict: - if self.is_visible: - return self._parse_answer(text_answer) - else: - return {} - - -class Shrinkable(PromptElement, abc.ABC): - @abc.abstractmethod - def shrink(self) -> None: - """Implement shrinking of this prompt element. - - You need to recursively call all shrinkable elements that are part of - this prompt. You can also implement a shriking startegy for this prompt. - Shrinking is can be called multiple times to progressively shrink the - prompt until it fits max_tokens. Default max shrink iterations is 20. - """ - pass - - -class Trunkater(Shrinkable): - def __init__(self, visible, shrink_speed=0.3, start_trunkate_iteration=10): - super().__init__(visible=visible) - self.shrink_speed = shrink_speed - self.start_trunkate_iteration = start_trunkate_iteration - self.shrink_calls = 0 - self.deleted_lines = 0 - - def shrink(self) -> None: - if self.is_visible and self.shrink_calls >= self.start_trunkate_iteration: - # remove the fraction of _prompt - lines = self._prompt.splitlines() - new_line_count = int(len(lines) * (1 - self.shrink_speed)) - self.deleted_lines += len(lines) - new_line_count - self._prompt = "\n".join(lines[:new_line_count]) - self._prompt += f"\n... Deleted {self.deleted_lines} lines to reduce prompt size." - - self.shrink_calls += 1 - - -def fit_tokens( - shrinkable: Shrinkable, max_prompt_tokens=None, max_iterations=20, model_name="openai/gpt-4" -): - """Shrink a prompt element until it fits max_tokens. - - Parameters - ---------- - shrinkable : Shrinkable - The prompt element to shrink. - max_tokens : int - The maximum number of tokens allowed. - max_iterations : int, optional - The maximum number of shrink iterations, by default 20. - model_name : str, optional - The name of the model used when tokenizing. - - Returns - ------- - str : the prompt after shrinking. - """ - - if max_prompt_tokens is None: - return shrinkable.prompt - - for _ in range(max_iterations): - prompt = shrinkable.prompt - if isinstance(prompt, str): - prompt_str = prompt - elif isinstance(prompt, list): - prompt_str = "\n".join([p["text"] for p in prompt if p["type"] == "text"]) - else: - raise ValueError(f"Unrecognized type for prompt: {type(prompt)}") - n_token = count_tokens(prompt_str, model=model_name) - if n_token <= max_prompt_tokens: - return prompt - shrinkable.shrink() - - logging.info( - dedent( - f"""\ - After {max_iterations} shrink iterations, the prompt is still - {count_tokens(prompt_str)} tokens (greater than {max_prompt_tokens}). Returning the prompt as is.""" - ) - ) - return prompt - - -class HTML(Trunkater): - def __init__(self, html, visible: bool = True, prefix="") -> None: - super().__init__(visible=visible, start_trunkate_iteration=5) - self._prompt = f"\n{prefix}HTML:\n{html}\n" - - -class AXTree(Trunkater): - def __init__(self, ax_tree, visible: bool = True, coord_type=None, prefix="") -> None: - super().__init__(visible=visible, start_trunkate_iteration=10) - if coord_type == "center": - coord_note = """\ -Note: center coordinates are provided in parenthesis and are - relative to the top left corner of the page.\n\n""" - elif coord_type == "box": - coord_note = """\ -Note: bounding box of each object are provided in parenthesis and are - relative to the top left corner of the page.\n\n""" - else: - coord_note = "" - self._prompt = f"\n{prefix}AXTree:\n{coord_note}{ax_tree}\n" - - -class Error(PromptElement): - def __init__(self, error, visible: bool = True, prefix="") -> None: - super().__init__(visible=visible) - self._prompt = f"\n{prefix}Error from previous action:\n{error}\n" - - -class Observation(Shrinkable): - """Observation of the current step. - - Contains the html, the accessibility tree and the error logs. - """ - - def __init__(self, obs, flags: Flags) -> None: - super().__init__() - self.flags = flags - self.obs = obs - self.html = HTML(obs[flags.html_type], visible=lambda: flags.use_html, prefix="## ") - self.ax_tree = AXTree( - obs["axtree_txt"], - visible=lambda: flags.use_ax_tree, - coord_type=flags.extract_coords, - prefix="## ", - ) - self.error = Error( - obs["last_action_error"], - visible=lambda: flags.use_error_logs and obs["last_action_error"], - prefix="## ", - ) - - def shrink(self): - self.ax_tree.shrink() - self.html.shrink() - - @property - def _prompt(self) -> str: - return f"\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n" - - def add_screenshot(self, prompt): - if self.flags.use_screenshot: - if isinstance(prompt, str): - prompt = [{"type": "text", "text": prompt}] - img_url = image_to_jpg_base64_url(self.obs["screenshot"]) - prompt.append({"type": "image_url", "image_url": {"url": img_url}}) - - return prompt - - -class MacNote(PromptElement): - def __init__(self) -> None: - super().__init__(visible=platform.system() == "Darwin") - self._prompt = ( - "\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n" - ) - - -class BeCautious(PromptElement): - def __init__(self, visible: bool = True) -> None: - super().__init__(visible=visible) - self._prompt = f"""\ -\nBe very cautious. Avoid submitting anything before verifying the effect of your -actions. Take the time to explore the effect of safe actions first. For example -you can fill a few elements of a form, but don't click submit before verifying -that everything was filled correctly.\n""" - - -class GoalInstructions(PromptElement): - def __init__(self, goal, visible: bool = True) -> None: - super().__init__(visible) - self._prompt = f"""\ -# Instructions -Review the current state of the page and all other information to find the best -possible next action to accomplish your goal. Your answer will be interpreted -and executed by a program, make sure to follow the formatting instructions. - -## Goal: -{goal} -""" - - -class ChatInstructions(PromptElement): - def __init__(self, chat_messages, visible: bool = True) -> None: - super().__init__(visible) - self._prompt = f"""\ -# Instructions - -You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can -communicate with the user via a chat, in which the user gives you instructions and in which you -can send back messages. You have access to a web browser that both you and the user can see, -and with which only you can interact via specific commands. - -Review the instructions from the user, the current state of the page and all other information -to find the best possible next action to accomplish your goal. Your answer will be interpreted -and executed by a program, make sure to follow the formatting instructions. - -## Chat messages: - -""" - self._prompt += "\n".join( - [ - f"""\ - - [{msg['role']}] {msg['message']}""" - for msg in chat_messages - ] - ) - - -class SystemPrompt(PromptElement): - _prompt = """\ -You are an agent trying to solve a web task based on the content of the page and -a user instructions. You can interact with the page and explore. Each time you -submit an action it will be sent to the browser and you will receive a new page.""" - - -class MainPrompt(Shrinkable): - def __init__( - self, - obs_history, - actions, - memories, - thoughts, - flags: Flags, - ) -> None: - super().__init__() - self.flags = flags - self.history = History(obs_history, actions, memories, thoughts, flags) - if self.flags.enable_chat: - self.instructions = ChatInstructions(obs_history[-1]["chat_messages"]) - else: - if sum([msg["role"] == "user" for msg in obs_history[-1]["chat_messages"]]) > 1: - logging.warning( - "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." - ) - self.instructions = GoalInstructions(obs_history[-1]["goal"]) - - self.obs = Observation(obs_history[-1], self.flags) - self.action_space = ActionSpace(self.flags) - - self.think = Think(visible=lambda: flags.use_thinking) - self.memory = Memory(visible=lambda: flags.use_memory) - - @property - def _prompt(self) -> str: - prompt = f"""\ -{self.instructions.prompt}\ -{self.obs.prompt}\ -{self.history.prompt}\ -{self.action_space.prompt}\ -{self.think.prompt}\ -{self.memory.prompt}\ -""" - - if self.flags.use_abstract_example: - prompt += f""" -# Abstract Example - -Here is an abstract version of the answer with description of the content of -each tag. Make sure you follow this structure, but replace the content with your -answer: -{self.think.abstract_ex}\ -{self.memory.abstract_ex}\ -{self.action_space.abstract_ex}\ -""" - - if self.flags.use_concrete_example: - prompt += f""" -# Concrete Example - -Here is a concrete example of how to format your answer. -Make sure to follow the template with proper tags: -{self.think.concrete_ex}\ -{self.memory.concrete_ex}\ -{self.action_space.concrete_ex}\ -""" - return self.obs.add_screenshot(prompt) - - def shrink(self): - self.history.shrink() - self.obs.shrink() - - def _parse_answer(self, text_answer): - ans_dict = {} - ans_dict.update(self.think._parse_answer(text_answer)) - ans_dict.update(self.memory._parse_answer(text_answer)) - ans_dict.update(self.action_space._parse_answer(text_answer)) - return ans_dict - - -class ActionSpace(PromptElement): - def __init__(self, flags: Flags) -> None: - super().__init__() - self.flags = flags - self.action_space = _get_action_space(flags) - - self._prompt = f"# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n" - self._abstract_ex = f""" - -{self.action_space.example_action(abstract=True)} - -""" - self._concrete_ex = f""" - -{self.action_space.example_action(abstract=False)} - -""" - - def _parse_answer(self, text_answer): - ans_dict = parse_html_tags_raise(text_answer, keys=["action"], merge_multiple=True) - - try: - # just check if action can be mapped to python code but keep action as is - # the environment will be responsible for mapping it to python - self.action_space.to_python_code(ans_dict["action"]) - except Exception as e: - raise ParseError( - f"Error while parsing action\n: {e}\n" - "Make sure your answer is restricted to the allowed actions." - ) - - return ans_dict - - -def _get_action_space(flags: Flags) -> AbstractActionSet: - match flags.action_space: - case "python": - action_space = PythonActionSet(strict=flags.is_strict) - if flags.multi_actions: - warn( - f"Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}." - ) - if flags.demo_mode != "off": - warn( - f"Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}." - ) - return action_space - case "bid": - action_subsets = ["chat", "bid"] - case "coord": - action_subsets = ["chat", "coord"] - case "bid+coord": - action_subsets = ["chat", "bid", "coord"] - case "bid+nav": - action_subsets = ["chat", "bid", "nav"] - case "coord+nav": - action_subsets = ["chat", "coord", "nav"] - case "bid+coord+nav": - action_subsets = ["chat", "bid", "coord", "nav"] - case _: - raise NotImplementedError(f"Unknown action_space {repr(flags.action_space)}") - - action_space = HighLevelActionSet( - subsets=action_subsets, - multiaction=flags.multi_actions, - strict=flags.is_strict, - demo_mode=flags.demo_mode, - retry_with_force=flags.retry_with_force, - ) - - return action_space - - -class Memory(PromptElement): - _prompt = "" # provided in the abstract and concrete examples - - _abstract_ex = """ - -Write down anything you need to remember for next steps. You will be presented -with the list of previous memories and past actions. - -""" - - _concrete_ex = """ - -I clicked on bid 32 to activate tab 2. The accessibility tree should mention -focusable for elements of the form at next step. - -""" - - def _parse_answer(self, text_answer): - return parse_html_tags_raise(text_answer, optional_keys=["memory"], merge_multiple=True) - - -class Think(PromptElement): - _prompt = "" - - _abstract_ex = """ - -Think step by step. If you need to make calculations such as coordinates, write them here. Describe the effect -that your previous action had on the current content of the page. - -""" - _concrete_ex = """ - -My memory says that I filled the first name and last name, but I can't see any -content in the form. I need to explore different ways to fill the form. Perhaps -the form is not visible yet or some fields are disabled. I need to replan. - -""" - - def _parse_answer(self, text_answer): - return parse_html_tags_raise(text_answer, optional_keys=["think"], merge_multiple=True) - - -def diff(previous, new): - """Return a string showing the difference between original and new. - - If the difference is above diff_threshold, return the diff string.""" - - if previous == new: - return "Identical", [] - - if len(previous) == 0 or previous is None: - return "previous is empty", [] - - diff_gen = difflib.ndiff(previous.splitlines(), new.splitlines()) - - diff_lines = [] - plus_count = 0 - minus_count = 0 - for line in diff_gen: - if line.strip().startswith("+"): - diff_lines.append(line) - plus_count += 1 - elif line.strip().startswith("-"): - diff_lines.append(line) - minus_count += 1 - else: - continue - - header = f"{plus_count} lines added and {minus_count} lines removed:" - - return header, diff_lines - - -class Diff(Shrinkable): - def __init__( - self, previous, new, prefix="", max_line_diff=20, shrink_speed=2, visible=True - ) -> None: - super().__init__(visible=visible) - self.max_line_diff = max_line_diff - self.header, self.diff_lines = diff(previous, new) - self.shrink_speed = shrink_speed - self.prefix = prefix - - def shrink(self): - self.max_line_diff -= self.shrink_speed - self.max_line_diff = max(1, self.max_line_diff) - - @property - def _prompt(self) -> str: - diff_str = "\n".join(self.diff_lines[: self.max_line_diff]) - if len(self.diff_lines) > self.max_line_diff: - original_count = len(self.diff_lines) - diff_str = f"{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown." - return f"{self.prefix}{self.header}\n{diff_str}\n" - - -class HistoryStep(Shrinkable): - def __init__( - self, previous_obs, current_obs, action, memory, flags: Flags, shrink_speed=1 - ) -> None: - super().__init__() - self.html_diff = Diff( - previous_obs[flags.html_type], - current_obs[flags.html_type], - prefix="\n### HTML diff:\n", - shrink_speed=shrink_speed, - visible=lambda: flags.use_html and flags.use_diff, - ) - self.ax_tree_diff = Diff( - previous_obs["axtree_txt"], - current_obs["axtree_txt"], - prefix=f"\n### Accessibility tree diff:\n", - shrink_speed=shrink_speed, - visible=lambda: flags.use_ax_tree and flags.use_diff, - ) - self.error = Error( - current_obs["last_action_error"], - visible=( - lambda: flags.use_error_logs - and current_obs["last_action_error"] - and flags.use_past_error_logs - ), - prefix="### ", - ) - self.shrink_speed = shrink_speed - self.action = action - self.memory = memory - self.flags = flags - - def shrink(self): - super().shrink() - self.html_diff.shrink() - self.ax_tree_diff.shrink() - - @property - def _prompt(self) -> str: - prompt = "" - - if self.flags.use_action_history: - prompt += f"\n### Action:\n{self.action}\n" - - prompt += f"{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}" - - if self.flags.use_memory and self.memory is not None: - prompt += f"\n### Memory:\n{self.memory}\n" - - return prompt - - -class History(Shrinkable): - def __init__( - self, history_obs, actions, memories, thoughts, flags: Flags, shrink_speed=1 - ) -> None: - super().__init__(visible=lambda: flags.use_history) - assert len(history_obs) == len(actions) + 1 - assert len(history_obs) == len(memories) + 1 - - self.shrink_speed = shrink_speed - self.history_steps: list[HistoryStep] = [] - - for i in range(1, len(history_obs)): - self.history_steps.append( - HistoryStep( - history_obs[i - 1], - history_obs[i], - actions[i - 1], - memories[i - 1], - flags, - ) - ) - - def shrink(self): - """Shrink individual steps""" - # TODO set the shrink speed of older steps to be higher - super().shrink() - for step in self.history_steps: - step.shrink() - - @property - def _prompt(self): - prompts = ["# History of interaction with the task:\n"] - for i, step in enumerate(self.history_steps): - prompts.append(f"## step {i}") - prompts.append(step.prompt) - return "\n".join(prompts) + "\n" - - -if __name__ == "__main__": - html_template = """ - - -
- Hello World. - Step {}. -
- - - """ - - OBS_HISTORY = [ - { - "goal": "do this and that", - "pruned_html": html_template.format(1), - "axtree_txt": "[1] Click me", - "last_action_error": "", - }, - { - "goal": "do this and that", - "pruned_html": html_template.format(2), - "axtree_txt": "[1] Click me", - "last_action_error": "", - }, - { - "goal": "do this and that", - "pruned_html": html_template.format(3), - "axtree_txt": "[1] Click me", - "last_action_error": "Hey, there is an error now", - }, - ] - ACTIONS = ["click('41')", "click('42')"] - MEMORIES = ["memory A", "memory B"] - THOUGHTS = ["thought A", "thought B"] - - flags = Flags( - use_html=True, - use_ax_tree=True, - use_thinking=True, - use_error_logs=True, - use_past_error_logs=True, - use_history=True, - use_action_history=True, - use_memory=True, - use_diff=True, - html_type="pruned_html", - use_concrete_example=True, - use_abstract_example=True, - multi_actions=True, - ) - - print( - MainPrompt( - obs_history=OBS_HISTORY, - actions=ACTIONS, - memories=MEMORIES, - thoughts=THOUGHTS, - step=0, - flags=flags, - ).prompt - ) diff --git a/demo_agent/agents/legacy/utils/__init__.py b/demo_agent/agents/legacy/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/demo_agent/agents/legacy/utils/chat_api.py b/demo_agent/agents/legacy/utils/chat_api.py deleted file mode 100644 index 3c969aff..00000000 --- a/demo_agent/agents/legacy/utils/chat_api.py +++ /dev/null @@ -1,292 +0,0 @@ -from dataclasses import asdict, dataclass -import io -import json -from .prompt_templates import PromptTemplate, get_prompt_template -from langchain.schema import BaseMessage, SystemMessage, HumanMessage, AIMessage -from functools import partial -from typing import Optional, List, Any -import logging -from typing import Tuple -import time - -from langchain_community.llms import HuggingFaceHub, HuggingFacePipeline -from langchain_openai import ChatOpenAI -from langchain.schema import BaseMessage -from langchain.chat_models.base import SimpleChatModel -from langchain.callbacks.manager import CallbackManagerForLLMRun -from pydantic import Field -from transformers import pipeline -from dataclasses import dataclass -from huggingface_hub import InferenceClient -from transformers import AutoTokenizer -from transformers import GPT2TokenizerFast - - -@dataclass -class ChatModelArgs: - """Serializable object for instantiating a generic chat model. - - Attributes - ---------- - model_name : str - The name or path of the model to use. - model_url : str, optional - The url of the model to use, e.g. via TGI. If None, then model_name or model_path must - be specified. - eai_token: str, optional - The EAI token to use for authentication on Toolkit. Defaults to snow.optimass_account.cl4code's token. - temperature : float - The temperature to use for the model. - max_new_tokens : int - The maximum number of tokens to generate. - hf_hosted : bool - Whether the model is hosted on HuggingFace Hub. Defaults to False. - info : dict, optional - Any other information about how the model was finetuned. - DGX related args - n_gpus : int - The number of GPUs to use. Defaults to 1. - tgi_image : str - The TGI image to use. Defaults to "e3cbr6awpnoq/research/text-generation-inference:1.1.0". - ace : str - The ACE to use. Defaults to "servicenow-scus-ace". - workspace : str - The workspace to use. Defaults to UI_COPILOT_SCUS_WORKSPACE. - max_total_tokens : int - The maximum number of total tokens (input + output). Defaults to 4096. - """ - - model_name: str = "openai/gpt-3.5-turbo" - model_url: str = None - temperature: float = 0.1 - max_new_tokens: int = None - max_total_tokens: int = None - max_input_tokens: int = None - hf_hosted: bool = False - info: dict = None - n_retry_server: int = 4 - - def __post_init__(self): - if self.model_url is not None and self.hf_hosted: - raise ValueError("model_url cannot be specified when hf_hosted is True") - - def make_chat_model(self): - if self.model_name.startswith("openai"): - _, model_name = self.model_name.split("/") - return ChatOpenAI( - model_name=model_name, - temperature=self.temperature, - max_tokens=self.max_new_tokens, - ) - else: - return HuggingFaceChatModel( - model_name=self.model_name, - hf_hosted=self.hf_hosted, - temperature=self.temperature, - max_new_tokens=self.max_new_tokens, - max_total_tokens=self.max_total_tokens, - max_input_tokens=self.max_input_tokens, - model_url=self.model_url, - n_retry_server=self.n_retry_server, - ) - - @property - def model_short_name(self): - if "/" in self.model_name: - return self.model_name.split("/")[1] - else: - return self.model_name - - def key(self): - """Return a unique key for these arguments.""" - return json.dumps(asdict(self), sort_keys=True) - - def has_vision(self): - # TODO make sure to upgrade this as we add more models - name_patterns_with_vision = [ - "vision", - "4o", - ] - return any(pattern in self.model_name for pattern in name_patterns_with_vision) - - -class HuggingFaceChatModel(SimpleChatModel): - """ - Custom LLM Chatbot that can interface with HuggingFace models. - - This class allows for the creation of a custom chatbot using models hosted - on HuggingFace Hub or a local checkpoint. It provides flexibility in defining - the temperature for response sampling and the maximum number of new tokens - in the response. - - Attributes: - llm (Any): The HuggingFaceHub model instance. - prompt_template (Any): Template for the prompt to be used for the model's input sequence. - """ - - llm: Any = Field(description="The HuggingFaceHub model instance") - tokenizer: Any = Field( - default=None, - description="The tokenizer to use for the model", - ) - prompt_template: Optional[PromptTemplate] = Field( - default=None, - description="Template for the prompt to be used for the model's input sequence", - ) - n_retry_server: int = Field( - default=4, - description="The number of times to retry the server if it fails to respond", - ) - - def __init__( - self, - model_name: str, - hf_hosted: bool, - temperature: float, - max_new_tokens: int, - max_total_tokens: int, - max_input_tokens: int, - model_url: str, - eai_token: str, - n_retry_server: int, - ): - """ - Initializes the CustomLLMChatbot with the specified configurations. - - Args: - model_name (str): The path to the model checkpoint. - prompt_template (PromptTemplate, optional): A string template for structuring the prompt. - hf_hosted (bool, optional): Whether the model is hosted on HuggingFace Hub. Defaults to False. - temperature (float, optional): Sampling temperature. Defaults to 0.1. - max_new_tokens (int, optional): Maximum length for the response. Defaults to 64. - model_url (str, optional): The url of the model to use. If None, then model_name or model_name will be used. Defaults to None. - """ - super().__init__() - - self.n_retry_server = n_retry_server - - if max_new_tokens is None: - max_new_tokens = max_total_tokens - max_input_tokens - logging.warning( - f"max_new_tokens is not specified. Setting it to {max_new_tokens} (max_total_tokens - max_input_tokens)." - ) - - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - if isinstance(self.tokenizer, GPT2TokenizerFast): - # TODO: make this less hacky once tokenizer.apply_chat_template is more mature - logging.warning( - f"No chat template is defined for {model_name}. Resolving to the hard-coded templates." - ) - self.tokenizer = None - self.prompt_template = get_prompt_template(model_name) - - if temperature < 1e-3: - logging.warning( - "some weird things might happen when temperature is too low for some models." - ) - - model_kwargs = { - "temperature": temperature, - } - - if model_url is not None: - logging.info("Loading the LLM from a URL") - client = InferenceClient(model=model_url, token=eai_token) - self.llm = partial( - client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens - ) - elif hf_hosted: - logging.info("Serving the LLM on HuggingFace Hub") - model_kwargs["max_length"] = max_new_tokens - self.llm = HuggingFaceHub(repo_id=model_name, model_kwargs=model_kwargs) - else: - logging.info("Loading the LLM locally") - pipe = pipeline( - task="text-generation", - model=model_name, - device_map="auto", - max_new_tokens=max_new_tokens, - model_kwargs=model_kwargs, - ) - self.llm = HuggingFacePipeline(pipeline=pipe) - - def _call( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - if stop is not None or run_manager is not None or kwargs: - logging.warning( - "The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation." - ) - - if self.tokenizer: - messages_formated = _convert_messages_to_dict(messages) - prompt = self.tokenizer.apply_chat_template(messages_formated, tokenize=False) - - elif self.prompt_template: - prompt = self.prompt_template.construct_prompt(messages) - - itr = 0 - while True: - try: - response = self.llm(prompt) - return response - except Exception as e: - if itr == self.n_retry_server - 1: - raise e - logging.warning( - f"Failed to get a response from the server: \n{e}\n" - f"Retrying... ({itr+1}/{self.n_retry_server})" - ) - time.sleep(5) - itr += 1 - - def _llm_type(self): - return "huggingface" - - -def _convert_messages_to_dict(messages): - """ - Converts a list of message objects into a list of dictionaries, categorizing each message by its role. - - Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage. - The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary. - - Args: - messages (list): A list of message objects. - - Returns: - list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys. - - Raises: - ValueError: If an unsupported message type is encountered. - - Example: - >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")] - >>> _convert_messages_to_dict(messages) - [ - {"role": "system", "content": "System initializing..."}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "How can I assist?"} - ] - """ - - # Mapping of message types to roles - message_type_to_role = { - SystemMessage: "system", - HumanMessage: "user", - AIMessage: "assistant", - } - - chat = [] - for message in messages: - message_role = message_type_to_role.get(type(message)) - if message_role: - chat.append({"role": message_role, "content": message.content}) - else: - raise ValueError(f"Message type {type(message)} not supported") - - return chat diff --git a/demo_agent/agents/legacy/utils/llm_utils.py b/demo_agent/agents/legacy/utils/llm_utils.py deleted file mode 100644 index c02603a2..00000000 --- a/demo_agent/agents/legacy/utils/llm_utils.py +++ /dev/null @@ -1,425 +0,0 @@ -import collections -import json -from pathlib import Path -import re -import time -from warnings import warn -import logging - -from functools import cache -import numpy as np -import tiktoken -import yaml -from langchain_openai import ChatOpenAI - -from langchain.schema import SystemMessage, HumanMessage -from openai import BadRequestError -from joblib import Memory -from transformers import AutoModel -from transformers import AutoTokenizer -import io -import base64 -from PIL import Image -from openai import RateLimitError - - -def _extract_wait_time(error_message, min_retry_wait_time=60): - """Extract the wait time from an OpenAI RateLimitError message.""" - match = re.search(r"try again in (\d+(\.\d+)?)s", error_message) - if match: - return max(min_retry_wait_time, float(match.group(1))) - return min_retry_wait_time - - -def retry( - chat: ChatOpenAI, - messages, - n_retry, - parser, - log=True, - min_retry_wait_time=60, - rate_limit_max_wait_time=60 * 30, -): - """Retry querying the chat models with the response from the parser until it - returns a valid value. - - If the answer is not valid, it will retry and append to the chat the retry - message. It will stop after `n_retry`. - - Note, each retry has to resend the whole prompt to the API. This can be slow - and expensive. - - Parameters: - ----------- - chat (function) : a langchain ChatOpenAI taking a list of messages and - returning a list of answers. - messages (list) : the list of messages so far. - n_retry (int) : the maximum number of sequential retries. - parser (function): a function taking a message and returning a tuple - with the following fields: - value : the parsed value, - valid : a boolean indicating if the value is valid, - retry_message : a message to send to the chat if the value is not valid - log (bool): whether to log the retry messages. - min_retry_wait_time (float): the minimum wait time in seconds - after RateLimtError. will try to parse the wait time from the error - message. - - Returns: - -------- - value: the parsed value - """ - tries = 0 - rate_limit_total_delay = 0 - while tries < n_retry and rate_limit_total_delay < rate_limit_max_wait_time: - try: - answer = chat.invoke(messages) - except RateLimitError as e: - wait_time = _extract_wait_time(e.args[0], min_retry_wait_time) - logging.warning(f"RateLimitError, waiting {wait_time}s before retrying.") - time.sleep(wait_time) - rate_limit_total_delay += wait_time - if rate_limit_total_delay >= rate_limit_max_wait_time: - logging.warning( - f"Total wait time for rate limit exceeded. Waited {rate_limit_total_delay}s > {rate_limit_max_wait_time}s." - ) - raise - continue - - messages.append(answer) - - value, valid, retry_message = parser(answer.content) - if valid: - return value - - tries += 1 - if log: - msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}" - logging.info(msg) - messages.append(HumanMessage(content=retry_message)) - - raise ValueError(f"Could not parse a valid value after {n_retry} retries.") - - -def retry_parallel(chat: ChatOpenAI, messages, n_retry, parser): - """Retry querying the chat models with the response from the parser until it returns a valid value. - - It will stop after `n_retry`. It assuemes that chat will generate n_parallel answers for each message. - The best answer is selected according to the score returned by the parser. If no answer is valid, the - it will retry with the best answer so far and append to the chat the retry message. If there is a - single parallel generation, it behaves like retry. - - This function is, in principle, more robust than retry. The speed and cost overhead is minimal with - the prompt is large and the length of the generated message is small. - - Parameters: - ----------- - chat (function) : a langchain ChatOpenAI taking a list of messages and returning a list of answers. - The number of parallel generations is specified at the creation of the chat object. - messages (list) : the list of messages so far. - n_retry (int) : the maximum number of sequential retries. - parser (function): a function taking a message and returning a tuple with the following fields: - value : the parsed value, - valid : a boolean indicating if the value is valid, - retry_message : a message to send to the chat if the value is not valid, - score : a score to select the best answer from the parallel generations - - Returns: - -------- - value: the parsed value - """ - - for i in range(n_retry): - try: - answers = chat.generate([messages]).generations[0] # chat.n parallel completions - except BadRequestError as e: - # most likely, the added messages triggered a message too long error - # we thus retry without the last two messages - if i == 0: - raise e - msg = f"BadRequestError, most likely the message is too long retrying with previous query." - warn(msg) - messages = messages[:-2] - answers = chat.generate([messages]).generations[0] - - values, valids, retry_messages, scores = zip( - *[parser(answer.message.content) for answer in answers] - ) - idx = np.argmax(scores) - value = values[idx] - valid = valids[idx] - retry_message = retry_messages[idx] - answer = answers[idx].message - - if valid: - return value - - msg = f"Query failed. Retrying {i+1}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}" - warn(msg) - messages.append(answer) # already of type AIMessage - messages.append(SystemMessage(content=retry_message)) - - raise ValueError(f"Could not parse a valid value after {n_retry} retries.") - - -def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"): - """Use tiktoken to truncate a text to a maximum number of tokens.""" - enc = tiktoken.encoding_for_model(model_name) - tokens = enc.encode(text) - if len(tokens) - start > max_tokens: - return enc.decode(tokens[start : (start + max_tokens)]) - else: - return text - - -@cache -def get_tokenizer(model_name="openai/gpt-4"): - if model_name.startswith("openai"): - return tiktoken.encoding_for_model(model_name.split("/")[-1]) - else: - return AutoTokenizer.from_pretrained(model_name) - - -def count_tokens(text, model="openai/gpt-4"): - enc = get_tokenizer(model) - return len(enc.encode(text)) - - -def count_messages_token(messages, model="openai/gpt-4"): - """Count the number of tokens in a list of messages. - - Args: - messages (list): a list of messages, each message can be a string or a - list of dicts or an object with a content attribute. - model (str): the model to use for tokenization. - - Returns: - int: the number of tokens. - """ - token_count = 0 - for message in messages: - if hasattr(message, "content"): - message = message.content - - if isinstance(message, str): - token_count += count_tokens(message, model) - # handles messages with image content - elif isinstance(message, (list, tuple)): - for part in message: - if not isinstance(part, dict): - raise ValueError( - f"The message is expected to be a list of dicts, but got list of {type(message)}" - ) - if part["type"] == "text": - token_count += count_tokens(part["text"], model) - else: - raise ValueError( - f"The message is expected to be a string or a list of dicts, but got {type(message)}" - ) - return token_count - - -def json_parser(message): - """Parse a json message for the retry function.""" - - try: - value = json.loads(message) - valid = True - retry_message = "" - except json.JSONDecodeError as e: - warn(e) - value = {} - valid = False - retry_message = "Your response is not a valid json. Please try again and be careful to the format. Don't add any apology or comment, just the answer." - return value, valid, retry_message - - -def yaml_parser(message): - """Parse a yaml message for the retry function.""" - - # saves gpt-3.5 from some yaml parsing errors - message = re.sub(r":\s*\n(?=\S|\n)", ": ", message) - - try: - value = yaml.safe_load(message) - valid = True - retry_message = "" - except yaml.YAMLError as e: - warn(str(e)) - value = {} - valid = False - retry_message = "Your response is not a valid yaml. Please try again and be careful to the format. Don't add any apology or comment, just the answer." - return value, valid, retry_message - - -def _compress_chunks(text, identifier, skip_list, split_regex="\n\n+"): - """Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex.""" - text_list = re.split(split_regex, text) - text_list = [chunk.strip() for chunk in text_list] - counter = collections.Counter(text_list) - def_dict = {} - id = 0 - - # Store items that occur more than once in a dictionary - for item, count in counter.items(): - if count > 1 and item not in skip_list and len(item) > 10: - def_dict[f"{identifier}-{id}"] = item - id += 1 - - # Replace redundant items with their identifiers in the text - compressed_text = "\n".join(text_list) - for key, value in def_dict.items(): - compressed_text = compressed_text.replace(value, key) - - return def_dict, compressed_text - - -def compress_string(text): - """Compress a string by replacing redundant paragraphs and lines with identifiers.""" - - # Perform paragraph-level compression - def_dict, compressed_text = _compress_chunks( - text, identifier="§", skip_list=[], split_regex="\n\n+" - ) - - # Perform line-level compression, skipping any paragraph identifiers - line_dict, compressed_text = _compress_chunks( - compressed_text, "¶", list(def_dict.keys()), split_regex="\n+" - ) - def_dict.update(line_dict) - - # Create a definitions section - def_lines = [""] - for key, value in def_dict.items(): - def_lines.append(f"{key}:\n{value}") - def_lines.append("") - definitions = "\n".join(def_lines) - - return definitions + "\n" + compressed_text - - -def extract_html_tags(text, keys): - """Extract the content within HTML tags for a list of keys. - - Parameters - ---------- - text : str - The input string containing the HTML tags. - keys : list of str - The HTML tags to extract the content from. - - Returns - ------- - dict - A dictionary mapping each key to a list of subset in `text` that match the key. - - Notes - ----- - All text and keys will be converted to lowercase before matching. - - """ - content_dict = {} - # text = text.lower() - # keys = set([k.lower() for k in keys]) - for key in keys: - pattern = f"<{key}>(.*?)" - matches = re.findall(pattern, text, re.DOTALL) - if matches: - content_dict[key] = [match.strip() for match in matches] - return content_dict - - -class ParseError(Exception): - pass - - -def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False): - """A version of parse_html_tags that raises an exception if the parsing is not successful.""" - content_dict, valid, retry_message = parse_html_tags( - text, keys, optional_keys, merge_multiple=merge_multiple - ) - if not valid: - raise ParseError(retry_message) - return content_dict - - -def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False): - """Satisfy the parse api, extracts 1 match per key and validates that all keys are present - - Parameters - ---------- - text : str - The input string containing the HTML tags. - keys : list of str - The HTML tags to extract the content from. - optional_keys : list of str - The HTML tags to extract the content from, but are optional. - - Returns - ------- - dict - A dictionary mapping each key to subset of `text` that match the key. - bool - Whether the parsing was successful. - str - A message to be displayed to the agent if the parsing was not successful. - """ - all_keys = tuple(keys) + tuple(optional_keys) - content_dict = extract_html_tags(text, all_keys) - retry_messages = [] - - for key in all_keys: - if not key in content_dict: - if not key in optional_keys: - retry_messages.append(f"Missing the key <{key}> in the answer.") - else: - val = content_dict[key] - content_dict[key] = val[0] - if len(val) > 1: - if not merge_multiple: - retry_messages.append( - f"Found multiple instances of the key {key}. You should have only one of them." - ) - else: - # merge the multiple instances - content_dict[key] = "\n".join(val) - - valid = len(retry_messages) == 0 - retry_message = "\n".join(retry_messages) - return content_dict, valid, retry_message - - -class ChatCached: - # I wish I could extend ChatOpenAI, but it is somehow locked, I don't know if it's pydantic soercey. - - def __init__(self, chat, memory=None): - self.chat = chat - self.memory = memory if memory else Memory(location=Path.home() / "llm-cache", verbose=10) - self._call = self.memory.cache(self.chat.__call__, ignore=["self"]) - self._generate = self.memory.cache(self.chat.generate, ignore=["self"]) - - def __call__(self, messages): - return self._call(messages) - - def generate(self, messages): - return self._generate(messages) - - -def download_and_save_model(model_name: str, save_dir: str = "."): - model = AutoModel.from_pretrained(model_name) - model.save_pretrained(save_dir) - print(f"Model downloaded and saved to {save_dir}") - - -def image_to_jpg_base64_url(image: np.ndarray | Image.Image): - """Convert a numpy array to a base64 encoded image url.""" - - if isinstance(image, np.ndarray): - image = Image.fromarray(image) - if image.mode in ("RGBA", "LA"): - image = image.convert("RGB") - buffered = io.BytesIO() - image.save(buffered, format="JPEG") - - image_base64 = base64.b64encode(buffered.getvalue()).decode() - return f"data:image/jpeg;base64,{image_base64}" diff --git a/demo_agent/agents/legacy/utils/prompt_templates.py b/demo_agent/agents/legacy/utils/prompt_templates.py deleted file mode 100644 index 23f9e6f9..00000000 --- a/demo_agent/agents/legacy/utils/prompt_templates.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import List - -from langchain.schema import BaseMessage, SystemMessage, HumanMessage, AIMessage -from dataclasses import dataclass - -""" -To use this class, you should have the ``openai`` python package installed, and the -environment variable ``OPENAI_API_KEY`` set with your API key. -""" - - -@dataclass -class PromptTemplate: - """ - Base class for prompt templates. - - Defines a standard interface for prompt templates, ensuring that they contain - the required fields for the CustomLLMChatbot. - """ - - system: str - human: str - ai: str - prompt_end: str = "" - - def format_message(self, message: BaseMessage) -> str: - """ - Formats a given message based on its type. - - Args: - message (BaseMessage): The message to be formatted. - - Returns: - str: The formatted message. - - Raises: - ValueError: If the message type is not supported. - """ - if isinstance(message, SystemMessage): - return self.system.format(input=message.content) - elif isinstance(message, HumanMessage): - return self.human.format(input=message.content) - elif isinstance(message, AIMessage): - return self.ai.format(input=message.content) - else: - raise ValueError(f"Message type {type(message)} not supported") - - def construct_prompt(self, messages: List[BaseMessage]) -> str: - """ - Constructs a prompt from a list of messages. - - Args: - messages (List[BaseMessage]): The list of messages to be formatted. - - Returns: - str: The constructed prompt. - """ - if not all(isinstance(m, BaseMessage) for m in messages): - raise ValueError("All elements in the list must be of type BaseMessage") - - prompt = "".join([self.format_message(m) for m in messages]) - prompt += self.prompt_end - return prompt - - -def get_prompt_template(model_name): - for key, value in MODEL_PREFIX_TO_PROMPT_TEMPLATES.items(): - if key in model_name: - return value - raise NotImplementedError(f"Model {model_name} has no supported chat template") - - -## Prompt templates - -STARCHAT_PROMPT_TEMPLATE = PromptTemplate( - system="<|system|>\n{input}<|end|>\n", - human="<|user|>\n{input}<|end|>\n", - ai="<|assistant|>\n{input}<|end|>\n", - prompt_end="<|assistant|>", -) - - -## Model prefix to prompt template mapping - -MODEL_PREFIX_TO_PROMPT_TEMPLATES = { - "starcoder": STARCHAT_PROMPT_TEMPLATE, - "starchat": STARCHAT_PROMPT_TEMPLATE, -} diff --git a/demo_agent/basic_agent.py b/demo_agent/basic_agent.py new file mode 100644 index 00000000..59ad59e8 --- /dev/null +++ b/demo_agent/basic_agent.py @@ -0,0 +1,330 @@ +import base64 +import dataclasses +import numpy as np +import io +import logging + +from PIL import Image + +from browsergym.experiments import Agent, AbstractAgentArgs +from browsergym.core.action.highlevel import HighLevelActionSet +from browsergym.core.action.python import PythonActionSet +from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html + +logger = logging.getLogger(__name__) + + +def image_to_jpg_base64_url(image: np.ndarray | Image.Image): + """Convert a numpy array to a base64 encoded image url.""" + + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") + + with io.BytesIO() as buffer: + image.save(buffer, format="JPEG") + image_base64 = base64.b64encode(buffer.getvalue()).decode() + + return f"data:image/jpeg;base64,{image_base64}" + + +class DemoAgent(Agent): + """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities.""" + + def obs_preprocessor(self, obs: dict) -> dict: + + return { + "chat_messages": obs["chat_messages"], + "screenshot": obs["screenshot"], + "goal_object": obs["goal_object"], + "last_action": obs["last_action"], + "last_action_error": obs["last_action_error"], + "axtree_txt": flatten_axtree_to_str(obs["axtree_object"]), + "pruned_html": prune_html(flatten_dom_to_str(obs["dom_object"])), + } + + def __init__( + self, + model_name: str, + chat_mode: bool, + demo_mode: str, + use_html: bool, + use_axtree: bool, + use_screenshot: bool, + ) -> None: + super().__init__() + self.model_name = model_name + self.chat_mode = chat_mode + self.use_html = use_html + self.use_axtree = use_axtree + self.use_screenshot = use_screenshot + + if not (use_html or use_axtree): + raise ValueError(f"Either use_html or use_axtree must be set to True.") + + from openai import OpenAI + + self.openai_client = OpenAI() + + self.action_set = HighLevelActionSet( + subsets=["chat", "bid", "infeas"], # define a subset of the action space + # subsets=["chat", "bid", "coord", "infeas"] # allow the agent to also use x,y coordinates + strict=False, # less strict on the parsing of the actions + multiaction=False, # does not enable the agent to take multiple actions at once + demo_mode=demo_mode, # add visual effects + ) + # use this instead to allow the agent to directly use Python code + # self.action_set = PythonActionSet()) + + self.action_history = [] + + def get_action(self, obs: dict) -> tuple[str, dict]: + system_msgs = [] + user_msgs = [] + + if self.chat_mode: + system_msgs.append( + { + "type": "text", + "text": f"""\ +# Instructions + +You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can +communicate with the user via a chat, to which the user gives you instructions and to which you +can send back messages. You have access to a web browser that both you and the user can see, +and with which only you can interact via specific commands. + +Review the instructions from the user, the current state of the page and all other information +to find the best possible next action to accomplish your goal. Your answer will be interpreted +and executed by a program, make sure to follow the formatting instructions. +""", + } + ) + # append chat messages + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Chat Messages +""", + } + ) + for msg in obs["chat_messages"]: + if msg["role"] in ("user", "assistant", "infeasible"): + user_msgs.append( + { + "type": "text", + "text": f"""\ +- [{msg['role']}] {msg['message']} +""", + } + ) + elif msg["role"] == "user_image": + user_msgs.append({"type": "image_url", "image_url": msg["message"]}) + else: + raise ValueError(f"Unexpected chat message role {repr(msg['role'])}") + + else: + assert obs["goal_object"], "The goal is missing." + system_msgs.append( + { + "type": "text", + "text": f"""\ +# Instructions + +Review the current state of the page and all other information to find the best +possible next action to accomplish your goal. Your answer will be interpreted +and executed by a program, make sure to follow the formatting instructions. +""", + } + ) + # append goal + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Goal +""", + } + ) + # goal_object is directly presented as a list of openai-style messages + user_msgs.extend(obs["goal_object"]) + + # append page AXTree (if asked) + if self.use_axtree: + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Current page Accessibility Tree + +{obs["axtree_txt"]} + +""", + } + ) + # append page HTML (if asked) + if self.use_html: + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Current page DOM + +{obs["pruned_html"]} + +""", + } + ) + + # append page screenshot (if asked) + if self.use_screenshot: + user_msgs.append( + { + "type": "text", + "text": """\ +# Current page Screenshot +""", + } + ) + user_msgs.append( + { + "type": "image_url", + "image_url": { + "url": image_to_jpg_base64_url(obs["screenshot"]), + "detail": "auto", + }, # Literal["low", "high", "auto"] = "auto" + } + ) + + # append action space description + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Action Space + +{self.action_set.describe(with_long_description=False, with_examples=True)} + +Here are examples of actions with chain-of-thought reasoning: + +I now need to click on the Submit button to send the form. I will use the click action on the button, which has bid 12. +```click("12")``` + +I found the information requested by the user, I will send it to the chat. +```send_msg_to_user("The price for a 15\\" laptop is 1499 USD.")``` + +""", + } + ) + + # append past actions (and last error message) if any + if self.action_history: + user_msgs.append( + { + "type": "text", + "text": f"""\ +# History of past actions +""", + } + ) + user_msgs.extend( + [ + { + "type": "text", + "text": f"""\ +{action} +""", + } + for action in self.action_history + ] + ) + + if obs["last_action_error"]: + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Error message from last action + +{obs["last_action_error"]} + +""", + } + ) + + # ask for the next action + user_msgs.append( + { + "type": "text", + "text": f"""\ +# Next action + +You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, the current state of the page before deciding on your next action. +""", + } + ) + + prompt_text_strings = [] + for message in system_msgs + user_msgs: + match message["type"]: + case "text": + prompt_text_strings.append(message["text"]) + case "image_url": + image_url = message["image_url"] + if isinstance(message["image_url"], dict): + image_url = image_url["url"] + if image_url.startswith("data:image"): + prompt_text_strings.append( + "image_url: " + image_url[:30] + "... (truncated)" + ) + else: + prompt_text_strings.append("image_url: " + image_url) + case _: + raise ValueError( + f"Unknown message type {repr(message['type'])} in the task goal." + ) + full_prompt_txt = "\n".join(prompt_text_strings) + logger.info(full_prompt_txt) + + # query OpenAI model + response = self.openai_client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": system_msgs}, + {"role": "user", "content": user_msgs}, + ], + ) + action = response.choices[0].message.content + + self.action_history.append(action) + + return action, {} + + +@dataclasses.dataclass +class DemoAgentArgs(AbstractAgentArgs): + """ + This class is meant to store the arguments that define the agent. + + By isolating them in a dataclass, this ensures serialization without storing + internal states of the agent. + """ + + model_name: str = "gpt-4o-mini" + chat_mode: bool = False + demo_mode: str = "off" + use_html: bool = False + use_axtree: bool = True + use_screenshot: bool = False + + def make_agent(self): + return DemoAgent( + model_name=self.model_name, + chat_mode=self.chat_mode, + demo_mode=self.demo_mode, + use_html=self.use_html, + use_axtree=self.use_axtree, + use_screenshot=self.use_screenshot, + ) diff --git a/demo_agent/run_demo.py b/demo_agent/run_demo.py index 1767fb4c..a8702cd9 100644 --- a/demo_agent/run_demo.py +++ b/demo_agent/run_demo.py @@ -1,15 +1,10 @@ -""" -WARNING DEPRECATED WILL BE REMOVED SOON -""" - import argparse -from pathlib import Path -from browsergym.experiments import ExpArgs, EnvArgs +# browsergym experiments utils +from browsergym.experiments import EnvArgs, ExpArgs, get_exp_result -from agents.legacy.agent import GenericAgentArgs -from agents.legacy.dynamic_prompting import Flags -from agents.legacy.utils.chat_api import ChatModelArgs +# locally defined agent +from basic_agent import DemoAgentArgs def str2bool(v): @@ -28,8 +23,8 @@ def parse_args(): parser.add_argument( "--model_name", type=str, - default="openai/gpt-4o", - help="Model name for the chat model.", + default="gpt-4o-mini", + help="OpenAI model name.", ) parser.add_argument( "--task_name", @@ -44,57 +39,29 @@ def parse_args(): help="Starting URL (only for the openended task).", ) parser.add_argument( - "--slow_mo", type=int, default=500, help="Slow motion delay for the playwright actions." - ) - parser.add_argument( - "--headless", - type=str2bool, - default=False, - help="Run the experiment in headless mode (hides the browser windows).", - ) - parser.add_argument( - "--demo_mode", + "--visual_effects", type=str2bool, default=True, help="Add visual effects when the agents performs actions.", ) parser.add_argument( - "--use_html", type=str2bool, default=True, help="Use HTML in the agent's observation space." + "--use_html", + type=str2bool, + default=False, + help="Use HTML in the agent's observation space.", ) parser.add_argument( - "--use_ax_tree", + "--use_axtree", type=str2bool, default=True, - help="Use AX tree in the agent's observation space.", + help="Use AXTree in the agent's observation space.", ) parser.add_argument( "--use_screenshot", type=str2bool, - default=True, + default=False, help="Use screenshot in the agent's observation space.", ) - parser.add_argument( - "--multi_actions", type=str2bool, default=True, help="Allow multi-actions in the agent." - ) - parser.add_argument( - "--action_space", - type=str, - default="bid", - choices=["python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"], - help="", - ) - parser.add_argument( - "--use_history", - type=str2bool, - default=True, - help="Use history in the agent's observation space.", - ) - parser.add_argument( - "--use_thinking", - type=str2bool, - default=True, - help="Use thinking in the agent (chain-of-thought prompting).", - ) return parser.parse_args() @@ -102,57 +69,56 @@ def parse_args(): def main(): print( """\ -WARNING this demo agent will soon be moved elsewhere. Expect it to be removed at some point.""" +--- WARNING --- +This is a basic agent for demo purposes. +Visit AgentLab for more capable agents with advanced features. +https://github.com/ServiceNow/AgentLab""" ) args = parse_args() + # setting up agent config + agent_args = DemoAgentArgs( + model_name=args.model_name, + chat_mode=False, + demo_mode="default" if args.visual_effects else "off", + use_html=args.use_html, + use_axtree=args.use_axtree, + use_screenshot=args.use_screenshot, + ) + + # setting up environment config env_args = EnvArgs( task_name=args.task_name, task_seed=None, max_steps=100, - headless=args.headless, - viewport={"width": 1500, "height": 1280}, - slow_mo=args.slow_mo, + headless=False, # keep the browser open + # viewport={"width": 1500, "height": 1280}, # can be played with if needed ) + # for openended task, set environment and agent to interactive chat mode on a start url if args.task_name == "openended": + agent_args.chat_mode = True env_args.wait_for_user_message = True env_args.task_kwargs = {"start_url": args.start_url} + # setting up the experiment exp_args = ExpArgs( env_args=env_args, - agent_args=GenericAgentArgs( - chat_model_args=ChatModelArgs( - model_name=args.model_name, - max_total_tokens=128_000, # "Maximum total tokens for the chat model." - max_input_tokens=126_000, # "Maximum tokens for the input to the chat model." - max_new_tokens=2_000, # "Maximum total tokens for the chat model." - ), - flags=Flags( - use_html=args.use_html, - use_ax_tree=args.use_ax_tree, - use_thinking=args.use_thinking, # "Enable the agent with a memory (scratchpad)." - use_error_logs=True, # "Prompt the agent with the error logs." - use_memory=False, # "Enables the agent with a memory (scratchpad)." - use_history=args.use_history, - use_diff=False, # "Prompt the agent with the difference between the current and past observation." - use_past_error_logs=True, # "Prompt the agent with the past error logs." - use_action_history=True, # "Prompt the agent with the action history." - multi_actions=args.multi_actions, - action_space="bid+nav", - use_abstract_example=True, # "Prompt the agent with an abstract example." - use_concrete_example=True, # "Prompt the agent with a concrete example." - use_screenshot=args.use_screenshot, - enable_chat=True, - demo_mode="default" if args.demo_mode else "off", - ), - ), + agent_args=agent_args, ) - exp_args.prepare(Path("./results")) + # running and logging results + exp_args.prepare("./results") exp_args.run() + # loading and printing results + exp_result = get_exp_result(exp_args.exp_dir) + exp_record = exp_result.get_exp_record() + + for key, val in exp_record.items(): + print(f"{key}: {val}") + if __name__ == "__main__": main() diff --git a/tests/core/test_task.py b/tests/core/test_task.py new file mode 100644 index 00000000..4e6549e3 --- /dev/null +++ b/tests/core/test_task.py @@ -0,0 +1,73 @@ +from typing import Tuple + +import playwright +import pytest + +from browsergym.core.env import BrowserEnv +from browsergym.core.task import AbstractBrowserTask + + +class MockImageGoalTask(AbstractBrowserTask): + @classmethod + def get_task_id(cls): + return "mockimagegoal" + + def __init__(self, seed: int = 0, start_url: str = "https://www.google.com") -> None: + """ + Args: + seed: random seed. + start_url: str, the url for the starting page. + goal: str, the initial goal. + + """ + super().__init__(seed) + self.start_url = start_url + self.goal = [ + {"type": "text", "text": "This is a mock task with an image goal."}, + { + "type": "image_url", + "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=", + }, + ] + + def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]: + page.goto(self.start_url, timeout=10000) + return self.goal, {} + + def teardown(self) -> None: + pass + + def validate( + self, page: playwright.sync_api.Page, chat_messages: list[str] + ) -> Tuple[float, bool, str, dict]: + reward, done, msg, info = 0, False, "", {} + + for message in chat_messages: + if message["role"] == "user" and message["message"] == "exit": + done = True + break + + return reward, done, msg, info + + +def test_mock_image_goal_task(): + env = BrowserEnv(MockImageGoalTask) + obs, _ = env.reset() + + assert "goal_object" in obs + assert len(obs["goal_object"]) == 2 + assert obs["goal_object"][0]["type"] == "text" + assert obs["goal_object"][0]["text"] == "This is a mock task with an image goal." + assert obs["goal_object"][1]["type"] == "image_url" + + env.chat.add_message("user", "exit") + obs, reward, terminated, _, _ = env.step("send_msg_to_user('bye')") + + assert reward == 0 + assert terminated is True + + env.close() + + +if __name__ == "__main__": + test_mock_image_goal_task()